Supersizing Transformers: Going Beyond RAG with Extended minds for LLMs

In this blog we discuss how the transformer architecture naturally extends over external memories, and share empirical results which leverage this capability. These methods are innate (don’t require fine tuning) and outperform popular retrieval augmented generation methods.
Author

Phoebe Klett, Thomas Ahle

Published

October 24, 2023

Today’s popularized large language models are optimized for the task of producing sequences of tokens that look like they could’ve been present in the training corpus. This is quite distinct from the ways in which LLMs are wielded in such user interfaces as ChatGPT or Perplexity.ai, where users expect the model to perform complex reasoning tasks and faithfully retrieve factual, topical information. If we hope to use the model as a general reasoning agent and not as a stochastic parrot, we need to provide it with any relevant data at inference time, rather than rely on (1) the salient data having appeared in the training corpus and (2) the model being able to recall said data. Further, surfacing references or citations that highlight which content the model used during its generation is crucial for building applications that truly augment human workflows.

This has prompted much development on methods colloquially referred to as “retrieval”. Or, methods that help LLMs make use of pertinent documents. In context learning, or placing the relevant documents in the context window before the prompt, is the obvious first step. However, in many cases we’re faced with documents longer than the context window of the model. RAG attempts to sidestep this by selecting the best subset of documents to include alongside the user’s query. While often effective, RAG is fundamentally limited by the need for a separate search engine. We can’t, for instance, ask the model questions which require synthesizing the entire set of documents. Further, since the retrieval happens before the generation, the best we can do r.e. explainability is report which text was included in the prompt itself. This says nothing about what text the model actually used during generation.

Finetuning seeks to extend the length of the context window itself. Running even a few epochs of training can be a non-trivial undertaking for today’s large models, even with a dedicated ML team. Further, these methods doesn’t contribute to the model’s interpretability. Other methods suggest structural changes to the model. Many of these are exciting, but most require training from scratch or fine-tuning, making them difficult to leverage with pre-trained models.

In this post, we propose and open source extended mind transformers, which generalize RAG internally. This simple mathematical generalization buys us the performance gains (and more) of RAG, as well as introducing net-new generation controls and granular causal citations. We also get the best of both worlds when it comes to ease of use: seamless integrations (everything is internal to the model), and no fine-tuning required!

Credits: Buchen ()

Aesthetics for Extended Mind Transformers

As motivation, we provide context from the Philosophy of Mind which served as inspiration for the naming convention and methodology. In Clark and Chalmers () “The Extended Mind”, they present the thesis that external information which is constantly and immediately accessible, and automatically endorsed should be considered part of the memory. And further, that this extension should be considered part of the mind. They term this idea active externalism. The story of Otto functions as an intuition pump:

“[L]ike many Alzheimer’s patients, [Otto] relies on information in the environment to help structure his life. Otto carries a notebook around with him everywhere he goes. When he learns new information, he writes it down. When he needs some old information, he looks it up. For Otto, his notebook plays the role usually played by a biological memory. … The information in the notebook functions just like information constituting an ordinary non-occurrent belief; it just happens that this information lies beyond the skin.”

In this piece, we present active externalism for LLMs, a mechanism for bolstering the memory of transformers aesthetically inspired by the Extended Mind Thesis. We call transformers which implement active externalism, extended mind transformers.

Extended Mind Transformers

Definition

Our proposed method, which closely resembles the work of Wu et al. (), is a simple change to the self-attention mechanism. In addition to the causal self-attention integral to transformers, we also allow each query token to attend to a fixed number of “external memories”. These memories are stored in a non-differentiable cache. The choice of which memories to attend to is made using cosine similarity within each decoder layer and attention head. More precisely, our attention computation is described by:

softmax(Q(KRKL)Td)×(VRVL)

Where (KL,VL) are key-value pairs from local context, and (KR,VR) are key-value pairs from external memories, and refers to tensor concatenation. We mask the attention weights such that each query token can only attend to its own retrieved keys, and not those retrieved by previous or following query tokens. In the experiments we present below we use models trained with linear biases rather than positional encodings. When we apply these linear biases to our attention weights, we assign the same index to all retrieved memories.

Importantly, active externalism retrieves memories exactly - it doesn’t summarize or otherwise dampen memories except through the linear biases.

