Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

Latest commit

 

History

History
118 lines (86 loc) · 4.95 KB

README.md

File metadata and controls

118 lines (86 loc) · 4.95 KB

Lift Yourself Up: Retrieval-augmented Text Generation with Self Memory

lightning transformers

This repository contains the source code for this paper Lift Yourself Up: Retrieval-augmented Text Generation with Self Memory.

With direct access to human-written reference as memory, retrieval-augmented generation has achieved much progress in a wide range of text generation tasks. Since better memory would typically prompt better generation (we define this as primal problem), previous works mainly focus on how to retrieve better memory.

However, one fundamental limitation exists for current literature: the memory is retrieved from a fixed corpus and is bounded by the quality of the corpus. Due to the finite retrieval space, bounded memory would greatly limit the potential of the memory-augmented generation model.

In this paper, by exploring the duality of the primal problem: better generation also prompts better memory, we propose a framework called Selfmem, which iteratively adopts a retrieval-augmented generator itself to generate an unbounded memory pool and uses a memory selector to pick one generated memory for the next generation round.

By combining the primal and dual problem, a retrieval-augmented generation model could lift itself up with its own output in the infinite generation space.


Setup

Our code is mainly based on ⚡ PyTorch Lightning and 🤗 Transformers.

Specifically, the model definition and tokenizer is based on 🤗, and the Trainer is from ⚡.

## firstly install torch corresponding to the CUDA
pip install transformers==4.24.0 \
            pytorch-lightning==1.8.0.post1 \
            sacrebleu==2.2.0 \
            gputil==1.4.0

git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable ./

Data

For PLM we use, BART-base and Pegasus, download from huggingface model hubs and put it in the pretrained_model folder.

For dataset we use:

  • JRC-Acquis. We use the data version from this paper. For downloading, we refer to this LINK to download the data and this script for data pre-processing.

  • XSum is downloaded from this LINK.

  • DailyDialog is downloaded from this LINK.

  • BigPatent is available here.

After downloading the data, make it in the format of Jsonline and put it in the data folder.

For initial memory retrieval, we use ElasticSearch to conduct first-stage memory retrieval based on BM25 score.

We also provide the final hypothesis and reference in the output directory for potential need. For evaluation scripts, please refer to metrics_utils.py


Retrieval-Augmented Generator

Here we use JRC-Acqius EnDe dataset as example:

cd your/work/dir/src

## train a vanilla Transformer model
python train_generator.py \
    --config config/jrc_ende/train_generator.yaml \
    --precision 16

## Transformer-Joint
python train_generator.py \
    --config config/jrc_ende/train_generator.yaml \
    --precision 16 \
    --memory_encoding concate \
    --memory_dir ../data/jrc_ende/memory/bm25 

## Transformer-Dual
python train_generator.py \
    --config config/jrc_ende/train_generator.yaml \
    --precision 16 \
    --memory_encoding separate \
    --memory_dir ../data/jrc_ende/memory/bm25 

Memory Selector

First we use the trained generator to generate candidates

cd your/work/dir/src

python generate_hyps.py \
	--config config/jrc_ende/generate_hyps.yaml \
    --num_beams 50 --num_return_sequences 50 \
	--data_path ../data/jrc_ende/test.jsonl \
	--pretrained_model_path your/trained/model/path
	--memory_encoding concate \
	--memory_path ../data/jrc_ende/memory/bm25/test.txt \
	--output_path output.txt

Then we using this code to train a memory selector.

Lift Yourself Up

Here is the pseudo code for the whole process:

generator = Trainer(model_input,memory)
candidates = generator(model_input,memory)
selector = Trainer(model_input,candidates)

for _ in range(iteration_k):
    candidates = generator(model_input,memory)
    memory = selector(model_input,candidates)
    hypothesis = generator(model_input,memory)
    current_score = metrics(hypothesis,reference)