From 2e28da63266ce235a59e7119dcab79def8af8a6a Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 10:26:21 +0000 Subject: [PATCH 01/31] Updated Docker to main for local build --- Docker | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Docker b/Docker index 782137fc2..e201d1e77 160000 --- a/Docker +++ b/Docker @@ -1 +1 @@ -Subproject commit 782137fc28db6af2f68074c06d5fd6fc86e4448f +Subproject commit e201d1e77259e5f2454d3c6048b505c89d10aa0e From 08c5fda5987cc0eb0a52b2c1731f2fd0e8fc82f2 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 11:05:49 +0000 Subject: [PATCH 02/31] Sync scripts for updated docker setups (seperating CPU and GPU containers) --- .github/workflows/buildAndTest.yml | 2 +- .github/workflows/buildDoc.yml | 2 +- .github/workflows/testHardware.yml | 2 +- .github/workflows/testTorchMLIR.yml | 2 +- Makefile | 8 ++++---- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index cedd6d8af..1ff3c9742 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -21,7 +21,7 @@ jobs: software-regression-test: runs-on: ubuntu-latest container: - image: deepwok/mase-docker:latest + image: deepwok/mase-docker-cpu:latest steps: # Clone the MASE repo and its submodules. diff --git a/.github/workflows/buildDoc.yml b/.github/workflows/buildDoc.yml index 6e2dee55a..9899242b9 100644 --- a/.github/workflows/buildDoc.yml +++ b/.github/workflows/buildDoc.yml @@ -22,7 +22,7 @@ jobs: software-regression-test: runs-on: ubuntu-latest container: - image: deepwok/mase-docker:latest + image: deepwok/mase-docker-cpu:latest steps: # Clone the MASE repo and its submodules. diff --git a/.github/workflows/testHardware.yml b/.github/workflows/testHardware.yml index dfe2d61f0..91eb3c1ca 100644 --- a/.github/workflows/testHardware.yml +++ b/.github/workflows/testHardware.yml @@ -25,7 +25,7 @@ jobs: hardware-regression-test: runs-on: ubuntu-latest container: - image: deepwok/mase-docker:latest + image: deepwok/mase-docker-cpu:latest steps: # Clone the MASE repo and its submodules. diff --git a/.github/workflows/testTorchMLIR.yml b/.github/workflows/testTorchMLIR.yml index a79fc7bd7..9f5ac29bb 100644 --- a/.github/workflows/testTorchMLIR.yml +++ b/.github/workflows/testTorchMLIR.yml @@ -21,7 +21,7 @@ jobs: torch-mlir-test: runs-on: ubuntu-latest container: - image: deepwok/mase-docker:latest + image: deepwok/mase-docker-cpu:latest steps: # Clone the MASE repo and its submodules. diff --git a/Makefile b/Makefile index 401ba540d..305b63e5b 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,8 @@ vhls=/mnt/applications/Xilinx/23.1 vhls_version=2023.1 local=0 -img=$(if $local,"mase-ubuntu2204:latest","deepwok/mase-docker:latest") +target=cpu +img=$(if $local,"mase-ubuntu2204:latest","deepwok/mase-docker-$(target):latest") user=$(if $(shell id -u),$(shell id -u),9001) group=$(if $(shell id -g),$(shell id -g),1000) coverage=machop/test/ @@ -20,9 +21,9 @@ sync-mlir: # Build Docker container build-docker: if [ $(local) = 1 ]; then \ - docker build --build-arg VHLS_PATH=$(vhls) --build-arg VHLS_VERSION=$(vhls_version) -f Docker/Dockerfile --tag mase-ubuntu2204 Docker; \ + docker build --build-arg VHLS_PATH=$(vhls) --build-arg VHLS_VERSION=$(vhls_version) -f Docker/Dockerfile-$(target) --tag mase-ubuntu2204 Docker; \ else \ - docker pull docker.io/deepwok/mase-docker:latest; \ + docker pull docker.io/deepwok/mase-docker-$(target):latest; \ fi shell: build-docker @@ -39,7 +40,6 @@ shell: build-docker # Short-term solution: call scripts under /tmp so we can clean it properly test-hw: mkdir -p ./tmp - pip install . (cd tmp; python3 ../scripts/test-hardware.py -a || exit 1) test-sw: From 8537c2cf2e4658d54d32b52e6f3ebbdb542d660b Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 11:40:37 +0000 Subject: [PATCH 03/31] removed docker as submodule --- .gitmodules | 3 --- Docker | 1 - 2 files changed, 4 deletions(-) delete mode 160000 Docker diff --git a/.gitmodules b/.gitmodules index 7ad4ee597..39b4aa1c1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,6 +7,3 @@ [submodule "machop/third-party/peft"] path = machop/third-party/peft url = git@github.com:huggingface/peft.git -[submodule "Docker"] - path = Docker - url = git@github.com:JianyiCheng/mase-docker.git diff --git a/Docker b/Docker deleted file mode 160000 index e201d1e77..000000000 --- a/Docker +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e201d1e77259e5f2454d3c6048b505c89d10aa0e From c2e47dc53ff88572159cc1cfb0b134780b94a7ef Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 11:40:48 +0000 Subject: [PATCH 04/31] Pull from Makefile to avoid repeated sync on submodule --- Makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Makefile b/Makefile index 305b63e5b..e38cbee61 100644 --- a/Makefile +++ b/Makefile @@ -21,6 +21,9 @@ sync-mlir: # Build Docker container build-docker: if [ $(local) = 1 ]; then \ + if [ ! -d Docker ]; then \ + git clone git@github.com:jianyicheng/mase-docker.git Docker; \ + fi; \ docker build --build-arg VHLS_PATH=$(vhls) --build-arg VHLS_VERSION=$(vhls_version) -f Docker/Dockerfile-$(target) --tag mase-ubuntu2204 Docker; \ else \ docker pull docker.io/deepwok/mase-docker-$(target):latest; \ From b3081baefd7e887ce219e32ba09d62ad746fadac Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 12:21:50 +0000 Subject: [PATCH 05/31] Added instructions to log for better readability on CI log --- scripts/test-machop.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/test-machop.sh b/scripts/test-machop.sh index 33ccf8261..30f7008fb 100755 --- a/scripts/test-machop.sh +++ b/scripts/test-machop.sh @@ -13,6 +13,8 @@ MASE=$SCRIPT_DIR/.. cd $MASE/machop +set -o xtrace + ##### Basic training and testing # training ./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug @@ -32,3 +34,5 @@ cd $MASE/machop ./ch transform --config configs/examples/jsc_toy_by_type_module.toml --task cls --accelerator=cpu --load ../mase_output/tmp/software/training_ckpts/best.ckpt --load-type pl # train the transformed network ./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug --load ../mase_output/jsc-tiny/software/transform/transformed_ckpt/state_dict.pt --load-type pt + +set +o xtrace From 9fd3d6711cee63e74f28be3392d244eff40d694f Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 12:38:18 +0000 Subject: [PATCH 06/31] Initially seperate dataflow sv out of top --- .../graph/transforms/verilog/emit_top.py | 62 ++++++++++--------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_top.py b/machop/chop/passes/graph/transforms/verilog/emit_top.py index ead3b3077..9e38f804f 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_top.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_top.py @@ -65,7 +65,7 @@ class VerilogParameterEmitter: def __init__(self, graph): self.graph = graph - def emit(self, graph, parameter_map) -> Tuple[str, Dict[str, str]]: + def emit(self, parameter_map) -> Tuple[str, Dict[str, str]]: """ Emit parameters at the top-level for the top-level module @@ -73,8 +73,8 @@ def emit(self, graph, parameter_map) -> Tuple[str, Dict[str, str]]: 1) list of parameters as a string to be embedded in Verilog file """ - nodes_in = graph.nodes_in - nodes_out = graph.nodes_out + nodes_in = self.graph.nodes_in + nodes_out = self.graph.nodes_out node_in_name = vf(nodes_in[0].name) node_out_name = vf(nodes_out[0].name) @@ -96,7 +96,7 @@ class VerilogInterfaceEmitter: def __init__(self, graph): self.graph = graph - def emit(self, graph, parameter_map): + def emit(self, parameter_map): """ Emit interface signal declarations for the top-level module """ @@ -259,13 +259,13 @@ def _emit_signals_top_hls(self, node, parameter_map): logic {node_name}_{key}_we0;""" return signals - def emit(self, graph, parameter_map): + def emit(self, parameter_map): """ Emit internal signal declarations for the top-level module """ signals = "" - for node in graph.fx_graph.nodes: + for node in self.graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue node_name = vf(node.name) @@ -478,7 +478,7 @@ def __init__(self, graph): self.internal_emitter = VerilogInternalComponentEmitter(graph) self.hls_emitter = VerilogHLSComponentEmitter(graph) - def emit(self, graph, parameter_map): + def emit(self, parameter_map): """ Emit component declarations for the top-level module """ @@ -488,7 +488,7 @@ def emit(self, graph, parameter_map): // Component instantiation // -------------------------- """ - for node in graph.fx_graph.nodes: + for node in self.graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: @@ -507,9 +507,8 @@ def emit(self, graph, parameter_map): class VerilogWireEmitter: - def __init__(self, graph, parameter_map): + def __init__(self, graph): self.graph = graph - self.parameter_map = parameter_map self.wires = """ // -------------------------- @@ -517,7 +516,7 @@ def __init__(self, graph, parameter_map): // -------------------------- """ - def _emit_top_wires(self): + def _emit_top_wires(self, parameter_map): nodes_in = self.graph.nodes_in nodes_out = self.graph.nodes_out @@ -580,7 +579,7 @@ def _emit_node2node_wires(self): """ return wires - def emit(self): + def emit(self, parameter_map): """ Emit internal signal connections for the top-level module This includes two interconnection types: @@ -588,7 +587,7 @@ def emit(self): 2. Interface casting between inputs and outputs """ - self.wires += self._emit_top_wires() + self.wires += self._emit_top_wires(parameter_map) self.wires += self._emit_node2node_wires() return self.wires @@ -598,28 +597,26 @@ def emit(self): # ============================================================================= -class VerilogEmitter: +class DataflowEmitter: def __init__(self, graph): self.graph = graph self.parameter_map = get_verilog_parameters(graph) - def emit(self, graph, top_name): - parameters_to_emit = VerilogParameterEmitter(graph).emit( - graph, self.parameter_map + def emit(self, top_name): + parameters_to_emit = VerilogParameterEmitter(self.graph).emit( + self.parameter_map ) - interface_to_emit = VerilogInterfaceEmitter(graph).emit( - graph, self.parameter_map - ) + interface_to_emit = VerilogInterfaceEmitter(self.graph).emit(self.parameter_map) - signals_to_emit = VerilogSignalEmitter(graph).emit(graph, self.parameter_map) + signals_to_emit = VerilogSignalEmitter(self.graph).emit(self.parameter_map) - components_to_emit = VerilogComponentEmitter(graph).emit( - graph, self.parameter_map + components_to_emit = VerilogComponentEmitter(self.graph).emit( + self.parameter_map ) - wires_to_emit = VerilogWireEmitter(graph, self.parameter_map).emit() + wires_to_emit = VerilogWireEmitter(self.graph).emit(self.parameter_map) time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S") @@ -676,16 +673,23 @@ def emit_verilog_top_transform_pass(graph, pass_args={}): project_dir = ( pass_args["project_dir"] if "project_dir" in pass_args.keys() - else Path.home() / ".mase" / "top" + else "./top" + # else Path.home() / ".mase" / "top" ) top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" init_project(project_dir) + logger.info(f"Project path: {project_dir}") + rtl_dir = os.path.join(project_dir, "hardware", "rtl") - top = VerilogEmitter(graph).emit(graph, top_name) + df = DataflowEmitter(graph).emit(top_name) + df_file = os.path.join(rtl_dir, f"{top_name}_df.sv") + with open(df_file, "w") as df_design: + df_design.write(df) - top_file = os.path.join(rtl_dir, f"{top_name}.sv") - with open(top_file, "w") as top_design: - top_design.write(top) + # top = MemoryMapEmitter(graph).emit(top_name) + # top_file = os.path.join(rtl_dir, f"{top_name}.sv") + # with open(top_file, "w") as top_design: + # top_design.write(top) return graph, {} From cf6f26c40fdb624f634f4a9ed1a18a6251608cc1 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 12:50:02 +0000 Subject: [PATCH 07/31] Split two files initially --- .../graph/transforms/verilog/emit_top.py | 629 +++++++++++++++++- 1 file changed, 601 insertions(+), 28 deletions(-) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_top.py b/machop/chop/passes/graph/transforms/verilog/emit_top.py index 9e38f804f..3a63d9016 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_top.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_top.py @@ -57,11 +57,582 @@ def param_needs_signals(node, param, value, qualifier="data_in"): # ============================================================================= -# Verilog parameters +# Emit design in a memory-independent dataflow graph # ============================================================================= +# ============================================================================= +# DFVerilog parameters +# ============================================================================= + + +class DFVerilogParameterEmitter: + def __init__(self, graph): + self.graph = graph + + def emit(self, parameter_map) -> Tuple[str, Dict[str, str]]: + """ + Emit parameters at the top-level for the top-level module + + Returns Tuple: + 1) list of parameters as a string to be embedded in DFVerilog file + """ + + nodes_in = self.graph.nodes_in + nodes_out = self.graph.nodes_out + node_in_name = vf(nodes_in[0].name) + node_out_name = vf(nodes_out[0].name) + + parameters = "" + + # Write node parameters + for key, value in parameter_map.items(): + parameters += f""" parameter {key} = {value},\n""" + + return _remove_last_comma(parameters) + + +# ============================================================================= +# DFVerilog interface +# ============================================================================= + + +class DFVerilogInterfaceEmitter: + def __init__(self, graph): + self.graph = graph + + def emit(self, parameter_map): + """ + Emit interface signal declarations for the top-level module + """ + + nodes_in = self.graph.nodes_in + nodes_out = self.graph.nodes_out + + interface = "" + # TODO: here we just enumerate the inputs of the input nodes - which may be + # order insensitive and require manual connection when adding the graph to + # a system. + i = 0 + for node in nodes_in: + node_name = vf(node.name) + for arg in node.meta["mase"].parameters["common"]["args"].keys(): + if "data_in" in arg: + arg_name = _cap(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + input [{node_name}_{arg_name}_PRECISION_0-1:0] data_in_{i} [{'*'.join(parallelism_params)}-1:0], + input data_in_{i}_valid, + output data_in_{i}_ready,""" + i += 1 + + i = 0 + for node in nodes_out: + node_name = vf(node.name) + for result in node.meta["mase"].parameters["common"]["results"].keys(): + if "data_out" in result: + result_name = _cap(result) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{result_name}_PARALLELISM_DIM" in param + ] + interface += f""" + output [{node_name}_{result_name}_PRECISION_0-1:0] data_out_{i} [{'*'.join(parallelism_params)}-1:0], + output data_out_{i}_valid, + input data_out_{i}_ready,""" + i += 1 + + # TODO: emit off-chip parameter interface + + return _remove_last_comma(interface) + + +# ============================================================================= +# DFVerilog signals +# ============================================================================= + + +class DFVerilogSignalEmitter: + def __init__(self, graph): + self.graph = graph + + def _emit_signals_top_internal(self, node, parameter_map): + signals = "" + node_name = vf(node.name) + # Input signals + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if not isinstance(arg_info, dict): + continue + + # Skip off-chip parameters as they will be directly connected to the top level + if ( + "data_in" in arg + or node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] + == "BRAM" + ): + arg_name = v2p(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + signals += f""" +logic [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0]; +logic {node_name}_{arg}_valid; +logic {node_name}_{arg}_ready;""" + + # Output signals + for result, result_info in ( + node.meta["mase"].parameters["common"]["results"].items() + ): + if not isinstance(result_info, dict): + continue + + # Skip off-chip parameters as they will be directly connected to the top level + if ( + "data_out" in result + or node.meta["mase"].parameters["hardware"]["interface"][result][ + "storage" + ] + == "BRAM" + ): + result_name = v2p(result) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{result_name}_PARALLELISM_DIM" in param + ] + signals += f""" +logic [{node_name}_{result_name}_PRECISION_0-1:0] {node_name}_{result} [{'*'.join(parallelism_params)}-1:0]; +logic {node_name}_{result}_valid; +logic {node_name}_{result}_ready;""" + + return signals + + def _emit_signals_top_hls(self, node, parameter_map): + """ + TODO + """ + + node_name = vf(node.name) + # Control signals for HLS component + signals = f""" +logic {node_name}_start; +logic {node_name}_done; +logic {node_name}_idle; +logic {node_name}_ready; +logic {node_name}_ce;""" + + # Input signals + for key, value in node.meta["mase"].parameters["common"]["args"].items(): + # No internal signals if the memory is stored off chip + if not param_needs_signals(node, key, value, qualifier="data_in"): + continue + + cap_key = v2p(key) + size = math.prod(value["shape"]) + + if key != "data_in": + a_width = math.ceil(math.log2(size)) + else: + depth = parameter_map[f"{node_name}_{cap_key}_DEPTH"] + a_width = math.ceil(math.log2(depth * size)) + + signals += f""" +logic [{node_name}_{cap_key}_PRECISION_0-1:0] {node_name}_{key}_q0; +logic [{a_width}-1:0] {node_name}_{key}_address0; +logic {node_name}_{key}_ce0;""" + + # Output signals + for key, value in node.meta["mase"].parameters["common"]["results"].items(): + # No internal signals if the memory is stored off chip + if not param_needs_signals(node, key, value, qualifier="data_out"): + continue + + cap_key = v2p(key) + size = math.prod(value["shape"]) + a_width = math.ceil(math.log2(size)) + signals += f""" +logic [{node_name}_{cap_key}_PRECISION_0-1:0] {node_name}_{key}_d0; +logic [{a_width}-1:0] {node_name}_{key}_address0; +logic {node_name}_{key}_ce0; +logic {node_name}_{key}_we0;""" + return signals + + def emit(self, parameter_map): + """ + Emit internal signal declarations for the top-level module + """ + + signals = "" + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + signals += f""" +// -------------------------- +// {node_name} signals +// --------------------------""" + if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: + signals += self._emit_signals_top_internal(node, parameter_map) + elif node.meta["mase"].parameters["hardware"]["toolchain"] == "HLS": + signals += self._emit_signals_top_hls(node, parameter_map) + else: + assert False, "Unknown node toolchain for signal declarations." + + return signals + + +# ============================================================================= +# DFVerilog components (INTERNAL) +# ============================================================================= + + +class DFVerilogInternalComponentEmitter: + def __init__(self, graph): + self.graph = graph + + def emit(self, node, parameter_map): + node_name = vf(node.name) + component_name = node.meta["mase"].parameters["hardware"]["module"] + signals = "" + + # Emit component instantiation parameters + parameters = "" + for key, value in ( + node.meta["mase"].parameters["hardware"]["verilog_param"].items() + ): + key_value = parameter_map[f"{node_name}_{key}"] + debug_info = f"// = {key_value}" + parameters += f""" .{key}({node_name}_{key}), {debug_info}\n""" + parameters = _remove_last_comma(parameters) + + # Emit component instantiation input signals + for key, value in node.meta["mase"].parameters["common"]["args"].items(): + if "data" not in key: + continue + signals += f""" + .{key}({node_name}_{key}), + .{key}_valid({node_name}_{key}_valid), + .{key}_ready({node_name}_{key}_ready), + """ + + # Emit component instantiation output signals + for key, value in node.meta["mase"].parameters["common"]["results"].items(): + if "data" not in key: + continue + signals += f""" + .{key}({node_name}_{key}), + .{key}_valid({node_name}_{key}_valid), + .{key}_ready({node_name}_{key}_ready), + """ + signals = _remove_last_comma(signals) + + # Combine component instantiation + components = f""" +// {node_name} +{component_name} #( +{parameters} +) {node_name}_inst ( + .clk(clk), + .rst(rst), +{signals} +); +""" + + return components + + +# ============================================================================= +# DFVerilog components (HLS) +# ============================================================================= + + +class DFVerilogHLSComponentEmitter: + def __init__(self, graph): + self.graph = graph + + def _emit_module_parameters_top_hls(self, key, value, node, parameter_map): + node_name = vf(node.name) + cap_key = v2p(key) + component_name = f"{node_name}_{key}_source" + component_name_inst = f"{node_name}_{key}_0" + + size_debug_info = math.prod(value["shape"]) + a_width = math.ceil(math.log2(size_debug_info)) + + return f""" +{component_name} #( + .DATA_WIDTH({node_name}_{cap_key}_PRECISION_0), + .ADDR_RANGE({node_name}_{cap_key}_TENSOR_SIZE_0), + .ADDR_WIDTH({a_width}) +) {component_name_inst} ( + .clk(clk), + .reset(rst), + + .address0({node_name}_{key}_address0), + .ce0({node_name}_{key}_ce0), + .q0({node_name}_{key}_q0) +); +""" + + def emit(self, node, parameter_map): + node_name = vf(node.name) + component_name = node.meta["mase"].parameters["hardware"]["module"] + + # Emit kernel instance + signals = "" + for key, value in node.meta["mase"].parameters["common"]["args"].items(): + signals += f""" + .{key}_address0({node_name}_{key}_address0), + .{key}_ce0({node_name}_{key}_ce0), + .{key}_q0({node_name}_{key}_q0), +""" + + for key, value in node.meta["mase"].parameters["common"]["results"].items(): + signals += f""" + .{key}_address0({node_name}_{key}_address0), + .{key}_ce0({node_name}_{key}_ce0), + .{key}_we0({node_name}_{key}_we0), + .{key}_d0({node_name}_{key}_d0), +""" + signals = _remove_last_comma(signals) + components = f""" +// {node_name} +{component_name} #( +) {node_name}_inst ( + .ap_clk(clk), + .ap_rst(rst), + + .ap_start({node_name}_start), + .ap_idle({node_name}_idle), + .ap_ready({node_name}_ready), + .ap_done({node_name}_done), + .ap_ce({node_name}_ce), +{signals} +); +""" + + # Emit parameter instance + for key, value in node.meta["mase"].parameters["common"]["args"].items(): + # Skip the parameter instance if the memory is stored off chip + if not param_needs_signals(node, key, value, qualifier="data_in"): + continue + components += self._emit_module_parameters_top_hls( + key, value, node, parameter_map + ) + + for key, value in node.meta["mase"].parameters["common"]["results"].items(): + # Skip the parameter instance if the memory is stored off chip + if not param_needs_signals(node, key, value, qualifier="data_out"): + continue + components += self._emit_module_parameters_top_hls( + key, value, node, parameter_map + ) + + return components + + +# ============================================================================= +# DFVerilog components +# ============================================================================= + + +class DFVerilogComponentEmitter: + def __init__(self, graph): + self.graph = graph + self.internal_emitter = DFVerilogInternalComponentEmitter(graph) + self.hls_emitter = DFVerilogHLSComponentEmitter(graph) + + def emit(self, parameter_map): + """ + Emit component declarations for the top-level module + """ + + components = """ +// -------------------------- +// Component instantiation +// -------------------------- +""" + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: + components += self.internal_emitter.emit(node, parameter_map) + elif node.meta["mase"].parameters["hardware"]["toolchain"] == "HLS": + components += self.hls_emitter.emit(node, parameter_map) + else: + assert False, "Unknown node toolchain for signal declarations." + + return components + + +# ============================================================================= +# DFVerilog wires +# ============================================================================= + + +class DFVerilogWireEmitter: + def __init__(self, graph): + self.graph = graph + + self.wires = """ +// -------------------------- +// Interconnections +// -------------------------- + """ + + def _emit_top_wires(self, parameter_map): + nodes_in = self.graph.nodes_in + nodes_out = self.graph.nodes_out + + # ============================================================ + # Top level wires + # ============================================================ + + wires = "" + # TODO: here we just enumerate the inputs of the input nodes - which may be + # order insensitive and require manual connection when adding the graph to + # a system. + i = 0 + for node in nodes_in: + node_name = vf(node.name) + for arg in node.meta["mase"].parameters["common"]["args"].keys(): + if "data_in" in arg: + wires += f""" + assign data_in_{i}_ready = {node_name}_{arg}_ready; + assign {node_name}_{arg}_valid = data_in_{i}_valid; + assign {node_name}_{arg} = data_in_{i}; +""" + i += 1 + i = 0 + for node in nodes_out: + node_name = vf(node.name) + for result in node.meta["mase"].parameters["common"]["results"].keys(): + if "data_out" in result: + wires += f""" + assign data_out_{i}_valid = {node_name}_{result}_valid; + assign {node_name}_{result}_ready = data_out_{i}_ready; + assign data_out_{i} = {node_name}_{result}; +""" + i += 1 + + # TODO: emit off-chip parameter interface + + return wires + + def _emit_node2node_wires(self): + nodes_in = self.graph.nodes_in + + # Ignore the input of the input nodes + # (as they are already connected by the previous process) + # For each other explicit node, emit the edge of their inputs. + # Assume all the node has only one output. + wires = "" + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + if node in nodes_in: + continue + + to_name = vf(node.name) + for i, node_in in enumerate(node.all_input_nodes): + from_name = vf(node_in.name) + wires += f""" + assign {from_name}_data_out_0_ready = {to_name}_data_in_{i}_ready; + assign {to_name}_data_in_{i}_valid = {from_name}_data_out_0_valid; + assign {to_name}_data_in_{i} = {from_name}_data_out_0; +""" + return wires + + def emit(self, parameter_map): + """ + Emit internal signal connections for the top-level module + This includes two interconnection types: + 1. Type casting between inputs and outputs + 2. Interface casting between inputs and outputs + """ + + self.wires += self._emit_top_wires(parameter_map) + self.wires += self._emit_node2node_wires() + return self.wires + + +# ============================================================================= +# Emit Verilog +# ============================================================================= + + +class DataflowEmitter: + def __init__(self, graph): + self.graph = graph + + self.parameter_map = get_verilog_parameters(graph) + + def emit(self, top_name): + parameters_to_emit = DFVerilogParameterEmitter(self.graph).emit( + self.parameter_map + ) + + interface_to_emit = DFVerilogInterfaceEmitter(self.graph).emit( + self.parameter_map + ) + + signals_to_emit = DFVerilogSignalEmitter(self.graph).emit(self.parameter_map) + + components_to_emit = DFVerilogComponentEmitter(self.graph).emit( + self.parameter_map + ) + + wires_to_emit = DFVerilogWireEmitter(self.graph).emit(self.parameter_map) + + time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S") + + module_inst = """ +// ===================================== +// Mase Hardware +// Model: {} +// {} +// ===================================== +`timescale 1ns/1ps +module {} #( +{} +) ( + input clk, + input rst, +{} +); +{} +{} +{} +endmodule + """.format( + top_name, + time_to_emit, + top_name, + parameters_to_emit, + interface_to_emit, + signals_to_emit, + components_to_emit, + wires_to_emit, + ) + return module_inst + + +# ============================================================================= +# Emit top-level design with memory mapping +# ============================================================================= -class VerilogParameterEmitter: +# ============================================================================= +# MMVerilog parameters +# ============================================================================= + + +class MMVerilogParameterEmitter: def __init__(self, graph): self.graph = graph @@ -70,7 +641,7 @@ def emit(self, parameter_map) -> Tuple[str, Dict[str, str]]: Emit parameters at the top-level for the top-level module Returns Tuple: - 1) list of parameters as a string to be embedded in Verilog file + 1) list of parameters as a string to be embedded in MMVerilog file """ nodes_in = self.graph.nodes_in @@ -88,11 +659,11 @@ def emit(self, parameter_map) -> Tuple[str, Dict[str, str]]: # ============================================================================= -# Verilog interface +# MMVerilog interface # ============================================================================= -class VerilogInterfaceEmitter: +class MMVerilogInterfaceEmitter: def __init__(self, graph): self.graph = graph @@ -148,11 +719,11 @@ def emit(self, parameter_map): # ============================================================================= -# Verilog signals +# MMVerilog signals # ============================================================================= -class VerilogSignalEmitter: +class MMVerilogSignalEmitter: def __init__(self, graph): self.graph = graph @@ -284,11 +855,11 @@ def emit(self, parameter_map): # ============================================================================= -# Verilog components (INTERNAL) +# MMVerilog components (INTERNAL) # ============================================================================= -class VerilogInternalComponentEmitter: +class MMVerilogInternalComponentEmitter: def __init__(self, graph): self.graph = graph @@ -378,11 +949,11 @@ def emit(self, node, parameter_map): # ============================================================================= -# Verilog components (HLS) +# MMVerilog components (HLS) # ============================================================================= -class VerilogHLSComponentEmitter: +class MMVerilogHLSComponentEmitter: def __init__(self, graph): self.graph = graph @@ -468,15 +1039,15 @@ def emit(self, node, parameter_map): # ============================================================================= -# Verilog components +# MMVerilog components # ============================================================================= -class VerilogComponentEmitter: +class MMVerilogComponentEmitter: def __init__(self, graph): self.graph = graph - self.internal_emitter = VerilogInternalComponentEmitter(graph) - self.hls_emitter = VerilogHLSComponentEmitter(graph) + self.internal_emitter = MMVerilogInternalComponentEmitter(graph) + self.hls_emitter = MMVerilogHLSComponentEmitter(graph) def emit(self, parameter_map): """ @@ -502,11 +1073,11 @@ def emit(self, parameter_map): # ============================================================================= -# Verilog wires +# MMVerilog wires # ============================================================================= -class VerilogWireEmitter: +class MMVerilogWireEmitter: def __init__(self, graph): self.graph = graph @@ -593,30 +1164,32 @@ def emit(self, parameter_map): # ============================================================================= -# Emit Verilog +# Emit MMVerilog # ============================================================================= -class DataflowEmitter: +class MemoryMapEmitter: def __init__(self, graph): self.graph = graph self.parameter_map = get_verilog_parameters(graph) def emit(self, top_name): - parameters_to_emit = VerilogParameterEmitter(self.graph).emit( + parameters_to_emit = MMVerilogParameterEmitter(self.graph).emit( self.parameter_map ) - interface_to_emit = VerilogInterfaceEmitter(self.graph).emit(self.parameter_map) + interface_to_emit = MMVerilogInterfaceEmitter(self.graph).emit( + self.parameter_map + ) - signals_to_emit = VerilogSignalEmitter(self.graph).emit(self.parameter_map) + signals_to_emit = MMVerilogSignalEmitter(self.graph).emit(self.parameter_map) - components_to_emit = VerilogComponentEmitter(self.graph).emit( + components_to_emit = MMVerilogComponentEmitter(self.graph).emit( self.parameter_map ) - wires_to_emit = VerilogWireEmitter(self.graph).emit(self.parameter_map) + wires_to_emit = MMVerilogWireEmitter(self.graph).emit(self.parameter_map) time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S") @@ -687,9 +1260,9 @@ def emit_verilog_top_transform_pass(graph, pass_args={}): with open(df_file, "w") as df_design: df_design.write(df) - # top = MemoryMapEmitter(graph).emit(top_name) - # top_file = os.path.join(rtl_dir, f"{top_name}.sv") - # with open(top_file, "w") as top_design: - # top_design.write(top) + top = MemoryMapEmitter(graph).emit(top_name) + top_file = os.path.join(rtl_dir, f"{top_name}.sv") + with open(top_file, "w") as top_design: + top_design.write(top) return graph, {} From 3bfe0649b867ebbf539aa40ee28b1f12e911c6fc Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 15:29:24 +0000 Subject: [PATCH 08/31] Refactored dataflow level --- .../graph/transforms/verilog/emit_top.py | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_top.py b/machop/chop/passes/graph/transforms/verilog/emit_top.py index 3a63d9016..c8adced98 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_top.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_top.py @@ -146,7 +146,26 @@ def emit(self, parameter_map): input data_out_{i}_ready,""" i += 1 - # TODO: emit off-chip parameter interface + # Emit all parameters as inputs (they will be mapped at the top-level) + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if not isinstance(arg_info, dict): + continue + if "data_in" not in arg: + arg_name = _cap(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + input [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0], + input {node_name}_{arg}_valid, + output {node_name}_{arg}_ready,""" + i += 1 return _remove_last_comma(interface) @@ -165,15 +184,7 @@ def _emit_signals_top_internal(self, node, parameter_map): node_name = vf(node.name) # Input signals for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): - if not isinstance(arg_info, dict): - continue - - # Skip off-chip parameters as they will be directly connected to the top level - if ( - "data_in" in arg - or node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] - == "BRAM" - ): + if "data_in" in arg: arg_name = v2p(arg) parallelism_params = [ param @@ -192,14 +203,7 @@ def _emit_signals_top_internal(self, node, parameter_map): if not isinstance(result_info, dict): continue - # Skip off-chip parameters as they will be directly connected to the top level - if ( - "data_out" in result - or node.meta["mase"].parameters["hardware"]["interface"][result][ - "storage" - ] - == "BRAM" - ): + if "data_out" in result: result_name = v2p(result) parallelism_params = [ param @@ -1255,11 +1259,13 @@ def emit_verilog_top_transform_pass(graph, pass_args={}): rtl_dir = os.path.join(project_dir, "hardware", "rtl") + # Emit device-independent hardware design in dataflow df = DataflowEmitter(graph).emit(top_name) df_file = os.path.join(rtl_dir, f"{top_name}_df.sv") with open(df_file, "w") as df_design: df_design.write(df) + # Emit memory mapping with BRAMs for the top-level hardware design top = MemoryMapEmitter(graph).emit(top_name) top_file = os.path.join(rtl_dir, f"{top_name}.sv") with open(top_file, "w") as top_design: From caad0b1651af38ef43853f9db1a012833fa9052a Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 17:04:44 +0000 Subject: [PATCH 09/31] Added missing dependences and updated file paths --- .../chop/passes/graph/transforms/verilog/emit_internal.py | 7 ++++--- .../graph/transforms/verilog/internal_file_dependences.py | 7 ++++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_internal.py b/machop/chop/passes/graph/transforms/verilog/emit_internal.py index 8b40a1b83..d29186ae0 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_internal.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_internal.py @@ -46,7 +46,7 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): for node in graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue - if "INTERNAL_RTL" == node.meta["mase"].parameters["hardware"]["toolchain"]: + if "INTERNAL" == node.meta["mase"].parameters["hardware"]["toolchain"]: if ( hasattr(node.meta["mase"].module, "config") and node.meta["mase"].module.config.get("name", "") == "logicnets" @@ -70,11 +70,12 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): "..", "..", "..", - "..", "mase_components", ) for f in rtl_dependencies: - shutil.copy(os.path.join(hardware_dir, f), rtl_dir) + fname = os.path.join(hardware_dir, f) + assert os.path.isfile(fname), f"Cannot find file {fname}." + shutil.copy(fname, rtl_dir) return graph, {} diff --git a/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py b/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py index a84625f93..9087e2ab3 100644 --- a/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py +++ b/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py @@ -12,6 +12,11 @@ "common/rtl/skid_buffer.sv", "common/rtl/join2.sv", "cast/rtl/fixed_rounding.sv", + "cast/rtl/fixed_round.sv", + ], + "relu": [ + "activations/rtl/fixed_relu.sv", + "cast/rtl/fixed_rounding.sv", + "cast/rtl/fixed_round.sv", ], - "relu": ["activations/fixed_relu.sv"], } From 5b8990fb5e03e3e554ffbc6d3371583dcb3853b1 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 17:36:02 +0000 Subject: [PATCH 10/31] Rafactored memory map emit --- .../graph/transforms/verilog/emit_top.py | 361 +++++------------- 1 file changed, 94 insertions(+), 267 deletions(-) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_top.py b/machop/chop/passes/graph/transforms/verilog/emit_top.py index c8adced98..32f79b0ff 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_top.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_top.py @@ -61,11 +61,11 @@ def param_needs_signals(node, param, value, qualifier="data_in"): # ============================================================================= # ============================================================================= -# DFVerilog parameters +# Verilog parameters # ============================================================================= -class DFVerilogParameterEmitter: +class VerilogParameterEmitter: def __init__(self, graph): self.graph = graph @@ -574,11 +574,10 @@ def emit(self, parameter_map): class DataflowEmitter: def __init__(self, graph): self.graph = graph - self.parameter_map = get_verilog_parameters(graph) def emit(self, top_name): - parameters_to_emit = DFVerilogParameterEmitter(self.graph).emit( + parameters_to_emit = VerilogParameterEmitter(self.graph).emit( self.parameter_map ) @@ -598,12 +597,12 @@ def emit(self, top_name): module_inst = """ // ===================================== -// Mase Hardware +// Mase Hardware (Dataflow) // Model: {} // {} // ===================================== `timescale 1ns/1ps -module {} #( +module {}_dataflow #( {} ) ( input clk, @@ -631,37 +630,6 @@ def emit(self, top_name): # Emit top-level design with memory mapping # ============================================================================= -# ============================================================================= -# MMVerilog parameters -# ============================================================================= - - -class MMVerilogParameterEmitter: - def __init__(self, graph): - self.graph = graph - - def emit(self, parameter_map) -> Tuple[str, Dict[str, str]]: - """ - Emit parameters at the top-level for the top-level module - - Returns Tuple: - 1) list of parameters as a string to be embedded in MMVerilog file - """ - - nodes_in = self.graph.nodes_in - nodes_out = self.graph.nodes_out - node_in_name = vf(nodes_in[0].name) - node_out_name = vf(nodes_out[0].name) - - parameters = "" - - # Write node parameters - for key, value in parameter_map.items(): - parameters += f""" parameter {key} = {value},\n""" - - return _remove_last_comma(parameters) - - # ============================================================================= # MMVerilog interface # ============================================================================= @@ -738,11 +706,12 @@ def _emit_signals_top_internal(self, node, parameter_map): for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): if not isinstance(arg_info, dict): continue + if "data_in" in arg: + continue # Skip off-chip parameters as they will be directly connected to the top level if ( - "data_in" in arg - or node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] + node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] == "BRAM" ): arg_name = v2p(arg) @@ -762,13 +731,12 @@ def _emit_signals_top_internal(self, node, parameter_map): ): if not isinstance(result_info, dict): continue + if "data_out" in result: + continue # Skip off-chip parameters as they will be directly connected to the top level if ( - "data_out" in result - or node.meta["mase"].parameters["hardware"]["interface"][result][ - "storage" - ] + node.meta["mase"].parameters["hardware"]["interface"][result]["storage"] == "BRAM" ): result_name = v2p(result) @@ -844,10 +812,6 @@ def emit(self, parameter_map): if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue node_name = vf(node.name) - signals += f""" -// -------------------------- -// {node_name} signals -// --------------------------""" if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: signals += self._emit_signals_top_internal(node, parameter_map) elif node.meta["mase"].parameters["hardware"]["toolchain"] == "HLS": @@ -895,50 +859,8 @@ def emit(self, node, parameter_map): component_name = node.meta["mase"].parameters["hardware"]["module"] signals = "" - # Emit component instantiation parameters - parameters = "" - for key, value in ( - node.meta["mase"].parameters["hardware"]["verilog_param"].items() - ): - key_value = parameter_map[f"{node_name}_{key}"] - debug_info = f"// = {key_value}" - parameters += f""" .{key}({node_name}_{key}), {debug_info}\n""" - parameters = _remove_last_comma(parameters) - - # Emit component instantiation input signals - for key, value in node.meta["mase"].parameters["common"]["args"].items(): - if "data" not in key: - continue - signals += f""" - .{key}({node_name}_{key}), - .{key}_valid({node_name}_{key}_valid), - .{key}_ready({node_name}_{key}_ready), - """ - - # Emit component instantiation output signals - for key, value in node.meta["mase"].parameters["common"]["results"].items(): - if "data" not in key: - continue - signals += f""" - .{key}({node_name}_{key}), - .{key}_valid({node_name}_{key}_valid), - .{key}_ready({node_name}_{key}_ready), - """ - signals = _remove_last_comma(signals) - - # Combine component instantiation - components = f""" -// {node_name} -{component_name} #( -{parameters} -) {node_name}_inst ( - .clk(clk), - .rst(rst), -{signals} -); -""" - # Emit module parameter instances (e.g. weights and biases) + components = "" for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): if "data_in" in arg: continue @@ -953,93 +875,85 @@ def emit(self, node, parameter_map): # ============================================================================= -# MMVerilog components (HLS) +# MMVerilog top interface connected to the dataflow design # ============================================================================= -class MMVerilogHLSComponentEmitter: +class MMVerilogTopInterfaceEmitter: def __init__(self, graph): self.graph = graph - def _emit_module_parameters_top_hls(self, key, value, node, parameter_map): - node_name = vf(node.name) - cap_key = v2p(key) - component_name = f"{node_name}_{key}_source" - component_name_inst = f"{node_name}_{key}_0" + def emit(self, parameter_map): + """ + Emit interface signal declarations for the top-level module + """ - size_debug_info = math.prod(value["shape"]) - a_width = math.ceil(math.log2(size_debug_info)) + nodes_in = self.graph.nodes_in + nodes_out = self.graph.nodes_out - return f""" -{component_name} #( - .DATA_WIDTH({node_name}_{cap_key}_PRECISION_0), - .ADDR_RANGE({node_name}_{cap_key}_TENSOR_SIZE_0), - .ADDR_WIDTH({a_width}) -) {component_name_inst} ( + interface = """ .clk(clk), - .reset(rst), - - .address0({node_name}_{key}_address0), - .ce0({node_name}_{key}_ce0), - .q0({node_name}_{key}_q0) -); -""" - - def emit(self, node, parameter_map): - node_name = vf(node.name) - component_name = node.meta["mase"].parameters["hardware"]["module"] - - # Emit kernel instance - signals = "" - for key, value in node.meta["mase"].parameters["common"]["args"].items(): - signals += f""" - .{key}_address0({node_name}_{key}_address0), - .{key}_ce0({node_name}_{key}_ce0), - .{key}_q0({node_name}_{key}_q0), -""" - - for key, value in node.meta["mase"].parameters["common"]["results"].items(): - signals += f""" - .{key}_address0({node_name}_{key}_address0), - .{key}_ce0({node_name}_{key}_ce0), - .{key}_we0({node_name}_{key}_we0), - .{key}_d0({node_name}_{key}_d0), -""" - signals = _remove_last_comma(signals) - components = f""" -// {node_name} -{component_name} #( -) {node_name}_inst ( - .ap_clk(clk), - .ap_rst(rst), - - .ap_start({node_name}_start), - .ap_idle({node_name}_idle), - .ap_ready({node_name}_ready), - .ap_done({node_name}_done), - .ap_ce({node_name}_ce), -{signals} -); + .rst(rst), """ + # TODO: here we just enumerate the inputs of the input nodes - which may be + # order insensitive and require manual connection when adding the graph to + # a system. + i = 0 + for node in nodes_in: + node_name = vf(node.name) + for arg in node.meta["mase"].parameters["common"]["args"].keys(): + if "data_in" in arg: + arg_name = _cap(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + .data_in_{i}(data_in_{i}), + .data_in_{i}_valid(data_in_{i}_valid), + .data_in_{i}_ready(data_in_{i}_ready),""" + i += 1 - # Emit parameter instance - for key, value in node.meta["mase"].parameters["common"]["args"].items(): - # Skip the parameter instance if the memory is stored off chip - if not param_needs_signals(node, key, value, qualifier="data_in"): - continue - components += self._emit_module_parameters_top_hls( - key, value, node, parameter_map - ) + i = 0 + for node in nodes_out: + node_name = vf(node.name) + for result in node.meta["mase"].parameters["common"]["results"].keys(): + if "data_out" in result: + result_name = _cap(result) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{result_name}_PARALLELISM_DIM" in param + ] + interface += f""" + .data_out_{i}(data_out_{i}), + .data_out_{i}_valid(data_out_{i}_valid), + .data_out_{i}_ready(data_out_{i}_ready),""" + i += 1 - for key, value in node.meta["mase"].parameters["common"]["results"].items(): - # Skip the parameter instance if the memory is stored off chip - if not param_needs_signals(node, key, value, qualifier="data_out"): + # Emit all parameters as inputs (they will be mapped at the top-level) + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue - components += self._emit_module_parameters_top_hls( - key, value, node, parameter_map - ) + node_name = vf(node.name) + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if not isinstance(arg_info, dict): + continue + if "data_in" not in arg: + arg_name = _cap(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + .{node_name}_{arg}({node_name}_{arg}), + .{node_name}_{arg}_valid({node_name}_{arg}_valid), + .{node_name}_{arg}_ready({node_name}_{arg}_ready),""" + i += 1 - return components + return _remove_last_comma(interface) # ============================================================================= @@ -1051,17 +965,25 @@ class MMVerilogComponentEmitter: def __init__(self, graph): self.graph = graph self.internal_emitter = MMVerilogInternalComponentEmitter(graph) - self.hls_emitter = MMVerilogHLSComponentEmitter(graph) - def emit(self, parameter_map): + def emit(self, parameter_map, top): """ Emit component declarations for the top-level module """ - components = """ + # Write node parameters + top_parameters = "" + for key, value in parameter_map.items(): + top_parameters += f""" .{key}({key}),\n""" + interface = MMVerilogTopInterfaceEmitter(self.graph).emit(parameter_map) + + components = f""" // -------------------------- // Component instantiation // -------------------------- +{top}_dataflow #({top_parameters} +) {top}_df_inst ({interface} +); """ for node in self.graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: @@ -1069,104 +991,14 @@ def emit(self, parameter_map): if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: components += self.internal_emitter.emit(node, parameter_map) elif node.meta["mase"].parameters["hardware"]["toolchain"] == "HLS": - components += self.hls_emitter.emit(node, parameter_map) + # Assume all parameters in HLS components are local + continue else: assert False, "Unknown node toolchain for signal declarations." return components -# ============================================================================= -# MMVerilog wires -# ============================================================================= - - -class MMVerilogWireEmitter: - def __init__(self, graph): - self.graph = graph - - self.wires = """ -// -------------------------- -// Interconnections -// -------------------------- - """ - - def _emit_top_wires(self, parameter_map): - nodes_in = self.graph.nodes_in - nodes_out = self.graph.nodes_out - - # ============================================================ - # Top level wires - # ============================================================ - - wires = "" - # TODO: here we just enumerate the inputs of the input nodes - which may be - # order insensitive and require manual connection when adding the graph to - # a system. - i = 0 - for node in nodes_in: - node_name = vf(node.name) - for arg in node.meta["mase"].parameters["common"]["args"].keys(): - if "data_in" in arg: - wires += f""" - assign data_in_{i}_ready = {node_name}_{arg}_ready; - assign {node_name}_{arg}_valid = data_in_{i}_valid; - assign {node_name}_{arg} = data_in_{i}; -""" - i += 1 - i = 0 - for node in nodes_out: - node_name = vf(node.name) - for result in node.meta["mase"].parameters["common"]["results"].keys(): - if "data_out" in result: - wires += f""" - assign data_out_{i}_valid = {node_name}_{result}_valid; - assign {node_name}_{result}_ready = data_out_{i}_ready; - assign data_out_{i} = {node_name}_{result}; -""" - i += 1 - - # TODO: emit off-chip parameter interface - - return wires - - def _emit_node2node_wires(self): - nodes_in = self.graph.nodes_in - - # Ignore the input of the input nodes - # (as they are already connected by the previous process) - # For each other explicit node, emit the edge of their inputs. - # Assume all the node has only one output. - wires = "" - for node in self.graph.fx_graph.nodes: - if node.meta["mase"].parameters["hardware"]["is_implicit"]: - continue - if node in nodes_in: - continue - - to_name = vf(node.name) - for i, node_in in enumerate(node.all_input_nodes): - from_name = vf(node_in.name) - wires += f""" - assign {from_name}_data_out_0_ready = {to_name}_data_in_{i}_ready; - assign {to_name}_data_in_{i}_valid = {from_name}_data_out_0_valid; - assign {to_name}_data_in_{i} = {from_name}_data_out_0; -""" - return wires - - def emit(self, parameter_map): - """ - Emit internal signal connections for the top-level module - This includes two interconnection types: - 1. Type casting between inputs and outputs - 2. Interface casting between inputs and outputs - """ - - self.wires += self._emit_top_wires(parameter_map) - self.wires += self._emit_node2node_wires() - return self.wires - - # ============================================================================= # Emit MMVerilog # ============================================================================= @@ -1179,7 +1011,7 @@ def __init__(self, graph): self.parameter_map = get_verilog_parameters(graph) def emit(self, top_name): - parameters_to_emit = MMVerilogParameterEmitter(self.graph).emit( + parameters_to_emit = VerilogParameterEmitter(self.graph).emit( self.parameter_map ) @@ -1190,16 +1022,14 @@ def emit(self, top_name): signals_to_emit = MMVerilogSignalEmitter(self.graph).emit(self.parameter_map) components_to_emit = MMVerilogComponentEmitter(self.graph).emit( - self.parameter_map + self.parameter_map, top_name ) - wires_to_emit = MMVerilogWireEmitter(self.graph).emit(self.parameter_map) - time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S") module_inst = """ // ===================================== -// Mase Hardware +// Mase Hardware (Memory Map) // Model: {} // {} // ===================================== @@ -1213,7 +1043,6 @@ def emit(self, top_name): ); {} {} -{} endmodule """.format( top_name, @@ -1223,7 +1052,6 @@ def emit(self, top_name): interface_to_emit, signals_to_emit, components_to_emit, - wires_to_emit, ) return module_inst @@ -1250,8 +1078,7 @@ def emit_verilog_top_transform_pass(graph, pass_args={}): project_dir = ( pass_args["project_dir"] if "project_dir" in pass_args.keys() - else "./top" - # else Path.home() / ".mase" / "top" + else Path.home() / ".mase" / "top" ) top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" init_project(project_dir) From 050901445cc7975c4090a19d2c8ba17acb12cfa7 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Wed, 24 Apr 2024 17:38:19 +0000 Subject: [PATCH 11/31] Added rounding at the output of relu --- .../activations/rtl/fixed_relu.sv | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/machop/mase_components/activations/rtl/fixed_relu.sv b/machop/mase_components/activations/rtl/fixed_relu.sv index ceef9c053..91e3399c8 100644 --- a/machop/mase_components/activations/rtl/fixed_relu.sv +++ b/machop/mase_components/activations/rtl/fixed_relu.sv @@ -28,14 +28,29 @@ module fixed_relu #( input logic data_out_0_ready ); - for (genvar i = 0; i < DATA_IN_0_TENSOR_SIZE_DIM_0; i++) begin : ReLU + logic [DATA_IN_0_PRECISION_0-1:0] data[DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0] ; + + for ( + genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1; i++ + ) begin : ReLU always_comb begin // negative value, put to zero if ($signed(data_in_0[i]) <= 0) data_out_0[i] = '0; - else data_out_0[i] = data_in_0[i]; + else data[i] = data_in_0[i]; end end + fixed_rounding #( + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1), + .IN_WIDTH(DATA_IN_0_PRECISION_0), + .IN_FRAC_WIDTH(DATA_IN_0_PRECISION_1), + .OUT_WIDTH(DATA_OUT_0_PRECISION_0), + .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1) + ) fr_inst ( + .data_in (data), + .data_out(data_out) + ); + assign data_out_0_valid = data_in_0_valid; assign data_in_0_ready = data_out_0_ready; From 181b5377b5190564fd89191eee7a16f5999790b9 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Thu, 25 Apr 2024 09:13:36 +0000 Subject: [PATCH 12/31] Fixed relu var name typos --- machop/mase_components/activations/rtl/fixed_relu.sv | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/machop/mase_components/activations/rtl/fixed_relu.sv b/machop/mase_components/activations/rtl/fixed_relu.sv index 91e3399c8..2ef87b1bc 100644 --- a/machop/mase_components/activations/rtl/fixed_relu.sv +++ b/machop/mase_components/activations/rtl/fixed_relu.sv @@ -35,7 +35,7 @@ module fixed_relu #( ) begin : ReLU always_comb begin // negative value, put to zero - if ($signed(data_in_0[i]) <= 0) data_out_0[i] = '0; + if ($signed(data_in_0[i]) <= 0) data[i] = '0; else data[i] = data_in_0[i]; end end @@ -48,7 +48,7 @@ module fixed_relu #( .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1) ) fr_inst ( .data_in (data), - .data_out(data_out) + .data_out(data_out_0) ); assign data_out_0_valid = data_in_0_valid; From f93c9768072908868767cd774a1437dab0d94d5c Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Thu, 25 Apr 2024 14:01:32 +0000 Subject: [PATCH 13/31] Updated linear layer with the right parameter names --- .../linear/rtl/fixed_linear.sv | 110 +++--- .../linear/test/fixed_linear_tb.py | 374 +++++++++++------- 2 files changed, 287 insertions(+), 197 deletions(-) diff --git a/machop/mase_components/linear/rtl/fixed_linear.sv b/machop/mase_components/linear/rtl/fixed_linear.sv index de6958925..584733bc1 100644 --- a/machop/mase_components/linear/rtl/fixed_linear.sv +++ b/machop/mase_components/linear/rtl/fixed_linear.sv @@ -1,76 +1,76 @@ `timescale 1ns / 1ps - -/* - * Constrained by WEIGHT_PARALLELISM_DIM_0 == DATA_OUT_0_PARALLELISM_DIM_0 - * -*/ - module fixed_linear #( /* verilator lint_off UNUSEDPARAM */ - parameter HAS_BIAS = 0, + parameter HAS_BIAS = 1, - parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_0 = 8, parameter DATA_IN_0_PRECISION_1 = 3, - parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 4, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 1, parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, - parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, - parameter IN_0_DEPTH = DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0, + parameter DATA_IN_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_2 = 1, - parameter WEIGHT_PRECISION_0 = 16, + parameter WEIGHT_PRECISION_0 = 8, parameter WEIGHT_PRECISION_1 = 3, - parameter WEIGHT_TENSOR_SIZE_DIM_0 = 32, + parameter WEIGHT_TENSOR_SIZE_DIM_0 = 1, + parameter WEIGHT_PARALLELISM_DIM_0 = 1, parameter WEIGHT_TENSOR_SIZE_DIM_1 = 1, - parameter WEIGHT_PARALLELISM_DIM_0 = 4, parameter WEIGHT_PARALLELISM_DIM_1 = 1, + parameter WEIGHT_TENSOR_SIZE_DIM_2 = 1, + parameter WEIGHT_PARALLELISM_DIM_2 = 1, - parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2( - DATA_IN_0_TENSOR_SIZE_DIM_0 - ) + $clog2( - IN_0_DEPTH - ) + HAS_BIAS, - parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1 + WEIGHT_PRECISION_1, - parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 4, + parameter DATA_OUT_0_PRECISION_0 = 8, + parameter DATA_OUT_0_PRECISION_1 = 3, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = 1, parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1, - parameter DATA_OUT_0_PARALLELISM_DIM_0 = WEIGHT_PARALLELISM_DIM_0, parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_2 = 1, - parameter BIAS_PRECISION_0 = 16, + parameter BIAS_PRECISION_0 = 8, parameter BIAS_PRECISION_1 = 3, - parameter BIAS_TENSOR_SIZE_DIM_0 = DATA_OUT_0_TENSOR_SIZE_DIM_0, - parameter BIAS_TENSOR_SIZE_DIM_1 = 1, + parameter BIAS_TENSOR_SIZE_DIM_0 = 1, parameter BIAS_PARALLELISM_DIM_0 = 1, - parameter BIAS_PARALLELISM_DIM_1 = 1 + parameter BIAS_TENSOR_SIZE_DIM_1 = 1, + parameter BIAS_PARALLELISM_DIM_1 = 1, + parameter BIAS_TENSOR_SIZE_DIM_2 = 1, + parameter BIAS_PARALLELISM_DIM_2 = 1 + /* verilator lint_on UNUSEDPARAM */ ) ( input clk, input rst, // input port for data_inivations - input [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], - input data_in_0_valid, - output data_in_0_ready, + input [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_1-1:0], + input data_in_0_valid, + output data_in_0_ready, // input port for weight - input [WEIGHT_PRECISION_0-1:0] weight [WEIGHT_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_0-1:0], + input [WEIGHT_PRECISION_0-1:0] weight [WEIGHT_PARALLELISM_DIM_0*WEIGHT_PARALLELISM_DIM_1-1:0], input weight_valid, output weight_ready, /* verilator lint_off UNUSEDSIGNAL */ - input [BIAS_PRECISION_0-1:0] bias[BIAS_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_0-1:0], - input bias_valid, + input [BIAS_PRECISION_0-1:0] bias [BIAS_PARALLELISM_DIM_0-1:0], + input bias_valid, /* verilator lint_on UNUSEDSIGNAL */ - output bias_ready, + output bias_ready, - output [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], - output data_out_0_valid, - input data_out_0_ready + output [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output data_out_0_valid, + input data_out_0_ready ); localparam FDP_WIDTH = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2( - DATA_IN_0_PARALLELISM_DIM_0 + DATA_IN_0_PARALLELISM_DIM_1 + ); + localparam ACC_WIDTH = FDP_WIDTH + $clog2( + DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1 ); - localparam ACC_WIDTH = FDP_WIDTH + $clog2(IN_0_DEPTH); logic fdp_join_valid, fdp_join_ready; join2 #() fdp_join_inst ( @@ -84,19 +84,19 @@ module fixed_linear #( // Assume the parallelised hardware above have the same arrival time // which means that they always have the same state. So we can just // pick one of the valid signal to use. - logic [WEIGHT_PARALLELISM_DIM_0-1:0] fdp_data_ready, fdp_weight_ready; + logic [DATA_OUT_0_PARALLELISM_DIM_1-1:0] fdp_data_ready, fdp_weight_ready; assign fdp_join_ready = fdp_data_ready[0]; /* verilator lint_on UNUSEDSIGNAL */ logic acc_ready; - logic [ACC_WIDTH-1:0] acc_data_out[WEIGHT_PARALLELISM_DIM_0*WEIGHT_PARALLELISM_DIM_1-1:0]; + logic [ACC_WIDTH-1:0] acc_data_out[DATA_OUT_0_PARALLELISM_DIM_1-1:0]; - // There are WEIGHT_PARALLELISM_DIM_0 number of dot product instances with DATA_IN_0_TENSOR_SIZE_DIM_0 inputs - // and each one computes for IN_0_DEPTH iterations for each inputs. - for (genvar i = 0; i < WEIGHT_PARALLELISM_DIM_0; i = i + 1) begin : linear + // There are DATA_OUT_0_PARALLELISM_DIM_1 number of dot product instances with DATA_IN_0_PARALLELISM_DIM_1 inputs + // and each one computes for DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1 iterations for each inputs. + for (genvar i = 0; i < DATA_OUT_0_PARALLELISM_DIM_1; i = i + 1) begin : linear // Assume the weight are transposed and partitioned - logic [WEIGHT_PRECISION_0-1:0] current_weight[DATA_IN_0_PARALLELISM_DIM_0-1:0]; - assign current_weight = weight[DATA_IN_0_PARALLELISM_DIM_0*(i+1)-1:DATA_IN_0_PARALLELISM_DIM_0*i]; + logic [WEIGHT_PRECISION_0-1:0] current_weight[DATA_IN_0_PARALLELISM_DIM_1-1:0]; + assign current_weight = weight[DATA_IN_0_PARALLELISM_DIM_1*i+DATA_IN_0_PARALLELISM_DIM_1-1:DATA_IN_0_PARALLELISM_DIM_1*i]; logic [FDP_WIDTH-1:0] fdp_data_out; logic fdp_data_out_valid, fdp_data_out_ready; @@ -105,7 +105,7 @@ module fixed_linear #( fixed_dot_product #( .IN_WIDTH(DATA_IN_0_PRECISION_0), .WEIGHT_WIDTH(WEIGHT_PRECISION_0), - .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0) + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_1) ) fdp_inst ( .clk(clk), .rst(rst), @@ -126,7 +126,7 @@ module fixed_linear #( fixed_accumulator #( .IN_WIDTH(FDP_WIDTH), - .IN_DEPTH(IN_0_DEPTH) + .IN_DEPTH(DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1) ) fixed_accumulator_inst ( .clk(clk), .rst(rst), @@ -146,9 +146,8 @@ module fixed_linear #( if (HAS_BIAS == 1) begin - logic [ACC_WIDTH-1:0] bias_sext[BIAS_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_0-1:0]; + logic [ACC_WIDTH-1:0] bias_sext[BIAS_PARALLELISM_DIM_0-1:0]; logic acc_join_valid, acc_join_ready; - logic [DATA_IN_0_PARALLELISM_DIM_0-1:0] reg_ready; join2 #() acc_join_inst ( .data_in_ready ({bias_ready, acc_ready}), @@ -156,9 +155,11 @@ module fixed_linear #( .data_out_valid(acc_join_valid), .data_out_ready(acc_join_ready) ); + logic [BIAS_PARALLELISM_DIM_0-1:0] reg_ready; + assign acc_join_ready = ®_ready; fixed_rounding #( - .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0), + .IN_SIZE(BIAS_PARALLELISM_DIM_0), .IN_WIDTH(BIAS_PRECISION_0), .IN_FRAC_WIDTH(BIAS_PRECISION_1), .OUT_WIDTH(ACC_WIDTH), @@ -168,9 +169,7 @@ module fixed_linear #( .data_out(bias_sext) ); - assign acc_join_ready = ®_ready; - - for (genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0; i = i + 1) begin : add_bias + for (genvar i = 0; i < BIAS_PARALLELISM_DIM_0; i = i + 1) begin : add_bias logic [DATA_OUT_0_PRECISION_0-1:0] add; assign add = $signed(acc_data_out[i]) + $signed(bias_sext[i]); /* verilator lint_off UNUSEDSIGNAL */ @@ -193,10 +192,7 @@ module fixed_linear #( end else begin assign acc_ready = data_out_0_ready; assign data_out_0_valid = linear[0].acc_data_out_valid; - - for (genvar i = 0; i < WEIGHT_PARALLELISM_DIM_0; i = i + 1) begin - assign data_out_0[i] = acc_data_out[i]; - end + assign data_out_0 = acc_data_out; assign bias_ready = 1; end diff --git a/machop/mase_components/linear/test/fixed_linear_tb.py b/machop/mase_components/linear/test/fixed_linear_tb.py index 54596dea8..58e57b9f1 100644 --- a/machop/mase_components/linear/test/fixed_linear_tb.py +++ b/machop/mase_components/linear/test/fixed_linear_tb.py @@ -3,158 +3,252 @@ # This script tests the fixed point linear import os, logging -import cocotb -from cocotb.log import SimLog -from cocotb.triggers import * - -from mase_cocotb.testbench import Testbench -from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor -from mase_cocotb.z_qlayers import quantize_to_int +from mase_cocotb.random_test import RandomSource, RandomSink, check_results from mase_cocotb.runner import mase_runner -from mase_cocotb.utils import bit_driver, sign_extend_t - -from chop.passes.graph.transforms.quantize.quantized_modules import LinearInteger - -import torch - -logger = logging.getLogger("testbench") -logger.setLevel(logging.DEBUG) - -class LinearTB(Testbench): - def __init__(self, dut, in_features=4, out_features=4) -> None: - super().__init__(dut, dut.clk, dut.rst) - - if not hasattr(self, "log"): - self.log = SimLog("%s" % (type(self).__qualname__)) - - self.data_in_0_driver = StreamDriver( - dut.clk, dut.data_in_0, dut.data_in_0_valid, dut.data_in_0_ready +import cocotb +from cocotb.triggers import Timer +from cocotb.triggers import FallingEdge +from cocotb.clock import Clock + +debug = False + +logger = logging.getLogger("tb_signals") +if debug: + logger.setLevel(logging.DEBUG) + + +# DUT test specifications +class VerificationCase: + def __init__(self, samples=10): + self.has_bias = 1 + + self.data_in_0_precision_0 = 8 + self.data_in_0_precision_1 = 3 + self.data_in_0_tensor_size_dim_0 = 1 + self.data_in_0_parallelism_dim_0 = 1 + self.data_in_0_tensor_size_dim_1 = 8 + self.data_in_0_parallelism_dim_1 = 8 + self.data_in_0_tensor_size_dim_2 = 1 + self.data_in_0_parallelism_dim_2 = 1 + + self.weight_precision_0 = 8 + self.weight_precision_1 = 3 + self.weight_tensor_size_dim_0 = 10 + self.weight_parallelism_dim_0 = 10 + self.weight_tensor_size_dim_1 = 8 + self.weight_parallelism_dim_1 = 8 + self.weight_tensor_size_dim_2 = 1 + self.weight_parallelism_dim_2 = 1 + + self.data_out_0_precision_0 = 32 + self.data_out_0_precision_1 = 16 + self.data_out_0_tensor_size_dim_0 = 1 + self.data_out_0_parallelism_dim_0 = 1 + self.data_out_0_tensor_size_dim_1 = 10 + self.data_out_0_parallelism_dim_1 = 10 + self.data_out_0_tensor_size_dim_2 = 1 + + self.bias_precision_0 = 8 + self.bias_precision_1 = 3 + self.bias_tensor_size_dim_0 = 10 + self.bias_parallelism_dim_0 = 10 + self.bias_tensor_size_dim_1 = 1 + self.bias_parallelism_dim_1 = 1 + self.bias_tensor_size_dim_2 = 1 + self.bias_parallelism_dim_2 = 1 + + self.data_in = RandomSource( + name="data_in", + samples=samples + * self.data_in_0_tensor_size_dim_1 + // self.data_in_0_parallelism_dim_1, + num=self.data_in_0_parallelism_dim_1, + max_stalls=0, + debug=debug, ) - self.weight_driver = StreamDriver( - dut.clk, dut.weight, dut.weight_valid, dut.weight_ready + self.weight = RandomSource( + name="weight", + samples=samples + * self.weight_tensor_size_dim_1 + // self.weight_parallelism_dim_1, + num=self.weight_parallelism_dim_0 * self.weight_parallelism_dim_1, + max_stalls=0, + debug=debug, ) - - if int(dut.HAS_BIAS) == 1: - self.bias_driver = StreamDriver( - dut.clk, dut.bias, dut.bias_valid, dut.bias_ready - ) - - self.data_out_0_monitor = StreamMonitor( - dut.clk, - dut.data_out_0, - dut.data_out_0_valid, - dut.data_out_0_ready, - check=False, + self.bias = RandomSource( + name="bias", + samples=samples, + num=self.bias_parallelism_dim_0, + max_stalls=0, + debug=debug, ) - # Model - self.model = LinearInteger( - in_features=in_features, - out_features=out_features, - bias=False, - config={ - "data_in_width": 16, - "data_in_frac_width": 3, - "weight_width": 16, - "weight_frac_width": 3, - "bias_width": 16, - "bias_frac_width": 3, - }, + self.outputs = RandomSink(samples=samples, max_stalls=0, debug=debug) + self.samples = samples + self.ref = self.sw_compute() + + def get_dut_parameters(self): + return { + "HAS_BIAS": self.has_bias, + "DATA_IN_0_PRECISION_0": self.data_in_0_precision_0, + "DATA_IN_0_PRECISION_1": self.data_in_0_precision_1, + "DATA_IN_0_TENSOR_SIZE_DIM_0": self.data_in_0_tensor_size_dim_0, + "DATA_IN_0_PARALLELISM_DIM_0": self.data_in_0_parallelism_dim_0, + "DATA_IN_0_TENSOR_SIZE_DIM_1": self.data_in_0_tensor_size_dim_1, + "DATA_IN_0_PARALLELISM_DIM_1": self.data_in_0_parallelism_dim_1, + "DATA_IN_0_TENSOR_SIZE_DIM_2": self.data_in_0_tensor_size_dim_2, + "DATA_IN_0_PARALLELISM_DIM_2": self.data_in_0_parallelism_dim_2, + "WEIGHT_PRECISION_0": self.weight_precision_0, + "WEIGHT_PRECISION_1": self.weight_precision_1, + "WEIGHT_TENSOR_SIZE_DIM_0": self.weight_tensor_size_dim_0, + "WEIGHT_PARALLELISM_DIM_0": self.weight_parallelism_dim_0, + "WEIGHT_TENSOR_SIZE_DIM_1": self.weight_tensor_size_dim_1, + "WEIGHT_PARALLELISM_DIM_1": self.weight_parallelism_dim_1, + "WEIGHT_TENSOR_SIZE_DIM_2": self.weight_tensor_size_dim_2, + "WEIGHT_PARALLELISM_DIM_2": self.weight_parallelism_dim_2, + "DATA_OUT_0_PRECISION_0": self.data_out_0_precision_0, + "DATA_OUT_0_PRECISION_1": self.data_out_0_precision_1, + "DATA_OUT_0_TENSOR_SIZE_DIM_0": self.data_out_0_tensor_size_dim_0, + "DATA_OUT_0_PARALLELISM_DIM_0": self.data_out_0_parallelism_dim_0, + "DATA_OUT_0_TENSOR_SIZE_DIM_1": self.data_out_0_tensor_size_dim_1, + "DATA_OUT_0_PARALLELISM_DIM_1": self.data_out_0_parallelism_dim_1, + "DATA_OUT_0_TENSOR_SIZE_DIM_2": self.data_out_0_tensor_size_dim_2, + "BIAS_PRECISION_0": self.bias_precision_0, + "BIAS_PRECISION_1": self.bias_precision_1, + "BIAS_TENSOR_SIZE_DIM_0": self.bias_tensor_size_dim_0, + "BIAS_PARALLELISM_DIM_0": self.bias_parallelism_dim_0, + "BIAS_TENSOR_SIZE_DIM_1": self.bias_tensor_size_dim_1, + "BIAS_PARALLELISM_DIM_1": self.bias_parallelism_dim_1, + "BIAS_TENSOR_SIZE_DIM_2": self.bias_tensor_size_dim_2, + "BIAS_PARALLELISM_DIM_2": self.bias_parallelism_dim_2, + } + + def sw_compute(self): + ref = [] + for i in range(self.samples): + acc = [0 for _ in range(self.data_out_0_parallelism_dim_1)] + for j in range( + self.data_in_0_tensor_size_dim_1 // self.data_in_0_parallelism_dim_1 + ): + data_idx = ( + i + * self.data_in_0_tensor_size_dim_1 + // self.data_in_0_parallelism_dim_1 + + j + ) + temp = [] + for k in range(self.data_out_0_parallelism_dim_1): + s = [ + self.data_in.data[data_idx][h] + * self.weight.data[data_idx][ + k * self.data_in_0_parallelism_dim_1 + h + ] + for h in range(self.data_in_0_parallelism_dim_1) + ] + acc[k] += sum(s) + if self.has_bias: + for k in range(self.bias_parallelism_dim_0): + acc[k] += self.bias.data[i][k] << ( + self.weight_precision_1 + + self.data_in_0_precision_1 + - self.bias_precision_1 + ) + ref.append(acc) + ref.reverse() + return ref + + +def debug_state(dut, state): + logger.debug( + "{} State: (bias_ready,bias_valid,bias_ready,bias_valid,data_in_ready,data_in_valid,data_out_ready,data_out_valid) = ({},{},{},{},{},{})".format( + state, + dut.bias_ready.value, + dut.bias_valid.value, + dut.weight_ready.value, + dut.weight_valid.value, + dut.data_in_0_ready.value, + dut.data_in_0_valid.value, + dut.data_out_0_ready.value, + dut.data_out_0_valid.value, ) + ) + - def generate_inputs(self): - return torch.randn((1, self.model.in_features)) - - def preprocess_tensor(self, tensor, quantizer, config, parallelism): - tensor = quantizer(tensor) - tensor = (tensor * 2 ** config["frac_width"]).int() - logger.info(f"Tensor in int format: {tensor}") - tensor = tensor.reshape(-1, parallelism).tolist() - return tensor - - async def run_test(self): - await self.reset() - logger.info(f"Reset finished") - self.data_out_0_monitor.ready.value = 1 - - inputs = self.generate_inputs() - exp_out = self.model(inputs) - - # Load the inputs driver - logger.info(f"Processing inputs") - inputs = self.preprocess_tensor( - inputs, - self.model.x_quantizer, - {"widht": 16, "frac_width": 3}, - int(self.dut.DATA_IN_0_PARALLELISM_DIM_0), +@cocotb.test() +async def test_fixed_linear(dut): + """Test integer based vector mult""" + samples = 1000 + test_case = VerificationCase(samples=samples) + + # Reset cycle + await Timer(20, units="ns") + dut.rst.value = 1 + await Timer(100, units="ns") + dut.rst.value = 0 + + # Create a 10ns-period clock on port clk + clock = Clock(dut.clk, 10, units="ns") + # Start the clock + cocotb.start_soon(clock.start()) + await Timer(500, units="ns") + + # Synchronize with the clock + dut.weight_valid.value = 0 + dut.bias_valid.value = 0 + dut.data_in_0_valid.value = 0 + dut.data_out_0_ready.value = 1 + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + + done = False + # Set a timeout to avoid deadlock + for i in range(samples * 100): + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + dut.weight_valid.value = test_case.weight.pre_compute() + dut.bias_valid.value = test_case.bias.pre_compute() + dut.data_in_0_valid.value = test_case.data_in.pre_compute() + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.pre_compute( + dut.data_out_0_valid.value ) - self.data_in_0_driver.load_driver(inputs) - - # Load the weights driver - logger.info(f"Processing weights") - weights = self.preprocess_tensor( - self.model.weight, - self.model.w_quantizer, - {"widht": 16, "frac_width": 3}, - int(self.dut.WEIGHT_PARALLELISM_DIM_0) - * int(self.dut.DATA_IN_0_PARALLELISM_DIM_0), + await Timer(1, units="ns") + debug_state(dut, "Post-clk") + + dut.bias_valid.value, dut.bias.value = test_case.bias.compute( + dut.bias_ready.value ) - self.weight_driver.load_driver(weights) - - # Load the output monitor - logger.info(f"Processing outputs: {exp_out}") - # To do: need to quantize output to a different precision - outs = self.preprocess_tensor( - exp_out, - self.model.x_quantizer, - {"widht": 16, "frac_width": 3}, - int(self.dut.DATA_OUT_0_PARALLELISM_DIM_0), + dut.weight_valid.value, dut.weight.value = test_case.weight.compute( + dut.weight_ready.value ) - self.data_out_0_monitor.load_monitor(outs) - - await Timer(1000, units="us") - assert self.data_out_0_monitor.exp_queue.empty() + dut.data_in_0_valid.value, dut.data_in_0.value = test_case.data_in.compute( + dut.data_in_0_ready.value + ) + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.compute( + dut.data_out_0_valid.value, dut.data_out_0.value + ) + debug_state(dut, "Pre-clk") + if ( + (not test_case.has_bias or test_case.bias.is_empty()) + and test_case.weight.is_empty() + and test_case.data_in.is_empty() + and test_case.outputs.is_full() + ): + done = True + break + assert ( + done + ), "Deadlock detected or the simulation reaches the maximum cycle limit (fixed it by adjusting the loop trip count)" -@cocotb.test() -async def test_20x20(dut): - tb = LinearTB(dut, in_features=20, out_features=20) - await tb.run_test() + check_results(test_case.outputs.data, test_case.ref) if __name__ == "__main__": - mase_runner( - trace=True, - module_param_list=[ - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, - "DATA_IN_0_PARALLELISM_DIM_0": 2, - "WEIGHT_TENSOR_SIZE_DIM_0": 20, - "WEIGHT_TENSOR_SIZE_DIM_1": 20, - "WEIGHT_PARALLELISM_DIM_0": 20, - "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, - "DATA_OUT_0_PARALLELISM_DIM_0": 20, - "BIAS_TENSOR_SIZE_DIM_0": 20, - }, - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, - "DATA_IN_0_PARALLELISM_DIM_0": 4, - "WEIGHT_TENSOR_SIZE_DIM_0": 20, - "WEIGHT_TENSOR_SIZE_DIM_1": 20, - "WEIGHT_PARALLELISM_DIM_0": 20, - "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, - "DATA_OUT_0_PARALLELISM_DIM_0": 20, - "BIAS_TENSOR_SIZE_DIM_0": 20, - }, - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, - "DATA_IN_0_PARALLELISM_DIM_0": 5, - "WEIGHT_TENSOR_SIZE_DIM_0": 20, - "WEIGHT_TENSOR_SIZE_DIM_1": 20, - "WEIGHT_PARALLELISM_DIM_0": 20, - "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, - "DATA_OUT_0_PARALLELISM_DIM_0": 20, - "BIAS_TENSOR_SIZE_DIM_0": 20, - }, - ], - ) + tb = VerificationCase() + mase_runner(module_param_list=[tb.get_dut_parameters()]) From 0a8b4dd0bb14464534d7a00c85a76b7617d52412 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Thu, 25 Apr 2024 14:24:21 +0000 Subject: [PATCH 14/31] Remove temporary code for parallelism --- .../graph/analysis/add_metadata/add_hardware_metadata.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py index 87d839644..b8ed538a7 100644 --- a/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py +++ b/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py @@ -45,6 +45,8 @@ def add_component_source(node): node.meta["mase"]["hardware"]["dependence_files"] = [] node.meta["mase"]["hardware"]["device_id"] = -1 + # Init data parallelism to 1 and use DSE pass for exploration + node.meta["mase"]["hardware"]["parallelism"] = {0: 1, 1: 1, 2: 1} # Current only support on-chip parameters args = node.meta["mase"]["common"]["args"] @@ -369,11 +371,6 @@ def add_hardware_metadata_analysis_pass(graph, pass_args=None): for node in graph.nodes: add_component_source(node) - # Temporary: fix parallelism to small value to enable verilator simulation - for node in graph.nodes: - # Batch parallelism set to 1, data parallelism to 4 - node.meta["mase"]["hardware"]["parallelism"] = [1, 4] - # Add hardware parameters for node in graph.nodes: add_verilog_param(node) From 35afa85ecd53f08f9943948626fb7741f7f700e6 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Thu, 25 Apr 2024 14:31:47 +0000 Subject: [PATCH 15/31] Added missing component interface --- machop/chop/passes/graph/transforms/verilog/emit_top.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_top.py b/machop/chop/passes/graph/transforms/verilog/emit_top.py index 32f79b0ff..2833729c5 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_top.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_top.py @@ -317,7 +317,7 @@ def emit(self, node, parameter_map): # Emit component instantiation input signals for key, value in node.meta["mase"].parameters["common"]["args"].items(): - if "data" not in key: + if not isinstance(value, dict): continue signals += f""" .{key}({node_name}_{key}), @@ -327,7 +327,7 @@ def emit(self, node, parameter_map): # Emit component instantiation output signals for key, value in node.meta["mase"].parameters["common"]["results"].items(): - if "data" not in key: + if not isinstance(value, dict): continue signals += f""" .{key}({node_name}_{key}), From 8ef71cf608cb2d0b8d7ee63a63c7a914d44cd54d Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Thu, 25 Apr 2024 14:44:11 +0000 Subject: [PATCH 16/31] removed extra comma in parameter map --- machop/chop/passes/graph/transforms/verilog/emit_top.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_top.py b/machop/chop/passes/graph/transforms/verilog/emit_top.py index 2833729c5..1d9a2640f 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_top.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_top.py @@ -975,6 +975,8 @@ def emit(self, parameter_map, top): top_parameters = "" for key, value in parameter_map.items(): top_parameters += f""" .{key}({key}),\n""" + top_parameters = _remove_last_comma(top_parameters) + interface = MMVerilogTopInterfaceEmitter(self.graph).emit(parameter_map) components = f""" From 921075621f5423f2a60795f3c68822dc7b5a0582 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 26 Apr 2024 20:13:32 +0000 Subject: [PATCH 17/31] Updated parallelism parameter formats --- .../add_metadata/add_hardware_metadata.py | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py index b8ed538a7..890e3a251 100644 --- a/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py +++ b/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py @@ -46,7 +46,13 @@ def add_component_source(node): node.meta["mase"]["hardware"]["device_id"] = -1 # Init data parallelism to 1 and use DSE pass for exploration - node.meta["mase"]["hardware"]["parallelism"] = {0: 1, 1: 1, 2: 1} + node.meta["mase"]["hardware"]["parallelism"] = {} + args = node.meta["mase"]["common"]["args"] + for arg, arg_info in args.items(): + node.meta["mase"]["hardware"]["parallelism"][arg] = {0: 1, 1: 1, 2: 1} + results = node.meta["mase"]["common"]["results"] + for result, result_info in results.items(): + node.meta["mase"]["hardware"]["parallelism"][result] = {0: 1, 1: 1, 2: 1} # Current only support on-chip parameters args = node.meta["mase"]["common"]["args"] @@ -83,17 +89,17 @@ def add_verilog_param(node): else 1 ) # If node data parallelism is set, take from hardware metadata - if node.meta["mase"]["hardware"]["parallelism"] is not None: - vp[_cap(arg + f"_parallelism_dim_{dim}")] = node.meta["mase"][ - "hardware" - ]["parallelism"][len(arg_info["shape"]) - 1 - dim] - # Otherwise, assign to tensor size by default - else: - vp[_cap(arg + f"_parallelism_dim_{dim}")] = ( - arg_info["shape"][len(arg_info["shape"]) - 1 - dim] - if dim < len(arg_info["shape"]) - else 1 - ) + assert node.meta["mase"]["hardware"]["parallelism"][arg] is not None + vp[_cap(arg + f"_parallelism_dim_{dim}")] = node.meta["mase"][ + "hardware" + ]["parallelism"][arg][len(arg_info["shape"]) - 1 - dim] + # # Otherwise, assign to tensor size by default + # else: + # vp[_cap(arg + f"_parallelism_dim_{dim}")] = ( + # arg_info["shape"][len(arg_info["shape"]) - 1 - dim] + # if dim < len(arg_info["shape"]) + # else 1 + # ) elif type(arg_info) == bool: vp[_cap(arg)] = 1 if arg_info else 0 else: @@ -109,16 +115,16 @@ def add_verilog_param(node): if dim < len(result_info["shape"]) else 1 ) - if node.meta["mase"]["hardware"]["parallelism"] is not None: - vp[_cap(result + f"_parallelism_dim_{dim}")] = node.meta["mase"][ - "hardware" - ]["parallelism"][len(result_info["shape"]) - 1 - dim] - else: - vp[_cap(result + f"_parallelism_dim_{dim}")] = ( - result_info["shape"][len(result_info["shape"]) - 1 - dim] - if dim < len(result_info["shape"]) - else 1 - ) + assert node.meta["mase"]["hardware"]["parallelism"] is not None + vp[_cap(result + f"_parallelism_dim_{dim}")] = node.meta["mase"][ + "hardware" + ]["parallelism"][result][len(result_info["shape"]) - 1 - dim] + # else: + # vp[_cap(result + f"_parallelism_dim_{dim}")] = ( + # result_info["shape"][len(result_info["shape"]) - 1 - dim] + # if dim < len(result_info["shape"]) + # else 1 + # ) else: vp[_cap(result)] = result_info From 3c272a72c1fec0b892be7d835ca080714d612543 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 26 Apr 2024 20:31:31 +0000 Subject: [PATCH 18/31] Updated bram width calculation with the latest parallelism parameters --- .../graph/transforms/verilog/emit_bram.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_bram.py b/machop/chop/passes/graph/transforms/verilog/emit_bram.py index 95a4bb5cb..ee03b3412 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_bram.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_bram.py @@ -34,18 +34,21 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): (Mostly because Vivado does not support string type parameters...) """ - # TODO: Force bias to have a depth of 1 for now - if param_name != "bias": - # out_depth = node.meta["mase"].parameters["hardware"]["verilog_param"][ - # "DATA_IN_0_DEPTH" - # ] - out_depth = 1 - else: - out_depth = 1 - addr_width = clog2(out_depth) + 1 total_size = math.prod( node.meta["mase"].parameters["common"]["args"][param_name]["shape"] ) + + dim = len(node.meta["mase"].parameters["common"]["args"][param_name]["shape"]) + out_depth = 1 + for i in range(dim): + out_depth *= int( + math.ceil( + node.meta["mase"].parameters["common"]["args"][param_name]["shape"][i] + / node.meta["mase"].parameters["hardware"]["parallelism"][param_name][i] + ) + ) + addr_width = clog2(out_depth) + 1 + # The depth of parameters must match with the input depth assert ( total_size % out_depth == 0 @@ -71,7 +74,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): module {node_param_name}_rom #( parameter DWIDTH = {out_size*out_width}, parameter MEM_SIZE = {out_depth}, - parameter AWIDTH = $clog2(MEM_SIZE) + 1 + parameter AWIDTH = $clog2(MEM_SIZE+1) ) ( input clk, input logic [AWIDTH-1:0] addr0, @@ -83,9 +86,9 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): logic [DWIDTH-1:0] q0_t0; logic [DWIDTH-1:0] q0_t1; - // initial begin - // $readmemh("{data_name}", ram); - // end + initial begin + $readmemh("{data_name}", ram); + end assign q0 = q0_t1; @@ -96,9 +99,9 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): `timescale 1 ns / 1 ps module {node_param_name} #( - parameter DATA_WIDTH = 32'd{out_width*out_size}, - parameter ADDR_RANGE = 32'd{out_depth}, - parameter ADDR_WIDTH = $clog2(ADDR_RANGE) + 1 + parameter DATA_WIDTH = {out_width*out_size}, + parameter ADDR_RANGE = {out_depth}, + parameter ADDR_WIDTH = $clog2(ADDR_RANGE+1) ) ( input reset, input clk, @@ -126,7 +129,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): parameter {_cap(param_name)}_PARALLELISM_DIM_0 = 1, parameter {_cap(param_name)}_PARALLELISM_DIM_1 = 1, - parameter OUT_DEPTH = {_cap(param_name)}_TENSOR_SIZE_DIM_0 / {_cap(param_name)}_PARALLELISM_DIM_0 + parameter OUT_DEPTH = {_cap(param_name)}_TENSOR_SIZE_DIM_0 * {_cap(param_name)}_TENSOR_SIZE_DIM_1 / ({_cap(param_name)}_PARALLELISM_DIM_0 * {_cap(param_name)}_PARALLELISM_DIM_1) ) ( input clk, input rst, @@ -136,8 +139,8 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): input data_out_ready ); // 1-bit wider so IN_DEPTH also fits. - localparam COUNTER_WIDTH = $clog2(OUT_DEPTH); - logic [COUNTER_WIDTH:0] counter; + localparam COUNTER_WIDTH = $clog2(OUT_DEPTH+1); + logic [COUNTER_WIDTH-1:0] counter; always_ff @(posedge clk) if (rst) counter <= 0; @@ -151,9 +154,9 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): logic ce0; assign ce0 = 1; - logic [{_cap(param_name)}_PRECISION_0*{_cap(param_name)}_TENSOR_SIZE_DIM_0-1:0] data_vector; + logic [{_cap(param_name)}_PRECISION_0*{_cap(param_name)}_PARALLELISM_DIM_0-1:0] data_vector; {node_param_name} #( - .DATA_WIDTH({_cap(param_name)}_PRECISION_0 * {_cap(param_name)}_TENSOR_SIZE_DIM_0), + .DATA_WIDTH({_cap(param_name)}_PRECISION_0 * {_cap(param_name)}_PARALLELISM_DIM_0), .ADDR_RANGE(OUT_DEPTH) ) {node_param_name}_mem ( .clk(clk), @@ -165,7 +168,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): // Cocotb/verilator does not support array flattening, so // we need to manually add some reshaping process. - for (genvar j = 0; j < {_cap(param_name)}_TENSOR_SIZE_DIM_0; j++) + for (genvar j = 0; j < {_cap(param_name)}_PARALLELISM_DIM_0; j++) assign data_out[j] = data_vector[{_cap(param_name)}_PRECISION_0*j+{_cap(param_name)}_PRECISION_0-1:{_cap(param_name)}_PRECISION_0*j]; assign data_out_valid = 1; From 29faf71a61efe3af61b2505e53ec303e18bd5b91 Mon Sep 17 00:00:00 2001 From: ChengZhang-98 <102538889+ChengZhang-98@users.noreply.github.com> Date: Fri, 26 Apr 2024 21:39:42 +0100 Subject: [PATCH 19/31] Fix quantization meta data for fixed-point quantization (#173) * new function to update meta after quantization [only works on fixed] * comment out manual precision & dtype assignment --------- Co-authored-by: Jianyi Cheng --- .../bert_quantized/quant_config_bert.py | 22 +- .../llama_quantized/quant_config_llama.py | 28 +- .../manual/opt_quantized/quant_config_opt.py | 24 +- machop/chop/models/manual/quant_utils.py | 4 +- machop/chop/passes/graph/__init__.py | 2 +- .../graph/analysis/report/report_node.py | 6 +- .../quantize/quant_parsers/__init__.py | 6 +- .../{ => archive}/parse_quant_config.py | 4 +- .../quant_parsers/archive/q_recipes.py | 54 +++ .../quant_parsers/archive/update_node_meta.py | 134 +++++++ .../quantize/quant_parsers/parse_q_config.py | 41 ++ .../quant_parsers/q_op_entries/__init__.py | 1 + .../quant_parsers/q_op_entries/fixed.py | 68 ++++ .../quant_parsers/update_node_meta.py | 375 ++++++++++++------ .../graph/transforms/quantize/quantize.py | 30 +- .../graph/transforms/verilog/emit_internal.py | 7 +- .../graph/transforms/verilog/emit_top.py | 168 ++++---- .../verilog/internal_file_dependences.py | 7 +- machop/configs/tests/quantize/fixed.toml | 33 +- .../verilog/test_emit_verilog_linear.py | 44 +- 20 files changed, 768 insertions(+), 290 deletions(-) rename machop/chop/passes/graph/transforms/quantize/quant_parsers/{ => archive}/parse_quant_config.py (98%) create mode 100644 machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/q_recipes.py create mode 100644 machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/update_node_meta.py create mode 100644 machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py create mode 100644 machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/__init__.py create mode 100644 machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py diff --git a/machop/chop/models/manual/bert_quantized/quant_config_bert.py b/machop/chop/models/manual/bert_quantized/quant_config_bert.py index cc6cf0294..0665a224b 100644 --- a/machop/chop/models/manual/bert_quantized/quant_config_bert.py +++ b/machop/chop/models/manual/bert_quantized/quant_config_bert.py @@ -6,7 +6,7 @@ import toml from chop.tools.config_load import convert_str_na_to_none -from ..quant_utils import parse_node_config +from ..quant_utils import parse_node_q_config logger = logging.getLogger(__name__) @@ -68,20 +68,20 @@ def create_a_layer_config( # fmt: off qc = { "attention": { - "query": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("query", linear_qc), "linear")), - "key": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("key", linear_qc), "linear")), - "value": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("value", linear_qc), "linear")), - "matmul_0": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("matmul_0", matmul_qc), "matmul")), - "matmul_1": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("matmul_1", matmul_qc), "matmul")), + "query": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("query", linear_qc), "linear")), + "key": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("key", linear_qc), "linear")), + "value": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("value", linear_qc), "linear")), + "matmul_0": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("matmul_0", matmul_qc), "matmul")), + "matmul_1": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("matmul_1", matmul_qc), "matmul")), "output": { - "dense": deepcopy(parse_node_config(layer_qc.get("attention", {}).get("output", {}).get("dense", linear_qc), "linear")), + "dense": deepcopy(parse_node_q_config(layer_qc.get("attention", {}).get("output", {}).get("dense", linear_qc), "linear")), }, }, "intermediate": { - "dense": deepcopy(parse_node_config(layer_qc.get("intermediate", {}).get("dense", linear_qc), "linear")), + "dense": deepcopy(parse_node_q_config(layer_qc.get("intermediate", {}).get("dense", linear_qc), "linear")), }, "output": { - "dense": deepcopy(parse_node_config(layer_qc.get("output", {}).get("dense", linear_qc), "linear")), + "dense": deepcopy(parse_node_q_config(layer_qc.get("output", {}).get("dense", linear_qc), "linear")), }, } # fmt: on @@ -94,10 +94,10 @@ def _parse_and_complete_config( ) -> dict: assert "default" in config, "Must provide a default config" default_qc: dict = config["default"] - linear_qc: dict = parse_node_config( + linear_qc: dict = parse_node_q_config( config.get("linear", default_qc), mase_op="linear" ) - matmul_qc: dict = parse_node_config( + matmul_qc: dict = parse_node_q_config( config.get("matmul", default_qc), mase_op="matmul" ) general_layer_qc: dict = config.get("model_layer", None) diff --git a/machop/chop/models/manual/llama_quantized/quant_config_llama.py b/machop/chop/models/manual/llama_quantized/quant_config_llama.py index b3988fd0f..d086a2a36 100644 --- a/machop/chop/models/manual/llama_quantized/quant_config_llama.py +++ b/machop/chop/models/manual/llama_quantized/quant_config_llama.py @@ -6,7 +6,7 @@ import toml from chop.tools.config_load import convert_str_na_to_none -from ..quant_utils import parse_node_config +from ..quant_utils import parse_node_q_config """ @@ -48,18 +48,18 @@ def create_a_layer_config( # fmt: off qc = { "self_attn": { - "q_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("q_proj", linear_qc), "linear")), - "k_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("k_proj", linear_qc), "linear")), - "v_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("v_proj", linear_qc), "linear")), - "o_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("o_proj", linear_qc), "linear")), - "rotary_positional_encoding": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("rotary_positional_encoding", rotary_positional_encoding_qc), "rotary_positional_encoding")), - "matmul_0": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("matmul_0", matmul_qc), "matmul")), - "matmul_1": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("matmul_1", matmul_qc), "matmul")), + "q_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("q_proj", linear_qc), "linear")), + "k_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("k_proj", linear_qc), "linear")), + "v_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("v_proj", linear_qc), "linear")), + "o_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("o_proj", linear_qc), "linear")), + "rotary_positional_encoding": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("rotary_positional_encoding", rotary_positional_encoding_qc), "rotary_positional_encoding")), + "matmul_0": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("matmul_0", matmul_qc), "matmul")), + "matmul_1": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("matmul_1", matmul_qc), "matmul")), }, "mlp": { - "gate_proj": deepcopy(parse_node_config(layer_qc.get("mlp", {}).get("gate_proj", linear_qc), "linear")), - "down_proj": deepcopy(parse_node_config(layer_qc.get("mlp", {}).get("down_proj", linear_qc), "linear")), - "up_proj": deepcopy(parse_node_config(layer_qc.get("mlp", {}).get("up_proj", linear_qc), "linear")) + "gate_proj": deepcopy(parse_node_q_config(layer_qc.get("mlp", {}).get("gate_proj", linear_qc), "linear")), + "down_proj": deepcopy(parse_node_q_config(layer_qc.get("mlp", {}).get("down_proj", linear_qc), "linear")), + "up_proj": deepcopy(parse_node_q_config(layer_qc.get("mlp", {}).get("up_proj", linear_qc), "linear")) }, } # fmt: on @@ -69,14 +69,14 @@ def create_a_layer_config( def _parse_and_complete_config(config: dict, num_hidden_layers: int) -> dict: assert "default" in config, "Must provide default config for by_name_parser" default_qc: dict = config["default"] - linear_qc: dict = parse_node_config( + linear_qc: dict = parse_node_q_config( config.get("linear", default_qc), mase_op="linear" ) - rotary_positional_encoding_qc: dict = parse_node_config( + rotary_positional_encoding_qc: dict = parse_node_q_config( config.get("rotary_positional_encoding", default_qc), mase_op="rotary_positional_encoding", ) - matmul_qc: dict = parse_node_config( + matmul_qc: dict = parse_node_q_config( config.get("matmul", default_qc), mase_op="matmul" ) general_layer_qc: dict = config.get("model_layer", None) diff --git a/machop/chop/models/manual/opt_quantized/quant_config_opt.py b/machop/chop/models/manual/opt_quantized/quant_config_opt.py index a76f8adb8..371fa6a62 100644 --- a/machop/chop/models/manual/opt_quantized/quant_config_opt.py +++ b/machop/chop/models/manual/opt_quantized/quant_config_opt.py @@ -5,9 +5,9 @@ import toml from ....tools.config_load import convert_str_na_to_none -from ....passes.graph import parse_node_config +from ....passes.graph import parse_node_q_config -from chop.passes.graph.transforms.quantize.quant_parsers import parse_quant_config +from chop.passes.graph.transforms.quantize.quant_parsers import parse_node_q_config """ An example of quant_config for opt @@ -43,15 +43,15 @@ def create_a_layer_config( # fmt: off qc = { "self_attn": { - "q_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("q_proj", linear_qc), "linear")), - "k_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("k_proj", linear_qc), "linear")), - "v_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("v_proj", linear_qc), "linear")), - "out_proj": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("out_proj", linear_qc), "linear")), - "bmm_0": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("bmm_0", bmm_qc), "matmul")), - "bmm_1": deepcopy(parse_node_config(layer_qc.get("self_attn", {}).get("bmm_1", bmm_qc), "matmul")), + "q_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("q_proj", linear_qc), "linear")), + "k_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("k_proj", linear_qc), "linear")), + "v_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("v_proj", linear_qc), "linear")), + "out_proj": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("out_proj", linear_qc), "linear")), + "bmm_0": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("bmm_0", bmm_qc), "matmul")), + "bmm_1": deepcopy(parse_node_q_config(layer_qc.get("self_attn", {}).get("bmm_1", bmm_qc), "matmul")), }, - "fc1": deepcopy(parse_node_config(layer_qc.get("fc1", linear_qc), "linear")), - "fc2": deepcopy(parse_node_config(layer_qc.get("fc2", linear_qc), "linear")), + "fc1": deepcopy(parse_node_q_config(layer_qc.get("fc1", linear_qc), "linear")), + "fc2": deepcopy(parse_node_q_config(layer_qc.get("fc2", linear_qc), "linear")), } # fmt: on return qc @@ -60,10 +60,10 @@ def create_a_layer_config( def _parse_and_complete_config(config: dict, num_hidden_layers: int) -> dict: assert "default" in config, "Must provide default config for by_name_parser" default_qc: dict = config["default"] - linear_qc: dict = parse_node_config( + linear_qc: dict = parse_node_q_config( config.get("linear", default_qc), mase_op="linear" ) - bmm_qc: dict = parse_node_config(config.get("bmm", default_qc), mase_op="matmul") + bmm_qc: dict = parse_node_q_config(config.get("bmm", default_qc), mase_op="matmul") general_layer_qc: dict = config.get("model_layer", None) # parsed config diff --git a/machop/chop/models/manual/quant_utils.py b/machop/chop/models/manual/quant_utils.py index c828d10dc..48506aca0 100644 --- a/machop/chop/models/manual/quant_utils.py +++ b/machop/chop/models/manual/quant_utils.py @@ -1,6 +1,6 @@ from typing import Callable -from chop.passes.graph import parse_node_config +from chop.passes.graph import parse_node_q_config from chop.passes.graph import quantized_func_map from chop.passes.graph import quantized_module_map @@ -16,4 +16,4 @@ def get_quantized_func(mase_op: str, quant_config: dict) -> Callable: def parse_op_quant_config(mase_op: str, config: dict) -> dict: - return parse_node_config(config=config, mase_op=mase_op) + return parse_node_q_config(config=config, mase_op=mase_op) diff --git a/machop/chop/passes/graph/__init__.py b/machop/chop/passes/graph/__init__.py index 63a87f5b1..0135d0847 100644 --- a/machop/chop/passes/graph/__init__.py +++ b/machop/chop/passes/graph/__init__.py @@ -37,7 +37,7 @@ ) from .transforms.quantize import quantized_func_map, quantized_module_map -from .transforms.quantize.quant_parsers import parse_node_config +from .transforms.quantize.quant_parsers import parse_node_q_config ANALYSIS_PASSES = [ "init_metadata", diff --git a/machop/chop/passes/graph/analysis/report/report_node.py b/machop/chop/passes/graph/analysis/report/report_node.py index cfa81812c..c2733a06a 100644 --- a/machop/chop/passes/graph/analysis/report/report_node.py +++ b/machop/chop/passes/graph/analysis/report/report_node.py @@ -5,6 +5,8 @@ import copy +import torch + logger = logging.getLogger(__name__) @@ -118,7 +120,7 @@ def report_node_hardware_type_analysis_pass(graph, pass_args: dict = {}): return graph, {} -def report_node_meta_param_analysis_pass(graph, pass_args: dict = None): +def report_node_meta_param_analysis_pass(graph, pass_args: dict = {}): """ Perform meta parameter analysis on the nodes in the graph and generate a report. @@ -131,6 +133,7 @@ def report_node_meta_param_analysis_pass(graph, pass_args: dict = None): :return: The analyzed graph and an empty dictionary. :rtype: tuple(MaseGraph, dict) """ + torch.set_printoptions(threshold=20) which_param = pass_args.get("which", ("all",)) assert isinstance(which_param, (list, tuple)) for param in which_param: @@ -184,4 +187,5 @@ def report_node_meta_param_analysis_pass(graph, pass_args: dict = None): with open(Path(save_path), "w") as f: f.write(table_txt) logger.info(f"Node meta param table is saved to {save_path}") + torch.set_printoptions(threshold=1000) return graph, {} diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py index 72a5b3b68..9668d2a7b 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py @@ -1,2 +1,4 @@ -from .parse_quant_config import parse_node_config -from .update_node_meta import relink_node_meta, update_quant_meta_param +# from .parse_quant_config import parse_node_q_config +from .parse_q_config import parse_node_q_config +# from .update_node_meta import relink_node_meta, update_quant_meta_param +from .update_node_meta import relink_node_meta, update_q_meta_param, infer_result_dtype_and_precision \ No newline at end of file diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/parse_quant_config.py similarity index 98% rename from machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py rename to machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/parse_quant_config.py index 63bbea13a..50c2c0779 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/parse_quant_config.py @@ -3,8 +3,6 @@ from .utils import cp_multi_values, has_multi_keys """ QUANT_ARITH_ENTRIES -A mapping from (quantization arithmetic name) to (a mapping from (operand name) to (operand quantization spec name)) - Example A fixed point quantized value is defined by (width, frac_width), thus the mapping is defined as follows: @@ -368,7 +366,7 @@ def optional_operand_entry_exists(config: dict, entry_name: str) -> bool: return False -def parse_node_config(config: dict, mase_op: str, strict: bool = True) -> dict: +def parse_node_q_config(config: dict, mase_op: str, strict: bool = True) -> dict: """ Parse a node config from a MASE op config. diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/q_recipes.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/q_recipes.py new file mode 100644 index 000000000..8e02397c2 --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/q_recipes.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass, field + + +@dataclass +class QRecipeFixed: + """_summary_ + Fixed point quantization + """ + + name: str = field(default="fixed", init=False) + bypass: bool = field(default=False) + data_in_width: int + data_in_frac_width: int + weight_width: int | None = field(default=None) + weight_frac_width: int | None = field(default=None) + bias_width: int | None = field(default=None) + bias_frac_width: int | None = field(default=None) + + +@dataclass +class QRecipeLutNet: + """ + LUTNET quantization + + binarization_level (int): which level of binarization is applied, "binarized_weight" is only weights binarized others is no binarization + input_expanded (bool): If set to True, means all LUT's inputs are considered during calculations , else only the first input will considered and the remaining will be masked. + k: int # k entries of a LUT + levels (int): number of residual levels to use in LUTNET + dim: this is needed by convolution + """ + + name: str = field(default="lutnet", init=False) + + data_in_width: int + data_in_frac_width: int + data_in_binarization_level: int + data_in_input_expanded: bool + data_in_k: int + data_in_in_levels: int + data_in_dim: tuple[int] + + weight_width: int + weight_frac_width: int + weight_binarization_level: int + weight_input_expanded: bool + weight_k: int + weight_in_dim: tuple[int] + + bias_width: int + bias_frac_width: int + bias_binarization_level: int + bias_input_expanded: bool + bias_k: int + bias_in_dim: tuple[int] diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/update_node_meta.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/update_node_meta.py new file mode 100644 index 000000000..5f980e10f --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/archive/update_node_meta.py @@ -0,0 +1,134 @@ +from functools import partial + + +def entry_to_list(config: dict, entry: str, suffixes: tuple[str]): + """e.g. [data_in_frac_width, data_in_width]""" + return list(config[f"{entry}_{suffix}"] for suffix in suffixes) + + +QUANT_ARITH_TO_SUFFIXES = { + "integer": ("width", "frac_width"), + "fixed": ("width", "frac_width"), + "binary": ( + "width", + "stochastic", + "bipolar", + ), # TODO: stochastic, bipolar flags are operational flag instead of precision. + "binary_residual": ( + "width", + "stochastic", + "bipolar", + ), # TODO: stochastic, bipolar flags are operational flag instead of precision. + "lutnet": ("width", "input_expanded", "k", "binarization_level"), + "logicnets": ("width", "frac_width"), + "ternary": ("width", "scaling_factor", "mean", "median", "max"), + "minifloat_ieee": ("width", "exponent_width", "exponent_bias"), + "minifloat_denorm": ("width", "exponent_width", "exponent_bias"), + "log": ("width", "exponent_bias"), + "block_fp": ("width", "exponent_width", "exponent_bias", "block_size"), + "block_minifloat": ("width", "exponent_width", "exponent_bias_width", "block_size"), + "block_log": ("width", "exponent_bias_width", "block_size"), +} + + +# quant_arith_to_list_fn = { +# : { +# : entry_to_list_ +# } +quant_arith_to_list_fn = {} +for quant_arith, suffixes in QUANT_ARITH_TO_SUFFIXES.items(): + quant_arith_to_list_fn[quant_arith] = partial(entry_to_list, suffixes=suffixes) + + +def update_arg(node, arg_name, dtype=None, precision=None, size=None): + if dtype is not None: + node.meta["mase"].parameters["common"]["args"][arg_name]["type"] = dtype + if precision is not None: + node.meta["mase"].parameters["common"]["args"][arg_name][ + "precision" + ] = precision + if size is not None: + node.meta["mase"].parameters["common"]["args"][arg_name]["size"] = size + + +MASE_OP_TO_INPUT_ENTRIES_AND_ARGS = { + # entry and arg corresponding to name in software and hardware mapping + "add": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "bmm": (("data_in", "weight"), ("data_in_0", "data_in_1")), + "conv1d": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), + "conv2d": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), + "matmul": (("data_in", "weight"), ("data_in_0", "data_in_1")), + "mul": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "linear": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), + "relu": (("data_in",), ("data_in_0",)), + "sub": (("data_in", "data_in"), ("data_in_0", "data_in_1")), +} + + +def update_result(node, output_name, dtype=None, precision=None, size=None): + if dtype is not None: + node.meta["mase"].parameters["common"]["results"][output_name]["type"] = dtype + if precision is not None: + node.meta["mase"].parameters["common"]["results"][output_name][ + "precision" + ] = precision + if size is not None: + node.meta["mase"].parameters["common"]["results"][output_name]["size"] = size + + +MASE_OP_TO_OUTPUT_ENTRIES = { + # entry and arg corresponding to name in software and hardware mapping + "add": (("data_out",), ("data_out_0",)), + "bmm": (("data_out",), ("data_out_0",)), + "conv1d": (("data_out",), ("data_out_0",)), + "conv2d": (("data_out",), ("data_out_0",)), + "matmul": (("data_out",), ("data_out_0",)), + "mul": (("data_out",), ("data_out_0",)), + "linear": (("data_out",), ("data_out_0",)), + "relu": (("data_out",), ("data_out_0",)), + "sub": (("data_out",), ("data_out_0",)), +} + + +def arg_exists(node, arg_name) -> bool: + return arg_name in node.meta["mase"].parameters["common"]["args"] + + +def update_quant_meta_param(node, config: dict, mase_op: str) -> None: + quant_arith = config["name"] + assert quant_arith in quant_arith_to_list_fn, f"Unknown quant_arith: {quant_arith}" + """ + MASE_OP_TO_INPUT_ENTRIES_AND_ARGS: Give a mapping between config file and mase model + How it works: + We find the precision of a certain paramter "e.g data_in" using the precision partial function. + + The precision partial function take a config file and entry "e.g data_in", + and it will search through all the attributes under this entry based on the quantisation scheme, + returning a list of precision with the order same as attributes defined in QUANT_ARITH_TO_SUFFIXES + + This precision list is then being mapped to mase data using 'arg' + """ + for entry, arg in zip(*MASE_OP_TO_INPUT_ENTRIES_AND_ARGS[mase_op]): + if not arg_exists(node, arg): + continue + update_arg( + node, + arg_name=arg, + dtype=quant_arith, + precision=quant_arith_to_list_fn[quant_arith](config, entry), + ) + + for entry, arg in zip(*MASE_OP_TO_OUTPUT_ENTRIES[mase_op]): + # Quantise all the output to fixed point. TODO: Make this automatic. Hardware will need change too + if quant_arith == "binary" or quant_arith == "binary_residual": + update_result( + node, + output_name=arg, + dtype="binary", + precision=[32, 0, 1], # [bitwidth, stochastic, bipolar] + ) + + +def relink_node_meta(node, model): + node.meta["mase"].node = node + node.meta["mase"].model = model diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py new file mode 100644 index 000000000..3caa6f839 --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py @@ -0,0 +1,41 @@ +import copy +from .q_op_entries import FIXED_OP_ENTRIES + +""" +MASE_OP_TO_ENTRIES = { + : { + "required": (...), + "optional": (...) + } +} +""" + +def get_q_op_entries(q_name: str, mase_op: str): + match q_name: + case "fixed": + op_entries = FIXED_OP_ENTRIES + case _: + raise ValueError(f"Unknown quantization arithmetic name: {q_name}") + + if mase_op not in op_entries: + raise ValueError(f"Unknown MASE operation name: {mase_op} for quantization arithmetic: {q_name}") + + return op_entries[mase_op] + + +def parse_node_q_config(q_config: dict, mase_op: str): + q_op_entries = get_q_op_entries(q_config["name"], mase_op) + + required_keys = q_op_entries["required"] + optional_keys = q_op_entries["optional"] + + parsed_q_config = {} + for k in required_keys: + assert k in q_config, f"Required key {k} not found in q_config: {q_config}" + parsed_q_config[k] = copy.deepcopy(q_config[k]) + + for k in optional_keys: + if k in q_config: + parsed_q_config[k] = copy.deepcopy(q_config[k]) + + return parsed_q_config diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/__init__.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/__init__.py new file mode 100644 index 000000000..af53c9d1f --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/__init__.py @@ -0,0 +1 @@ +from .fixed import FIXED_OP_ENTRIES diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py new file mode 100644 index 000000000..f1c8c8da1 --- /dev/null +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py @@ -0,0 +1,68 @@ +FIXED_OP_ENTRIES = { + "add": { + "required": ("name", "data_in_width", "data_in_frac_width"), + "optional": ("bypass",), + }, + "bmm": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass",), + }, + "conv1d": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass", "bias_width", "bias_frac_width"), + }, + "conv2d": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass", "bias_width", "bias_frac_width"), + }, + "linear": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass", "bias_width", "bias_frac_width"), + }, + "matmul": { + "required": ( + "name", + "data_in_width", + "data_in_frac_width", + "weight_width", + "weight_frac_width", + ), + "optional": ("bypass",), + }, + "relu": { + "required": ("name", "data_in_width", "data_in_frac_width"), + "optional": ("bypass",), + }, + "sub": { + "required": ("name", "data_in_width", "data_in_frac_width"), + "optional": ("bypass",), + }, + "rotary_positional_encoding": { + "required": ("name", "data_in_width", "data_in_frac_width"), + "optional": ("bypass",), + }, +} diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py index 5f980e10f..a5dc9cd03 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py @@ -1,134 +1,273 @@ -from functools import partial - - -def entry_to_list(config: dict, entry: str, suffixes: tuple[str]): - """e.g. [data_in_frac_width, data_in_width]""" - return list(config[f"{entry}_{suffix}"] for suffix in suffixes) - - -QUANT_ARITH_TO_SUFFIXES = { - "integer": ("width", "frac_width"), - "fixed": ("width", "frac_width"), - "binary": ( - "width", - "stochastic", - "bipolar", - ), # TODO: stochastic, bipolar flags are operational flag instead of precision. - "binary_residual": ( - "width", - "stochastic", - "bipolar", - ), # TODO: stochastic, bipolar flags are operational flag instead of precision. - "lutnet": ("width", "input_expanded", "k", "binarization_level"), - "logicnets": ("width", "frac_width"), - "ternary": ("width", "scaling_factor", "mean", "median", "max"), - "minifloat_ieee": ("width", "exponent_width", "exponent_bias"), - "minifloat_denorm": ("width", "exponent_width", "exponent_bias"), - "log": ("width", "exponent_bias"), - "block_fp": ("width", "exponent_width", "exponent_bias", "block_size"), - "block_minifloat": ("width", "exponent_width", "exponent_bias_width", "block_size"), - "block_log": ("width", "exponent_bias_width", "block_size"), -} - +import logging +from ....utils import get_mase_op, get_mase_type -# quant_arith_to_list_fn = { -# : { -# : entry_to_list_ -# } -quant_arith_to_list_fn = {} -for quant_arith, suffixes in QUANT_ARITH_TO_SUFFIXES.items(): - quant_arith_to_list_fn[quant_arith] = partial(entry_to_list, suffixes=suffixes) - - -def update_arg(node, arg_name, dtype=None, precision=None, size=None): - if dtype is not None: - node.meta["mase"].parameters["common"]["args"][arg_name]["type"] = dtype - if precision is not None: - node.meta["mase"].parameters["common"]["args"][arg_name][ - "precision" - ] = precision - if size is not None: - node.meta["mase"].parameters["common"]["args"][arg_name]["size"] = size - - -MASE_OP_TO_INPUT_ENTRIES_AND_ARGS = { - # entry and arg corresponding to name in software and hardware mapping - "add": (("data_in", "data_in"), ("data_in_0", "data_in_1")), - "bmm": (("data_in", "weight"), ("data_in_0", "data_in_1")), - "conv1d": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), - "conv2d": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), - "matmul": (("data_in", "weight"), ("data_in_0", "data_in_1")), - "mul": (("data_in", "data_in"), ("data_in_0", "data_in_1")), - "linear": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")), - "relu": (("data_in",), ("data_in_0",)), - "sub": (("data_in", "data_in"), ("data_in_0", "data_in_1")), -} +logger = logging.getLogger(__name__) -def update_result(node, output_name, dtype=None, precision=None, size=None): - if dtype is not None: - node.meta["mase"].parameters["common"]["results"][output_name]["type"] = dtype - if precision is not None: - node.meta["mase"].parameters["common"]["results"][output_name][ - "precision" - ] = precision - if size is not None: - node.meta["mase"].parameters["common"]["results"][output_name]["size"] = size - - -MASE_OP_TO_OUTPUT_ENTRIES = { - # entry and arg corresponding to name in software and hardware mapping - "add": (("data_out",), ("data_out_0",)), - "bmm": (("data_out",), ("data_out_0",)), - "conv1d": (("data_out",), ("data_out_0",)), - "conv2d": (("data_out",), ("data_out_0",)), - "matmul": (("data_out",), ("data_out_0",)), - "mul": (("data_out",), ("data_out_0",)), - "linear": (("data_out",), ("data_out_0",)), - "relu": (("data_out",), ("data_out_0",)), - "sub": (("data_out",), ("data_out_0",)), +OPERANDS_TO_META_ARG_NAMES = { + "add": { + "required": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "optional": None, + }, + "bmm": { + "required": (("data_in", "weight"), ("data_in_0", "data_in_1")), + "optional": None, + }, + "conv1d": { + "required": (("data_in", "weight"), ("data_in_0", "weight")), + "optional": (("bias",), ("bias",)), + }, + "conv2d": { + "required": (("data_in", "weight"), ("data_in_0", "weight")), + "optional": (("bias",), ("bias",)), + }, + "matmul": { + "required": (("data_in", "weight"), ("data_in_0", "data_in_1")), + "optional": None, + }, + "mul": { + "required": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "optional": None, + }, + "linear": { + "required": (("data_in", "weight"), ("data_in_0", "weight")), + "optional": (("bias",), ("bias",)), + }, + "relu": { + "required": (("data_in",), ("data_in_0",)), + "optional": None, + }, + "sub": { + "required": (("data_in", "data_in"), ("data_in_0", "data_in_1")), + "optional": None, + }, } -def arg_exists(node, arg_name) -> bool: - return arg_name in node.meta["mase"].parameters["common"]["args"] +def update_node_meta_param_fixed(node, q_config): + """Add fixed-point precision to node meta for quantization - -def update_quant_meta_param(node, config: dict, mase_op: str) -> None: - quant_arith = config["name"] - assert quant_arith in quant_arith_to_list_fn, f"Unknown quant_arith: {quant_arith}" + Precision format: [width, frac_width] """ - MASE_OP_TO_INPUT_ENTRIES_AND_ARGS: Give a mapping between config file and mase model - How it works: - We find the precision of a certain paramter "e.g data_in" using the precision partial function. + mase_op = get_mase_op(node) + if mase_op not in OPERANDS_TO_META_ARG_NAMES: + raise ValueError( + f"Unsupported MASE operation name `{mase_op}` for updating node meta for quantization" + ) - The precision partial function take a config file and entry "e.g data_in", - and it will search through all the attributes under this entry based on the quantisation scheme, - returning a list of precision with the order same as attributes defined in QUANT_ARITH_TO_SUFFIXES + required_args = OPERANDS_TO_META_ARG_NAMES[mase_op]["required"] + optional_args = OPERANDS_TO_META_ARG_NAMES[mase_op]["optional"] - This precision list is then being mapped to mase data using 'arg' - """ - for entry, arg in zip(*MASE_OP_TO_INPUT_ENTRIES_AND_ARGS[mase_op]): - if not arg_exists(node, arg): - continue - update_arg( - node, - arg_name=arg, - dtype=quant_arith, - precision=quant_arith_to_list_fn[quant_arith](config, entry), - ) + for operand_name, arg_name in zip(*required_args): + node.meta["mase"].parameters["common"]["args"][arg_name]["type"] = "fixed" + node.meta["mase"].parameters["common"]["args"][arg_name]["precision"] = [ + q_config[f"{operand_name}_width"], + q_config[f"{operand_name}_frac_width"], + ] - for entry, arg in zip(*MASE_OP_TO_OUTPUT_ENTRIES[mase_op]): - # Quantise all the output to fixed point. TODO: Make this automatic. Hardware will need change too - if quant_arith == "binary" or quant_arith == "binary_residual": - update_result( - node, - output_name=arg, - dtype="binary", - precision=[32, 0, 1], # [bitwidth, stochastic, bipolar] - ) + if optional_args is not None: + for operand_name, arg_name in zip(*optional_args): + if arg_name in node.meta["mase"].parameters["common"]["args"]: + if not ( + f"{operand_name}_width" in q_config + and f"{operand_name}_frac_width" in q_config + ): + raise RuntimeError( + f"Optional argument {arg_name} found in node meta, but not found in q_config: {q_config}" + ) + node.meta["mase"].parameters["common"]["args"][arg_name]["type"] = "fixed" + node.meta["mase"].parameters["common"]["args"][arg_name]["precision"] = [ + q_config[f"{operand_name}_width"], + q_config[f"{operand_name}_frac_width"], + ] def relink_node_meta(node, model): node.meta["mase"].node = node node.meta["mase"].model = model + + +def update_q_meta_param(node, config: dict): + q_arith = config["name"] + + match q_arith: + case "fixed": + update_node_meta_param_fixed(node, config) + case _: + raise ValueError(f"Unsupported quantization arithmetic name: {q_arith}") + +from torch.fx import Node + +def find_next_compute_node(node: Node): + for n in node.users: + if get_mase_type(n) in ["module_related_func", "builtin_func"]: + return node, n + for n in node.users: + return find_next_compute_node(n) + return None, None + +def find_prev_compute_node(node: Node): + for n in node.all_input_nodes: + if get_mase_type(n) in ["module_related_func", "builtin_func"]: + return node, n + for n in node.all_input_nodes: + return find_prev_compute_node(n) + return None, None + +def infer_result_dtype_and_precision(node: Node): + """ + ```text + n_1 n_2 + \ / + node + ``` + + assign node's args precision & dtype to n_1, n_2 results + """ + + if get_mase_type(node) == "placeholder": + # input node + input_node, next_node = find_next_compute_node(node) + if input_node is None: + logger.warning(f"Failed to find next module_related_func node for input node {node.name}. Check if the graph contains module_related_func") + return + i = 0 + for n in next_node.all_input_nodes: + if n is input_node: + break + i += 1 + arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_value = next_node.meta["mase"].parameters["common"]["args"][arg_key] + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] + node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] + + for arg in node.meta["mase"].parameters["common"]["args"]: + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value["type"] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = arg_value["precision"] + + logger.debug(f"Inferred arg & result dtype and precision for input node `{node.name}` using `{next_node.name}`") + + elif get_mase_type(node) in ["module_related_func", "builtin_func"]: + input_node, next_node = find_next_compute_node(node) + if next_node is None: + # this is the last compute node in the graph, use its args to infer dtype and precision + max_precision = None + max_dtype = None + max_bitwidth = 0 + for arg in node.meta["mase"].parameters["common"]["args"]: + if not isinstance(node.meta["mase"].parameters["common"]["args"][arg], dict): + continue + if not "precision" in node.meta["mase"].parameters["common"]["args"][arg]: + continue + cur_width = node.meta["mase"].parameters["common"]["args"][arg]["precision"][0] + if cur_width > max_bitwidth: + max_bitwidth = cur_width + max_precision = node.meta["mase"].parameters["common"]["args"][arg]["precision"] + max_dtype = node.meta["mase"].parameters["common"]["args"][arg]["type"] + + if max_precision is None: + raise RuntimeError(f"Failed to infer dtype and precision for module_related_func node {node.name}") + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = max_dtype + node.meta["mase"].parameters["common"]["results"][result]["precision"] = max_precision + logger.debug(f"Inferred result dtype and precision for module_related_func node `{node.name}` using its args") + else: + # use next compute node's args to infer dtype and precision + i = 0 + for n in next_node.all_input_nodes: + if n is input_node: + break + i += 1 + arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_value = next_node.meta["mase"].parameters["common"]["args"][arg_key] + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] + node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] + logger.debug(f"Inferred result dtype and precision for module_related_func node `{node.name}` using `{next_node.name}`") + + elif get_mase_type(node) == "implicit_func": + input_node, next_node = find_next_compute_node(node) + user_node, prev_node = find_prev_compute_node(node) + + if next_node is not None: + i = 0 + for n in next_node.all_input_nodes: + if n is input_node: + break + i += 1 + arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_value = next_node.meta["mase"].parameters["common"]["args"][arg_key] + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] + node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] + + for arg in node.meta["mase"].parameters["common"]["args"]: + if not isinstance(node.meta["mase"].parameters["common"]["args"][arg], dict): + continue + if not "precision" in node.meta["mase"].parameters["common"]["args"][arg]: + continue + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value["type"] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = arg_value["precision"] + logger.debug(f"Inferred arg & result dtype and precision for implicit_func node `{node.name}` using `{next_node.name}`") + + elif prev_node is not None: + i = 0 + for n in prev_node.users: + if n is user_node: + break + i += 1 + arg_key = list(prev_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_value = prev_node.meta["mase"].parameters["common"]["args"][arg_key] + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] + node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] + + for arg in node.meta["mase"].parameters["common"]["args"]: + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value["type"] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = arg_value["precision"] + logger.debug(f"Inferred arg & result dtype and precision for implicit_func node `{node.name}` using `{prev_node.name}`") + + else: + raise RuntimeError(f"Failed to infer dtype and precision for implicit_func node {node.name} as it has no input nodes or users of type `module_related_func`") + + elif get_mase_type(node) == "output": + # output node + # find the max precision of all input nodes + user_node, prev_node = find_prev_compute_node(node) + + if prev_node is None: + raise RuntimeError(f"Failed to find prev module_related_func node for output node {node.name}") + + max_precision = None + max_dtype = None + max_bitwidth = 0 + + i = 0 + for n in prev_node.users: + if n is user_node: + break + i += 1 + + arg_key = list(prev_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_value = prev_node.meta["mase"].parameters["common"]["args"][arg_key] + + for result in node.meta["mase"].parameters["common"]["results"]: + node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] + node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] + + for arg in node.meta["mase"].parameters["common"]["args"]: + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value["type"] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = arg_value["precision"] + + logger.debug(f"Inferred dtype and precision for output node `{node.name}` using `{prev_node.name}`") + + else: + raise RuntimeError(f"Unsupported node type {get_mase_type(node)} for inferring dtype and precision") + + diff --git a/machop/chop/passes/graph/transforms/quantize/quantize.py b/machop/chop/passes/graph/transforms/quantize/quantize.py index ba3c62f10..312d6d820 100644 --- a/machop/chop/passes/graph/transforms/quantize/quantize.py +++ b/machop/chop/passes/graph/transforms/quantize/quantize.py @@ -15,7 +15,7 @@ ) from .modify import create_new_fn, create_new_module -from .quant_parsers import parse_node_config, relink_node_meta, update_quant_meta_param +from .quant_parsers import parse_node_q_config, relink_node_meta, update_q_meta_param, infer_result_dtype_and_precision from .summary import graph_iterator_compare_nodes, graph_iterator_node_histogram logger = logging.getLogger(__name__) @@ -63,7 +63,7 @@ def graph_iterator_quantize_by_type(graph, config: dict): node_config = get_config(config, get_mase_op(node)) if node_config["name"] is None: continue - node_config = parse_node_config(node_config, get_mase_op(node)) + node_config = parse_node_q_config(node_config, get_mase_op(node)) # if get_mase_type(node) == "module": if node.op == "call_module": ori_module = get_node_actual_target(node) @@ -82,7 +82,7 @@ def graph_iterator_quantize_by_type(graph, config: dict): parent_name, name = get_parent_name(node.target) setattr(graph.modules[parent_name], name, new_module) # update precision and type in meta.parameters["common"] - update_quant_meta_param(node, node_config, get_mase_op(node)) + update_q_meta_param(node, node_config) elif get_mase_type(node) in [ "builtin_func", "module_related_func", @@ -94,9 +94,13 @@ def graph_iterator_quantize_by_type(graph, config: dict): new_node.meta["mase"] = copy(node.meta["mase"]) # new_node.meta["mase"].node -> new_node relink_node_meta(new_node, model=graph.model) - update_quant_meta_param(new_node, node_config, get_mase_op(node)) + update_q_meta_param(new_node, node_config) node.replace_all_uses_with(new_node) graph.fx_graph.erase_node(node) + + for node in graph.fx_graph.nodes: + if get_mase_type(node) in ["module_related_func", "builtin_func", "output", "placeholder", "implicit_func"]: + infer_result_dtype_and_precision(node) return graph @@ -107,7 +111,7 @@ def graph_iterator_quantize_by_name(graph, config: dict): node_config = get_config(config, node.name) if node_config["name"] is None: continue - node_config = parse_node_config(node_config, get_mase_op(node)) + node_config = parse_node_q_config(node_config, get_mase_op(node)) output_layers_names = node_config.get("additional_layers_outputs", []) output_layers = [ get_node_target_by_name(graph, name) for name in output_layers_names @@ -128,7 +132,7 @@ def graph_iterator_quantize_by_name(graph, config: dict): ) parent_name, name = get_parent_name(node.target) setattr(graph.modules[parent_name], name, new_module) - update_quant_meta_param(node, node_config, get_mase_op(node)) + update_q_meta_param(node, node_config) logger.debug(f"Quantized module: {node.target} with config: {node_config}") elif get_mase_type(node) in [ "builtin_func", @@ -140,7 +144,7 @@ def graph_iterator_quantize_by_name(graph, config: dict): new_node.name = node.name new_node.meta["mase"] = copy(node.meta["mase"]) relink_node_meta(new_node, model=graph.model) - update_quant_meta_param(new_node, node_config, get_mase_op(node)) + update_q_meta_param(new_node, node_config) node.replace_all_uses_with(new_node) graph.fx_graph.erase_node(node) logger.debug( @@ -150,6 +154,9 @@ def graph_iterator_quantize_by_name(graph, config: dict): raise ValueError( "Unsupported node type for quantisation: {}".format(get_mase_type(node)) ) + for node in graph.fx_graph.nodes: + if get_mase_type(node) in ["module_related_func", "builtin_func", "output", "placeholder", "implicit_func"]: + infer_result_dtype_and_precision(node) return graph @@ -165,7 +172,7 @@ def graph_iterator_quantize_by_regex_name(graph, config: dict): node_config = get_config(config, matched_pattern) if node_config["name"] is None: continue - node_config = parse_node_config(node_config, get_mase_op(node)) + node_config = parse_node_q_config(node_config, get_mase_op(node)) # if get_mase_type(node) == "module": if node.op == "call_module": ori_module = graph.modules[node.target] @@ -174,7 +181,7 @@ def graph_iterator_quantize_by_regex_name(graph, config: dict): ) parent_name, name = get_parent_name(node.target) setattr(graph.modules[parent_name], name, new_module) - update_quant_meta_param(node, node_config, get_mase_op(node)) + update_q_meta_param(node, node_config) elif get_mase_type(node) in [ "builtin_func", "module_related_func", @@ -185,13 +192,16 @@ def graph_iterator_quantize_by_regex_name(graph, config: dict): new_node.name = node.name new_node.meta["mase"] = deepcopy(node.meta["mase"]) relink_node_meta(new_node, model=graph.model) - update_quant_meta_param(new_node, node_config, get_mase_op(node)) + update_q_meta_param(new_node, node_config) node.replace_all_uses_with(new_node) graph.fx_graph.erase_node(node) else: raise ValueError( "Unsupported node type for quantisation:{}".format(get_mase_type(node)) ) + for node in graph.fx_graph.nodes: + if get_mase_type(node) in ["module_related_func", "builtin_func", "output", "placeholder", "implicit_func"]: + infer_result_dtype_and_precision(node) return graph diff --git a/machop/chop/passes/graph/transforms/verilog/emit_internal.py b/machop/chop/passes/graph/transforms/verilog/emit_internal.py index d29186ae0..8b40a1b83 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_internal.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_internal.py @@ -46,7 +46,7 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): for node in graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue - if "INTERNAL" == node.meta["mase"].parameters["hardware"]["toolchain"]: + if "INTERNAL_RTL" == node.meta["mase"].parameters["hardware"]["toolchain"]: if ( hasattr(node.meta["mase"].module, "config") and node.meta["mase"].module.config.get("name", "") == "logicnets" @@ -70,12 +70,11 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): "..", "..", "..", + "..", "mase_components", ) for f in rtl_dependencies: - fname = os.path.join(hardware_dir, f) - assert os.path.isfile(fname), f"Cannot find file {fname}." - shutil.copy(fname, rtl_dir) + shutil.copy(os.path.join(hardware_dir, f), rtl_dir) return graph, {} diff --git a/machop/chop/passes/graph/transforms/verilog/emit_top.py b/machop/chop/passes/graph/transforms/verilog/emit_top.py index 1d9a2640f..f9f8d1881 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_top.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_top.py @@ -56,10 +56,6 @@ def param_needs_signals(node, param, value, qualifier="data_in"): ) -# ============================================================================= -# Emit design in a memory-independent dataflow graph -# ============================================================================= - # ============================================================================= # Verilog parameters # ============================================================================= @@ -69,16 +65,16 @@ class VerilogParameterEmitter: def __init__(self, graph): self.graph = graph - def emit(self, parameter_map) -> Tuple[str, Dict[str, str]]: + def emit(self, graph, parameter_map) -> Tuple[str, Dict[str, str]]: """ Emit parameters at the top-level for the top-level module Returns Tuple: - 1) list of parameters as a string to be embedded in DFVerilog file + 1) list of parameters as a string to be embedded in Verilog file """ - nodes_in = self.graph.nodes_in - nodes_out = self.graph.nodes_out + nodes_in = graph.nodes_in + nodes_out = graph.nodes_out node_in_name = vf(nodes_in[0].name) node_out_name = vf(nodes_out[0].name) @@ -92,15 +88,15 @@ def emit(self, parameter_map) -> Tuple[str, Dict[str, str]]: # ============================================================================= -# DFVerilog interface +# Verilog interface # ============================================================================= -class DFVerilogInterfaceEmitter: +class VerilogInterfaceEmitter: def __init__(self, graph): self.graph = graph - def emit(self, parameter_map): + def emit(self, graph, parameter_map): """ Emit interface signal declarations for the top-level module """ @@ -146,36 +142,17 @@ def emit(self, parameter_map): input data_out_{i}_ready,""" i += 1 - # Emit all parameters as inputs (they will be mapped at the top-level) - for node in self.graph.fx_graph.nodes: - if node.meta["mase"].parameters["hardware"]["is_implicit"]: - continue - node_name = vf(node.name) - for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): - if not isinstance(arg_info, dict): - continue - if "data_in" not in arg: - arg_name = _cap(arg) - parallelism_params = [ - param - for param in parameter_map - if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param - ] - interface += f""" - input [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0], - input {node_name}_{arg}_valid, - output {node_name}_{arg}_ready,""" - i += 1 + # TODO: emit off-chip parameter interface return _remove_last_comma(interface) # ============================================================================= -# DFVerilog signals +# Verilog signals # ============================================================================= -class DFVerilogSignalEmitter: +class VerilogSignalEmitter: def __init__(self, graph): self.graph = graph @@ -184,7 +161,15 @@ def _emit_signals_top_internal(self, node, parameter_map): node_name = vf(node.name) # Input signals for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): - if "data_in" in arg: + if not isinstance(arg_info, dict): + continue + + # Skip off-chip parameters as they will be directly connected to the top level + if ( + "data_in" in arg + or node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] + == "BRAM" + ): arg_name = v2p(arg) parallelism_params = [ param @@ -203,7 +188,14 @@ def _emit_signals_top_internal(self, node, parameter_map): if not isinstance(result_info, dict): continue - if "data_out" in result: + # Skip off-chip parameters as they will be directly connected to the top level + if ( + "data_out" in result + or node.meta["mase"].parameters["hardware"]["interface"][result][ + "storage" + ] + == "BRAM" + ): result_name = v2p(result) parallelism_params = [ param @@ -267,13 +259,13 @@ def _emit_signals_top_hls(self, node, parameter_map): logic {node_name}_{key}_we0;""" return signals - def emit(self, parameter_map): + def emit(self, graph, parameter_map): """ Emit internal signal declarations for the top-level module """ signals = "" - for node in self.graph.fx_graph.nodes: + for node in graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue node_name = vf(node.name) @@ -292,14 +284,37 @@ def emit(self, parameter_map): # ============================================================================= -# DFVerilog components (INTERNAL) +# Verilog components (INTERNAL) # ============================================================================= -class DFVerilogInternalComponentEmitter: +class VerilogInternalComponentEmitter: def __init__(self, graph): self.graph = graph + def _emit_module_parameters_top_internal(self, key, value, node, parameter_map): + node_name = vf(node.name) + component_name = f"{node_name}_{key}_source" + component_name_inst = f"{component_name}_0" + + parameters = "" + for param in node.meta["mase"].parameters["hardware"]["verilog_param"].keys(): + if f"{_cap(key)}_" in param: + parameters += f".{param}({node_name}_{param}),\n" + parameters = _remove_last_comma(parameters) + + return f""" +{component_name} #( +{parameters} +) {component_name_inst} ( + .clk(clk), + .rst(rst), + .data_out({node_name}_{key}), + .data_out_ready({node_name}_{key}_ready), + .data_out_valid({node_name}_{key}_valid) +); +""" + def emit(self, node, parameter_map): node_name = vf(node.name) component_name = node.meta["mase"].parameters["hardware"]["module"] @@ -348,15 +363,26 @@ def emit(self, node, parameter_map): ); """ + # Emit module parameter instances (e.g. weights and biases) + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if "data_in" in arg: + continue + if not isinstance(arg_info, dict): + continue + + components += self._emit_module_parameters_top_internal( + arg, arg_info, node, parameter_map + ) + return components # ============================================================================= -# DFVerilog components (HLS) +# Verilog components (HLS) # ============================================================================= -class DFVerilogHLSComponentEmitter: +class VerilogHLSComponentEmitter: def __init__(self, graph): self.graph = graph @@ -442,17 +468,17 @@ def emit(self, node, parameter_map): # ============================================================================= -# DFVerilog components +# Verilog components # ============================================================================= -class DFVerilogComponentEmitter: +class VerilogComponentEmitter: def __init__(self, graph): self.graph = graph - self.internal_emitter = DFVerilogInternalComponentEmitter(graph) - self.hls_emitter = DFVerilogHLSComponentEmitter(graph) + self.internal_emitter = VerilogInternalComponentEmitter(graph) + self.hls_emitter = VerilogHLSComponentEmitter(graph) - def emit(self, parameter_map): + def emit(self, graph, parameter_map): """ Emit component declarations for the top-level module """ @@ -462,7 +488,7 @@ def emit(self, parameter_map): // Component instantiation // -------------------------- """ - for node in self.graph.fx_graph.nodes: + for node in graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: @@ -476,13 +502,14 @@ def emit(self, parameter_map): # ============================================================================= -# DFVerilog wires +# Verilog wires # ============================================================================= -class DFVerilogWireEmitter: - def __init__(self, graph): +class VerilogWireEmitter: + def __init__(self, graph, parameter_map): self.graph = graph + self.parameter_map = parameter_map self.wires = """ // -------------------------- @@ -490,7 +517,7 @@ def __init__(self, graph): // -------------------------- """ - def _emit_top_wires(self, parameter_map): + def _emit_top_wires(self): nodes_in = self.graph.nodes_in nodes_out = self.graph.nodes_out @@ -553,7 +580,7 @@ def _emit_node2node_wires(self): """ return wires - def emit(self, parameter_map): + def emit(self): """ Emit internal signal connections for the top-level module This includes two interconnection types: @@ -561,7 +588,7 @@ def emit(self, parameter_map): 2. Interface casting between inputs and outputs """ - self.wires += self._emit_top_wires(parameter_map) + self.wires += self._emit_top_wires() self.wires += self._emit_node2node_wires() return self.wires @@ -571,38 +598,39 @@ def emit(self, parameter_map): # ============================================================================= -class DataflowEmitter: +class VerilogEmitter: def __init__(self, graph): self.graph = graph + self.parameter_map = get_verilog_parameters(graph) - def emit(self, top_name): - parameters_to_emit = VerilogParameterEmitter(self.graph).emit( - self.parameter_map + def emit(self, graph, top_name): + parameters_to_emit = VerilogParameterEmitter(graph).emit( + graph, self.parameter_map ) - interface_to_emit = DFVerilogInterfaceEmitter(self.graph).emit( - self.parameter_map + interface_to_emit = VerilogInterfaceEmitter(graph).emit( + graph, self.parameter_map ) - signals_to_emit = DFVerilogSignalEmitter(self.graph).emit(self.parameter_map) + signals_to_emit = VerilogSignalEmitter(graph).emit(graph, self.parameter_map) - components_to_emit = DFVerilogComponentEmitter(self.graph).emit( - self.parameter_map + components_to_emit = VerilogComponentEmitter(graph).emit( + graph, self.parameter_map ) - wires_to_emit = DFVerilogWireEmitter(self.graph).emit(self.parameter_map) + wires_to_emit = VerilogWireEmitter(graph, self.parameter_map).emit() time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S") module_inst = """ // ===================================== -// Mase Hardware (Dataflow) +// Mase Hardware // Model: {} // {} // ===================================== `timescale 1ns/1ps -module {}_dataflow #( +module {} #( {} ) ( input clk, @@ -1084,18 +1112,10 @@ def emit_verilog_top_transform_pass(graph, pass_args={}): ) top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" init_project(project_dir) - logger.info(f"Project path: {project_dir}") - rtl_dir = os.path.join(project_dir, "hardware", "rtl") - # Emit device-independent hardware design in dataflow - df = DataflowEmitter(graph).emit(top_name) - df_file = os.path.join(rtl_dir, f"{top_name}_df.sv") - with open(df_file, "w") as df_design: - df_design.write(df) + top = VerilogEmitter(graph).emit(graph, top_name) - # Emit memory mapping with BRAMs for the top-level hardware design - top = MemoryMapEmitter(graph).emit(top_name) top_file = os.path.join(rtl_dir, f"{top_name}.sv") with open(top_file, "w") as top_design: top_design.write(top) diff --git a/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py b/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py index 9087e2ab3..a84625f93 100644 --- a/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py +++ b/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py @@ -12,11 +12,6 @@ "common/rtl/skid_buffer.sv", "common/rtl/join2.sv", "cast/rtl/fixed_rounding.sv", - "cast/rtl/fixed_round.sv", - ], - "relu": [ - "activations/rtl/fixed_relu.sv", - "cast/rtl/fixed_rounding.sv", - "cast/rtl/fixed_round.sv", ], + "relu": ["activations/fixed_relu.sv"], } diff --git a/machop/configs/tests/quantize/fixed.toml b/machop/configs/tests/quantize/fixed.toml index d7d109695..ad4c79007 100644 --- a/machop/configs/tests/quantize/fixed.toml +++ b/machop/configs/tests/quantize/fixed.toml @@ -1,15 +1,24 @@ -model = "toy" -dataset = "toy-tiny" +model="toy" +dataset="toy-tiny" [passes.quantize] -by = "type" -report = true + by="type" + report=true -[passes.quantize.default.config] -name = "fixed" -data_in_width = 8 -data_in_frac_width = 3 -weight_width = 8 -weight_frac_width = 3 -bias_width = 8 -bias_frac_width = 3 \ No newline at end of file + [passes.quantize.default.config] + name="fixed" + data_in_width=8 + data_in_frac_width=3 + weight_width=8 + weight_frac_width=3 + bias_width=8 + bias_frac_width=3 + + [passes.quantize.relu.config] + name="fixed" + data_in_width=4 + data_in_frac_width=2 + weight_width=4 + weight_frac_width=2 + bias_width=4 + bias_frac_width=2 diff --git a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py index 1ab406ced..ba2fbe224 100644 --- a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py +++ b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py @@ -35,6 +35,7 @@ def __init__(self) -> None: def forward(self, x): x = torch.flatten(x, start_dim=1, end_dim=-1) + # x = torch.nn.functional.relu(x) x = torch.nn.functional.relu(self.fc1(x)) x = torch.nn.functional.relu(self.fc2(x)) x = self.fc3(x) @@ -69,7 +70,7 @@ def test_emit_verilog_linear(): "configs", "tests", "quantize", - "integer.toml", + "fixed.toml", ) # load toml config file @@ -77,27 +78,30 @@ def test_emit_verilog_linear(): quan_args = toml.load(f)["passes"]["quantize"] mg, _ = passes.quantize_transform_pass(mg, quan_args) + # inspect the graph metadata + # mg, _ = passes.report_node_meta_param_analysis_pass(mg) + # There is a bug in the current quantizzation pass, where the results metadata is not uppdated with the precision. # Here we temporarily update the metadata here so we can test the hardware back end. - for node in mg.fx_graph.nodes: - for arg, _ in node.meta["mase"].parameters["common"]["args"].items(): - if ( - type(node.meta["mase"].parameters["common"]["args"][arg]) == dict - and "type" in node.meta["mase"].parameters["common"]["args"][arg].keys() - ): - node.meta["mase"].parameters["common"]["args"][arg]["type"] = "fixed" - for result, _ in node.meta["mase"].parameters["common"]["results"].items(): - if ( - type(node.meta["mase"].parameters["common"]["results"][result]) == dict - and "type" - in node.meta["mase"].parameters["common"]["results"][result].keys() - ): - node.meta["mase"].parameters["common"]["results"][result][ - "type" - ] = "fixed" - node.meta["mase"].parameters["common"]["results"][result][ - "precision" - ] = [8, 3] + # for node in mg.fx_graph.nodes: + # for arg, _ in node.meta["mase"].parameters["common"]["args"].items(): + # if ( + # type(node.meta["mase"].parameters["common"]["args"][arg]) == dict + # and "type" in node.meta["mase"].parameters["common"]["args"][arg].keys() + # ): + # node.meta["mase"].parameters["common"]["args"][arg]["type"] = "fixed" + # for result, _ in node.meta["mase"].parameters["common"]["results"].items(): + # if ( + # type(node.meta["mase"].parameters["common"]["results"][result]) == dict + # and "type" + # in node.meta["mase"].parameters["common"]["results"][result].keys() + # ): + # node.meta["mase"].parameters["common"]["results"][result][ + # "type" + # ] = "fixed" + # node.meta["mase"].parameters["common"]["results"][result][ + # "precision" + # ] = [8, 3] mg, _ = passes.add_hardware_metadata_analysis_pass( mg From 50816af9931e02afc5e0e9537505eb7aedf63a93 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 26 Apr 2024 20:40:39 +0000 Subject: [PATCH 20/31] reverted changes made by the quantization PR --- .../graph/transforms/verilog/emit_internal.py | 7 +- .../graph/transforms/verilog/emit_top.py | 168 ++++++++---------- .../verilog/internal_file_dependences.py | 7 +- 3 files changed, 84 insertions(+), 98 deletions(-) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_internal.py b/machop/chop/passes/graph/transforms/verilog/emit_internal.py index 8b40a1b83..d29186ae0 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_internal.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_internal.py @@ -46,7 +46,7 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): for node in graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue - if "INTERNAL_RTL" == node.meta["mase"].parameters["hardware"]["toolchain"]: + if "INTERNAL" == node.meta["mase"].parameters["hardware"]["toolchain"]: if ( hasattr(node.meta["mase"].module, "config") and node.meta["mase"].module.config.get("name", "") == "logicnets" @@ -70,11 +70,12 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): "..", "..", "..", - "..", "mase_components", ) for f in rtl_dependencies: - shutil.copy(os.path.join(hardware_dir, f), rtl_dir) + fname = os.path.join(hardware_dir, f) + assert os.path.isfile(fname), f"Cannot find file {fname}." + shutil.copy(fname, rtl_dir) return graph, {} diff --git a/machop/chop/passes/graph/transforms/verilog/emit_top.py b/machop/chop/passes/graph/transforms/verilog/emit_top.py index f9f8d1881..1d9a2640f 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_top.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_top.py @@ -56,6 +56,10 @@ def param_needs_signals(node, param, value, qualifier="data_in"): ) +# ============================================================================= +# Emit design in a memory-independent dataflow graph +# ============================================================================= + # ============================================================================= # Verilog parameters # ============================================================================= @@ -65,16 +69,16 @@ class VerilogParameterEmitter: def __init__(self, graph): self.graph = graph - def emit(self, graph, parameter_map) -> Tuple[str, Dict[str, str]]: + def emit(self, parameter_map) -> Tuple[str, Dict[str, str]]: """ Emit parameters at the top-level for the top-level module Returns Tuple: - 1) list of parameters as a string to be embedded in Verilog file + 1) list of parameters as a string to be embedded in DFVerilog file """ - nodes_in = graph.nodes_in - nodes_out = graph.nodes_out + nodes_in = self.graph.nodes_in + nodes_out = self.graph.nodes_out node_in_name = vf(nodes_in[0].name) node_out_name = vf(nodes_out[0].name) @@ -88,15 +92,15 @@ def emit(self, graph, parameter_map) -> Tuple[str, Dict[str, str]]: # ============================================================================= -# Verilog interface +# DFVerilog interface # ============================================================================= -class VerilogInterfaceEmitter: +class DFVerilogInterfaceEmitter: def __init__(self, graph): self.graph = graph - def emit(self, graph, parameter_map): + def emit(self, parameter_map): """ Emit interface signal declarations for the top-level module """ @@ -142,17 +146,36 @@ def emit(self, graph, parameter_map): input data_out_{i}_ready,""" i += 1 - # TODO: emit off-chip parameter interface + # Emit all parameters as inputs (they will be mapped at the top-level) + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if not isinstance(arg_info, dict): + continue + if "data_in" not in arg: + arg_name = _cap(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + input [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0], + input {node_name}_{arg}_valid, + output {node_name}_{arg}_ready,""" + i += 1 return _remove_last_comma(interface) # ============================================================================= -# Verilog signals +# DFVerilog signals # ============================================================================= -class VerilogSignalEmitter: +class DFVerilogSignalEmitter: def __init__(self, graph): self.graph = graph @@ -161,15 +184,7 @@ def _emit_signals_top_internal(self, node, parameter_map): node_name = vf(node.name) # Input signals for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): - if not isinstance(arg_info, dict): - continue - - # Skip off-chip parameters as they will be directly connected to the top level - if ( - "data_in" in arg - or node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] - == "BRAM" - ): + if "data_in" in arg: arg_name = v2p(arg) parallelism_params = [ param @@ -188,14 +203,7 @@ def _emit_signals_top_internal(self, node, parameter_map): if not isinstance(result_info, dict): continue - # Skip off-chip parameters as they will be directly connected to the top level - if ( - "data_out" in result - or node.meta["mase"].parameters["hardware"]["interface"][result][ - "storage" - ] - == "BRAM" - ): + if "data_out" in result: result_name = v2p(result) parallelism_params = [ param @@ -259,13 +267,13 @@ def _emit_signals_top_hls(self, node, parameter_map): logic {node_name}_{key}_we0;""" return signals - def emit(self, graph, parameter_map): + def emit(self, parameter_map): """ Emit internal signal declarations for the top-level module """ signals = "" - for node in graph.fx_graph.nodes: + for node in self.graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue node_name = vf(node.name) @@ -284,37 +292,14 @@ def emit(self, graph, parameter_map): # ============================================================================= -# Verilog components (INTERNAL) +# DFVerilog components (INTERNAL) # ============================================================================= -class VerilogInternalComponentEmitter: +class DFVerilogInternalComponentEmitter: def __init__(self, graph): self.graph = graph - def _emit_module_parameters_top_internal(self, key, value, node, parameter_map): - node_name = vf(node.name) - component_name = f"{node_name}_{key}_source" - component_name_inst = f"{component_name}_0" - - parameters = "" - for param in node.meta["mase"].parameters["hardware"]["verilog_param"].keys(): - if f"{_cap(key)}_" in param: - parameters += f".{param}({node_name}_{param}),\n" - parameters = _remove_last_comma(parameters) - - return f""" -{component_name} #( -{parameters} -) {component_name_inst} ( - .clk(clk), - .rst(rst), - .data_out({node_name}_{key}), - .data_out_ready({node_name}_{key}_ready), - .data_out_valid({node_name}_{key}_valid) -); -""" - def emit(self, node, parameter_map): node_name = vf(node.name) component_name = node.meta["mase"].parameters["hardware"]["module"] @@ -363,26 +348,15 @@ def emit(self, node, parameter_map): ); """ - # Emit module parameter instances (e.g. weights and biases) - for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): - if "data_in" in arg: - continue - if not isinstance(arg_info, dict): - continue - - components += self._emit_module_parameters_top_internal( - arg, arg_info, node, parameter_map - ) - return components # ============================================================================= -# Verilog components (HLS) +# DFVerilog components (HLS) # ============================================================================= -class VerilogHLSComponentEmitter: +class DFVerilogHLSComponentEmitter: def __init__(self, graph): self.graph = graph @@ -468,17 +442,17 @@ def emit(self, node, parameter_map): # ============================================================================= -# Verilog components +# DFVerilog components # ============================================================================= -class VerilogComponentEmitter: +class DFVerilogComponentEmitter: def __init__(self, graph): self.graph = graph - self.internal_emitter = VerilogInternalComponentEmitter(graph) - self.hls_emitter = VerilogHLSComponentEmitter(graph) + self.internal_emitter = DFVerilogInternalComponentEmitter(graph) + self.hls_emitter = DFVerilogHLSComponentEmitter(graph) - def emit(self, graph, parameter_map): + def emit(self, parameter_map): """ Emit component declarations for the top-level module """ @@ -488,7 +462,7 @@ def emit(self, graph, parameter_map): // Component instantiation // -------------------------- """ - for node in graph.fx_graph.nodes: + for node in self.graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]: @@ -502,14 +476,13 @@ def emit(self, graph, parameter_map): # ============================================================================= -# Verilog wires +# DFVerilog wires # ============================================================================= -class VerilogWireEmitter: - def __init__(self, graph, parameter_map): +class DFVerilogWireEmitter: + def __init__(self, graph): self.graph = graph - self.parameter_map = parameter_map self.wires = """ // -------------------------- @@ -517,7 +490,7 @@ def __init__(self, graph, parameter_map): // -------------------------- """ - def _emit_top_wires(self): + def _emit_top_wires(self, parameter_map): nodes_in = self.graph.nodes_in nodes_out = self.graph.nodes_out @@ -580,7 +553,7 @@ def _emit_node2node_wires(self): """ return wires - def emit(self): + def emit(self, parameter_map): """ Emit internal signal connections for the top-level module This includes two interconnection types: @@ -588,7 +561,7 @@ def emit(self): 2. Interface casting between inputs and outputs """ - self.wires += self._emit_top_wires() + self.wires += self._emit_top_wires(parameter_map) self.wires += self._emit_node2node_wires() return self.wires @@ -598,39 +571,38 @@ def emit(self): # ============================================================================= -class VerilogEmitter: +class DataflowEmitter: def __init__(self, graph): self.graph = graph - self.parameter_map = get_verilog_parameters(graph) - def emit(self, graph, top_name): - parameters_to_emit = VerilogParameterEmitter(graph).emit( - graph, self.parameter_map + def emit(self, top_name): + parameters_to_emit = VerilogParameterEmitter(self.graph).emit( + self.parameter_map ) - interface_to_emit = VerilogInterfaceEmitter(graph).emit( - graph, self.parameter_map + interface_to_emit = DFVerilogInterfaceEmitter(self.graph).emit( + self.parameter_map ) - signals_to_emit = VerilogSignalEmitter(graph).emit(graph, self.parameter_map) + signals_to_emit = DFVerilogSignalEmitter(self.graph).emit(self.parameter_map) - components_to_emit = VerilogComponentEmitter(graph).emit( - graph, self.parameter_map + components_to_emit = DFVerilogComponentEmitter(self.graph).emit( + self.parameter_map ) - wires_to_emit = VerilogWireEmitter(graph, self.parameter_map).emit() + wires_to_emit = DFVerilogWireEmitter(self.graph).emit(self.parameter_map) time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S") module_inst = """ // ===================================== -// Mase Hardware +// Mase Hardware (Dataflow) // Model: {} // {} // ===================================== `timescale 1ns/1ps -module {} #( +module {}_dataflow #( {} ) ( input clk, @@ -1112,10 +1084,18 @@ def emit_verilog_top_transform_pass(graph, pass_args={}): ) top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" init_project(project_dir) + logger.info(f"Project path: {project_dir}") + rtl_dir = os.path.join(project_dir, "hardware", "rtl") - top = VerilogEmitter(graph).emit(graph, top_name) + # Emit device-independent hardware design in dataflow + df = DataflowEmitter(graph).emit(top_name) + df_file = os.path.join(rtl_dir, f"{top_name}_df.sv") + with open(df_file, "w") as df_design: + df_design.write(df) + # Emit memory mapping with BRAMs for the top-level hardware design + top = MemoryMapEmitter(graph).emit(top_name) top_file = os.path.join(rtl_dir, f"{top_name}.sv") with open(top_file, "w") as top_design: top_design.write(top) diff --git a/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py b/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py index a84625f93..9087e2ab3 100644 --- a/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py +++ b/machop/chop/passes/graph/transforms/verilog/internal_file_dependences.py @@ -12,6 +12,11 @@ "common/rtl/skid_buffer.sv", "common/rtl/join2.sv", "cast/rtl/fixed_rounding.sv", + "cast/rtl/fixed_round.sv", + ], + "relu": [ + "activations/rtl/fixed_relu.sv", + "cast/rtl/fixed_rounding.sv", + "cast/rtl/fixed_round.sv", ], - "relu": ["activations/fixed_relu.sv"], } From c1d1fd7c6d265157b33ce9d1396d262d55e3011a Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 26 Apr 2024 20:42:59 +0000 Subject: [PATCH 21/31] reduced test case --- .../verilog/test_emit_verilog_linear.py | 47 +++++-------------- 1 file changed, 12 insertions(+), 35 deletions(-) diff --git a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py index ba2fbe224..79c154c37 100644 --- a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py +++ b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py @@ -29,16 +29,15 @@ class MLP(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.fc1 = nn.Linear(28 * 28, 28 * 28) - self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4) - self.fc3 = nn.Linear(28 * 28 * 4, 10) + self.fc1 = nn.Linear(5 * 5, 5 * 5) + self.fc2 = nn.Linear(5 * 5, 5 * 5 * 4) + self.fc3 = nn.Linear(5 * 5 * 4, 10) def forward(self, x): x = torch.flatten(x, start_dim=1, end_dim=-1) - # x = torch.nn.functional.relu(x) x = torch.nn.functional.relu(self.fc1(x)) - x = torch.nn.functional.relu(self.fc2(x)) - x = self.fc3(x) + # x = torch.nn.functional.relu(self.fc2(x)) + # x = self.fc3(x) return x @@ -48,7 +47,7 @@ def test_emit_verilog_linear(): # Provide a dummy input for the graph so it can use for tracing batch_size = 1 - x = torch.randn((batch_size, 28, 28)) + x = torch.randn((batch_size, 5, 5)) dummy_in = {"x": x} mg, _ = passes.init_metadata_analysis_pass(mg, None) @@ -81,37 +80,15 @@ def test_emit_verilog_linear(): # inspect the graph metadata # mg, _ = passes.report_node_meta_param_analysis_pass(mg) - # There is a bug in the current quantizzation pass, where the results metadata is not uppdated with the precision. - # Here we temporarily update the metadata here so we can test the hardware back end. - # for node in mg.fx_graph.nodes: - # for arg, _ in node.meta["mase"].parameters["common"]["args"].items(): - # if ( - # type(node.meta["mase"].parameters["common"]["args"][arg]) == dict - # and "type" in node.meta["mase"].parameters["common"]["args"][arg].keys() - # ): - # node.meta["mase"].parameters["common"]["args"][arg]["type"] = "fixed" - # for result, _ in node.meta["mase"].parameters["common"]["results"].items(): - # if ( - # type(node.meta["mase"].parameters["common"]["results"][result]) == dict - # and "type" - # in node.meta["mase"].parameters["common"]["results"][result].keys() - # ): - # node.meta["mase"].parameters["common"]["results"][result][ - # "type" - # ] = "fixed" - # node.meta["mase"].parameters["common"]["results"][result][ - # "precision" - # ] = [8, 3] - - mg, _ = passes.add_hardware_metadata_analysis_pass( - mg - ) # add metadata for hardware in each mase node of graph - mg, _ = passes.report_node_hardware_type_analysis_pass(mg) # pretty print + # add metadata for hardware in each mase node of graph + mg, _ = passes.add_hardware_metadata_analysis_pass(mg) + # pretty print + mg, _ = passes.report_node_hardware_type_analysis_pass(mg) # mg = verify_hardware_metadata_analysis_pass(mg) mg, _ = passes.emit_verilog_top_transform_pass(mg) - # mg = passes.emit_bram_transform_pass(mg) - # mg, _ = passes.emit_internal_rtl_transform_pass(mg) + mg, _ = passes.emit_bram_transform_pass(mg) + mg, _ = passes.emit_internal_rtl_transform_pass(mg) # # For internal models, the test inputs can be directly fetched from the dataset # # using InputGenerator from chop.tools.get_input From 8535c575ae06f14fd0f07f3019e0eb9c5455f5b4 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Sun, 28 Apr 2024 16:24:10 +0000 Subject: [PATCH 22/31] Fetched previous verilog analysis pass for testing --- .../passes/graph/analysis/verilog/cocotb.py | 102 +++++++++++++ .../graph/analysis/verilog/test_verilog.py | 143 +++++++++--------- 2 files changed, 176 insertions(+), 69 deletions(-) create mode 100644 machop/chop/passes/graph/analysis/verilog/cocotb.py diff --git a/machop/chop/passes/graph/analysis/verilog/cocotb.py b/machop/chop/passes/graph/analysis/verilog/cocotb.py new file mode 100644 index 000000000..0194a4fcb --- /dev/null +++ b/machop/chop/passes/graph/analysis/verilog/cocotb.py @@ -0,0 +1,102 @@ +import os, glob +from chop.passes.graph.utils import vf + +from .cocotb import VerificationCase +from mase_cocotb.random_test import RandomSource, RandomSink, check_results +from cocotb.runner import get_runner + + +def get_dut_parameters(graph): + parameter_map = {} + + for node in graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + + for key, value in ( + node.meta["mase"].parameters["hardware"]["verilog_param"].items() + ): + if not isinstance(value, (int, float, complex, bool)): + value = '"' + value + '"' + assert ( + f"{node_name}_{key}" not in parameter_map.keys() + ), f"{node_name}_{key} already exists in the parameter map" + parameter_map[f"{node_name}_{key}"] = value + return parameter_map + + +def get_dependence_files(graph): + f = [] + for node in graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + f += node.meta["mase"].parameters["hardware"]["dependence_files"] + + f = list(dict.fromkeys(f)) + return f + + +def runner(mg): + sim = os.getenv("SIM", "verilator") + + verilog_sources = get_dependence_files(mg) + for i, v in enumerate(verilog_sources): + verilog_sources[i] = os.path.relpath( + os.path.join("/workspace", "mase_components", v), os.getcwd() + ) + # TODO: make project name variable + for v in glob.glob("./top/hardware/rtl/*.sv"): + verilog_sources.append(os.path.relpath(v, os.getcwd())) + + # TODO: make samples and iterations variable + tb = VerificationCase(samples=1, iterations=1) + + # TODO: work out the num + for node in mg.nodes_in: + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + setattr( + tb, + arg, + RandomSource( + name=arg, + samples=tb.samples * tb.iterations, + num=12324, + max_stalls=0, + ), + ) + for node in mg.nodes_out: + for result, result_info in ( + node.meta["mase"].parameters["common"]["results"].items() + ): + setattr( + tb, + result, + RandomSink( + name=result, + samples=tb.samples * tb.iterations, + num=12324, + max_stalls=0, + ), + ) + + p = get_dut_parameters(mg) + print(p) + + # set parameters + extra_args = [] + for k, v in p.items(): + extra_args.append(f"-G{k}={v}") + print(extra_args) + runner = get_runner(sim) + runner.build( + verilog_sources=verilog_sources, + hdl_toplevel="top", + build_args=extra_args, + ) + runner.test(hdl_toplevel="top", test_module="top_tb") + + +def test_verilog_analysis_pass(mg, pass_args={}): + runner(mg) + return mg, {} diff --git a/machop/chop/passes/graph/analysis/verilog/test_verilog.py b/machop/chop/passes/graph/analysis/verilog/test_verilog.py index 50479a85c..f9b272c74 100644 --- a/machop/chop/passes/graph/analysis/verilog/test_verilog.py +++ b/machop/chop/passes/graph/analysis/verilog/test_verilog.py @@ -1,75 +1,80 @@ -import logging -from typing import Tuple, Dict -import math -import os -import time -from multiprocessing import Process, Queue +#!/usr/bin/env python3 -from chop.passes.graph.utils import vf, v2p, init_project +import os, logging -logger = logging.getLogger(__name__) - - -def get_test_parameters(mg): - """ - Extract the verilog parameters from the mase graph for cocotb testing - """ - return {} - - -def get_dummy_inputs(mg): - """ - Fetch test inputs from dataset or create a random one - """ - return {} - - -def run_software_test(mg, inputs): - """ - Run software model on given inputs - """ - return {} - - -def run_cocotb_test(mg, parameters, inputs): - """ - Create a cocotb test case and use mase runner to run hardware simulation - """ - return {} +from mase_cocotb.random_test import check_results +from mase_cocotb.runner import mase_runner +import cocotb +from cocotb.triggers import Timer +from cocotb.triggers import FallingEdge +from cocotb.clock import Clock -def compare_results(r0, r1): - return r0 == r1 - - -def test_verilog_analysis_pass(graph, pass_args={}): - """Use cocotb to test the model design in Verilog - - :param graph: a MaseGraph - :type graph: MaseGraph - :param pass_args: this pass requires additional arguments which is explained below, defaults to {} - :type pass_args: _type_, optional - :return: return a tuple of a MaseGraph and an empty dict (no additional info to return) - :rtype: tuple(MaseGraph, Dict) - - - - pass_args - - project_dir -> str : the directory of the project for cosimulation - - top_name -> str : top-level name - """ - - logger.info("Testing the model in Verilog...") - - project_dir = ( - pass_args["project_dir"] if "project_dir" in pass_args.keys() else "top" - ) - top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" - - parameters = get_test_parameters(graph) - inputs = get_dummy_inputs(graph) - software_results = run_software_test(graph, inputs) - hardware_results = run_cocotb_test(graph, parameters, inputs) +logger = logging.getLogger(__name__) - compare_results(software_results, hardware_results) - return graph, {} +# DUT test specifications +class VerificationCase: + def __init__(self, iterations=1, samples=10): + self.samples = samples + self.iterations = iterations + + +@cocotb.test() +async def test_fixed_linear(dut): + """Test integer based vector mult""" + samples = 1000 + test_case = VerificationCase(samples=samples) + + # Reset cycle + await Timer(20, units="ns") + dut.rst.value = 1 + await Timer(100, units="ns") + dut.rst.value = 0 + + # Create a 10ns-period clock on port clk + clock = Clock(dut.clk, 10, units="ns") + # Start the clock + cocotb.start_soon(clock.start()) + await Timer(500, units="ns") + + # Synchronize with the clock + dut.data_in_0_valid.value = 0 + dut.data_out_0_ready.value = 1 + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + + done = False + # Set a timeout to avoid deadlock + for i in range(samples * 100): + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + dut.data_in_0_valid.value = test_case.data_in.pre_compute() + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.pre_compute( + dut.data_out_0_valid.value + ) + await Timer(1, units="ns") + debug_state(dut, "Post-clk") + + dut.data_in_0_valid.value, dut.data_in_0.value = test_case.data_in.compute( + dut.data_in_0_ready.value + ) + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.compute( + dut.data_out_0_valid.value, dut.data_out_0.value + ) + debug_state(dut, "Pre-clk") + + if test_case.data_in.is_empty() and test_case.outputs.is_full(): + done = True + break + assert ( + done + ), "Deadlock detected or the simulation reaches the maximum cycle limit (fixed it by adjusting the loop trip count)" + + check_results(test_case.outputs.data, test_case.ref) From 590f20e633a2907827adf1d00fbbbcd5e058373a Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Sun, 28 Apr 2024 16:29:41 +0000 Subject: [PATCH 23/31] Fetched the latest version of draft --- .../passes/graph/analysis/verilog/cocotb.py | 174 ++++++++---------- .../graph/analysis/verilog/test_verilog.py | 174 ++++++++++-------- 2 files changed, 174 insertions(+), 174 deletions(-) diff --git a/machop/chop/passes/graph/analysis/verilog/cocotb.py b/machop/chop/passes/graph/analysis/verilog/cocotb.py index 0194a4fcb..f9b272c74 100644 --- a/machop/chop/passes/graph/analysis/verilog/cocotb.py +++ b/machop/chop/passes/graph/analysis/verilog/cocotb.py @@ -1,102 +1,80 @@ -import os, glob -from chop.passes.graph.utils import vf - -from .cocotb import VerificationCase -from mase_cocotb.random_test import RandomSource, RandomSink, check_results -from cocotb.runner import get_runner - - -def get_dut_parameters(graph): - parameter_map = {} - - for node in graph.fx_graph.nodes: - if node.meta["mase"].parameters["hardware"]["is_implicit"]: - continue - node_name = vf(node.name) - - for key, value in ( - node.meta["mase"].parameters["hardware"]["verilog_param"].items() - ): - if not isinstance(value, (int, float, complex, bool)): - value = '"' + value + '"' - assert ( - f"{node_name}_{key}" not in parameter_map.keys() - ), f"{node_name}_{key} already exists in the parameter map" - parameter_map[f"{node_name}_{key}"] = value - return parameter_map - - -def get_dependence_files(graph): - f = [] - for node in graph.fx_graph.nodes: - if node.meta["mase"].parameters["hardware"]["is_implicit"]: - continue - f += node.meta["mase"].parameters["hardware"]["dependence_files"] - - f = list(dict.fromkeys(f)) - return f - - -def runner(mg): - sim = os.getenv("SIM", "verilator") - - verilog_sources = get_dependence_files(mg) - for i, v in enumerate(verilog_sources): - verilog_sources[i] = os.path.relpath( - os.path.join("/workspace", "mase_components", v), os.getcwd() +#!/usr/bin/env python3 + +import os, logging + +from mase_cocotb.random_test import check_results +from mase_cocotb.runner import mase_runner + +import cocotb +from cocotb.triggers import Timer +from cocotb.triggers import FallingEdge +from cocotb.clock import Clock + +logger = logging.getLogger(__name__) + + +# DUT test specifications +class VerificationCase: + def __init__(self, iterations=1, samples=10): + self.samples = samples + self.iterations = iterations + + +@cocotb.test() +async def test_fixed_linear(dut): + """Test integer based vector mult""" + samples = 1000 + test_case = VerificationCase(samples=samples) + + # Reset cycle + await Timer(20, units="ns") + dut.rst.value = 1 + await Timer(100, units="ns") + dut.rst.value = 0 + + # Create a 10ns-period clock on port clk + clock = Clock(dut.clk, 10, units="ns") + # Start the clock + cocotb.start_soon(clock.start()) + await Timer(500, units="ns") + + # Synchronize with the clock + dut.data_in_0_valid.value = 0 + dut.data_out_0_ready.value = 1 + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + debug_state(dut, "Pre-clk") + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + + done = False + # Set a timeout to avoid deadlock + for i in range(samples * 100): + await FallingEdge(dut.clk) + debug_state(dut, "Post-clk") + dut.data_in_0_valid.value = test_case.data_in.pre_compute() + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.pre_compute( + dut.data_out_0_valid.value ) - # TODO: make project name variable - for v in glob.glob("./top/hardware/rtl/*.sv"): - verilog_sources.append(os.path.relpath(v, os.getcwd())) - - # TODO: make samples and iterations variable - tb = VerificationCase(samples=1, iterations=1) + await Timer(1, units="ns") + debug_state(dut, "Post-clk") - # TODO: work out the num - for node in mg.nodes_in: - for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): - setattr( - tb, - arg, - RandomSource( - name=arg, - samples=tb.samples * tb.iterations, - num=12324, - max_stalls=0, - ), - ) - for node in mg.nodes_out: - for result, result_info in ( - node.meta["mase"].parameters["common"]["results"].items() - ): - setattr( - tb, - result, - RandomSink( - name=result, - samples=tb.samples * tb.iterations, - num=12324, - max_stalls=0, - ), - ) - - p = get_dut_parameters(mg) - print(p) - - # set parameters - extra_args = [] - for k, v in p.items(): - extra_args.append(f"-G{k}={v}") - print(extra_args) - runner = get_runner(sim) - runner.build( - verilog_sources=verilog_sources, - hdl_toplevel="top", - build_args=extra_args, - ) - runner.test(hdl_toplevel="top", test_module="top_tb") + dut.data_in_0_valid.value, dut.data_in_0.value = test_case.data_in.compute( + dut.data_in_0_ready.value + ) + await Timer(1, units="ns") + dut.data_out_0_ready.value = test_case.outputs.compute( + dut.data_out_0_valid.value, dut.data_out_0.value + ) + debug_state(dut, "Pre-clk") + if test_case.data_in.is_empty() and test_case.outputs.is_full(): + done = True + break + assert ( + done + ), "Deadlock detected or the simulation reaches the maximum cycle limit (fixed it by adjusting the loop trip count)" -def test_verilog_analysis_pass(mg, pass_args={}): - runner(mg) - return mg, {} + check_results(test_case.outputs.data, test_case.ref) diff --git a/machop/chop/passes/graph/analysis/verilog/test_verilog.py b/machop/chop/passes/graph/analysis/verilog/test_verilog.py index f9b272c74..0194a4fcb 100644 --- a/machop/chop/passes/graph/analysis/verilog/test_verilog.py +++ b/machop/chop/passes/graph/analysis/verilog/test_verilog.py @@ -1,80 +1,102 @@ -#!/usr/bin/env python3 - -import os, logging - -from mase_cocotb.random_test import check_results -from mase_cocotb.runner import mase_runner - -import cocotb -from cocotb.triggers import Timer -from cocotb.triggers import FallingEdge -from cocotb.clock import Clock - -logger = logging.getLogger(__name__) - - -# DUT test specifications -class VerificationCase: - def __init__(self, iterations=1, samples=10): - self.samples = samples - self.iterations = iterations - - -@cocotb.test() -async def test_fixed_linear(dut): - """Test integer based vector mult""" - samples = 1000 - test_case = VerificationCase(samples=samples) - - # Reset cycle - await Timer(20, units="ns") - dut.rst.value = 1 - await Timer(100, units="ns") - dut.rst.value = 0 - - # Create a 10ns-period clock on port clk - clock = Clock(dut.clk, 10, units="ns") - # Start the clock - cocotb.start_soon(clock.start()) - await Timer(500, units="ns") - - # Synchronize with the clock - dut.data_in_0_valid.value = 0 - dut.data_out_0_ready.value = 1 - debug_state(dut, "Pre-clk") - await FallingEdge(dut.clk) - debug_state(dut, "Post-clk") - debug_state(dut, "Pre-clk") - await FallingEdge(dut.clk) - debug_state(dut, "Post-clk") - - done = False - # Set a timeout to avoid deadlock - for i in range(samples * 100): - await FallingEdge(dut.clk) - debug_state(dut, "Post-clk") - dut.data_in_0_valid.value = test_case.data_in.pre_compute() - await Timer(1, units="ns") - dut.data_out_0_ready.value = test_case.outputs.pre_compute( - dut.data_out_0_valid.value - ) - await Timer(1, units="ns") - debug_state(dut, "Post-clk") +import os, glob +from chop.passes.graph.utils import vf - dut.data_in_0_valid.value, dut.data_in_0.value = test_case.data_in.compute( - dut.data_in_0_ready.value - ) - await Timer(1, units="ns") - dut.data_out_0_ready.value = test_case.outputs.compute( - dut.data_out_0_valid.value, dut.data_out_0.value +from .cocotb import VerificationCase +from mase_cocotb.random_test import RandomSource, RandomSink, check_results +from cocotb.runner import get_runner + + +def get_dut_parameters(graph): + parameter_map = {} + + for node in graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + + for key, value in ( + node.meta["mase"].parameters["hardware"]["verilog_param"].items() + ): + if not isinstance(value, (int, float, complex, bool)): + value = '"' + value + '"' + assert ( + f"{node_name}_{key}" not in parameter_map.keys() + ), f"{node_name}_{key} already exists in the parameter map" + parameter_map[f"{node_name}_{key}"] = value + return parameter_map + + +def get_dependence_files(graph): + f = [] + for node in graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + f += node.meta["mase"].parameters["hardware"]["dependence_files"] + + f = list(dict.fromkeys(f)) + return f + + +def runner(mg): + sim = os.getenv("SIM", "verilator") + + verilog_sources = get_dependence_files(mg) + for i, v in enumerate(verilog_sources): + verilog_sources[i] = os.path.relpath( + os.path.join("/workspace", "mase_components", v), os.getcwd() ) - debug_state(dut, "Pre-clk") + # TODO: make project name variable + for v in glob.glob("./top/hardware/rtl/*.sv"): + verilog_sources.append(os.path.relpath(v, os.getcwd())) + + # TODO: make samples and iterations variable + tb = VerificationCase(samples=1, iterations=1) + + # TODO: work out the num + for node in mg.nodes_in: + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + setattr( + tb, + arg, + RandomSource( + name=arg, + samples=tb.samples * tb.iterations, + num=12324, + max_stalls=0, + ), + ) + for node in mg.nodes_out: + for result, result_info in ( + node.meta["mase"].parameters["common"]["results"].items() + ): + setattr( + tb, + result, + RandomSink( + name=result, + samples=tb.samples * tb.iterations, + num=12324, + max_stalls=0, + ), + ) + + p = get_dut_parameters(mg) + print(p) + + # set parameters + extra_args = [] + for k, v in p.items(): + extra_args.append(f"-G{k}={v}") + print(extra_args) + runner = get_runner(sim) + runner.build( + verilog_sources=verilog_sources, + hdl_toplevel="top", + build_args=extra_args, + ) + runner.test(hdl_toplevel="top", test_module="top_tb") - if test_case.data_in.is_empty() and test_case.outputs.is_full(): - done = True - break - assert ( - done - ), "Deadlock detected or the simulation reaches the maximum cycle limit (fixed it by adjusting the loop trip count)" - check_results(test_case.outputs.data, test_case.ref) +def test_verilog_analysis_pass(mg, pass_args={}): + runner(mg) + return mg, {} From 8076bc1bc597038edd84cf623953787a87c41334 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Sun, 28 Apr 2024 17:00:57 +0000 Subject: [PATCH 24/31] Pass syntax error in Python --- machop/chop/passes/__init__.py | 1 + .../graph/analysis/verilog/test_verilog.py | 58 ++++++++++++++----- .../verilog/test_emit_verilog_linear.py | 2 + 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/machop/chop/passes/__init__.py b/machop/chop/passes/__init__.py index 59ddba950..517c2edbf 100644 --- a/machop/chop/passes/__init__.py +++ b/machop/chop/passes/__init__.py @@ -15,6 +15,7 @@ verify_common_metadata_analysis_pass, run_cosim_analysis_pass, get_synthesis_results, + test_verilog_analysis_pass, ) from .graph.transforms import ( prune_transform_pass, diff --git a/machop/chop/passes/graph/analysis/verilog/test_verilog.py b/machop/chop/passes/graph/analysis/verilog/test_verilog.py index 0194a4fcb..276e624b3 100644 --- a/machop/chop/passes/graph/analysis/verilog/test_verilog.py +++ b/machop/chop/passes/graph/analysis/verilog/test_verilog.py @@ -1,10 +1,15 @@ +import logging import os, glob -from chop.passes.graph.utils import vf +from pathlib import Path from .cocotb import VerificationCase -from mase_cocotb.random_test import RandomSource, RandomSink, check_results from cocotb.runner import get_runner +from chop.passes.graph.utils import vf +from mase_cocotb.random_test import RandomSource, RandomSink, check_results + +logger = logging.getLogger(__name__) + def get_dut_parameters(graph): parameter_map = {} @@ -37,17 +42,13 @@ def get_dependence_files(graph): return f -def runner(mg): +def runner(mg, project_dir, top_name): sim = os.getenv("SIM", "verilator") - verilog_sources = get_dependence_files(mg) - for i, v in enumerate(verilog_sources): - verilog_sources[i] = os.path.relpath( - os.path.join("/workspace", "mase_components", v), os.getcwd() - ) - # TODO: make project name variable - for v in glob.glob("./top/hardware/rtl/*.sv"): - verilog_sources.append(os.path.relpath(v, os.getcwd())) + # TODO: Grab internal verilog source only. Need to include HLS hardware as well. + sv_srcs = [] + for v in glob.glob(os.path.join(project_dir, "hardware", "rtl", "*.sv")): + sv_srcs.append(os.path.relpath(v, os.getcwd())) # TODO: make samples and iterations variable tb = VerificationCase(samples=1, iterations=1) @@ -65,6 +66,7 @@ def runner(mg): max_stalls=0, ), ) + for node in mg.nodes_out: for result, result_info in ( node.meta["mase"].parameters["common"]["results"].items() @@ -81,22 +83,46 @@ def runner(mg): ) p = get_dut_parameters(mg) - print(p) + # logger.debug(p) # set parameters extra_args = [] for k, v in p.items(): extra_args.append(f"-G{k}={v}") - print(extra_args) + logger.debug(extra_args) runner = get_runner(sim) runner.build( - verilog_sources=verilog_sources, - hdl_toplevel="top", + verilog_sources=sv_srcs, + hdl_toplevel=top_name, build_args=extra_args, ) runner.test(hdl_toplevel="top", test_module="top_tb") def test_verilog_analysis_pass(mg, pass_args={}): - runner(mg) + """Test the top-level hardware design using Cocotb + + :param graph: a MaseGraph + :type graph: MaseGraph + :param pass_args: this pass requires additional arguments which is explained below, defaults to {} + :type pass_args: _type_, optional + :return: return a tuple of a MaseGraph and an empty dict (no additional info to return) + :rtype: tuple(MaseGraph, Dict) + + - pass_args + - project_dir -> str : the directory of the project for cosimulation + - top_name -> str : top-level name + """ + + logger.info(f"Running hardware simulation using Cocotb") + + project_dir = ( + pass_args["project_dir"] + if "project_dir" in pass_args.keys() + else Path.home() / ".mase" / "top" + ) + top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" + logger.info(f"Project path: {project_dir}") + + runner(mg, project_dir, top_name) return mg, {} diff --git a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py index 79c154c37..9baab3f11 100644 --- a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py +++ b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py @@ -90,6 +90,8 @@ def test_emit_verilog_linear(): mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) + mg, _ = passes.test_verilog_analysis_pass(mg) + # # For internal models, the test inputs can be directly fetched from the dataset # # using InputGenerator from chop.tools.get_input # project_dir = Path(__file__).parents[6] / "top" From 5d26837fbae680c4e23069442273d768fd1d9230 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Sun, 28 Apr 2024 17:10:08 +0000 Subject: [PATCH 25/31] refactored test pass format --- .../graph/analysis/verilog/test_verilog.py | 13 +----------- .../verilog/test_emit_verilog_linear.py | 21 ++----------------- 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/machop/chop/passes/graph/analysis/verilog/test_verilog.py b/machop/chop/passes/graph/analysis/verilog/test_verilog.py index 276e624b3..a7598a653 100644 --- a/machop/chop/passes/graph/analysis/verilog/test_verilog.py +++ b/machop/chop/passes/graph/analysis/verilog/test_verilog.py @@ -31,17 +31,6 @@ def get_dut_parameters(graph): return parameter_map -def get_dependence_files(graph): - f = [] - for node in graph.fx_graph.nodes: - if node.meta["mase"].parameters["hardware"]["is_implicit"]: - continue - f += node.meta["mase"].parameters["hardware"]["dependence_files"] - - f = list(dict.fromkeys(f)) - return f - - def runner(mg, project_dir, top_name): sim = os.getenv("SIM", "verilator") @@ -96,7 +85,7 @@ def runner(mg, project_dir, top_name): hdl_toplevel=top_name, build_args=extra_args, ) - runner.test(hdl_toplevel="top", test_module="top_tb") + runner.test(hdl_toplevel=top_name, test_module=f"{top_name}_tb") def test_verilog_analysis_pass(mg, pass_args={}): diff --git a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py index 9baab3f11..2a40e55ba 100644 --- a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py +++ b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py @@ -86,31 +86,14 @@ def test_emit_verilog_linear(): mg, _ = passes.report_node_hardware_type_analysis_pass(mg) # mg = verify_hardware_metadata_analysis_pass(mg) + # Emit Verilog sources mg, _ = passes.emit_verilog_top_transform_pass(mg) mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) + # Test Verilog sources mg, _ = passes.test_verilog_analysis_pass(mg) - # # For internal models, the test inputs can be directly fetched from the dataset - # # using InputGenerator from chop.tools.get_input - # project_dir = Path(__file__).parents[6] / "top" - # print(f"project_dir {project_dir}") - # cosim_config = {"test_inputs": [x], "trans_num": 1, "project_dir": project_dir} - # # mg = passes.emit_verilog_tb_transform_pass(mg, pass_args=cosim_config) - - # # Run simulation pass if Vivado available - # try: - # execute_cli("xelab -h", log_output=False) - # has_verilog = True - # # mg = get_synthesis_results("top", mg, target="xcu250-figd2104-2L-e", output_dir=".") - # except: - # has_verilog = False - # print(f"Vivado not available") - - # if has_verilog: - # mg = passes.run_cosim_analysis_pass(mg) - if __name__ == "__main__": test_emit_verilog_linear() From cfb256e896119ccdee516e3e2ef6fb73cc67ff1a Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Sun, 28 Apr 2024 17:22:52 +0000 Subject: [PATCH 26/31] format quantize PR https://github.com/DeepWok/mase/pull/173 --- .../passes/graph/analysis/verilog/cocotb.py | 4 +- .../quantize/quant_parsers/__init__.py | 7 +- .../quantize/quant_parsers/parse_q_config.py | 5 +- .../quant_parsers/update_node_meta.py | 184 +++++++++++++----- .../graph/transforms/quantize/quantize.py | 31 ++- 5 files changed, 176 insertions(+), 55 deletions(-) diff --git a/machop/chop/passes/graph/analysis/verilog/cocotb.py b/machop/chop/passes/graph/analysis/verilog/cocotb.py index f9b272c74..45743b2d6 100644 --- a/machop/chop/passes/graph/analysis/verilog/cocotb.py +++ b/machop/chop/passes/graph/analysis/verilog/cocotb.py @@ -21,8 +21,8 @@ def __init__(self, iterations=1, samples=10): @cocotb.test() -async def test_fixed_linear(dut): - """Test integer based vector mult""" +async def test_top(dut): + """Test top-level model hardware design""" samples = 1000 test_case = VerificationCase(samples=samples) diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py index 9668d2a7b..c5205a89a 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/__init__.py @@ -1,4 +1,9 @@ # from .parse_quant_config import parse_node_q_config from .parse_q_config import parse_node_q_config + # from .update_node_meta import relink_node_meta, update_quant_meta_param -from .update_node_meta import relink_node_meta, update_q_meta_param, infer_result_dtype_and_precision \ No newline at end of file +from .update_node_meta import ( + relink_node_meta, + update_q_meta_param, + infer_result_dtype_and_precision, +) diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py index 3caa6f839..93491fb75 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/parse_q_config.py @@ -10,6 +10,7 @@ } """ + def get_q_op_entries(q_name: str, mase_op: str): match q_name: case "fixed": @@ -18,7 +19,9 @@ def get_q_op_entries(q_name: str, mase_op: str): raise ValueError(f"Unknown quantization arithmetic name: {q_name}") if mase_op not in op_entries: - raise ValueError(f"Unknown MASE operation name: {mase_op} for quantization arithmetic: {q_name}") + raise ValueError( + f"Unknown MASE operation name: {mase_op} for quantization arithmetic: {q_name}" + ) return op_entries[mase_op] diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py index a5dc9cd03..985a3c4e9 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py @@ -75,8 +75,12 @@ def update_node_meta_param_fixed(node, q_config): raise RuntimeError( f"Optional argument {arg_name} found in node meta, but not found in q_config: {q_config}" ) - node.meta["mase"].parameters["common"]["args"][arg_name]["type"] = "fixed" - node.meta["mase"].parameters["common"]["args"][arg_name]["precision"] = [ + node.meta["mase"].parameters["common"]["args"][arg_name][ + "type" + ] = "fixed" + node.meta["mase"].parameters["common"]["args"][arg_name][ + "precision" + ] = [ q_config[f"{operand_name}_width"], q_config[f"{operand_name}_frac_width"], ] @@ -96,8 +100,10 @@ def update_q_meta_param(node, config: dict): case _: raise ValueError(f"Unsupported quantization arithmetic name: {q_arith}") + from torch.fx import Node + def find_next_compute_node(node: Node): for n in node.users: if get_mase_type(n) in ["module_related_func", "builtin_func"]: @@ -106,6 +112,7 @@ def find_next_compute_node(node: Node): return find_next_compute_node(n) return None, None + def find_prev_compute_node(node: Node): for n in node.all_input_nodes: if get_mase_type(n) in ["module_related_func", "builtin_func"]: @@ -114,6 +121,7 @@ def find_prev_compute_node(node: Node): return find_prev_compute_node(n) return None, None + def infer_result_dtype_and_precision(node: Node): """ ```text @@ -129,7 +137,9 @@ def infer_result_dtype_and_precision(node: Node): # input node input_node, next_node = find_next_compute_node(node) if input_node is None: - logger.warning(f"Failed to find next module_related_func node for input node {node.name}. Check if the graph contains module_related_func") + logger.warning( + f"Failed to find next module_related_func node for input node {node.name}. Check if the graph contains module_related_func" + ) return i = 0 for n in next_node.all_input_nodes: @@ -140,14 +150,24 @@ def infer_result_dtype_and_precision(node: Node): arg_value = next_node.meta["mase"].parameters["common"]["args"][arg_key] for result in node.meta["mase"].parameters["common"]["results"]: - node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] - node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result]["precision"] = ( + arg_value["precision"] + ) for arg in node.meta["mase"].parameters["common"]["args"]: - node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value["type"] - node.meta["mase"].parameters["common"]["args"][arg]["precision"] = arg_value["precision"] - - logger.debug(f"Inferred arg & result dtype and precision for input node `{node.name}` using `{next_node.name}`") + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value[ + "type" + ] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = ( + arg_value["precision"] + ) + + logger.debug( + f"Inferred arg & result dtype and precision for input node `{node.name}` using `{next_node.name}`" + ) elif get_mase_type(node) in ["module_related_func", "builtin_func"]: input_node, next_node = find_next_compute_node(node) @@ -157,23 +177,42 @@ def infer_result_dtype_and_precision(node: Node): max_dtype = None max_bitwidth = 0 for arg in node.meta["mase"].parameters["common"]["args"]: - if not isinstance(node.meta["mase"].parameters["common"]["args"][arg], dict): + if not isinstance( + node.meta["mase"].parameters["common"]["args"][arg], dict + ): continue - if not "precision" in node.meta["mase"].parameters["common"]["args"][arg]: + if ( + not "precision" + in node.meta["mase"].parameters["common"]["args"][arg] + ): continue - cur_width = node.meta["mase"].parameters["common"]["args"][arg]["precision"][0] + cur_width = node.meta["mase"].parameters["common"]["args"][arg][ + "precision" + ][0] if cur_width > max_bitwidth: max_bitwidth = cur_width - max_precision = node.meta["mase"].parameters["common"]["args"][arg]["precision"] - max_dtype = node.meta["mase"].parameters["common"]["args"][arg]["type"] + max_precision = node.meta["mase"].parameters["common"]["args"][arg][ + "precision" + ] + max_dtype = node.meta["mase"].parameters["common"]["args"][arg][ + "type" + ] if max_precision is None: - raise RuntimeError(f"Failed to infer dtype and precision for module_related_func node {node.name}") + raise RuntimeError( + f"Failed to infer dtype and precision for module_related_func node {node.name}" + ) for result in node.meta["mase"].parameters["common"]["results"]: - node.meta["mase"].parameters["common"]["results"][result]["type"] = max_dtype - node.meta["mase"].parameters["common"]["results"][result]["precision"] = max_precision - logger.debug(f"Inferred result dtype and precision for module_related_func node `{node.name}` using its args") + node.meta["mase"].parameters["common"]["results"][result][ + "type" + ] = max_dtype + node.meta["mase"].parameters["common"]["results"][result][ + "precision" + ] = max_precision + logger.debug( + f"Inferred result dtype and precision for module_related_func node `{node.name}` using its args" + ) else: # use next compute node's args to infer dtype and precision i = 0 @@ -181,13 +220,21 @@ def infer_result_dtype_and_precision(node: Node): if n is input_node: break i += 1 - arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[ + i + ] arg_value = next_node.meta["mase"].parameters["common"]["args"][arg_key] for result in node.meta["mase"].parameters["common"]["results"]: - node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] - node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] - logger.debug(f"Inferred result dtype and precision for module_related_func node `{node.name}` using `{next_node.name}`") + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result][ + "precision" + ] = arg_value["precision"] + logger.debug( + f"Inferred result dtype and precision for module_related_func node `{node.name}` using `{next_node.name}`" + ) elif get_mase_type(node) == "implicit_func": input_node, next_node = find_next_compute_node(node) @@ -199,21 +246,38 @@ def infer_result_dtype_and_precision(node: Node): if n is input_node: break i += 1 - arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_key = list(next_node.meta["mase"].parameters["common"]["args"].keys())[ + i + ] arg_value = next_node.meta["mase"].parameters["common"]["args"][arg_key] for result in node.meta["mase"].parameters["common"]["results"]: - node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] - node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result][ + "precision" + ] = arg_value["precision"] for arg in node.meta["mase"].parameters["common"]["args"]: - if not isinstance(node.meta["mase"].parameters["common"]["args"][arg], dict): + if not isinstance( + node.meta["mase"].parameters["common"]["args"][arg], dict + ): continue - if not "precision" in node.meta["mase"].parameters["common"]["args"][arg]: + if ( + not "precision" + in node.meta["mase"].parameters["common"]["args"][arg] + ): continue - node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value["type"] - node.meta["mase"].parameters["common"]["args"][arg]["precision"] = arg_value["precision"] - logger.debug(f"Inferred arg & result dtype and precision for implicit_func node `{node.name}` using `{next_node.name}`") + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value[ + "type" + ] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = ( + arg_value["precision"] + ) + logger.debug( + f"Inferred arg & result dtype and precision for implicit_func node `{node.name}` using `{next_node.name}`" + ) elif prev_node is not None: i = 0 @@ -221,20 +285,34 @@ def infer_result_dtype_and_precision(node: Node): if n is user_node: break i += 1 - arg_key = list(prev_node.meta["mase"].parameters["common"]["args"].keys())[i] + arg_key = list(prev_node.meta["mase"].parameters["common"]["args"].keys())[ + i + ] arg_value = prev_node.meta["mase"].parameters["common"]["args"][arg_key] for result in node.meta["mase"].parameters["common"]["results"]: - node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] - node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result][ + "precision" + ] = arg_value["precision"] for arg in node.meta["mase"].parameters["common"]["args"]: - node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value["type"] - node.meta["mase"].parameters["common"]["args"][arg]["precision"] = arg_value["precision"] - logger.debug(f"Inferred arg & result dtype and precision for implicit_func node `{node.name}` using `{prev_node.name}`") + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value[ + "type" + ] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = ( + arg_value["precision"] + ) + logger.debug( + f"Inferred arg & result dtype and precision for implicit_func node `{node.name}` using `{prev_node.name}`" + ) else: - raise RuntimeError(f"Failed to infer dtype and precision for implicit_func node {node.name} as it has no input nodes or users of type `module_related_func`") + raise RuntimeError( + f"Failed to infer dtype and precision for implicit_func node {node.name} as it has no input nodes or users of type `module_related_func`" + ) elif get_mase_type(node) == "output": # output node @@ -242,7 +320,9 @@ def infer_result_dtype_and_precision(node: Node): user_node, prev_node = find_prev_compute_node(node) if prev_node is None: - raise RuntimeError(f"Failed to find prev module_related_func node for output node {node.name}") + raise RuntimeError( + f"Failed to find prev module_related_func node for output node {node.name}" + ) max_precision = None max_dtype = None @@ -258,16 +338,26 @@ def infer_result_dtype_and_precision(node: Node): arg_value = prev_node.meta["mase"].parameters["common"]["args"][arg_key] for result in node.meta["mase"].parameters["common"]["results"]: - node.meta["mase"].parameters["common"]["results"][result]["type"] = arg_value["type"] - node.meta["mase"].parameters["common"]["results"][result]["precision"] = arg_value["precision"] + node.meta["mase"].parameters["common"]["results"][result]["type"] = ( + arg_value["type"] + ) + node.meta["mase"].parameters["common"]["results"][result]["precision"] = ( + arg_value["precision"] + ) for arg in node.meta["mase"].parameters["common"]["args"]: - node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value["type"] - node.meta["mase"].parameters["common"]["args"][arg]["precision"] = arg_value["precision"] - - logger.debug(f"Inferred dtype and precision for output node `{node.name}` using `{prev_node.name}`") + node.meta["mase"].parameters["common"]["args"][arg]["type"] = arg_value[ + "type" + ] + node.meta["mase"].parameters["common"]["args"][arg]["precision"] = ( + arg_value["precision"] + ) + + logger.debug( + f"Inferred dtype and precision for output node `{node.name}` using `{prev_node.name}`" + ) else: - raise RuntimeError(f"Unsupported node type {get_mase_type(node)} for inferring dtype and precision") - - + raise RuntimeError( + f"Unsupported node type {get_mase_type(node)} for inferring dtype and precision" + ) diff --git a/machop/chop/passes/graph/transforms/quantize/quantize.py b/machop/chop/passes/graph/transforms/quantize/quantize.py index 312d6d820..febb27e50 100644 --- a/machop/chop/passes/graph/transforms/quantize/quantize.py +++ b/machop/chop/passes/graph/transforms/quantize/quantize.py @@ -15,7 +15,12 @@ ) from .modify import create_new_fn, create_new_module -from .quant_parsers import parse_node_q_config, relink_node_meta, update_q_meta_param, infer_result_dtype_and_precision +from .quant_parsers import ( + parse_node_q_config, + relink_node_meta, + update_q_meta_param, + infer_result_dtype_and_precision, +) from .summary import graph_iterator_compare_nodes, graph_iterator_node_histogram logger = logging.getLogger(__name__) @@ -99,7 +104,13 @@ def graph_iterator_quantize_by_type(graph, config: dict): graph.fx_graph.erase_node(node) for node in graph.fx_graph.nodes: - if get_mase_type(node) in ["module_related_func", "builtin_func", "output", "placeholder", "implicit_func"]: + if get_mase_type(node) in [ + "module_related_func", + "builtin_func", + "output", + "placeholder", + "implicit_func", + ]: infer_result_dtype_and_precision(node) return graph @@ -155,7 +166,13 @@ def graph_iterator_quantize_by_name(graph, config: dict): "Unsupported node type for quantisation: {}".format(get_mase_type(node)) ) for node in graph.fx_graph.nodes: - if get_mase_type(node) in ["module_related_func", "builtin_func", "output", "placeholder", "implicit_func"]: + if get_mase_type(node) in [ + "module_related_func", + "builtin_func", + "output", + "placeholder", + "implicit_func", + ]: infer_result_dtype_and_precision(node) return graph @@ -200,7 +217,13 @@ def graph_iterator_quantize_by_regex_name(graph, config: dict): "Unsupported node type for quantisation:{}".format(get_mase_type(node)) ) for node in graph.fx_graph.nodes: - if get_mase_type(node) in ["module_related_func", "builtin_func", "output", "placeholder", "implicit_func"]: + if get_mase_type(node) in [ + "module_related_func", + "builtin_func", + "output", + "placeholder", + "implicit_func", + ]: infer_result_dtype_and_precision(node) return graph From d825cbe11d3307a6dc0d44b81b06199003a27439 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Sun, 28 Apr 2024 17:28:11 +0000 Subject: [PATCH 27/31] Added bit truncation in bram param to avoid verilator warnings --- machop/chop/passes/graph/transforms/verilog/emit_bram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/machop/chop/passes/graph/transforms/verilog/emit_bram.py b/machop/chop/passes/graph/transforms/verilog/emit_bram.py index ee03b3412..19a17f36d 100644 --- a/machop/chop/passes/graph/transforms/verilog/emit_bram.py +++ b/machop/chop/passes/graph/transforms/verilog/emit_bram.py @@ -146,7 +146,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): if (rst) counter <= 0; else begin if (data_out_ready) begin - if (counter == OUT_DEPTH - 1) counter <= 0; + if (counter == COUNTER_WIDTH'(OUT_DEPTH) - 1) counter <= 0; else counter <= counter + 1; end end From 3a23cb2a762376ea46d395cba2c250d682ffa4cd Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Sun, 28 Apr 2024 17:38:14 +0000 Subject: [PATCH 28/31] Fixed bitwidth error (this is temporary for the version of casting inside hardware components --- .../linear/rtl/fixed_linear.sv | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/machop/mase_components/linear/rtl/fixed_linear.sv b/machop/mase_components/linear/rtl/fixed_linear.sv index 584733bc1..ce954dfa0 100644 --- a/machop/mase_components/linear/rtl/fixed_linear.sv +++ b/machop/mase_components/linear/rtl/fixed_linear.sv @@ -60,9 +60,9 @@ module fixed_linear #( /* verilator lint_on UNUSEDSIGNAL */ output bias_ready, - output [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_1-1:0], - output data_out_0_valid, - input data_out_0_ready + output [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output data_out_0_valid, + input data_out_0_ready ); localparam FDP_WIDTH = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2( @@ -71,6 +71,7 @@ module fixed_linear #( localparam ACC_WIDTH = FDP_WIDTH + $clog2( DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1 ); + logic [ACC_WIDTH-1:0] data_out_buff[DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1-1:0]; logic fdp_join_valid, fdp_join_ready; join2 #() fdp_join_inst ( @@ -170,12 +171,12 @@ module fixed_linear #( ); for (genvar i = 0; i < BIAS_PARALLELISM_DIM_0; i = i + 1) begin : add_bias - logic [DATA_OUT_0_PRECISION_0-1:0] add; + logic [ACC_WIDTH-1:0] add; assign add = $signed(acc_data_out[i]) + $signed(bias_sext[i]); /* verilator lint_off UNUSEDSIGNAL */ logic dout_valid; skid_buffer #( - .DATA_WIDTH(DATA_OUT_0_PRECISION_0) + .DATA_WIDTH(ACC_WIDTH) ) register_slice ( .clk (clk), .rst (rst), @@ -184,7 +185,7 @@ module fixed_linear #( .data_in (add), .data_out_valid(dout_valid), .data_out_ready(data_out_0_ready), - .data_out (data_out_0[i]) + .data_out (data_out_buff[i]) ); end assign data_out_0_valid = add_bias[0].dout_valid; @@ -192,8 +193,21 @@ module fixed_linear #( end else begin assign acc_ready = data_out_0_ready; assign data_out_0_valid = linear[0].acc_data_out_valid; - assign data_out_0 = acc_data_out; + assign data_out_buff = acc_data_out; assign bias_ready = 1; end + + fixed_rounding #( + .IN_SIZE(DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1), + .IN_WIDTH(ACC_WIDTH), + .IN_FRAC_WIDTH(DATA_IN_0_PRECISION_1 + WEIGHT_PRECISION_1), + .OUT_WIDTH(DATA_OUT_0_PRECISION_0), + .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1) + ) fr_inst ( + .data_in (data_out_buff), + .data_out(data_out_0) + ); + + endmodule From 80c671fe4f53480bf31b72159e74c1487e822082 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Sun, 28 Apr 2024 21:18:53 +0000 Subject: [PATCH 29/31] Fixed minor parallelism parameter shapes --- .../analysis/add_metadata/add_hardware_metadata.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py index 890e3a251..0150791e1 100644 --- a/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py +++ b/machop/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py @@ -49,10 +49,16 @@ def add_component_source(node): node.meta["mase"]["hardware"]["parallelism"] = {} args = node.meta["mase"]["common"]["args"] for arg, arg_info in args.items(): - node.meta["mase"]["hardware"]["parallelism"][arg] = {0: 1, 1: 1, 2: 1} + if isinstance(arg_info, dict): + node.meta["mase"]["hardware"]["parallelism"][arg] = [ + 1 for _ in range(len(arg_info["shape"])) + ] + results = node.meta["mase"]["common"]["results"] for result, result_info in results.items(): - node.meta["mase"]["hardware"]["parallelism"][result] = {0: 1, 1: 1, 2: 1} + node.meta["mase"]["hardware"]["parallelism"][result] = [ + 1 for _ in range(len(result_info["shape"])) + ] # Current only support on-chip parameters args = node.meta["mase"]["common"]["args"] From 8b5309c9d562eea5991fc9540798e588dac0191b Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Sun, 28 Apr 2024 22:34:53 +0000 Subject: [PATCH 30/31] Get working flow for hardware testing - but need to check hardware results --- .../graph/analysis/verilog/test_verilog.py | 276 +++++++++++++++--- 1 file changed, 239 insertions(+), 37 deletions(-) diff --git a/machop/chop/passes/graph/analysis/verilog/test_verilog.py b/machop/chop/passes/graph/analysis/verilog/test_verilog.py index a7598a653..fd3d4e39c 100644 --- a/machop/chop/passes/graph/analysis/verilog/test_verilog.py +++ b/machop/chop/passes/graph/analysis/verilog/test_verilog.py @@ -1,15 +1,229 @@ -import logging -import os, glob +import logging, toml, os, glob, math from pathlib import Path +import torch -from .cocotb import VerificationCase +import cocotb from cocotb.runner import get_runner +from cocotb.triggers import Timer +from cocotb.triggers import FallingEdge +from cocotb.clock import Clock from chop.passes.graph.utils import vf from mase_cocotb.random_test import RandomSource, RandomSink, check_results logger = logging.getLogger(__name__) +# ============================================================================= +# DUT test specifications +# ============================================================================= + + +def hardware_reshape(input_data, input_shape, tiling): + """ + Apply 2D tiling. TODO: For higher dimensions, just faltten it in time. + """ + + assert len(input_shape) == 2, "Default hardware test bench only support 2D inputs" + + row_size = int(math.ceil(input_shape[0] / tiling[0])) + col_size = int(math.ceil(input_shape[1] / tiling[1])) + output_data = [ + [0 for _ in range(tiling[1] * tiling[0])] for _ in range(row_size * col_size) + ] + for i in range(row_size): + for j in range(col_size): + for ii in range(0, tiling[0]): + for jj in range(0, tiling[1]): + rowi = i * tiling[0] + ii + coli = j * tiling[1] + jj + if rowi < input_shape[0] and coli < input_shape[1]: + output_data[i * row_size + j][ii * tiling[1] + jj] = int( + input_data[rowi][coli] + ) + + return output_data + + +class VerificationCase: + # TODO: sample > 1 needs to be added + def __init__(self, samples=1): + self.samples = samples + + def generate_tv(self, mg): + """ + Generate test vector and emit to ~/.mase.{pid}.toml + """ + + # Generate random inputs + test_inputs = {} + # TODO: here we just enumerate the inputs of the input nodes - which may be + # order insensitive and require manual connection when adding the graph to + # a system. + name_idx = 0 + for node in mg.nodes_in: + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if "data_in" in arg: + test_inputs[f"data_in_{name_idx}"] = torch.randint( + 32, arg_info["shape"] + ) + name_idx += 1 + logger.debug(test_inputs) + + # Get software results + y = mg.model(*list(test_inputs.values())) + + output_toml = {} + output_toml["samples"] = 1 + + # Reshape values for hardware testing + # TODO: assume 2D inputs + reshaped_inputs = {} + name_idx = 0 + for node in mg.nodes_in: + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if "data_in" in arg: + + # By default: the data is passed column by column + reshaped_inputs[f"data_in_{name_idx}"] = hardware_reshape( + test_inputs[f"data_in_{name_idx}"], + arg_info["shape"], + node.meta["mase"].parameters["hardware"]["parallelism"][arg], + ) + name_idx += 1 + + output_toml["inputs"] = reshaped_inputs + + assert len(mg.nodes_out) == 1, "Expect the model only has one output!" + reshaped_y = reshaped_inputs[f"data_out_0"] = hardware_reshape( + y, + mg.nodes_out[0] + .meta["mase"] + .parameters["common"]["results"]["data_out_0"]["shape"], + mg.nodes_out[0] + .meta["mase"] + .parameters["hardware"]["parallelism"]["data_out_0"], + ) + + output_toml["outputs"] = {"data_out_0": reshaped_y} + + home = Path.home() + Path(os.path.join(home, f".mase")).mkdir(parents=True, exist_ok=True) + fname = os.path.join(home, f".mase", f"tv.toml") + assert not os.path.isfile( + fname + ), f"Cannot create a temporary toml for testing data - {fname} already exists" + with open(fname, "w") as toml_file: + toml.dump(output_toml, toml_file) + + logger.debug(f"Test data saved to {fname}") + + def load_tv(self, fname=""): + home = Path.home() + fname = os.path.join(home, ".mase", f"tv.toml") + assert os.path.isfile( + fname + ), f"Cannot find the temporary toml for testing data - {fname}" + with open(fname, "r") as f: + input_toml = toml.load(f) + + self.samples = input_toml["samples"] + + for val, values in input_toml["inputs"].items(): + setattr( + self, + val, + RandomSource( + name=val, + samples=len(values), + num=len(values[0]), + max_stalls=0, + ), + ) + source = getattr(self, val) + source.data = values + + for val, values in input_toml["outputs"].items(): + setattr( + self, + val, + RandomSink( + name=val, + samples=len(values), + num=len(values[0]), + max_stalls=0, + ), + ) + self.ref = values + + os.remove(fname) + logger.debug(f"Test data loaded from {fname}") + + +class TestBehavior: + async def test_bench_behavior(dut): + """Test top-level model hardware design (default behavior)""" + test_case = VerificationCase() + test_case.load_tv() + + # Reset cycle + await Timer(20, units="ns") + dut.rst.value = 1 + await Timer(100, units="ns") + dut.rst.value = 0 + + # Create a 10ns-period clock on port clk + clock = Clock(dut.clk, 10, units="ns") + # Start the clock + cocotb.start_soon(clock.start()) + await Timer(500, units="ns") + + # Synchronize with the clock + dut.data_in_0_valid.value = 0 + dut.data_out_0_ready.value = 1 + await FallingEdge(dut.clk) + await FallingEdge(dut.clk) + + done = False + # Set a timeout to avoid deadlock + for i in range(test_case.samples * 100): + await FallingEdge(dut.clk) + + dut.data_in_0_valid.value = test_case.data_in_0.pre_compute() + await Timer(1, units="ns") + + dut.data_out_0_ready.value = test_case.data_out_0.pre_compute( + dut.data_out_0_valid.value + ) + await Timer(1, units="ns") + + dut.data_in_0_valid.value, dut.data_in_0.value = ( + test_case.data_in_0.compute(dut.data_in_0_ready.value) + ) + await Timer(1, units="ns") + + dut.data_out_0_ready.value = test_case.data_out_0.compute( + dut.data_out_0_valid.value, dut.data_out_0.value + ) + + if test_case.data_in_0.is_empty() and test_case.data_out_0.is_full(): + done = True + break + assert ( + done + ), "Deadlock detected or the simulation reaches the maximum cycle limit (fixed it by adjusting the loop trip count)" + + check_results(test_case.data_out_0.data, test_case.ref) + + +# ============================================================================= +# Cocotb interface setup +# ============================================================================= + + +@cocotb.test() +async def test_top(dut): + await TestBehavior.test_bench_behavior(dut) + def get_dut_parameters(graph): parameter_map = {} @@ -39,38 +253,6 @@ def runner(mg, project_dir, top_name): for v in glob.glob(os.path.join(project_dir, "hardware", "rtl", "*.sv")): sv_srcs.append(os.path.relpath(v, os.getcwd())) - # TODO: make samples and iterations variable - tb = VerificationCase(samples=1, iterations=1) - - # TODO: work out the num - for node in mg.nodes_in: - for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): - setattr( - tb, - arg, - RandomSource( - name=arg, - samples=tb.samples * tb.iterations, - num=12324, - max_stalls=0, - ), - ) - - for node in mg.nodes_out: - for result, result_info in ( - node.meta["mase"].parameters["common"]["results"].items() - ): - setattr( - tb, - result, - RandomSink( - name=result, - samples=tb.samples * tb.iterations, - num=12324, - max_stalls=0, - ), - ) - p = get_dut_parameters(mg) # logger.debug(p) @@ -85,7 +267,11 @@ def runner(mg, project_dir, top_name): hdl_toplevel=top_name, build_args=extra_args, ) - runner.test(hdl_toplevel=top_name, test_module=f"{top_name}_tb") + + runner.test( + hdl_toplevel=top_name, + test_module=f"chop.passes.graph.analysis.verilog.test_verilog", + ) def test_verilog_analysis_pass(mg, pass_args={}): @@ -101,9 +287,13 @@ def test_verilog_analysis_pass(mg, pass_args={}): - pass_args - project_dir -> str : the directory of the project for cosimulation - top_name -> str : top-level name + - samples -> str : the number of test inputs, samples = 1 by default + - test_bench -> str : the test bench behavior specified by the user, which runs end-to-end simulation by default + - preprocess -> str : preprocess of IO for testing, which generates random inputs by default """ logger.info(f"Running hardware simulation using Cocotb") + logger.debug(f"test verilog pass pass_args = {pass_args}") project_dir = ( pass_args["project_dir"] @@ -111,7 +301,19 @@ def test_verilog_analysis_pass(mg, pass_args={}): else Path.home() / ".mase" / "top" ) top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" - logger.info(f"Project path: {project_dir}") + samples = pass_args["samples"] if "samples" in pass_args.keys() else 1 + + # TODO: Create a global variable traced by pass ID. This is bad... + test_case = VerificationCase(samples) + globals()["test_verilog_analysis_pass_tc"] = test_case + print(globals()) + + if "preprocess" in pass_args.keys(): + test_case.preprocess = pass_args["preprocess"] + if "test_bench" in pass_args.keys(): + test_case.test_bench_behavior = pass_args["test_bench"] + + test_case.generate_tv(mg) runner(mg, project_dir, top_name) return mg, {} From 047f27b9b156a7575b416c1156ea571c229cf9e8 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Mon, 29 Apr 2024 17:09:48 +0100 Subject: [PATCH 31/31] Mlp quantization error (#178) * Created test case with errors * cache quantized weight after the first inference * Removed quantization check --------- Co-authored-by: Cheng Zhang --- .../quant_parsers/q_op_entries/fixed.py | 2 +- .../quantize/quantized_modules/linear.py | 42 ++++++++++--------- machop/configs/tests/quantize/fixed.toml | 1 + .../verilog/test_emit_verilog_linear.py | 6 ++- 4 files changed, 29 insertions(+), 22 deletions(-) diff --git a/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py index f1c8c8da1..63d8fa0bc 100644 --- a/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py +++ b/machop/chop/passes/graph/transforms/quantize/quant_parsers/q_op_entries/fixed.py @@ -41,7 +41,7 @@ "weight_width", "weight_frac_width", ), - "optional": ("bypass", "bias_width", "bias_frac_width"), + "optional": ("bypass", "cache_quantized_weight", "bias_width", "bias_frac_width"), }, "matmul": { "required": ( diff --git a/machop/chop/passes/graph/transforms/quantize/quantized_modules/linear.py b/machop/chop/passes/graph/transforms/quantize/quantized_modules/linear.py index 710396e49..3411fe218 100644 --- a/machop/chop/passes/graph/transforms/quantize/quantized_modules/linear.py +++ b/machop/chop/passes/graph/transforms/quantize/quantized_modules/linear.py @@ -116,25 +116,29 @@ def __init__( integer_quantizer, width=b_width, frac_width=b_frac_width ) - # def get_output_bitwidth(self): - # config = self.config - # w_width, w_frac = config["weight_width"], config["weight_frac_width"] - # x_width, x_frac = config["data_in_width"], config["data_in_frac_width"] - # bias_width = config["bias_width"] - - # ops = self.in_features - # product_width = w_width + x_width - # product_frac_width = w_frac + x_frac - # # *: + 1 for bias - # output_width = max(bias_width, product_width + ceil(log2(ops))) + 1 - # output_frac_width = product_frac_width - - # o_bitwidth = {} - # o_bitwidth["data_out_width"] = output_width - # o_bitwidth["data_out_frac_width"] = output_frac_width - # # o_bitwidth["product_width"] = product_width - # # o_bitwidth["product_frac_width"] = product_frac_width - # return o_bitwidth + self.quantized_weight_is_cached = False + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + # if bypss, there is no quantization + return F.linear(x, self.weight, self.bias) + else: + x = self.x_quantizer(x) + if self.config.get("cache_quantized_weight", False): + if not self.quantized_weight_is_cached: + w = self.w_quantizer(self.weight) + self.weight.copy_(w) + if self.bias is not None: + bias = self.b_quantizer(self.bias) + self.bias.copy_(bias) + self.quantized_weight_is_cached = True + else: + w = self.weight + bias = self.bias + else: + w = self.w_quantizer(self.weight) + bias = self.b_quantizer(self.bias) if self.bias is not None else None + return F.linear(x, w, bias) class LinearMinifloatDenorm(_LinearBase): diff --git a/machop/configs/tests/quantize/fixed.toml b/machop/configs/tests/quantize/fixed.toml index ad4c79007..cf1eab877 100644 --- a/machop/configs/tests/quantize/fixed.toml +++ b/machop/configs/tests/quantize/fixed.toml @@ -7,6 +7,7 @@ dataset="toy-tiny" [passes.quantize.default.config] name="fixed" + cache_quantized_weight=true data_in_width=8 data_in_frac_width=3 weight_width=8 diff --git a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py index 2a40e55ba..93cb7d1cb 100644 --- a/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py +++ b/machop/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # This example converts a simple MLP model to Verilog import os, sys, logging -import toml +import toml, math import torch import torch.nn as nn @@ -75,7 +75,9 @@ def test_emit_verilog_linear(): # load toml config file with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] - mg, _ = passes.quantize_transform_pass(mg, quan_args) + with torch.no_grad(): + mg, _ = passes.quantize_transform_pass(mg, quan_args) + mg.model(dummy_in["x"]) # inspect the graph metadata # mg, _ = passes.report_node_meta_param_analysis_pass(mg)