Skip to content

A nimble and innovative implementation of the Direct Preference Optimization (DPO) algorithm with Causal Transformer and LSTM model, inspired by the paper of DPO in fine-tuning unsupervised Language Models

License

Notifications You must be signed in to change notification settings

jamesliu/nanoDPO

Repository files navigation

nanoDPO: Direct Preference Optimization

PyPI Changelog Tests Documentation Status License

Introduction

Welcome to nanoDPO – a novel, cutting-edge library for Direct Preference Optimization (DPO). Inspired by the concept of utilizing DPO in fine-tuning unsupervised Language Models (LMs) to align with user preferences.

Installation

To get started with nanoDPO, simply install the package using pip:

pip install nanoDPO

Key Features

  • Causal Transformer & Simple Sequence Model: Incorporates both a Causal Transformer and a LSTM-based Simple Sequence Model for diverse modeling needs.
  • Preference Data Simulation: Utilizes a custom function, simulate_dpo_dataset_noise, to generate synthetic preference-based data.
  • Sequence Data Preparation: Prepares data for training with prepare_sequence_datasets, aligning data with the DPO framework.
  • DPO Training with PyTorch: Leverages the power of PyTorch for efficient and effective model training, complete with customizable parameters.
  • MulticlassTrainer provides an additional approach to handle data, focusing on traditional multiclass classification tasks.
  • Cross-Entropy Loss for Multiclass Classification: Optimized for handling multiple classes in data.
  • Customizable Training and Evaluation: Flexible parameters for epochs, batch size, and learning rate.
  • Model Evaluation and Visualization: Offers tools for model evaluation and metrics visualization, ensuring an insightful analysis of performance.

Usage

Import the necessary modules from nanoDPO, including the CausalTransformer, SimpleSequenceModel, and dataset preparation functions. Utilize the DPOOneModelTrainer for Direct Preference Optimization or MulticlassTrainer for conventional multiclass training.

import torch
from nanodpo.causal_transformer import CausalTransformer
from nanodpo.simple_sequence_model import SimpleSequenceModel
from nanodpo.preference_data import simulate_dpo_dataset_noise
from nanodpo.sequence_data import prepare_sequence_datasets
from nanodpo.dpo_onemodel_trainer import DPOOneModelTrainer
from nanodpo.multiclass_trainer import MulticlassTrainer

# Initialize and train your model
...
model = CausalTransformer(d_feature= feature_dim, d_model=d_model, n_head=n_head, n_layer=n_layer, 
                          num_actions=num_actions, max_len=sequence_len,
                          device=device).to(device)
...
trainer = DPOOneModelTrainer(model=model, model_dir=f"dpo_{model_type}_model/", device=device,
                             learning_rate=learning_rate, batch_size=batch_size)
trainer.train(train_dataset, test_dataset, epochs=epochs, eval_interval=eval_interval)
# Evaluate and visualize the results
trainer.plot_metrics()
trainer.evaluate(test_dataset)

wandb dpo causal_transformer

License

nanoDPO is licensed under the Apache License 2.0. See the LICENSE file for more details.

Acknowledgments

Inspired by the paper "Direct Preference Optimization: Your Language Model is Secretly a Reward Model," nanoDPO extends the concept of DPO to the domain of data, opening new avenues for research and application.

Citation

@misc{rafailov2023direct,
  title   = {Direct Preference Optimization: Your Language Model is Secretly a Reward Model}, 
  author  = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
  year    = {2023},
  eprint  = {2305.18290},
  archivePrefix = {arXiv},
  primaryClass = {cs.LG}
}

About

A nimble and innovative implementation of the Direct Preference Optimization (DPO) algorithm with Causal Transformer and LSTM model, inspired by the paper of DPO in fine-tuning unsupervised Language Models

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages