Neural Turing Machines – A First Look

Some time last week, a paper from Google DeepMind caught my attention.

The paper is of particular interest to me because I’ve been thinking about how a recurrent neural network could learn to have access to an external form of memory. The approach taken here is interesting as it makes use of a balance between seeking using similarity of content, and shifting from that using location.

My focus this time would be on some of the details needed for implementation. Some of these specifics are glossed over in the paper, and I’ll try to infer whatever I can and, perhaps in the next post, have code (in Theano, what else?) to present.

$$
\newcommand{\memory}{\mathbf{M}}
\newcommand{\read}{\mathbf{r}}
\newcommand{\erase}{\mathbf{e}}
\newcommand{\add}{\mathbf{a}}
\newcommand{\weight}{\mathbf{w}}
\newcommand{\key}{\mathbf{k}}
\newcommand{\shift}{\mathbf{s}}
\newcommand{\Shift}{\mathbf{S}}
$$

Reading and Writing

The memory matrix, $\memory_t$, is modified at every time step, and this is interacted with via heads: reading ($\read_t$), writing: erase ($\erase_t$) and add ($\add_t$). This is similar to the mechanisms used in Long Short-term Memory (LSTM) units, just with a slight change in naming – write is now add due to a clash in naming as there is a new weighting vector ($\weight_t$).

$$
\begin{align}
\memory_t
&= \left[\begin{array}{ccc}
\memory_t(0)_0 & \cdots & \memory_t(0)_{N-1}\\
\vdots & \ddots & \vdots \\
\memory_t(i)_0 & \cdots & \memory_t(i)_{N-1}\\
\vdots & \ddots & \vdots \\
\memory_t(M-1)_0 & \cdots & \memory_t(M-1)_{N-1}\\
\end{array}\right]
&
\weight_t &= \left[ \begin{array}{c}
w_t(0)\\
\vdots\\
w_t(i)\\
\vdots\\
w_t(M-1)
\end{array}\right]
\\
\read_t
&= ~\left[ \begin{array}{ccc}
~~~r_{0}~~~ & ~~~\cdots~~~ & ~~~r_{N-1}~~~
\end{array}\right]
\\
\erase_t
&= ~\left[ \begin{array}{ccc}
~~~e_{0}~~~ & ~~~\cdots~~~ & ~~~e_{N-1}~~~
\end{array}\right]
\\
\add_t
&= ~\left[ \begin{array}{ccc}
~~~a_{0}~~~ & ~~~\cdots~~~ & ~~~a_{N-1}~~~
\end{array}\right]
\end{align}
$$

Hopefully that should give you an idea of what goes where, and at the same time, give you a clue what their purposes are. The general idea here is to use $\weight_t$ as a focusing mechanism to select the row, and $\read_t$, $\erase_t$ and $\add_t$ perform their individual roles element wise in $\memory_t(i)$.

Here’s where I find this interesting. If we were to try and write some sort of automaton that does this, we would perform the read/write at one location. However, in order for this to be able to be “end-to-end” differentiable, the authors have made $\weight_t$ a distribution over positions, $\sum_i w_t(i) = 1$. This is a technique that is used pretty consistently throughout the different mechanisms in the NTM.

Addressing

Screenshot from 2014-10-30 12:50:51
Still on the topic of $\weight_t$, let me try and give a brief overview of what the authors are trying to achieve with their very… convoluted (forgive the pun) way of computing it at every time step.

In order for the controller to be able to perform any useful action on $\memory_t$, it needs to be able to take in an input from the input sequence, and then translate (by which I mean predict) that to some sort of key ($\key_t$) that it can look up. After that, merely picking out that particular value may not be particularly useful, since we could possibly have just used the prediction. We might need to shift ($\shift_t$) before or after that memory location to find something useful for computation. You could imagine a sort of table where keys and values are stored $k_0,v_0,\cdots,k_N,v_N$ in succession, then when we look up $k_i$ and that is at location $2i$, then $v_i$ would be at location $2i+1$, requiring an address shift by 1.

Let’s now take a closer look at how this $\weight_t$ is computed at each time step.

Content-based using $\key_t$ and $\beta_t$ to $\weight_t^c$

So, say our mysterious controller predicts a vector $\key_t$. Again, with the table analogy, we look up $\memory_t$ to find the entry most similar to $\key_t$, but we have to do this probabilistically. The authors do this by using a similarity function (in this case cosine similarity) and computing similarity with $\key_t$ over all entries in $\memory_t$. They then normalise this by slapping a softmax over the computed values. This gives us $\weight_t^c$, which is just an intermediate value before the final weight vector is computed.

Shifting values using $\shift_t$ to $\weight_t$

