Skip to content

Commit

Permalink
Adding support of domains to config
Browse files Browse the repository at this point in the history
  • Loading branch information
paultltc committed Jan 22, 2025
1 parent 4d6c1d3 commit 93cf438
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
18 changes: 10 additions & 8 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,12 @@ def __post_init__(self):
class MultilingualNanosetDatasetsArgs:
training_folder: Union[str, dict, List[str]]
validation_folder: Union[str, List[str]]
languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
domains: Optional[List[str]] = None # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
languages: Optional[List[str]] = None # NOTE(@paultltc): For back-compatibility

def __post_init__(self):
if self.languages is not None and self.domains is None:
self.domains = self.languages
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
self.training_folder = [self.training_folder]
self.validation_folder = [self.validation_folder]
Expand All @@ -125,13 +128,13 @@ def __post_init__(self):
self.training_folder = list(tmp_training_folder.keys())
self.dataset_weights = list(tmp_training_folder.values())

assert len(self.training_folder) == len(
self.languages
), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})"
# assert len(self.training_folder) == len(
# self.domains
# ), f"The sizes of training_folder and domains mismatch ({len(self.training_folder)} vs {len(self.domains)})"

assert len(self.training_folder) == len(
self.validation_folder
), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})"
# assert len(self.training_folder) == len(
# self.validation_folder
# ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})"


@dataclass
Expand Down Expand Up @@ -189,7 +192,6 @@ class GeneralArgs:
Args:
project: Name of the project (a project gather several runs in common tensorboard/hub-folders)
entity: Weights and bias entity name (optional)
run: Name of the run
step: Global step (updated when we save the checkpoint)
consumed_train_samples: Number of samples consumed during training (should be actually just step*batch_size)
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
)

lang_losses = {
lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.languages
lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.domains
}
lang_losses_list = list(lang_losses.keys())

Expand Down

0 comments on commit 93cf438

Please sign in to comment.