This directory conducts federated instruction tuning with a pretrained Mistral-7B-Instruct model on a Medical dataset. We use Flower Datasets to download, partition and preprocess the dataset. Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way, which allows users to perform the training on a single GPU.
The fine-tuning results have been submitted as a PEFT adapter and can be accessed here:
FlowerTune-Mistral-7B-Instruct-v0.3-Medical-PEFT
This experiment performs federated LLM fine-tuning with LoRA using the 🤗PEFT library. The clients' models are aggregated with FedProx strategy.
For the Mistral-7B-Instruct-v0.3 model, we adopted the following fine-tuning methodology:
- Precision: bf16 for model weights, tf32 for gradients and optimizer states.
- Quantization: 4-bit quantization for reduced memory usage.
- Optimizer: Paged AdamW 8-bit for effective optimization under constrained resources.
- LoRA Configuration:
- Rank (r): 8
- Alpha: 32
- Target Modules: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj, lm_head
- Training Configuration:
- Batch size: 16
- Maximum number of steps: 6
- Warmup steps: 2
- Total number of rounds: 100
- Fraction fit per round: 0.15
- Learning Rate Scheduler: Constant learning rate scheduler with warmup steps, where:
- Maximum LR: 5e-5
- Minimum LR: 1e-6
- Strategy: FedProx
When bf16 and tf32 are enabled, model weights are stored in bf16 format, while gradients are computed in half-precision and converted to full 32-bit precision for updates.
Below is the training loss plot from the experiment:
This methodology enabled efficient fine-tuning within constrained resources while ensuring competitive performance.
- pubmedqa: 0.638
- medqa: 0.3983
- medmcqa: 0.256
- average: 0.4308
46228 Megabytes
Project dependencies are defined in pyproject.toml
. Install them in an activated Python environment with:
pip install -e .
The dataset is divided into 20 partitions in an IID fashion, a partition is assigned to each ClientApp.
We randomly sample a fraction (0.15) of the total nodes to participate in each round, for a total of 100
rounds.
All settings are defined in pyproject.toml
.
Important
Please note that [tool.flwr.app.config.static]
and options.num-supernodes
under [tool.flwr.federations.local-simulation]
are not allowed to be modified for fair competition if you plan to participated in the LLM leaderboard.
First make sure that you have got the access to Mistral-7B-Instruct model with your Hugging-Face account. You can request access directly from the Hugging-Face website. Then, follow the instruction here to log in your account. Note you only need to complete this stage once in your development machine:
huggingface-cli login
Run the challenge with default config values.
The configs are defined in [tool.flwr.app.config]
entry of pyproject.toml
, and are loaded automatically.
flwr run
We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM consumption per client for each challenge is shown below:
Challenges | GeneralNLP | Finance | Medical | Code |
---|---|---|---|---|
VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |
You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which are specified with options.backend.client-resources.num-cpus
and options.backend.client-resources.num-gpus
under [tool.flwr.federations.local-simulation]
entry in pyproject.toml
.
The global PEFT model checkpoints are saved every 5 rounds after aggregation on the sever side as default, which can be specified with train.save-every-round
under [tool.flwr.app.config] entry in pyproject.toml
.
Note
Please provide the last PEFT checkpoint if you plan to participated in the LLM leaderboard.