Now that we have a weighting for content-based addressing, we need to allow the controller to decide if it needs to use that, or to just perform a shift from the weighting from the previous time step. Predictably, the way the authors have achieved this is using the “expected” value ($\weight^g_t$) of $\weight_{t-1}$ and $\weight^c_t$ weighted by $g_t$ and $1-g_t$.

Okay, so time to shift $\weight^g_t$ probabilistically. It goes without saying then that $\sum_i s_t(i) = 1$. The paper describes this shift as,
$$\widetilde{w}_t(i) = \sum^{N-1}_{j=0} w^g_t(j)\cdot s_t(i-j)$$

There might be a simpler way to see how this works in vectorised form. We could use $\shift_t$ to describe a matrix $\Shift_t$,

$$\Shift_t =
\left[\begin{array}{ccc}
s_t(0) & s_t(N-1) & \cdots & s_t(2) & s_t(1) \\
s_t(1) & s_t(0) & s_t(N-1) & \cdots & s_t(2) \\
\vdots & s_t(1) & s_t(0) & \ddots & s_t(2) \\
s_t(N-2) & & \ddots & \ddots & s_t(N-1) \\
s_t(N-1) & s_t(N-2) & \cdots & s_t(1) & s_t(0) \\
\end{array}\right]
$$

in which case, then the final computation would just be,
$$ \widetilde{\weight}_t = \Shift_t\weight^g_t$$

As a final step, they power up the values of $\widetilde{\weight}_t$ by $\gamma_t$, and then renormalise to get the final $\weight_t$.

The Controller

Screenshot from 2014-10-30 12:50:35

So what’s this controller that keeps popping up? We could think of it as the CPU of the entire system, deciding what to read and write to memory. In the paper, the authors experiment with using an LSTM neural network and a standard feed-forward network for this purpose.

More importantly for us, we need to figure out what the input and the output is for this network in order to implement anything. So to summarise the above, here are the inputs and outputs, along with whatever constraints that may be attached to them.

Inputs:
$$
\begin{align}
\mathbf{i}_t,\read_t &\in \mathbb{R}^N \\
\end{align}
$$

Outputs:
$$
\begin{align}
\erase_t,\add_t,\key_t &\in (0,1)^N \\
\shift_t &\in (0,1)^M, & \sum_i& s_t(i) = 1 \\
\beta_t &\in \mathbb{R}^+ & \gamma_t & \in\mathbb{R}^{\ge 1} \\
& & g_t & \in (0,1) \\
\end{align}
$$

At every time step, these are the things fed into the controller, and the outputs with which we use to manipulate $\memory_t$.

And I think that’s all we need! I’ll look at implementing this over the weekend, and we’ll see if we can achieve similar results.

Also read...

Comments

  1. Hey Man, Love the blog,

    I a understand their differentiable memory concept, but how have you been able to implement that, and how does that not make a read probabilistic because of overwrites? Like if memory is spread across rows, then how are insert conflicts managed?

    Thanks!

    Reply
    • Also, how do the controller heads interact with the memory matrix, do they do so like LSTM, with a gate squashing the input to the memory matrix and a gate squashing the output and the memory matrix is just a node with a loop to itself?

      Reply
      • I’ve implemented all of this in Theano, a library in python that allows me to do symbolic differentiation. You can take a look at the code here: https://github.com/shawntan/neural-turing-machines

        The assumption is that the ‘blurry’ reads eventually get ‘sharper’ after training, which is what actually happens in the case of the copy task. The write happens after the read IIRC, and so there’s no conflict, assuming a pointwise read write.

        Yes, the interactions with the memory matrix kind of look like those of the LSTM. The difference here is that there is a probabilistic weight vector across the memory locations. So the memory at time $t$ is just the memory at $t-1$ after the reads and writes have been performed.

        Reply
          • All of these values emitted by the controller are not directly optimised, except the output. The idea is to learn the values of the read/write heads in order to give the correct output.

  2. Hi Shawn,

    Thank you for this informative blog. I am implementing the NTM on Matlab from scratch. I am using Feedforward NN with one hidden layer. I couldn’t figure out how to update the weights right behind the controller outputs (a, e, beta, gama, s, k, g). Did you propagate the error in the network that is in front of the read input?

    Thanks!

    Reply
    • The error is propagated from the expected output, so in the case of copy, it is propagated all the way from the output “labels”.

      This question keeps getting asked, and I think it would totally miss the point if some kind of signal is passed back from the heads. The idea should be to be able to just give it the input, give it the output, and let it learn the procedure.

      Reply
  3. Hey Shawn,

    thanks for the blog. It was useful in understanding several parts. Could you please tell me how you calculated the values for $a_t$, and $e_t$. From my understanding, if the memory is NxM, then i understand how to calculate the N w(i) values, but I am unclear on how to obtain $a_t$ and $e_t$ $\in$ $R^M$

    Reply

Leave a Reply

Your email address will not be published. Required fields are marked *