Skip to content

Commit

Permalink
[feat] Support multiple datasets tracking for HF trainer (#2518)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamohannes authored Feb 3, 2023
1 parent b8fe44a commit 4acafa9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- Add Prophet integration (grigoryan-davit)
- Add 'Dataset' type support for hf/datasets (tmynn)
- Add HuggingFace Transformers model info (tmynn)
- Add multidataset logging support for HuggingFace transformers (tmynn)

### Fixes

Expand Down
59 changes: 43 additions & 16 deletions aim/sdk/adapters/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from logging import getLogger
from typing import Optional
from typing import Optional, List, Dict
from difflib import SequenceMatcher
from collections import defaultdict

from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT
from aim.sdk.num_utils import is_number
Expand All @@ -17,13 +19,13 @@


class AimCallback(TrainerCallback):
def __init__(self,
repo: Optional[str] = None,
experiment: Optional[str] = None,
system_tracking_interval: Optional[int]
= DEFAULT_SYSTEM_TRACKING_INT,
log_system_params: bool = True,
):
def __init__(
self,
repo: Optional[str] = None,
experiment: Optional[str] = None,
system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT,
log_system_params: bool = True,
):
self._repo_path = repo
self._experiment_name = experiment
self._system_tracking_interval = system_tracking_interval
Expand Down Expand Up @@ -63,15 +65,14 @@ def setup(self, args=None, state=None, model=None):
for key, value in combined_dict.items():
self._run.set(('hparams', key), value, strict=False)
if model:
self._run.set("model", {**vars(model.config), "num_labels": model.num_labels})
self._run.set('model', {**vars(model.config), 'num_labels': model.num_labels})

# Store model configs as well
# if hasattr(model, 'config') and model.config is not None:
# model_config = model.config.to_dict()
# self._run['model'] = model_config

def on_train_begin(self, args, state, control,
model=None, **kwargs):
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not state.is_world_process_zero:
return
if not self._run:
Expand All @@ -82,8 +83,7 @@ def on_train_end(self, args, state, control, **kwargs):
return
self.close()

def on_log(self, args, state, control,
model=None, logs=None, **kwargs):
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not state.is_world_process_zero:
return

Expand All @@ -97,6 +97,13 @@ def on_log(self, args, state, control,
if log_name.startswith(prefix):
log_name = log_name[len(prefix):]
context = {'subset': prefix[:-1]}
if '_' in log_name:
sub_dataset = AimCallback.find_most_common_substring(
list(logs.keys())
).split(prefix)[-1]
if sub_dataset != prefix.rstrip('_'):
log_name = log_name.split(sub_dataset)[-1].lstrip('_')
context['sub_dataset'] = sub_dataset
break
if not is_number(log_value):
if not self._log_value_warned:
Expand All @@ -108,15 +115,35 @@ def on_log(self, args, state, control,
)
continue

self._run.track(log_value,
name=log_name, context=context,
step=state.global_step, epoch=state.epoch)
self._run.track(
log_value,
name=log_name,
context=context,
step=state.global_step,
epoch=state.epoch,
)

def close(self):
if self._run:
self._run.close()
del self._run
self._run = None

@staticmethod
def find_most_common_substring(names: List[str]) -> Dict[str, int]:
substring_counts = defaultdict(lambda: 0)

for i in range(0, len(names)):
for j in range(i + 1, len(names)):
string1 = names[i]
string2 = names[j]
match = SequenceMatcher(None, string1, string2).find_longest_match(
0, len(string1), 0, len(string2)
)
matching_substring = string1[match.a:match.a + match.size]
substring_counts[matching_substring] += 1

return max(substring_counts, key= lambda x: substring_counts[x]).rstrip('_')

def __del__(self):
self.close()

0 comments on commit 4acafa9

Please sign in to comment.