diff --git a/.github/workflows/test_mlagility.yml b/.github/workflows/test_mlagility.yml index 3cce76f3..89ea6e4a 100644 --- a/.github/workflows/test_mlagility.yml +++ b/.github/workflows/test_mlagility.yml @@ -34,7 +34,7 @@ jobs: python -m pip install --upgrade pip conda install pylint if [ -f setup.py ]; then pip install -e .; fi - pip install transformers + pip install transformers timm python -m pip check - name: Lint with PyLint shell: bash -el {0} @@ -56,6 +56,8 @@ jobs: rm -rf ~/.cache/mlagility benchit examples/cli/scripts/hello_world.py rm -rf ~/.cache/mlagility + benchit examples/cli/scripts/multiple_invocations.py + rm -rf ~/.cache/mlagility benchit examples/cli/scripts/max_depth.py --max-depth 1 rm -rf ~/.cache/mlagility benchit examples/cli/scripts/two_models.py diff --git a/examples/cli/discovery.md b/examples/cli/discovery.md index 8efe2e56..f4a382f3 100644 --- a/examples/cli/discovery.md +++ b/examples/cli/discovery.md @@ -112,7 +112,7 @@ You can see that `hello_world.py`, `two_models.py`, and `max_depth.py` are all e > See the [Benchmark Multiple Scripts documentation](https://github.com/groq/mlagility/blob/main/docs/tools_user_guide.md#benchmark-multiple-scripts) for more details. -### Maximum Analysis Depth +## Maximum Analysis Depth PyTorch models (eg, `torch.nn.Module`) are often built out of a collection of smaller instances. For example, a PyTorch multilayer perceptron (MLP) model may be built out of many `torch.nn.Linear` modules. diff --git a/examples/cli/readme.md b/examples/cli/readme.md index 6b833ea3..7a5f0547 100644 --- a/examples/cli/readme.md +++ b/examples/cli/readme.md @@ -17,6 +17,7 @@ In this tutorial you will learn things such as: - [A "hello world" example, which is the easiest way to get started](#hello-world) - [Benchmarking on Nvidia GPUs](#nvidia-benchmarking) - [Working with scripts that invoke more than one model](#multiple-models-per-script) +- [Working with scripts that invoke a model multiple times](#multiple-invocations-of-a-model) - [Benchmarking an ONNX file](#onnx-benchmarking) # Just Benchmark BERT @@ -171,6 +172,46 @@ Woohoo! The 'benchmark' command is complete. You can see that both model instances in `two_models.py`, `pytorch_model` and `another_pytorch_model`, are both discovered and benchmarked. +## Multiple Invocations of a Model + +A single script may invoke the same model multiple times using different input shapes (e.g. when varying the batch size). When this happens, MLAgility will benchmark and display each of those invocations as sub-results of the same model instance. + +> **Note**: multiple invocations of a model with the same input shape will only be benchmarked once. + +The `multiple_invocations.py` script instantiates a single model and invokes it three times. The fist two times the model is invoked with inputs of the same shape (batch 1), while the third invocation uses a different input shape (batch 2). Note that two unique static model invocations are identified. + +Run the following command: + +``` +benchit scripts/multiple_invocations.py +``` + +To get a result like: +``` +Models discovered during profiling: + +multiple_invocations.py: + pytorch_model + Model Type: Pytorch (torch.nn.Module) + Class: SmallModel () + Location: /net/home/dhnoronha/mlagility/examples/cli/scripts/multiple_invocations.py, line 40 + Parameters: 60 (<0.1 MB) + + With input shape 1 (executed 2x) + Input Shape: 'x': (1, 11) + Hash: b4aa73ae + Status: Successfully benchmarked on Intel(R) Xeon(R) CPU @ 2.20GHz (ort v1.14.1) + Mean Latency: 0.013 milliseconds (ms) + Throughput: 77909.6 inferences per second (IPS) + + With input shape 2 (executed 1x) + Input Shape: 'x': (2, 11) + Hash: cfaa2e2c + Status: Successfully benchmarked on Intel(R) Xeon(R) CPU @ 2.20GHz (ort v1.14.1) + Mean Latency: 0.015 milliseconds (ms) + Throughput: 64938.1 inferences per second (IPS) +``` + ## ONNX Benchmarking If you already happen to have an ONNX file, `benchit` can benchmark it for you. We can demonstrate this with the ONNX file in `examples/cli/onnx/sample.onnx`. diff --git a/examples/cli/scripts/multiple_invocations.py b/examples/cli/scripts/multiple_invocations.py new file mode 100644 index 00000000..6ddc7e89 --- /dev/null +++ b/examples/cli/scripts/multiple_invocations.py @@ -0,0 +1,47 @@ +# labels: name::multiple_invocations +""" +This example demonstrates what happens when your script contains +a model that is invoked multiple times with different input shapes + +To try it, run: + +benchit multiple_invocations.py + +You should see the two unique invocations being identified. +""" +import torch + +torch.manual_seed(1) + +# Define model class +class SmallModel(torch.nn.Module): + def __init__(self, input_features, output_size): + super(SmallModel, self).__init__() + self.fc = torch.nn.Linear(input_features, output_size) + + def forward(self, x): + # x has shape (batch_size, input_features) + # Set the batch size dimension to -1 to allow for flexibility + x = x.view(-1, x.size(1)) + + output = self.fc(x) + + # Reshape the output to restore the original batch size dimension + output = output.view(-1, output_size) + return output + + +# Instantiate model and generate inputs +input_features = 11 +output_size = 5 +pytorch_model = SmallModel(input_features, output_size) + +# Create 3 sets of inputs +batch_size = 1 +inputs1 = {"x": torch.rand(batch_size, input_features)} +inputs2 = {"x": torch.rand(batch_size, input_features)} +inputs3 = {"x": torch.rand(batch_size + 1, input_features)} + +pytorch_model(**inputs1) +pytorch_model(**inputs2) +pytorch_model(**inputs3) diff --git a/models/llm_layer/llama_layer_prototype.py b/models/llm_layer/llama_layer_prototype.py index a6214792..c61c5219 100644 --- a/models/llm_layer/llama_layer_prototype.py +++ b/models/llm_layer/llama_layer_prototype.py @@ -12,14 +12,7 @@ def call_llama_layer(params="7B", use_cache=False): - # Use different torch seeds for KV caching vs. not, so that - # the models end up with different mlagility hashes - # Remove the if-statement when - # https://github.com/groq/mlagility/issues/316 is fixed - if use_cache: - torch.manual_seed(0) - else: - torch.manual_seed(1) + torch.manual_seed(0) # Parsing command-line arguments batch_size, max_seq_length = parse(["batch_size", "max_seq_length"]) diff --git a/src/mlagility/analysis/analysis.py b/src/mlagility/analysis/analysis.py index f3098472..4f273e0e 100644 --- a/src/mlagility/analysis/analysis.py +++ b/src/mlagility/analysis/analysis.py @@ -8,7 +8,8 @@ import functools import dataclasses import traceback -from typing import Union, List, Dict +import hashlib +from typing import Union, List, Dict, Tuple from types import FrameType, TracebackType from enum import Enum import torch @@ -63,28 +64,31 @@ def torch_activations(self) -> List[str]: return act -def _store_traceback(model_info: util.ModelInfo): +def _store_traceback(invocation_info: util.UniqueInvocationInfo): """ - Store the traceback from an exception into model_info so that + Store the traceback from an exception into invocation_info so that we can print it during the status update. """ exc_type, exc_value, exc_traceback = sys.exc_info() - model_info.traceback = traceback.format_exception( + invocation_info.traceback = traceback.format_exception( exc_type, exc_value, exc_traceback ) -def call_benchit( - model_inputs: dict, model_info: util.ModelInfo, tracer_args: TracerArgs +def explore_invocation( + model_inputs: dict, + model_info: util.ModelInfo, + invocation_info: util.UniqueInvocationInfo, + tracer_args: TracerArgs, ) -> None: """ Calls the benchit function from within the model forward function """ # Update status to "computing" - model_info.status_message = "Computing..." - model_info.status_message_color = printing.Colors.OKBLUE + invocation_info.status_message = "Computing..." + invocation_info.status_message_color = printing.Colors.OKBLUE status.update(tracer_args.models_found) # Get a copy of the keyword arguments @@ -111,10 +115,10 @@ def call_benchit( inputs[all_args[i]] = torch.tensor(args[i].detach().numpy()) else: inputs[all_args[i]] = args[i] - model_info.inputs = inputs + invocation_info.inputs = inputs build_name = filesystem.get_build_name( - tracer_args.script_name, tracer_args.labels, model_info.hash + tracer_args.script_name, tracer_args.labels, invocation_info.hash ) # Save model labels @@ -124,12 +128,12 @@ def call_benchit( perf = None try: if model_info.model_type == build.ModelType.PYTORCH_COMPILED: - model_info.status_message = ( + invocation_info.status_message = ( "Skipping model compiled using torch.compile(). " "benchit requires models to be in eager mode " "(regardless of what runtime you have selected)." ) - model_info.status_message_color = printing.Colors.WARNING + invocation_info.status_message_color = printing.Colors.WARNING else: perf = benchmark_model( model_info.model, @@ -150,36 +154,36 @@ def call_benchit( onnx_opset=tracer_args.onnx_opset, ) if Action.BENCHMARK in tracer_args.actions: - model_info.status_message = "Model successfully benchmarked!" - model_info.performance = perf - model_info.status_message_color = printing.Colors.OKGREEN + invocation_info.status_message = "Model successfully benchmarked!" + invocation_info.performance = perf + invocation_info.status_message_color = printing.Colors.OKGREEN else: - model_info.status_message = "Model successfully built!" - model_info.status_message_color = printing.Colors.OKGREEN + invocation_info.status_message = "Model successfully built!" + invocation_info.status_message_color = printing.Colors.OKGREEN except exp.StageError: build_state = build.load_state( cache_dir=tracer_args.cache_dir, build_name=build_name ) - model_info.status_message = "Build Error: see log files for details." - model_info.status_message_color = printing.Colors.WARNING + invocation_info.status_message = "Build Error: see log files for details." + invocation_info.status_message_color = printing.Colors.WARNING - _store_traceback(model_info) + _store_traceback(invocation_info) except exp.Error: - model_info.status_message = "GroqFlowError: see log files for details." - model_info.status_message_color = printing.Colors.WARNING + invocation_info.status_message = "GroqFlowError: see log files for details." + invocation_info.status_message_color = printing.Colors.WARNING - _store_traceback(model_info) + _store_traceback(invocation_info) # This broad exception is ok since enumerating all exceptions is # not possible, as the tested software continuously evolves. except Exception as e: # pylint: disable=broad-except util.stop_stdout_forward() - model_info.status_message = f"Unknown benchit error: {e}" - model_info.status_message_color = printing.Colors.WARNING + invocation_info.status_message = f"Unknown benchit error: {e}" + invocation_info.status_message_color = printing.Colors.WARNING - _store_traceback(model_info) + _store_traceback(invocation_info) finally: # Ensure that stdout is not being forwarded before updating status if hasattr(sys.stdout, "terminal"): @@ -247,7 +251,28 @@ def call_benchit( def get_model_hash( model: Union[torch.nn.Module, "tf.keras.Model"], model_type: build.ModelType ): - return build.hash_model(model, model_type, hash_params=True)[:8] + return build.hash_model(model, model_type, hash_params=False)[:8] + + +def get_invocation_hash( + model_hash: str, parent_invocation_hash: str, args: Tuple, kwargs: Dict +) -> str: + """ + Combines the model hash and the input shapes to create the invocation hash + We also ensure that invocations that come from different parents have different hashes + """ + + # Merge positional and keyword args + args = {"Positional Arg {}".format(i + 1): arg for i, arg in enumerate(args)} + kwargs = {**kwargs, **args} + + # Get input shapes and types + input_shapes, input_dtypes = build.get_shapes_and_dtypes(kwargs) + + hashable_content = ( + f"{model_hash}{parent_invocation_hash}{input_shapes}{input_dtypes}" + ) + return hashlib.sha256(hashable_content.encode()).hexdigest()[:8], input_shapes def store_model_info( @@ -292,7 +317,6 @@ def store_model_info( depth=depth, hash=model_hash, parent_hash=parent_hash, - is_target=model_hash in tracer_args.targets or tracer_args.targets == [], build_model=build_model, model_type=model_type, script_name=tracer_args.script_name, @@ -421,11 +445,6 @@ def forward_spy(*args, **kwargs): # do so by setting the max_depth flag. return old_forward(*args, **kwargs) - # Keep track of execution time - start_time = time.time() - outputs = old_forward(*args, **kwargs) - end_time = time.time() - # We can only keep track of keras models once they have been executed if model_type == build.ModelType.KERAS: store_model_info( @@ -438,22 +457,54 @@ def forward_spy(*args, **kwargs): depth, parent_hash, ) + + # Get parent invocation hash + parent_invocation_hash = None + if parent_hash: + parent_invocation_hash = tracer_args.models_found[ + parent_hash + ].last_unique_invocation_executed + model_hash = get_model_hash(local_var, model_type) + invocation_hash, input_shapes = get_invocation_hash( + model_hash, parent_invocation_hash, args, kwargs + ) model_info = tracer_args.models_found[model_hash] - model_info.exec_time = model_info.exec_time + end_time - start_time - model_info.executed = model_info.executed + 1 + if invocation_hash not in model_info.unique_invocations: + model_info.unique_invocations[ + invocation_hash + ] = util.UniqueInvocationInfo( + hash=invocation_hash, + is_target=invocation_hash in tracer_args.targets + or len(tracer_args.targets) == 0, + input_shapes=input_shapes, + parent_hash=parent_invocation_hash, + ) + model_info.last_unique_invocation_executed = invocation_hash + + # Keep track of execution time + start_time = time.time() + outputs = old_forward(*args, **kwargs) + end_time = time.time() + + invocation_info = model_info.unique_invocations[invocation_hash] + invocation_info.exec_time = ( + invocation_info.exec_time + end_time - start_time + ) + invocation_info.executed = invocation_info.executed + 1 # Call groqit if this is the first time the model is being executed # and this model has been selected by the user if ( - model_info.executed == 1 - and model_info.is_target + invocation_info.executed == 1 + and invocation_info.is_target and (model_info.build_model) ): - call_benchit( + explore_invocation( model_inputs=[args, kwargs], model_info=model_info, + invocation_info=invocation_info, tracer_args=tracer_args, ) # Ensure that groqit() doesn't interfere with our execution count diff --git a/src/mlagility/analysis/status.py b/src/mlagility/analysis/status.py index dac41c93..ec88ca56 100644 --- a/src/mlagility/analysis/status.py +++ b/src/mlagility/analysis/status.py @@ -16,96 +16,151 @@ def update(models_found: Dict[str, ModelInfo]) -> None: "\nModels discovered during profiling:\n", c=printing.Colors.BOLD, ) - recursive_print(models_found, None, []) + recursive_print(models_found, None, None, []) def recursive_print( models_found: Dict[str, ModelInfo], - parent_hash: Union[str, None] = None, + parent_model_hash: Union[str, None] = None, + parent_invocation_hash: Union[str, None] = None, script_names_visited: List[str] = False, ) -> None: script_names_visited = [] - for h in models_found.keys(): - if parent_hash == models_found[h].parent_hash and models_found[h].executed > 0: - print_file_name = models_found[h].script_name not in script_names_visited - - print_model(models_found[h], h, print_file_name) - - if print_file_name: - script_names_visited.append(models_found[h].script_name) - - recursive_print( - models_found, parent_hash=h, script_names_visited=script_names_visited - ) - - -def print_model( - model_info: ModelInfo, model_hash: Union[str, None], print_file_name: bool = False + for model_hash in models_found.keys(): + model_visited = False + model_info = models_found[model_hash] + invocation_idx = 0 + for invocation_hash in model_info.unique_invocations.keys(): + unique_invocation = model_info.unique_invocations[invocation_hash] + + if ( + parent_model_hash == model_info.parent_hash + and unique_invocation.executed > 0 + and ( + model_info.unique_invocations[invocation_hash].parent_hash + == parent_invocation_hash + ) + ): + print_file_name = False + if model_info.script_name not in script_names_visited: + script_names_visited.append(model_info.script_name) + if model_info.depth == 0: + print_file_name = True + + print_invocation( + model_info, + invocation_hash, + print_file_name, + invocation_idx=invocation_idx, + model_visited=model_visited, + ) + model_visited = True + invocation_idx += 1 + + if print_file_name: + script_names_visited.append(model_info.script_name) + + recursive_print( + models_found, + parent_model_hash=model_hash, + parent_invocation_hash=invocation_hash, + script_names_visited=script_names_visited, + ) + + +def print_invocation( + model_info: ModelInfo, + invocation_hash: Union[str, None], + print_file_name: bool = False, + invocation_idx: int = 0, + model_visited: bool = False, ) -> None: """ Print information about a given model or submodel """ + unique_invocation = model_info.unique_invocations[invocation_hash] ident = "\t" * (2 * model_info.depth + 1) if print_file_name: print(f"{model_info.script_name}.py:") - printing.log(f"{ident}{model_info.name} ") - # Show the number of times the model has been executed - # Only show the execution time if we are not running benchit() as this - # impacts time measurement. - if model_info.exec_time == 0 or model_info.build_model: + if unique_invocation.exec_time == 0 or model_info.build_model: exec_time = "" else: - exec_time = f" - {model_info.exec_time:.2f}s" - printing.logn( - f"(executed {model_info.executed}x{exec_time})", - c=printing.Colors.OKGREEN, - ) + exec_time = f" - {unique_invocation.exec_time:.2f}s" - if model_info.model_type == build.ModelType.PYTORCH: - print(f"{ident}\tModel Type:\tPytorch (torch.nn.Module)") - elif model_info.model_type == build.ModelType.KERAS: - print(f"{ident}\tModel Type:\tKeras (tf.keras.Model)") + if model_info.depth == 0 and len(model_info.unique_invocations) > 1: + if not model_visited: + printing.logn(f"{ident}{model_info.name}") + else: + printing.log(f"{ident}{model_info.name}") + printing.logn( + f" (executed {unique_invocation.executed}x{exec_time})", + c=printing.Colors.OKGREEN, + ) + + if (model_info.depth == 0 and not model_visited) or (model_info.depth != 0): + if model_info.depth == 0: + if model_info.model_type == build.ModelType.PYTORCH: + print(f"{ident}\tModel Type:\tPytorch (torch.nn.Module)") + elif model_info.model_type == build.ModelType.KERAS: + print(f"{ident}\tModel Type:\tKeras (tf.keras.Model)") + + # Display class of found model and the file where it was found + model_class = type(model_info.model) + print(f"{ident}\tClass:\t\t{model_class.__name__} ({model_class})") + if model_info.depth == 0: + print(f"{ident}\tLocation:\t{model_info.file}, line {model_info.line}") + + # Converting number of parameters to MB assuming 2 bytes per parameter + # NOTE: https://github.com/groq/mlagility/issues/330 suggests eliminating this assumption + model_size = model_info.params * 2 / (1024 * 1024) + model_size = "{:.1f}".format(model_size) if model_size > 0.1 else "<0.1" + print( + f"{ident}\tParameters:\t{'{:,}'.format(model_info.params)} ({model_size} MB)" + ) - # Display class of found model and the file where it was found - model_class = type(model_info.model) - print(f"{ident}\tClass:\t\t{model_class.__name__} ({model_class})") - if model_info.depth == 0: - print(f"{ident}\tLocation:\t{model_info.file}, line {model_info.line}") + if model_info.depth == 0 and len(model_info.unique_invocations) > 1: + printing.logn( + f"\n{ident}\tWith input shape {invocation_idx+1} (executed {unique_invocation.executed}x{exec_time})", + c=printing.Colors.OKGREEN, + ) - # Converting number of parameters to MB assuming 2 bytes per parameter - model_size = model_info.params * 2 / (1024 * 1024) - model_size = "{:.1f}".format(model_size) if model_size > 0.1 else "<0.1" - print(f"{ident}\tParameters:\t{'{:,}'.format(model_info.params)} ({model_size} MB)") - print(f"{ident}\tHash:\t\t" + model_hash) + # Prepare input shape to be printed + input_shape = dict(model_info.unique_invocations[invocation_hash].input_shapes) + input_shape = {key: value for key, value in input_shape.items() if value != ()} + input_shape = str(input_shape).replace("{", "").replace("}", "") + + print(f"{ident}\tInput Shape:\t{input_shape}") + print(f"{ident}\tHash:\t\t" + invocation_hash) # Print benchit results if benchit was run - if model_info.performance: + if unique_invocation.performance: printing.log(f"{ident}\tStatus:\t\t") printing.logn( - f"Successfully benchmarked on {model_info.performance.device} ({model_info.performance.runtime} v{model_info.performance.runtime_version})", - c=model_info.status_message_color, + f"Successfully benchmarked on {unique_invocation.performance.device} ({unique_invocation.performance.runtime} v{unique_invocation.performance.runtime_version})", + c=unique_invocation.status_message_color, ) printing.logn( - f"{ident}\t\t\tMean Latency:\t{model_info.performance.mean_latency:.3f}" - f"\t{model_info.performance.latency_units}" + f"{ident}\t\t\tMean Latency:\t{unique_invocation.performance.mean_latency:.3f}" + f"\t{unique_invocation.performance.latency_units}" ) printing.logn( - f"{ident}\t\t\tThroughput:\t{model_info.performance.throughput:.1f}" - f"\t{model_info.performance.throughput_units}" + f"{ident}\t\t\tThroughput:\t{unique_invocation.performance.throughput:.1f}" + f"\t{unique_invocation.performance.throughput_units}" ) print() else: - if model_info.is_target and model_info.build_model: + if unique_invocation.is_target and model_info.build_model: printing.log(f"{ident}\tStatus:\t\t") printing.logn( - f"{model_info.status_message}", c=model_info.status_message_color + f"{unique_invocation.status_message}", + c=unique_invocation.status_message_color, ) - if model_info.traceback is not None: + if unique_invocation.traceback is not None: if os.environ.get("MLAGILITY_TRACEBACK") != "False": - for line in model_info.traceback: + for line in unique_invocation.traceback: for subline in line.split("\n")[:-1]: print(f"{ident}\t{subline}") @@ -117,5 +172,4 @@ def print_model( ) else: print() - else: - print("") + print() diff --git a/src/mlagility/analysis/util.py b/src/mlagility/analysis/util.py index 8a1abc2f..0c2b4845 100644 --- a/src/mlagility/analysis/util.py +++ b/src/mlagility/analysis/util.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Callable, List, Union, Dict import inspect +import dataclasses import torch import onnx from onnxflow.common import printing @@ -15,6 +16,27 @@ class AnalysisException(Exception): """ +@dataclass +class UniqueInvocationInfo: + """ + Refers to unique static model invocations + (i.e. models executed with unique input shapes) + """ + + hash: Union[str, None] = None + parent_hash: Union[str, None] = None + performance: MeasuredPerformance = None + traceback: List[str] = None + inputs: Union[dict, None] = None + input_shapes: Union[dict, None] = None + executed: int = 0 + exec_time: float = 0.0 + status_message: str = "" + is_target: bool = False + status_message_color: printing.Colors = printing.Colors.ENDC + traceback_message_color: printing.Colors = printing.Colors.FAIL + + @dataclass class ModelInfo: model: torch.nn.Module @@ -26,18 +48,13 @@ class ModelInfo: depth: int = 0 hash: Union[str, None] = None parent_hash: Union[str, None] = None - inputs: Union[dict, None] = None - executed: int = 0 - exec_time: float = 0.0 old_forward: Union[Callable, None] = None - status_message: str = "" - status_message_color: printing.Colors = printing.Colors.ENDC - traceback_message_color: printing.Colors = printing.Colors.FAIL - is_target: bool = False + unique_invocations: Union[ + Dict[str, UniqueInvocationInfo], None + ] = dataclasses.field(default_factory=dict) + last_unique_invocation_executed: Union[str, None] = None build_model: bool = False model_type: build.ModelType = build.ModelType.PYTORCH - performance: MeasuredPerformance = None - traceback: List[str] = None def __post_init__(self): self.params = count_parameters(self.model, self.model_type) diff --git a/src/mlagility/version.py b/src/mlagility/version.py index 1fe90f6a..0aff436e 100644 --- a/src/mlagility/version.py +++ b/src/mlagility/version.py @@ -1 +1 @@ -__version__ = "3.1.4" +__version__ = "3.1.5" diff --git a/src/onnxflow/common/build.py b/src/onnxflow/common/build.py index 7a45a57f..5ccf6f68 100644 --- a/src/onnxflow/common/build.py +++ b/src/onnxflow/common/build.py @@ -168,7 +168,10 @@ def get_shapes_and_dtypes(inputs: dict): subkey = f"{key}[{i}]" shapes[subkey] = np.array(v).shape dtypes[subkey] = np.array(v).dtype.name - elif torch.is_tensor(value) or tf_helpers.is_keras_tensor(value): + elif torch.is_tensor(value): + shapes[key] = np.array(value.detach()).shape + dtypes[key] = np.array(value.detach()).dtype.name + elif tf_helpers.is_keras_tensor(value): shapes[key] = np.array(value).shape dtypes[key] = np.array(value).dtype.name elif isinstance(value, np.ndarray): diff --git a/test/analysis.py b/test/analysis.py index 447c16c1..53171c64 100644 --- a/test/analysis.py +++ b/test/analysis.py @@ -118,6 +118,24 @@ def __init__(self, **kwargs): print(parsed_args) +""", + "two_executions": """ +import torch +import timm +from mlagility.parser import parse + +# Creating model and set it to evaluation mode +model = timm.create_model("mobilenetv2_035", pretrained=False) +model.eval() + +# Creating inputs +inputs1 = torch.rand((1, 3, 28, 28)) +inputs2 = torch.rand((1, 3, 224, 224)) + +# Calling model +model(inputs1) +model(inputs2) +model(inputs1) """, } minimal_tokenizer = """ @@ -220,7 +238,7 @@ def test_04_build(self): output = run_analysis( [ "benchit", - "linear_pytorch.py::bf68fb06", + "linear_pytorch.py::76af2f62", "--max-depth", "1", "--build-only", @@ -231,7 +249,7 @@ def test_04_build(self): assert np.array_equal(output, (2, 0, 1)) def test_05_cache(self): - model_hash = "bf68fb06" + model_hash = "76af2f62" run_analysis( [ "benchit", @@ -321,7 +339,7 @@ def test_12_benchit_hashes(self): output = run_analysis( [ "benchit", - "linear_pytorch.py::bf68fb06", + "linear_pytorch.py::76af2f62", "--build-only", "--max-depth", "1", @@ -332,7 +350,7 @@ def test_12_benchit_hashes(self): assert np.array_equal(output, (2, 0, 1)) def test_13_clean_cache(self): - model_hash = "bf68fb06" + model_hash = "76af2f62" run_analysis( [ "benchit", @@ -358,6 +376,28 @@ def test_13_clean_cache(self): assert cache_is_lean(cache_dir, build_name) + def test_14_same_model_different_input_shapes(self): + output = run_analysis( + [ + "benchit", + "two_executions.py", + "--analyze-only", + ] + ) + assert np.array_equal(output, (2, 0, 0)) + + def test_15_same_model_different_input_shapes_maxdepth(self): + output = run_analysis( + [ + "benchit", + "two_executions.py", + "--analyze-only", + "--max-depth", + "1", + ] + ) + assert np.array_equal(output, (6, 0, 0)) + if __name__ == "__main__": unittest.main() diff --git a/test/cli.py b/test/cli.py index d41e856e..97061152 100644 --- a/test/cli.py +++ b/test/cli.py @@ -420,7 +420,7 @@ def test_004_cli_report(self): linear_summary["model_class"] == "TwoLayerModel" ), f"Wrong class found {linear_summary['model_class']}" assert ( - linear_summary["hash"] == "54dedbb1" + linear_summary["hash"] == "80b93950" ), f"Wrong hash found {linear_summary['hash']}" assert ( float(linear_summary["x86_latency"]) > 0