We generate the external memories (key-value pairs) once, and then pass the representations to each decoder layer in an analogous fashion to passing previous “cached” key-values. In order to speed up the top-k cosine similarity computation we can use a vector database designed exactly for this purpose.

We argue that this way of attending to external memories or beliefs is the natural and optimal generalization of methods like RAG, and closely mimics the kind of relationship Otto has with his notebook. The information is constantly and immediately accessible, automatically endorsed, and reliably referenced. We set a similarity threshold such that we always reference our external memories (for every generated token, within all decoder layers), but discard keys that don’t meet some low similarity threshold to avoid confusing the model with irrelevant information.

Active externalism is not conceptually difficult to implement, but does require getting familiar with a particular model’s implementation since details like the way key-value pairs are stored and read into the self-attention computation need to be hijacked.

Benchmark Results

Perplexity Experiments

We use perplexity as a metric for model performance. Perplexity is a measure of uncertainty of the model over each generated token, closely related to our cross-entropy loss function. For a full explanation of perplexity as a metric, we suggest checking out this excellent post.

We show results below for perplexity experiments on the Wikitext-103 benchmark using Mosaic’s MPT-7b model. We use a stride of 512 tokens in our perplexity experiments, meaning each token is conditioned on at least 512 previous tokens, given that there are indeed 512 tokens to condition on.

Our active externalism method batches each sequence into chunks of increasing length (x-axis), and attends to tokens previous to the last 2048 (max sequence length) as external memories. We show results for varying k, where k is the number of memories we retrieve per query token. We compare active externalism to two baseline methods. The “truncated” baseline simply throws out any tokens previous to the last 2048 during perplexity computations, and the “naive” method which uses all input-length tokens, no matter how long the sequences become.

In the case of the naive method, we observe exactly the phenomenon active externalism seeks to ameliorate: after sequences exceed lengths greater than 2-3k tokens, the performance quickly drops off (in this case, perplexity blows up).

Perplexity results for Naive and Extended Mind MTP-7b, using a stride length of 512 tokens. Documents are batched into lengths of “Input Length” and we report average PPL on Y-Axis.

While we can see that active externalism provides clear benefits over simply doing local attention, in the case of the truncated benchmark. Even more exciting, perplexity continues to decrease as we increase the number of retrieved memories per query token.

Perplexity results for Truncated and Extended Mind MTP-7b, using a stride length of 512 tokens. Documents are batched into lengths of “Input Length” and we report average PPL on Y-Axis.

Retrieval Experiments

We also measure performance on retrieval benchmarks, and compare with RAG and simple baselines. Our dataset is a modified version of the recently released Long context WikiQA benchmark from Abacus.AI.

Our goal is to measure retrieval abilities over varying document lengths, but we also want to control for facts memorized during training, so we edit the dataset by changing the labeled answers to realistic but wrong answers. I.e, we replace every instance of “Lee Hazlewood” with “Terry Allen” in the Wikipedia entry for the song “These Boots Were Made For Walking”, and then ask the model to produce the songwriter’s name, with the correct answer now being “Terry Allen”.

Our intention is to measure the model’s ability to prioritize in context or in memory facts over those it memorized during training. Again, we feel this is an important ability if we’re asking LLMs to be reasoning agents in an evolving world.

In the results below, baseline receives no context at all for the question (we ask it point-blank), RAG selects the best ~2-3k tokens out of the document to include in-context, and active externalism puts the entire document in memory and uses it as Otto uses his notebook.

Retrieval Benchmark Results, by Document Length

We see that while RAG methods drop off with input length, active externalism continues to be effective. While models finetuned to use longer contexts do currently outperform active externalism on some long-range retrieval tasks, active externalism appears to be a more effective way to do retrieval over long contexts for smaller models.

Where active externalism clearly outperforms RAG in large models is precisely where the model has memorized before overfitting. Or, the model’s weights encode factual information even as the model’s performance on test data continues to improve. Depending on your application, this could be seen as a strength or shortcoming. Certainly when we use LLMs as reasoning agents, this is a shortcoming.

Using active externalism also appears to eliminate some reliance on prompting. Whereas usually we’d need to include some examples of the kind of responses we hope to observe in the prompt (or use a “chat” model which has been RLHF’ed), we observe experimentally that this isn’t necessary when using active externalism.

Impact on reasoning engine

We discuss two important consequences of active externalism on the LLM’s ability as a reasoning agent: uncertainty awareness and abstraction levers.

