Exploring Beyond Regular Transformers
A review of transformer variants: RWKV, Longnet, MegaByte, Hyena and more...
In this article, I will explore various alternatives to transformers, considering their architectural improvements, computational efficiency, and performance results across different benchmarks. I intend to continually update this post with new models in the future. If you believe there are any models or important points that should be included or any corrections that need to be made, please feel free to reach out.
Transformer
Space: O(T^2 + Td)
Time: O(T log Td)
Traditional sequential models, like recurrent neural networks (RNNs) and long short-term memory networks (LSTMs), face challenges in effectively capturing long-range dependencies and parallelizing computations. The Transformer architecture addresses these issues by relying on self-attention mechanisms.
At the core of the Transformer is the self-attention mechanism. Unlike traditional approaches, where each element in a sequence is processed one at a time, self-attention allows the model to weigh the importance of different elements relative to each other. This enables capturing relationships between distant words in a sentence.
Transformer has some limitations and constraints in terms of computation and storage. The Transformer is based on dot-product attention that computes softmax(Q*K.t)
, which is computationally heavy, and it needs to store a KV cache that is also heavy in memory at inference. This is a limiting factor, especially in problems with extended context sizes. Transformers' space complexity increases quadratically with the increasing context size.
The Transformer is a key component of the current LLM revolution, and researchers are actively seeking alternatives to address its limitations. While there have been several proposed alternatives, the original model has yet to be as successful as the original model. Nevertheless, considering the scale of the state-of-the-art LLM problem and the high cost of training these models, even a slight improvement can have a significant impact.
RWKV
Space: O(Td)
Time: O(Td)
RWKV is a new approach that combines the advantages of RNNs and Transformers while mitigating their known limitations. It introduces several key strategies that allow it to capture local and long-range dependencies. RWKV offers a promising and viable solution for modeling with billions of parameters, exhibiting competitive performance at a fraction of the computational cost.
The aim of RWKV is to harness the advantages of both RNNs and Transformers while addressing their shortcomings. In comparison to RNNs, RWKV provides more efficient parallelizable training and improved performance in capturing long-range dependencies. This is achieved by eliminating the reliance on a single vector to transmit the context between different time steps.
Compared to Transformers, RWKV offers linear attention and constant computational and memory complexity during inference, making it more efficient for large-scale models.
There are two primary components of a RWKV block: time-mixing and channel-mixing. Time-mixing operates by using linear interpolation to blend the current input with the input from the previous time step. This process effectively combines and controls the information in the input channels. The time-mixing block is composed of three equations that compute the values of r, k, and v at each time step, which are then used to calculate WKV which plays the role of Transformer's attention.
Channel mixing aids in capturing local information effectively. It works by computing the values of r, k, and o at each time step, which are then used to calculate the final output vector. The channel-mixing block comprises three equations that compute the values of r, k, and o at each time step. The output vector is calculated by taking the sigmoid of the receptance r and using it as a "forget gate" to eliminate unnecessary historical information. The final output vector is then computed by multiplying the sigmoid of r with the result of a squared ReLU activation on k.
RWKV has certain limitations. For instance, it may struggle with tasks that require recalling information over a long context. This is because RWKV relies on a limited window between time steps, whereas Transformers have access to all the information at each step through attention. Another limitation is the prominence of prompt engineering in RWKV. In RWKV, the linear attention mechanism restricts the extent to which prompt information is passed on to the model. Empirical evidence supports this, showing that when prompts were adjusted to be for RWKV, there was a significant increase in performance, with the F1 measure improving from 44.2% to 74.8%.
The results have demonstrated that RWKV has delivered impressive performance and outperformed other models. Nevertheless, when tasks require a greater emphasis on context, RWKV's performance tends to decline.
RWKV is developed fully in open-source. They provide many models in different languages and tasks that are all in Apache-2 license. The community is very active and efforts are already underway to address the limitations of RWKV. You can join their discord to participate in the process.
Hyena
Time: O(NdT (logT + d))
st. N is the number of projections
Space Complexity: O(Td)
Hyena addresses the attention in Transformers, which becomes computationally expensive with longer sequences. Hyena offers a subquadratic alternative to attention by combining long convolutions with data-controlled gating. In various tasks involving recall and reasoning with sequences containing thousands to hundreds of thousands of tokens, Hyena has demonstrated significant improvements in accuracy. It achieves Transformer-level quality while reducing required compute in training by 20% at a sequence length of 2K. Notably, Hyena operators are also faster, offering twice the efficiency of highly optimized attention operators.
Hyena first projects the input into a set of vectors v, x_1, ..., x_n
and v
acts like the value vector as in the attention. Then it projects v, x_1, ..., x_n
with learnable filters h_1, ..., h_n
. Hyena applies a multiplicative gating interaction to the projected vectors, similar to LSTMs. This gating is used to control the information flow through the recurrence.
Hyena uses an implicit long convolution to the gated input, using a set of Hyena filters that are parametrized by a feedforward network. This convolution is used to capture long-range dependencies in the input.
Below is the overall Hyena operator in Python as described in the blog post:
def hyena_filter(t):
return window(t) * ffn(t) * poitional_encoding(t)
x, v = input_projections(u)
for o in range(hyena_orders):
h = hyena_filter(L) # long conv filter parameterized via an MLP
v = x[o] * fftconv(h, v) # elem-wise mult & fftconv
)
Regarding language modeling, Hyena is compared to GPTNeo and RWKV. Hyena outperforms in few-shot learning, but RWKV is better in zero-shot accuracy on SuperGLUE tasks. Moreover, Hyena performs on par with a Transformer on language modeling on WikiText103.
Regarding runtime, the cross-over point between Hyena and attention occurs at 2048, and Hyena and flash attention range from 4086 to 8196.
My 2 cents: Hyena is an interesting approach for extending input length through scalable computing. Nonetheless, further investigations on a larger scale are necessary to confirm its efficacy as a viable alternative to the Transformer model. For now, the RWKV model offers better value in terms of both complexity and performance.
Attention Free Transformer
Time: AFT-simple O(Td)
, AFT-full O(T^2d)
Space: O(Td)
📎 Paper
👩💻 Code (unofficial)
Attention Free Transformer (AFT) eliminates the need for dot product self-attention, making it scalable with long inputs and large model sizes. AFT takes advantage of locality and spatial weight sharing while maintaining global connectivity, resulting in excellent efficiency. The paper presents experiments on autoregressive modeling tasks and image recognition, demonstrating competitive performance compared to other models.
AFT is a weighted average over values combined by the queries with element-wise multiplication instead of a heavy attention matrix. In an Attention-based Feedforward Transformer (AFT) layer, the learned position biases are added to the key values. Then, the values are combined with the key using element-wise multiplication. Finally, the resulting values are multiplied with the query element-wise. Thus, it avoids the computationally heavy softmax(Q*K.t)
of a Transformer. "AFT can be interpreted as performing implicit attention with as many heads as feature dimensions, where the attention matrices take a factorized form."
There are four different versions of AFT. The first version is AFT-simple, which does not utilize position encoding. The second version is AFT-full, which includes regular position encoding. The third version is AFT-local, incorporating a learned set of relative position biases within a specified window. The fourth version is AFT-conv, which utilizes depth-wise separable convolution and is proposed especially for image tasks.
In terms of results, the paper shows that AFT achieves comparable or better accuracy than traditional Transformers on various autoregressive modeling tasks and image recognition tasks while using much smaller memory footprints. AFT also outperforms other efficient Transformer variants such as Linformer and Performer. The paper also demonstrates the effectiveness of AFT on variable-length inputs and shows that it is well-suited for pretraining and finetuning workflows in vision tasks.
My 2 cents: AFT shows potential as a substitute for conventional Transformers. It substantially reduces computational requirements and memory usage, all while maintaining high performance. Moreover, AFT serves as the foundation for the development of both Hyena and RWKV.
Retentive Network
Time: O(Td(b + h))
s.t. b chunk size and h is head dimension
Space: O(T)
📎 Paper
👩💻 Official Code
👩💻 Code 1
👩💻 Code 2
RetNet borrows recurrent inference from RNN and parallel training from Transformer, combining them to achieve an efficient model. Recurrent models facilitate O(1) inference as they do not require modeling the relationship between each input and every other input in the sequence. RetNet applies chunk-wise recurrence to alleviate the representational bottleneck of RNNs and yet be efficient at inference.
RetNet introduces a novel approach to replace the softmax attention with a Hadamard product. By leveraging D-matrix and incorporating a GroupNorm operation, the relative attention weights assigned to each token in the input sequence are determined.
In RetNet, training and inference use different computation graphs that result in the same math. In the training phase, a parallel formulation, while in the inference phase, a recurrent formulation is utilised.
I suggest you check this post by Shantanu Chandra who did a better job than the paper explaining how things work.
When we compare RetNet to attention-free transformers and RWKV, it retains the element-wise interactions in the sequence with the retention operation. It keeps the high-dimensional state of the encoded sequence information, which they claim to contribute to the model performance.
Results show that after ~2.7B parameters, RetNet achieves lower perplexity and outperforms Transformer. Most of the results are reported based on the 6.7B model. RetNet is significantly better than Transformer at this scale in zero-shot, few-shot learning.
RetNet replaces the KV cache of Transformers with recurrence and saves memory. Also, chunk-wise retention makes inference significantly scalable with increasing batch size and input length.
They also show that RetNet is computationally way more efficient than Transformer and almost on par with Transformer + Flash Attention 1 (needs to compare Flash Attention2). Results show that it uses 3.4x lower memory, 8.4x higher throughput, and 15.6x lower latency concerning a Transformer model.
When compared to the other Transformer alternatives, RetNet outperforms all the different models by a big margin on language modeling.
Longnet
Time: O(Td)
Space: O(T/r log T/r d)
s.t. r
is the attention dilation rate
LONGNET is designed to tackle longer sequence problems. It can handle sequences with over 1 billion tokens while maintaining performance on shorter sequences. This is accomplished through dilated attention, which enhances the model's ability to attend to distant tokens. LONGNET has advantages such as linear time complexity, and the capability to serve as an efficient distributed trainer for long sequences. Experiments confirm its effectiveness.
To simplify the self-attention layers, LONGNET utilizes dilated attention. This approach involves dividing the input sequence into segments and dilating each segment at a specific rate. The model uses different segment lengths and dilation rates to improve its modeling abilities. The outputs of each segment size and dilation rate pair are then combined through a weighted sum. These weights are determined based on the softmax denominators of each output. The combination of using segments and dilated attention strikes a balance between considering the global context and maintaining efficiency, as dilated attention serves as an efficient approximation of the dense attention matrix.
Two more tricks LONGNET employs for better modeling. It incorporates varying dilation rates in each attention head for more diversity. It also gradually increases the segment lengths and dilation rates in successive layers, allowing for the processing of extremely long input sequences with an increasing receptive window in late layers.
To train LONGNET on 1 billion tokens, distributed training is necessary. LONGNET divides the inputs into segments, which are then distributed across different GPUs. These segments are processed simultaneously, with a constant communication overhead.
They used the Stack dataset to test the model, a source code collection with over 300 programming languages. They showed that LONGNET outperforms a vanilla Transformer model by a large margin in perplexity and computation. They were able to train LONGNET with 32k context size whereas the Transformer only 16k.
My 2 cents: Consider using LONGNET when processing or streaming long sequences.
MegaByte
Time: O(T ^ (4/3) d)
Space: O(T log Td)
MEGABYTE is a "multiscale decoder architecture that enables end-to-end differentiable modeling of sequences of over one million bytes". MEGABYTE learns from raw bytes, requiring the ability to effectively capture a lengthy context. To achieve this, it divides sequences into patches and employs a local model for each patch, while also incorporating a global model between patches. By doing so, MEGABYTE achieves sub-quadratic attention, facilitates larger feedforward layers without incurring additional computational costs, and enhances parallelism during decoding. As a result, MEGABYTE delivers improved performance for training and generation efficiently.
MEGABYTES offers several advantages, including sub-quadratic self-attention, per-patch feedforward layers, and parallel decoding. The sub-quadratic self-attention is achieved by dividing the input into smaller "patches,". This reduces the self-attention cost to O(T^(4/3)d)
.
They note that in a Transformer, the feedforward layers consume about 98% of the FLOPs. MEGABYTES addresses this by replacing multiple passes of these layers with a single pass, utilizing a larger linear layer.
Furthermore, the use of patches also introduces a level of parallelism. As a result, they found that their 1.5B parameter model is 40% faster than a 350M Transformer model.
The MEGABYTE system is composed of three main components:
patch embedder, which converts the patch sequences into a representation that considers the context.
global Transformer that encodes the contextualized inputs.
local Transformer that takes each output from the global model and predicts the output tokens in an auto-regressive manner.
MEGABYTE is applied to language modeling, image modeling, and audio modeling. The cool thing is that it is trained by the raw byte values (hence the name). It is compared to PerceiverAR and a Transformer baseline. In all tasks, it outperforms both and is competitive with models that use tokenizers to discretize the input.
The ablation analysis shows the importance of having both local and global models. If one of these components is absent, there is a notable decline in performance.
My 2 cents: I find learning from raw bytes and utilizing multi-stage transformers intriguing. This approach can potentially revolutionize language model systems (LLMs). By eliminating tokenization models, we can bridge the gap between computers and models, paving the way for developing new generation LLM-based operating systems.
In addition, I'd like to try MEGABYTE for text-to-speech. I believe it is well-suited to learn local and global relations better than Transfomers for TTS.
Edit: Looks like UniAudio did it.
Noteworthy Mentions
Here are a few other noteworthy models that I won't delve into further since they have yet to gain much traction in the community or are simple tricks that don't require much explanation.
Multi-Query Attention
Using shared key and value vectors among attention heads reduces the memory overhead at inference by reducing the size of the KV cache.
Linformer
A linear time self-attention is achieved by breaking down the scaled dot-product attention into multiple smaller attentions using linear projections. Together, these operations create a low-rank factorization of the original attention mechanism.
Roformer
"Rotary Position Embedding, or RoPE, is a type of position embedding which encodes absolute positional information with rotation matrix and naturally incorporates explicit relative position dependency in the self-attention formulation."
One Wide Feedforward is All You Need
📎 Paper
It is suggested that the Feedforward Network (FFN) are unnecessary and redundant in Transformers. As a result, the FFN is removed from the Transformer decoder, shared in the encoder. Even though there is a small decrease in accuracy as a result of this change, when the model is scaled back to its original size, it leads to enhanced accuracy and decreased latency. They report 18.5% speed-up using this technique.
Performer
Time: O(Td^2 log d)
Space: O(Td log d + d^2 lod d)
Performer can "estimate" regular dot-product attention using an approach called "Fast attention via positive orthogonal random features" FAVOR+. FAVOR+ combines low-rank approximation, matrix factorization, and matrix decomposition; then the space and time complexity becomes much more linear.
Reformer
Time: O(T log Td)
Space: O(T log T + Td)
Reformer model incorporates three techniques to improve efficiency. First, it uses "reversible residuals" to reduce memory consumption by storing only one copy of the intermediate activation that can be used to reproduce the activations of the earlier layers by the model parameters. This helps minimize the memory overhead. Second, it splits values into chunks, saving memory in FFT layers and make the inference more efficient. Lastly, Reformer uses locality-sensitive hashing to approximate the attention matrix for a more efficient runtime.
Monarch Mixer
"Monarch Mixer uses monarch matrices for a sub-quadratic model in sequence length and model dimension. The idea is to replace the major elements of a Transformer with Monarch matrices — which are a class of structured matrices that "generalize the FFT and are sub-quadratic, hardware-efficient, and expressive." In Monarch Mixer, they use layers built up from Monarch matrices to mix across the sequence (replacing the Attention operation) and across the model dimension (replacing the dense MLP).
Conformers
📎 Paper
👩💻 Code (unofficial)
The Conformer is a variant designed for speech recognition. While the Transformer excels at capturing global relationships, it is less effective than convolutional layers in capturing local information. To address this, the Conformer augments the Transformer model by adding convolutional layers between the attention module and the final feedforward layer. As a result, the Conformer achieves significantly better performance than previous Transformer and CNN-based models, setting new state-of-the-art on ASR.
Efficient Streaming LMs with Attention Sinks
📎 Paper
This looks similar to Longnet, but they keep a set of learnable tokens - sinks - at the beginning of the generated sequence, observing that it improves stability and performance even if you window the attention computation.