Retrieval Augmented Generation
How to improve LLM-generated text with RAG
Retrieving documents from external sources gives language models access to data that they didn’t see during training. Particularly for question-answering systems where you want the model to answer questions based on specific documents and not from its training corpus. This is important because you can have the model answer questions from private data sources.
The process works by retrieving the relevant documents and having the model answer questions from them. Having language models help in this process is time-saving because it would be quite cumbersome to search for answers from hundreds of documents.
Retrieval Augmented Generation involves:
Loading relevant documents
Creating embeddings for the documents using an embedding model
Storing the documents and embeddings in a vector database
Querying the vector database using a retriever to obtain the relevant documents based on certain criteria such as cosine similarity
Passing relevant documents to the language model for question-answering
Augmenting the language models with external data is critical in ensuring that its responses are up-to-date. For example, a model that was trained before COVID-19 will make up information when asked about COVID-19. Since the language models are also trained on general domain corpora, they may not perform well on domain-specific tasks.
Retrieval Augmented Generation reduces the chances of the language model hallucinating or generating responses from its training data by grounding its responses within the provided context. Therefore, the retriever is a critical part of this process as the language model answers the user’s query based on the documents provided. This means that you need a good receiver and language model for an effective system.
Thanks for reading mldive! Subscribe for free to receive new posts and support my work.
How Retrieval Augmented Generation Work
Retrieval Augmented Generation works in two phases:
Retrieving information relevant to the user’s query
Response generation based on the relevant information
RAG was proposed in the paper Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks.
The above method used a parametric sequence-to-sequence transformer model and a non-parametric Wikipedia vector index. As shown in the above figure the method uses an input sequence x to fetch text documents z which are used as additional context when generating the final sequence y.
The method is made up of the following components:
The purpose of the retriever is to get the documents that will be used for answering the questions. You can think of it as a smart search engine that returns the top documents that are relevant for answering a certain query. Then the language model acts as a writer and crafts an answer based on these documents.
The system uses a dense encoder to create a vector index for all the passages for use at retrieval. The retriever used is the Dense Passage Retriever which provides the latent documents that have been conditioned on the input.
The Dense Passage Retriever is defined as follows where:
d represents the dense representation of a document as obtained by a BERT BASE document encoder
q(x) is the query representation from the query encoder which is a BERT BASE model
Dense Passage Retriever works by indexing the given passages in a low-dimensional and continuous space and returning the most relevant passages at run time based on the question.
RAG uses a pre-trained retriever and model combined with the Maximum Inner Product Search (MIPS) to obtain the top-K documents for each query. The documents that a user’s query is to be answered from are usually chunked, embedded, and stored in a vector database. The vector database stores numerical representations of the documents.
An embedding of a user’s query is also created and can be compared with the documents in the vector store. The comparison can be based on different criteria, in this case, MIPS.
The generator used is a 400M parameter encoder-decoder BART model. The documents are conditioned on the input using the generator to obtain the output. The retrieved documents and input sequences are combined via concatenation.
The generator and retriever are trained jointly by minimizing the negative marginal log-likelihood. The document encoder is not updated during training because it’s costly. Hence, only the query encoder and generator are fine-tuned.
Two models are used for producing distributions over the generated text:
RAG-Sequence using the document to predict each target token
RAG-Token to predict each token based on a different document
The RAG-Sequence Model generates a complete sequence using the same retrieved document by viewing each document as one latent variable which is marginalized to get the seq2seq probability p(y|x) via a top-K approximation. The retriever produces the top-k documents and the generator computes the output sequence probability for each document for marginalization. The model’s formula is:
In the RAG-Token Model, the retriever obtains the top-k documents which the generator uses to produce a distribution for the next token for each document. The generator picks content from content from different sources to produce the answer.
Decoding the final output is done by treating the RAG-token model as an autoregressive seq2seq model with transition probability and by performing beam search on the RAG-sequence model.
Retrieval Augmented Generation systems have shown good performance on:
Open-domain Question Answering
Abstractive Question Answering
Jeopardy Question Generation
Other interesting papers to consider in the RAG space are:
RAG Evaluation Metrics
Evaluating retrieval augmented systems is quite challenging. However, there are several tools that have cropped up to make this possible. One such tool is LlamaIndex follows the evaluation of language models on various items such as whether:
The responses came from the provided documents
The response matches the given query
The language is able to follow the given guidelines
The generated answer is relevant and correct
LlamaIndex provides a batch runner that you can use to run several evaluations at the same time and aggregate them through an average to obtain the final score.
You can also evaluate the language mode using language model evaluation metrics.
Retrieval Augmented Generation is important in solving problems using language models because the provided sources are verifiable. It also grounds the models within the provided context and ensures that it doesn’t hallucinate or leak any sensitive data it may have learned during training. RAG also enables us to control the behavior of the language model by restricting it within a certain context.
However, your system may be limited by the quality of the retriever and language model. If the data being sent to the language is poor then the user’s query won’t be answered sufficiently. On the other hand, the data may be of high quality but the language model may not be adequate, meaning that you have to strike a balance between the two.
If you like this content and would like to support my work, consider purchasing one of my products.