Skip to content

Commit

Permalink
Fix llama3 quantization for DPO models (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert authored Jun 24, 2024
1 parent 1ce7c87 commit e7b62d8
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ When updating the `Dockerfile`, make sure to see the instructions at the top to

In development, we have the following docker images (most recent first as it's likely what you need).
TODO: Update it so one image has VLLM (for generative RM only) and one without. Without will load much faster.
- `nathanl/rb_v19`: Fixes to DPO handling (minor)
- `nathanl/rb_v20`: Fixes to DPO handling (minor) + llama 3 not quantized for dpo
- `nathanl/rb_v18`: Improvements to RewardBench CLI
- `nathanl/rb_v17` (with VLLM): add support for vllm + llm as a judge, `rb_v16` is similar without prometheus and some OpenAI models
- `nathanl/rb_v12`: add support for llama3
Expand Down
8 changes: 7 additions & 1 deletion rewardbench/rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,13 @@ def main():
if not is_dpo:
quantized = config["quantized"] # only Starling isn't quantized for now
# if llama-3 in name, switch quantized to False (severely degrades performance)
if "llama-3" in args.model or args.not_quantized:
if (
("llama-3" in args.model)
or ("Llama3" in args.model)
or ("Llama-3" in args.model)
or ("LLaMA3" in args.model)
or args.not_quantized
):
quantized = False
logger.info(f"Disabling quantization for llama-3 or override flag (--not_quantized: {args.not_quantized})")
custom_dialogue = config["custom_dialogue"]
Expand Down
4 changes: 2 additions & 2 deletions scripts/configs/eval_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ allenai/llama-3-tulu-2-dpo-8b:
ref_model: allenai/llama-3-tulu-2-8b
tokenizer: allenai/llama-3-tulu-2-dpo-8b
chat_template: # none for tokenizer
batch_size: 2
batch_size: 1
num_gpus: 2
trust_remote_code: False
dpo: True
Expand All @@ -638,7 +638,7 @@ allenai/llama-3-tulu-2-dpo-70b:
ref_model: allenai/llama-3-tulu-2-70b
tokenizer: allenai/llama-3-tulu-2-dpo-70b
chat_template: # none for tokenizer
batch_size: 2
batch_size: 1
num_gpus: 4
trust_remote_code: False
dpo: True
40 changes: 30 additions & 10 deletions scripts/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def get_args():
parser.add_argument(
"--disable_beaker_save", action="store_true", help="disable saving the main results in a file for AI2 Beaker"
)
parser.add_argument(
"--not_quantized", action="store_true", help="disable quantization for models that are quantized by default"
)

args = parser.parse_args()
return args
Expand Down Expand Up @@ -144,11 +147,33 @@ def main():
############################
BATCH_SIZE = args.batch_size

model_kwargs = {
"load_in_8bit": True,
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
if (
("llama-3" in args.model)
or ("Llama3" in args.model)
or ("Llama-3" in args.model)
or ("LLaMA3" in args.model)
or args.not_quantized
):
model_kwargs = {
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
model_kwargs_ref = {
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
else:
model_kwargs = {
"load_in_8bit": True,
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
model_kwargs_ref = {
"load_in_8bit": True,
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}

model = model_builder(
args.model,
trust_remote_code=args.trust_remote_code,
Expand All @@ -158,11 +183,6 @@ def main():
if ref_free:
ref_model = None
else:
model_kwargs_ref = {
"load_in_8bit": True,
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
ref_model = model_builder(
args.ref_model,
trust_remote_code=args.trust_remote_code,
Expand Down
8 changes: 7 additions & 1 deletion scripts/run_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,13 @@ def main():

quantized = config["quantized"] # only Starling isn't quantized for now
# if llama-3 in name, switch quantized to False (severely degrades performance)
if "llama-3" in args.model or args.not_quantized:
if (
("llama-3" in args.model)
or ("Llama3" in args.model)
or ("Llama-3" in args.model)
or ("LLaMA3" in args.model)
or args.not_quantized
):
quantized = False
logger.info(f"Disabling quantization for llama-3 or override flag (--not_quantized: {args.not_quantized})")

Expand Down
2 changes: 1 addition & 1 deletion scripts/submit_eval_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"--eval_on_pref_sets", action="store_true", default=False, help="Evaluate on preference sets rather than core set"
)
argparser.add_argument("--eval_on_bon", action="store_true", default=False, help="Evaluate on BON preference sets")
argparser.add_argument("--image", type=str, default="nathanl/rb_v19", help="Beaker image to use")
argparser.add_argument("--image", type=str, default="nathanl/rb_v20", help="Beaker image to use")
argparser.add_argument("--cluster", type=str, default="ai2/allennlp-cirrascale", help="Beaker cluster to use")
argparser.add_argument("--priority", type=str, default="normal", help="Priority of the job")
argparser.add_argument("--upload_to_hub", action="store_false", default=True, help="Upload to results to HF hub")
Expand Down

0 comments on commit e7b62d8

Please sign in to comment.