If we prompt the model with a question it’s unsure about, it may not respond in a way that’s transparent about that uncertainty. Active externalism provides a new method for revealing when a model is uncertain about its answer.

Let’s look at an example. We load our model easily from huggingface, and pass a paragraph from Wikipedia’s entry on Grothendieck as external memories.

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

wikipedia = """Alexander Grothendieck (/ˈɡroʊtəndiːk/; German pronunciation: [ˌalɛˈksandɐ ˈɡʁoːtn̩ˌdiːk] (listen); French: [ɡʁɔtɛndik]; 28 March 1928 – 13 November 2014) was a stateless (and then, since 1971, French) mathematician who became the leading figure in the creation of modern algebraic geometry.[7][8] His research extended the scope of the field and added elements of commutative algebra, homological algebra, sheaf theory, and category theory to its foundations, while his so-called "relative" perspective led to revolutionary advances in many areas of pure mathematics.[7][9] He is considered by many to be the greatest mathematician of the twentieth century.[10][11]

Grothendieck began his productive and public career as a mathematician in 1949. In 1958, he was appointed a research professor at the Institut des hautes études scientifiques (IHÉS) and remained there until 1970, when, driven by personal and political convictions, he left following a dispute over military funding. He received the Fields Medal in 1966 for advances in algebraic geometry, homological algebra, and K-theory.[12] He later became professor at the University of Montpellier[1] and, while still producing relevant mathematical work, he withdrew from the mathematical community and devoted himself to political and religious pursuits (first Buddhism and later, a more Christian vision).[13] In 1991, he moved to the French village of Lasserre in the Pyrenees, where he lived in seclusion, still working tirelessly on mathematics and his philosophical and religious thoughts until his death in 2014.[14]
"""

tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
memory_ids = tokenizer(wikipedia, return_tensors='pt')['input_ids']

model = AutoModelForCausalLM.from_pretrained("normalcomputing/extended-mind-mpt-7b", external_memories=memory_ids, trust_remote_code=True)

Now, let’s ask the model a question we know is answered (albeit a little obscurely) in the above paragraph without using active externalism. We can achieve this by setting the parameter model.use_active_externalism = False or simply passing topk=0. Hint: the correct answer is 1971.

prompt = "When did Alexander Grothendieck get his French citizenship?"
input_ids = tokenizer(prompt, return_tensors='pt')['input_ids']

out = model.generate(input_ids, max_length=input_ids.size(-1)+50, topk=0)
print('Baseline Generation: ', tokenizer.decode(out[0]))
Baseline Generation:  When did Alexander Grothendieck get his French citizenship?
I am trying to find out when Alexander Grothendieck got his French citizenship. I know that he was born in Germany and that he got his French citizenship in the late 1950s. I am trying to find out when he got his

Now let’s enable active externalism, slowly cranking up the number of memories each query token is allowed to attend to using the topk parameter.

out = model.generate(input_ids, max_length=input_ids.size(-1)+15, topk=5)
print('Generation for k=5: ', tokenizer.decode(out[0][input_ids.size(-1):]).strip())

out = model.generate(input_ids, max_length=input_ids.size(-1)+15, topk=6)
print('Generation for k=6: ',tokenizer.decode(out[0][input_ids.size(-1):]).strip())

out = model.generate(input_ids, max_length=input_ids.size(-1)+20, topk=7)
print('Generation for k=7: ',tokenizer.decode(out[0][input_ids.size(-1):]).strip())

out = model.generate(input_ids, max_length=input_ids.size(-1)+15, topk=8)
print('Generation for k=8: ',tokenizer.decode(out[0][input_ids.size(-1):]).strip())

out = model.generate(input_ids, max_length=input_ids.size(-1)+20, topk=30)
print('Generation for k=30: ',tokenizer.decode(out[0][input_ids.size(-1):]).strip())
Generation for k=5:  A: I think he got it in the early 1960s.
Generation for k=6:  A: I think he got it in the early 1970s.
Generation for k=7:  A: He was born in France, and he was naturalized in 1971.
<|endoftext|>
Generation for k=8:  A: I think he got it in 1971.
<|endoftext|>Q
Generation for k=30:  A: He was born in Germany, and became a French citizen in 1971.

Not only did the model produce the correct answer, but it also expressed increasing certainty about its answer. This evolution of generations signals the model’s original uncertainty.

