Skip to content

Commit

Permalink
Merge pull request #28 from YerevaNN/safe2
Browse files Browse the repository at this point in the history
final safe
  • Loading branch information
FilyaGeikyan authored Sep 27, 2024
2 parents d584b5a + e82a73e commit be97fc3
Show file tree
Hide file tree
Showing 7 changed files with 487 additions and 6 deletions.
5 changes: 5 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ def __init__(self):
default=True,
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--training.representation_type",
default="SMILES",
help="The representation type of the molecule for training the model.",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
Expand Down
8 changes: 5 additions & 3 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
dataset_path: Optional[str],
data_processing_style: str,
tokenizer: Tokenizer,
representation_type: str = "SMILES",
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
Expand Down Expand Up @@ -135,14 +136,14 @@ def __init__(
self.infinite = infinite
self.rank = rank
self.world_size = world_size
self.representation_type = representation_type

# for non sync communication between ranks
if not self.infinite and store:
self.store = store
else:
self.store = None


# variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []
Expand Down Expand Up @@ -172,7 +173,7 @@ def __iter__(self):
for sample_json in self._get_data_iter():
if self._some_rank_finished():
break
sample_text = self.data_processing_fn(sample_json, self.rng)
sample_text = self.data_processing_fn(sample_json, self.rng, self.representation_type)
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1
Expand Down Expand Up @@ -255,6 +256,7 @@ def build_hf_data_loader(
seq_len: int,
world_size,
rank,
representation_type,
infinite: bool = True,
pin_memory: bool = False,
num_workers: int = 2,
Expand All @@ -268,7 +270,7 @@ def build_hf_data_loader(
data_completion_store = None

hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, data_processing_style, tokenizer, seq_len, world_size, rank, infinite, special_mode,store = data_completion_store
dataset_name, dataset_path, data_processing_style, tokenizer, representation_type, seq_len, world_size, rank, infinite, special_mode,store = data_completion_store
)

return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers)
4 changes: 2 additions & 2 deletions torchtitan/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def load_jsonl_line(jsonl_line):
raise ValueError(f"Error decoding JSON: {e}")


def chemlactica_style_data_processing(sample_json, rng):
def chemlactica_style_data_processing(sample_json, rng, representation_type):
try:
sample_json = json.loads(sample_json["text"])
compound = delete_empty_tags(sample_json)
sample_json = generate_formatted_string(
compound, rng
compound, rng, representation_type
)
except Exception as e:
print(e)
Expand Down
Loading

0 comments on commit be97fc3

Please sign in to comment.