Skip to content

Commit

Permalink
Add custom.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Nov 17, 2024
1 parent 7843b87 commit 40d2bb3
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 115 deletions.
169 changes: 138 additions & 31 deletions atomgpt/inverse_models/custom_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
import torch
import torch.nn as nn
from trl import SFTTrainer
from atomgpt.inverse_models.utils import text2atoms


def extract_atomic_structure(target_texts):
"""
Extracts the atomic structure description from a list of target texts.
:param target_texts: List of strings containing target texts with atomic structure details.
:return: List of strings with only the atomic structure descriptions.
"""
atomic_structures = []

for text in target_texts:
# Split the text at "### Output:"
if "### Output:" in text:
structure_part = text2atoms(
text.split("### Output:")[1]
) # .strip()
atomic_structures.append(structure_part)
else:
print("No '### Output:' found in the text.")

return atomic_structures


class CustomSFTTrainer(SFTTrainer):
Expand Down Expand Up @@ -44,6 +67,24 @@ def __init__(
self.loss_type = loss_type.lower()
# self.use_bare_trainer = use_bare_trainer

def calculate_density(self, atomic_structure):
# Example of a function to calculate density (or any other feature from atomic structure)
# You can implement this based on your domain knowledge.
return len(
atomic_structure
) # Placeholder: use actual calculation logic

def extract_atomic_structure(self, target_texts):
atomic_structures = []
for text in target_texts:
# Split the text at "### Output:"
if "### Output:" in text:
structure_part = text.split("### Output:")[1].strip()
atomic_structures.append(structure_part)
else:
print("No '### Output:' found in the text.")
return atomic_structures

def compute_loss(self, model, inputs, return_outputs=False):
"""
Custom loss computation based on the selected loss type or the bare trainer.
Expand Down Expand Up @@ -74,50 +115,116 @@ def compute_loss(self, model, inputs, return_outputs=False):

if labels is not None:
# labels = labels.cpu().numpy()
print("self.tokenizer", self.tokenizer)
# print("self.tokenizer", self.tokenizer)
# print("inputs", inputs,inputs['input_ids'].shape)
# print('logits',logits,logits.shape)
# print('labels1',labels,labels.shape)
# Need to make generalized
labels[labels == -100] = 0
# print('labels2',labels,labels.shape)
# Generate outputs
# Decode generated text (example for illustration)
target_texts = self.tokenizer.batch_decode(
labels
) # , skip_special_tokens=True)
target_inputs = self.tokenizer(
target_texts,
return_tensors="pt",
padding=True,
truncation=True,
).to(self.model.device)
labels, skip_special_tokens=True
)
pred_texts = self.tokenizer.batch_decode(
logits.argmax(-1), skip_special_tokens=True
)

# Generate outputs
outputs = self.model.generate(
input_ids=target_inputs["input_ids"],
max_new_tokens=1024,
use_cache=True,
# Extract atomic structures (or manipulate the texts)
target_atomic_structures = self.extract_atomic_structure(
target_texts
)
pred_atomic_structures = self.extract_atomic_structure(
pred_texts
)

# Decode the generated outputs for analysis or debugging
generated_texts = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True
# For demonstration, let's calculate the L1 loss between target and predicted atomic structures
# Assuming that the atomic structures are numerical or encoded in a way that we can directly compare
# For the sake of this example, let's assume you have a function to calculate density or other features
# Example: comparing the density (or other features) of the predicted and target atomic structures
target_densities = torch.tensor(
[
self.calculate_density(struct)
for struct in target_atomic_structures
]
)
pred_densities = torch.tensor(
[
self.calculate_density(struct)
for struct in pred_atomic_structures
]
)
print("Generated Texts:", generated_texts)

# Ensure the tensors are on the correct device
target_densities = target_densities.to(logits.device)
pred_densities = pred_densities.to(logits.device)

# Custom loss: L1 loss between target and predicted densities
loss_fn = nn.L1Loss()
loss = loss_fn(pred_densities, target_densities)
print(loss)
return loss
import sys

sys.exit()
x = logits # .view(-1, logits.size(-1))
y = labels # .view(-1)
print("x", x, x.shape, logits.shape)
print("y", y, y.shape, labels.shape)
outputs = self.model.generate(
target, max_new_tokens=1024, use_cache=True
target_out = self.model.generate(
input_ids=inputs["input_ids"],
max_new_tokens=2024,
use_cache=True,
)
# print("target_out", target_out)

# Decode the generated outputs for analysis or debugging
target_texts = self.tokenizer.batch_decode(
target_out, skip_special_tokens=True
)
target_atom_texts = extract_atomic_structure(target_texts)
# print("Target Texts:", target_texts,target_atom_texts)

gen_out = self.model.generate(
input_ids=labels,
max_new_tokens=2024,
use_cache=True,
)
# print("gen_out", gen_out)

# Decode the generated outputs for analysis or debugging
gen_texts = self.tokenizer.batch_decode(
gen_out, skip_special_tokens=True
)
gen_atom_texts = extract_atomic_structure(gen_texts)
# print("Generated Texts:", gen_texts,gen_atom_texts)
loss_fn = nn.L1Loss()
target = torch.tensor(
[i.density for i in target_atom_texts],
device=labels.device,
dtype=torch.float,
requires_grad=False,
)
print("outputs", outputs)
response = self.tokenizer.batch_decode(
labels
) # [0].split("# Output:")[1]
# loss_fn = nn.L1Loss()
# target = labels.float()
# loss = loss_fn(logits.view(-1), target.view(-1))
pred = torch.tensor(
[i.density for i in gen_atom_texts],
device=labels.device,
dtype=torch.float,
requires_grad=True,
)

# target = torch.tensor([i.density for i in target_atom_texts]).to(labels.device)
# pred = torch.tensor([i.density for i in gen_atom_texts]).to(labels.device)
loss = loss_fn(target, pred)
print("target", target)
print("pred", pred)
print("loss", loss)
return loss

elif self.loss_type == "cross_entropy":
loss_fn = nn.CrossEntropyLoss()
x = logits.view(-1, logits.size(-1))
y = labels.view(-1)
# print('x',x.shape)
# print('y',y.shape)
loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
# print('loss',loss,loss.shape)
else:
raise ValueError(f"Unsupported loss type: {self.loss_type}")

Expand Down
91 changes: 7 additions & 84 deletions atomgpt/inverse_models/inverse_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from typing import Optional
from atomgpt.inverse_models.loader import FastLanguageModel
from atomgpt.inverse_models.custom_trainer import CustomSFTTrainer
from atomgpt.inverse_models.utils import (
gen_atoms,
text2atoms,
get_crystal_string_t,
)
import torch
from peft import PeftModel
from datasets import load_dataset
Expand Down Expand Up @@ -38,7 +43,7 @@
class TrainingPropConfig(BaseSettings):
"""Training config defaults and validation."""

id_prop_path: Optional[str] = "id_prop.csv"
id_prop_path: Optional[str] = "atomgpt/examples/inverse_model/id_prop.csv"
prefix: str = "atomgpt_run"
model_name: str = "knc6/atomgpt_mistral_tc_supercon"
batch_size: int = 2
Expand Down Expand Up @@ -84,29 +89,6 @@ class TrainingPropConfig(BaseSettings):
{}"""


def get_crystal_string_t(atoms):
lengths = atoms.lattice.abc # structure.lattice.parameters[:3]
angles = atoms.lattice.angles
atom_ids = atoms.elements
frac_coords = atoms.frac_coords

crystal_str = (
" ".join(["{0:.2f}".format(x) for x in lengths])
+ "\n"
+ " ".join([str(int(x)) for x in angles])
+ "\n"
+ "\n".join(
[
str(t) + " " + " ".join(["{0:.3f}".format(x) for x in c])
for t, c in zip(atom_ids, frac_coords)
]
)
)

# crystal_str = atoms_describer(atoms) + "\n*\n" + crystal_str
return crystal_str


def make_alpaca_json(
dataset=[], jids=[], prop="Tc_supercon", include_jid=False
):
Expand Down Expand Up @@ -149,65 +131,6 @@ def formatting_prompts_func(examples):
}


def text2atoms(response):
tmp_atoms_array = response.strip("</s>").split("\n")
# tmp_atoms_array= [element for element in tmp_atoms_array if element != '']
# print("tmp_atoms_array", tmp_atoms_array)
lat_lengths = np.array(tmp_atoms_array[1].split(), dtype="float")
lat_angles = np.array(tmp_atoms_array[2].split(), dtype="float")

lat = Lattice.from_parameters(
lat_lengths[0],
lat_lengths[1],
lat_lengths[2],
lat_angles[0],
lat_angles[1],
lat_angles[2],
)
elements = []
coords = []
for ii, i in enumerate(tmp_atoms_array):
if ii > 2 and ii < len(tmp_atoms_array):
# if ii>2 and ii<len(tmp_atoms_array)-1:
tmp = i.split()
elements.append(tmp[0])
coords.append([float(tmp[1]), float(tmp[2]), float(tmp[3])])

atoms = Atoms(
coords=coords,
elements=elements,
lattice_mat=lat.lattice(),
cartesian=False,
)
return atoms


def gen_atoms(prompt="", max_new_tokens=512, model="", tokenizer=""):
inputs = tokenizer(
[
alpaca_prompt.format(
instruction,
prompt, # input
"", # output - leave this blank for generation!
)
],
return_tensors="pt",
).to("cuda")

outputs = model.generate(
**inputs, max_new_tokens=max_new_tokens, use_cache=True
)
response = tokenizer.batch_decode(outputs)[0].split("# Output:")[1]
atoms = None
try:
atoms = text2atoms(response)
except Exception as exp:

print(exp)
pass
return atoms


#######################################


Expand All @@ -222,7 +145,7 @@ def run_atomgpt_inverse(config_file="config.json"):
num_test = config.num_test
model_name = config.model_name
# loss_function = config.loss_function
id_prop_path = os.path.join(run_path, id_prop_path)
# id_prop_path = os.path.join(run_path, id_prop_path)
with open(id_prop_path, "r") as f:
reader = csv.reader(f)
dt = [row for row in reader]
Expand Down

0 comments on commit 40d2bb3

Please sign in to comment.