From 33bb49da47957dc678885c44c259fe09e5084d66 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Thu, 13 Jun 2024 14:35:18 +0000 Subject: [PATCH] Default to using tokenizer chat template if available --- Dockerfile | 2 +- Dockerfile-notebook | 2 +- config-base.yaml | 2 +- data_utils.py | 15 ++++++++------- train.py | 6 ++---- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/Dockerfile b/Dockerfile index 92c5a06..e4fd816 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,7 +9,7 @@ RUN mkdir -p /packages && \ cd /packages && \ git clone https://github.com/truefoundry/axolotl && \ cd axolotl/ && \ - git checkout 0711bfeb6af7d359deb4ee2cae81ceb6890ebf80 + git checkout 5ba183d302ed1c91912555b76e423786acaccae8 RUN cd /packages/axolotl/ && \ MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --no-build-isolation -e .[flash-attn,mamba-ssm,fused-dense-lib] && \ pip install --no-cache-dir -U -r /tmp/requirements.txt && \ diff --git a/Dockerfile-notebook b/Dockerfile-notebook index 31a3c41..71966c2 100644 --- a/Dockerfile-notebook +++ b/Dockerfile-notebook @@ -21,7 +21,7 @@ USER jovyan RUN cd /packages && \ git clone https://github.com/truefoundry/axolotl && \ cd axolotl/ && \ - git checkout 0711bfeb6af7d359deb4ee2cae81ceb6890ebf80 + git checkout 5ba183d302ed1c91912555b76e423786acaccae8 RUN cd /packages/axolotl/ && \ MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --no-build-isolation -e .[flash-attn,mamba-ssm,fused-dense-lib] && \ pip install --no-cache-dir -U -r /tmp/llm-finetune/notebook-requirements.txt diff --git a/config-base.yaml b/config-base.yaml index 7ed9691..9d8598c 100644 --- a/config-base.yaml +++ b/config-base.yaml @@ -3,7 +3,7 @@ adapter: qlora base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 debug: False micro_batch_size: 1 -revision_of_model: +revision_of_model: null sequence_len: 2048 val_set_size: 0.1 ## Added by TrueFoundry, not native to Axolotl diff --git a/data_utils.py b/data_utils.py index f66904f..acdee88 100644 --- a/data_utils.py +++ b/data_utils.py @@ -26,7 +26,9 @@ class DatasetType(str, enum.Enum): def _make_dataset_file_source( - path, split="train", dataset_type: DatasetType = DatasetType.completion, chat_template: str = "chatml" + path, + split="train", + dataset_type: DatasetType = DatasetType.completion, ): """ Axolotl dynamically loads prompt strategies based on the `type` key @@ -56,7 +58,6 @@ def _make_dataset_file_source( "path": path, "ds_type": "json", "type": "chat_template", - "chat_template": chat_template, "field_messages": "messages", "message_field_role": "role", "message_field_content": "content", @@ -68,7 +69,9 @@ def _make_dataset_file_source( def dataset_uri_to_axolotl_datasources( - uri, download_dir, dataset_type: DatasetType = DatasetType.completion, chat_template: str = "chatml" + uri, + download_dir, + dataset_type: DatasetType = DatasetType.completion, ): # TODO: Add support for HF datasets if uri.startswith("https://"): @@ -88,11 +91,9 @@ def dataset_uri_to_axolotl_datasources( datasources = [] if os.path.isdir(uri): for filepath in find_all_jsonl_files(uri): - datasources.append( - _make_dataset_file_source(path=filepath, dataset_type=dataset_type, chat_template=chat_template) - ) + datasources.append(_make_dataset_file_source(path=filepath, dataset_type=dataset_type)) else: - datasources = [_make_dataset_file_source(path=uri, dataset_type=dataset_type, chat_template=chat_template)] + datasources = [_make_dataset_file_source(path=uri, dataset_type=dataset_type)] return datasources else: raise ValueError("Unsupported data uri or path does not exist: {uri}") diff --git a/train.py b/train.py index 6488ed3..394815a 100644 --- a/train.py +++ b/train.py @@ -178,7 +178,7 @@ def make_axolotl_config(config_base, kwargs, timestamp=None): if cfg.chat_template == "auto": model_type = getattr(model_hf_config, "model_type", None) - chat_template = MODEL_TYPE_TO_CHAT_TEMPLATE.get(model_type, "chatml") + chat_template = "tokenizer_default_fallback_" + MODEL_TYPE_TO_CHAT_TEMPLATE.get(model_type, "chatml") set_cfg_option_if_auto(cfg, "chat_template", chat_template) if cfg.datasets == "auto": @@ -188,7 +188,6 @@ def make_axolotl_config(config_base, kwargs, timestamp=None): uri=cfg.train_data_uri, download_dir=cfg.data_dir, dataset_type=cfg.dataset_type, - chat_template=cfg.chat_template, ) if cfg.test_datasets == "auto": if cfg.val_data_uri and str(cfg.val_data_uri).lower() != "na": @@ -196,7 +195,6 @@ def make_axolotl_config(config_base, kwargs, timestamp=None): uri=cfg.val_data_uri, download_dir=cfg.data_dir, dataset_type=cfg.dataset_type, - chat_template=chat_template, ) elif cfg.val_set_size: set_cfg_option_if_auto(cfg, "test_datasets", None, force=True) @@ -220,7 +218,7 @@ def make_axolotl_config(config_base, kwargs, timestamp=None): cfg["special_tokens"]["pad_token"] = tokenizer.eos_token set_cfg_option_if_auto(cfg, "lora_modules_to_save", []) logger.info(f"Prepared config: {cfg}") - # This hack is needed because yaml dump refuses to tread DictDefault as dict + # This hack is needed because yaml dump refuses to treat DictDefault as dict yaml.add_representer( DictDefault, lambda dumper, data: dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) )