Data Augmentation¶
To demonstrate the usage of RAG-FiT data augmentation, we will follow the experimentation presented in the paper. Choosing the ASQA Q&A dataset and the Phi-3 model. We compare a baseline configuration with 4 other configurations:
- Retrieval augmentation using a corpus and inserting the documents in the prompt after the question.
- Similar to (1) but having the model fine-tune on the completions.
- Similar to (1) and adding a Chain-of-Thought instruction for the model to explain its reasoning and format its answer.
- Similar to (3) but having the model fine-tune on the completions while implementing a technique from RAFT where distracting documents are used.
The ASQA dataset has two types of answer: a long answer and lists of short answers (actually list of lists). Additionally, it has some minimal amount of context in the data, so we augment it using a corpus, stored as a vector DB; we use Qdrant.
In order to train configuration (4), we need to have CoT well-reasoned responses as labels, so we use OpenAI GPT4 model to augment a dataset with these synthetic labels.
Notice: all the configurations mentioned here, implementing the experiments done in the paper, are saved in
configs/paper/
. They don't run by default, they need to be specified by running:
Retrieval¶
The first step would be to augment the entire dataset (train, dev) with relevant documents, based on the questions, see processing-asqa-retrieval.yaml. Let's focus on the different steps:
- _target_: ragfit.processing.dataset_loaders.loaders.HFLoader
inputs: train
dataset_config:
path: din0s/asqa
split: train
- _target_: ragfit.processing.dataset_loaders.loaders.HFLoader
inputs: dev
dataset_config:
path: din0s/asqa
split: dev
We load the train and dev splits, to be used in the pipeline; they will be referred using the inputs
keyword used in this
step.
query
, answers
, positive_passages
, etc. Feel free to add your own types
of pre-processing.
Notice the inputs
keyword can accept a list of strings, meaning the step will run over the datasets specified.
- _target_:
ragfit.processing.local_steps.retrievers.haystack.HaystackRetriever
inputs: [train, dev]
pipeline_or_yaml_path: ./configs/external/haystack/qdrant.yaml
docs_key: positive_passages
query_key: query
The retrieval step will store the most relevant documents (k=5) in the docs_key
and the query will be defined by the
query_key
.
- _target_: ragfit.processing.local_steps.context.ContextHandler
inputs: [train, dev]
docs_key: positive_passages
- _target_: ragfit.processing.global_steps.sampling.Sampler
inputs: [train, dev]
k: 1
input_key: positive_passages
output_key: negative_passages
Sampler
class deals with sampling examples from the same dataset or others. In order to train the RAFT-based
model on a combination of relevant and distracting documents, we need to collect these distracting documents. Here we
chose to collect positive documents from other examples, to be used as negative documents. The Sampler
is then ran
with k=1, it collects only the positive_passages
from the examples it samples and store them in a new keyword, called
negative_passages
.
Finally we write the two resulting dataset to disk. They represent the retrieval-augmented datasets, ready to be
processed for the different tasks.
To run this process:
Baseline Configuration¶
For the baseline, there is not going to be context, only the question presented to the model. We use instruction-following models that have a chat template built-in. The framework populates the chat template using the inputs and outputs we generate, so we don't need to worry about roles and special tokens. Additionally, the system instruction is specified only during training and inference: it needn't be part of the dataset so these next steps mainly deal with the prompt generation.
These are the interesting steps:
- _target_: ragfit.processing.dataset_loaders.loaders.LocalLoader
inputs: dev
filename: asqa-dev.jsonl
- _target_: ragfit.processing.local_steps.prompter.TextPrompter
inputs: dev
prompt_file: ragfit/processing/prompts/qa-short.txt
output_key: prompt
mapping:
query: query
We load the locally retrieval-augmented files we generated in the previous section.
The TextPrompter
populates a template file containing placeholders in python format, see the short
template. The step replace the placeholders with variables using a provided
mapping. The result is a string, saved in a keyword called outputs_key
.
To run this process:
Context¶
Preparing for configurations (1) and (2), we want to augment the examples with the top 5 documents we collected in the first step.
- _target_: ragfit.processing.local_steps.context.DocumentsJoiner
inputs: [train, dev]
docs_key: positive_passages
k: 5
- _target_: ragfit.processing.local_steps.prompter.TextPrompter
inputs: [train, dev]
prompt_file: ragfit/processing/prompts/qa.txt
output_key: prompt
mapping:
question: query
context: positive_passages
DocumentJoiner
joins a list of strings and is needed before the TextPrompter
we've seen from the previous
section. We prepare a dev file—for testing the model with retrieved documents—and also a training file, in order
to run fine-tuning. Both configurations will be evaluated on the dev dataset.
To run this process:
Chain-of-Thought¶
We prepare a dev set with CoT reasoning prompt. The configuration will be similar to the Context configuration, however here we use a different prompt template:
- _target_: ragfit.processing.local_steps.prompter.TextPrompter
inputs: dev
prompt_file: ragfit/processing/prompts/cot.txt
output_key: prompt
mapping:
question: query
context: positive_passages
To run this process:
Chain-of-Thought Training Dataset¶
In order to train a model on a CoT-based prompt, we need to collect well-reasoned responses; we use GPT4 for that. Additionally, we implement a technique from RAFT where some percentage of the examples have purely distractor documents, in order for the model ability to filter noise. Here are the relevant steps:
- _target_: ragfit.processing.local_steps.raft.RAFTStep
inputs: train
k: 5
raft_p: 0.5
neg_docs_num: 2
output_key: raft_docs
RAFTStep
implements the logic presented in the paper; the percentage of purely-distractor documents is defined by
raft_p
. The list of documents, some relevant, some distracting, are saved in a keyword called output_key
.
- _target_: ragfit.processing.local_steps.context.DocumentsJoiner
inputs: train
docs_key: raft_docs
k:
- _target_: ragfit.processing.local_steps.prompter.TextPrompter
inputs: train
prompt_file: ragfit/processing/prompts/cot.txt
output_key: prompt
mapping:
question: query
context: raft_docs
k:
all documents are used. The prompt used is the same as when building the dev dataset.
Next is interacting with OpeanAI; we implemented an OpenAI class using Azure,
one can implement using other abstractions. The step itself needs the prompt_key
, instruction file and the results are
saved in the answer_key
.
- _target_: ragfit.processing.local_steps.api.openai.OpenAIChat
inputs: train
prompt_key: prompt
answer_key: generated_answer
instruction: ragfit/processing/prompts/prompt_instructions/qa.txt
model:
azure_endpoint: azure.endpoint.com
api_version: 2024-05-01-preview
model: GPT-4-32k-Bot
To run this process: