diff --git a/docs/config.qmd b/docs/config.qmd index ecb571040f..114c7c3e67 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -137,6 +137,8 @@ datasets: message_field_role: role # Key for content in each message (default: "content") message_field_content: content + # Mapping of properties from the input dataset to the chat template. (default: None) + message_property_mappings: # Optional[Dict[str, List]]. Roles mapping in the messages. The default is: roles: diff --git a/scripts/chat_datasets.py b/scripts/chat_datasets.py index 6210b11388..8ae1e256d3 100644 --- a/scripts/chat_datasets.py +++ b/scripts/chat_datasets.py @@ -31,27 +31,26 @@ def parse_dataset(dataset=None, split="train"): ds_cfg["field_messages"] = field_messages message_fields = features[field_messages][0].keys() - message_field_role = None + + message_property_mappings = {"role": None, "content": None} for key in ["from", "role"]: if key in message_fields: - message_field_role = key + message_property_mappings["role"] = key break - if not message_field_role: + if not message_property_mappings["role"]: raise ValueError( f'No role field found in messages: {", ".join(message_fields)}' ) - ds_cfg["message_field_role"] = message_field_role - message_field_content = None for key in ["content", "text", "value"]: if key in message_fields: - message_field_content = key + message_property_mappings["content"] = key break - if not message_field_content: + if not message_property_mappings["content"]: raise ValueError( f'No content field found in messages: {", ".join(message_fields)}' ) - ds_cfg["message_field_content"] = message_field_content + ds_cfg["message_property_mappings"] = message_property_mappings print(yaml.dump({"datasets": [ds_cfg]})) diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 74da20c5e1..d71a5fe261 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -30,6 +30,7 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): load_kwargs["ds_cfg"] = ds_cfg if "processor" in sig.parameters: load_kwargs["processor"] = processor + return func(tokenizer, cfg, **load_kwargs) except ModuleNotFoundError: return None diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index 4f60842c5f..16ba1c4b40 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -95,8 +95,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): prompter_params = { "tokenizer": tokenizer, "chat_template": chat_template_string, - "message_field_role": ds_cfg.get("message_field_role", "role"), - "message_field_content": ds_cfg.get("message_field_content", "content"), + "message_property_mappings": ds_cfg.get("message_property_mappings", {}), "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( "message_field_training_detail", None diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 5b12130d75..63801da246 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -3,10 +3,11 @@ """ import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set from transformers import ProcessorMixin +from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import get_chat_template_from_config @@ -25,10 +26,10 @@ def __init__( processor=None, chat_template=None, max_length=2048, - message_field_role: str = "role", - message_field_content: str = "content", + message_property_mappings: Optional[Dict[str, str]] = None, message_field_training: Optional[str] = None, message_field_training_detail: Optional[str] = None, + messages_array_name: str = "messages", roles: Optional[Dict[str, List[str]]] = None, drop_system_message: bool = False, ): @@ -44,8 +45,10 @@ def __init__( "tool": "tool", } - self.message_field_role = message_field_role - self.message_field_content = message_field_content + self._chat_template_msg_variables = self.get_chat_template_msg_variables( + chat_template, messages_array_name + ) + self.message_property_mappings = message_property_mappings self.message_field_training = message_field_training self.message_field_training_detail = message_field_training_detail self.tokenizer = tokenizer @@ -54,6 +57,10 @@ def __init__( self.max_length = max_length self.drop_system_message = drop_system_message + @property + def chat_template_msg_variables(self) -> Set[str]: + return self._chat_template_msg_variables + def build_prompt(self, conversation, add_generation_prompt=False, images=None): if self.processor: text = self.processor.apply_chat_template( @@ -183,6 +190,12 @@ def adjust_train_details( return adjusted_details + def get_chat_template_msg_variables( + self, chat_template: str, messages_array_name: str + ) -> Set[str]: + template_analyzer = JinjaTemplateAnalyzer(chat_template) + return template_analyzer.get_message_vars(messages_array_name) + class ChatTemplateStrategy(PromptTokenizingStrategy): """ @@ -212,6 +225,10 @@ def __init__( self.train_on_eos = train_on_eos self.images = "images" + LOG.info( + f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}" + ) + @property def messages(self): return self._messages @@ -424,30 +441,17 @@ def find_turn(self, turns: list[dict], turn_idx: int): def get_conversation_thread(self, prompt): turns = [] - optional_keys = [ - "tool_calls", # tool that 'assistant' calls - "name", # name of tool given by 'tool' - "tool_call_id", # mistral/mixtral requires this - ] for message in prompt[self.messages]: + transformed_message = self.transform_message(message) + turn = { - "role": self.prompter.roles[message[self.prompter.message_field_role]], + **transformed_message, "training": message.get(self.prompter.message_field_training), "training_detail": message.get( self.prompter.message_field_training_detail ), } - # do not add content if None as it may conflict with some templates due to tools - content = message.get(self.prompter.message_field_content, None) - if content is not None: - turn["content"] = content - - for key in optional_keys: - value = message.get(key, None) - if value is not None: - turn[key] = value - turns.append(turn) if self.prompter.drop_system_message and turns[0]["role"] == "system": @@ -455,30 +459,64 @@ def get_conversation_thread(self, prompt): return turns + def transform_message(self, message): + # Build the initial transformed message from the mappings + transformed_message = { + key: message[value] + for key, value in self.prompter.message_property_mappings.items() + if message.get(value) is not None + } + + # Map the role if necessary + if "role" in transformed_message: + transformed_message["role"] = self.prompter.roles.get( + transformed_message["role"], transformed_message["role"] + ) + + # Determine which keys in the original message were not mapped + mapped_values = set(self.prompter.message_property_mappings.values()) + remaining_keys = set(message) - mapped_values + + # Keep only the properties defined in the chat template + # and not already mapped + for key in self.prompter.chat_template_msg_variables: + if key in remaining_keys: + val = message.get(key) + if val is not None: + transformed_message[key] = val + + return transformed_message + def get_images(self, prompt): return prompt.get(self.images, None) -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): - # pylint: disable=duplicate-code - ds_cfg = ds_cfg or {} +def load( + tokenizer, + cfg, + ds_cfg: Optional[Dict[str, Any]] = None, + processor=None, +): + dataset_config = ds_cfg if ds_cfg else {} chat_template_string = get_chat_template_from_config( - cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer ) LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---") prompter_params = { "tokenizer": tokenizer, "chat_template": chat_template_string, - "message_field_role": ds_cfg.get("message_field_role", "role"), - "message_field_content": ds_cfg.get("message_field_content", "content"), - "message_field_training": ds_cfg.get("message_field_training", None), - "message_field_training_detail": ds_cfg.get( + "message_property_mappings": dataset_config.get( + "message_property_mappings", {} + ), + "message_field_training": dataset_config.get("message_field_training", None), + "message_field_training_detail": dataset_config.get( "message_field_training_detail", None, ), - "roles": ds_cfg.get("roles"), - "drop_system_message": ds_cfg.get("drop_system_message", False), + "messages_array_name": dataset_config.get("field_messages", "messages"), + "roles": dataset_config.get("roles"), + "drop_system_message": dataset_config.get("drop_system_message", False), # we need to add one for detecting sequences with exceeding the `sequence_len` limit. "max_length": cfg.sequence_len + 1, "processor": processor, @@ -487,15 +525,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + "roles_to_train": dataset_config.get("roles_to_train", ["assistant"]), + "train_on_eos": dataset_config.get("train_on_eos", "turn"), } strategy = ChatTemplateStrategy( ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params ) - if "field_messages" in ds_cfg and hasattr(strategy, "messages"): - strategy.messages = ds_cfg["field_messages"] + if "field_messages" in dataset_config and hasattr(strategy, "messages"): + strategy.messages = dataset_config["field_messages"] return strategy diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 489b864851..9e4986dd4c 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -3,20 +3,28 @@ """ from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template +from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic def default( cfg, dataset_idx=0, **kwargs ): # pylint: disable=possibly-unused-variable,unused-argument ds_cfg = cfg["datasets"][dataset_idx] + ds_cfg = handle_legacy_message_fields_logic(ds_cfg) + chat_template_choice, chat_template_jinja = extract_chat_template_args( cfg=cfg, ds_cfg=ds_cfg ) field_messages = ds_cfg.get("field_messages", "messages") field_chosen = ds_cfg.get("field_chosen", "chosen") field_rejected = ds_cfg.get("field_rejected", "rejected") - field_message_role = ds_cfg.get("message_field_role", "role") - field_message_content = ds_cfg.get("message_field_content", "content") + message_property_mappings = ds_cfg.get( + "message_property_mappings", + { + "role": "role", + "content": "content", + }, + ) role_map_inv = ds_cfg.get( "roles", { @@ -40,18 +48,18 @@ def transform_fn(sample, tokenizer=None): messages = sample[field_messages] messages = [ { - "role": role_map[m[field_message_role]], - "content": m[field_message_content], + "role": role_map[m[message_property_mappings["role"]]], + "content": m[message_property_mappings["content"]], } for m in messages ] chosen = { - "role": role_map[sample[field_chosen][field_message_role]], - "content": sample[field_chosen][field_message_content], + "role": role_map[sample[field_chosen][message_property_mappings["role"]]], + "content": sample[field_chosen][message_property_mappings["content"]], } rejected = { - "role": role_map[sample[field_rejected][field_message_role]], - "content": sample[field_rejected][field_message_content], + "role": role_map[sample[field_rejected][message_property_mappings["role"]]], + "content": sample[field_rejected][message_property_mappings["content"]], } dummy_user_message = {"role": "user", "content": "[[dummy_message]]"} diff --git a/src/axolotl/prompt_strategies/jinja_template_analyzer.py b/src/axolotl/prompt_strategies/jinja_template_analyzer.py new file mode 100644 index 0000000000..bcf9a4128a --- /dev/null +++ b/src/axolotl/prompt_strategies/jinja_template_analyzer.py @@ -0,0 +1,318 @@ +"""Module for inspect jinja templates for the variables they use""" +from typing import Dict, Optional, Set, TypedDict, Union + +from jinja2 import Environment, meta, nodes + + +class JinjaTemplateAnalysis(TypedDict): + """ + Represents the detailed analysis of a Jinja template variable. + + Attributes: + accessed_properties (Set[str]): A set of properties accessed from the variable + (e.g., `foo.bar` results in 'bar' being accessed for 'foo'). + accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable. + is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop. + is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block). + iteration_source (Optional[str]): The name of the variable being iterated over, if applicable. + iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration. + """ + + accessed_properties: Set[str] + accessed_indices: Set[Union[int, float]] + is_iterated: bool + is_conditional: bool + iteration_source: Optional[str] + iteration_target: Optional[Union[str, list[str]]] + + +class JinjaTemplateAnalyzer: + """ + Analyzes Jinja templates to extract information about variable usage, + including accessed properties, iteration, and conditional references. + + Attributes: + env (jinja2.Environment): The Jinja2 environment used for parsing templates. + property_access (Dict[str, Set[str]]): Tracks accessed properties for variables. + iteration_targets (Dict[str, str]): Maps iteration target variables to their sources. + + Methods: + get_template_variables(template: str) -> Dict[str, Set[str]]: + Parse a Jinja template and return a mapping of variables to their accessed properties. + + analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]: + Perform a detailed analysis of the template, including variable usage, + iteration, and conditional references. + + Private Methods: + _visit_node(node) -> None: + Recursively visit AST nodes to detect attribute access and iteration targets. + + _get_base_name(node) -> Optional[str]: + Extract the base variable name from a node. + + _get_target_name(node) -> Optional[Union[str, list[str]]]: + Extract the target name(s) from a `For` node. + """ + + def __init__(self, template: str): + self.env: Environment = Environment(autoescape=True) + self.property_access: Dict[str, Set[str]] = {} + self.iteration_targets: Dict[str, Union[str, list[str]]] = {} + self.index_access: Dict[str, Set[Union[int, float]]] = {} + self.ast: nodes.Node = self.env.parse(template) + self.template: str = template + self.variable_assignments: Dict[str, str] = {} + + def _visit_node(self, node) -> None: + """Recursively visit AST nodes to find attribute access.""" + # Handle attribute access (dot notation) + if isinstance(node, nodes.Getattr): + base_name = self._get_base_name(node.node) + if base_name: + self.property_access.setdefault(base_name, set()).add(node.attr) + + # Handle dictionary access (subscript notation) + elif isinstance(node, nodes.Getitem): + base_name = self._get_base_name(node.node) + if base_name and isinstance(node.arg, nodes.Const): + value = node.arg.value + if isinstance(value, (int, float)): + self.index_access.setdefault(base_name, set()).add(value) + else: + self.property_access.setdefault(base_name, set()).add(value) + + elif isinstance(node, nodes.Test) and node.name == "defined": + base_name = self._get_base_name(node.node) + if base_name: + if isinstance(node.node, nodes.Getattr): + self.property_access.setdefault(base_name, set()).add( + node.node.attr + ) + + # Handle loop variables + elif isinstance(node, nodes.For): + iter_name = self._get_base_name(node.iter) + target_name = self._get_target_name(node.target) + if iter_name and target_name: + self.iteration_targets[target_name] = iter_name + self.property_access.setdefault(iter_name, set()) + + elif isinstance(node, nodes.Assign): + target_name = self._get_target_name(node.target) + source_name = self._get_base_name(node.node) + if target_name and source_name: + self.variable_assignments[target_name] = source_name + + elif isinstance(node, nodes.Filter): + if node.name == "selectattr": + target = self._get_base_name(node.node) + if target: + self.variable_assignments[f"filtered_{target}"] = target + + for child in node.iter_child_nodes(): + self._visit_node(child) + + def _get_target_name(self, node) -> Optional[str]: + """Get the target variable name from a For node. + + Args: + node: A Jinja AST node representing either a Name or Tuple node + + Returns: + - str: For simple variable targets (e.g., "item" in "for item in items") + - None: If the node type is not recognized or is a tuple + """ + if isinstance(node, nodes.Name): + return node.name + return None + + def _get_target_names(self, node) -> list[str]: + """Get all target variable names from a For node, including tuple unpacking. + + Args: + node: A Jinja AST node representing either a Name or Tuple node + + Returns: + List of target variable names + """ + if isinstance(node, nodes.Name): + return [node.name] + + if isinstance(node, nodes.Tuple): + names = [] + for n in node.items: + if isinstance(n, nodes.Name): + names.append(n.name) + return names + + return [] + + def _get_base_name(self, node) -> Optional[str]: + """Get the base variable name from a node.""" + if isinstance(node, nodes.Name): + return node.name + + if isinstance(node, nodes.Getattr): + return self._get_base_name(node.node) + + if isinstance(node, nodes.Getitem): + return self._get_base_name(node.node) + + return None + + def get_template_variables(self) -> Dict[str, Set[str]]: + """ + Parse a Jinja template and return both variables and their accessed properties. + + Args: + template (str): The Jinja template string + + Returns: + Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties + """ + # Parse the template + ast = self.env.parse(self.template) + + # Get all undeclared variables + variables = meta.find_undeclared_variables(ast) + + # Reset property access tracking + self.property_access = {} + + # Visit all nodes to find property access + self._visit_node(ast) + + # Create result dictionary + result: Dict[str, Set[str]] = {var: set() for var in variables} + # Merge in any discovered sub-properties + for var, props in self.property_access.items(): + if var not in result: + result[var] = set() + result[var].update(props) + + return result + + def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]: + """ + Provide a detailed analysis of template variables and their usage. + """ + variables = self.get_template_variables() + self.iteration_targets = {} + + analysis: Dict[str, JinjaTemplateAnalysis] = { + var: JinjaTemplateAnalysis( + accessed_properties=props, + accessed_indices=set(), + is_iterated=False, + is_conditional=False, + iteration_source=None, + iteration_target=None, + ) + for var, props in variables.items() + } + + for var, indices in self.index_access.items(): + if var in analysis: + analysis[var]["accessed_indices"] = indices + + def visit_node(node): + if isinstance(node, nodes.If): + + def find_test_vars(test_node): + if isinstance(test_node, nodes.Name): + if test_node.name in analysis: + analysis[test_node.name]["is_conditional"] = True + for child in test_node.iter_child_nodes(): + find_test_vars(child) + + find_test_vars(node.test) + + if isinstance(node, nodes.For): + iter_target = self._get_base_name(node.iter) + target_name = self._get_target_name(node.target) + if iter_target in analysis: + analysis[iter_target]["is_iterated"] = True + if target_name: + analysis[iter_target]["iteration_target"] = target_name + if isinstance(target_name, str) and target_name not in analysis: + analysis[target_name] = { + "accessed_properties": set(), + "is_iterated": False, + "is_conditional": False, + "iteration_source": iter_target, + "iteration_target": None, + } + + for child in node.iter_child_nodes(): + visit_node(child) + + visit_node(self.ast) + return analysis + + def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]: + """ + Get all properties accessed on a variable and its downstream assignments. + + Args: + start_var: The starting variable to trace + + Returns: + Dict mapping variable names to their accessed properties + """ + visited = set() + properties = {} + + def trace_variable(var_name: str): + if var_name in visited: + return + visited.add(var_name) + + # Get direct properties + if var_name in self.property_access: + properties[var_name] = self.property_access[var_name] + + # Get properties from iteration targets + if var_name in self.iteration_targets: + target = self.iteration_targets[var_name] + if isinstance(target, str): + trace_variable(target) + elif isinstance(target, list): + for t in target: + trace_variable(t) + + # Follow assignments + for target, source in self.variable_assignments.items(): + if source == var_name: + trace_variable(target) + + # Check for array slicing + analysis = self.analyze_template() + if var_name in analysis: + var_info = analysis[var_name] + if var_info["accessed_indices"]: + # If this variable is sliced, follow the resulting assignment + slice_result = f"{var_name}_slice" + if slice_result in self.property_access: + trace_variable(slice_result) + + trace_variable(start_var) + return properties + + def get_message_vars(self, messages_array_name: str = "messages") -> Set[str]: + """ + Get all properties accessed on messages and derived variables. + """ + all_properties = self.get_downstream_properties(messages_array_name) + + # Combine all properties from all related variables + combined_properties = set() + for properties in all_properties.values(): + combined_properties.update(properties) + + # Also include properties from the message iteration variable + analysis = self.analyze_template() + if "message" in analysis: + combined_properties.update(analysis["message"]["accessed_properties"]) + + return combined_properties diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 7ddff62196..2e95606674 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -17,6 +17,7 @@ from axolotl.utils.config.models.input.v0_4_1 import ( AxolotlInputConfig as AxolotlInputConfigBase, ) +from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model_config @@ -249,7 +250,7 @@ def validate_config( cfg: DictDefault, capabilities: Optional[dict] = None, env_capabilities: Optional[dict] = None, -): +) -> DictDefault: AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase AxolotlInputConfig = AxolotlInputConfigBase @@ -259,6 +260,16 @@ def validate_config( AxolotlInputConfig, # pylint: disable=invalid-name ) = merge_input_args() + # Convert datasets to proper format if needed + if cfg.get("datasets"): + for idx, ds_cfg in enumerate(cfg["datasets"]): + if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset): + cfg["datasets"][idx] = DPODataset(**ds_cfg) + elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset): + cfg["datasets"][idx] = KTODataset(**ds_cfg) + elif not isinstance(ds_cfg, SFTDataset): + cfg["datasets"][idx] = SFTDataset(**ds_cfg) + if capabilities or env_capabilities: if (capabilities and env_capabilities is None) or ( env_capabilities and capabilities is None diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index dc8897863d..d286f3caac 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -15,6 +15,7 @@ Field, StringConstraints, conlist, + field_serializer, field_validator, model_validator, ) @@ -183,6 +184,7 @@ class SFTDataset(BaseModel): field_messages: Optional[str] = None message_field_role: Optional[str] = None message_field_content: Optional[str] = None + message_property_mappings: Optional[Dict[str, str]] = None message_field_training: Optional[str] = None message_field_training_detail: Optional[str] = None roles_to_train: Optional[List[str]] = None @@ -192,9 +194,18 @@ class SFTDataset(BaseModel): trust_remote_code: Optional[bool] = False revision: Optional[str] = None + @model_validator(mode="before") + @classmethod + def handle_legacy_message_fields(cls, data): + """Handle backwards compatibility between legacy message field mapping and new property mapping system.""" + return handle_legacy_message_fields_logic(data) + @model_validator(mode="before") @classmethod def check_chat_template_config(cls, data): + if isinstance(data, BaseModel): + data = data.model_dump() + # Set chat_template to tokenizer_default if not set if data.get("type") == "chat_template" and not data.get("chat_template"): data["chat_template"] = ChatTemplate.tokenizer_default @@ -214,9 +225,14 @@ def check_chat_template_config(cls, data): return data -class UserDefinedDPOType(BaseModel): - """User defined typing for DPO""" +class DPODataset(BaseModel): + """DPO configuration subset""" + path: Optional[str] = None + split: Optional[str] = None + type: Optional[str] = None + data_files: Optional[List[str]] = None + revision: Optional[str] = None field_system: Optional[str] = None field_prompt: Optional[str] = None field_chosen: Optional[str] = None @@ -224,16 +240,25 @@ class UserDefinedDPOType(BaseModel): prompt_format: Optional[str] = None chosen_format: Optional[str] = None rejected_format: Optional[str] = None + field_messages: Optional[str] = None -class DPODataset(BaseModel): - """DPO configuration subset""" +class KTODataset(BaseModel): + """KTO configuration subset""" path: Optional[str] = None split: Optional[str] = None - type: Optional[Union[UserDefinedDPOType, str]] = None + type: Optional[str] = None data_files: Optional[List[str]] = None + trust_remote_code: Optional[bool] = False revision: Optional[str] = None + field_system: Optional[str] = None + field_prompt: Optional[str] = None + field_completion: Optional[str] = None + field_messages: Optional[str] = None + field_label: Optional[bool] = None + prompt_format: Optional[str] = None + completion_format: Optional[str] = None class StepwiseSupervisedDataset(BaseModel): @@ -248,26 +273,7 @@ class StepwiseSupervisedDataset(BaseModel): train_on_last_step_only: Optional[bool] = None -class UserDefinedKTOType(BaseModel): - """User defined typing for KTO""" - - field_system: Optional[str] = None - field_prompt: Optional[str] = None - field_completion: Optional[str] = None - field_label: Optional[bool] = None - prompt_format: Optional[str] = None - completion_format: Optional[str] = None - - -class KTODataset(BaseModel): - """KTO configuration subset""" - - path: Optional[str] = None - split: Optional[str] = None - type: Optional[Union[UserDefinedKTOType, str]] = None - data_files: Optional[List[str]] = None - trust_remote_code: Optional[bool] = False - revision: Optional[str] = None +DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset] class LoftQConfig(BaseModel): @@ -669,17 +675,15 @@ class Config: bool ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. - datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore - test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore + datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore + test_datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore shuffle_merged_datasets: Optional[bool] = True dataset_prepared_path: Optional[str] = None dataset_shard_num: Optional[int] = None dataset_shard_idx: Optional[int] = None skip_prepare_dataset: Optional[bool] = False - pretraining_dataset: Optional[ # type: ignore - conlist(Union[PretrainingDataset, SFTDataset], min_length=1) - ] = Field( + pretraining_dataset: Optional[conlist(Union[PretrainingDataset, SFTDataset], min_length=1)] = Field( # type: ignore default=None, json_schema_extra={"description": "streaming dataset to use for pretraining"}, ) @@ -878,10 +882,15 @@ class Config: @classmethod def deprecate_sharegpt_datasets(cls, datasets): for _, ds_cfg in enumerate(datasets): - if not ds_cfg.get("type"): + # Handle both dict and pydantic model cases + ds_type = ( + ds_cfg.get("type") + if isinstance(ds_cfg, dict) + else getattr(ds_cfg, "type", None) + ) + if not ds_type: continue - ds_type = ds_cfg["type"] # skip if it's a dict (for custom user instruction prompt) if isinstance(ds_type, dict): continue @@ -893,6 +902,14 @@ def deprecate_sharegpt_datasets(cls, datasets): return datasets + @field_serializer("datasets") + def datasets_serializer( + self, ds_configs: Optional[List[DatasetConfig]] + ) -> Optional[List[Dict[str, Any]]]: + if ds_configs: + return [ds_config.model_dump() for ds_config in ds_configs] + return None + @model_validator(mode="before") @classmethod def check_batch_size_fields(cls, data): @@ -1692,3 +1709,73 @@ def check_torch_compile_auto(cls, data): else: data["torch_compile"] = False return data + + +def handle_legacy_message_fields_logic(data: dict) -> dict: + """ + Handle backwards compatibility between legacy message field mapping and new property mapping system. + + Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options: + - message_field_role: Mapped to the role field + - message_field_content: Mapped to the content field + + The new system uses message_property_mappings to support arbitrary field mappings: + message_property_mappings: + role: source_role_field + content: source_content_field + additional_field: source_field + + Args: + data: Dictionary containing configuration data + + Returns: + Updated dictionary with message field mappings consolidated + + Raises: + ValueError: If there are conflicts between legacy and new mappings + """ + data = data.copy() # Create a copy to avoid modifying the original + + if data.get("message_property_mappings") is None: + data["message_property_mappings"] = {} + + # Check for conflicts and handle role + if "message_field_role" in data: + LOG.warning( + "message_field_role is deprecated, use message_property_mappings instead. " + f"Example: message_property_mappings: {{role: {data['message_field_role']}}}" + ) + if ( + "role" in data["message_property_mappings"] + and data["message_property_mappings"]["role"] != data["message_field_role"] + ): + raise ValueError( + f"Conflicting message role fields: message_field_role='{data['message_field_role']}' " + f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'" + ) + data["message_property_mappings"]["role"] = data["message_field_role"] or "role" + elif "role" not in data["message_property_mappings"]: + data["message_property_mappings"]["role"] = "role" + + # Check for conflicts and handle content + if "message_field_content" in data: + LOG.warning( + "message_field_content is deprecated, use message_property_mappings instead. " + f"Example: message_property_mappings: {{content: {data['message_field_content']}}}" + ) + if ( + "content" in data["message_property_mappings"] + and data["message_property_mappings"]["content"] + != data["message_field_content"] + ): + raise ValueError( + f"Conflicting message content fields: message_field_content='{data['message_field_content']}' " + f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'" + ) + data["message_property_mappings"]["content"] = ( + data["message_field_content"] or "content" + ) + elif "content" not in data["message_property_mappings"]: + data["message_property_mappings"]["content"] = "content" + + return data diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index ba5d0c54d1..cee1ead4d5 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -173,6 +173,7 @@ def load_tokenized_prepared_datasets( ) -> Tuple[DatasetDict, List[Prompter]]: cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = cfg.tokenizer_config + ds_hash = str( md5( ( diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py index 409d088e6d..f24f7c4a98 100644 --- a/src/axolotl/utils/dict.py +++ b/src/axolotl/utils/dict.py @@ -13,3 +13,26 @@ def __missing__(self, key): def __or__(self, other): return DictDefault(super().__ror__(other)) + + def __setitem__(self, name, value): + # workaround for pickle/unpickle issues and __frozen not being available + try: + isFrozen = hasattr( # pylint: disable=invalid-name + self, "__frozen" + ) and object.__getattribute__(self, "__frozen") + except AttributeError: + isFrozen = False # pylint: disable=invalid-name + + if isFrozen and name not in super().keys(): + raise KeyError(name) + super(Dict, self).__setitem__(name, value) # pylint: disable=bad-super-call + try: + p = object.__getattribute__(self, "__parent") + key = object.__getattribute__(self, "__key") + except AttributeError: + p = None + key = None + if p is not None: + p[key] = self + object.__delattr__(self, "__parent") + object.__delattr__(self, "__key") diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index 2bfd36d155..f71e4fb4af 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -11,7 +11,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, check_tensorboard @@ -76,7 +76,9 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_ste else: cfg.fp16 = True + cfg = validate_config(cfg) normalize_config(cfg) + cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py index 191f76f647..53a239c517 100644 --- a/tests/e2e/solo/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -10,7 +10,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir @@ -73,6 +73,8 @@ def test_relora(self, temp_dir): "use_tensorboard": True, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 2d0baceeef..cf7335805c 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_preference_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -63,6 +63,8 @@ def test_dpo_lora(self, temp_dir): "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) @@ -108,6 +110,8 @@ def test_dpo_nll_lora(self, temp_dir): "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) @@ -153,6 +157,8 @@ def test_dpo_use_weighting(self, temp_dir): "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) @@ -198,6 +204,8 @@ def test_kto_pair_lora(self, temp_dir): "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) @@ -242,6 +250,8 @@ def test_ipo_lora(self, temp_dir): "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) @@ -289,6 +299,8 @@ def test_orpo_lora(self, temp_dir): "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) @@ -353,6 +365,8 @@ def test_kto_lora(self, temp_dir): "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 4261ccc266..e9962f6680 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -9,7 +9,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, check_tensorboard, with_temp_dir @@ -56,6 +56,8 @@ def test_train_w_embedding_lr_scale(self, temp_dir): "use_tensorboard": True, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index ddcb662755..3fdc2e1bf4 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -9,7 +9,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -65,6 +65,8 @@ def test_lora(self, temp_dir): "bf16": "auto", } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -118,6 +120,8 @@ def test_lora_added_vocab(self, temp_dir): "bf16": "auto", } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -157,6 +161,8 @@ def test_ft(self, temp_dir): "bf16": "auto", } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index a948284904..77e70d8c24 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -10,7 +10,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault LOG = logging.getLogger("axolotl.tests.e2e") @@ -56,6 +56,8 @@ def test_fft_trust_remote_code(self, temp_dir): "save_safetensors": True, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -99,6 +101,8 @@ def test_fix_untrained_tokens(self, temp_dir): "save_safetensors": True, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -138,6 +142,8 @@ def test_batch_flattening(self, temp_dir): "save_safetensors": True, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index c1f024b872..7939884f74 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -10,7 +10,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, check_tensorboard @@ -69,6 +69,8 @@ def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn): "use_tensorboard": True, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 91f101e44c..c4a41f521d 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -9,7 +9,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -62,6 +62,8 @@ def test_lora_llama_vision_text_only_dataset(self, temp_dir): "bf16": True, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 696c47aed7..6bb118470d 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -9,7 +9,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -59,6 +59,8 @@ def test_lora(self, temp_dir): "max_steps": 20, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index 4b4db30585..8884f42f95 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -11,7 +11,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -59,6 +59,8 @@ def test_fft(self, temp_dir): "save_safetensors": False, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index a304e9b4a5..59e1ba7015 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -11,7 +11,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -63,6 +63,8 @@ def test_lora(self, temp_dir): "eval_steps": 10, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -106,6 +108,8 @@ def test_ft(self, temp_dir): cfg.bf16 = True else: cfg.fp16 = True + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 6e06626f6e..5de5ab4036 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -69,6 +69,8 @@ def test_qlora_w_fa2(self, temp_dir): "eval_steps": 10, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -123,6 +125,8 @@ def test_qlora_wo_fa2(self, temp_dir): "eval_steps": 10, } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -180,6 +184,8 @@ def test_16bit_lora_w_fa2(self, temp_dir): cfg.bf16 = True else: cfg.fp16 = True + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -233,6 +239,8 @@ def test_16bit_lora_wo_fa2(self, temp_dir): "eval_steps": 10, } ) + + cfg = validate_config(cfg) normalize_config(cfg) if is_torch_bf16_gpu_available(): cfg.bf16 = True @@ -281,6 +289,8 @@ def test_ft(self, temp_dir): cfg.bf16 = True else: cfg.fp16 = True + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 453872538a..4b0ad1142a 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -9,7 +9,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir @@ -59,6 +59,8 @@ def test_optimi_adamw(self, temp_dir): "lr_scheduler": "cosine", } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -103,6 +105,8 @@ def test_adopt_adamw(self, temp_dir): "lr_scheduler": "cosine", } ) + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -139,6 +143,8 @@ def test_fft_schedule_free_adamw(self, temp_dir): } ) # pylint: disable=duplicate-code + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index 13244a2152..077ea57dd7 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -11,7 +11,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_tensorboard, with_temp_dir @@ -59,6 +59,8 @@ def test_loss_packed(self, temp_dir): cfg.bf16 = True else: cfg.fp16 = True + + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 54f564d0e7..49f9261c9f 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -9,7 +9,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -61,6 +61,7 @@ def test_phi_ft(self, temp_dir): "bf16": "auto", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py index 7a343f4d3f..39d55603f5 100644 --- a/tests/e2e/test_qwen.py +++ b/tests/e2e/test_qwen.py @@ -40,8 +40,10 @@ def test_dpo(self, base_model, temp_dir): "field_messages": "conversation", "field_chosen": "chosen", "field_rejected": "rejected", - "message_field_role": "role", - "message_field_content": "content", + "message_property_mappings": { + "role": "role", + "content": "content", + }, "roles": { "system": ["system"], "user": ["user"], diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py index 7360a99dc8..a7357b5b62 100644 --- a/tests/e2e/test_reward_model_smollm2.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -9,7 +9,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, check_tensorboard, with_temp_dir @@ -66,6 +66,7 @@ def test_rm_lora(self, temp_dir): "use_tensorboard": True, } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py index fdfcbff438..9864a6fecd 100644 --- a/tests/prompt_strategies/conftest.py +++ b/tests/prompt_strategies/conftest.py @@ -7,6 +7,7 @@ from huggingface_hub import hf_hub_download from transformers import AutoTokenizer +from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer from axolotl.utils.chat_templates import _CHAT_TEMPLATES @@ -174,3 +175,32 @@ def fixture_llama3_2_vision_with_hardcoded_date() -> str: modified_template = template.replace(old_date_logic, new_date_logic) return modified_template + + +@pytest.fixture(name="chat_template_jinja_with_optional_fields") +def fixture_chat_template_jinja_with_optional_fields() -> str: + return """{% for message in messages %} +{{'<|im_start|>'}}{{ message['role'] }} +{% if message['thoughts'] is defined %}[Thoughts: {{ message['thoughts'] }}]{% endif %} +{% if message['tool_calls'] is defined %}[Tool: {{ message['tool_calls'][0]['type'] }}]{% endif %} +{{ message['content'] }}{{'<|im_end|>'}} +{% endfor %}""" + + +@pytest.fixture(name="basic_jinja_template_analyzer") +def basic_jinja_template_analyzer(): + return JinjaTemplateAnalyzer( + """{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'user' %}{{'<|user|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|> +' + message['content'] + '<|end|> +'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> +' }}{% else %}{{ eos_token }}{% endif %}""" + ) + + +@pytest.fixture(name="mistral_jinja_template_analyzer") +def mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_jinja): + return JinjaTemplateAnalyzer(mistralv03_tokenizer_chat_template_jinja) diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 8ec4fa1191..84e0576c01 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -38,6 +38,10 @@ def test_llama3_load(self, llama3_tokenizer, assistant_dataset): "chat_template": "llama3", "message_field_role": "role", "message_field_content": "content", + "message_property_mappings": { + "role": "role", + "content": "content", + }, "roles": { "user": ["user"], "assistant": ["assistant"], @@ -74,8 +78,10 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): ChatTemplatePrompter( llama3_tokenizer, chat_template=get_chat_template("llama3"), - message_field_role="role", - message_field_content="content", + message_property_mappings={ + "role": "role", + "content": "content", + }, roles={ "user": ["user"], "assistant": ["assistant"], @@ -114,8 +120,10 @@ def test_phi35(self, phi35_tokenizer, assistant_dataset): ChatTemplatePrompter( phi35_tokenizer, chat_template=get_chat_template("phi_35"), - message_field_role="role", - message_field_content="content", + message_property_mappings={ + "role": "role", + "content": "content", + }, roles={ "user": ["user"], "assistant": ["assistant"], @@ -170,9 +178,11 @@ def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): ChatTemplatePrompter( llama3_tokenizer, chat_template=get_chat_template("llama3"), - message_field_role="role", - message_field_content="content", message_field_training="training", + message_property_mappings={ + "role": "role", + "content": "content", + }, roles={ "user": ["user"], "assistant": ["assistant"], @@ -230,8 +240,10 @@ def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): ChatTemplatePrompter( llama3_tokenizer, chat_template=get_chat_template("llama3"), - message_field_role="from", - message_field_content="value", + message_property_mappings={ + "role": "from", + "content": "value", + }, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -287,8 +299,10 @@ def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): ChatTemplatePrompter( llama3_tokenizer, chat_template=get_chat_template("llama3"), - message_field_role="from", - message_field_content="value", + message_property_mappings={ + "role": "from", + "content": "value", + }, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -344,8 +358,10 @@ def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): ChatTemplatePrompter( llama3_tokenizer, chat_template=get_chat_template("llama3"), - message_field_role="from", - message_field_content="value", + message_property_mappings={ + "role": "from", + "content": "value", + }, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -417,8 +433,7 @@ def test_llama32vision_train_on_assistant( chat_template=get_chat_template( "jinja", jinja_template=llama3_2_vision_chat_template_jinja ), - message_field_role="role", - message_field_content="content", + message_property_mappings={"role": "role", "content": "content"}, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -486,8 +501,7 @@ def test_llama32vision_train_on_tools( chat_template=get_chat_template( "jinja", jinja_template=llama3_2_vision_chat_template_jinja ), - message_field_role="role", - message_field_content="content", + message_property_mappings={"role": "role", "content": "content"}, ), tokenizer=llama3_tokenizer, train_on_inputs=False, diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 7d09b059cc..e5aaa632d1 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -3,7 +3,6 @@ """ import logging -import unittest from copy import deepcopy import pytest @@ -123,8 +122,7 @@ def test_train_on_inputs_true( chat_template=get_chat_template( chat_template, jinja_template=chat_template_jinja ), - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=True, @@ -180,8 +178,7 @@ def test_train_on_inputs_false( chat_template=get_chat_template( chat_template, jinja_template=chat_template_jinja ), - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -241,8 +238,7 @@ def test_roles_to_train_human_assistant_only( chat_template=get_chat_template( chat_template, jinja_template=chat_template_jinja ), - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -307,8 +303,7 @@ def test_roles_to_train_all( chat_template=get_chat_template( chat_template, jinja_template=chat_template_jinja ), - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=True, @@ -360,8 +355,7 @@ def test_empty_roles_to_train( chat_template=get_chat_template( chat_template, jinja_template=chat_template_jinja ), - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -400,8 +394,7 @@ def test_train_on_eos_all( chat_template=get_chat_template( chat_template, jinja_template=chat_template_jinja ), - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -446,8 +439,7 @@ def test_train_on_eos_turn( chat_template=get_chat_template( chat_template, jinja_template=chat_template_jinja ), - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -526,8 +518,7 @@ def test_train_on_eos_last( chat_template=get_chat_template( chat_template, jinja_template=chat_template_jinja ), - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -578,8 +569,7 @@ def test_train_on_eos_none( chat_template=get_chat_template( chat_template, jinja_template=chat_template_jinja ), - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -624,8 +614,7 @@ def test_drop_system_message( chat_template, jinja_template=chat_template_jinja ), drop_system_message=True, - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -668,8 +657,7 @@ def test_custom_roles( chat_template, jinja_template=chat_template_jinja ), roles=custom_roles, - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -741,8 +729,7 @@ def test_message_field_training( ), message_field_training="train", message_field_training_detail="train_detail", - message_field_role="from", - message_field_content="value", + message_property_mappings={"role": "from", "content": "value"}, ), tokenizer=tokenizer, train_on_inputs=False, @@ -911,6 +898,64 @@ def verify_labels(labels_span, should_train, context_message): LOG.debug(f"Final labels: {labels}") LOG.debug(f"Final input_ids: {input_ids}") + def test_get_chat_template_variables( + self, tokenizer, chat_template, chat_template_jinja, eos_token, request + ): + LOG.info("Testing get_chat_template_variables") + + actual_tokenizer, actual_jinja_template = self.setup_tokenizer( + tokenizer, chat_template, chat_template_jinja, eos_token, request + ) + + prompter = ChatTemplatePrompter( + actual_tokenizer, + chat_template=get_chat_template( + chat_template, jinja_template=actual_jinja_template + ), + message_property_mappings={"from": "role", "value": "content"}, + ) + + variables = prompter.get_chat_template_msg_variables( + actual_jinja_template + if actual_jinja_template + else actual_tokenizer.get_chat_template(), + "messages", + ) -if __name__ == "__main__": - unittest.main() + if chat_template == "llama3": + assert variables == {"role", "content"}, ( + f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" + f"Got: {variables}\n" + f"Chat template: {actual_jinja_template}" + ) + elif chat_template == "chatml": + assert variables == {"role", "content"}, ( + f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" + f"Got: {variables}\n" + f"Chat template: {actual_jinja_template}" + ) + elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer": + assert variables == {"role", "content", "tool_call_id", "tool_calls"}, ( + f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n" + f"Got: {variables}\n" + f"Chat template: {actual_jinja_template}" + ) + elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer": + assert variables == {"role", "content"}, ( + f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" + f"Got: {variables}\n" + f"Chat template: {actual_jinja_template}" + ) + elif chat_template == "phi_35": + assert variables == {"role", "content"}, ( + f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" + f"Got: {variables}\n" + f"Chat template: {actual_jinja_template}" + ) + else: + LOG.warning( + f"Unsupported chat template: {chat_template} with {chat_template_jinja}" + ) + raise ValueError( + f"Unsupported chat template: {chat_template} with {chat_template_jinja}" + ) diff --git a/tests/prompt_strategies/test_jinja_template_analyzer.py b/tests/prompt_strategies/test_jinja_template_analyzer.py new file mode 100644 index 0000000000..004f810999 --- /dev/null +++ b/tests/prompt_strategies/test_jinja_template_analyzer.py @@ -0,0 +1,159 @@ +""" +tests for jinja_template_analyzer +""" +import logging + +import pytest + +from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer + +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger("axolotl") + + +class TestJinjaTemplateAnalyzer: + """ + tests for jinja_template_analyzer + """ + + def test_basic_variable_extraction(self, basic_jinja_template_analyzer): + """Test that all top-level variables are correctly extracted.""" + LOG.info("Testing with train_on_inputs=True") + + variables = basic_jinja_template_analyzer.get_template_variables() + expected_vars = {"messages", "add_generation_prompt", "eos_token", "message"} + assert set(variables.keys()) == expected_vars + + def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer): + """Test that all top-level variables are correctly extracted.""" + LOG.info("Testing with train_on_inputs=True") + + variables = mistral_jinja_template_analyzer.get_template_variables() + expected_vars = { + "messages", + "content", + "eos_token", + "message", + "tools", + "system_message", + "loop_messages", + "ns", + "tool_call", + "tool", + "loop", + "bos_token", + "raise_exception", + } + assert set(variables.keys()) == expected_vars + message_vars = variables["message"] + assert message_vars == {"role", "content", "tool_calls", "tool_call_id"} + + def test_message_property_access(self, basic_jinja_template_analyzer): + """Test that properties accessed on 'message' variable are correctly identified.""" + LOG.info("Testing message property access") + + variables = basic_jinja_template_analyzer.get_template_variables() + assert "messages" in variables + assert "message" in variables + assert "role" in variables["message"] + assert "content" in variables["message"] + + def test_detailed_analysis(self, basic_jinja_template_analyzer): + """Test the detailed analysis of variable usage.""" + LOG.info("Testing detailed analysis") + + analysis = basic_jinja_template_analyzer.analyze_template() + + assert analysis["messages"]["is_iterated"] is True + assert "role" in analysis["message"]["accessed_properties"] + assert "content" in analysis["message"]["accessed_properties"] + + assert analysis["add_generation_prompt"]["is_conditional"] is True + assert len(analysis["add_generation_prompt"]["accessed_properties"]) == 0 + + assert not analysis["eos_token"]["is_iterated"] + assert len(analysis["eos_token"]["accessed_properties"]) == 0 + + def test_nested_property_access(self): + """Test handling of nested property access.""" + LOG.info("Testing nested property access") + + template = """{{ user.profile.name }}{{ user.settings['preference'] }}""" + analyzer = JinjaTemplateAnalyzer(template) + variables = analyzer.get_template_variables() + + assert "user" in variables + assert "profile" in variables["user"] + assert "settings" in variables["user"] + + def test_loop_variable_handling(self): + """Test handling of loop variables and their properties.""" + LOG.info("Testing loop variable handling") + + template = """ + {% for item in items %} + {{ item.name }} + {% for subitem in item.subitems %} + {{ subitem.value }} + {% endfor %} + {% endfor %} + """ + analyzer = JinjaTemplateAnalyzer(template) + analysis = analyzer.analyze_template() + + assert analysis["items"]["is_iterated"] + assert "name" in analysis["item"]["accessed_properties"] + assert "subitems" in analysis["item"]["accessed_properties"] + + def test_conditional_variable_usage(self): + """Test detection of variables used in conditional statements.""" + LOG.info("Testing conditional variable usage") + + template = """ + {% if user.is_admin and config.debug_mode %} + {{ debug_info }} + {% endif %} + """ + analyzer = JinjaTemplateAnalyzer(template) + analysis = analyzer.analyze_template() + + assert analysis["user"]["is_conditional"] + assert analysis["config"]["is_conditional"] + assert "is_admin" in analysis["user"]["accessed_properties"] + assert "debug_mode" in analysis["config"]["accessed_properties"] + + def test_complex_expressions(self): + """Test handling of complex expressions and filters.""" + LOG.info("Testing complex expressions and filters") + + template = """ + {{ user.name | upper }} + {{ messages | length > 0 and messages[0].content }} + {{ data['key'].nested['value'] }} + """ + analyzer = JinjaTemplateAnalyzer(template) + variables = analyzer.get_template_variables() + + assert "user" in variables + assert "name" in variables["user"] + assert "messages" in variables + assert "content" in variables["messages"] + assert "data" in variables + + def test_basic_msg_vars(self, basic_jinja_template_analyzer): + """Test that the basic message variables are correctly identified.""" + LOG.info("Testing basic message variables") + + variables = basic_jinja_template_analyzer.get_message_vars() + assert variables == {"role", "content"} + + def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer): + """Test that the mixtral message variables are correctly identified.""" + LOG.info("Testing mixtral message variables") + + variables = mistral_jinja_template_analyzer.get_message_vars() + assert variables == {"role", "content", "tool_calls", "tool_call_id"} + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index 89f642051b..5c1b5a1f77 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -302,3 +302,22 @@ def test_dataset_sharegpt_deprecation(self, minimal_cfg): ) validate_config(cfg) + + def test_message_property_mappings(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + "message_property_mappings": { + "role": "role", + "content": "content", + }, + } + ], + } + ) + + validate_config(cfg) diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py index 31698f05fb..e424b39464 100644 --- a/tests/utils/test_models.py +++ b/tests/utils/test_models.py @@ -116,3 +116,79 @@ def test_set_quantization_config( assert self.model_loader.model_kwargs.get( "quantization_config", BitsAndBytesConfig ) + + def test_message_property_mapping(self): + """Test message property mapping configuration validation""" + from axolotl.utils.config.models.input.v0_4_1 import SFTDataset + + # Test legacy fields are mapped correctly + dataset = SFTDataset( + path="test_path", + message_field_role="role_field", + message_field_content="content_field", + ) + assert dataset.message_property_mappings == { + "role": "role_field", + "content": "content_field", + } + + # Test direct message_property_mapping works + dataset = SFTDataset( + path="test_path", + message_property_mappings={ + "role": "custom_role", + "content": "custom_content", + }, + ) + assert dataset.message_property_mappings == { + "role": "custom_role", + "content": "custom_content", + } + + # Test both legacy and new fields work when they match + dataset = SFTDataset( + path="test_path", + message_field_role="same_role", + message_property_mappings={"role": "same_role"}, + ) + assert dataset.message_property_mappings == { + "role": "same_role", + "content": "content", + } + + # Test both legacy and new fields work when they don't overlap + dataset = SFTDataset( + path="test_path", + message_field_role="role_field", + message_property_mappings={"content": "content_field"}, + ) + assert dataset.message_property_mappings == { + "role": "role_field", + "content": "content_field", + } + + # Test no role or content provided + dataset = SFTDataset( + path="test_path", + ) + assert dataset.message_property_mappings == { + "role": "role", + "content": "content", + } + + # Test error when legacy and new fields conflict + with pytest.raises(ValueError) as exc_info: + SFTDataset( + path="test_path", + message_field_role="legacy_role", + message_property_mappings={"role": "different_role"}, + ) + assert "Conflicting message role fields" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + SFTDataset( + path="test_path", + message_field_content="legacy_content", + message_property_mappings={"content": "different_content"}, + ) + assert "Conflicting message content fields" in str(exc_info.value)