Parallel scans

This post is inspired by the recent paper on Mamba. Mamba introduces a simplified, linear RNN and shows that it can be computed in \(\mathcal{O}(\log n)\) time using a parallel scan. It’s not immediately obvious how the parallel scan algorithm can be applied to this recurrence, so I set out to understand the approach and see if it could be generalized.

parallel scan

The parallel scan

We have a sequence \(x_0, \ldots, x_{n-1}\) and we want to compute the sequence \(y_0, \ldots, y_n\) where \(y_0\) is given and for \(i > 0\), \(y_i = f_i(x_i, y_{i-1})\).

If \(y_0 = 0\) and \(f_i(x,y) = x + y\), the scan computes the prefix sums, and due to the associativity of addition, it can be computed in parallel. Compute, in parallel, the prefix sums for the first half and for the second half. Then add the last element of the first half (the sum) to each element of the second half, which can be done entirely in parallel. With enough parallel processors, this means the prefix sum can be computed in \(\mathcal{O}(\log n)\).

A more complicated example

Let’s consider a more complicated recurrence. We will make the \(x_i\) and the \(y_i\) vectors, and \(f_i(x,y) = A_i x_i + B_i y_i\) for index-dependent matrices \(A_i\) and \(B_i\).

A priori, this is not an associative operation… But let’s try anyway. We’ll process the first half starting with some \(y_0\), and the second half starting with some \(y_{n/2-1}\). If we knew the real \(y_{n/2-1}\), it would be easy, but we don’t. This doesn’t really help us; if we change \(y_{n/2-1}\), we have to recompute the scan for the second half.

The trick is to lift the sequence \(y\). Notice that given the \(A_i\), \(B_i\), and \(X_i\), each \(y_i\) is a linear function of \(y_j\) for \(j < i\). But we know how to represent linear functions between vectors: they are matrices. So if, instead of calculating the sequence of \(y_i\), we compute the sequence of \(Y_i\) which represents the function that gives \(y_i\) as a function of \(y_0\), then we are getting somewhere!

Indeed, we can compute the second half of the scan as \(Y_{n/2}, Y_{n/2+1}, \ldots\) and, when the first half has completed, multiply those matrices by \(y_{n/2-1}\), which can be done in parallel.

Generalization

In fact, the transform we have applied applies generically to any function \(f_i\)! The algorithm is the following:

  • Compute the first half of the sequence.
  • For the second half of the sequence, compute \(F_i : y \rightarrow f_i(x_i, F_{i-1}(y))\).

In the linear state-space model, the functions collapse and simplify, and we have a neat representation for them. However, in the general case, those functions do not simplify. However, we can perhaps approximate them?

For instance, in the one-dimensional case, we could posit that \(F_i\) is well approximated by a Chebyshev polynomial and maintain the \(F_i\) as the values taken around the Chebyshev nodes.

The Chebyshev trick can be extended to the multivariate case, but the size of the representation unfortunately blows up with dimensionality \(d\) as \(\binom{n+d}{d}\) for polynomials of degree \(n\).

Alternatively, we could start with a random cloud of points, propagate it using \(f_i\), but then assume that \(F\) is approximated by a Gaussian process and move the points to somehow provide a better representation of the function, for example by moving them slightly towards zones of higher curvature.

Gradients and conclusion

Another matter is how to deal with gradients. Since the \(f_i\) are known, it’s straightforward to compute the gradient with respect of their parameters. However, these are the gradients of the exact scan. Should we use those gradients, or should we use the gradient of our approximation? We know the cloud of points at \(i-1\), so we know how a change in our approximation of \(y_{i-1}\) and a change in \(f_i\) would cause a change in our approximation of \(y_{i}\).

Ultimately, this whole fields seems to be more of an experimental science than anything, and mathematics are just here to suggest reasonable things which are worth trying. I believe this is.

It also gives a framework to detect parallel scans… do the functions \(F\) belong to some set with a simple parametrization which is stable by composition with \(f_i\)? Linear functions and simple categorical functions are the only examples I can think of, but there could be more.

comments powered by Disqus