From 79cb1c92eb260174de874967baa42324a4fe6eff Mon Sep 17 00:00:00 2001 From: Per-Arne Andersen Date: Mon, 21 Oct 2024 15:35:57 +0200 Subject: [PATCH] Rewritten Composite. Now should run more smoothly. also updated callbacks and tested that example runs --- examples/composite/TMCompositeCIFAR10Demo.py | 4 +- tmu/composite/callbacks/base.py | 34 +- tmu/composite/components/base.py | 33 +- .../components/color_thermometer_scoring.py | 47 ++- tmu/composite/composite.py | 390 +++++++++++++----- 5 files changed, 358 insertions(+), 150 deletions(-) diff --git a/examples/composite/TMCompositeCIFAR10Demo.py b/examples/composite/TMCompositeCIFAR10Demo.py index 50e6ac99..fbdc8893 100644 --- a/examples/composite/TMCompositeCIFAR10Demo.py +++ b/examples/composite/TMCompositeCIFAR10Demo.py @@ -50,10 +50,10 @@ def main(args): class TMCompositeCheckpointCallback(TMCompositeCallback): - def on_epoch_component_begin(self, component, epoch, logs=None): + def on_epoch_component_begin(self, component, epoch, logs=None, **kwargs): pass - def on_epoch_component_end(self, component, epoch, logs=None): + def on_epoch_component_end(self, component, epoch, logs=None, **kwargs): component.save(component_path / f"{component}-{epoch}.pkl") class TMCompositeEvaluationCallback(TMCompositeCallback): diff --git a/tmu/composite/callbacks/base.py b/tmu/composite/callbacks/base.py index 33dd107e..ad541d23 100644 --- a/tmu/composite/callbacks/base.py +++ b/tmu/composite/callbacks/base.py @@ -1,37 +1,49 @@ +from dataclasses import dataclass +from enum import auto, Enum from multiprocessing import Queue +from typing import Any, Dict -class TMCompositeCallback: +class CallbackMethod(Enum): + ON_TRAIN_COMPOSITE_BEGIN = auto() + ON_TRAIN_COMPOSITE_END = auto() + ON_EPOCH_COMPONENT_BEGIN = auto() + ON_EPOCH_COMPONENT_END = auto() + UPDATE_PROGRESS = auto() + +@dataclass +class CallbackMessage: + method: CallbackMethod + kwargs: Dict[str, Any] +class TMCompositeCallback: def __init__(self): pass - def on_epoch_component_begin(self, component, epoch, logs=None): + def on_epoch_component_begin(self, component, epoch, logs=None, **kwargs): pass - def on_epoch_component_end(self, component, epoch, logs=None): + def on_epoch_component_end(self, component, epoch, logs=None, **kwargs): pass - def on_train_composite_end(self, composite, logs=None): + def on_train_composite_end(self, composite, logs=None, **kwargs): pass - def on_train_composite_begin(self, composite, logs=None): + def on_train_composite_begin(self, composite, logs=None, **kwargs): pass - class TMCompositeCallbackProxy: - def __init__(self, queue: Queue): self.queue = queue def on_epoch_component_begin(self, component, epoch, logs=None): - self.queue.put(('on_epoch_component_begin', component, epoch, logs)) + self.queue.put(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_BEGIN, {'component': component, 'epoch': epoch, 'logs': logs})) def on_epoch_component_end(self, component, epoch, logs=None): - self.queue.put(('on_epoch_component_end', component, epoch, logs)) + self.queue.put(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_END, {'component': component, 'epoch': epoch, 'logs': logs})) def on_train_composite_end(self, composite, logs=None): - self.queue.put(('on_train_composite_end', composite, logs)) + self.queue.put(CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_END, {'composite': composite, 'logs': logs})) def on_train_composite_begin(self, composite, logs=None): - self.queue.put(('on_train_composite_begin', composite, logs)) + self.queue.put(CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_BEGIN, {'composite': composite, 'logs': logs})) \ No newline at end of file diff --git a/tmu/composite/components/base.py b/tmu/composite/components/base.py index 56a95c88..2ebff22e 100644 --- a/tmu/composite/components/base.py +++ b/tmu/composite/components/base.py @@ -1,4 +1,6 @@ import abc +import uuid + import numpy as np from pathlib import Path from typing import Union, Tuple @@ -13,6 +15,7 @@ def __init__(self, model_cls, model_config, epochs=1, **kwargs) -> None: self.model_cls = model_cls self.model_config = model_config self.epochs = epochs + self.uuid = uuid.uuid4() # Warn about unused kwargs if kwargs: @@ -36,19 +39,23 @@ def preprocess(self, data: dict) -> dict: return data def fit(self, data: dict) -> None: - x_train, y_train = data["X"], data["Y"] - - # Check if type is uint32 - if x_train.dtype != np.uint32: - x_train = x_train.astype(np.uint32) - - if y_train.dtype != np.uint32: - y_train = y_train.astype(np.uint32) - - self.model_instance.fit( - x_train, - y_train, - ) + try: + x_train, y_train = data["X"], data["Y"] + + # Check if type is uint32 + if x_train.dtype != np.uint32: + x_train = x_train.astype(np.uint32) + + if y_train.dtype != np.uint32: + y_train = y_train.astype(np.uint32) + + self.model_instance.fit( + x_train, + y_train, + ) + except Exception as e: + print(f"Error: {e}") + raise e def predict(self, data: dict) -> Tuple[np.array, np.array]: X_test = data["X"] diff --git a/tmu/composite/components/color_thermometer_scoring.py b/tmu/composite/components/color_thermometer_scoring.py index a3f50df6..6410160d 100644 --- a/tmu/composite/components/color_thermometer_scoring.py +++ b/tmu/composite/components/color_thermometer_scoring.py @@ -1,26 +1,49 @@ import numpy as np +from typing import Dict, Any from tmu.composite.components.base import TMComponent class ColorThermometerComponent(TMComponent): - def __init__(self, model_cls, model_config, resolution=8, **kwargs) -> None: super().__init__(model_cls=model_cls, model_config=model_config, **kwargs) + if resolution < 2 or resolution > 255: + raise ValueError("Resolution must be between 2 and 255") self.resolution = resolution + self._thresholds = None + + def _create_thresholds(self) -> None: + self._thresholds = np.linspace(0, 255, self.resolution + 1)[1:-1] - def preprocess(self, data: dict): + def preprocess(self, data: dict) -> Dict[str, Any]: super().preprocess(data=data) + X_org = data.get("X") + Y = data.get("Y") + + if X_org is None: + raise ValueError("Input data 'X' is missing") + + if X_org.ndim != 4: + raise ValueError(f"Expected 4D input, got {X_org.ndim}D") + + if X_org.shape[-1] != 3: + raise ValueError(f"Expected 3 color channels, got {X_org.shape[-1]}") + + if self._thresholds is None: + self._create_thresholds() - X_org = data["X"] - Y = data["Y"] + # Use broadcasting for efficient computation + X = (X_org[:, :, :, :, np.newaxis] >= self._thresholds).astype(np.uint8) - X = np.empty((X_org.shape[0], X_org.shape[1], X_org.shape[2], X_org.shape[3], self.resolution), dtype=np.uint8) - for z in range(self.resolution): - X[:, :, :, :, z] = X_org[:, :, :, :] >= (z + 1) * 255 / (self.resolution + 1) + # Reshape correctly + batch_size, height, width, channels, _ = X.shape + X = X.transpose(0, 1, 2, 4, 3).reshape(batch_size, height, width, channels * (self.resolution - 1)) - X = X.reshape((X_org.shape[0], X_org.shape[1], X_org.shape[2], 3 * self.resolution)) + return { + "X": X, + "Y": Y + } - return dict( - X=X, - Y=Y, - ) + def get_output_shape(self, input_shape: tuple) -> tuple: + if len(input_shape) != 4: + raise ValueError(f"Expected 4D input shape, got {len(input_shape)}D") + return (*input_shape[:-1], input_shape[-1] * (self.resolution - 1)) \ No newline at end of file diff --git a/tmu/composite/composite.py b/tmu/composite/composite.py index c942e1d1..dc461844 100644 --- a/tmu/composite/composite.py +++ b/tmu/composite/composite.py @@ -1,17 +1,37 @@ -import threading -from collections import defaultdict -from os import cpu_count -from typing import Optional, Type, Union, List -from pathlib import Path -from multiprocessing import Pool, Manager +import concurrent.futures +import multiprocessing +import traceback +import uuid +from functools import partial +from multiprocessing import Manager, Queue, cpu_count, Pool import numpy as np -from tqdm import tqdm - +from typing import Optional, List, Dict, Any, Tuple, Union +from dataclasses import dataclass +from pathlib import Path +import threading from tmu.composite.callbacks.base import TMCompositeCallbackProxy, TMCompositeCallback from tmu.composite.components.base import TMComponent from tmu.composite.gating.base import BaseGate from tmu.composite.gating.linear_gate import LinearGate - +from tmu.composite.callbacks.base import CallbackMessage, CallbackMethod + + +@dataclass +class ComponentTask: + component: TMComponent + data: Any + epochs: int + progress: int = 0 + result: Any = None + @property + def component_id(self) -> uuid.UUID: + return self.component.uuid + +@dataclass +class FitResult: + component: TMComponent + success: bool + error: Optional[Exception] = None class TMCompositeBase: @@ -35,122 +55,270 @@ def _component_predict(self, component, data): return votes -class TMCompositeMP(TMCompositeBase): +class TMCompositeMP: + def __init__(self, composite: 'TMComposite', **kwargs) -> None: + self.composite = composite + self.max_workers = min(cpu_count(), len(composite.components)) + self.remove_data_after_preprocess = kwargs.get('remove_data_after_preprocess', False) + multiprocessing.set_start_method('spawn', force=True) + + def _process_callbacks(self, callbacks: List[TMCompositeCallback], message: CallbackMessage) -> None: + method_name = message.method.name.lower() + for callback in callbacks: + try: + getattr(callback, method_name)(**message.kwargs, composite=self.composite) + except Exception as e: + print(f"Error in callback {callback.__class__.__name__}.{method_name}: {e}") + traceback.print_exc() + + def _fit_component(self, task, callback_queue) -> FitResult: + try: + data_preprocessed = task.component.preprocess(task.data) + callbacks = [] + + # remove task.data + if self.remove_data_after_preprocess: + task.data = None + + for epoch in range(task.epochs): + + if callback_queue: + callbacks.append(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_BEGIN, {'component': task.component, 'epoch': epoch})) + + task.component.fit(data=data_preprocessed) + task.progress += 1 + + if callback_queue: + callbacks.append(CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_END, {'component': task.component, 'epoch': epoch})) + + if callback_queue: + callback_queue.put(callbacks) # Send all callbacks at once + callbacks.clear() + + return FitResult(component=task.component, success=True) + except Exception as e: + print(f"Error in _fit_component for {task.component.__class__.__name__}: {e}") + traceback.print_exc() + return FitResult(component=task.component, success=False, error=e) + + def fit(self, data: Dict[str, Any], callbacks: Optional[List[TMCompositeCallback]] = None) -> None: + with Manager() as manager: + callback_queue: Optional[Queue] = manager.Queue() if callbacks else None + error_queue: Queue = manager.Queue() - def __init__(self, composite) -> None: - super().__init__(composite=composite) - - def _listener(self, queue, callbacks): - while True: - item = queue.get() - if item == 'DONE': - break - method, *args = item - for callback in callbacks: - getattr(callback, method)(*args) - - @staticmethod - def _mp_fit(args: tuple) -> None: - idx, component, data_preprocessed, proxy_callback = args - - if proxy_callback: - proxy_callback.on_train_composite_begin(composite=component) - - epochs = component.epochs - pbar = tqdm(total=epochs, position=idx) - pbar.set_description(f"Component {idx}: {type(component).__name__}") - for epoch in range(epochs): - if proxy_callback: - proxy_callback.on_epoch_component_begin(component=component, epoch=epoch) - component.fit(data=data_preprocessed) - pbar.update(1) - if proxy_callback: - proxy_callback.on_epoch_component_end(component=component, epoch=epoch) - - if proxy_callback: - proxy_callback.on_train_composite_end(composite=component) - return component - - def fit(self, data: dict, callbacks: Optional[list[TMCompositeCallback]] = None) -> None: + self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_BEGIN, {})) - with Manager() as manager: + tasks = [ComponentTask(component=component, data=data, epochs=component.epochs) + for component in self.composite.components] + callback_thread = None if callbacks: - callback_queue = manager.Queue() # Create a queue with the manager - callback_proxy = TMCompositeCallbackProxy(callback_queue) + callback_thread = self._start_callback_handler(callbacks, callback_queue, error_queue) - # Start listener thread - listener_thread = threading.Thread(target=self._listener, args=(callback_queue, callbacks)) - listener_thread.start() - else: - callback_proxy = None - with Pool() as pool: - data_preprocessed = [component.preprocess(data) for component in self.composite.components] - self.composite.components = pool.map(TMCompositeMP._mp_fit, - ((idx, component, data_preprocessed[idx], callback_proxy) for - idx, component in - enumerate(self.composite.components))) + results = self._execute_tasks(tasks, callback_queue) - if callbacks: - callback_queue.put('DONE') # Send done signal to listener - listener_thread.join() # Wait for listener to process all logs + self._process_results(results, error_queue) - def predict(self, data: dict, votes: dict, gating_mask: np.ndarray) -> np.array: - # Determine number of processes based on available CPU cores - n_processes = min(cpu_count(), len(self.composite.components)) - with Pool(n_processes) as pool: - results = pool.starmap(self._component_predict, [ - (component, data) for i, component in enumerate(self.composite.components) - ]) + self._cleanup(callback_queue, callback_thread) - # Aggregate results from each process - for i, result in enumerate(results): - for key, score in result.items(): + self._check_errors(error_queue) - # Apply gating mask - masked_score = score * gating_mask[:, i] + self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_END, {})) - if key not in votes: - votes[key] = masked_score - else: - votes[key] += masked_score + def _start_callback_handler(self, callbacks: List[TMCompositeCallback], callback_queue: Queue, error_queue: Queue): + def callback_handler() -> None: + while True: + try: + message = callback_queue.get() + if message == 'DONE': + break + if isinstance(message, list): # Handle batch of callbacks + for callback_message in message: + if isinstance(callback_message, CallbackMessage): + self._process_callbacks(callbacks, callback_message) + except Exception as e: + print(f"Error in callback handler: {e}") + traceback.print_exc() + error_queue.put(('callback_handler', e)) + callback_thread = threading.Thread(target=callback_handler) + callback_thread.start() + return callback_thread -class TMCompositeSingleCPU(TMCompositeBase): - def __init__(self, composite) -> None: - super().__init__(composite=composite) - def fit(self, data: dict, callbacks: Optional[list[TMCompositeCallback]] = None) -> None: + + + def _execute_tasks(self, tasks: List[ComponentTask], callback_queue: Optional[Queue]) -> List[FitResult]: + results = [] + + # Create a partial function with the callback_queue + fit_component_partial = partial(self._fit_component, callback_queue=callback_queue) + + with Pool(processes=self.max_workers) as pool: + # Map the tasks to the pool + async_results = [ + pool.apply_async(fit_component_partial, (task,)) for task in tasks + ] + + # Collect results as they complete + for async_result, task in zip(async_results, tasks): + try: + result = async_result.get() # This will wait for the task to complete + results.append(result) + except Exception as e: + print(f"Exception when processing results for {task.component.__class__.__name__}: {e}") + traceback.print_exc() + results.append( + FitResult(component=task.component, success=False, error=e) + ) + + return results + + def _process_results(self, results: List[FitResult], error_queue: Queue) -> None: + for result in results: + if result.success: + matching_component = next( + (c for c in self.composite.components if c.uuid == result.component.uuid), + None + ) + + if matching_component is not None: + idx = self.composite.components.index(matching_component) + self.composite.components[idx] = result.component + else: + error_message = f"Could not find a matching component for {result.component}" + print(error_message) + error_queue.put(('process_results', ValueError(error_message))) + else: + error_queue.put(('fit_component', result.error)) + + def _cleanup(self, callback_queue: Optional[Queue], callback_thread: Optional[threading.Thread]) -> None: + if callback_queue: + callback_queue.put('DONE') + if callback_thread: + callback_thread.join() + + + def _check_errors(self, error_queue: Queue) -> None: + if not error_queue.empty(): + print("Errors occurred during fitting:") + while not error_queue.empty(): + error_source, error = error_queue.get() + print(f"Error in {error_source}: {error}") + + def predict(self, data: Dict[str, Any], votes: Dict[str, np.ndarray], gating_mask: np.ndarray) -> None: + with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor: + future_to_component = { + executor.submit(self._predict_component, component, data, gating_mask, i): component + for i, component in enumerate(self.composite.components) + } + + for future in concurrent.futures.as_completed(future_to_component): + component = future_to_component[future] + try: + result = future.result() + for key, score in result.items(): + if key not in votes: + votes[key] = score + else: + votes[key] += score + except Exception as e: + print(f"Exception when processing results for {component.__class__.__name__}: {e}") + traceback.print_exc() + + def _predict_component( + self, + component: TMComponent, + data: Dict[str, Any], + gating_mask: np.ndarray, + component_idx: int + ) -> Dict[str, np.ndarray]: + try: + # Preprocess data and get scores + _, scores = component.predict(component.preprocess(data)) + scores = scores.reshape(scores.shape[0], -1) # Ensure 2D + + # Normalize scores + denominator = np.maximum(np.ptp(scores, axis=1), 1e-8) # Avoid division by zero + normalized_scores = scores / denominator[:, np.newaxis] + + # Apply gating mask + mask = gating_mask[:, component_idx].reshape(-1, 1) if gating_mask.ndim > 1 else gating_mask.reshape(-1, 1) + masked_scores = normalized_scores * mask + + # Create and return votes + return { + "composite": masked_scores, + str(component): masked_scores + } + except Exception as e: + print(f"Error in predict_component for {component}: {str(e)}") + print(f"Shapes - scores: {scores.shape}, gating_mask: {gating_mask.shape}, mask: {mask.shape}") + traceback.print_exc() + return {} + + +class TMCompositeSingleCPU: + def __init__(self, composite, **kwargs) -> None: + self.composite = composite + + def _component_predict(self, component, data): + data_preprocessed = component.preprocess(data) + _, scores = component.predict(data_preprocessed) + + votes = dict() + votes["composite"] = np.zeros_like(scores, dtype=np.float32) + votes[str(component)] = np.zeros_like(scores, dtype=np.float32) + + for i in range(scores.shape[0]): + denominator = np.max(scores[i]) - np.min(scores[i]) + score = 1.0 * scores[i] / denominator if denominator != 0 else 0 + votes["composite"][i] += score + votes[str(component)][i] += score + + return votes + + def _process_callbacks(self, callbacks: List[TMCompositeCallback], message: CallbackMessage) -> None: + method_name = message.method.name.lower() + for callback in callbacks: + try: + getattr(callback, method_name)(**message.kwargs) + except Exception as e: + print(f"Error in callback {callback.__class__.__name__}.{method_name}: {e}") + import traceback + traceback.print_exc() + + def fit(self, data: Dict[str, Any], callbacks: Optional[List[TMCompositeCallback]] = None) -> None: + if callbacks is None: + callbacks = [] + data_preprocessed = [component.preprocess(data) for component in self.composite.components] epochs_left = [component.epochs for component in self.composite.components] - pbars = [tqdm(total=component.epochs) for component in self.composite.components] - for idx, (pbar, component) in enumerate(zip(pbars, self.composite.components)): - pbar.set_description(f"Component {idx}: {type(component).__name__}") - [callback.on_train_composite_begin(composite=self) for callback in callbacks] + self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_BEGIN, {'composite': self.composite})) + epoch = 0 while any(epochs_left): for idx, component in enumerate(self.composite.components): if epochs_left[idx] > 0: - [callback.on_epoch_component_begin(component=component, epoch=epoch) for callback in callbacks] + self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_BEGIN, {'component': component, 'epoch': epoch})) + component.fit(data=data_preprocessed[idx]) - [callback.on_epoch_component_end(component=component, epoch=epoch) for callback in callbacks] - pbars[idx].update(1) + + self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_EPOCH_COMPONENT_END, {'component': component, 'epoch': epoch})) + epochs_left[idx] -= 1 epoch += 1 - [callback.on_train_composite_end(composite=self) for callback in callbacks] + self._process_callbacks(callbacks, CallbackMessage(CallbackMethod.ON_TRAIN_COMPOSITE_END, {'composite': self.composite})) - def predict(self, data: dict, votes: dict, gating_mask: np.ndarray): - pbar = tqdm(total=len(self.composite.components)) + def predict(self, data: Dict[str, Any], votes: Dict[str, np.ndarray], gating_mask: np.ndarray) -> None: for i, component in enumerate(self.composite.components): - pbar.set_description(f"Component {i}: {type(component).__name__}") component_votes = self._component_predict(component, data) for key, score in component_votes.items(): - # Apply gating mask masked_score = score * gating_mask[:, i] @@ -158,34 +326,31 @@ def predict(self, data: dict, votes: dict, gating_mask: np.ndarray): votes[key] = masked_score else: votes[key] += masked_score - pbar.update(1) - class TMComposite: - def __init__( self, - components: Optional[list[TMComponent]] = None, - gate_function: Optional[Type[BaseGate]] = None, - gate_function_params: Optional[dict] = None, - use_multiprocessing: bool = False + components: Optional[List[TMComponent]] = None, + gate_function: Optional[type[BaseGate]] = None, + gate_function_params: Optional[Dict[str, Any]] = None, + use_multiprocessing: bool = False, + **kwargs ) -> None: self.components: List[TMComponent] = components or [] self.use_multiprocessing = use_multiprocessing if gate_function_params is None: - gate_function_params = dict() + gate_function_params = {} - self.gate_function_instance = gate_function(self, **gate_function_params) if gate_function else LinearGate(self, - **gate_function_params) + self.gate_function_instance: BaseGate = gate_function(self, **gate_function_params) if gate_function else LinearGate(self, **gate_function_params) - self.logic = TMCompositeSingleCPU(composite=self) if not use_multiprocessing else TMCompositeMP(composite=self) + self.logic: Union[TMCompositeSingleCPU, TMCompositeMP] = TMCompositeMP(composite=self, **kwargs) if use_multiprocessing else TMCompositeSingleCPU(composite=self, **kwargs) - def fit(self, data: dict, callbacks: Optional[list[TMCompositeCallback]] = None) -> None: + def fit(self, data: Dict[str, Any], callbacks: Optional[List[TMCompositeCallback]] = None) -> None: self.logic.fit(data, callbacks) - def predict(self, data: dict) -> np.array: - votes = dict() + def predict(self, data: Dict[str, Any]) -> Dict[str, np.ndarray]: + votes: Dict[str, np.ndarray] = {} # Gating Mechanism gating_mask: np.ndarray = self.gate_function_instance.predict(data) @@ -196,7 +361,7 @@ def predict(self, data: dict) -> np.array: return {k: v.argmax(axis=1) for k, v in votes.items()} - def save_model(self, path: Union[Path, str], format="pkl") -> None: + def save_model(self, path: Union[Path, str], format: str = "pkl") -> None: path = Path(path) if isinstance(path, str) else path if format == "pkl": @@ -206,7 +371,7 @@ def save_model(self, path: Union[Path, str], format="pkl") -> None: else: raise NotImplementedError(f"Format {format} not supported") - def load_model(self, path: Union[Path, str], format="pkl") -> None: + def load_model(self, path: Union[Path, str], format: str = "pkl") -> None: path = Path(path) if isinstance(path, str) else path if format == "pkl": @@ -217,7 +382,7 @@ def load_model(self, path: Union[Path, str], format="pkl") -> None: else: raise NotImplementedError(f"Format {format} not supported") - def load_model_from_components(self, path: Union[Path, str], format="pkl") -> None: + def load_model_from_components(self, path: Union[Path, str], format: str = "pkl") -> None: path = Path(path) if isinstance(path, str) else path if not path.is_dir(): @@ -227,7 +392,8 @@ def load_model_from_components(self, path: Union[Path, str], format="pkl") -> No files = [f for f in path.iterdir() if f.is_file() and f.suffix == f".{format}"] # Group files by component details - component_groups = defaultdict(list) + from collections import defaultdict + component_groups: Dict[str, List[Tuple[int, Path]]] = defaultdict(list) for file in files: parts = file.stem.split('-') epoch = int(parts[-1])