Skip to content

Commit

Permalink
Refactor training logs & warmup_proportion
Browse files Browse the repository at this point in the history
Now we display the number of unique pairs (instead of number of examples, which was broken).
  • Loading branch information
tomaarsen committed Jan 16, 2024
1 parent 9729aa0 commit 23cd04d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
17 changes: 8 additions & 9 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,24 +440,23 @@ def train_embeddings(
self.state.total_flos = 0

train_max_pairs = -1 if args.max_steps == -1 else args.max_steps * args.embedding_batch_size
train_dataloader, loss_func, batch_size = self.get_dataloader(
train_dataloader, loss_func, batch_size, num_unique_pairs = self.get_dataloader(
x_train, y_train, args=args, max_pairs=train_max_pairs
)
if x_eval is not None and args.evaluation_strategy != IntervalStrategy.NO:
eval_max_pairs = -1 if args.eval_max_steps == -1 else args.eval_max_steps * args.embedding_batch_size
eval_dataloader, _, _ = self.get_dataloader(x_eval, y_eval, args=args, max_pairs=eval_max_pairs)
eval_dataloader, _, _, _ = self.get_dataloader(x_eval, y_eval, args=args, max_pairs=eval_max_pairs)
else:
eval_dataloader = None

total_train_steps = len(train_dataloader) * args.embedding_num_epochs
if args.max_steps > 0:
total_train_steps = args.max_steps
else:
total_train_steps = len(train_dataloader) * args.embedding_num_epochs
total_train_steps = min(args.max_steps, total_train_steps)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader)}")
logger.info(f" Num unique pairs = {num_unique_pairs}")
logger.info(f" Batch size = {batch_size}")
logger.info(f" Num epochs = {args.embedding_num_epochs}")
logger.info(f" Total optimization steps = {total_train_steps}")
logger.info(f" Total train batch size = {batch_size}")

warmup_steps = math.ceil(total_train_steps * args.warmup_proportion)
self._train_sentence_transformer(
Expand All @@ -471,7 +470,7 @@ def train_embeddings(

def get_dataloader(
self, x: List[str], y: Union[List[int], List[List[int]]], args: TrainingArguments, max_pairs: int = -1
) -> Tuple[DataLoader, nn.Module, int]:
) -> Tuple[DataLoader, nn.Module, int, int]:
# sentence-transformers adaptation
input_data = [InputExample(texts=[text], label=label) for text, label in zip(x, y)]

Expand Down Expand Up @@ -511,7 +510,7 @@ def get_dataloader(
dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=False)
loss = args.loss(self.model.model_body)

return dataloader, loss, batch_size
return dataloader, loss, batch_size, len(data_sampler)

def log(self, args: TrainingArguments, logs: Dict[str, float]) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/setfit/trainer_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_dataloader(
y: Optional[Union[List[int], List[List[int]]]],
args: TrainingArguments,
max_pairs: int = -1,
) -> Tuple[DataLoader, nn.Module, int]:
) -> Tuple[DataLoader, nn.Module, int, int]:
x_embd_student = self.teacher_model.model_body.encode(
x, convert_to_tensor=self.teacher_model.has_differentiable_head
)
Expand All @@ -96,7 +96,7 @@ def get_dataloader(
batch_size = min(args.embedding_batch_size, len(data_sampler))
dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=False)
loss = args.loss(self.model.model_body)
return dataloader, loss, batch_size
return dataloader, loss, batch_size, len(data_sampler)

def train_classifier(self, x_train: List[str], args: Optional[TrainingArguments] = None) -> None:
"""
Expand Down

0 comments on commit 23cd04d

Please sign in to comment.