Neural Turing Machines – Implementation Hell

I’ve been struggling with the implementation of the NTM for the past week and a half now.

There are various problems that I’ve been trying to deal with. The paper is relatively sparse when it comes to details of the architecture, and a lot more brief when it comes to the training process. Alex Graves trains RNNs a lot in his work, and it seems to me some of the tricks he has used here might have been distributed through his previous work.

Controller Outputs

A huge part of my problems stem from not knowing what the architecture of the controller looks like. The paper describes it’s inputs as coming from the input sequence, and the read head of the memory. It’s outputs are the write heads, and the output sequence.

It then goes on to describe the reading and writing mechanism, sometimes referring to “multiple heads”. This got me confused as to whether there were separate read, erase, and write heads and corresponding weights to go along with each of these. Another possible way to implement this was to have, for each head, an erase, and an add head, with one set of weight parameters (parameters from which we calculate the current weight for this timestep from the previous). This is the version I finally settled with.

The other problem was to decide where the head[s] would go. Looking at previous LSTM papers, the write/forget gates are directly dependent on the input data as well, so this led me to think maybe the heads was a single layer network that directly gave a read, writes and erase on seeing input. I eventually settled for the heads being connected after the hidden layer.

I took a look at this paper to figure out how Graves usually models outputs of different domains. This is what I’ve come up with:

\hat{y} = \left(\hat{\key},\hat{\beta},\hat{g},\hat{\shift},\hat{\gamma}\right) = \mathbf{h}^\top\mathbf{W} + \mathbf{b}

\key &= \hat{\key} &\\
\beta &= \exp\left(\hat{\beta}\right) &\Longrightarrow &\beta > 0 \\
g &= \sigma\left(\hat{g}\right) &\Longrightarrow & g \in (0,1) \\
(\shift)_i &= \frac{\exp((\hat{\shift})_i)}{\sum_j \exp((\hat{\shift})_j)} &\Longrightarrow
& (\shift)_i \in (0,1),\sum_i (\shift)_i = 1 \\
\gamma &= \log\left(e^{\hat{\gamma}} + 1 \right) + 1 &\Longrightarrow & \gamma \geq 1 \\

I have no particular justification to choose the softplus over the $\exp$. In my opinion, exponentiating seems like it would be much more likely to be unstable numerically as opposed to the softplus.

Taking a look at equation (5) in the paper, if we let $\beta = 1$ and $K$ be the dot product function, then we see that equation (5) is actually a general form of the softmax function.

It’s not entirely clear to me why (5) looks more like softmax, while (9) doesn’t. I guess this might have something to do with (9) already dealing with values that are supposed to be probabilities.


I initially used adadelta for training, but because I was unable to get good results, I tried implementing the version of rmsprop Graves used in his previous paper.

The rmsprop algorithm is interesting, and it is slightly different from the one presented by Hinton in his Coursera lesson. In Graves’ variant, he is essentially dividing by the standard deviation of the gradient in recent history. You can take a look at the code here.

I run into a lot of problems with numerical stability, and have tried adjusting learning rates to cope with that. The paper mentions using gradient clipping, but since I have no way of implementing that easily, I left it out.

One hack that seemed to work was ‘squeezing’ the sigmoids down so that high activations don’t create $\log(0)$ situations.

$$\frac{\epsilon}{2} + (1-\epsilon)\cdot\sigma(x)$$

This results in the outputs being between $\left(\frac{\epsilon}{2},1-\frac{\epsilon}{2}\right)$, so the probabilities in cross-entropy calculation is never zero.

If you look at the file I use to do the training, you’ll see the commented code that was there from all the things I’ve tried.

Also read...

Leave a Reply

Your e-mail address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.