In cases where the model is certain about the answer, the generations are stable as we increase k over the external context.

prompt = "What was did Alexander Grothendieck's profession?"
input_ids = tokenizer(prompt, return_tensors='pt')['input_ids']

out = model.generate(input_ids, max_length=input_ids.size(-1)+25, topk=0)
print('Baseline Generation: ', tokenizer.decode(out[0][input_ids.size(-1):]).strip())

out = model.generate(input_ids, max_length=input_ids.size(-1)+15, topk=2)
print('Generation for k=2: ', tokenizer.decode(out[0][input_ids.size(-1):]).strip())

out = model.generate(input_ids, max_length=input_ids.size(-1)+15, topk=8)
print('Generation for k=8: ', tokenizer.decode(out[0][input_ids.size(-1):]).strip())
Baseline Generation:  What was did Alexander Grothendieck's profession?
Alexander Grothendieck was a French mathematician
Generation for k=2:  Alexander Grothendieck was a mathematician.

What
Generation for k=8:  A: He was a mathematician.
<|endoftext|>Q: What

A natural extension of this principle might look like the development of a metric based on similarity or attention weight which could communicate this uncertainty in a more compact form, work currently under development at Normal.

The parameter topk also serves as a useful lever for the level of abstraction in the model’s output. E.g., the extent to which we’d like the model to synthesize the memories vs. quote verbatim from the source. We see this clearly in question-answering tasks over code. We show an example using the chat model here, which is best equipped to handle more free form question-answering tasks.

code_snippet = """def sieve_of_eratosthenes(limit):
    sieve = [True] * (limit + 1)
    sieve[0] = sieve[1] = False
    primes = []
    
    for current in range(2, int(limit**0.5) + 1):
        if sieve[current]:
            primes.append(current)
            for multiple in range(current*current, limit + 1, current):
                sieve[multiple] = False
    
    for num in range(int(limit**0.5) + 1, limit + 1):
        if sieve[num]:
            primes.append(num)
    
    return primes
"""
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
memory_ids = tokenizer(code_snippet, return_tensors='pt')['input_ids']

model = AutoModelForCausalLM.from_pretrained("normalcomputing/extended-mind-mpt-7b-chat", external_memories=memory_ids, trust_remote_code=True)

We ask the model to recall what our function does, first with a small topk.

prompt =  "What does the function sieve_of_eratosthenes do?"
input_ids = tokenizer(prompt, return_tensors='pt')['input_ids']

out = model.generate(input_ids, max_length=input_ids.size(-1)+100, topk=2)
print(tokenizer.decode(out[0]))
What does the function sieve_of_eratosthenes do?
The function sieve_of_eratosthenes is a Python function that implements the Sieve of Eratosthenes algorithm to generate all prime numbers up to a given limit.

The Sieve of Eratosthenes is a simple algorithm that generates all prime numbers up to a given limit. It works by creating a list of all integers from 2 to the given limit, and then iteratively marking the multiples of each prime number as composite (not prime).

We see that with a small topk the model abstracts away the details from the code, providing a natural language description of what the code does. Now let’s try with a larger topk.

out = model.generate(input_ids, max_length=input_ids.size(-1)+100, topk=14)
print(tokenizer.decode(out[0]))
What does the function sieve_of_eratosthenes do?(limit):
        primes.append(True)
        for i in range(2, int(limit**0.5) + 1):
            if sieve[i]:
                break
        else:
            for i in range(2, int(limit**0.5) + 1):
                if i % 2 == 0:
                    sieve[i] = False
    
    return primes
```

This implementation of the S

Now the model outputs much closer to verbatim code, while abstracting away some variable names. This is the kind of nuanced stylistic choice is very hard to achieve using naive prompting and RAG methods without developing many point solutions specific to the data and prompt. More importantly, this kind of experiment gives us small clues into how the model actually reasons over these key-value pairs. At Normal, we hope to combine work on mechanistic interpretability methods with extended mind transformers, building a unified system for understanding how models store facts and reason over them.

Explainability

Clark and Chalmers write in their paper: “By embracing an active externalism, we allow a more natural explanation of all sorts of actions”, and indeed this is true for our active externalism as well. Using attention weights, we can highlight which memories were used during each generation step. Here we highlight the memories used when generating the correct token “1971”. Since we retrieve memories per layer, per head, we display the mode.