NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
News I've seen the posts about SuperHOT and just recently, the paper from Meta which uses RoPE interpolation, and I've noticed an immediate improvement that can be brought to this method. Basically if you apply Neural Tangent Kernel (NTK) theory to this problem, it becomes clear that simply interpolating the RoPE's fourier space "linearly" is very sub-optimal, as it prevents the network to distinguish the order and positions of tokens that are very close by. Borrowing from NTK literature, scaling down the fourier features too much will eventually even prevent succesful finetunes (this is corroborated by the recent paper by Meta that suggests an upper bound of ~600x)
Instead of the simple linear interpolation scheme, I've tried to design a nonlinear interpolation scheme using tools from NTK literature. Basically this interpolation scheme changes the base of the RoPE instead of the scale, which intuitively changes the "spinning" speed which each of the RoPE's dimension vectors compared to the next. Because it does not scale the fourier features directly, all the positions are perfectly distinguishable from eachother, even when taken to the extreme (eg. streched 1million times, which is effectively a context size of 2 Billion)
To my surprise, this method works extremely well, so much so that you don't even need to fine tune the LLaMA 7B model for 4096 context size! The perplexity degradation is minimal. I'm sure with fine tuning this would become even better.
Enough explanations, here's some empirical results. All the perplexity measurements are done on LLaMA 7b with the tau/scrolls · Datasets at Hugging Face dataset (I only used a subset of gov_report).
Here's a graph showing the average perplexity of LLaMA 7b on a set of 40 very long prompts (12k+ context size). Compared to changing the scale (from SuperHOT, which was set to 4), we change a factor alpha, which when equal to 8 provides the same context size increase but with much less perplexity degradation. All without any finetuning!
Code can be found on a colab notebook with a test example. NTKAwareScaledRotaryEmbedding.ipynb - Colaboratory (google.com)
Again, the changes to the RoPE code is only 3 lines.
Also just for fun, I've pasted the entirety of the Meta paper [2306.15595] Extending Context Window of Large Language Models via Positional Interpolation (arxiv.org), with the abstract and conclusions removed, and asked OASST-sft-6-llama-30b some questions. I think it fared pretty well given it hasn't been fine tuned at all for 6000+ context size. I used an alpha of 8 for this prompt.
The entire paper is about 6200 tokens. These prompts can be found in the colab notebook linked above. The following questions are simply appended at the end of the prompt which contains the entire research paper.
QUESTION
Please give me a brief summary of this research paper in a few bullet points.
ANSWER
-
Large Language models (LLMs) often have a pre-defined context window size, but sometimes they need to operate on longer contexts, e.g., in conversation systems or summarization tasks.
-
Extending the context window of an existing pre-trained LLM is challenging because direct extrapolation of positional encodings can lead to poor performance.
-
Researchers propose "Position Interpolation" to extend the context window of LLMs by downscaling the positional encodings.
-
They prove that interpolated positional encodings have a much smaller upper bound than extrapolated ones, allowing the model to adapt to longer contexts more efficiently.
-
They experimentally validate the effectiveness of Position Interpolation on LLaMA models, reducing the perplexity gap between the original and extended contexts.
-
They compare Position Interpolation with direct extrapolation and find that Position Interpolation works better.
-
They apply Position Interpolation to LLaMa models and fine-tune them on the Pile dataset, finding that the models can effectively utilize longer contexts.
-
They also evaluate the models' abilities to attend to longer contexts and perform well on benchmarks within the original context size.
-
They apply Position Interpolation to long document summarization and obtain competitive results.
-
They provide insights into the interpolation/extrapolation tradeoff and discuss potential improvements.
-
They plan to explore more advanced methods like data augmentation and multi-task learning to improve the interpolation.
QUESTION
Please write me the abstract for this paper.
ANSWER
This paper presents a technique called "Position Interpolation" for extending the context window size of pre-trained Large Language Models (LLMs) without modifying their architecture. Previous methods like ALiBi and LeX were unable to efficiently extend the context window size of LLMs, leading to poor performance. The authors propose Position Interpolation, which involves downscaling the positional encoding of the LLM's attention mechanism to match the desired longer context window size. They prove mathematically that interpolated attention scores are more stable than extrapolated ones, allowing the model to adapt to longer contexts with just a few hundred fine-tuning steps. Experiments showed that Position Interpolation enables LLMs to perform well on longer contexts, reducing perplexity by up to 0.8 points on average, and generating strong summarizations. Extending the context window size to 32768 led to slight degradation on the original benchmarks, but the models adapted to longer texts. The authors conclude that interpolating positional encodings allows the model to attend to relevant evidence in the longer contexts, and the stability of the interpolated attention score helps the model avoid catastrophic behaviors.
QUESTION
How many steps was the model fine tuned for the final results? Give a short answer.
ANSWER
The authors of the paper "Extending Context Windows of Large Language Models via Position Interpolation" fine-tuned the model for 1000 batches for the experiments presented in the paper.
QUESTION
How big is the interpolation bound compared to the extrapolation bound? Give a short answer.
ANSWER
The interpolation bound is at least 2 - 29.7 * 10^(-6) times smaller than the extrapolation bound.
Sorry for the large dump of information on reddit, I'm too lazy to write a blog post for this. I might give a more detailed explanation on how I derived the formula used to calculate the base if enough people are interested.
I did not test fine-tuning performance as I do not have the resources or the time to fine tune an LLM, I just derived this formula during lunch and experimented with it. However, I think that this method will do even better with fine tuning. Also thanks to the people behind the SuperHOT blogpost, it was their hard work that inspired me and allowed me to make this contribution for everyone!
Finally, I really hope this post will inspire others to start experimenting on ways to improve LLMs. There's so much to learn and so much left to discover! What a time to be alive!