diff --git a/machop/chop/models/manual/bert_quantized/quant_config_bert.py b/machop/chop/models/manual/bert_quantized/quant_config_bert.py index cc6cf0294..0665a224b 100644 --- a/machop/chop/models/manual/bert_quantized/quant_config_bert.py +++ b/machop/chop/models/manual/bert_quantized/quant_config_bert.py @@ -6,7 +6,7 @@ import toml from chop.tools.config_load import convert_str_na_to_none -from ..quant_utils import parse_node_config +from ..quant_utils import parse_node_q_config logger = logging.getLogger(__name__) @@ -68,20 +68,20 @@ def create_a_layer_config( # fmt: off qc = { "attention": { - "query": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("query", linear_qc), "linear")), - "key": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("key", linear_qc), "linear")), - "value": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("value", linear_qc), "linear")), - "matmul_0": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("matmul_0", matmul_qc), "matmul")), - "matmul_1": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("matmul_1", matmul_qc), "matmul")), + "query": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("query", linear_qc), "linear")), + "key": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("key", linear_qc), "linear")), + "value": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("value", linear_qc), "linear")), + "matmul_0": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("matmul_0", matmul_qc), "matmul")), + "matmul_1": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("matmul_1", matmul_qc), "matmul")), "output": { - "dense": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("output", {}).get("dense", linear_qc), "linear")), + "dense": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("output", {}).get("dense", linear_qc), "linear")), }, }, "intermediate": { - "dense": deepcopy(parse_node_config(layer_qc.get("intermediate", {}).get("dense", linear_qc), "linear")), + "dense": deepcopy(parse_node_q_config(layer_qc.get("intermediate", {}).get("dense", linear_qc), "linear")), }, "output": { - "dense": deepcopy(parse_node_config(layer_qc.get("output", {}).get("dense", linear_qc), "linear")), + "dense": deepcopy(parse_node_q_config(layer_qc.get("output", {}).get("dense", linear_qc), "linear")), }, } # fmt: on @@ -94,10 +94,10 @@ def _parse_and_complete_config( ) -> dict: assert "default" in config, "Must provide a default config" default_qc: dict = config["default"] - linear_qc: dict = parse_node_config( + linear_qc: dict = parse_node_q_config( config.get("linear", default_qc), mase_op="linear" ) - matmul_qc: dict = parse_node_config( + matmul_qc: dict = parse_node_q_config( config.get("matmul", default_qc), mase_op="matmul" ) general_layer_qc: dict = config.get("model_layer", None) diff --git a/machop/chop/models/manual/llama_quantized/quant_config_llama.py b/machop/chop/models/manual/llama_quantized/quant_config_llama.py index b3988fd0f..d086a2a36 100644 --- a/machop/chop/models/manual/llama_quantized/quant_config_llama.py +++ b/machop/chop/models/manual/llama_quantized/quant_config_llama.py @@ -6,7 +6,7 @@ import toml from chop.tools.config_load import convert_str_na_to_none -from ..quant_utils import parse_node_config +from ..quant_utils import parse_node_q_config """ @@ -48,18 +48,18 @@ def create_a_layer_config( # fmt: off qc = { "self_attn": { - "q_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("q_proj", linear_qc), "linear")), - "k_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("k_proj", linear_qc), "linear")), - "v_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("v_proj", linear_qc), "linear")), - "o_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("o_proj", linear_qc), "linear")), - "rotary_positional_encoding": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("rotary_positional_encoding", rotary_positional_encoding_qc), "rotary_positional_encoding")), - "matmul_0": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("matmul_0", matmul_qc), "matmul")), - "matmul_1": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("matmul_1", matmul_qc), "matmul")), + "q_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("q_proj", linear_qc), "linear")), + "k_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("k_proj", linear_qc), "linear")), + "v_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("v_proj", linear_qc), "linear")), + "o_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("o_proj", linear_qc), "linear")), + "rotary_positional_encoding": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("rotary_positional_encoding", rotary_positional_encoding_qc), "rotary_positional_encoding")), + "matmul_0": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("matmul_0", matmul_qc), "matmul")), + "matmul_1": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("matmul_1", matmul_qc), "matmul")), }, "mlp": { - "gate_proj": deepcopy(parse_node_config(layer_qc.get("mlp", {}).get("gate_proj", linear_qc), "linear")), - "down_proj": deepcopy(parse_node_config(layer_qc.get("mlp", {}).get("down_proj", linear_qc), "linear")), - "up_proj": deepcopy(parse_node_config(layer_qc.get("mlp", {}).get("up_proj", linear_qc), "linear")) + "gate_proj": deepcopy(parse_node_q_config(layer_qc.get("mlp", {}).get("gate_proj", linear_qc), "linear")), + "down_proj": deepcopy(parse_node_q_config(layer_qc.get("mlp", {}).get("down_proj", linear_qc), "linear")), + "up_proj": deepcopy(parse_node_q_config(layer_qc.get("mlp", {}).get("up_proj", linear_qc), "linear")) }, } # fmt: on @@ -69,14 +69,14 @@ def create_a_layer_config( def _parse_and_complete_config(config: dict, num_hidden_layers: int) -> dict: assert "default" in config, "Must provide default config for by_name_parser" default_qc: dict = config["default"] - linear_qc: dict = parse_node_config( + linear_qc: dict = parse_node_q_config( config.get("linear", default_qc), mase_op="linear" ) - rotary_positional_encoding_qc: dict = parse_node_config( + rotary_positional_encoding_qc: dict = parse_node_q_config( config.get("rotary_positional_encoding", default_qc), mase_op="rotary_positional_encoding", ) - matmul_qc: dict = parse_node_config( + matmul_qc: dict = parse_node_q_config( config.get("matmul", default_qc), mase_op="matmul" ) general_layer_qc: dict = config.get("model_layer", None) diff --git a/machop/chop/models/manual/opt_quantized/quant_config_opt.py b/machop/chop/models/manual/opt_quantized/quant_config_opt.py index a76f8adb8..371fa6a62 100644 --- a/machop/chop/models/manual/opt_quantized/quant_config_opt.py +++ b/machop/chop/models/manual/opt_quantized/quant_config_opt.py @@ -5,9 +5,9 @@ import toml from ....tools.config_load import convert_str_na_to_none -from ....passes.graph import parse_node_config +from ....passes.graph import parse_node_q_config -from chop.passes.graph.transforms.quantize.quant_parsers import parse_quant_config +from chop.passes.graph.transforms.quantize.quant_parsers import parse_node_q_config """ An example of quant_config for opt @@ -43,15 +43,15 @@ def create_a_layer_config( # fmt: off qc = { "self_attn": { - "q_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("q_proj", linear_qc), "linear")), - "k_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("k_proj", linear_qc), "linear")), - "v_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("v_proj", linear_qc), "linear")), - "out_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("out_proj", linear_qc), "linear")), - "bmm_0": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("bmm_0", bmm_qc), "matmul")), - "bmm_1": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("bmm_1", bmm_qc), "matmul")), + "q_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("q_proj", linear_qc), "linear")), + "k_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("k_proj", linear_qc), "linear")), + "v_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("v_proj", linear_qc), "linear")), + "out_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("out_proj", linear_qc), "linear")), + "bmm_0": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("bmm_0", bmm_qc), "matmul")), + "bmm_1": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("bmm_1", bmm_qc), "matmul")), }, - "fc1": deepcopy(parse_node_config(layer_qc.get("fc1", linear_qc), "linear")), - "fc2": deepcopy(parse_node_config(layer_qc.get("fc2", linear_qc), "linear")), + "fc1": deepcopy(parse_node_q_config(layer_qc.get("fc1", linear_qc), "linear")), + "fc2": deepcopy(parse_node_q_config(layer_qc.get("fc2", linear_qc), "linear")), } # fmt: on return qc @@ -60,10 +60,10 @@ def create_a_layer_config( def _parse_and_complete_config(config: dict, num_hidden_layers: int) -> dict: assert "default" in config, "Must provide default config for by_name_parser" default_qc: dict = config["default"] - linear_qc: dict = parse_node_config( + linear_qc: dict = parse_node_q_config( config.get("linear", default_qc), mase_op="linear" ) - bmm_qc: dict = parse_node_config(config.get("bmm", default_qc), mase_op="matmul") + bmm_qc: dict = parse_node_q_config(config.get("bmm", default_qc), mase_op="matmul") general_layer_qc: dict = config.get("model_layer", None) # parsed config diff --git a/machop/chop/models/manual/quant_utils.py b/machop/chop/models/manual/quant_utils.py index c828d10dc..48506aca0 100644 --- a/machop/chop/models/manual/quant_utils.py +++ b/machop/chop/models/manual/quant_utils.py @@ -1,6 +1,6 @@ from typing import Callable -from chop.passes.graph import parse_node_config +from chop.passes.graph import parse_node_q_config from chop.passes.graph import quantized_func_map from chop.passes.graph import quantized_module_map @@ -16,4 +16,4 @@ def get_quantized_func(mase_op: str, quant_config: dict) -> Callable: def parse_op_quant_config(mase_op: str, config: dict) -> dict: - return parse_node_config(config=config, mase_op=mase_op) + return parse_node_q_config(config=config, mase_op=mase_op) diff --git a/machop/chop/passes/__init__.py b/machop/chop/passes/__init__.py index 59ddba950..517c2edbf 100644 --- a/machop/chop/passes/__init__.py +++ b/machop/chop/passes/__init__.py @@ -15,6 +15,7 @@ verify_common_metadata_analysis_pass, run_cosim_analysis_pass, get_synthesis_results, + test_verilog_analysis_pass, ) from .graph.transforms import ( prune_transform_pass, diff --git a/machop/chop/passes/graph/__init__.py b/machop/chop/passes/graph/__init__.py index 63a87f5b1..0135d0847 100644 --- a/machop/chop/passes/graph/__init__.py +++ b/machop/chop/passes/graph/__init__.py @@ -37,7 +37,7 @@ ) from .transforms.quantize import quantized_func_map, quantized_module_map -from .transforms.quantize.quant_parsers import parse_node_config +from .transforms.quantize.quant_parsers import parse_node_q_config ANALYSIS_PASSES = [ "init_metadata", diff --git a/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py index 87d839644..0150791e1 100644 --- a/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py +++ b/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py @@ -45,6 +45,20 @@ def add_component_source(node): node.meta["mase"]["hardware"]["dependence_files"] = [] node.meta["mase"]["hardware"]["device_id"] = -1 + # Init data parallelism to 1 and use DSE pass for exploration + node.meta["mase"]["hardware"]["parallelism"] = {} + args = node.meta["mase"]["common"]["args"] + for arg, arg_info in args.items(): + if isinstance(arg_info, dict): + node.meta["mase"]["hardware"]["parallelism"][arg] = [ + 1 for _ in range(len(arg_info["shape"])) + ] + + results = node.meta["mase"]["common"]["results"] + for result, result_info in results.items(): + node.meta["mase"]["hardware"]["parallelism"][result] = [ + 1 for _ in range(len(result_info["shape"])) + ] # Current only support on-chip parameters args = node.meta["mase"]["common"]["args"] @@ -81,17 +95,17 @@ def add_verilog_param(node): else 1 ) # If node data parallelism is set, take from hardware metadata - if node.meta["mase"]["hardware"]["parallelism"] is not None: - vp[_cap(arg + f"_parallelism_dim_{dim}")] = node.meta["mase"][ - "hardware" - ]["parallelism"][len(arg_info["shape"]) - 1 - dim] - # Otherwise, assign to tensor size by default - else: - vp[_cap(arg + f"_parallelism_dim_{dim}")] = ( - arg_info["shape"][len(arg_info["shape"]) - 1 - dim] - if dim < len(arg_info["shape"]) - else 1 - ) + assert node.meta["mase"]["hardware"]["parallelism"][arg] is not None + vp[_cap(arg + f"_parallelism_dim_{dim}")] = node.meta["mase"][ + "hardware" + ]["parallelism"][arg][len(arg_info["shape"]) - 1 - dim] + # # Otherwise, assign to tensor size by default + # else: + # vp[_cap(arg + f"_parallelism_dim_{dim}")] = ( + # arg_info["shape"][len(arg_info["shape"]) - 1 - dim] + # if dim < len(arg_info["shape"]) + # else 1 + # ) elif type(arg_info) == bool: vp[_cap(arg)] = 1 if arg_info else 0 else: @@ -107,16 +121,16 @@ def add_verilog_param(node): if dim < len(result_info["shape"]) else 1 ) - if node.meta["mase"]["hardware"]["parallelism"] is not None: - vp[_cap(result + f"_parallelism_dim_{dim}")] = node.meta["mase"][ - "hardware" - ]["parallelism"][len(result_info["shape"]) - 1 - dim] - else: - vp[_cap(result + f"_parallelism_dim_{dim}")] = ( - result_info["shape"][len(result_info["shape"]) - 1 - dim] - if dim < len(result_info["shape"]) - else 1 - ) + assert node.meta["mase"]["hardware"]["parallelism"] is not None + vp[_cap(result + f"_parallelism_dim_{dim}")] = node.meta["mase"][ + "hardware" + ]["parallelism"][result][len(result_info["shape"]) - 1 - dim] + # else: + # vp[_cap(result + f"_parallelism_dim_{dim}")] = ( + # result_info["shape"][len(result_info["shape"]) - 1 - dim] + # if dim < len(result_info["shape"]) + # else 1 + # ) else: vp[_cap(result)] = result_info @@ -369,11 +383,6 @@ def add_hardware_metadata_analysis_pass(graph, pass_args=None): for node in graph.nodes: add_component_source(node) - # Temporary: fix parallelism to small value to enable verilator simulation - for node in graph.nodes: - # Batch parallelism set to 1, data parallelism to 4 - node.meta["mase"]["hardware"]["parallelism"] = [1, 4] - # Add hardware parameters for node in graph.nodes: add_verilog_param(node) diff --git a/machop/chop/passes/graph/analysis/report/report_node.py b/machop/chop/passes/graph/analysis/report/report_node.py index cfa81812c..c2733a06a 100644 --- a/machop/chop/passes/graph/analysis/report/report_node.py +++ b/machop/chop/passes/graph/analysis/report/report_node.py @@ -5,6 +5,8 @@ import copy +import torch + logger = logging.getLogger(__name__) @@ -118,7 +120,7 @@ def report_node_hardware_type_analysis_pass(graph, pass_args: dict = {}): return graph, {} -def report_node_meta_param_analysis_pass(graph, pass_args: dict = None): +def report_node_meta_param_analysis_pass(graph, pass_args: dict = {}): """ Perform meta parameter analysis on the nodes in the graph and generate a report. @@ -131,6 +133,7 @@ def report_node_meta_param_analysis_pass(graph, pass_args: dict = None): :return: The analyzed graph and an empty dictionary. :rtype: tuple(MaseGraph, dict) """ + torch.set_printoptions(threshold=20) which_param = pass_args.get("which", ("all",)) assert isinstance(which_param, (list, tuple)) for param in which_param: @@ -184,4 +187,5 @@ def report_node_meta_param_analysis_pass(graph, pass_args: dict = None): with open(Path(save_path), "w") as f: f.write(table_txt) logger.info(f"Node meta param table is saved to {save_path}") + torch.set_printoptions(threshold=1000) return graph, {} diff --git a/machop/chop/passes/graph/analysis/verilog/cocotb.py b/machop/chop/passes/graph/analysis/verilog/cocotb.py new file mode 100644 index 000000000..45743b2d6 --- /dev/null +++ b/machop/chop/passes/graph/analysis/verilog/cocotb.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +import os, logging + +from mase_cocotb.random_test import check_results +from mase_cocotb.runner import mase_runner + +import cocotb +from cocotb.triggers import Timer +from cocotb.triggers import FallingEdge +from cocotb.clock import Clock + +logger = logging.getLogger(__name__) + + +# DUT test specifications +class VerificationCase: + def __init__(self, iterations=1, samples=10): + self.samples = samples + self.iterations = iterations + + +@cocotb.test() +async def test_top(dut): + """Test top-level model hardware design""" + samples = 1000 + test_case = VerificationCase(samples=samples) + + # Reset cycle + await Timer(20, units="ns") + dut.rst.value = 1 + await Timer(100, units="ns") + dut.rst.value = 0 + + # Create a 10ns-period clock on port clk + clock = Clock(dut.clk, 10, units="ns") + # Start the clock + cocotb.start_soon(clock.start()) + await Timer(500, units="ns") + + # Synchronize with the clock + dut.data_in_0_valid.value = 0 + dut.data_out_0_ready.value = 1 + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + + done = False + # Set a timeout to avoid deadlock + for i in range(samples * 100): + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + dut.data_in_0_valid.value = test_case.data_in.pre_compute() + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.pre_compute( + dut.data_out_0_valid.value + ) + await Timer(1, units="ns") + debug_state(dut, "Post-clk") + + dut.data_in_0_valid.value, dut.data_in_0.value = test_case.data_in.compute( + dut.data_in_0_ready.value + ) + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.compute( + dut.data_out_0_valid.value, dut.data_out_0.value + ) + debug_state(dut, "Pre-clk") + + if test_case.data_in.is_empty() and test_case.outputs.is_full(): + done = True + break + assert ( + done + ), "Deadlock detected or the simulation reaches the maximum cycle limit (fixed it by adjusting the loop trip count)" + + check_results(test_case.outputs.data, test_case.ref) diff --git a/machop/chop/passes/graph/analysis/verilog/test_verilog.py b/machop/chop/passes/graph/analysis/verilog/test_verilog.py index 50479a85c..fd3d4e39c 100644 --- a/machop/chop/passes/graph/analysis/verilog/test_verilog.py +++ b/machop/chop/passes/graph/analysis/verilog/test_verilog.py @@ -1,49 +1,281 @@ -import logging -from typing import Tuple, Dict -import math -import os -import time -from multiprocessing import Process, Queue +import logging, toml, os, glob, math +from pathlib import Path +import torch -from chop.passes.graph.utils import vf, v2p, init_project +import cocotb +from cocotb.runner import get_runner +from cocotb.triggers import Timer +from cocotb.triggers import FallingEdge +from cocotb.clock import Clock + +from chop.passes.graph.utils import vf +from mase_cocotb.random_test import RandomSource, RandomSink, check_results logger = logging.getLogger(__name__) +# ============================================================================= +# DUT test specifications +# ============================================================================= + -def get_test_parameters(mg): +def hardware_reshape(input_data, input_shape, tiling): """ - Extract the verilog parameters from the mase graph for cocotb testing + Apply 2D tiling. TODO: For higher dimensions, just faltten it in time. """ - return {} + assert len(input_shape) == 2, "Default hardware test bench only support 2D inputs" -def get_dummy_inputs(mg): - """ - Fetch test inputs from dataset or create a random one - """ - return {} + row_size = int(math.ceil(input_shape[0] / tiling[0])) + col_size = int(math.ceil(input_shape[1] / tiling[1])) + output_data = [ + [0 for _ in range(tiling[1] * tiling[0])] for _ in range(row_size * col_size) + ] + for i in range(row_size): + for j in range(col_size): + for ii in range(0, tiling[0]): + for jj in range(0, tiling[1]): + rowi = i * tiling[0] + ii + coli = j * tiling[1] + jj + if rowi < input_shape[0] and coli < input_shape[1]: + output_data[i * row_size + j][ii * tiling[1] + jj] = int( + input_data[rowi][coli] + ) + return output_data -def run_software_test(mg, inputs): - """ - Run software model on given inputs - """ - return {} +class VerificationCase: + # TODO: sample > 1 needs to be added + def __init__(self, samples=1): + self.samples = samples -def run_cocotb_test(mg, parameters, inputs): - """ - Create a cocotb test case and use mase runner to run hardware simulation - """ - return {} + def generate_tv(self, mg): + """ + Generate test vector and emit to ~/.mase.{pid}.toml + """ + + # Generate random inputs + test_inputs = {} + # TODO: here we just enumerate the inputs of the input nodes - which may be + # order insensitive and require manual connection when adding the graph to + # a system. + name_idx = 0 + for node in mg.nodes_in: + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if "data_in" in arg: + test_inputs[f"data_in_{name_idx}"] = torch.randint( + 32, arg_info["shape"] + ) + name_idx += 1 + logger.debug(test_inputs) + + # Get software results + y = mg.model(*list(test_inputs.values())) + + output_toml = {} + output_toml["samples"] = 1 + + # Reshape values for hardware testing + # TODO: assume 2D inputs + reshaped_inputs = {} + name_idx = 0 + for node in mg.nodes_in: + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if "data_in" in arg: + + # By default: the data is passed column by column + reshaped_inputs[f"data_in_{name_idx}"] = hardware_reshape( + test_inputs[f"data_in_{name_idx}"], + arg_info["shape"], + node.meta["mase"].parameters["hardware"]["parallelism"][arg], + ) + name_idx += 1 + + output_toml["inputs"] = reshaped_inputs + + assert len(mg.nodes_out) == 1, "Expect the model only has one output!" + reshaped_y = reshaped_inputs[f"data_out_0"] = hardware_reshape( + y, + mg.nodes_out[0] + .meta["mase"] + .parameters["common"]["results"]["data_out_0"]["shape"], + mg.nodes_out[0] + .meta["mase"] + .parameters["hardware"]["parallelism"]["data_out_0"], + ) + + output_toml["outputs"] = {"data_out_0": reshaped_y} + + home = Path.home() + Path(os.path.join(home, f".mase")).mkdir(parents=True, exist_ok=True) + fname = os.path.join(home, f".mase", f"tv.toml") + assert not os.path.isfile( + fname + ), f"Cannot create a temporary toml for testing data - {fname} already exists" + with open(fname, "w") as toml_file: + toml.dump(output_toml, toml_file) + + logger.debug(f"Test data saved to {fname}") + + def load_tv(self, fname=""): + home = Path.home() + fname = os.path.join(home, ".mase", f"tv.toml") + assert os.path.isfile( + fname + ), f"Cannot find the temporary toml for testing data - {fname}" + with open(fname, "r") as f: + input_toml = toml.load(f) + + self.samples = input_toml["samples"] + + for val, values in input_toml["inputs"].items(): + setattr( + self, + val, + RandomSource( + name=val, + samples=len(values), + num=len(values[0]), + max_stalls=0, + ), + ) + source = getattr(self, val) + source.data = values + + for val, values in input_toml["outputs"].items(): + setattr( + self, + val, + RandomSink( + name=val, + samples=len(values), + num=len(values[0]), + max_stalls=0, + ), + ) + self.ref = values + + os.remove(fname) + logger.debug(f"Test data loaded from {fname}") + + +class TestBehavior: + async def test_bench_behavior(dut): + """Test top-level model hardware design (default behavior)""" + test_case = VerificationCase() + test_case.load_tv() + + # Reset cycle + await Timer(20, units="ns") + dut.rst.value = 1 + await Timer(100, units="ns") + dut.rst.value = 0 + # Create a 10ns-period clock on port clk + clock = Clock(dut.clk, 10, units="ns") + # Start the clock + cocotb.start_soon(clock.start()) + await Timer(500, units="ns") -def compare_results(r0, r1): - return r0 == r1 + # Synchronize with the clock + dut.data_in_0_valid.value = 0 + dut.data_out_0_ready.value = 1 + await FallingEdge(dut.clk) + await FallingEdge(dut.clk) + done = False + # Set a timeout to avoid deadlock + for i in range(test_case.samples * 100): + await FallingEdge(dut.clk) -def test_verilog_analysis_pass(graph, pass_args={}): - """Use cocotb to test the model design in Verilog + dut.data_in_0_valid.value = test_case.data_in_0.pre_compute() + await Timer(1, units="ns") + + dut.data_out_0_ready.value = test_case.data_out_0.pre_compute( + dut.data_out_0_valid.value + ) + await Timer(1, units="ns") + + dut.data_in_0_valid.value, dut.data_in_0.value = ( + test_case.data_in_0.compute(dut.data_in_0_ready.value) + ) + await Timer(1, units="ns") + + dut.data_out_0_ready.value = test_case.data_out_0.compute( + dut.data_out_0_valid.value, dut.data_out_0.value + ) + + if test_case.data_in_0.is_empty() and test_case.data_out_0.is_full(): + done = True + break + assert ( + done + ), "Deadlock detected or the simulation reaches the maximum cycle limit (fixed it by adjusting the loop trip count)" + + check_results(test_case.data_out_0.data, test_case.ref) + + +# ============================================================================= +# Cocotb interface setup +# ============================================================================= + + +@cocotb.test() +async def test_top(dut): + await TestBehavior.test_bench_behavior(dut) + + +def get_dut_parameters(graph): + parameter_map = {} + + for node in graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + + for key, value in ( + node.meta["mase"].parameters["hardware"]["verilog_param"].items() + ): + if not isinstance(value, (int, float, complex, bool)): + value = '"' + value + '"' + assert ( + f"{node_name}_{key}" not in parameter_map.keys() + ), f"{node_name}_{key} already exists in the parameter map" + parameter_map[f"{node_name}_{key}"] = value + return parameter_map + + +def runner(mg, project_dir, top_name): + sim = os.getenv("SIM", "verilator") + + # TODO: Grab internal verilog source only. Need to include HLS hardware as well. + sv_srcs = [] + for v in glob.glob(os.path.join(project_dir, "hardware", "rtl", "*.sv")): + sv_srcs.append(os.path.relpath(v, os.getcwd())) + + p = get_dut_parameters(mg) + # logger.debug(p) + + # set parameters + extra_args = [] + for k, v in p.items(): + extra_args.append(f"-G{k}={v}") + logger.debug(extra_args) + runner = get_runner(sim) + runner.build( + verilog_sources=sv_srcs, + hdl_toplevel=top_name, + build_args=extra_args, + ) + + runner.test( + hdl_toplevel=top_name, + test_module=f"chop.passes.graph.analysis.verilog.test_verilog", + ) + + +def test_verilog_analysis_pass(mg, pass_args={}): + """Test the top-level hardware design using Cocotb :param graph: a MaseGraph :type graph: MaseGraph @@ -52,24 +284,36 @@ def test_verilog_analysis_pass(graph, pass_args={}): :return: return a tuple of a MaseGraph and an empty dict (no additional info to return) :rtype: tuple(MaseGraph, Dict) - - pass_args - project_dir -> str : the directory of the project for cosimulation - top_name -> str : top-level name + - samples -> str : the number of test inputs, samples = 1 by default + - test_bench -> str : the test bench behavior specified by the user, which runs end-to-end simulation by default + - preprocess -> str : preprocess of IO for testing, which generates random inputs by default """ - logger.info("Testing the model in Verilog...") + logger.info(f"Running hardware simulation using Cocotb") + logger.debug(f"test verilog pass pass_args = {pass_args}") project_dir = ( - pass_args["project_dir"] if "project_dir" in pass_args.keys() else "top" + pass_args["project_dir"] + if "project_dir" in pass_args.keys() + else Path.home() / ".mase" / "top" ) top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" + samples = pass_args["samples"] if "samples" in pass_args.keys() else 1 + + # TODO: Create a global variable traced by pass ID. This is bad... + test_case = VerificationCase(samples) + globals()["test_verilog_analysis_pass_tc"] = test_case + print(globals()) - parameters = get_test_parameters(graph) - inputs = get_dummy_inputs(graph) - software_results = run_software_test(graph, inputs) - hardware_results = run_cocotb_test(graph, parameters, inputs) + if "preprocess" in pass_args.keys(): + test_case.preprocess = pass_args["preprocess"] + if "test_bench" in pass_args.keys(): + test_case.test_bench_behavior = pass_args["test_bench"] - compare_results(software_results, hardware_results) + test_case.generate_tv(mg) - return graph, {} + runner(mg, project_dir, top_name) + return mg, {} diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py index 72a5b3b68..c5205a89a 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py @@ -1,2 +1,9 @@ -from .parse_quant_config import parse_node_config -from .update_node_meta import relink_node_meta, update_quant_meta_param +# from .parse_quant_config import parse_node_q_config +from .parse_q_config import parse_node_q_config + +# from .update_node_meta import relink_node_meta, update_quant_meta_param +from .update_node_meta import ( + relink_node_meta, + update_q_meta_param, + infer_result_dtype_and_precision, +) diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/parse_quant_config.py similarity index 98% rename from machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py rename to machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/parse_quant_config.py index 63bbea13a..50c2c0779 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/parse_quant_config.py @@ -3,8 +3,6 @@ from .utils import cp_multi_values, has_multi_keys """ QUANT_ARITH_ENTRIES -A mapping from (quantization arithmetic name) to (a mapping from (operand name) to (operand quantization spec name)) - Example A fixed point quantized value is defined by (width, frac_width), thus the mapping is defined as follows: @@ -368,7 +366,7 @@ def optional_operand_entry_exists(config: dict, entry_name: str) -> bool: return False -def parse_node_config(config: dict, mase_op: str, strict: bool = True) -> dict: +def parse_node_q_config(config: dict, mase_op: str, strict: bool = True) -> dict: """ Parse a node config from a MASE op config. diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/q_recipes.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/q_recipes.py new file mode 100644 index 000000000..8e02397c2 --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/q_recipes.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass, field + + +@dataclass +class QRecipeFixed: + """_summary_ + Fixed point quantization + """ + + name: str = field(default="fixed", init=False) + bypass: bool = field(default=False) + data_in_width: int + data_in_frac_width: int + weight_width: int | None = field(default=None) + weight_frac_width: int | None = field(default=None) + bias_width: int | None = field(default=None) + bias_frac_width: int | None = field(default=None) + + +@dataclass +class QRecipeLutNet: + """ + LUTNET quantization + + binarization_level (int): which level of binarization is applied, "binarized_weight" is only weights binarized others is no binarization + input_expanded (bool): If set to True, means all LUT's inputs are considered during calculations , else only the first input will considered and the remaining will be masked. + k: int # k entries of a LUT + levels (int): number of residual levels to use in LUTNET + dim: this is needed by convolution + """ + + name: str = field(default="lutnet", init=False) + + data_in_width: int + data_in_frac_width: int + data_in_binarization_level: int + data_in_input_expanded: bool + data_in_k: int + data_in_in_levels: int + data_in_dim: tuple[int] + + weight_width: int + weight_frac_width: int + weight_binarization_level: int + weight_input_expanded: bool + weight_k: int + weight_in_dim: tuple[int] + + bias_width: int + bias_frac_width: int + bias_binarization_level: int + bias_input_expanded: bool + bias_k: int + bias_in_dim: tuple[int] diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/update_node_meta.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/update_node_meta.py new file mode 100644 index 000000000..5f980e10f --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/update_node_meta.py @@ -0,0 +1,134 @@ +from functools import partial + + +def entry_to_list(config: dict, entry: str, suffixes: tuple[str]): + """e.g. [data_in_frac_width, data_in_width]""" + return list(config[f"{entry}_{suffix}"] for suffix in suffixes) + + +QUANT_ARITH_TO_SUFFIXES = { + "integer": ("width", "frac_width"), + "fixed": ("width", "frac_width"), + "binary": ( + "width", + "stochastic", + "bipolar", + ), # TODO: stochastic, bipolar flags are operational flag instead of precision. + "binary_residual": ( + "width", + "stochastic", + "bipolar", + ), # TODO: stochastic, bipolar flags are operational flag instead of precision. + "lutnet": ("width", "input_expanded", "k", "binarization_level"), + "logicnets": ("width", "frac_width"), + "ternary": ("width", "scaling_factor", "mean", "median", "max"), + "minifloat_ieee": ("width", "exponent_width", "exponent_bias"), + "minifloat_denorm": ("width", "exponent_width", "exponent_bias"), + "log": ("width", "exponent_bias"), + "block_fp": ("width", "exponent_width", "exponent_bias", "block_size"), + "block_minifloat": ("width", "exponent_width", "exponent_bias_width", "block_size"), + "block_log": ("width", "exponent_bias_width", "block_size"), +} + + +# quant_arith_to_list_fn = { +# : { +# : entry_to_list_ +# } +quant_arith_to_list_fn = {} +for quant_arith, suffixes in QUANT_ARITH_TO_SUFFIXES.items(): + quant_arith_to_list_fn[quant_arith] = partial(entry_to_list, suffixes=suffixes) + + +def update_arg(node, arg_name, dtype=None, precision=None, size=None): + if dtype is not None: + node.meta["mase"].parameters["common"]["args"][arg_name]["type"] = dtype + if precision is not None: + node.meta["mase"].parameters["common"]["args"][arg_name][ + "precision" + ] = precision + if size is not None: + node.meta["mase"].parameters["common"]["args"][arg_name]["size"] = size + + +MASE_OP_TO_INPUT_ENTRIES_AND_ARGS = { + # entry and arg corresponding to name in software and hardware mapping + "add": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "bmm": (("data_in", "weight"), ("data_in_0", "data_in_1")), + "conv1d": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), + "conv2d": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), + "matmul": (("data_in", "weight"), ("data_in_0", "data_in_1")), + "mul": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "linear": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), + "relu": (("data_in",), ("data_in_0",)), + "sub": (("data_in", "data_in"), ("data_in_0", "data_in_1")), +} + + +def update_result(node, output_name, dtype=None, precision=None, size=None): + if dtype is not None: + node.meta["mase"].parameters["common"]["results"][output_name]["type"] = dtype + if precision is not None: + node.meta["mase"].parameters["common"]["results"][output_name][ + "precision" + ] = precision + if size is not None: + node.meta["mase"].parameters["common"]["results"][output_name]["size"] = size + + +MASE_OP_TO_OUTPUT_ENTRIES = { + # entry and arg corresponding to name in software and hardware mapping + "add": (("data_out",), ("data_out_0",)), + "bmm": (("data_out",), ("data_out_0",)), + "conv1d": (("data_out",), ("data_out_0",)), + "conv2d": (("data_out",), ("data_out_0",)), + "matmul": (("data_out",), ("data_out_0",)), + "mul": (("data_out",), ("data_out_0",)), + "linear": (("data_out",), ("data_out_0",)), + "relu": (("data_out",), ("data_out_0",)), + "sub": (("data_out",), ("data_out_0",)), +} + + +def arg_exists(node, arg_name) -> bool: + return arg_name in node.meta["mase"].parameters["common"]["args"] + + +def update_quant_meta_param(node, config: dict, mase_op: str) -> None: + quant_arith = config["name"] + assert quant_arith in quant_arith_to_list_fn, f"Unknown quant_arith: {quant_arith}" + """ + MASE_OP_TO_INPUT_ENTRIES_AND_ARGS: Give a mapping between config file and mase model + How it works: + We find the precision of a certain paramter "e.g data_in" using the precision partial function. + + The precision partial function take a config file and entry "e.g data_in", + and it will search through all the attributes under this entry based on the quantisation scheme, + returning a list of precision with the order same as attributes defined in QUANT_ARITH_TO_SUFFIXES + + This precision list is then being mapped to mase data using 'arg' + """ + for entry, arg in zip(*MASE_OP_TO_INPUT_ENTRIES_AND_ARGS[mase_op]): + if not arg_exists(node, arg): + continue + update_arg( + node, + arg_name=arg, + dtype=quant_arith, + precision=quant_arith_to_list_fn[quant_arith](config, entry), + ) + + for entry, arg in zip(*MASE_OP_TO_OUTPUT_ENTRIES[mase_op]): + # Quantise all the output to fixed point. TODO: Make this automatic. Hardware will need change too + if quant_arith == "binary" or quant_arith == "binary_residual": + update_result( + node, + output_name=arg, + dtype="binary", + precision=[32, 0, 1], # [bitwidth, stochastic, bipolar] + ) + + +def relink_node_meta(node, model): + node.meta["mase"].node = node + node.meta["mase"].model = model diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py new file mode 100644 index 000000000..93491fb75 --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py @@ -0,0 +1,44 @@ +import copy +from .q_op_entries import FIXED_OP_ENTRIES + +""" +MASE_OP_TO_ENTRIES = { + : { + "required": (...), + "optional": (...) + } +} +""" + + +def get_q_op_entries(q_name: str, mase_op: str): + match q_name: + case "fixed": + op_entries = FIXED_OP_ENTRIES + case _: + raise ValueError(f"Unknown quantization arithmetic name: {q_name}") + + if mase_op not in op_entries: + raise ValueError( + f"Unknown MASE operation name: {mase_op} for quantization arithmetic: {q_name}" + ) + + return op_entries[mase_op] + + +def parse_node_q_config(q_config: dict, mase_op: str): + q_op_entries = get_q_op_entries(q_config["name"], mase_op) + + required_keys = q_op_entries["required"] + optional_keys = q_op_entries["optional"] + + parsed_q_config = {} + for k in required_keys: + assert k in q_config, f"Required key {k} not found in q_config: {q_config}" + parsed_q_config[k] = copy.deepcopy(q_config[k]) + + for k in optional_keys: + if k in q_config: + parsed_q_config[k] = copy.deepcopy(q_config[k]) + + return parsed_q_config diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/__init__.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/__init__.py new file mode 100644 index 000000000..af53c9d1f --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/__init__.py @@ -0,0 +1 @@ +from .fixed import FIXED_OP_ENTRIES diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py new file mode 100644 index 000000000..63d8fa0bc --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py @@ -0,0 +1,68 @@ +FIXED_OP_ENTRIES = { + "add": { + "required": ("name", "data_in_width", "data_in_frac_width"), + "optional": ("bypass",), + }, + "bmm": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass",), + }, + "conv1d": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass", "bias_width", "bias_frac_width"), + }, + "conv2d": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass", "bias_width", "bias_frac_width"), + }, + "linear": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass", "cache_quantized_weight", "bias_width", "bias_frac_width"), + }, + "matmul": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass",), + }, + "relu": { + "required": ("name", "data_in_width", "data_in_frac_width"), + "optional": ("bypass",), + }, + "sub": { + "required": ("name", "data_in_width", "data_in_frac_width"), + "optional": ("bypass",), + }, + "rotary_positional_encoding": { + "required": ("name", "data_in_width", "data_in_frac_width"), + "optional": ("bypass",), + }, +} diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py index 5f980e10f..985a3c4e9 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py @@ -1,134 +1,363 @@ -from functools import partial - - -def entry_to_list(config: dict, entry: str, suffixes: tuple[str]): - """e.g. [data_in_frac_width, data_in_width]""" - return list(config[f"{entry}_{suffix}"] for suffix in suffixes) - - -QUANT_ARITH_TO_SUFFIXES = { - "integer": ("width", "frac_width"), - "fixed": ("width", "frac_width"), - "binary": ( - "width", - "stochastic", - "bipolar", - ), # TODO: stochastic, bipolar flags are operational flag instead of precision. - "binary_residual": ( - "width", - "stochastic", - "bipolar", - ), # TODO: stochastic, bipolar flags are operational flag instead of precision. - "lutnet": ("width", "input_expanded", "k", "binarization_level"), - "logicnets": ("width", "frac_width"), - "ternary": ("width", "scaling_factor", "mean", "median", "max"), - "minifloat_ieee": ("width", "exponent_width", "exponent_bias"), - "minifloat_denorm": ("width", "exponent_width", "exponent_bias"), - "log": ("width", "exponent_bias"), - "block_fp": ("width", "exponent_width", "exponent_bias", "block_size"), - "block_minifloat": ("width", "exponent_width", "exponent_bias_width", "block_size"), - "block_log": ("width", "exponent_bias_width", "block_size"), -} +import logging +from ....utils import get_mase_op, get_mase_type + +logger = logging.getLogger(__name__) -# quant_arith_to_list_fn = { -# : { -# : entry_to_list_ -# } -quant_arith_to_list_fn = {} -for quant_arith, suffixes in QUANT_ARITH_TO_SUFFIXES.items(): - quant_arith_to_list_fn[quant_arith] = partial(entry_to_list, suffixes=suffixes) - - -def update_arg(node, arg_name, dtype=None, precision=None, size=None): - if dtype is not None: - node.meta["mase"].parameters["common"]["args"][arg_name]["type"] = dtype - if precision is not None: - node.meta["mase"].parameters["common"]["args"][arg_name][ - "precision" - ] = precision - if size is not None: - node.meta["mase"].parameters["common"]["args"][arg_name]["size"] = size - - -MASE_OP_TO_INPUT_ENTRIES_AND_ARGS = { - # entry and arg corresponding to name in software and hardware mapping - "add": (("data_in", "data_in"), ("data_in_0", "data_in_1")), - "bmm": (("data_in", "weight"), ("data_in_0", "data_in_1")), - "conv1d": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), - "conv2d": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), - "matmul": (("data_in", "weight"), ("data_in_0", "data_in_1")), - "mul": (("data_in", "data_in"), ("data_in_0", "data_in_1")), - "linear": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), - "relu": (("data_in",), ("data_in_0",)), - "sub": (("data_in", "data_in"), ("data_in_0", "data_in_1")), +OPERANDS_TO_META_ARG_NAMES = { + "add": { + "required": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "optional": None, + }, + "bmm": { + "required": (("data_in", "weight"), ("data_in_0", "data_in_1")), + "optional": None, + }, + "conv1d": { + "required": (("data_in", "weight"), ("data_in_0", "weight")), + "optional": (("bias",), ("bias",)), + }, + "conv2d": { + "required": (("data_in", "weight"), ("data_in_0", "weight")), + "optional": (("bias",), ("bias",)), + }, + "matmul": { + "required": (("data_in", "weight"), ("data_in_0", "data_in_1")), + "optional": None, + }, + "mul": { + "required": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "optional": None, + }, + "linear": { + "required": (("data_in", "weight"), ("data_in_0", "weight")), + "optional": (("bias",), ("bias",)), + }, + "relu": { + "required": (("data_in",), ("data_in_0",)), + "optional": None, + }, + "sub": { + "required": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "optional": None, + }, } -def update_result(node, output_name, dtype=None, precision=None, size=None): - if dtype is not None: - node.meta["mase"].parameters["common"]["results"][output_name]["type"] = dtype - if precision is not None: - node.meta["mase"].parameters["common"]["results"][output_name][ - "precision" - ] = precision - if size is not None: - node.meta["mase"].parameters["common"]["results"][output_name]["size"] = size - - -MASE_OP_TO_OUTPUT_ENTRIES = { - # entry and arg corresponding to name in software and hardware mapping - "add": (("data_out",), ("data_out_0",)), - "bmm": (("data_out",), ("data_out_0",)), - "conv1d": (("data_out",), ("data_out_0",)), - "conv2d": (("data_out",), ("data_out_0",)), - "matmul": (("data_out",), ("data_out_0",)), - "mul": (("data_out",), ("data_out_0",)), - "linear": (("data_out",), ("data_out_0",)), - "relu": (("data_out",), ("data_out_0",)), - "sub": (("data_out",), ("data_out_0",)), -} +def update_node_meta_param_fixed(node, q_config): + """Add fixed-point precision to node meta for quantization + + Precision format: [width, frac_width] + """ + mase_op = get_mase_op(node) + if mase_op not in OPERANDS_TO_META_ARG_NAMES: + raise ValueError( + f"Unsupported MASE operation name `{mase_op}` for updating node meta for quantization" + ) + required_args = OPERANDS_TO_META_ARG_NAMES[mase_op]["required"] + optional_args = OPERANDS_TO_META_ARG_NAMES[mase_op]["optional"] -def arg_exists(node, arg_name) -> bool: - return arg_name in node.meta["mase"].parameters["common"]["args"] + for operand_name, arg_name in zip(*required_args): + node.meta["mase"].parameters["common"]["args"][arg_name]["type"] = "fixed" + node.meta["mase"].parameters["common"]["args"][arg_name]["precision"] = [ + q_config[f"{operand_name}_width"], + q_config[f"{operand_name}_frac_width"], + ] + if optional_args is not None: + for operand_name, arg_name in zip(*optional_args): + if arg_name in node.meta["mase"].parameters["common"]["args"]: + if not ( + f"{operand_name}_width" in q_config + and f"{operand_name}_frac_width" in q_config + ): + raise RuntimeError( + f"Optional argument {arg_name} found in node meta, but not found in q_config: {q_config}" + ) + node.meta["mase"].parameters["common"]["args"][arg_name][ + "type" + ] = "fixed" + node.meta["mase"].parameters["common"]["args"][arg_name][ + "precision" + ] = [ + q_config[f"{operand_name}_width"], + q_config[f"{operand_name}_frac_width"], + ] -def update_quant_meta_param(node, config: dict, mase_op: str) -> None: - quant_arith = config["name"] - assert quant_arith in quant_arith_to_list_fn, f"Unknown quant_arith: {quant_arith}" - """ - MASE_OP_TO_INPUT_ENTRIES_AND_ARGS: Give a mapping between config file and mase model - How it works: - We find the precision of a certain paramter "e.g data_in" using the precision partial function. - The precision partial function take a config file and entry "e.g data_in", - and it will search through all the attributes under this entry based on the quantisation scheme, - returning a list of precision with the order same as attributes defined in QUANT_ARITH_TO_SUFFIXES +def relink_node_meta(node, model): + node.meta["mase"].node = node + node.meta["mase"].model = model + + +def update_q_meta_param(node, config: dict): + q_arith = config["name"] + + match q_arith: + case "fixed": + update_node_meta_param_fixed(node, config) + case _: + raise ValueError(f"Unsupported quantization arithmetic name: {q_arith}") + + +from torch.fx import Node + - This precision list is then being mapped to mase data using 'arg' +def find_next_compute_node(node: Node): + for n in node.users: + if get_mase_type(n) in ["module_related_func", "builtin_func"]: + return node, n + for n in node.users: + return find_next_compute_node(n) + return None, None + + +def find_prev_compute_node(node: Node): + for n in node.all_input_nodes: + if get_mase_type(n) in ["module_related_func", "builtin_func"]: + return node, n + for n in node.all_input_nodes: + return find_prev_compute_node(n) + return None, None + + +def infer_result_dtype_and_precision(node: Node): """ - for entry, arg in zip(*MASE_OP_TO_INPUT_ENTRIES_AND_ARGS[mase_op]): - if not arg_exists(node, arg): - continue - update_arg( - node, - arg_name=arg, - dtype=quant_arith, - precision=quant_arith_to_list_fn[quant_arith](config, entry), + ```text + n_1 n_2 + \ / + node + ``` + + assign node's args precision & dtype to n_1, n_2 results + """ + + if get_mase_type(node) == "placeholder": + # input node + input_node, next_node = find_next_compute_node(node) + if input_node is None: + logger.warning( + f"Failed to find next module_related_func node for input node {node.name}. Check if the graph contains module_related_func" + ) + return + i = 0 + for n in next_node.all_input_nodes: + if n is input_node: + break + i += 1 + arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_value = next_node.meta["mase"].parameters["common"]["args"][arg_key] + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result]["precision"] = ( + arg_value["precision"] + ) + + for arg in node.meta["mase"].parameters["common"]["args"]: + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value[ + "type" + ] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = ( + arg_value["precision"] + ) + + logger.debug( + f"Inferred arg & result dtype and precision for input node `{node.name}` using `{next_node.name}`" ) - for entry, arg in zip(*MASE_OP_TO_OUTPUT_ENTRIES[mase_op]): - # Quantise all the output to fixed point. TODO: Make this automatic. Hardware will need change too - if quant_arith == "binary" or quant_arith == "binary_residual": - update_result( - node, - output_name=arg, - dtype="binary", - precision=[32, 0, 1], # [bitwidth, stochastic, bipolar] + elif get_mase_type(node) in ["module_related_func", "builtin_func"]: + input_node, next_node = find_next_compute_node(node) + if next_node is None: + # this is the last compute node in the graph, use its args to infer dtype and precision + max_precision = None + max_dtype = None + max_bitwidth = 0 + for arg in node.meta["mase"].parameters["common"]["args"]: + if not isinstance( + node.meta["mase"].parameters["common"]["args"][arg], dict + ): + continue + if ( + not "precision" + in node.meta["mase"].parameters["common"]["args"][arg] + ): + continue + cur_width = node.meta["mase"].parameters["common"]["args"][arg][ + "precision" + ][0] + if cur_width > max_bitwidth: + max_bitwidth = cur_width + max_precision = node.meta["mase"].parameters["common"]["args"][arg][ + "precision" + ] + max_dtype = node.meta["mase"].parameters["common"]["args"][arg][ + "type" + ] + + if max_precision is None: + raise RuntimeError( + f"Failed to infer dtype and precision for module_related_func node {node.name}" + ) + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result][ + "type" + ] = max_dtype + node.meta["mase"].parameters["common"]["results"][result][ + "precision" + ] = max_precision + logger.debug( + f"Inferred result dtype and precision for module_related_func node `{node.name}` using its args" ) + else: + # use next compute node's args to infer dtype and precision + i = 0 + for n in next_node.all_input_nodes: + if n is input_node: + break + i += 1 + arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[ + i + ] + arg_value = next_node.meta["mase"].parameters["common"]["args"][arg_key] + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result][ + "precision" + ] = arg_value["precision"] + logger.debug( + f"Inferred result dtype and precision for module_related_func node `{node.name}` using `{next_node.name}`" + ) -def relink_node_meta(node, model): - node.meta["mase"].node = node - node.meta["mase"].model = model + elif get_mase_type(node) == "implicit_func": + input_node, next_node = find_next_compute_node(node) + user_node, prev_node = find_prev_compute_node(node) + + if next_node is not None: + i = 0 + for n in next_node.all_input_nodes: + if n is input_node: + break + i += 1 + arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[ + i + ] + arg_value = next_node.meta["mase"].parameters["common"]["args"][arg_key] + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result][ + "precision" + ] = arg_value["precision"] + + for arg in node.meta["mase"].parameters["common"]["args"]: + if not isinstance( + node.meta["mase"].parameters["common"]["args"][arg], dict + ): + continue + if ( + not "precision" + in node.meta["mase"].parameters["common"]["args"][arg] + ): + continue + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value[ + "type" + ] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = ( + arg_value["precision"] + ) + logger.debug( + f"Inferred arg & result dtype and precision for implicit_func node `{node.name}` using `{next_node.name}`" + ) + + elif prev_node is not None: + i = 0 + for n in prev_node.users: + if n is user_node: + break + i += 1 + arg_key = list(prev_node.meta["mase"].parameters["common"]["args"].keys())[ + i + ] + arg_value = prev_node.meta["mase"].parameters["common"]["args"][arg_key] + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result][ + "precision" + ] = arg_value["precision"] + + for arg in node.meta["mase"].parameters["common"]["args"]: + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value[ + "type" + ] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = ( + arg_value["precision"] + ) + logger.debug( + f"Inferred arg & result dtype and precision for implicit_func node `{node.name}` using `{prev_node.name}`" + ) + + else: + raise RuntimeError( + f"Failed to infer dtype and precision for implicit_func node {node.name} as it has no input nodes or users of type `module_related_func`" + ) + + elif get_mase_type(node) == "output": + # output node + # find the max precision of all input nodes + user_node, prev_node = find_prev_compute_node(node) + + if prev_node is None: + raise RuntimeError( + f"Failed to find prev module_related_func node for output node {node.name}" + ) + + max_precision = None + max_dtype = None + max_bitwidth = 0 + + i = 0 + for n in prev_node.users: + if n is user_node: + break + i += 1 + + arg_key = list(prev_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_value = prev_node.meta["mase"].parameters["common"]["args"][arg_key] + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result]["precision"] = ( + arg_value["precision"] + ) + + for arg in node.meta["mase"].parameters["common"]["args"]: + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value[ + "type" + ] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = ( + arg_value["precision"] + ) + + logger.debug( + f"Inferred dtype and precision for output node `{node.name}` using `{prev_node.name}`" + ) + + else: + raise RuntimeError( + f"Unsupported node type {get_mase_type(node)} for inferring dtype and precision" + ) diff --git a/machop/chop/passes/graph/transforms/quantize/quantize.py b/machop/chop/passes/graph/transforms/quantize/quantize.py index ba3c62f10..febb27e50 100644 --- a/machop/chop/passes/graph/transforms/quantize/quantize.py +++ b/machop/chop/passes/graph/transforms/quantize/quantize.py @@ -15,7 +15,12 @@ ) from .modify import create_new_fn, create_new_module -from .quant_parsers import parse_node_config, relink_node_meta, update_quant_meta_param +from .quant_parsers import ( + parse_node_q_config, + relink_node_meta, + update_q_meta_param, + infer_result_dtype_and_precision, +) from .summary import graph_iterator_compare_nodes, graph_iterator_node_histogram logger = logging.getLogger(__name__) @@ -63,7 +68,7 @@ def graph_iterator_quantize_by_type(graph, config: dict): node_config = get_config(config, get_mase_op(node)) if node_config["name"] is None: continue - node_config = parse_node_config(node_config, get_mase_op(node)) + node_config = parse_node_q_config(node_config, get_mase_op(node)) # if get_mase_type(node) == "module": if node.op == "call_module": ori_module = get_node_actual_target(node) @@ -82,7 +87,7 @@ def graph_iterator_quantize_by_type(graph, config: dict): parent_name, name = get_parent_name(node.target) setattr(graph.modules[parent_name], name, new_module) # update precision and type in meta.parameters["common"] - update_quant_meta_param(node, node_config, get_mase_op(node)) + update_q_meta_param(node, node_config) elif get_mase_type(node) in [ "builtin_func", "module_related_func", @@ -94,9 +99,19 @@ def graph_iterator_quantize_by_type(graph, config: dict): new_node.meta["mase"] = copy(node.meta["mase"]) # new_node.meta["mase"].node -> new_node relink_node_meta(new_node, model=graph.model) - update_quant_meta_param(new_node, node_config, get_mase_op(node)) + update_q_meta_param(new_node, node_config) node.replace_all_uses_with(new_node) graph.fx_graph.erase_node(node) + + for node in graph.fx_graph.nodes: + if get_mase_type(node) in [ + "module_related_func", + "builtin_func", + "output", + "placeholder", + "implicit_func", + ]: + infer_result_dtype_and_precision(node) return graph @@ -107,7 +122,7 @@ def graph_iterator_quantize_by_name(graph, config: dict): node_config = get_config(config, node.name) if node_config["name"] is None: continue - node_config = parse_node_config(node_config, get_mase_op(node)) + node_config = parse_node_q_config(node_config, get_mase_op(node)) output_layers_names = node_config.get("additional_layers_outputs", []) output_layers = [ get_node_target_by_name(graph, name) for name in output_layers_names @@ -128,7 +143,7 @@ def graph_iterator_quantize_by_name(graph, config: dict): ) parent_name, name = get_parent_name(node.target) setattr(graph.modules[parent_name], name, new_module) - update_quant_meta_param(node, node_config, get_mase_op(node)) + update_q_meta_param(node, node_config) logger.debug(f"Quantized module: {node.target} with config: {node_config}") elif get_mase_type(node) in [ "builtin_func", @@ -140,7 +155,7 @@ def graph_iterator_quantize_by_name(graph, config: dict): new_node.name = node.name new_node.meta["mase"] = copy(node.meta["mase"]) relink_node_meta(new_node, model=graph.model) - update_quant_meta_param(new_node, node_config, get_mase_op(node)) + update_q_meta_param(new_node, node_config) node.replace_all_uses_with(new_node) graph.fx_graph.erase_node(node) logger.debug( @@ -150,6 +165,15 @@ def graph_iterator_quantize_by_name(graph, config: dict): raise ValueError( "Unsupported node type for quantisation: {}".format(get_mase_type(node)) ) + for node in graph.fx_graph.nodes: + if get_mase_type(node) in [ + "module_related_func", + "builtin_func", + "output", + "placeholder", + "implicit_func", + ]: + infer_result_dtype_and_precision(node) return graph @@ -165,7 +189,7 @@ def graph_iterator_quantize_by_regex_name(graph, config: dict): node_config = get_config(config, matched_pattern) if node_config["name"] is None: continue - node_config = parse_node_config(node_config, get_mase_op(node)) + node_config = parse_node_q_config(node_config, get_mase_op(node)) # if get_mase_type(node) == "module": if node.op == "call_module": ori_module = graph.modules[node.target] @@ -174,7 +198,7 @@ def graph_iterator_quantize_by_regex_name(graph, config: dict): ) parent_name, name = get_parent_name(node.target) setattr(graph.modules[parent_name], name, new_module) - update_quant_meta_param(node, node_config, get_mase_op(node)) + update_q_meta_param(node, node_config) elif get_mase_type(node) in [ "builtin_func", "module_related_func", @@ -185,13 +209,22 @@ def graph_iterator_quantize_by_regex_name(graph, config: dict): new_node.name = node.name new_node.meta["mase"] = deepcopy(node.meta["mase"]) relink_node_meta(new_node, model=graph.model) - update_quant_meta_param(new_node, node_config, get_mase_op(node)) + update_q_meta_param(new_node, node_config) node.replace_all_uses_with(new_node) graph.fx_graph.erase_node(node) else: raise ValueError( "Unsupported node type for quantisation:{}".format(get_mase_type(node)) ) + for node in graph.fx_graph.nodes: + if get_mase_type(node) in [ + "module_related_func", + "builtin_func", + "output", + "placeholder", + "implicit_func", + ]: + infer_result_dtype_and_precision(node) return graph diff --git a/machop/chop/passes/graph/transforms/quantize/quantized_modules/linear.py b/machop/chop/passes/graph/transforms/quantize/quantized_modules/linear.py index 710396e49..3411fe218 100644 --- a/machop/chop/passes/graph/transforms/quantize/quantized_modules/linear.py +++ b/machop/chop/passes/graph/transforms/quantize/quantized_modules/linear.py @@ -116,25 +116,29 @@ def __init__( integer_quantizer, width=b_width, frac_width=b_frac_width ) - # def get_output_bitwidth(self): - # config = self.config - # w_width, w_frac = config["weight_width"], config["weight_frac_width"] - # x_width, x_frac = config["data_in_width"], config["data_in_frac_width"] - # bias_width = config["bias_width"] - - # ops = self.in_features - # product_width = w_width + x_width - # product_frac_width = w_frac + x_frac - # # *: + 1 for bias - # output_width = max(bias_width, product_width + ceil(log2(ops))) + 1 - # output_frac_width = product_frac_width - - # o_bitwidth = {} - # o_bitwidth["data_out_width"] = output_width - # o_bitwidth["data_out_frac_width"] = output_frac_width - # # o_bitwidth["product_width"] = product_width - # # o_bitwidth["product_frac_width"] = product_frac_width - # return o_bitwidth + self.quantized_weight_is_cached = False + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + # if bypss, there is no quantization + return F.linear(x, self.weight, self.bias) + else: + x = self.x_quantizer(x) + if self.config.get("cache_quantized_weight", False): + if not self.quantized_weight_is_cached: + w = self.w_quantizer(self.weight) + self.weight.copy_(w) + if self.bias is not None: + bias = self.b_quantizer(self.bias) + self.bias.copy_(bias) + self.quantized_weight_is_cached = True + else: + w = self.weight + bias = self.bias + else: + w = self.w_quantizer(self.weight) + bias = self.b_quantizer(self.bias) if self.bias is not None else None + return F.linear(x, w, bias) class LinearMinifloatDenorm(_LinearBase): diff --git a/machop/chop/passes/graph/transforms/verilog/emit_bram.py b/machop/chop/passes/graph/transforms/verilog/emit_bram.py index 95a4bb5cb..19a17f36d 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_bram.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_bram.py @@ -34,18 +34,21 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): (Mostly because Vivado does not support string type parameters...) """ - # TODO: Force bias to have a depth of 1 for now - if param_name != "bias": - # out_depth = node.meta["mase"].parameters["hardware"]["verilog_param"][ - # "DATA_IN_0_DEPTH" - # ] - out_depth = 1 - else: - out_depth = 1 - addr_width = clog2(out_depth) + 1 total_size = math.prod( node.meta["mase"].parameters["common"]["args"][param_name]["shape"] ) + + dim = len(node.meta["mase"].parameters["common"]["args"][param_name]["shape"]) + out_depth = 1 + for i in range(dim): + out_depth *= int( + math.ceil( + node.meta["mase"].parameters["common"]["args"][param_name]["shape"][i] + / node.meta["mase"].parameters["hardware"]["parallelism"][param_name][i] + ) + ) + addr_width = clog2(out_depth) + 1 + # The depth of parameters must match with the input depth assert ( total_size % out_depth == 0 @@ -71,7 +74,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): module {node_param_name}_rom #( parameter DWIDTH = {out_size*out_width}, parameter MEM_SIZE = {out_depth}, - parameter AWIDTH = $clog2(MEM_SIZE) + 1 + parameter AWIDTH = $clog2(MEM_SIZE+1) ) ( input clk, input logic [AWIDTH-1:0] addr0, @@ -83,9 +86,9 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): logic [DWIDTH-1:0] q0_t0; logic [DWIDTH-1:0] q0_t1; - // initial begin - // $readmemh("{data_name}", ram); - // end + initial begin + $readmemh("{data_name}", ram); + end assign q0 = q0_t1; @@ -96,9 +99,9 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): `timescale 1 ns / 1 ps module {node_param_name} #( - parameter DATA_WIDTH = 32'd{out_width*out_size}, - parameter ADDR_RANGE = 32'd{out_depth}, - parameter ADDR_WIDTH = $clog2(ADDR_RANGE) + 1 + parameter DATA_WIDTH = {out_width*out_size}, + parameter ADDR_RANGE = {out_depth}, + parameter ADDR_WIDTH = $clog2(ADDR_RANGE+1) ) ( input reset, input clk, @@ -126,7 +129,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): parameter {_cap(param_name)}_PARALLELISM_DIM_0 = 1, parameter {_cap(param_name)}_PARALLELISM_DIM_1 = 1, - parameter OUT_DEPTH = {_cap(param_name)}_TENSOR_SIZE_DIM_0 / {_cap(param_name)}_PARALLELISM_DIM_0 + parameter OUT_DEPTH = {_cap(param_name)}_TENSOR_SIZE_DIM_0 * {_cap(param_name)}_TENSOR_SIZE_DIM_1 / ({_cap(param_name)}_PARALLELISM_DIM_0 * {_cap(param_name)}_PARALLELISM_DIM_1) ) ( input clk, input rst, @@ -136,14 +139,14 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): input data_out_ready ); // 1-bit wider so IN_DEPTH also fits. - localparam COUNTER_WIDTH = $clog2(OUT_DEPTH); - logic [COUNTER_WIDTH:0] counter; + localparam COUNTER_WIDTH = $clog2(OUT_DEPTH+1); + logic [COUNTER_WIDTH-1:0] counter; always_ff @(posedge clk) if (rst) counter <= 0; else begin if (data_out_ready) begin - if (counter == OUT_DEPTH - 1) counter <= 0; + if (counter == COUNTER_WIDTH'(OUT_DEPTH) - 1) counter <= 0; else counter <= counter + 1; end end @@ -151,9 +154,9 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): logic ce0; assign ce0 = 1; - logic [{_cap(param_name)}_PRECISION_0*{_cap(param_name)}_TENSOR_SIZE_DIM_0-1:0] data_vector; + logic [{_cap(param_name)}_PRECISION_0*{_cap(param_name)}_PARALLELISM_DIM_0-1:0] data_vector; {node_param_name} #( - .DATA_WIDTH({_cap(param_name)}_PRECISION_0 * {_cap(param_name)}_TENSOR_SIZE_DIM_0), + .DATA_WIDTH({_cap(param_name)}_PRECISION_0 * {_cap(param_name)}_PARALLELISM_DIM_0), .ADDR_RANGE(OUT_DEPTH) ) {node_param_name}_mem ( .clk(clk), @@ -165,7 +168,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): // Cocotb/verilator does not support array flattening, so // we need to manually add some reshaping process. - for (genvar j = 0; j < {_cap(param_name)}_TENSOR_SIZE_DIM_0; j++) + for (genvar j = 0; j < {_cap(param_name)}_PARALLELISM_DIM_0; j++) assign data_out[j] = data_vector[{_cap(param_name)}_PRECISION_0*j+{_cap(param_name)}_PRECISION_0-1:{_cap(param_name)}_PRECISION_0*j]; assign data_out_valid = 1; diff --git a/machop/chop/passes/graph/transforms/verilog/emit_internal.py b/machop/chop/passes/graph/transforms/verilog/emit_internal.py index 8b40a1b83..d29186ae0 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_internal.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_internal.py @@ -46,7 +46,7 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): for node in graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue - if "INTERNAL_RTL" == node.meta["mase"].parameters["hardware"]["toolchain"]: + if "INTERNAL" == node.meta["mase"].parameters["hardware"]["toolchain"]: if ( hasattr(node.meta["mase"].module, "config") and node.meta["mase"].module.config.get("name", "") == "logicnets" @@ -70,11 +70,12 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): "..", "..", "..", - "..", "mase_components", ) for f in rtl_dependencies: - shutil.copy(os.path.join(hardware_dir, f), rtl_dir) + fname = os.path.join(hardware_dir, f) + assert os.path.isfile(fname), f"Cannot find file {fname}." + shutil.copy(fname, rtl_dir) return graph, {} diff --git a/machop/chop/passes/graph/transforms/verilog/emit_top.py b/machop/chop/passes/graph/transforms/verilog/emit_top.py index ead3b3077..1d9a2640f 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_top.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_top.py @@ -56,6 +56,10 @@ def param_needs_signals(node, param, value, qualifier="data_in"): ) +# ============================================================================= +# Emit design in a memory-independent dataflow graph +# ============================================================================= + # ============================================================================= # Verilog parameters # ============================================================================= @@ -65,16 +69,16 @@ class VerilogParameterEmitter: def __init__(self, graph): self.graph = graph - def emit(self, graph, parameter_map) -> Tuple[str, Dict[str, str]]: + def emit(self, parameter_map) -> Tuple[str, Dict[str, str]]: """ Emit parameters at the top-level for the top-level module Returns Tuple: - 1) list of parameters as a string to be embedded in Verilog file + 1) list of parameters as a string to be embedded in DFVerilog file """ - nodes_in = graph.nodes_in - nodes_out = graph.nodes_out + nodes_in = self.graph.nodes_in + nodes_out = self.graph.nodes_out node_in_name = vf(nodes_in[0].name) node_out_name = vf(nodes_out[0].name) @@ -88,15 +92,15 @@ def emit(self, graph, parameter_map) -> Tuple[str, Dict[str, str]]: # ============================================================================= -# Verilog interface +# DFVerilog interface # ============================================================================= -class VerilogInterfaceEmitter: +class DFVerilogInterfaceEmitter: def __init__(self, graph): self.graph = graph - def emit(self, graph, parameter_map): + def emit(self, parameter_map): """ Emit interface signal declarations for the top-level module """ @@ -142,17 +146,36 @@ def emit(self, graph, parameter_map): input data_out_{i}_ready,""" i += 1 - # TODO: emit off-chip parameter interface + # Emit all parameters as inputs (they will be mapped at the top-level) + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if not isinstance(arg_info, dict): + continue + if "data_in" not in arg: + arg_name = _cap(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + input [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0], + input {node_name}_{arg}_valid, + output {node_name}_{arg}_ready,""" + i += 1 return _remove_last_comma(interface) # ============================================================================= -# Verilog signals +# DFVerilog signals # ============================================================================= -class VerilogSignalEmitter: +class DFVerilogSignalEmitter: def __init__(self, graph): self.graph = graph @@ -161,15 +184,7 @@ def _emit_signals_top_internal(self, node, parameter_map): node_name = vf(node.name) # Input signals for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): - if not isinstance(arg_info, dict): - continue - - # Skip off-chip parameters as they will be directly connected to the top level - if ( - "data_in" in arg - or node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] - == "BRAM" - ): + if "data_in" in arg: arg_name = v2p(arg) parallelism_params = [ param @@ -188,14 +203,7 @@ def _emit_signals_top_internal(self, node, parameter_map): if not isinstance(result_info, dict): continue - # Skip off-chip parameters as they will be directly connected to the top level - if ( - "data_out" in result - or node.meta["mase"].parameters["hardware"]["interface"][result][ - "storage" - ] - == "BRAM" - ): + if "data_out" in result: result_name = v2p(result) parallelism_params = [ param @@ -259,13 +267,13 @@ def _emit_signals_top_hls(self, node, parameter_map): logic {node_name}_{key}_we0;""" return signals - def emit(self, graph, parameter_map): + def emit(self, parameter_map): """ Emit internal signal declarations for the top-level module """ signals = "" - for node in graph.fx_graph.nodes: + for node in self.graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue node_name = vf(node.name) @@ -284,37 +292,14 @@ def emit(self, graph, parameter_map): # ============================================================================= -# Verilog components (INTERNAL) +# DFVerilog components (INTERNAL) # ============================================================================= -class VerilogInternalComponentEmitter: +class DFVerilogInternalComponentEmitter: def __init__(self, graph): self.graph = graph - def _emit_module_parameters_top_internal(self, key, value, node, parameter_map): - node_name = vf(node.name) - component_name = f"{node_name}_{key}_source" - component_name_inst = f"{component_name}_0" - - parameters = "" - for param in node.meta["mase"].parameters["hardware"]["verilog_param"].keys(): - if f"{_cap(key)}_" in param: - parameters += f".{param}({node_name}_{param}),\n" - parameters = _remove_last_comma(parameters) - - return f""" -{component_name} #( -{parameters} -) {component_name_inst} ( - .clk(clk), - .rst(rst), - .data_out({node_name}_{key}), - .data_out_ready({node_name}_{key}_ready), - .data_out_valid({node_name}_{key}_valid) -); -""" - def emit(self, node, parameter_map): node_name = vf(node.name) component_name = node.meta["mase"].parameters["hardware"]["module"] @@ -332,7 +317,7 @@ def emit(self, node, parameter_map): # Emit component instantiation input signals for key, value in node.meta["mase"].parameters["common"]["args"].items(): - if "data" not in key: + if not isinstance(value, dict): continue signals += f""" .{key}({node_name}_{key}), @@ -342,7 +327,7 @@ def emit(self, node, parameter_map): # Emit component instantiation output signals for key, value in node.meta["mase"].parameters["common"]["results"].items(): - if "data" not in key: + if not isinstance(value, dict): continue signals += f""" .{key}({node_name}_{key}), @@ -363,26 +348,15 @@ def emit(self, node, parameter_map): ); """ - # Emit module parameter instances (e.g. weights and biases) - for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): - if "data_in" in arg: - continue - if not isinstance(arg_info, dict): - continue - - components += self._emit_module_parameters_top_internal( - arg, arg_info, node, parameter_map - ) - return components # ============================================================================= -# Verilog components (HLS) +# DFVerilog components (HLS) # ============================================================================= -class VerilogHLSComponentEmitter: +class DFVerilogHLSComponentEmitter: def __init__(self, graph): self.graph = graph @@ -468,17 +442,17 @@ def emit(self, node, parameter_map): # ============================================================================= -# Verilog components +# DFVerilog components # ============================================================================= -class VerilogComponentEmitter: +class DFVerilogComponentEmitter: def __init__(self, graph): self.graph = graph - self.internal_emitter = VerilogInternalComponentEmitter(graph) - self.hls_emitter = VerilogHLSComponentEmitter(graph) + self.internal_emitter = DFVerilogInternalComponentEmitter(graph) + self.hls_emitter = DFVerilogHLSComponentEmitter(graph) - def emit(self, graph, parameter_map): + def emit(self, parameter_map): """ Emit component declarations for the top-level module """ @@ -488,7 +462,7 @@ def emit(self, graph, parameter_map): // Component instantiation // -------------------------- """ - for node in graph.fx_graph.nodes: + for node in self.graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: @@ -502,14 +476,13 @@ def emit(self, graph, parameter_map): # ============================================================================= -# Verilog wires +# DFVerilog wires # ============================================================================= -class VerilogWireEmitter: - def __init__(self, graph, parameter_map): +class DFVerilogWireEmitter: + def __init__(self, graph): self.graph = graph - self.parameter_map = parameter_map self.wires = """ // -------------------------- @@ -517,7 +490,7 @@ def __init__(self, graph, parameter_map): // -------------------------- """ - def _emit_top_wires(self): + def _emit_top_wires(self, parameter_map): nodes_in = self.graph.nodes_in nodes_out = self.graph.nodes_out @@ -580,7 +553,7 @@ def _emit_node2node_wires(self): """ return wires - def emit(self): + def emit(self, parameter_map): """ Emit internal signal connections for the top-level module This includes two interconnection types: @@ -588,7 +561,7 @@ def emit(self): 2. Interface casting between inputs and outputs """ - self.wires += self._emit_top_wires() + self.wires += self._emit_top_wires(parameter_map) self.wires += self._emit_node2node_wires() return self.wires @@ -598,39 +571,38 @@ def emit(self): # ============================================================================= -class VerilogEmitter: +class DataflowEmitter: def __init__(self, graph): self.graph = graph - self.parameter_map = get_verilog_parameters(graph) - def emit(self, graph, top_name): - parameters_to_emit = VerilogParameterEmitter(graph).emit( - graph, self.parameter_map + def emit(self, top_name): + parameters_to_emit = VerilogParameterEmitter(self.graph).emit( + self.parameter_map ) - interface_to_emit = VerilogInterfaceEmitter(graph).emit( - graph, self.parameter_map + interface_to_emit = DFVerilogInterfaceEmitter(self.graph).emit( + self.parameter_map ) - signals_to_emit = VerilogSignalEmitter(graph).emit(graph, self.parameter_map) + signals_to_emit = DFVerilogSignalEmitter(self.graph).emit(self.parameter_map) - components_to_emit = VerilogComponentEmitter(graph).emit( - graph, self.parameter_map + components_to_emit = DFVerilogComponentEmitter(self.graph).emit( + self.parameter_map ) - wires_to_emit = VerilogWireEmitter(graph, self.parameter_map).emit() + wires_to_emit = DFVerilogWireEmitter(self.graph).emit(self.parameter_map) time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S") module_inst = """ // ===================================== -// Mase Hardware +// Mase Hardware (Dataflow) // Model: {} // {} // ===================================== `timescale 1ns/1ps -module {} #( +module {}_dataflow #( {} ) ( input clk, @@ -654,6 +626,438 @@ def emit(self, graph, top_name): return module_inst +# ============================================================================= +# Emit top-level design with memory mapping +# ============================================================================= + +# ============================================================================= +# MMVerilog interface +# ============================================================================= + + +class MMVerilogInterfaceEmitter: + def __init__(self, graph): + self.graph = graph + + def emit(self, parameter_map): + """ + Emit interface signal declarations for the top-level module + """ + + nodes_in = self.graph.nodes_in + nodes_out = self.graph.nodes_out + + interface = "" + # TODO: here we just enumerate the inputs of the input nodes - which may be + # order insensitive and require manual connection when adding the graph to + # a system. + i = 0 + for node in nodes_in: + node_name = vf(node.name) + for arg in node.meta["mase"].parameters["common"]["args"].keys(): + if "data_in" in arg: + arg_name = _cap(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + input [{node_name}_{arg_name}_PRECISION_0-1:0] data_in_{i} [{'*'.join(parallelism_params)}-1:0], + input data_in_{i}_valid, + output data_in_{i}_ready,""" + i += 1 + + i = 0 + for node in nodes_out: + node_name = vf(node.name) + for result in node.meta["mase"].parameters["common"]["results"].keys(): + if "data_out" in result: + result_name = _cap(result) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{result_name}_PARALLELISM_DIM" in param + ] + interface += f""" + output [{node_name}_{result_name}_PRECISION_0-1:0] data_out_{i} [{'*'.join(parallelism_params)}-1:0], + output data_out_{i}_valid, + input data_out_{i}_ready,""" + i += 1 + + # TODO: emit off-chip parameter interface + + return _remove_last_comma(interface) + + +# ============================================================================= +# MMVerilog signals +# ============================================================================= + + +class MMVerilogSignalEmitter: + def __init__(self, graph): + self.graph = graph + + def _emit_signals_top_internal(self, node, parameter_map): + signals = "" + node_name = vf(node.name) + # Input signals + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if not isinstance(arg_info, dict): + continue + if "data_in" in arg: + continue + + # Skip off-chip parameters as they will be directly connected to the top level + if ( + node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] + == "BRAM" + ): + arg_name = v2p(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + signals += f""" +logic [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0]; +logic {node_name}_{arg}_valid; +logic {node_name}_{arg}_ready;""" + + # Output signals + for result, result_info in ( + node.meta["mase"].parameters["common"]["results"].items() + ): + if not isinstance(result_info, dict): + continue + if "data_out" in result: + continue + + # Skip off-chip parameters as they will be directly connected to the top level + if ( + node.meta["mase"].parameters["hardware"]["interface"][result]["storage"] + == "BRAM" + ): + result_name = v2p(result) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{result_name}_PARALLELISM_DIM" in param + ] + signals += f""" +logic [{node_name}_{result_name}_PRECISION_0-1:0] {node_name}_{result} [{'*'.join(parallelism_params)}-1:0]; +logic {node_name}_{result}_valid; +logic {node_name}_{result}_ready;""" + + return signals + + def _emit_signals_top_hls(self, node, parameter_map): + """ + TODO + """ + + node_name = vf(node.name) + # Control signals for HLS component + signals = f""" +logic {node_name}_start; +logic {node_name}_done; +logic {node_name}_idle; +logic {node_name}_ready; +logic {node_name}_ce;""" + + # Input signals + for key, value in node.meta["mase"].parameters["common"]["args"].items(): + # No internal signals if the memory is stored off chip + if not param_needs_signals(node, key, value, qualifier="data_in"): + continue + + cap_key = v2p(key) + size = math.prod(value["shape"]) + + if key != "data_in": + a_width = math.ceil(math.log2(size)) + else: + depth = parameter_map[f"{node_name}_{cap_key}_DEPTH"] + a_width = math.ceil(math.log2(depth * size)) + + signals += f""" +logic [{node_name}_{cap_key}_PRECISION_0-1:0] {node_name}_{key}_q0; +logic [{a_width}-1:0] {node_name}_{key}_address0; +logic {node_name}_{key}_ce0;""" + + # Output signals + for key, value in node.meta["mase"].parameters["common"]["results"].items(): + # No internal signals if the memory is stored off chip + if not param_needs_signals(node, key, value, qualifier="data_out"): + continue + + cap_key = v2p(key) + size = math.prod(value["shape"]) + a_width = math.ceil(math.log2(size)) + signals += f""" +logic [{node_name}_{cap_key}_PRECISION_0-1:0] {node_name}_{key}_d0; +logic [{a_width}-1:0] {node_name}_{key}_address0; +logic {node_name}_{key}_ce0; +logic {node_name}_{key}_we0;""" + return signals + + def emit(self, parameter_map): + """ + Emit internal signal declarations for the top-level module + """ + + signals = "" + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: + signals += self._emit_signals_top_internal(node, parameter_map) + elif node.meta["mase"].parameters["hardware"]["toolchain"] == "HLS": + signals += self._emit_signals_top_hls(node, parameter_map) + else: + assert False, "Unknown node toolchain for signal declarations." + + return signals + + +# ============================================================================= +# MMVerilog components (INTERNAL) +# ============================================================================= + + +class MMVerilogInternalComponentEmitter: + def __init__(self, graph): + self.graph = graph + + def _emit_module_parameters_top_internal(self, key, value, node, parameter_map): + node_name = vf(node.name) + component_name = f"{node_name}_{key}_source" + component_name_inst = f"{component_name}_0" + + parameters = "" + for param in node.meta["mase"].parameters["hardware"]["verilog_param"].keys(): + if f"{_cap(key)}_" in param: + parameters += f".{param}({node_name}_{param}),\n" + parameters = _remove_last_comma(parameters) + + return f""" +{component_name} #( +{parameters} +) {component_name_inst} ( + .clk(clk), + .rst(rst), + .data_out({node_name}_{key}), + .data_out_ready({node_name}_{key}_ready), + .data_out_valid({node_name}_{key}_valid) +); +""" + + def emit(self, node, parameter_map): + node_name = vf(node.name) + component_name = node.meta["mase"].parameters["hardware"]["module"] + signals = "" + + # Emit module parameter instances (e.g. weights and biases) + components = "" + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if "data_in" in arg: + continue + if not isinstance(arg_info, dict): + continue + + components += self._emit_module_parameters_top_internal( + arg, arg_info, node, parameter_map + ) + + return components + + +# ============================================================================= +# MMVerilog top interface connected to the dataflow design +# ============================================================================= + + +class MMVerilogTopInterfaceEmitter: + def __init__(self, graph): + self.graph = graph + + def emit(self, parameter_map): + """ + Emit interface signal declarations for the top-level module + """ + + nodes_in = self.graph.nodes_in + nodes_out = self.graph.nodes_out + + interface = """ + .clk(clk), + .rst(rst), +""" + # TODO: here we just enumerate the inputs of the input nodes - which may be + # order insensitive and require manual connection when adding the graph to + # a system. + i = 0 + for node in nodes_in: + node_name = vf(node.name) + for arg in node.meta["mase"].parameters["common"]["args"].keys(): + if "data_in" in arg: + arg_name = _cap(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + .data_in_{i}(data_in_{i}), + .data_in_{i}_valid(data_in_{i}_valid), + .data_in_{i}_ready(data_in_{i}_ready),""" + i += 1 + + i = 0 + for node in nodes_out: + node_name = vf(node.name) + for result in node.meta["mase"].parameters["common"]["results"].keys(): + if "data_out" in result: + result_name = _cap(result) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{result_name}_PARALLELISM_DIM" in param + ] + interface += f""" + .data_out_{i}(data_out_{i}), + .data_out_{i}_valid(data_out_{i}_valid), + .data_out_{i}_ready(data_out_{i}_ready),""" + i += 1 + + # Emit all parameters as inputs (they will be mapped at the top-level) + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if not isinstance(arg_info, dict): + continue + if "data_in" not in arg: + arg_name = _cap(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + .{node_name}_{arg}({node_name}_{arg}), + .{node_name}_{arg}_valid({node_name}_{arg}_valid), + .{node_name}_{arg}_ready({node_name}_{arg}_ready),""" + i += 1 + + return _remove_last_comma(interface) + + +# ============================================================================= +# MMVerilog components +# ============================================================================= + + +class MMVerilogComponentEmitter: + def __init__(self, graph): + self.graph = graph + self.internal_emitter = MMVerilogInternalComponentEmitter(graph) + + def emit(self, parameter_map, top): + """ + Emit component declarations for the top-level module + """ + + # Write node parameters + top_parameters = "" + for key, value in parameter_map.items(): + top_parameters += f""" .{key}({key}),\n""" + top_parameters = _remove_last_comma(top_parameters) + + interface = MMVerilogTopInterfaceEmitter(self.graph).emit(parameter_map) + + components = f""" +// -------------------------- +// Component instantiation +// -------------------------- +{top}_dataflow #({top_parameters} +) {top}_df_inst ({interface} +); +""" + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: + components += self.internal_emitter.emit(node, parameter_map) + elif node.meta["mase"].parameters["hardware"]["toolchain"] == "HLS": + # Assume all parameters in HLS components are local + continue + else: + assert False, "Unknown node toolchain for signal declarations." + + return components + + +# ============================================================================= +# Emit MMVerilog +# ============================================================================= + + +class MemoryMapEmitter: + def __init__(self, graph): + self.graph = graph + + self.parameter_map = get_verilog_parameters(graph) + + def emit(self, top_name): + parameters_to_emit = VerilogParameterEmitter(self.graph).emit( + self.parameter_map + ) + + interface_to_emit = MMVerilogInterfaceEmitter(self.graph).emit( + self.parameter_map + ) + + signals_to_emit = MMVerilogSignalEmitter(self.graph).emit(self.parameter_map) + + components_to_emit = MMVerilogComponentEmitter(self.graph).emit( + self.parameter_map, top_name + ) + + time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S") + + module_inst = """ +// ===================================== +// Mase Hardware (Memory Map) +// Model: {} +// {} +// ===================================== +`timescale 1ns/1ps +module {} #( +{} +) ( + input clk, + input rst, +{} +); +{} +{} +endmodule + """.format( + top_name, + time_to_emit, + top_name, + parameters_to_emit, + interface_to_emit, + signals_to_emit, + components_to_emit, + ) + return module_inst + + def emit_verilog_top_transform_pass(graph, pass_args={}): """Emit the top-level model design in Verilog @@ -680,10 +1084,18 @@ def emit_verilog_top_transform_pass(graph, pass_args={}): ) top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" init_project(project_dir) + logger.info(f"Project path: {project_dir}") + rtl_dir = os.path.join(project_dir, "hardware", "rtl") - top = VerilogEmitter(graph).emit(graph, top_name) + # Emit device-independent hardware design in dataflow + df = DataflowEmitter(graph).emit(top_name) + df_file = os.path.join(rtl_dir, f"{top_name}_df.sv") + with open(df_file, "w") as df_design: + df_design.write(df) + # Emit memory mapping with BRAMs for the top-level hardware design + top = MemoryMapEmitter(graph).emit(top_name) top_file = os.path.join(rtl_dir, f"{top_name}.sv") with open(top_file, "w") as top_design: top_design.write(top) diff --git a/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py b/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py index a84625f93..9087e2ab3 100644 --- a/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py +++ b/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py @@ -12,6 +12,11 @@ "common/rtl/skid_buffer.sv", "common/rtl/join2.sv", "cast/rtl/fixed_rounding.sv", + "cast/rtl/fixed_round.sv", + ], + "relu": [ + "activations/rtl/fixed_relu.sv", + "cast/rtl/fixed_rounding.sv", + "cast/rtl/fixed_round.sv", ], - "relu": ["activations/fixed_relu.sv"], } diff --git a/machop/configs/tests/quantize/fixed.toml b/machop/configs/tests/quantize/fixed.toml index d7d109695..cf1eab877 100644 --- a/machop/configs/tests/quantize/fixed.toml +++ b/machop/configs/tests/quantize/fixed.toml @@ -1,15 +1,25 @@ -model = "toy" -dataset = "toy-tiny" +model="toy" +dataset="toy-tiny" [passes.quantize] -by = "type" -report = true + by="type" + report=true -[passes.quantize.default.config] -name = "fixed" -data_in_width = 8 -data_in_frac_width = 3 -weight_width = 8 -weight_frac_width = 3 -bias_width = 8 -bias_frac_width = 3 \ No newline at end of file + [passes.quantize.default.config] + name="fixed" + cache_quantized_weight=true + data_in_width=8 + data_in_frac_width=3 + weight_width=8 + weight_frac_width=3 + bias_width=8 + bias_frac_width=3 + + [passes.quantize.relu.config] + name="fixed" + data_in_width=4 + data_in_frac_width=2 + weight_width=4 + weight_frac_width=2 + bias_width=4 + bias_frac_width=2 diff --git a/machop/mase_components/activations/rtl/fixed_relu.sv b/machop/mase_components/activations/rtl/fixed_relu.sv index ceef9c053..2ef87b1bc 100644 --- a/machop/mase_components/activations/rtl/fixed_relu.sv +++ b/machop/mase_components/activations/rtl/fixed_relu.sv @@ -28,14 +28,29 @@ module fixed_relu #( input logic data_out_0_ready ); - for (genvar i = 0; i < DATA_IN_0_TENSOR_SIZE_DIM_0; i++) begin : ReLU + logic [DATA_IN_0_PRECISION_0-1:0] data[DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0] ; + + for ( + genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1; i++ + ) begin : ReLU always_comb begin // negative value, put to zero - if ($signed(data_in_0[i]) <= 0) data_out_0[i] = '0; - else data_out_0[i] = data_in_0[i]; + if ($signed(data_in_0[i]) <= 0) data[i] = '0; + else data[i] = data_in_0[i]; end end + fixed_rounding #( + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1), + .IN_WIDTH(DATA_IN_0_PRECISION_0), + .IN_FRAC_WIDTH(DATA_IN_0_PRECISION_1), + .OUT_WIDTH(DATA_OUT_0_PRECISION_0), + .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1) + ) fr_inst ( + .data_in (data), + .data_out(data_out_0) + ); + assign data_out_0_valid = data_in_0_valid; assign data_in_0_ready = data_out_0_ready; diff --git a/machop/mase_components/linear/rtl/fixed_linear.sv b/machop/mase_components/linear/rtl/fixed_linear.sv index de6958925..ce954dfa0 100644 --- a/machop/mase_components/linear/rtl/fixed_linear.sv +++ b/machop/mase_components/linear/rtl/fixed_linear.sv @@ -1,76 +1,77 @@ `timescale 1ns / 1ps - -/* - * Constrained by WEIGHT_PARALLELISM_DIM_0 == DATA_OUT_0_PARALLELISM_DIM_0 - * -*/ - module fixed_linear #( /* verilator lint_off UNUSEDPARAM */ - parameter HAS_BIAS = 0, + parameter HAS_BIAS = 1, - parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_0 = 8, parameter DATA_IN_0_PRECISION_1 = 3, - parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 4, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 1, parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, - parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, - parameter IN_0_DEPTH = DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0, + parameter DATA_IN_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_2 = 1, - parameter WEIGHT_PRECISION_0 = 16, + parameter WEIGHT_PRECISION_0 = 8, parameter WEIGHT_PRECISION_1 = 3, - parameter WEIGHT_TENSOR_SIZE_DIM_0 = 32, + parameter WEIGHT_TENSOR_SIZE_DIM_0 = 1, + parameter WEIGHT_PARALLELISM_DIM_0 = 1, parameter WEIGHT_TENSOR_SIZE_DIM_1 = 1, - parameter WEIGHT_PARALLELISM_DIM_0 = 4, parameter WEIGHT_PARALLELISM_DIM_1 = 1, + parameter WEIGHT_TENSOR_SIZE_DIM_2 = 1, + parameter WEIGHT_PARALLELISM_DIM_2 = 1, - parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2( - DATA_IN_0_TENSOR_SIZE_DIM_0 - ) + $clog2( - IN_0_DEPTH - ) + HAS_BIAS, - parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1 + WEIGHT_PRECISION_1, - parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 4, + parameter DATA_OUT_0_PRECISION_0 = 8, + parameter DATA_OUT_0_PRECISION_1 = 3, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = 1, parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1, - parameter DATA_OUT_0_PARALLELISM_DIM_0 = WEIGHT_PARALLELISM_DIM_0, parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_2 = 1, - parameter BIAS_PRECISION_0 = 16, + parameter BIAS_PRECISION_0 = 8, parameter BIAS_PRECISION_1 = 3, - parameter BIAS_TENSOR_SIZE_DIM_0 = DATA_OUT_0_TENSOR_SIZE_DIM_0, - parameter BIAS_TENSOR_SIZE_DIM_1 = 1, + parameter BIAS_TENSOR_SIZE_DIM_0 = 1, parameter BIAS_PARALLELISM_DIM_0 = 1, - parameter BIAS_PARALLELISM_DIM_1 = 1 + parameter BIAS_TENSOR_SIZE_DIM_1 = 1, + parameter BIAS_PARALLELISM_DIM_1 = 1, + parameter BIAS_TENSOR_SIZE_DIM_2 = 1, + parameter BIAS_PARALLELISM_DIM_2 = 1 + /* verilator lint_on UNUSEDPARAM */ ) ( input clk, input rst, // input port for data_inivations - input [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], - input data_in_0_valid, - output data_in_0_ready, + input [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_1-1:0], + input data_in_0_valid, + output data_in_0_ready, // input port for weight - input [WEIGHT_PRECISION_0-1:0] weight [WEIGHT_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_0-1:0], + input [WEIGHT_PRECISION_0-1:0] weight [WEIGHT_PARALLELISM_DIM_0*WEIGHT_PARALLELISM_DIM_1-1:0], input weight_valid, output weight_ready, /* verilator lint_off UNUSEDSIGNAL */ - input [BIAS_PRECISION_0-1:0] bias[BIAS_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_0-1:0], - input bias_valid, + input [BIAS_PRECISION_0-1:0] bias [BIAS_PARALLELISM_DIM_0-1:0], + input bias_valid, /* verilator lint_on UNUSEDSIGNAL */ - output bias_ready, + output bias_ready, - output [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1-1:0], output data_out_0_valid, input data_out_0_ready ); localparam FDP_WIDTH = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2( - DATA_IN_0_PARALLELISM_DIM_0 + DATA_IN_0_PARALLELISM_DIM_1 ); - localparam ACC_WIDTH = FDP_WIDTH + $clog2(IN_0_DEPTH); + localparam ACC_WIDTH = FDP_WIDTH + $clog2( + DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1 + ); + logic [ACC_WIDTH-1:0] data_out_buff[DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1-1:0]; logic fdp_join_valid, fdp_join_ready; join2 #() fdp_join_inst ( @@ -84,19 +85,19 @@ module fixed_linear #( // Assume the parallelised hardware above have the same arrival time // which means that they always have the same state. So we can just // pick one of the valid signal to use. - logic [WEIGHT_PARALLELISM_DIM_0-1:0] fdp_data_ready, fdp_weight_ready; + logic [DATA_OUT_0_PARALLELISM_DIM_1-1:0] fdp_data_ready, fdp_weight_ready; assign fdp_join_ready = fdp_data_ready[0]; /* verilator lint_on UNUSEDSIGNAL */ logic acc_ready; - logic [ACC_WIDTH-1:0] acc_data_out[WEIGHT_PARALLELISM_DIM_0*WEIGHT_PARALLELISM_DIM_1-1:0]; + logic [ACC_WIDTH-1:0] acc_data_out[DATA_OUT_0_PARALLELISM_DIM_1-1:0]; - // There are WEIGHT_PARALLELISM_DIM_0 number of dot product instances with DATA_IN_0_TENSOR_SIZE_DIM_0 inputs - // and each one computes for IN_0_DEPTH iterations for each inputs. - for (genvar i = 0; i < WEIGHT_PARALLELISM_DIM_0; i = i + 1) begin : linear + // There are DATA_OUT_0_PARALLELISM_DIM_1 number of dot product instances with DATA_IN_0_PARALLELISM_DIM_1 inputs + // and each one computes for DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1 iterations for each inputs. + for (genvar i = 0; i < DATA_OUT_0_PARALLELISM_DIM_1; i = i + 1) begin : linear // Assume the weight are transposed and partitioned - logic [WEIGHT_PRECISION_0-1:0] current_weight[DATA_IN_0_PARALLELISM_DIM_0-1:0]; - assign current_weight = weight[DATA_IN_0_PARALLELISM_DIM_0*(i+1)-1:DATA_IN_0_PARALLELISM_DIM_0*i]; + logic [WEIGHT_PRECISION_0-1:0] current_weight[DATA_IN_0_PARALLELISM_DIM_1-1:0]; + assign current_weight = weight[DATA_IN_0_PARALLELISM_DIM_1*i+DATA_IN_0_PARALLELISM_DIM_1-1:DATA_IN_0_PARALLELISM_DIM_1*i]; logic [FDP_WIDTH-1:0] fdp_data_out; logic fdp_data_out_valid, fdp_data_out_ready; @@ -105,7 +106,7 @@ module fixed_linear #( fixed_dot_product #( .IN_WIDTH(DATA_IN_0_PRECISION_0), .WEIGHT_WIDTH(WEIGHT_PRECISION_0), - .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0) + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_1) ) fdp_inst ( .clk(clk), .rst(rst), @@ -126,7 +127,7 @@ module fixed_linear #( fixed_accumulator #( .IN_WIDTH(FDP_WIDTH), - .IN_DEPTH(IN_0_DEPTH) + .IN_DEPTH(DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1) ) fixed_accumulator_inst ( .clk(clk), .rst(rst), @@ -146,9 +147,8 @@ module fixed_linear #( if (HAS_BIAS == 1) begin - logic [ACC_WIDTH-1:0] bias_sext[BIAS_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_0-1:0]; + logic [ACC_WIDTH-1:0] bias_sext[BIAS_PARALLELISM_DIM_0-1:0]; logic acc_join_valid, acc_join_ready; - logic [DATA_IN_0_PARALLELISM_DIM_0-1:0] reg_ready; join2 #() acc_join_inst ( .data_in_ready ({bias_ready, acc_ready}), @@ -156,9 +156,11 @@ module fixed_linear #( .data_out_valid(acc_join_valid), .data_out_ready(acc_join_ready) ); + logic [BIAS_PARALLELISM_DIM_0-1:0] reg_ready; + assign acc_join_ready = ®_ready; fixed_rounding #( - .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0), + .IN_SIZE(BIAS_PARALLELISM_DIM_0), .IN_WIDTH(BIAS_PRECISION_0), .IN_FRAC_WIDTH(BIAS_PRECISION_1), .OUT_WIDTH(ACC_WIDTH), @@ -168,15 +170,13 @@ module fixed_linear #( .data_out(bias_sext) ); - assign acc_join_ready = ®_ready; - - for (genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0; i = i + 1) begin : add_bias - logic [DATA_OUT_0_PRECISION_0-1:0] add; + for (genvar i = 0; i < BIAS_PARALLELISM_DIM_0; i = i + 1) begin : add_bias + logic [ACC_WIDTH-1:0] add; assign add = $signed(acc_data_out[i]) + $signed(bias_sext[i]); /* verilator lint_off UNUSEDSIGNAL */ logic dout_valid; skid_buffer #( - .DATA_WIDTH(DATA_OUT_0_PRECISION_0) + .DATA_WIDTH(ACC_WIDTH) ) register_slice ( .clk (clk), .rst (rst), @@ -185,7 +185,7 @@ module fixed_linear #( .data_in (add), .data_out_valid(dout_valid), .data_out_ready(data_out_0_ready), - .data_out (data_out_0[i]) + .data_out (data_out_buff[i]) ); end assign data_out_0_valid = add_bias[0].dout_valid; @@ -193,11 +193,21 @@ module fixed_linear #( end else begin assign acc_ready = data_out_0_ready; assign data_out_0_valid = linear[0].acc_data_out_valid; - - for (genvar i = 0; i < WEIGHT_PARALLELISM_DIM_0; i = i + 1) begin - assign data_out_0[i] = acc_data_out[i]; - end + assign data_out_buff = acc_data_out; assign bias_ready = 1; end + + fixed_rounding #( + .IN_SIZE(DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1), + .IN_WIDTH(ACC_WIDTH), + .IN_FRAC_WIDTH(DATA_IN_0_PRECISION_1 + WEIGHT_PRECISION_1), + .OUT_WIDTH(DATA_OUT_0_PRECISION_0), + .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1) + ) fr_inst ( + .data_in (data_out_buff), + .data_out(data_out_0) + ); + + endmodule diff --git a/machop/mase_components/linear/test/fixed_linear_tb.py b/machop/mase_components/linear/test/fixed_linear_tb.py index 54596dea8..58e57b9f1 100644 --- a/machop/mase_components/linear/test/fixed_linear_tb.py +++ b/machop/mase_components/linear/test/fixed_linear_tb.py @@ -3,158 +3,252 @@ # This script tests the fixed point linear import os, logging -import cocotb -from cocotb.log import SimLog -from cocotb.triggers import * - -from mase_cocotb.testbench import Testbench -from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor -from mase_cocotb.z_qlayers import quantize_to_int +from mase_cocotb.random_test import RandomSource, RandomSink, check_results from mase_cocotb.runner import mase_runner -from mase_cocotb.utils import bit_driver, sign_extend_t - -from chop.passes.graph.transforms.quantize.quantized_modules import LinearInteger - -import torch - -logger = logging.getLogger("testbench") -logger.setLevel(logging.DEBUG) - -class LinearTB(Testbench): - def __init__(self, dut, in_features=4, out_features=4) -> None: - super().__init__(dut, dut.clk, dut.rst) - - if not hasattr(self, "log"): - self.log = SimLog("%s" % (type(self).__qualname__)) - - self.data_in_0_driver = StreamDriver( - dut.clk, dut.data_in_0, dut.data_in_0_valid, dut.data_in_0_ready +import cocotb +from cocotb.triggers import Timer +from cocotb.triggers import FallingEdge +from cocotb.clock import Clock + +debug = False + +logger = logging.getLogger("tb_signals") +if debug: + logger.setLevel(logging.DEBUG) + + +# DUT test specifications +class VerificationCase: + def __init__(self, samples=10): + self.has_bias = 1 + + self.data_in_0_precision_0 = 8 + self.data_in_0_precision_1 = 3 + self.data_in_0_tensor_size_dim_0 = 1 + self.data_in_0_parallelism_dim_0 = 1 + self.data_in_0_tensor_size_dim_1 = 8 + self.data_in_0_parallelism_dim_1 = 8 + self.data_in_0_tensor_size_dim_2 = 1 + self.data_in_0_parallelism_dim_2 = 1 + + self.weight_precision_0 = 8 + self.weight_precision_1 = 3 + self.weight_tensor_size_dim_0 = 10 + self.weight_parallelism_dim_0 = 10 + self.weight_tensor_size_dim_1 = 8 + self.weight_parallelism_dim_1 = 8 + self.weight_tensor_size_dim_2 = 1 + self.weight_parallelism_dim_2 = 1 + + self.data_out_0_precision_0 = 32 + self.data_out_0_precision_1 = 16 + self.data_out_0_tensor_size_dim_0 = 1 + self.data_out_0_parallelism_dim_0 = 1 + self.data_out_0_tensor_size_dim_1 = 10 + self.data_out_0_parallelism_dim_1 = 10 + self.data_out_0_tensor_size_dim_2 = 1 + + self.bias_precision_0 = 8 + self.bias_precision_1 = 3 + self.bias_tensor_size_dim_0 = 10 + self.bias_parallelism_dim_0 = 10 + self.bias_tensor_size_dim_1 = 1 + self.bias_parallelism_dim_1 = 1 + self.bias_tensor_size_dim_2 = 1 + self.bias_parallelism_dim_2 = 1 + + self.data_in = RandomSource( + name="data_in", + samples=samples + * self.data_in_0_tensor_size_dim_1 + // self.data_in_0_parallelism_dim_1, + num=self.data_in_0_parallelism_dim_1, + max_stalls=0, + debug=debug, ) - self.weight_driver = StreamDriver( - dut.clk, dut.weight, dut.weight_valid, dut.weight_ready + self.weight = RandomSource( + name="weight", + samples=samples + * self.weight_tensor_size_dim_1 + // self.weight_parallelism_dim_1, + num=self.weight_parallelism_dim_0 * self.weight_parallelism_dim_1, + max_stalls=0, + debug=debug, ) - - if int(dut.HAS_BIAS) == 1: - self.bias_driver = StreamDriver( - dut.clk, dut.bias, dut.bias_valid, dut.bias_ready - ) - - self.data_out_0_monitor = StreamMonitor( - dut.clk, - dut.data_out_0, - dut.data_out_0_valid, - dut.data_out_0_ready, - check=False, + self.bias = RandomSource( + name="bias", + samples=samples, + num=self.bias_parallelism_dim_0, + max_stalls=0, + debug=debug, ) - # Model - self.model = LinearInteger( - in_features=in_features, - out_features=out_features, - bias=False, - config={ - "data_in_width": 16, - "data_in_frac_width": 3, - "weight_width": 16, - "weight_frac_width": 3, - "bias_width": 16, - "bias_frac_width": 3, - }, + self.outputs = RandomSink(samples=samples, max_stalls=0, debug=debug) + self.samples = samples + self.ref = self.sw_compute() + + def get_dut_parameters(self): + return { + "HAS_BIAS": self.has_bias, + "DATA_IN_0_PRECISION_0": self.data_in_0_precision_0, + "DATA_IN_0_PRECISION_1": self.data_in_0_precision_1, + "DATA_IN_0_TENSOR_SIZE_DIM_0": self.data_in_0_tensor_size_dim_0, + "DATA_IN_0_PARALLELISM_DIM_0": self.data_in_0_parallelism_dim_0, + "DATA_IN_0_TENSOR_SIZE_DIM_1": self.data_in_0_tensor_size_dim_1, + "DATA_IN_0_PARALLELISM_DIM_1": self.data_in_0_parallelism_dim_1, + "DATA_IN_0_TENSOR_SIZE_DIM_2": self.data_in_0_tensor_size_dim_2, + "DATA_IN_0_PARALLELISM_DIM_2": self.data_in_0_parallelism_dim_2, + "WEIGHT_PRECISION_0": self.weight_precision_0, + "WEIGHT_PRECISION_1": self.weight_precision_1, + "WEIGHT_TENSOR_SIZE_DIM_0": self.weight_tensor_size_dim_0, + "WEIGHT_PARALLELISM_DIM_0": self.weight_parallelism_dim_0, + "WEIGHT_TENSOR_SIZE_DIM_1": self.weight_tensor_size_dim_1, + "WEIGHT_PARALLELISM_DIM_1": self.weight_parallelism_dim_1, + "WEIGHT_TENSOR_SIZE_DIM_2": self.weight_tensor_size_dim_2, + "WEIGHT_PARALLELISM_DIM_2": self.weight_parallelism_dim_2, + "DATA_OUT_0_PRECISION_0": self.data_out_0_precision_0, + "DATA_OUT_0_PRECISION_1": self.data_out_0_precision_1, + "DATA_OUT_0_TENSOR_SIZE_DIM_0": self.data_out_0_tensor_size_dim_0, + "DATA_OUT_0_PARALLELISM_DIM_0": self.data_out_0_parallelism_dim_0, + "DATA_OUT_0_TENSOR_SIZE_DIM_1": self.data_out_0_tensor_size_dim_1, + "DATA_OUT_0_PARALLELISM_DIM_1": self.data_out_0_parallelism_dim_1, + "DATA_OUT_0_TENSOR_SIZE_DIM_2": self.data_out_0_tensor_size_dim_2, + "BIAS_PRECISION_0": self.bias_precision_0, + "BIAS_PRECISION_1": self.bias_precision_1, + "BIAS_TENSOR_SIZE_DIM_0": self.bias_tensor_size_dim_0, + "BIAS_PARALLELISM_DIM_0": self.bias_parallelism_dim_0, + "BIAS_TENSOR_SIZE_DIM_1": self.bias_tensor_size_dim_1, + "BIAS_PARALLELISM_DIM_1": self.bias_parallelism_dim_1, + "BIAS_TENSOR_SIZE_DIM_2": self.bias_tensor_size_dim_2, + "BIAS_PARALLELISM_DIM_2": self.bias_parallelism_dim_2, + } + + def sw_compute(self): + ref = [] + for i in range(self.samples): + acc = [0 for _ in range(self.data_out_0_parallelism_dim_1)] + for j in range( + self.data_in_0_tensor_size_dim_1 // self.data_in_0_parallelism_dim_1 + ): + data_idx = ( + i + * self.data_in_0_tensor_size_dim_1 + // self.data_in_0_parallelism_dim_1 + + j + ) + temp = [] + for k in range(self.data_out_0_parallelism_dim_1): + s = [ + self.data_in.data[data_idx][h] + * self.weight.data[data_idx][ + k * self.data_in_0_parallelism_dim_1 + h + ] + for h in range(self.data_in_0_parallelism_dim_1) + ] + acc[k] += sum(s) + if self.has_bias: + for k in range(self.bias_parallelism_dim_0): + acc[k] += self.bias.data[i][k] << ( + self.weight_precision_1 + + self.data_in_0_precision_1 + - self.bias_precision_1 + ) + ref.append(acc) + ref.reverse() + return ref + + +def debug_state(dut, state): + logger.debug( + "{} State: (bias_ready,bias_valid,bias_ready,bias_valid,data_in_ready,data_in_valid,data_out_ready,data_out_valid) = ({},{},{},{},{},{})".format( + state, + dut.bias_ready.value, + dut.bias_valid.value, + dut.weight_ready.value, + dut.weight_valid.value, + dut.data_in_0_ready.value, + dut.data_in_0_valid.value, + dut.data_out_0_ready.value, + dut.data_out_0_valid.value, ) + ) + - def generate_inputs(self): - return torch.randn((1, self.model.in_features)) - - def preprocess_tensor(self, tensor, quantizer, config, parallelism): - tensor = quantizer(tensor) - tensor = (tensor * 2 ** config["frac_width"]).int() - logger.info(f"Tensor in int format: {tensor}") - tensor = tensor.reshape(-1, parallelism).tolist() - return tensor - - async def run_test(self): - await self.reset() - logger.info(f"Reset finished") - self.data_out_0_monitor.ready.value = 1 - - inputs = self.generate_inputs() - exp_out = self.model(inputs) - - # Load the inputs driver - logger.info(f"Processing inputs") - inputs = self.preprocess_tensor( - inputs, - self.model.x_quantizer, - {"widht": 16, "frac_width": 3}, - int(self.dut.DATA_IN_0_PARALLELISM_DIM_0), +@cocotb.test() +async def test_fixed_linear(dut): + """Test integer based vector mult""" + samples = 1000 + test_case = VerificationCase(samples=samples) + + # Reset cycle + await Timer(20, units="ns") + dut.rst.value = 1 + await Timer(100, units="ns") + dut.rst.value = 0 + + # Create a 10ns-period clock on port clk + clock = Clock(dut.clk, 10, units="ns") + # Start the clock + cocotb.start_soon(clock.start()) + await Timer(500, units="ns") + + # Synchronize with the clock + dut.weight_valid.value = 0 + dut.bias_valid.value = 0 + dut.data_in_0_valid.value = 0 + dut.data_out_0_ready.value = 1 + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + + done = False + # Set a timeout to avoid deadlock + for i in range(samples * 100): + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + dut.weight_valid.value = test_case.weight.pre_compute() + dut.bias_valid.value = test_case.bias.pre_compute() + dut.data_in_0_valid.value = test_case.data_in.pre_compute() + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.pre_compute( + dut.data_out_0_valid.value ) - self.data_in_0_driver.load_driver(inputs) - - # Load the weights driver - logger.info(f"Processing weights") - weights = self.preprocess_tensor( - self.model.weight, - self.model.w_quantizer, - {"widht": 16, "frac_width": 3}, - int(self.dut.WEIGHT_PARALLELISM_DIM_0) - * int(self.dut.DATA_IN_0_PARALLELISM_DIM_0), + await Timer(1, units="ns") + debug_state(dut, "Post-clk") + + dut.bias_valid.value, dut.bias.value = test_case.bias.compute( + dut.bias_ready.value ) - self.weight_driver.load_driver(weights) - - # Load the output monitor - logger.info(f"Processing outputs: {exp_out}") - # To do: need to quantize output to a different precision - outs = self.preprocess_tensor( - exp_out, - self.model.x_quantizer, - {"widht": 16, "frac_width": 3}, - int(self.dut.DATA_OUT_0_PARALLELISM_DIM_0), + dut.weight_valid.value, dut.weight.value = test_case.weight.compute( + dut.weight_ready.value ) - self.data_out_0_monitor.load_monitor(outs) - - await Timer(1000, units="us") - assert self.data_out_0_monitor.exp_queue.empty() + dut.data_in_0_valid.value, dut.data_in_0.value = test_case.data_in.compute( + dut.data_in_0_ready.value + ) + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.compute( + dut.data_out_0_valid.value, dut.data_out_0.value + ) + debug_state(dut, "Pre-clk") + if ( + (not test_case.has_bias or test_case.bias.is_empty()) + and test_case.weight.is_empty() + and test_case.data_in.is_empty() + and test_case.outputs.is_full() + ): + done = True + break + assert ( + done + ), "Deadlock detected or the simulation reaches the maximum cycle limit (fixed it by adjusting the loop trip count)" -@cocotb.test() -async def test_20x20(dut): - tb = LinearTB(dut, in_features=20, out_features=20) - await tb.run_test() + check_results(test_case.outputs.data, test_case.ref) if __name__ == "__main__": - mase_runner( - trace=True, - module_param_list=[ - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, - "DATA_IN_0_PARALLELISM_DIM_0": 2, - "WEIGHT_TENSOR_SIZE_DIM_0": 20, - "WEIGHT_TENSOR_SIZE_DIM_1": 20, - "WEIGHT_PARALLELISM_DIM_0": 20, - "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, - "DATA_OUT_0_PARALLELISM_DIM_0": 20, - "BIAS_TENSOR_SIZE_DIM_0": 20, - }, - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, - "DATA_IN_0_PARALLELISM_DIM_0": 4, - "WEIGHT_TENSOR_SIZE_DIM_0": 20, - "WEIGHT_TENSOR_SIZE_DIM_1": 20, - "WEIGHT_PARALLELISM_DIM_0": 20, - "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, - "DATA_OUT_0_PARALLELISM_DIM_0": 20, - "BIAS_TENSOR_SIZE_DIM_0": 20, - }, - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, - "DATA_IN_0_PARALLELISM_DIM_0": 5, - "WEIGHT_TENSOR_SIZE_DIM_0": 20, - "WEIGHT_TENSOR_SIZE_DIM_1": 20, - "WEIGHT_PARALLELISM_DIM_0": 20, - "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, - "DATA_OUT_0_PARALLELISM_DIM_0": 20, - "BIAS_TENSOR_SIZE_DIM_0": 20, - }, - ], - ) + tb = VerificationCase() + mase_runner(module_param_list=[tb.get_dut_parameters()]) diff --git a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py index 1ab406ced..93cb7d1cb 100644 --- a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py +++ b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # This example converts a simple MLP model to Verilog import os, sys, logging -import toml +import toml, math import torch import torch.nn as nn @@ -29,15 +29,15 @@ class MLP(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.fc1 = nn.Linear(28 * 28, 28 * 28) - self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4) - self.fc3 = nn.Linear(28 * 28 * 4, 10) + self.fc1 = nn.Linear(5 * 5, 5 * 5) + self.fc2 = nn.Linear(5 * 5, 5 * 5 * 4) + self.fc3 = nn.Linear(5 * 5 * 4, 10) def forward(self, x): x = torch.flatten(x, start_dim=1, end_dim=-1) x = torch.nn.functional.relu(self.fc1(x)) - x = torch.nn.functional.relu(self.fc2(x)) - x = self.fc3(x) + # x = torch.nn.functional.relu(self.fc2(x)) + # x = self.fc3(x) return x @@ -47,7 +47,7 @@ def test_emit_verilog_linear(): # Provide a dummy input for the graph so it can use for tracing batch_size = 1 - x = torch.randn((batch_size, 28, 28)) + x = torch.randn((batch_size, 5, 5)) dummy_in = {"x": x} mg, _ = passes.init_metadata_analysis_pass(mg, None) @@ -69,64 +69,32 @@ def test_emit_verilog_linear(): "configs", "tests", "quantize", - "integer.toml", + "fixed.toml", ) # load toml config file with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] - mg, _ = passes.quantize_transform_pass(mg, quan_args) - - # There is a bug in the current quantizzation pass, where the results metadata is not uppdated with the precision. - # Here we temporarily update the metadata here so we can test the hardware back end. - for node in mg.fx_graph.nodes: - for arg, _ in node.meta["mase"].parameters["common"]["args"].items(): - if ( - type(node.meta["mase"].parameters["common"]["args"][arg]) == dict - and "type" in node.meta["mase"].parameters["common"]["args"][arg].keys() - ): - node.meta["mase"].parameters["common"]["args"][arg]["type"] = "fixed" - for result, _ in node.meta["mase"].parameters["common"]["results"].items(): - if ( - type(node.meta["mase"].parameters["common"]["results"][result]) == dict - and "type" - in node.meta["mase"].parameters["common"]["results"][result].keys() - ): - node.meta["mase"].parameters["common"]["results"][result][ - "type" - ] = "fixed" - node.meta["mase"].parameters["common"]["results"][result][ - "precision" - ] = [8, 3] - - mg, _ = passes.add_hardware_metadata_analysis_pass( - mg - ) # add metadata for hardware in each mase node of graph - mg, _ = passes.report_node_hardware_type_analysis_pass(mg) # pretty print + with torch.no_grad(): + mg, _ = passes.quantize_transform_pass(mg, quan_args) + mg.model(dummy_in["x"]) + + # inspect the graph metadata + # mg, _ = passes.report_node_meta_param_analysis_pass(mg) + + # add metadata for hardware in each mase node of graph + mg, _ = passes.add_hardware_metadata_analysis_pass(mg) + # pretty print + mg, _ = passes.report_node_hardware_type_analysis_pass(mg) # mg = verify_hardware_metadata_analysis_pass(mg) + # Emit Verilog sources mg, _ = passes.emit_verilog_top_transform_pass(mg) - # mg = passes.emit_bram_transform_pass(mg) - # mg, _ = passes.emit_internal_rtl_transform_pass(mg) - - # # For internal models, the test inputs can be directly fetched from the dataset - # # using InputGenerator from chop.tools.get_input - # project_dir = Path(__file__).parents[6] / "top" - # print(f"project_dir {project_dir}") - # cosim_config = {"test_inputs": [x], "trans_num": 1, "project_dir": project_dir} - # # mg = passes.emit_verilog_tb_transform_pass(mg, pass_args=cosim_config) - - # # Run simulation pass if Vivado available - # try: - # execute_cli("xelab -h", log_output=False) - # has_verilog = True - # # mg = get_synthesis_results("top", mg, target="xcu250-figd2104-2L-e", output_dir=".") - # except: - # has_verilog = False - # print(f"Vivado not available") - - # if has_verilog: - # mg = passes.run_cosim_analysis_pass(mg) + mg, _ = passes.emit_bram_transform_pass(mg) + mg, _ = passes.emit_internal_rtl_transform_pass(mg) + + # Test Verilog sources + mg, _ = passes.test_verilog_analysis_pass(mg) if __name__ == "__main__":