diff --git a/atomgpt/inverse_models/inverse_models.py b/atomgpt/inverse_models/inverse_models.py index 957e30c..df658b3 100644 --- a/atomgpt/inverse_models/inverse_models.py +++ b/atomgpt/inverse_models/inverse_models.py @@ -19,6 +19,8 @@ from pydantic_settings import BaseSettings import sys import argparse +from peft import PeftModelForCausalLM + parser = argparse.ArgumentParser( description="Atomistic Generative Pre-trained Transformer." @@ -38,7 +40,7 @@ class TrainingPropConfig(BaseSettings): prefix: str = "atomgpt_run" model_name: str = "unsloth/mistral-7b-bnb-4bit" batch_size: int = 2 - num_epochs: int = 5 + num_epochs: int = 2 seed_val: int = 42 num_train: Optional[int] = 2 num_val: Optional[int] = 2 @@ -164,7 +166,7 @@ def text2atoms(response): return atoms -def gen_atoms(prompt="", max_new_tokens=512, model="", tokenizer=""): +def gen_atoms(prompt="", max_new_tokens=2048, model="", tokenizer=""): inputs = tokenizer( [ alpaca_prompt.format( @@ -179,9 +181,7 @@ def gen_atoms(prompt="", max_new_tokens=512, model="", tokenizer=""): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, use_cache=True ) - response = tokenizer.batch_decode(outputs) - print("response", response) - response = response[0].split("# Output:")[1] + response = tokenizer.batch_decode(outputs)[0].split("# Output:")[1] atoms = None try: atoms = text2atoms(response) @@ -263,26 +263,28 @@ def run_atomgpt_inverse(config_file="config.json"): # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf ) - model = FastLanguageModel.get_peft_model( - model, - r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 - target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], - lora_alpha=16, - lora_dropout=0, # Supports any, but = 0 is optimized - bias="none", # Supports any, but = "none" is optimized - use_gradient_checkpointing=True, - random_state=3407, - use_rslora=False, # We support rank stabilized LoRA - loftq_config=None, # And LoftQ - ) + if not isinstance(model, PeftModelForCausalLM): + + model = FastLanguageModel.get_peft_model( + model, + r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + lora_alpha=16, + lora_dropout=0, # Supports any, but = 0 is optimized + bias="none", # Supports any, but = "none" is optimized + use_gradient_checkpointing=True, + random_state=3407, + use_rslora=False, # We support rank stabilized LoRA + loftq_config=None, # And LoftQ + ) EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN