diff --git a/LICENSE b/LICENSE index 0a041280..c64a3e3d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,165 +1,72 @@ - GNU LESSER GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 - - Copyright (C) 2007 Free Software Foundation, Inc. - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - - This version of the GNU Lesser General Public License incorporates -the terms and conditions of version 3 of the GNU General Public -License, supplemented by the additional permissions listed below. - - 0. Additional Definitions. - - As used herein, "this License" refers to version 3 of the GNU Lesser -General Public License, and the "GNU GPL" refers to version 3 of the GNU -General Public License. - - "The Library" refers to a covered work governed by this License, -other than an Application or a Combined Work as defined below. - - An "Application" is any work that makes use of an interface provided -by the Library, but which is not otherwise based on the Library. -Defining a subclass of a class defined by the Library is deemed a mode -of using an interface provided by the Library. - - A "Combined Work" is a work produced by combining or linking an -Application with the Library. The particular version of the Library -with which the Combined Work was made is also called the "Linked -Version". - - The "Minimal Corresponding Source" for a Combined Work means the -Corresponding Source for the Combined Work, excluding any source code -for portions of the Combined Work that, considered in isolation, are -based on the Application, and not on the Linked Version. - - The "Corresponding Application Code" for a Combined Work means the -object code and/or source code for the Application, including any data -and utility programs needed for reproducing the Combined Work from the -Application, but excluding the System Libraries of the Combined Work. - - 1. Exception to Section 3 of the GNU GPL. - - You may convey a covered work under sections 3 and 4 of this License -without being bound by section 3 of the GNU GPL. - - 2. Conveying Modified Versions. - - If you modify a copy of the Library, and, in your modifications, a -facility refers to a function or data to be supplied by an Application -that uses the facility (other than as an argument passed when the -facility is invoked), then you may convey a copy of the modified -version: - - a) under this License, provided that you make a good faith effort to - ensure that, in the event an Application does not supply the - function or data, the facility still operates, and performs - whatever part of its purpose remains meaningful, or - - b) under the GNU GPL, with none of the additional permissions of - this License applicable to that copy. - - 3. Object Code Incorporating Material from Library Header Files. - - The object code form of an Application may incorporate material from -a header file that is part of the Library. You may convey such object -code under terms of your choice, provided that, if the incorporated -material is not limited to numerical parameters, data structure -layouts and accessors, or small macros, inline functions and templates -(ten or fewer lines in length), you do both of the following: - - a) Give prominent notice with each copy of the object code that the - Library is used in it and that the Library and its use are - covered by this License. - - b) Accompany the object code with a copy of the GNU GPL and this license - document. - - 4. Combined Works. - - You may convey a Combined Work under terms of your choice that, -taken together, effectively do not restrict modification of the -portions of the Library contained in the Combined Work and reverse -engineering for debugging such modifications, if you also do each of -the following: - - a) Give prominent notice with each copy of the Combined Work that - the Library is used in it and that the Library and its use are - covered by this License. - - b) Accompany the Combined Work with a copy of the GNU GPL and this license - document. - - c) For a Combined Work that displays copyright notices during - execution, include the copyright notice for the Library among - these notices, as well as a reference directing the user to the - copies of the GNU GPL and this license document. - - d) Do one of the following: - - 0) Convey the Minimal Corresponding Source under the terms of this - License, and the Corresponding Application Code in a form - suitable for, and under terms that permit, the user to - recombine or relink the Application with a modified version of - the Linked Version to produce a modified Combined Work, in the - manner specified by section 6 of the GNU GPL for conveying - Corresponding Source. - - 1) Use a suitable shared library mechanism for linking with the - Library. A suitable mechanism is one that (a) uses at run time - a copy of the Library already present on the user's computer - system, and (b) will operate properly with a modified version - of the Library that is interface-compatible with the Linked - Version. - - e) Provide Installation Information, but only if you would otherwise - be required to provide such information under section 6 of the - GNU GPL, and only to the extent that such information is - necessary to install and execute a modified version of the - Combined Work produced by recombining or relinking the - Application with a modified version of the Linked Version. (If - you use option 4d0, the Installation Information must accompany - the Minimal Corresponding Source and Corresponding Application - Code. If you use option 4d1, you must provide the Installation - Information in the manner specified by section 6 of the GNU GPL - for conveying Corresponding Source.) - - 5. Combined Libraries. - - You may place library facilities that are a work based on the -Library side by side in a single library together with other library -facilities that are not Applications and are not covered by this -License, and convey such a combined library under terms of your -choice, if you do both of the following: - - a) Accompany the combined library with a copy of the same work based - on the Library, uncombined with any other library facilities, - conveyed under the terms of this License. - - b) Give prominent notice with the combined library that part of it - is a work based on the Library, and explaining where to find the - accompanying uncombined form of the same work. - - 6. Revised Versions of the GNU Lesser General Public License. - - The Free Software Foundation may publish revised and/or new versions -of the GNU Lesser General Public License from time to time. Such new -versions will be similar in spirit to the present version, but may -differ in detail to address new problems or concerns. - - Each version is given a distinguishing version number. If the -Library as you received it specifies that a certain numbered version -of the GNU Lesser General Public License "or any later version" -applies to it, you have the option of following the terms and -conditions either of that published version or of any later version -published by the Free Software Foundation. If the Library as you -received it does not specify a version number of the GNU Lesser -General Public License, you may choose any version of the GNU Lesser -General Public License ever published by the Free Software Foundation. - - If the Library as you received it specifies that a proxy can decide -whether future versions of the GNU Lesser General Public License shall -apply, that proxy's public statement of acceptance of any version is -permanent authorization for you to choose that version for the -Library. +License text copyright (c) 2020 MariaDB Corporation Ab, All Rights Reserved. +"Business Source License" is a trademark of MariaDB Corporation Ab. + +Parameters + +Licensor: Arcee AI +Licensed Work: MergeKit Version 0.1.0 or later. The Licensed Work is (c) 2025 + Arcee AI. +Additional Use Grant: You may make production use of the Licensed Work so long as + you and your affiliates, considered both individually and in + the aggregate, do not meet any of the following criteria: + (i) have more than 100 total full-time equivalent employees + and/or contractors, (ii) have more than USD $10 million in + annual recurring revenue, (iii) make available products and + services with more than 1 million daily active users, (iv) are + a publicly traded company with a market capitalization of + greater than USD $300 million, or (v) are a private company + whose most recent financing post-money valuation was greater + than USD $300 million. + + For clarity, the use of the Licensed Work to create models for + production use constitutes production use. +Change Date: Two years from the date the Licensed Work is published. +Change License: GNU Lesser General Public License v3.0 or later + +For information about alternative licensing arrangements for the Licensed Work, +please contact licensing@arcee.ai. + +Notice + +Business Source License 1.1 + +Terms + +The Licensor hereby grants you the right to copy, modify, create derivative +works, redistribute, and make non-production use of the Licensed Work. The +Licensor may make an Additional Use Grant, above, permitting limited production use. + +Effective on the Change Date, or the fourth anniversary of the first publicly +available distribution of a specific version of the Licensed Work under this +License, whichever comes first, the Licensor hereby grants you rights under +the terms of the Change License, and the rights granted in the paragraph +above terminate. + +If your use of the Licensed Work does not comply with the requirements +currently in effect as described in this License, you must purchase a +commercial license from the Licensor, its affiliated entities, or authorized +resellers, or you must refrain from using the Licensed Work. + +All copies of the original and modified Licensed Work, and derivative works +of the Licensed Work, are subject to this License. This License applies +separately for each version of the Licensed Work and the Change Date may vary +for each version of the Licensed Work released by Licensor. + +You must conspicuously display this License on each original or modified copy +of the Licensed Work. If you receive the Licensed Work in original or +modified form from a third party, the terms and conditions set forth in this +License apply to your use of that work. + +Any use of the Licensed Work in violation of this License will automatically +terminate your rights under this License for the current and all other +versions of the Licensed Work. + +This License does not grant you any right in any trademark or logo of +Licensor or its affiliates (provided that you may use a trademark or logo of +Licensor as expressly required by this License). + +TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON +AN "AS IS" BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, +EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND +TITLE. diff --git a/examples/arcee_fusion.yml b/examples/arcee_fusion.yml new file mode 100644 index 00000000..58d50ccd --- /dev/null +++ b/examples/arcee_fusion.yml @@ -0,0 +1,6 @@ +models: + - model: model_a + - model: model_b +merge_method: arcee_fusion +base_model: model_a +dtype: bfloat16 diff --git a/mergekit/_data/architectures/gpt2.json b/mergekit/_data/architectures/gpt2.json index fc7a3201..915beb9d 100644 --- a/mergekit/_data/architectures/gpt2.json +++ b/mergekit/_data/architectures/gpt2.json @@ -44,6 +44,9 @@ "num_layers_config_key": "n_layer", "layer_templates": { "weights": [ + { + "name": "h.${layer_index}.attn.bias" + }, { "name": "transformer.h.${layer_index}.attn.c_attn.weight", "aliases": [ diff --git a/mergekit/architecture.py b/mergekit/architecture.py index af098169..49840b73 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -1,28 +1,23 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import importlib.resources +import logging +import re import string +import warnings from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path from typing import ClassVar, Dict, List, Optional, Tuple, Union +from huggingface_hub import snapshot_download from pydantic import BaseModel, Field from transformers import PretrainedConfig from typing_extensions import Literal import mergekit._data.architectures +from mergekit.io.lazy_tensor_loader import ShardedTensorIndex class WeightInfo(BaseModel, frozen=True): @@ -203,6 +198,115 @@ def _template_substitution( return TemplateWithArithmetic(template).substitute(substitutions) +def _hierarchy(names, layer_prefix=r"\.\d+\.") -> Dict[str, List[str]]: + hierarchy = defaultdict(list) + + # Regular expression to match layers (denoted by .{integer}. by default) + layer_pattern = re.compile(layer_prefix) + + if names: + for name in names: + # Find the layer part of the string (e.g., 'model.layers.0.') + match = layer_pattern.search(name) + if match: + # Extract everything up to the layer identifier + layer_prefix = name[: match.end() - 1] # e.g., 'model.layers.0' + # Extract the parameter name after the layer identifier + param_name = name[match.end() :] # e.g., 'input_layernorm.weight' + # Add the parameter name to the corresponding layer in the hierarchy + hierarchy[layer_prefix].append(param_name) + else: + hierarchy[name].append("") + + return hierarchy + + +class AutomaticArchitectureInfo(ArchitectureInfo, BaseModel): + arch_name: str = Field(default="") + parameter_names: List[str] = Field(default_factory=list) + embed: List[str] = Field(default_factory=list) + layered_parameter_names: Dict[str, List[str]] = Field(default_factory=dict) + prefix_tracker: Dict[str, str] = Field(default_factory=dict) + post_fill_parameters: bool = False + + def __init__( + self, + arch_name: str, + parameter_names: List[str], + prefix_tracker: Optional[Dict[str, str]] = None, + post_fill_parameters: bool = False, + ): + super().__init__() + self.arch_name = arch_name + self.parameter_names = parameter_names + self.layered_parameter_names = _hierarchy(self.parameter_names) + self.prefix_tracker = prefix_tracker or {} + self.embed = self._find_embed_params() + self.post_fill_parameters = post_fill_parameters + + def _find_embed_params(self) -> List[str]: + """Identify embedding parameters (e.g., 'lm_head', 'embed') that may require special handling.""" + embed_params = [] + for name in self.parameter_names: + if any(embedding_name in name for embedding_name in ["lm_head", "embed"]): + embed_params.append(name) + return embed_params + + def name(self) -> str: + """Returns the architecture name.""" + return self.arch_name + + def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + """This architecture does not distinguish pre-weights.""" + return [] + + def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + """This architecture does not distinguish post-weights.""" + return [] + + def layer_weights( + self, index: int, config: PretrainedConfig + ) -> Optional[List[WeightInfo]]: + """ + Retrieves the weights for a specified layer, adjusting names for prefixes if applicable. + """ + layer_name = list(self.layered_parameter_names.keys())[index] + adjusted_layer_name = self._adjust_layer_name(layer_name, config) + + weights = [ + WeightInfo( + name=f"{adjusted_layer_name}.{param}" if param else adjusted_layer_name, + is_embed=(layer_name in self.embed), + ) + for param in self.layered_parameter_names[layer_name] + ] + return ( + weights + if weights + else [ + WeightInfo( + name=adjusted_layer_name, is_embed=(layer_name in self.embed) + ) + ] + ) + + def _adjust_layer_name(self, layer_name: str, config: PretrainedConfig) -> str: + """Adjust layer names by removing any prefix as indicated in the prefix tracker.""" + if config and config.name_or_path in self.prefix_tracker: + prefix = self.prefix_tracker.get(config.name_or_path, "") + if layer_name.startswith(prefix): + return layer_name[len(prefix) :] + return layer_name + + def sliceable(self) -> bool: + """Indicates if the architecture supports slicing.""" + return True + + def num_layers(self, config: PretrainedConfig) -> int: + """Returns the number of layers based on layered parameter names.""" + return len(self.layered_parameter_names) + + class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True): definition: JSONArchitectureDefinition @@ -359,26 +463,317 @@ def _load_all_architectures() -> ( QWEN2_INFO = _load_json_arch("qwen2.json") -def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: - if len(config.architectures) != 1: - raise RuntimeError("More than one architecture in config?") +class ArchitectureInfoUtils: + """Functions for inferring architecture information from a merge configuration.""" - arch_name = config.architectures[0] + @staticmethod + def get_architecture_info(config: PretrainedConfig) -> Optional[ArchitectureInfo]: + """Get architecture info from an existing model config.""" + if len(config.architectures) != 1: + raise RuntimeError("More than one architecture in config?") - if arch_name == MixtralTensorNames.ARCHITECTURE_NAME: - return MixtralTensorNames.from_config(config) + arch_name = config.architectures[0] - if arch_name not in NAME_TO_ARCH: - raise RuntimeError(f"Unsupported architecture {arch_name}") + if arch_name == MixtralTensorNames.ARCHITECTURE_NAME: + return MixtralTensorNames.from_config(config) - candidates = list(NAME_TO_ARCH[arch_name]) - if len(candidates) == 1: - return candidates[0] + if arch_name in NAME_TO_ARCH: + candidates = list(NAME_TO_ARCH[arch_name]) + if len(candidates) == 1: + return candidates[0] - for c in candidates: - if c.definition.expected_model_type == config.model_type: - return c + for c in candidates: + if c.definition.expected_model_type == config.model_type: + return c - raise RuntimeError( - f"Unsupported model_type {config.model_type} for architecture {arch_name}" - ) + warnings.warn(f"No architecture config available for: {arch_name}.") + return None + + @staticmethod + def infer_architecture_info(merge_config) -> AutomaticArchitectureInfo: + """ + Infer architecture info and prefixes for alignment. + Prefixes typically denote where a model is used as a subcomponent of another model. + e.g., [layer.0, layer.1, ...] and []'vision_tower.layer.0', vision_tower.layer.1', ...] + inferring ßprefix = 'vision_tower' is required to align the two models. + + Usage: + Similar to `get_architecture_info`, but requires a merge configuration object rather than a model config. + This is so the common parameter names between all models can be inferred. + """ + param_names = [ + ParameterNamesUtils.get_model_parameter_names(source_model.model.path) + for source_model in merge_config.referenced_models() + ] + base_model = merge_config.base_model + + paired_list = list(zip(param_names, merge_config.referenced_models())) + paired_list.sort(key=lambda x: len(x[0]), reverse=True) + for i, (_, model_name) in enumerate(paired_list): + if model_name == base_model: + paired_list.insert(0, paired_list.pop(i)) + break + param_names, referenced_models = zip(*paired_list) + logging.info(f"Base model selected: {referenced_models[0].model.path}") + + prefixes = [""] + for i in range(1, len(param_names)): + assert len(param_names[0]) >= len( + param_names[i] + ), f"base model names list can't be shorter than model {i} names list" + prefixes.append( + ParameterNamesUtils.find_prefix(param_names[0], param_names[i]) + ) + + common_names = ParameterNamesUtils.find_common_ordered_names( + param_names, prefixes + ) + + common_names = ParameterNamesUtils.remove_size_conflicts( + common_names, referenced_models, prefixes + ) + + ArchitectureInfoUtils.log_info(common_names, param_names, referenced_models) + + if not common_names or any([p is None for p in prefixes]): + raise ValueError("Could not resolve model architecture automatically.") + + prefix_tracker = { + model.model.path: f"{prefix}." if prefix else "" + for model, prefix in zip(referenced_models, prefixes) + } + + arch_name = referenced_models[0].model.path + parameter_names = common_names + + return AutomaticArchitectureInfo( + arch_name=arch_name, + parameter_names=parameter_names, + prefix_tracker=prefix_tracker, + post_fill_parameters=( + referenced_models[0].model.path # base model name + if len(common_names) != len(param_names[0]) + else None # no post-fill needed + ), + ) + + @staticmethod + def log_info(common_names, param_names, referenced_models): + for i in range(1, len(param_names)): + prefix, case_message = ParameterNamesUtils.report_names_similarity( + param_names[0], param_names[i] + ) + logging.info( + f"Model {referenced_models[i].model.path}: \ + \n {f'Best prefix found: {prefix}' if prefix else 'No prefix found'}\ + \n {case_message.replace('MODEL_ID', referenced_models[i].model.path)}" + ) + + if len(common_names) != len(param_names[0]): + warnings.warn( + f"Merging {len(common_names)}/{len(param_names[0])} base model parameters. \ + \n Base model selected: {referenced_models[0].model.path} \ + \n copy_and_fill_missing_params will run when merge is complete, to fill in missing params from base model." + ) + + if len(common_names) < 0.3 * len(param_names[0]): + warnings.warn( + "Not many common parameters found. Are you sure you are merging the correct models?" + ) + + +class ParameterNamesUtils: + """Utility functions for handling parameter names.""" + + @staticmethod + def resolve_model_directory(repo_id: str) -> Path: + """Resolve the model directory (local or Hugging Face Hub).""" + if Path(repo_id).is_dir(): + return Path(repo_id) + + return Path(snapshot_download(repo_id)) + + @staticmethod + def get_model_parameter_names(repo_id: str) -> List[str]: + """Get parameter names of a model from a Hugging Face repo or local directory.""" + model_dir = ParameterNamesUtils.resolve_model_directory(repo_id) + return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys()) + + @staticmethod + def strip_prefix(name: str, prefix: str) -> str: + """Remove a single prefix from the start of a name.""" + if prefix != "" and name.startswith(prefix + "."): + return name[len(prefix) + 1 :] + return name + + @staticmethod + def find_prefix(list1: List[str], list2: List[str]) -> Optional[str]: + """ + Find a prefix in list1 that, after removal, makes list2 an ordered sublist. + """ + assert len(list1) >= len(list2), "params name list1 can't be shorter than list2" + + possible_prefixes = {item.split(".")[0] for item in list1 if "." in item} + possible_prefixes = [""] + list(possible_prefixes) + + prefix_matches = {} + best_prefix = "" # Default to no prefix + for prefix in possible_prefixes: + stripped_list1 = [ + ParameterNamesUtils.strip_prefix(item, prefix) for item in list1 + ] + prefix_matches[prefix] = len( + [item for item in list2 if item in stripped_list1] + ) + + if max(prefix_matches.values()) > prefix_matches[""]: + best_prefix = max(prefix_matches, key=prefix_matches.get) + + return best_prefix + + @staticmethod + def find_common_ordered_names( + param_names: List[List[str]], prefixes: List[str] + ) -> List[str]: + """Identify and return common parameter names across all models, ensuring correct order. Also account for prefix.""" + common_names = set(param_names[0]) + for i in range(1, len(param_names)): + prefix = f"{prefixes[i]}." if prefixes[i] else "" + common_names.intersection_update({prefix + name for name in param_names[i]}) + return [name for name in param_names[0] if name in common_names] + + @staticmethod + def remove_size_conflicts(common_names, referenced_models, prefixes): + model_dirs = [ + ParameterNamesUtils.resolve_model_directory(m.model.path) + for m in referenced_models + ] + model_indices = [ShardedTensorIndex.from_disk(str(dir)) for dir in model_dirs] + + common_name_and_shape = common_names.copy() + removed_names = [] + + for name in common_names: + base_shape = ParameterNamesUtils.tensor_shape(name, model_indices[0]) + + for i in range(1, len(referenced_models)): + other_name = name + prefix = f"{prefixes[i]}." if prefixes[i] else "" + if name.startswith(prefix) and prefix != "": + other_name = name[len(prefix) :] + shape = ParameterNamesUtils.tensor_shape(other_name, model_indices[i]) + + if base_shape != shape: + common_name_and_shape.remove(name) + removed_names.append((name, base_shape, shape, i)) + break + + size_mismatch_count = len(removed_names) + if size_mismatch_count > 0: + logging.warning( + f"Size mismatch detected for {size_mismatch_count}/{size_mismatch_count + len(common_names)} tensors. " + "These names were removed from the merge list." + ) + logging.info( + "The following tensors have different shapes across models and were removed from the merge list:" + ) + for name, base_shape, shape, i in removed_names: + logging.info( + f"Tensor name: {name}, Base model shape: {base_shape}, Mismatched shape: {shape} in model {referenced_models[i].model.path}" + ) + + return common_name_and_shape + + @staticmethod + def are_common_params_ordered(list1: List[str], list2: List[str]) -> bool: + """ + Check if common elements of list2 maintain their relative order in list1. + """ + common_params = set(list1).intersection(set(list2)) + last_index = -1 + + for param in list2: + if param in common_params: + current_index = list1.index(param) + if current_index < last_index: + return False + last_index = current_index + return True + + @staticmethod + def ordered_sublist(list1: List[str], list2: List[str]) -> bool: + """ + Check if list2 is a contiguous ordered sublist of list1. + """ + n, m = len(list1), len(list2) + + for i in range(n - m + 1): + if list1[i : i + m] == list2: + return True + return False + + @staticmethod + def report_names_similarity( + base_names: List[str], other_names: List[str] + ) -> Tuple[Optional[str], str]: + """ + Analyze similarity between parameter names of two models and identify shared prefixes. + + Returns: + best_prefix (str): Best matching prefix for parameter names. + case_message (str): Explanation of the structural relationship. + """ + possible_prefixes = {""} + possible_prefixes.update( + {item.split(".")[0] for item in base_names if "." in item} + ) + + prefixes_subset_overlap = {} + best_prefix = None + case_message = "No common parameter names found for any prefix" + + for prefix in possible_prefixes: + base_names_stripped = [ + ParameterNamesUtils.strip_prefix(name, prefix) for name in base_names + ] + + if ParameterNamesUtils.ordered_sublist(base_names_stripped, other_names): + return prefix, "All params in model have exact match in base model." + + intersection = set(base_names_stripped).intersection(set(other_names)) + prefixes_subset_overlap[prefix] = intersection + + if prefixes_subset_overlap: + best_prefix = max( + prefixes_subset_overlap, key=lambda x: len(prefixes_subset_overlap[x]) + ) + base_names_stripped = [ + ParameterNamesUtils.strip_prefix(name, best_prefix) + for name in base_names + ] + + overlap = len(prefixes_subset_overlap[best_prefix]) + ordered = ParameterNamesUtils.are_common_params_ordered( + base_names_stripped, other_names + ) + mismatched = [ + item for item in other_names if item not in base_names_stripped + ] + mismatched = "\n ".join(mismatched) + case_message = ( + f"{overlap}/{len(other_names)} ({100 * overlap / len(other_names):.2f}%) " + f"of model parameters are in the base model. \n" + f" Name ordering is {'preserved' if ordered else 'not preserved'}.\n" + f" Missing parameters:\n {mismatched}" + ) + + return best_prefix, case_message + + @staticmethod + def tensor_shape(name, index) -> Tuple[int]: + from safetensors import safe_open + + with safe_open( + Path(index.base_path) / index.tensor_paths[name], framework="pt" + ) as f: + return f.get_slice(name).get_shape() diff --git a/mergekit/card.py b/mergekit/card.py index f4ef9436..4195bc27 100644 --- a/mergekit/card.py +++ b/mergekit/card.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging import os diff --git a/mergekit/common.py b/mergekit/common.py index 1667a2ca..5b066d4e 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import binascii import logging diff --git a/mergekit/config.py b/mergekit/config.py index 46265bfd..7c031e54 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -37,7 +25,7 @@ class ConditionalParameter(BaseModel): def evaluate_setting( tensor_name: str, setting: ParameterSetting, t: float = 0 -) -> float: +) -> Optional[float]: if isinstance(setting, (float, int, bool, str)): return setting elif isinstance(setting, list): diff --git a/mergekit/evo/actors.py b/mergekit/evo/actors.py index 2acc1e3e..a3b2ba9a 100644 --- a/mergekit/evo/actors.py +++ b/mergekit/evo/actors.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import gc import logging @@ -35,7 +23,7 @@ vllm = None -from mergekit.architecture import ConfiguredArchitectureInfo, get_architecture_info +from mergekit.architecture import ArchitectureInfoUtils, ConfiguredArchitectureInfo from mergekit.config import MergeConfiguration from mergekit.evo.config import EvolMergeConfiguration from mergekit.evo.genome import InvalidGenotypeError, ModelGenome @@ -150,7 +138,9 @@ def __init__( super().__init__(*args, vllm=vllm, **kwargs) def _maybe_init_model(self, config: MergeConfiguration): - ai = get_architecture_info(self.genome._input_config_example) + ai = ArchitectureInfoUtils.get_architecture_info( + self.genome._input_config_example + ) cfg_out = _model_out_config( config, ai, diff --git a/mergekit/evo/config.py b/mergekit/evo/config.py index 73cc9c8d..2c3961df 100644 --- a/mergekit/evo/config.py +++ b/mergekit/evo/config.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from typing import List, Optional diff --git a/mergekit/evo/genome.py b/mergekit/evo/genome.py index c43037df..0c569ab7 100644 --- a/mergekit/evo/genome.py +++ b/mergekit/evo/genome.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging import os diff --git a/mergekit/evo/helpers.py b/mergekit/evo/helpers.py index 1d165628..e51e4d99 100644 --- a/mergekit/evo/helpers.py +++ b/mergekit/evo/helpers.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging import os diff --git a/mergekit/evo/monkeypatch.py b/mergekit/evo/monkeypatch.py index 9f120a03..72a93139 100644 --- a/mergekit/evo/monkeypatch.py +++ b/mergekit/evo/monkeypatch.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import torch diff --git a/mergekit/evo/strategy.py b/mergekit/evo/strategy.py index fe6fb2cc..fe6b2b03 100644 --- a/mergekit/evo/strategy.py +++ b/mergekit/evo/strategy.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import asyncio import logging diff --git a/mergekit/graph.py b/mergekit/graph.py index 4f0e5a95..b032243c 100644 --- a/mergekit/graph.py +++ b/mergekit/graph.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 """ Module for computational graph execution. @@ -106,6 +94,12 @@ def uses_accelerator(self) -> bool: """ return False + def main_thread_only(self) -> bool: + """ + Returns True if the task should only be executed on the main thread. + """ + return False + class Executor: """ @@ -119,6 +113,7 @@ class Executor: targets (List[Task]): List of target tasks to be executed. schedule (List[Task]): Calculated execution schedule of tasks. dependencies (Dict[Task, Set[Task]]): Dependencies of each task. + cached_values (Optional[Dict[Task, Any]]): Cached values for tasks that have been executed before in a different context. """ math_device: torch.device @@ -126,12 +121,14 @@ class Executor: targets: List[Task] schedule: List[Task] dependencies: Dict[Task, Set[Task]] + cached_values: Optional[Dict[Task, Any]] def __init__( self, tasks: List[Task], math_device: torch.device = torch.device("cpu"), storage_device: torch.device = torch.device("cpu"), + cached_values: Optional[Dict[Task, Any]] = None, ): """ Initializes the Executor with a list of tasks and device configurations. @@ -141,12 +138,20 @@ def __init__( math_device (torch.device, optional): The device for tensor computations. Defaults to CPU. storage_device (torch.device, optional): The device for storing results. Defaults to CPU. """ + self.cached_values = cached_values + self.targets = tasks + if isinstance(math_device, str): + math_device = torch.device(math_device) + if isinstance(storage_device, str): + storage_device = torch.device(storage_device) self.math_device = math_device self.storage_device = storage_device self.schedule = self._make_schedule(tasks) - self.targets = tasks - def run(self, quiet: bool = False) -> Iterator[Tuple[Task, Any]]: + def run( + self, + quiet: bool = False, + ) -> Iterator[Tuple[Task, Any]]: """ Execute the computed schedule and yield the target values. @@ -156,13 +161,18 @@ def run(self, quiet: bool = False) -> Iterator[Tuple[Task, Any]]: # determine last usage of each value, so they can be evicted afterwards last_use_index = {} for idx, task in reversed(list(enumerate(self.schedule))): - for t in self.dependencies[task]: + for t in self.dependencies.get(task, []): if t not in last_use_index: last_use_index[t] = idx if task not in last_use_index: last_use_index[task] = idx + for task in self.cached_values or []: + if task not in last_use_index: + last_use_index[task] = len(self.schedule) + 1 values: Dict[Task, Any] = {} + if self.cached_values: + values.update(self.cached_values) for idx, task in ( pbar := tqdm.tqdm( list(enumerate(self.schedule)), @@ -178,27 +188,14 @@ def run(self, quiet: bool = False) -> Iterator[Tuple[Task, Any]]: # ensure any input tensors are on math device if task asks for it if use_math_device: - if ( - isinstance(value, torch.Tensor) - and value.device != self.math_device - ): - value = value.to(self.math_device) - elif isinstance(value, dict): - for key in value: - if ( - isinstance(value[key], torch.Tensor) - and value[key].device != self.math_device - ): - value[key] = value[key].to(self.math_device) + value = self._move_tensors(value, self.math_device) arguments[name] = value del value res = task.execute(**arguments) del arguments - - if isinstance(res, torch.Tensor) and res.device != self.storage_device: - res = res.to(self.storage_device) + res = self._move_tensors(res, self.storage_device) values[task] = res del res @@ -225,6 +222,23 @@ def execute(self) -> None: for task, value in self.run(): pass + def _move_tensors( + self, value: Any, device: torch.device, non_blocking: Optional[bool] = None + ) -> Any: + if non_blocking is None: + non_blocking = device.type == "cuda" + if isinstance(value, torch.Tensor): + return value.to(device=device, non_blocking=non_blocking) + elif isinstance(value, dict): + return { + k: self._move_tensors(v, device, non_blocking) for k, v in value.items() + } + elif isinstance(value, list): + return [self._move_tensors(v, device, non_blocking) for v in value] + elif isinstance(value, tuple): + return tuple(self._move_tensors(v, device, non_blocking) for v in value) + return value + DUMMY_TASK_VALUE = "!!DUMMY!!" def _make_schedule(self, targets: List[Task]) -> List[Task]: @@ -253,7 +267,8 @@ def _compare_key(task: Union[Task, str]): res = [ t for t in networkx.lexicographical_topological_sort(graph, key=_compare_key) - if t != Executor.DUMMY_TASK_VALUE + if (t != Executor.DUMMY_TASK_VALUE) + and (t not in (self.cached_values or {})) ] return res @@ -266,6 +281,8 @@ def _build_dependencies(self, targets: List[Task]) -> Dict[Task, Set[Task]]: continue task_dependencies[child] = set() + if child in (self.cached_values or {}): + continue for _, dep in child.arguments().items(): task_dependencies[child].add(dep) to_process.append(dep) diff --git a/mergekit/io/__init__.py b/mergekit/io/__init__.py index 520b513c..38acbc2a 100644 --- a/mergekit/io/__init__.py +++ b/mergekit/io/__init__.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from mergekit.io.lazy_tensor_loader import ( LazyTensorLoader, diff --git a/mergekit/io/lazy_tensor_loader.py b/mergekit/io/lazy_tensor_loader.py index bdb74f74..753bc0f8 100644 --- a/mergekit/io/lazy_tensor_loader.py +++ b/mergekit/io/lazy_tensor_loader.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import json import logging @@ -76,30 +64,41 @@ def from_disk(cls, base_path: str) -> "ShardedTensorIndex": ) shards.append(info) - elif os.path.exists(model_path): - shard_name = os.path.basename(model_path) - - # get list of tensors contained in single-file checkpoint - if model_path.lower().endswith(".safetensors"): - with safetensors.safe_open(model_path, framework="pt") as st: - tensor_paths = {key: shard_name for key in st.keys()} - else: - # this is ugly but not much else can be done - shard = torch.load(model_path, map_location="meta") - if "state_dict" in shard: - shard = shard["state_dict"] - - tensor_paths = {key: shard_name for key in shard} - - shards.append( - ShardInfo(os.path.basename(model_path), list(tensor_paths.keys())) + return ShardedTensorIndex( + base_path=base_path, + is_safetensors=is_safetensors, + tensor_paths=tensor_paths, + shards=shards, ) + elif os.path.exists(model_path): + return ShardedTensorIndex.from_file(model_path) + + else: + raise RuntimeError(f"Unable to find model files at {base_path}") + + @classmethod + def from_file(cls, file_path: str) -> "ShardedTensorIndex": + if not os.path.exists(file_path): + raise FileNotFoundError(file_path) + + lower = file_path.lower() + shard_name = os.path.basename(file_path) + if lower.endswith(".safetensors"): + with safetensors.safe_open(file_path, framework="pt") as st: + tensor_paths = {key: shard_name for key in st.keys()} + else: + shard = torch.load(file_path, map_location="meta") + if "state_dict" in shard: + shard = shard["state_dict"] + + tensor_paths = {key: shard_name for key in shard} + return ShardedTensorIndex( - base_path=base_path, - is_safetensors=is_safetensors, + base_path=os.path.dirname(file_path), + is_safetensors=lower.endswith(".safetensors"), tensor_paths=tensor_paths, - shards=shards, + shards=[ShardInfo(shard_name, list(tensor_paths.keys()))], ) diff --git a/mergekit/io/lazy_unpickle.py b/mergekit/io/lazy_unpickle.py index 3c634751..81fd9313 100644 --- a/mergekit/io/lazy_unpickle.py +++ b/mergekit/io/lazy_unpickle.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import codecs import collections @@ -115,6 +103,14 @@ def persistent_load(self, pid: Any) -> Any: return DeferredLoad(name=key, location=location, dtype=get_dtype(storage_type)) +class LazyUnpickleModule: + Unpickler = LazyTorchUnpickler + + @staticmethod + def load(*args, **kwargs): + return LazyTorchUnpickler(*args, **kwargs).load() + + class TorchArchiveReader: """ Class for lazily reading (sections of) files from a torch ZIP archive. diff --git a/mergekit/io/loader.py b/mergekit/io/loader.py index d3fb3612..bff78ca9 100644 --- a/mergekit/io/loader.py +++ b/mergekit/io/loader.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from abc import ABC, abstractmethod from typing import Dict, Optional, Sequence @@ -19,7 +7,12 @@ import safetensors import torch -from mergekit.io.lazy_unpickle import DeferredLoad, TorchArchiveReader, torch_lazy_load +from mergekit.io.lazy_unpickle import ( + DeferredLoad, + LazyUnpickleModule, + TorchArchiveReader, + torch_lazy_load, +) class TensorLoader(ABC): @@ -61,7 +54,7 @@ def __init__(self, path: str, device: Optional[str] = None): self.zip_reader = TorchArchiveReader(path) self.device = device with torch_lazy_load(): - self.index = torch.load(path) + self.index = torch.load(path, pickle_module=LazyUnpickleModule) def get_tensor(self, key: str) -> torch.Tensor: if key not in self.index: diff --git a/mergekit/io/tasks.py b/mergekit/io/tasks.py index 5a680207..d12c0b86 100644 --- a/mergekit/io/tasks.py +++ b/mergekit/io/tasks.py @@ -1,20 +1,9 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import os import re +import threading from typing import Dict, Optional, Tuple import torch @@ -34,13 +23,13 @@ class LoaderCache: lazy_unpickle: bool = False trust_remote_code: bool = False - # singleton instance - _instance: Optional["LoaderCache"] = None + # singleton instance per thread + _instance = threading.local() def __new__(cls) -> "LoaderCache": - if cls._instance is None: - cls._instance = super(LoaderCache, cls).__new__(cls) - return cls._instance + if not hasattr(cls._instance, "value"): + cls._instance.value = super(LoaderCache, cls).__new__(cls) + return cls._instance.value def get(self, model: ModelReference) -> LazyTensorLoader: if model not in self.loaders: @@ -173,6 +162,9 @@ def execute(self, **_kwargs) -> TensorWriter: safe_serialization=self.safe_serialization, ) + def main_thread_only(self): + return True + class SaveTensor(Task[None]): tensor_name: str @@ -214,15 +206,8 @@ def arguments(self) -> Dict[str, Task]: def execute(self, writer: TensorWriter, **kwargs) -> None: writer.finalize() - -class BuildStateDict(Task[Dict[str, torch.Tensor]]): - tensors: ImmutableMap[WeightInfo, Task[torch.Tensor]] - - def arguments(self) -> Dict[str, Task]: - return {str(wi): t for wi, t in self.tensors.items()} - - def execute(self, **kwargs) -> Dict[str, torch.Tensor]: - return {str(wi): t for wi, t in self.tensors.items()} + def main_thread_only(self): + return True class ReturnTensor(Task[torch.Tensor]): diff --git a/mergekit/io/tensor_writer.py b/mergekit/io/tensor_writer.py index e9dd922e..e6b0e459 100644 --- a/mergekit/io/tensor_writer.py +++ b/mergekit/io/tensor_writer.py @@ -1,26 +1,17 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import json import logging import os +import threading from typing import Dict import safetensors import torch +logger = logging.getLogger(__name__) + class TensorWriter: out_path: str @@ -30,6 +21,7 @@ class TensorWriter: current_shard: Dict[str, torch.Tensor] current_shard_size: int safe_serialization: bool + lock: threading.Lock def __init__( self, @@ -46,29 +38,30 @@ def __init__( self.weight_map = {} self.current_shard = {} self.current_shard_size = 0 + self.lock = threading.Lock() def save_tensor(self, name: str, tensor: torch.Tensor, clone: bool = False): if not tensor.is_contiguous(): tensor = tensor.contiguous() - - tensor_size = tensor.numel() * tensor.element_size() - if ( - self.current_shard - and self.current_shard_size + tensor_size > self.max_shard_size - ): - self.flush_current_shard() - if clone: tensor = tensor.clone() - self.current_shard[name] = tensor - self.current_shard_size += tensor_size + tensor_size = tensor.numel() * tensor.element_size() + with self.lock: + if ( + self.current_shard + and self.current_shard_size + tensor_size > self.max_shard_size + ): + self._flush_current_shard() + + self.current_shard[name] = tensor + self.current_shard_size += tensor_size - def flush_current_shard(self): + def _flush_current_shard(self): if not self.current_shard: return - logging.info(f"Writing shard #{self.shards_written+1} to disk") + logger.info(f"Writing shard #{self.shards_written+1} to disk") prefix, extension = self._get_name_components() shard_name = f"{prefix}-{self.shards_written+1}.{extension}" @@ -87,43 +80,50 @@ def flush_current_shard(self): self.shards_written = self.shards_written + 1 def finalize(self): - self.flush_current_shard() + with self.lock: + self._flush_current_shard() - logging.info("Finalizing shard names") + logger.info("Finalizing shard names") - prefix, extension = self._get_name_components() + prefix, extension = self._get_name_components() - # standardize shard names to hf format - total_shards = self.shards_written - name_remap = {} - for idx in range(total_shards): - name_remap[ - f"{prefix}-{idx+1}.{extension}" - ] = f"{prefix}-{idx+1:05d}-of-{total_shards:05d}.{extension}" - - for old_name, new_name in name_remap.items(): - os.rename( - os.path.join(self.out_path, old_name), - os.path.join(self.out_path, new_name), - ) + # standardize shard names to hf format + total_shards = self.shards_written + name_remap = {} + for idx in range(total_shards): + name_remap[ + f"{prefix}-{idx+1}.{extension}" + ] = f"{prefix}-{idx+1:05d}-of-{total_shards:05d}.{extension}" - for key in self.weight_map: - self.weight_map[key] = name_remap[self.weight_map[key]] - - with open( - os.path.join(self.out_path, f"{prefix}.{extension}.index.json"), - "w", - encoding="utf-8", - ) as file: - json.dump( - { - "metadata": { - "mergekit_version": "0.0.6", + if total_shards < 2: + name_remap[f"{prefix}-1.{extension}"] = f"{prefix}.{extension}" + + for old_name, new_name in name_remap.items(): + os.rename( + os.path.join(self.out_path, old_name), + os.path.join(self.out_path, new_name), + ) + + if total_shards < 2: + return + + for key in self.weight_map: + self.weight_map[key] = name_remap[self.weight_map[key]] + + with open( + os.path.join(self.out_path, f"{prefix}.{extension}.index.json"), + "w", + encoding="utf-8", + ) as file: + json.dump( + { + "metadata": { + "mergekit_version": "0.1.0", + }, + "weight_map": self.weight_map, }, - "weight_map": self.weight_map, - }, - file, - ) + file, + ) def _get_name_components(self): if self.safe_serialization: @@ -146,7 +146,7 @@ def _do_save(): and isinstance(e.args[0], str) and "share memory" in e.args[0] ): - logging.warning( + logger.warning( "Your model has duplicated tensors but the --clone-tensors " "flag is not set." ) diff --git a/mergekit/merge.py b/mergekit/merge.py index 111c68a9..994774e3 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -1,23 +1,12 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import importlib import importlib.resources import logging import os import shutil +import warnings from collections import Counter from typing import Optional @@ -25,15 +14,18 @@ import transformers from mergekit._data import chat_templates -from mergekit.architecture import ArchitectureInfo, get_architecture_info +from mergekit.architecture import ArchitectureInfo, ArchitectureInfoUtils from mergekit.card import generate_card from mergekit.config import MergeConfiguration from mergekit.graph import Executor from mergekit.io.tasks import LoaderCache +from mergekit.multigpu_executor import MultiGPUExecutor from mergekit.options import MergeOptions from mergekit.plan import MergePlanner from mergekit.tokenizer import TokenizerInfo +logger = logging.getLogger(__name__) + def run_merge( merge_config: MergeConfiguration, @@ -47,16 +39,7 @@ def run_merge( if not merge_config.models and not merge_config.slices: raise RuntimeError("No output requested") - model_arch_info = [ - get_architecture_info(m.config(trust_remote_code=options.trust_remote_code)) - for m in merge_config.referenced_models() - ] - if not options.allow_crimes: - if not all(a == model_arch_info[0] for a in model_arch_info[1:]): - raise RuntimeError( - "Must specify --allow-crimes to attempt to mix different architectures" - ) - arch_info = model_arch_info[0] + arch_info = _load_arch_info(merge_config, options) # initialize loader cache and set options loader_cache = LoaderCache() @@ -78,7 +61,7 @@ def run_merge( loader_cache.get(model) del pbar - logging.info("Planning operations") + logger.info("Planning operations") targets = MergePlanner( merge_config, arch_info, @@ -86,11 +69,17 @@ def run_merge( out_model_config=cfg_out, ).plan_to_disk(out_path=out_path) - exec = Executor( - tasks=targets, - math_device="cuda" if options.cuda else "cpu", - storage_device="cuda" if options.low_cpu_memory else "cpu", - ) + if options.multi_gpu: + exec = MultiGPUExecutor( + tasks=targets, + storage_device=None if options.low_cpu_memory else "cpu", + ) + else: + exec = Executor( + tasks=targets, + math_device="cuda" if options.cuda else "cpu", + storage_device="cuda" if options.low_cpu_memory else "cpu", + ) tokenizer = None for _task, value in exec.run(quiet=options.quiet): @@ -103,7 +92,7 @@ def run_merge( pad_to_multiple_of = merge_config.tokenizer.pad_to_multiple_of _update_config_vocab(cfg_out, tokenizer, pad_to_multiple_of=pad_to_multiple_of) - logging.info("Saving config") + logger.info("Saving config") cfg_out.save_pretrained(out_path) if options.write_model_card: @@ -130,20 +119,33 @@ def run_merge( merge_config, out_path, trust_remote_code=options.trust_remote_code ) except Exception as e: - logging.error( + logger.error( "Failed to copy tokenizer. The merge was still successful, just copy it from somewhere else.", exc_info=e, ) elif merge_config.chat_template: - logging.warning( + logger.warning( "Chat template specified but no tokenizer found. Chat template will not be saved." ) if tokenizer: - logging.info("Saving tokenizer") + logger.info("Saving tokenizer") _set_chat_template(tokenizer, merge_config) tokenizer.save_pretrained(out_path, safe_serialization=True) + if getattr(arch_info, "post_fill_parameters", False): + from mergekit.scripts.fill_missing_params import copy_and_fill_missing_params + + logging.info( + f"Filling missing parameters from base model {arch_info.post_fill_parameters} into new directory" + ) + copy_and_fill_missing_params( + base_model_repo_id=arch_info.post_fill_parameters, + sub_model_dir=out_path, + ) + logging.info("Deleting initial merge directory: " + out_path) + shutil.rmtree(out_path) + def _set_chat_template( tokenizer: transformers.PreTrainedTokenizerBase, @@ -170,13 +172,13 @@ def _set_chat_template( if template: model_templates.append(template.strip()) except Exception as e: - logging.warning(f"Unable to load tokenizer for {model}", exc_info=e) + logger.warning(f"Unable to load tokenizer for {model}", exc_info=e) if not model_templates: return chat_template = Counter(model_templates).most_common(1)[0][0] - logging.info(f"Auto-selected chat template: {chat_template}") + logger.info(f"Auto-selected chat template: {chat_template}") elif importlib.resources.is_resource(chat_templates, chat_template + ".jinja"): with importlib.resources.open_text( @@ -205,7 +207,7 @@ def _copy_tokenizer( or os.path.exists(os.path.join(donor_model.model.path, "tokenizer.model")) ) ): - logging.info(f"Copying tokenizer from {donor_model}") + logger.info(f"Copying tokenizer from {donor_model}") for file_name in [ "tokenizer_config.json", @@ -222,7 +224,7 @@ def _copy_tokenizer( return # fallback: try actually loading the tokenizer and saving it - logging.info(f"Reserializing tokenizer from {donor_model}") + logger.info(f"Reserializing tokenizer from {donor_model}") tokenizer = transformers.AutoTokenizer.from_pretrained( donor_model.model.path, revision=donor_model.model.revision, @@ -255,7 +257,7 @@ def _model_out_config( ) setattr(res, arch_info.num_layers_config_key(), num_layers) except Exception as e: - logging.warning( + logger.warning( "Unable to set number of layers in output config - you may need to manually correct it.", exc_info=e, ) @@ -274,10 +276,38 @@ def _update_config_vocab( try: config.vocab_size = vocab_size except Exception as e: - logging.warning( + logger.warning( "Unable to set vocabulary size in output config - you may need to manually correct it.", exc_info=e, ) +def _load_arch_info( + merge_config: MergeConfiguration, options: MergeOptions +) -> ArchitectureInfo: + """ + Loads architecture information, handling cases where models lack predefined architecture info. + """ + model_arch_info = [ + ArchitectureInfoUtils.get_architecture_info( + m.config(trust_remote_code=options.trust_remote_code) + ) + for m in merge_config.referenced_models() + ] + + if all(a is not None for a in model_arch_info): + if not options.allow_crimes and not all( + a == model_arch_info[0] for a in model_arch_info[1:] + ): + raise RuntimeError( + "Must specify --allow-crimes to attempt to mix different architectures" + ) + return model_arch_info[0] + else: + warnings.warn("Attempting Automatic Merge.") + model_arch_info = ArchitectureInfoUtils.infer_architecture_info(merge_config) + + return model_arch_info + + __all__ = ["MergeOptions", "run_merge"] diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index d9abd71f..84cf6588 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import mergekit.merge_methods.multislerp from mergekit.merge_methods.base import MergeMethod diff --git a/mergekit/merge_methods/arcee_fusion.py b/mergekit/merge_methods/arcee_fusion.py new file mode 100644 index 00000000..ef38f4ab --- /dev/null +++ b/mergekit/merge_methods/arcee_fusion.py @@ -0,0 +1,134 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 + +from typing import Dict, List, Optional + +import torch +import torch.nn.functional as F +from typing_extensions import override + +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference +from mergekit.graph import Task +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) +from mergekit.merge_methods.rectify_embed import rectify_embed_sizes + + +class DynamicThresholdFusion: + def approximate_quantiles(self, tensor, q): + # Flatten the tensor + flat_tensor = tensor.view(-1) + + # If tensor is too large, sample it + if flat_tensor.numel() > 1e6: + flat_tensor = flat_tensor[torch.randperm(flat_tensor.numel())[:1000000]] + + # Sort the (possibly sampled) tensor + sorted_tensor, _ = torch.sort(flat_tensor) + + # Compute quantile indices + quantile_indices = (q * (sorted_tensor.numel() - 1)).long() + + # Return quantiles + return sorted_tensor[quantile_indices] + + def calculate_dynamic_threshold(self, importance_scores): + # Approximate median and quantiles + median = self.approximate_quantiles(importance_scores, torch.tensor([0.5]))[0] + q1, q3 = self.approximate_quantiles( + importance_scores, torch.tensor([0.25, 0.75]) + ) + + # Calculate IQR + iqr = q3 - q1 + + # Set threshold as median + 1.5 * IQR + dynamic_threshold = median + 1.5 * iqr + + return dynamic_threshold + + def compute_fusion_mask(self, importance_scores): + threshold = self.calculate_dynamic_threshold(importance_scores) + fusion_mask = (importance_scores >= threshold).float() + return fusion_mask, threshold + + +class ArceeFusionMergeTask(Task[torch.Tensor]): + gather_tensors: MergeTensorInput + base_model: ModelReference + weight_info: WeightInfo + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: + if len(tensors) == 1: + return list(tensors.values())[0] + elif len(tensors) != 2: + raise RuntimeError("ArceeFusion merge expects exactly two models") + elif self.base_model not in tensors: + raise RuntimeError("Base model not in input tensors") + + [a, b] = list(tensors.items()) + if a[0] != self.base_model: + [a, b] = [b, a] + prepped_tensors = [a[1], b[1]] + + rectify_embed_sizes(self.weight_info, prepped_tensors) + + importance_scores = self._compute_importance( + prepped_tensors[1], prepped_tensors[0] + ) + dynamic_threshold_fusion = DynamicThresholdFusion() + fusion_mask, _threshold = dynamic_threshold_fusion.compute_fusion_mask( + importance_scores + ) + + delta = prepped_tensors[1] - prepped_tensors[0] + masked_delta = delta * fusion_mask + fused = prepped_tensors[0] + masked_delta + + return fused + + def _compute_importance( + self, params: torch.Tensor, base_params: torch.Tensor, eps: float = 1e-8 + ) -> torch.Tensor: + diff = (params - base_params).abs() + p = F.softmax(params, dim=-1) + eps + q = F.softmax(base_params, dim=-1) + eps + kl_div = torch.sum(p * torch.log(p / q), dim=-1) + return diff * kl_div.unsqueeze(-1) + + +class ArceeFusionMerge(MergeMethod): + def name(self) -> str: + return "arcee_fusion" + + @override + def pretty_name(self) -> Optional[str]: + return "Arcee Fusion" + + @override + def reference_url(self) -> Optional[str]: + return "https://arcee.ai" + + def parameters(self) -> List[ConfigParameterDef]: + return [] + + def make_task( + self, + output_weight: WeightInfo, + tensors: MergeTensorInput, + base_model: Optional[ModelReference], + **kwargs, + ) -> Task[torch.Tensor]: + return ArceeFusionMergeTask( + gather_tensors=tensors, weight_info=output_weight, base_model=base_model + ) diff --git a/mergekit/merge_methods/base.py b/mergekit/merge_methods/base.py index 0bf10133..685dc1ee 100644 --- a/mergekit/merge_methods/base.py +++ b/mergekit/merge_methods/base.py @@ -1,21 +1,10 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from abc import ABC, abstractmethod -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union +import torch from pydantic import BaseModel from typing_extensions import TypeAlias @@ -25,7 +14,25 @@ from mergekit.io.tasks import GatherTensors from mergekit.tokenizer import PermutedEmbeddings -MergeTensorInput: TypeAlias = Union[GatherTensors, PermutedEmbeddings] + +class TensorDictWrapper(Task[Dict[ModelReference, torch.Tensor]]): + tensors: ImmutableMap[ModelReference, Task[torch.Tensor]] + + def arguments(self) -> Dict[str, Task]: + return { + k.model_dump_json( + exclude_none=True, exclude_defaults=True, round_trip=True + ): v + for k, v in self.tensors.items() + } + + def execute(self, **kwargs) -> Dict[ModelReference, torch.Tensor]: + return {ModelReference.model_validate_json(k): v for k, v in kwargs.items()} + + +MergeTensorInput: TypeAlias = Union[ + GatherTensors, PermutedEmbeddings, TensorDictWrapper +] class ConfigParameterDef(BaseModel): diff --git a/mergekit/merge_methods/easy_define.py b/mergekit/merge_methods/easy_define.py index ebeb0eff..a84e229a 100644 --- a/mergekit/merge_methods/easy_define.py +++ b/mergekit/merge_methods/easy_define.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import inspect import typing diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 3e34a306..484b66c6 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from enum import Enum diff --git a/mergekit/merge_methods/linear.py b/mergekit/merge_methods/linear.py index f58f142d..65ba2a8b 100644 --- a/mergekit/merge_methods/linear.py +++ b/mergekit/merge_methods/linear.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import Any, Dict, List, Optional diff --git a/mergekit/merge_methods/model_stock.py b/mergekit/merge_methods/model_stock.py index a1b0665b..581128a5 100644 --- a/mergekit/merge_methods/model_stock.py +++ b/mergekit/merge_methods/model_stock.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from typing import Any, Dict, List, Optional diff --git a/mergekit/merge_methods/multislerp.py b/mergekit/merge_methods/multislerp.py index f0ed8767..04666d5e 100644 --- a/mergekit/merge_methods/multislerp.py +++ b/mergekit/merge_methods/multislerp.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import List, Optional diff --git a/mergekit/merge_methods/nearswap.py b/mergekit/merge_methods/nearswap.py index 9621108d..55460296 100644 --- a/mergekit/merge_methods/nearswap.py +++ b/mergekit/merge_methods/nearswap.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import Any, Dict, List, Optional diff --git a/mergekit/merge_methods/nuslerp.py b/mergekit/merge_methods/nuslerp.py index 6a967064..b64a7e4d 100644 --- a/mergekit/merge_methods/nuslerp.py +++ b/mergekit/merge_methods/nuslerp.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import Any, Dict, List, Optional diff --git a/mergekit/merge_methods/passthrough.py b/mergekit/merge_methods/passthrough.py index ce0922f9..4e9ae071 100644 --- a/mergekit/merge_methods/passthrough.py +++ b/mergekit/merge_methods/passthrough.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import Any, Dict, List, Optional diff --git a/mergekit/merge_methods/rectify_embed.py b/mergekit/merge_methods/rectify_embed.py index 2c481edd..6d38f748 100644 --- a/mergekit/merge_methods/rectify_embed.py +++ b/mergekit/merge_methods/rectify_embed.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging diff --git a/mergekit/merge_methods/registry.py b/mergekit/merge_methods/registry.py index 91e76918..83ad37fe 100644 --- a/mergekit/merge_methods/registry.py +++ b/mergekit/merge_methods/registry.py @@ -1,20 +1,9 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import Dict, List +from mergekit.merge_methods.arcee_fusion import ArceeFusionMerge from mergekit.merge_methods.base import MergeMethod from mergekit.merge_methods.generalized_task_arithmetic import ( ConsensusMethod, @@ -37,6 +26,7 @@ ModelStockMerge(), SCEMerge(), NearSwapMerge(), + ArceeFusionMerge(), # generalized task arithmetic methods GeneralizedTaskArithmeticMerge( consensus_method=None, diff --git a/mergekit/merge_methods/sce.py b/mergekit/merge_methods/sce.py index d485cf4e..0a7b428f 100644 --- a/mergekit/merge_methods/sce.py +++ b/mergekit/merge_methods/sce.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from typing import Any, Dict, List, Optional, Tuple diff --git a/mergekit/merge_methods/slerp.py b/mergekit/merge_methods/slerp.py index 7f4e8d28..11755816 100644 --- a/mergekit/merge_methods/slerp.py +++ b/mergekit/merge_methods/slerp.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import Any, Dict, List, Optional, Union diff --git a/mergekit/moe/arch.py b/mergekit/moe/arch.py index a52ad2d5..e385a48e 100644 --- a/mergekit/moe/arch.py +++ b/mergekit/moe/arch.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from abc import ABC, abstractmethod from typing import List, Optional diff --git a/mergekit/moe/common.py b/mergekit/moe/common.py index 64b888b4..2f8cf8e8 100644 --- a/mergekit/moe/common.py +++ b/mergekit/moe/common.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from typing import Dict, Optional, Tuple diff --git a/mergekit/moe/config.py b/mergekit/moe/config.py index bfb3deb4..ff0874c6 100644 --- a/mergekit/moe/config.py +++ b/mergekit/moe/config.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from typing import List, Optional diff --git a/mergekit/moe/deepseek.py b/mergekit/moe/deepseek.py index 54c2a0b7..9f8a4b1f 100644 --- a/mergekit/moe/deepseek.py +++ b/mergekit/moe/deepseek.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import json import logging @@ -22,7 +10,7 @@ import tqdm import transformers -from mergekit.architecture import get_architecture_info +from mergekit.architecture import ArchitectureInfoUtils from mergekit.moe.arch import MoEOutputArchitecture from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype from mergekit.moe.config import MoEMergeConfig @@ -138,7 +126,7 @@ def write_model( loaders, base_loader, writer = initialize_io(config, out_path, merge_options) shared_loader = loaders.get(shared_def.source_model) if shared_def else None for weight_info in tqdm.tqdm( - get_architecture_info(base_cfg).all_weights(base_cfg), + ArchitectureInfoUtils.get_architecture_info(base_cfg).all_weights(base_cfg), desc="Weights", ): tensor_name = weight_info.name diff --git a/mergekit/moe/mixtral.py b/mergekit/moe/mixtral.py index 6e68770f..5f0c7dfd 100644 --- a/mergekit/moe/mixtral.py +++ b/mergekit/moe/mixtral.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from typing import List, Optional diff --git a/mergekit/moe/qwen.py b/mergekit/moe/qwen.py index 6d505846..f5730a6c 100644 --- a/mergekit/moe/qwen.py +++ b/mergekit/moe/qwen.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from typing import List, Optional diff --git a/mergekit/moe/router.py b/mergekit/moe/router.py index 57454a87..5817dd6a 100644 --- a/mergekit/moe/router.py +++ b/mergekit/moe/router.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging import math diff --git a/mergekit/multigpu_executor.py b/mergekit/multigpu_executor.py new file mode 100644 index 00000000..73c41e42 --- /dev/null +++ b/mergekit/multigpu_executor.py @@ -0,0 +1,283 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 +""" +Implementation of multi-GPU parallel task execution. + +Handles distribution of parallelizable tasks across multiple GPUs while respecting: +- Main-thread-only task requirements +- Task dependency graphs +- GPU assignment of connected task components +- Intermediate result storage locations +""" + +import concurrent.futures +import logging +import queue +import threading +from collections import defaultdict +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple + +import networkx as nx +import torch +import tqdm + +from .graph import Executor, Task + +logger = logging.getLogger(__name__) + + +class MultiGPUExecutor: + """ + Execute tasks across multiple GPUs. + + Attributes: + num_gpus: Number of GPUs to utilize (None = all available) + storage_device: Device for storing tensors between stages + targets: Final output tasks to retain results for + """ + + def __init__( + self, + tasks: List[Task], + num_gpus: Optional[int] = None, + storage_device: Optional[torch.device] = None, + ): + """ + Initialize the executor with a list of tasks. + + Args: + tasks: List of tasks to execute + num_gpus: Number of GPUs to utilize (None = all available) + storage_device: Device for storing tensors between stages + """ + self.results = {} + self.targets = set(tasks) + self.storage_device = storage_device + + if num_gpus is None: + num_gpus = torch.cuda.device_count() + + # Create temp executor to get full schedule + temp_exec = Executor(tasks) + ordered_tasks = temp_exec._make_schedule(tasks) + self.dependencies = temp_exec.dependencies + self.total_tasks = len(ordered_tasks) + + leading_tasks = self._find_leading_tasks(ordered_tasks) + trailing_tasks = self._find_trailing_tasks(ordered_tasks) + self.trailing_main_tasks = [t for t in ordered_tasks if t in trailing_tasks] + self.leading_main_tasks = [t for t in ordered_tasks if t in leading_tasks] + + self.trailing_dependencies = set() + for task in self.trailing_main_tasks: + self.trailing_dependencies.update(self.dependencies[task]) + + parallel_tasks = [ + t + for t in ordered_tasks + if (t not in trailing_tasks and t not in leading_tasks) + ] + logger.info( + f"Task breakdown: {len(self.leading_main_tasks)} leading, " + f"{len(parallel_tasks)} parallel, " + f"{len(self.trailing_main_tasks)} trailing" + ) + if any(t.main_thread_only() for t in parallel_tasks): + raise RuntimeError( + "Main-thread-only tasks must be either leading or trailing" + ) + self.gpu_assignments = self._assign_islands_to_gpus(parallel_tasks, num_gpus) + + self.task_completion_queue = queue.Queue() + self.done_event = threading.Event() + + def run(self, quiet: bool = False) -> Iterator[Tuple[Task, Any]]: + """ + Execute all tasks and yield target results. + + Yields: + Iterator[Tuple[Task, Any]]: Task and result pairs + """ + with tqdm.tqdm( + total=self.total_tasks, disable=quiet, desc="Executing graph" + ) as pbar: + if self.leading_main_tasks: + exec = Executor( + self.leading_main_tasks, + math_device=self.storage_device or torch.device("cpu"), + storage_device=self.storage_device or torch.device("cpu"), + ) + for task, result in exec.run(quiet=True): + pbar.update() + self.results[task] = result + + logger.debug("Leading tasks complete, beginning parallel execution") + + def update_progress(): + while not self.done_event.is_set(): + try: + task, result = self.task_completion_queue.get(timeout=0.1) + self.results[task] = result + pbar.update() + except queue.Empty: + continue + + progress_thread = threading.Thread(target=update_progress) + progress_thread.start() + + # Run parallel tasks + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + for device, island_tasks in self.gpu_assignments.items(): + futures.append( + executor.submit( + self._device_worker, + task_list=island_tasks, + cached_values=dict(self.results), + device=device, + quiet=True, + ) + ) + + for future in concurrent.futures.as_completed(futures): + if future.exception(): + self.done_event.set() + executor.shutdown(wait=False) + raise future.exception() + + self.done_event.set() + progress_thread.join() + + logger.debug("Parallel tasks complete") + + # Run main thread tasks + if self.trailing_main_tasks: + exec = Executor( + self.trailing_main_tasks, + math_device=self.storage_device or torch.device("cpu"), + storage_device=self.storage_device or torch.device("cpu"), + cached_values=dict(self.results), + ) + for task, result in exec.run(quiet=True): + pbar.update() + if task in self.targets: + self.results[task] = result + + # Yield final results + for task, result in self.results.items(): + if task in self.targets: + yield task, result + + def execute(self) -> None: + """Execute all tasks and discard results""" + for _ in self.run(quiet=False): + pass + + def _find_trailing_tasks(self, tasks: List[Task]) -> Set[Task]: + """ + Identify tasks that must execute AFTER parallel GPU tasks complete. + + Trailing tasks must: + - Require main thread execution + - Not have non-trailing dependants + """ + dependants = defaultdict(set) + for task, deps in self.dependencies.items(): + for dep in deps: + dependants[dep].add(task) + + trailing_tasks = set() + to_explore = set([t for t in tasks if not dependants[t]]) + while to_explore: + task = to_explore.pop() + if not task.main_thread_only(): + continue + if all(d in trailing_tasks for d in dependants[task]): + trailing_tasks.add(task) + to_explore.update(self.dependencies[task]) + return trailing_tasks + + def _find_leading_tasks(self, tasks: List[Task]) -> Set[Task]: + """Identify tasks that must execute BEFORE parallel GPU tasks. + + Leading tasks must: + - Require main thread execution + - Not have non-leading dependencies + """ + leading_tasks = set() + for task in tasks: + if not task.main_thread_only(): + continue + if self.dependencies[task] and any( + dep not in leading_tasks for dep in self.dependencies[task] + ): + continue + leading_tasks.add(task) + return leading_tasks + + def _assign_islands_to_gpus( + self, tasks: List[Task], num_gpus: int + ) -> Dict[torch.device, List[Task]]: + """ + Assign task islands to GPUs. + + Task islands (weakly connected components) are groups of tasks that + can execute independently. This method identifies islands in the + non-trailing, non-leading task graph and assigns them to devices. + """ + + island_graph = nx.DiGraph() + island_graph.add_nodes_from(tasks) + + # Add edges only between parallel tasks + for task in tasks: + for dep in self.dependencies[task]: + if dep in tasks: + island_graph.add_edge(dep, task) + + islands = list(nx.weakly_connected_components(island_graph)) + logger.info(f"Found {len(islands)} islands in parallel task graph") + assignments = {} + for island in islands: + # Borrow orderings from original task list + island_tasks = [t for t in tasks if t in island] + # assign to GPU with fewest tasks + device_idx = min( + range(num_gpus), + key=lambda i: len(assignments.get(torch.device(f"cuda:{i}"), [])), + ) + device = torch.device(f"cuda:{device_idx}") + assignments[device] = assignments.get(device, []) + island_tasks + return assignments + + def _device_worker( + self, + task_list: List[Task], + cached_values: Dict[Task, Any], + device: torch.device, + quiet: bool, + ): + """ + Execute a set of tasks on a single GPU. + + Args: + island_tasks: List of tasks to execute + cached_values: Values of previously-executed dependent tasks + device: Device to execute tasks on + quiet: Suppress progress bar output + """ + stream = torch.cuda.Stream(device=device) + with torch.cuda.stream(stream): + exec = Executor( + tasks=task_list, + math_device=device, + storage_device=self.storage_device or device, + cached_values=cached_values, + ) + count = 0 + for task, result in exec.run(quiet=quiet): + count += 1 + if not (task in self.targets or task in self.trailing_dependencies): + result = None + self.task_completion_queue.put((task, result)) + torch.cuda.synchronize(device=device) diff --git a/mergekit/options.py b/mergekit/options.py index ae6dbcba..ef29908a 100644 --- a/mergekit/options.py +++ b/mergekit/options.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import functools import typing @@ -40,6 +28,7 @@ class MergeOptions(BaseModel): safe_serialization: bool = True quiet: bool = False read_to_gpu: bool = False + multi_gpu: bool = False OPTION_HELP = { @@ -58,6 +47,7 @@ class MergeOptions(BaseModel): "safe_serialization": "Save output in safetensors. Do this, don't poison the world with more pickled models.", "quiet": "Suppress progress bars and other non-essential output", "read_to_gpu": "Read model weights directly to GPU", + "multi_gpu": "Use multi-gpu parallel graph execution engine", } diff --git a/mergekit/plan.py b/mergekit/plan.py index 297e25fa..65e63bef 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from functools import lru_cache diff --git a/mergekit/scripts/ABM/activations_based_merge.py b/mergekit/scripts/ABM/activations_based_merge.py index cb3c912a..3834892d 100644 --- a/mergekit/scripts/ABM/activations_based_merge.py +++ b/mergekit/scripts/ABM/activations_based_merge.py @@ -8,7 +8,7 @@ import tqdm from transformers import AutoTokenizer -from mergekit.architecture import get_architecture_info +from mergekit.architecture import ArchitectureInfoUtils from mergekit.common import ModelReference, dtype_from_name from mergekit.io.tasks import LoaderCache from mergekit.io.tensor_writer import TensorWriter @@ -62,7 +62,7 @@ def main( ) model_config = model.config(trust_remote_code=merge_options.trust_remote_code) - model_arch_info = get_architecture_info( + model_arch_info = ArchitectureInfoUtils.get_architecture_info( model.config(trust_remote_code=merge_options.trust_remote_code) ) diff --git a/mergekit/scripts/ABM/extract_activations.py b/mergekit/scripts/ABM/extract_activations.py index 7cb5961b..3f7c151b 100644 --- a/mergekit/scripts/ABM/extract_activations.py +++ b/mergekit/scripts/ABM/extract_activations.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from transformers import AutoModel, AutoTokenizer, DefaultDataCollator -from mergekit.architecture import _template_substitution, get_architecture_info +from mergekit.architecture import ArchitectureInfoUtils, _template_substitution from mergekit.common import ModelReference logging.basicConfig(level=logging.INFO) @@ -130,7 +130,7 @@ def main( model = ModelReference.model_validate(model_path) model_config = model.config() - model_arch_info = get_architecture_info(model_config) + model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config) _json = model_arch_info.definition diff --git a/mergekit/scripts/ABM/extract_permutation_matrices.py b/mergekit/scripts/ABM/extract_permutation_matrices.py index 75c58692..4c862664 100644 --- a/mergekit/scripts/ABM/extract_permutation_matrices.py +++ b/mergekit/scripts/ABM/extract_permutation_matrices.py @@ -8,7 +8,7 @@ import scipy import torch -from mergekit.architecture import _template_substitution, get_architecture_info +from mergekit.architecture import ArchitectureInfoUtils, _template_substitution from mergekit.common import ModelReference @@ -147,7 +147,7 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device): model_config = model.config() - model_arch_info = get_architecture_info(model_config) + model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config) _json = model_arch_info.definition diff --git a/mergekit/scripts/bakllama.py b/mergekit/scripts/bakllama.py index e363ec7b..21218270 100644 --- a/mergekit/scripts/bakllama.py +++ b/mergekit/scripts/bakllama.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import List, Optional diff --git a/mergekit/scripts/evolve.py b/mergekit/scripts/evolve.py index d8996522..5c5a3228 100644 --- a/mergekit/scripts/evolve.py +++ b/mergekit/scripts/evolve.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging import os diff --git a/mergekit/scripts/extract_lora.py b/mergekit/scripts/extract_lora.py index c589bba7..2f3ddc55 100644 --- a/mergekit/scripts/extract_lora.py +++ b/mergekit/scripts/extract_lora.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import json import logging diff --git a/mergekit/scripts/fill_missing_params.py b/mergekit/scripts/fill_missing_params.py new file mode 100644 index 00000000..81aec1b3 --- /dev/null +++ b/mergekit/scripts/fill_missing_params.py @@ -0,0 +1,199 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 +import logging +import shutil +from pathlib import Path + +import click +import torch +from safetensors import safe_open +from tqdm import tqdm + +from mergekit.architecture import ParameterNamesUtils +from mergekit.io.lazy_tensor_loader import ShardedTensorIndex +from mergekit.io.tensor_writer import TensorWriter + +DEFAULT_SHARD_SIZE = 5 * 1024**3 + + +def load_tensor_from_file(tensor_name: str, tensor_file: str = None) -> torch.Tensor: + """ + Load a specific tensor from a .safetensors file. + + :param tensor_name: The name of the tensor to load. + :param tensor_file: The .safetensors file that contains the tensor. + :return: The loaded tensor as a PyTorch tensor. + """ + with safe_open(tensor_file, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + return f.get_tensor(tensor_name) + else: + raise ValueError( + f"Tensor '{tensor_name}' not found in file '{tensor_file}'" + ) + + +def load_tensor_from_index(tensor_name: str, index: ShardedTensorIndex) -> torch.Tensor: + """ + Load a specific tensor from a ShardedTensorIndex. + + :param tensor_name: The name of the tensor to load. + :param index: The ShardedTensorIndex containing the tensor. + :return: The loaded tensor as a PyTorch tensor. + """ + return load_tensor_from_file( + tensor_name, Path(index.base_path) / index.tensor_paths[tensor_name] + ) + + +def copy_and_fill_missing_params( + base_model_repo_id: str, + sub_model_dir: str, + max_shard_size: int = DEFAULT_SHARD_SIZE, + output_dir: str = None, +): + """ + Merge submodel weights into a base model and fill in missing parameters. + + Use Case: + Given a submodel (e.g., a language model) that is structurally identical to a subset of a + larger base model (e.g., a vision-language model). + The submodel contains only a subset of the weights (e.g., for the language model part), + while the base model contains all weights required for the complete architecture. + + This function replaces the shared parameters in the base model with those from the submodel, + fascilitating testing after generating submodel parameters through merging. + + + + Parameters: + base_model_repo_id (str): + The path to the base model's directory or its Hugging Face repository ID. + This model provides all parameters and files required for the complete model. + sub_model_dir (str): + The path to the submodel's directory containing the merged weights. + Parameters in this directory replace the corresponding weights in the base model. + max_shard_size (int, optional): + The maximum shard size for saving model weights, in bytes. Defaults to 5 GiB. + output_dir (str, optional): + The directory to save the final merged model. If not provided, a default directory + is created using the names of the base and submodel. + + Returns: + pathlib.Path: + The path to the directory where the final merged model is saved. + + Raises: + AssertionError: + If the base model has fewer parameters than the submodel, ensuring compatibility. + ValueError: + If tensor loading or parameter alignment issues occur. + + Notes: + - The function does not modify the original base or submodel directories. + - For Hugging Face repository IDs, ensure the `HF_HOME` environment variable is properly configured. + - Non-shared parameters, as well as any additional configuration files, are copied from the base model to create a fully functional model. + """ + # Prepare paths and configurations + output_dir = ( + Path(sub_model_dir).parent + / f"{Path(base_model_repo_id).stem}--{Path(sub_model_dir).stem}" + if output_dir is None + else Path(output_dir) + ) + output_dir.mkdir(parents=True, exist_ok=True) + + # Resolve the model directory for the base model + base_dir = ParameterNamesUtils.resolve_model_directory(base_model_repo_id) + files_to_copy = [ + item + for item in base_dir.rglob("*") + if item.is_file() and item.suffix not in {".safetensors", ".bin"} + ] + + # Copy non-parameter files from the base model + with tqdm( + total=len(files_to_copy), desc="Copying non-parameter files", unit="file" + ) as pbar: + for item in files_to_copy: + target_path = output_dir / item.relative_to(base_dir) + target_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(item, target_path) + pbar.update(1) + + # Retrieve parameter names from both models + base_param_names = ParameterNamesUtils.get_model_parameter_names(base_model_repo_id) + submodel_param_names = ParameterNamesUtils.get_model_parameter_names(sub_model_dir) + + # Ensure the base model has more parameters than the submodel + assert len(base_param_names) > len(submodel_param_names), ( + f"Base model must have more parameters than the submodel. " + f"Base: {len(base_param_names)}, Submodel: {len(submodel_param_names)}" + ) + + # Determine parameter prefix and find common names + prefix = ParameterNamesUtils.find_prefix(base_param_names, submodel_param_names) + common_param_names = ParameterNamesUtils.find_common_ordered_names( + [base_param_names, submodel_param_names], ["", prefix] + ) + + # Load parameter indices for tensor storage + base_index = ShardedTensorIndex.from_disk(str(base_dir)) + submodel_index = ShardedTensorIndex.from_disk( + str(ParameterNamesUtils.resolve_model_directory(sub_model_dir)) + ) + + # Initialize the tensor writer + writer = TensorWriter( + out_path=str(output_dir), max_shard_size=max_shard_size, safe_serialization=True + ) + + # Copy and fill parameters from base to submodel + for name, tensor_path in tqdm( + base_index.tensor_paths.items(), + total=len(base_index.tensor_paths), + desc="Merging tensors", + unit="tensor", + ): + tensor = load_tensor_from_index(name, base_index) + + # Check if the parameter is common to both models + if name in common_param_names: + submodel_name = ParameterNamesUtils.strip_prefix(name, prefix) + submodel_tensor = load_tensor_from_index(submodel_name, submodel_index) + + # Log size mismatches + if submodel_tensor.size() != tensor.size(): + logging.warning( + f"Size mismatch for tensor '{name}': {tensor.size()} vs {submodel_tensor.size()}" + ) + + tensor = submodel_tensor + + # Save the tensor to the output directory + writer.save_tensor(name, tensor.clone()) + + # Finalize the writer to ensure data is saved and index file is created + writer.finalize() + + return output_dir + + +@click.command() +@click.argument("base_model_repo_id", type=str) +@click.argument("sub_model_dir", type=str) +@click.option("--max_shard_size", type=int, default=DEFAULT_SHARD_SIZE) +@click.option("--output_dir", type=str, default=None) +def main( + base_model_repo_id, + sub_model_dir, + max_shard_size, + output_dir, +): + copy_and_fill_missing_params( + base_model_repo_id, sub_model_dir, max_shard_size, output_dir + ) + + +if __name__ == "__main__": + main() diff --git a/mergekit/scripts/layershuffle.py b/mergekit/scripts/layershuffle.py index c7af0053..22e9d161 100644 --- a/mergekit/scripts/layershuffle.py +++ b/mergekit/scripts/layershuffle.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import random from typing import List @@ -19,7 +7,7 @@ import click import yaml -from mergekit.architecture import get_architecture_info +from mergekit.architecture import ArchitectureInfoUtils from mergekit.common import ModelReference from mergekit.config import ( InputSliceDefinition, @@ -76,7 +64,7 @@ def main( models = [ModelReference.parse(m) for m in model] m0_cfg = models[0].config() - arch_info = get_architecture_info(m0_cfg) + arch_info = ArchitectureInfoUtils.get_architecture_info(m0_cfg) total_num_layers = arch_info.num_layers(m0_cfg) out_slices: List[OutputSliceDefinition] = [] diff --git a/mergekit/scripts/legacy.py b/mergekit/scripts/legacy.py index 7a772d61..7ca9ca7d 100644 --- a/mergekit/scripts/legacy.py +++ b/mergekit/scripts/legacy.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import List, Optional diff --git a/mergekit/scripts/merge_raw_pytorch.py b/mergekit/scripts/merge_raw_pytorch.py new file mode 100644 index 00000000..a7475582 --- /dev/null +++ b/mergekit/scripts/merge_raw_pytorch.py @@ -0,0 +1,254 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 + +import logging +from typing import Dict, List, Optional + +import click +import torch +import tqdm +import yaml +from pydantic import BaseModel + +import mergekit.merge_methods as merge_methods +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference, dtype_from_name +from mergekit.config import ParameterSetting, evaluate_setting +from mergekit.graph import Executor, Task +from mergekit.io import LazyTensorLoader, ShardedTensorIndex +from mergekit.io.tasks import FinalizeModel, SaveTensor, TensorWriterTask +from mergekit.merge_methods.base import MergeMethod, TensorDictWrapper +from mergekit.options import MergeOptions, add_merge_options + + +class InputModelDefinition(BaseModel, frozen=True): + model: str + parameters: Optional[Dict[str, ParameterSetting]] = None + + +class RawPyTorchMergeConfig(BaseModel, frozen=True): + merge_method: str + parameters: Optional[Dict[str, ParameterSetting]] + models: List[InputModelDefinition] + dtype: Optional[str] = None + base_model: Optional[str] = None + + +class SimpleLoaderCache: + loaders: Dict[str, LazyTensorLoader] + lazy_unpickle: bool = False + _instance: Optional["SimpleLoaderCache"] = None + + def __new__(cls) -> "SimpleLoaderCache": + if cls._instance is None: + cls._instance = super(SimpleLoaderCache, cls).__new__(cls) + cls._instance.loaders = {} + return cls._instance + + def get(self, model: str) -> LazyTensorLoader: + if model not in self.loaders: + self.loaders[model] = LazyTensorLoader( + ShardedTensorIndex.from_file(model), lazy_unpickle=self.lazy_unpickle + ) + return self.loaders[model] + + +class SimpleLoadTensor(Task[torch.Tensor]): + model: str + tensor_name: str + dtype: Optional[str] = None + device: Optional[str] = None + + def arguments(self) -> Dict[str, Task]: + return {} + + def execute(self) -> torch.Tensor: + loader = SimpleLoaderCache().get(self.model) + tensor = loader.get_tensor(self.tensor_name, device=self.device or "cpu") + if tensor is None: + return None + if dt := dtype_from_name(self.dtype): + tensor = tensor.to(dtype=dt) + return tensor + + +def plan_flat_merge( + config: RawPyTorchMergeConfig, + out_path: str, + tensor_union: bool, + tensor_intersection: bool, + options: MergeOptions, +) -> List[Task[torch.Tensor]]: + merge_method = merge_methods.get(config.merge_method) + + loaders = SimpleLoaderCache() + loaders.lazy_unpickle = options.lazy_unpickle + all_tensor_names = set() + for model_def in tqdm.tqdm(config.models, desc="Preparing model loaders"): + loader = loaders.get(model_def.model) + all_tensor_names.update(loader.index.tensor_paths.keys()) + + writer_task = TensorWriterTask( + out_path=out_path, + max_shard_size=options.out_shard_size, + safe_serialization=options.safe_serialization, + ) + + save_tasks = [] + for tensor_name in tqdm.tqdm(list(all_tensor_names), desc="Planning operations"): + inputs = { + model_def.model: SimpleLoadTensor( + model=model_def.model, tensor_name=tensor_name, dtype=config.dtype + ) + for model_def in config.models + } + if config.base_model is not None and config.base_model not in inputs: + inputs[config.base_model] = SimpleLoadTensor( + model=config.base_model, tensor_name=tensor_name, dtype=config.dtype + ) + + has_tensor = [ + lt.model + for lt in inputs.values() + if lt.tensor_name in loaders.get(lt.model).index.tensor_paths + ] + if len(has_tensor) < len(inputs): + if tensor_intersection: + continue + elif tensor_union: + pass + else: + missing = set(inputs) - set(has_tensor) + logging.warning(f"Tensor {tensor_name} not found in models:") + for model in missing: + logging.warning(f" {model}") + logging.warning("Was found in:") + for model in has_tensor: + logging.warning(f" {model}") + raise RuntimeError("Missing tensors") + + inputs = { + ModelReference.model_validate({"model": {"path": k}}): v + for k, v in inputs.items() + } + + global_params, tensor_params = construct_param_dicts( + config, merge_method, tensor_name + ) + + tensor_task = merge_method.make_task( + output_weight=WeightInfo(name=tensor_name), + tensors=TensorDictWrapper(tensors=inputs), + parameters=ImmutableMap(global_params), + tensor_parameters=ImmutableMap( + data={ + key: ImmutableMap(data=tensor_params[key]) for key in tensor_params + } + ), + base_model=( + ModelReference.model_validate({"model": {"path": config.base_model}}) + if config.base_model is not None + else None + ), + ) + save_task = SaveTensor( + tensor_name=tensor_name, + tensor_task=tensor_task, + writer_task=writer_task, + clone=options.clone_tensors, + dtype=config.dtype, + ) + save_tasks.append(save_task) + + finalize = FinalizeModel(tensor_save_tasks=save_tasks, writer_task=writer_task) + return save_tasks + [finalize] + + +def construct_param_dicts( + config: RawPyTorchMergeConfig, merge_method: MergeMethod, tensor_name: str +): + global_params = {} + for param_def in merge_method.parameters(): + if param_def.name in config.parameters: + value = evaluate_setting(tensor_name, config.parameters[param_def.name]) + if value is not None: + global_params[param_def.name] = value + + if param_def.name not in global_params: + if param_def.required: + raise RuntimeError( + f"Missing required parameter {param_def.name} for merge method {merge_method}" + ) + else: + global_params[param_def.name] = param_def.default_value + + tensor_params = {} + for param_def in merge_method.tensor_parameters(): + for model_def in config.models: + mr = ModelReference.model_validate({"model": {"path": model_def.model}}) + tensor_params[mr] = tensor_params.get(mr, {}) + if value := evaluate_setting( + tensor_name, model_def.parameters.get(param_def.name, []) + ): + tensor_params[mr][param_def.name] = value + elif value := evaluate_setting( + tensor_name, config.parameters.get(param_def.name, []) + ): + tensor_params[mr][param_def.name] = value + elif param_def.required: + raise RuntimeError( + f"Missing required parameter {param_def.name} for model {mr} tensor {tensor_name}" + ) + else: + tensor_params[mr][param_def.name] = param_def.default_value + return global_params, tensor_params + + +@click.command("mergekit-pytorch") +@click.argument("config_path", type=click.Path(exists=True)) +@click.argument("out_path", type=click.Path()) +@click.option( + "--tensor-intersection", + "-i", + type=bool, + default=False, + is_flag=True, + help="Only merge tensors that are present in all input models", +) +@click.option( + "--tensor-union", + "-u", + type=bool, + default=False, + is_flag=True, + help="Merge all tensors present in any input model", +) +@add_merge_options +def main( + config_path: str, + out_path: str, + tensor_union: bool, + tensor_intersection: bool, + merge_options: MergeOptions, +): + """Merge arbitrary PyTorch models. + + Uses similar configuration syntax to `mergekit-yaml`, minus the + `slices` sections. Each input model should be the path on disk to a + pytorch pickle file or safetensors file.""" + with open(config_path, "r", encoding="utf-8") as file: + config_source = file.read() + + config = RawPyTorchMergeConfig.model_validate(yaml.safe_load(config_source)) + tasks = plan_flat_merge( + config, out_path, tensor_union, tensor_intersection, merge_options + ) + + executor = Executor( + tasks, + math_device="cuda" if merge_options.cuda else "cpu", + storage_device=( + "cuda" if (merge_options.cuda and merge_options.low_cpu_memory) else "cpu" + ), + ) + executor.execute() diff --git a/mergekit/scripts/moe.py b/mergekit/scripts/moe.py index 7b10e268..0a532bfa 100644 --- a/mergekit/scripts/moe.py +++ b/mergekit/scripts/moe.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging import os diff --git a/mergekit/scripts/run_yaml.py b/mergekit/scripts/run_yaml.py index 99889bac..16354d59 100644 --- a/mergekit/scripts/run_yaml.py +++ b/mergekit/scripts/run_yaml.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging diff --git a/mergekit/scripts/tokensurgeon.py b/mergekit/scripts/tokensurgeon.py index 514efd49..89d0ef59 100644 --- a/mergekit/scripts/tokensurgeon.py +++ b/mergekit/scripts/tokensurgeon.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import enum import logging @@ -25,9 +13,9 @@ from typing_extensions import TypeAlias from mergekit.architecture import ( + ArchitectureInfoUtils, ConfiguredArchitectureInfo, WeightInfo, - get_architecture_info, ) from mergekit.common import ModelReference from mergekit.io import TensorWriter @@ -281,7 +269,7 @@ def get_embedding_info( ) -> Tuple[WeightInfo, WeightInfo]: """Get WeightInfo for the input and output embeddings of a model.""" cfg = model.config(trust_remote_code=options.trust_remote_code) - arch_info = get_architecture_info(cfg) + arch_info = ArchitectureInfoUtils.get_architecture_info(cfg) embed, lm_head = None, None for weight_info in arch_info.pre_weights(cfg): @@ -596,8 +584,8 @@ def validate_architecture( """ model_cfg = model.config(trust_remote_code=options.trust_remote_code) donor_cfg = donor.config(trust_remote_code=options.trust_remote_code) - model_arch_info = get_architecture_info(model_cfg) - donor_arch_info = get_architecture_info(donor_cfg) + model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_cfg) + donor_arch_info = ArchitectureInfoUtils.get_architecture_info(donor_cfg) if donor_arch_info != model_arch_info: report_issue( f"Model architectures do not match: {model_arch_info.name()} vs {donor_arch_info.name()}", diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 39a1d6a2..f2b96242 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from enum import Enum diff --git a/mergekit/tokenizer/__init__.py b/mergekit/tokenizer/__init__.py index c47b6ad4..9413a3e7 100644 --- a/mergekit/tokenizer/__init__.py +++ b/mergekit/tokenizer/__init__.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo from mergekit.tokenizer.config import TokenizerConfig diff --git a/mergekit/tokenizer/build.py b/mergekit/tokenizer/build.py index d5ae95cf..ea13895c 100644 --- a/mergekit/tokenizer/build.py +++ b/mergekit/tokenizer/build.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import json import logging @@ -28,6 +16,8 @@ from mergekit.common import ModelPath, ModelReference from mergekit.graph import Task +logger = logging.getLogger(__name__) + def get_vocab_size(model_path: ModelPath, trust_remote_code: bool) -> Optional[int]: try: @@ -38,7 +28,7 @@ def get_vocab_size(model_path: ModelPath, trust_remote_code: bool) -> Optional[i ) return cfg.vocab_size except Exception as e: - logging.warning(f"Unable to get vocab size for {model_path}", exc_info=e) + logger.warning(f"Unable to get vocab size for {model_path}", exc_info=e) return None @@ -128,7 +118,7 @@ def build_union_tokenizer( vocab = tokenizer.get_vocab() for tok, idx in vocab.items(): if idx >= vocab_size: - logging.warning( + logger.warning( f"Token {repr(tok)} present in {str(model)} tokenizer but >= vocab_size" ) continue @@ -146,7 +136,7 @@ def build_union_tokenizer( if tok in out_added_tokens: if (out_added_tokens[tok] != info) and tok not in warned_added_tokens: - logging.warning( + logger.warning( f"Token '{tok}' added with multiple different settings, using first" ) warned_added_tokens.add(tok) @@ -198,7 +188,7 @@ def build_tokenizer( ) # load all tokenizers - logging.info("Loading tokenizers") + logger.info("Loading tokenizers") tokenizers = {base_model: tokenizer_base} for model in referenced_models: if model == base_model: @@ -211,14 +201,14 @@ def build_tokenizer( trust_remote_code=trust_remote_code, ) except Exception as e: - logging.error(e) - logging.warning( + logger.error(e) + logger.warning( f"Unable to load tokenizer for {model}. Assuming same as {base_model}." ) continue tokenizers[model] = model_tok - logging.info("Building output tokenizer") + logger.info("Building output tokenizer") # build final vocabulary if isinstance(tokenizer_source, ModelReference): tokenizer_out = transformers.AutoTokenizer.from_pretrained( @@ -241,7 +231,7 @@ def build_tokenizer( vocab_out = tokenizer_out.get_vocab() - logging.info("Building permutations") + logger.info("Building permutations") permutations = {} for model in ( pbar := tqdm.tqdm(referenced_models, desc="Building tokenizer permutations") @@ -264,7 +254,7 @@ def build_tokenizer( orig_idx = model_vocab[tok] if orig_idx >= vocab_size: - logging.warning( + logger.warning( f"{model} token {repr(tok)} has index {orig_idx}>{vocab_size-1} (padding?)" ) continue diff --git a/mergekit/tokenizer/config.py b/mergekit/tokenizer/config.py index a3889032..80f68490 100644 --- a/mergekit/tokenizer/config.py +++ b/mergekit/tokenizer/config.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 from typing import Dict, Optional, Union diff --git a/mergekit/tokenizer/embed.py b/mergekit/tokenizer/embed.py index 194f3cf2..06b05653 100644 --- a/mergekit/tokenizer/embed.py +++ b/mergekit/tokenizer/embed.py @@ -1,17 +1,5 @@ # Copyright (C) 2025 Arcee AI -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. +# SPDX-License-Identifier: BUSL-1.1 import logging from typing import Dict, Optional diff --git a/pyproject.toml b/pyproject.toml index 79fabbfa..6b284b74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,8 @@ build-backend = "setuptools.build_meta" name = "mergekit" description = "Tools for merging pre-trained large language models" readme = "README.md" -license = { text = "LGPL-3.0-or-later" } -version = "0.0.6" +license = { text = "BUSL-1.1" } +version = "0.1.0" authors = [{ name = "Charles Goddard", email = "chargoddard@gmail.com" }] dependencies = [ "torch>=2.0.0", @@ -47,6 +47,7 @@ mergekit-moe = "mergekit.scripts.moe:main" mergekit-tokensurgeon = "mergekit.scripts.tokensurgeon:main" mergekit-extract-lora = "mergekit.scripts.extract_lora:main" mergekit-evolve = "mergekit.scripts.evolve:main" +mergekit-pytorch = "mergekit.scripts.merge_raw_pytorch:main" [tool.setuptools] packages = [ diff --git a/tests/common.py b/tests/common.py index 23f63b25..7ebbfc3b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -2,9 +2,16 @@ import tempfile from typing import Callable, Optional -from transformers import AutoConfig, LlamaConfig, LlamaForCausalLM +from transformers import ( + AutoConfig, + CLIPVisionConfig, + LlamaConfig, + LlamaForCausalLM, + LlavaConfig, + LlavaForConditionalGeneration, +) -from mergekit.architecture import get_architecture_info +from mergekit.architecture import ArchitectureInfoUtils from mergekit.config import MergeConfiguration from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex from mergekit.merge import MergeOptions, run_merge @@ -16,15 +23,17 @@ def run_and_check_merge( check_tensors: bool = True, validate: Optional[Callable[[str], None]] = None, index_json_name: Optional[str] = None, + auto_arch: bool = False, ): if index_json_name is None: index_json_name = "model.safetensors.index.json" with tempfile.TemporaryDirectory() as tmpdir: run_merge(config, out_path=tmpdir, options=MergeOptions()) - assert os.path.exists( - os.path.join(tmpdir, index_json_name) - ), "No index file for merge" + index_path = os.path.join(tmpdir, index_json_name) + index_exists = os.path.exists(index_path) + single_shard_exists = os.path.exists(index_path.replace(".index.json", "")) + assert index_exists or single_shard_exists, "No model produced by merge" assert os.path.exists( os.path.join(tmpdir, "config.json") ), "No config json produced by merge" @@ -40,11 +49,14 @@ def run_and_check_merge( assert not has_nan, "Output contains NaN" if check_tensors: - config = AutoConfig.from_pretrained(tmpdir) - arch_info = get_architecture_info(config) + model_config = AutoConfig.from_pretrained(tmpdir) + if auto_arch: + arch_info = ArchitectureInfoUtils.infer_architecture_info(config) + else: + arch_info = ArchitectureInfoUtils.get_architecture_info(model_config) index = ShardedTensorIndex.from_disk(tmpdir) - for weight_info in arch_info.all_weights(config): + for weight_info in arch_info.all_weights(model_config): if weight_info.optional: continue if weight_info.name not in index.tensor_paths and not any( @@ -67,3 +79,36 @@ def make_picollama(path: str, vocab_size: int = 64): model = LlamaForCausalLM(cfg) model.save_pretrained(path, safe_serialization=True) return str(path) + + +def make_picoLlaVa(path: str): + # Define minimal vision configuration + vision_config = CLIPVisionConfig( + image_size=32, + patch_size=4, + num_hidden_layers=2, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + ) + + # Define minimal text configuration + text_config = LlamaConfig( + vocab_size=64, + hidden_size=32, + intermediate_size=48, + num_attention_heads=16, + num_hidden_layers=2, + ) + + # Combine into Llava configuration + llava_config = LlavaConfig( + vision_config=vision_config, + text_config=text_config, + image_seq_length=16, + ) + + # Instantiate the model + model = LlavaForConditionalGeneration(config=llava_config) + model.save_pretrained(path, safe_serialization=True) + return str(path) diff --git a/tests/test_basic_merges.py b/tests/test_basic_merges.py index 797be01d..043e93bd 100644 --- a/tests/test_basic_merges.py +++ b/tests/test_basic_merges.py @@ -1,7 +1,7 @@ from typing import Dict, Optional import pytest -from common import make_picollama, run_and_check_merge +from common import make_picollama, make_picoLlaVa, run_and_check_merge from transformers import AutoConfig from mergekit.config import ( @@ -29,6 +29,21 @@ def model_c(tmp_path_factory): return make_picollama(tmp_path_factory.mktemp("model_c")) +@pytest.fixture(scope="session") +def vlm_a(tmp_path_factory): + return make_picoLlaVa(tmp_path_factory.mktemp("vlm_a")) + + +@pytest.fixture(scope="session") +def vlm_b(tmp_path_factory): + return make_picoLlaVa(tmp_path_factory.mktemp("vlm_b")) + + +@pytest.fixture(scope="session") +def vlm_c(tmp_path_factory): + return make_picoLlaVa(tmp_path_factory.mktemp("vlm_c")) + + class TestBasicMerges: def test_gpt2_copy(self): config = MergeConfiguration( @@ -195,6 +210,12 @@ def test_model_stock_filterwise_merge(self, model_a, model_b, model_c): ) run_and_check_merge(config) + def test_arcee_fusion_merge(self, model_a, model_b): + config = self.two_model_config( + model_a, model_b, merge_method="arcee_fusion", base_model=model_a + ) + run_and_check_merge(config) + def two_model_config( self, model_a, @@ -221,3 +242,95 @@ def two_model_config( ) return config + + def test_linear_VLM_merge(self, vlm_a, vlm_b): + config = self.two_model_config(vlm_a, vlm_b, merge_method="linear") + run_and_check_merge(config, auto_arch=True) + + def test_slerp_VLM_merge(self, vlm_a, vlm_b): + config = self.two_model_config( + vlm_a, vlm_b, merge_method="slerp", base_model=vlm_a + ) + config.parameters = {"t": 0.35} + run_and_check_merge(config, auto_arch=True) + + def test_nuslerp_VLM_merges(self, vlm_a, vlm_b, vlm_c): + for base_model in [None, vlm_c]: + for row_wise in [False, True]: + for flatten in [False, True]: + print( + f"Testing nuslerp with row_wise={row_wise}, flatten={flatten}, base_model={base_model}" + ) + run_and_check_merge( + self.two_model_config( + vlm_a, + vlm_b, + merge_method="nuslerp", + base_model=base_model, + params={ + "nuslerp_row_wise": row_wise, + "nuslerp_flatten": flatten, + }, + ), + auto_arch=True, + ) + + # test weights that sum to zero + config = self.two_model_config( + vlm_a, + vlm_b, + merge_method="nuslerp", + base_model=vlm_c, + params={"nuslerp_row_wise": False, "nuslerp_flatten": False}, + ) + config.models[0].parameters["weight"] = -0.5 + config.models[1].parameters["weight"] = 0.5 + run_and_check_merge(config, auto_arch=True) + + def test_task_arithmetic_VLM_merge(self, vlm_a, vlm_b, vlm_c): + config = self.two_model_config( + vlm_a, vlm_b, merge_method="task_arithmetic", base_model=vlm_c + ) + run_and_check_merge(config, auto_arch=True) + + def test_breadcrumbs_VLM_merge(self, vlm_a, vlm_b, vlm_c): + config = self.two_model_config( + vlm_a, vlm_b, merge_method="breadcrumbs", base_model=vlm_c + ) + run_and_check_merge(config, auto_arch=True) + + def test_ties_VLM_merge(self, vlm_a, vlm_b, vlm_c): + config = self.two_model_config( + vlm_a, + vlm_b, + merge_method="ties", + base_model=vlm_c, + params={"density": 0.3}, + ) + run_and_check_merge(config, auto_arch=True) + + def test_dare_ties_VLM_merge(self, vlm_a, vlm_b, vlm_c): + config = self.two_model_config( + vlm_a, + vlm_b, + merge_method="dare_ties", + base_model=vlm_c, + params={"density": 0.66}, + ) + run_and_check_merge(config, auto_arch=True) + + def test_model_stock_VLM_merge(self, vlm_a, vlm_b, vlm_c): + config = self.two_model_config( + vlm_b, vlm_c, merge_method="model_stock", base_model=vlm_a + ) + run_and_check_merge(config, auto_arch=True) + + def test_model_stock_filterwise_VLM_merge(self, vlm_a, vlm_b, vlm_c): + config = self.two_model_config( + vlm_b, + vlm_c, + merge_method="model_stock", + base_model=vlm_a, + params={"filter_wise": True}, + ) + run_and_check_merge(config, auto_arch=True) diff --git a/tests/test_io.py b/tests/test_io.py index c1e1d1aa..6c4c82b6 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -13,8 +13,7 @@ def test_safetensors(self): writer.save_tensor("steve", torch.randn(4)) writer.finalize() - assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors")) - assert os.path.exists(os.path.join(d, "model.safetensors.index.json")) + assert os.path.exists(os.path.join(d, "model.safetensors")) def test_pickle(self): with tempfile.TemporaryDirectory() as d: @@ -22,8 +21,7 @@ def test_pickle(self): writer.save_tensor("timothan", torch.randn(4)) writer.finalize() - assert os.path.exists(os.path.join(d, "pytorch_model-00001-of-00001.bin")) - assert os.path.exists(os.path.join(d, "pytorch_model.bin.index.json")) + assert os.path.exists(os.path.join(d, "pytorch_model.bin")) def test_duplicate_tensor(self): with tempfile.TemporaryDirectory() as d: @@ -33,5 +31,4 @@ def test_duplicate_tensor(self): writer.save_tensor("jimbo", jim) writer.finalize() - assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors")) - assert os.path.exists(os.path.join(d, "model.safetensors.index.json")) + assert os.path.exists(os.path.join(d, "model.safetensors"))