From 40d2bb3a2d7de9e69689b6a50077d66e90db0bb4 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sun, 17 Nov 2024 07:11:05 -0500 Subject: [PATCH] Add custom. --- atomgpt/inverse_models/custom_trainer.py | 169 ++++++++++++++++++----- atomgpt/inverse_models/inverse_models.py | 91 +----------- 2 files changed, 145 insertions(+), 115 deletions(-) diff --git a/atomgpt/inverse_models/custom_trainer.py b/atomgpt/inverse_models/custom_trainer.py index 1cff078..c3043f3 100644 --- a/atomgpt/inverse_models/custom_trainer.py +++ b/atomgpt/inverse_models/custom_trainer.py @@ -2,6 +2,29 @@ import torch import torch.nn as nn from trl import SFTTrainer +from atomgpt.inverse_models.utils import text2atoms + + +def extract_atomic_structure(target_texts): + """ + Extracts the atomic structure description from a list of target texts. + + :param target_texts: List of strings containing target texts with atomic structure details. + :return: List of strings with only the atomic structure descriptions. + """ + atomic_structures = [] + + for text in target_texts: + # Split the text at "### Output:" + if "### Output:" in text: + structure_part = text2atoms( + text.split("### Output:")[1] + ) # .strip() + atomic_structures.append(structure_part) + else: + print("No '### Output:' found in the text.") + + return atomic_structures class CustomSFTTrainer(SFTTrainer): @@ -44,6 +67,24 @@ def __init__( self.loss_type = loss_type.lower() # self.use_bare_trainer = use_bare_trainer + def calculate_density(self, atomic_structure): + # Example of a function to calculate density (or any other feature from atomic structure) + # You can implement this based on your domain knowledge. + return len( + atomic_structure + ) # Placeholder: use actual calculation logic + + def extract_atomic_structure(self, target_texts): + atomic_structures = [] + for text in target_texts: + # Split the text at "### Output:" + if "### Output:" in text: + structure_part = text.split("### Output:")[1].strip() + atomic_structures.append(structure_part) + else: + print("No '### Output:' found in the text.") + return atomic_structures + def compute_loss(self, model, inputs, return_outputs=False): """ Custom loss computation based on the selected loss type or the bare trainer. @@ -74,50 +115,116 @@ def compute_loss(self, model, inputs, return_outputs=False): if labels is not None: # labels = labels.cpu().numpy() - print("self.tokenizer", self.tokenizer) + # print("self.tokenizer", self.tokenizer) + # print("inputs", inputs,inputs['input_ids'].shape) + # print('logits',logits,logits.shape) + # print('labels1',labels,labels.shape) + # Need to make generalized + labels[labels == -100] = 0 + # print('labels2',labels,labels.shape) + # Generate outputs + # Decode generated text (example for illustration) target_texts = self.tokenizer.batch_decode( - labels - ) # , skip_special_tokens=True) - target_inputs = self.tokenizer( - target_texts, - return_tensors="pt", - padding=True, - truncation=True, - ).to(self.model.device) + labels, skip_special_tokens=True + ) + pred_texts = self.tokenizer.batch_decode( + logits.argmax(-1), skip_special_tokens=True + ) - # Generate outputs - outputs = self.model.generate( - input_ids=target_inputs["input_ids"], - max_new_tokens=1024, - use_cache=True, + # Extract atomic structures (or manipulate the texts) + target_atomic_structures = self.extract_atomic_structure( + target_texts + ) + pred_atomic_structures = self.extract_atomic_structure( + pred_texts ) - # Decode the generated outputs for analysis or debugging - generated_texts = self.tokenizer.batch_decode( - outputs, skip_special_tokens=True + # For demonstration, let's calculate the L1 loss between target and predicted atomic structures + # Assuming that the atomic structures are numerical or encoded in a way that we can directly compare + # For the sake of this example, let's assume you have a function to calculate density or other features + # Example: comparing the density (or other features) of the predicted and target atomic structures + target_densities = torch.tensor( + [ + self.calculate_density(struct) + for struct in target_atomic_structures + ] + ) + pred_densities = torch.tensor( + [ + self.calculate_density(struct) + for struct in pred_atomic_structures + ] ) - print("Generated Texts:", generated_texts) + + # Ensure the tensors are on the correct device + target_densities = target_densities.to(logits.device) + pred_densities = pred_densities.to(logits.device) + + # Custom loss: L1 loss between target and predicted densities + loss_fn = nn.L1Loss() + loss = loss_fn(pred_densities, target_densities) + print(loss) + return loss import sys sys.exit() - x = logits # .view(-1, logits.size(-1)) - y = labels # .view(-1) - print("x", x, x.shape, logits.shape) - print("y", y, y.shape, labels.shape) - outputs = self.model.generate( - target, max_new_tokens=1024, use_cache=True + target_out = self.model.generate( + input_ids=inputs["input_ids"], + max_new_tokens=2024, + use_cache=True, + ) + # print("target_out", target_out) + + # Decode the generated outputs for analysis or debugging + target_texts = self.tokenizer.batch_decode( + target_out, skip_special_tokens=True + ) + target_atom_texts = extract_atomic_structure(target_texts) + # print("Target Texts:", target_texts,target_atom_texts) + + gen_out = self.model.generate( + input_ids=labels, + max_new_tokens=2024, + use_cache=True, + ) + # print("gen_out", gen_out) + + # Decode the generated outputs for analysis or debugging + gen_texts = self.tokenizer.batch_decode( + gen_out, skip_special_tokens=True + ) + gen_atom_texts = extract_atomic_structure(gen_texts) + # print("Generated Texts:", gen_texts,gen_atom_texts) + loss_fn = nn.L1Loss() + target = torch.tensor( + [i.density for i in target_atom_texts], + device=labels.device, + dtype=torch.float, + requires_grad=False, ) - print("outputs", outputs) - response = self.tokenizer.batch_decode( - labels - ) # [0].split("# Output:")[1] - # loss_fn = nn.L1Loss() - # target = labels.float() - # loss = loss_fn(logits.view(-1), target.view(-1)) + pred = torch.tensor( + [i.density for i in gen_atom_texts], + device=labels.device, + dtype=torch.float, + requires_grad=True, + ) + + # target = torch.tensor([i.density for i in target_atom_texts]).to(labels.device) + # pred = torch.tensor([i.density for i in gen_atom_texts]).to(labels.device) + loss = loss_fn(target, pred) + print("target", target) + print("pred", pred) + print("loss", loss) + return loss elif self.loss_type == "cross_entropy": loss_fn = nn.CrossEntropyLoss() + x = logits.view(-1, logits.size(-1)) + y = labels.view(-1) + # print('x',x.shape) + # print('y',y.shape) loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) + # print('loss',loss,loss.shape) else: raise ValueError(f"Unsupported loss type: {self.loss_type}") diff --git a/atomgpt/inverse_models/inverse_models.py b/atomgpt/inverse_models/inverse_models.py index aeb84f4..5cc8f89 100644 --- a/atomgpt/inverse_models/inverse_models.py +++ b/atomgpt/inverse_models/inverse_models.py @@ -2,6 +2,11 @@ from typing import Optional from atomgpt.inverse_models.loader import FastLanguageModel from atomgpt.inverse_models.custom_trainer import CustomSFTTrainer +from atomgpt.inverse_models.utils import ( + gen_atoms, + text2atoms, + get_crystal_string_t, +) import torch from peft import PeftModel from datasets import load_dataset @@ -38,7 +43,7 @@ class TrainingPropConfig(BaseSettings): """Training config defaults and validation.""" - id_prop_path: Optional[str] = "id_prop.csv" + id_prop_path: Optional[str] = "atomgpt/examples/inverse_model/id_prop.csv" prefix: str = "atomgpt_run" model_name: str = "knc6/atomgpt_mistral_tc_supercon" batch_size: int = 2 @@ -84,29 +89,6 @@ class TrainingPropConfig(BaseSettings): {}""" -def get_crystal_string_t(atoms): - lengths = atoms.lattice.abc # structure.lattice.parameters[:3] - angles = atoms.lattice.angles - atom_ids = atoms.elements - frac_coords = atoms.frac_coords - - crystal_str = ( - " ".join(["{0:.2f}".format(x) for x in lengths]) - + "\n" - + " ".join([str(int(x)) for x in angles]) - + "\n" - + "\n".join( - [ - str(t) + " " + " ".join(["{0:.3f}".format(x) for x in c]) - for t, c in zip(atom_ids, frac_coords) - ] - ) - ) - - # crystal_str = atoms_describer(atoms) + "\n*\n" + crystal_str - return crystal_str - - def make_alpaca_json( dataset=[], jids=[], prop="Tc_supercon", include_jid=False ): @@ -149,65 +131,6 @@ def formatting_prompts_func(examples): } -def text2atoms(response): - tmp_atoms_array = response.strip("").split("\n") - # tmp_atoms_array= [element for element in tmp_atoms_array if element != ''] - # print("tmp_atoms_array", tmp_atoms_array) - lat_lengths = np.array(tmp_atoms_array[1].split(), dtype="float") - lat_angles = np.array(tmp_atoms_array[2].split(), dtype="float") - - lat = Lattice.from_parameters( - lat_lengths[0], - lat_lengths[1], - lat_lengths[2], - lat_angles[0], - lat_angles[1], - lat_angles[2], - ) - elements = [] - coords = [] - for ii, i in enumerate(tmp_atoms_array): - if ii > 2 and ii < len(tmp_atoms_array): - # if ii>2 and ii