Vectorising The Inside Algorithm
This one goes by several names: CYK, Inside, Matrix chain ordering problem. Whatever you call it, the “shape” of the algorithm looks the same:
And it’s ultimately used to enumerate over all possible full binary trees. In the Matrix chain ordering problem, the tree defines the pairwise order in which matrices are multiplied,
$$(A(BC))(DE)$$
while CYK constructs a tree from the bottom up with Context-Free Grammar rules that would generate the observed sentence. The Inside algorithm simply computes the possible subtrees and they’re corresponding probabilities in order to give the probability of the sentence.
Who even does this anymore?
It’s been making a little bit of a comeback in deep learning, despite everything rushing headlong towards Transformer-based models.
For one thing, a lot of these losses computed based on dynamic programming are differentiable if you relax them. The way DIORA goes about it, it uses the predicted probabilities kinda like an attention score. Compound PCFG on the other hand is a bit more ‘traditional’ with what it does, and manages to train a Neural PCFG, augmented with a continuous latent variable.
There’s more examples in recent literature, but you get the point, it’s not dead! (yet.)
It’s $O(N^3)$. Waaay too slow.
It is. Right now, the pseudo-code looks something like this:
for span_length in range(2, sentence_length):
for start in range(sentence_length - span_length + 1):
end = start + span_length - 1
current_val = None
for split in range(span_length - 1): # Innermost loop
current_val = aggregator(
current_val,
f(start, split, end,
M[start, start + split],
M[start + split + 1, end])
)
M[start, end] = current_val
or, you can just refer back to that awesome GIF at the start of the post.
We can make that a little better though.
What can we parallelise in this process?
Notice that, generally, applications of Inside do not have the “Innermost loop” depend on previous iterations of the loop.
The aggregator
is typically a sum
, max
, or logsumexp
.
This means all separate f(start, split, end, M[start, split], M[split+1, end])
can be computed in parallel:
This is fairly convenient for a GPU, as we regularly compute max
or mean
pool operations over tensors. $O(N^2)$!
But we can go further. Notice there’s some similarity here to how convolutions work: As long as the computations don’t depend on each other, they can be parallelised.
The writing operations of each row do not depend on other elements in that row.
This means that the computation of the row can also be parallelised: for all possible spans of span_length
, pool across all possible splits, at the same time.
Or in GIF-ese,
We’re at $O(N)$ now, if you consider each parallel operation constant time. There’s no escaping the $O(N^3)$ complexity in memory though, since all possible intermediate values have to be kept in order to compute the gradients through the algorithm.
Ok, but in PyTorch?
Well, for starters, we can compute all start
and end
indices in parallel,
def start_end_idxs(sentence_length, span_length):
start = torch.arange(sentence_length - span_length + 1)
end = start + span_length - 1
return start, end
start_end_idxs(10, 2)
# (tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]),
# tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]))
start, end = start_end_idxs(10, 4)
# (tensor([0, 1, 2, 3, 4, 5, 6]),
# tensor([3, 4, 5, 6, 7, 8, 9]))
Great. Now we need to deal with the splits (Innermost loop). What should this look like?
Well, for each row in this triangle, we care about one span_length
.
For each span_length
, we enumerate through the possible spans in the sentence.
This gives us the above start
and end
vector of indices, of dimension (sentence_length - span_length + 1,)
For each index in start
and end
, there are span_length - 1
possible splits: from start
to start + span_length - 1
.
This gives us a clue that the dimensions should look something like:
(sentence_length - span_length + 1, span_length - 1)
def split_idxs(sentence_length, span_length):
start = torch.arange(sentence_length - span_length + 1)
end = start + span_length - 1
splits = torch.arange(span_length - 1)
split_idxs = start[:, None] + splits
return split_idxs
split = split_idxs(10, 4)
# tensor([[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4],
# [3, 4, 5],
# [4, 5, 6],
# [5, 6, 7],
# [6, 7, 8]])
Alright, so let’s list what we have up to this point:
start
:(sentence_length-span_length+1,)
end
:(sentence_length-span_length+1,)
split
:(sentence_length-span_length+1, span_length-1)
How do we use these indices?
In the naive version, M[i,j]
stores the value that represents a span starting at i
and ending at j
.
It then composes the values at M[start, start + split]
(we’ll call this the left subtree), and M[start + split + 1, end])
(right subtree) in the application specific way, before aggregating everything along those splits, again, in an application specific way.
We can extract the values for all possible splits too, which should result in two tensors with the dimensions (sentence_length-span_length+1, span_length-1)
, one for the left subtree and one for the right subtree.
For illustration purposes, let’s fill our intermediate M
table with these values:
M = \
tensor([[ 0, 1, 2, 0, 0, 0, 0, 0, 0, 0],
[ 0, 11, 12, 13, 0, 0, 0, 0, 0, 0],
[ 0, 0, 22, 23, 24, 0, 0, 0, 0, 0],
[ 0, 0, 0, 33, 34, 35, 0, 0, 0, 0],
[ 0, 0, 0, 0, 44, 45, 46, 0, 0, 0],
[ 0, 0, 0, 0, 0, 55, 56, 57, 0, 0],
[ 0, 0, 0, 0, 0, 0, 66, 67, 68, 0],
[ 0, 0, 0, 0, 0, 0, 0, 77, 78, 79],
[ 0, 0, 0, 0, 0, 0, 0, 0, 88, 89],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 99]])
(we only care about stuff in the upper triangle of the table)
If we’re computing values for span_length = 4
, for the first inner loop, we need [0, 1, 2]
, on the left, and correspondingly [13, 23, 33]
on the right.
To extract all of them at once, we can fall back on some indexing black magic:
l_vals = M[start[:, None], split]
# tensor([[ 0, 1, 2],
# [11, 12, 13],
# [22, 23, 24],
# [33, 34, 35],
# [44, 45, 46],
# [55, 56, 57],
# [66, 67, 68]])
r_vals = M[split + 1, end[:, None]]
# tensor([[13, 23, 33],
# [24, 34, 44],
# [35, 45, 55],
# [46, 56, 66],
# [57, 67, 77],
# [68, 78, 88],
# [79, 89, 99]])
Yes! Indices broadcast too!
Now we can aggregate some function of l_vals
and r_vals
across dim=1
, which will result in a (sentence_length-span_length+1,)
tensor.
These values should be written into M[start, end]
.
Since start
and end
are themselves tensors of indices, this will give you the corresponding cells to fill in the next diagonal row.
M[start, end] = aggregate(start, end, splits, l_vals, r_vals)
So you can now implement aggregate
, and changing that implementation will result in the different algorithms listed above.
I’ve written up a simple version of this with minibatching here. Some things to consider in a v2.0:
-
What if your minibatch has varying lengths? How would masking work?
-
Say your combination function is:
lin_1 = nn.Linear(hidden_size, hidden_size) lin_2 = nn.Linear(hidden_size, hidden_size) def op(l_val, r_val): return torch.tanh(lin_1(l_val) + lin_2(r_val))
Is there some redundancy here? Are there ways to reduce that redundancy?
-
M
is huge if you have hidden representation that is large (we spoke here about scalars, but in DIORA it’s hidden layers of a “soft” tree structure). Can we do better than $N \times N$?
If we’re scaling up models like we do these days, it’s worth putting some thought into speedups, especially with novelty architectures that involve algorithms like this one. Hopefully some of the techniques here can be useful to you in thinking about parallelisation with frameworks like PyTorch.
@misc{tan2021-10-08,
title = {Vectorising The Inside Algorithm},
author = {Tan, Shawn},
howpublished = {\url{https://blog.wtf.sg/posts/2021-10-08-vectorising-the-inside-algorithm/}},
year = {2021}
}