diff --git a/pyproject.toml b/pyproject.toml index 9c0f1289..83a52788 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "click", "amaranth==0.4.0.*", "marshmallow_dataclass", - "numexpr", + "simpleeval", "typing_extensions", "marshmallow-dataclass[enum]", "typing_extensions", diff --git a/tests/tests_kpm/test_kpm_import.py b/tests/tests_kpm/test_kpm_import.py index 6cf8d499..fb9c8fd1 100644 --- a/tests/tests_kpm/test_kpm_import.py +++ b/tests/tests_kpm/test_kpm_import.py @@ -378,7 +378,6 @@ def find_node_with_interface_id(design_graphs, iface_id, conn_id): node_names.append(find_node_with_interface_id(design_graphs, conn["to"], conn["id"])) node_names.append(find_node_with_interface_id(design_graphs, conn["from"], conn["id"])) node_occurrence_dict = {item: node_names.count(item) for item in node_names} - breakpoint() conn_dict = { "Constant": 2, "External Input": 2, diff --git a/topwrap/hdl_parsers_utils.py b/topwrap/hdl_parsers_utils.py index 76ad0f82..224b697e 100644 --- a/topwrap/hdl_parsers_utils.py +++ b/topwrap/hdl_parsers_utils.py @@ -4,7 +4,7 @@ from enum import Enum from logging import warning -import numexpr as ex +from simpleeval import SimpleEval from topwrap.amaranth_helpers import DIR_IN, DIR_INOUT, DIR_OUT @@ -26,7 +26,7 @@ class PortDefinition: direction: PortDirection -def _eval_param(val, params: dict): +def _eval_param(val, params: dict, simpleeval_instance: SimpleEval): """Function used to calculate parameter value. It is used for evaluating CONCAT and REPL_CONCAT in resolve_ops()""" if isinstance(val, int): @@ -34,7 +34,7 @@ def _eval_param(val, params: dict): if isinstance(val, dict) and val.keys() == {"value", "width"}: return val if isinstance(val, str): - return _eval_param(params[val], params) + return _eval_param(params[val], params, simpleeval_instance) elif val["__class__"] == "HdlValueInt": value = int(val["val"], val["base"]) @@ -45,8 +45,8 @@ def _eval_param(val, params: dict): elif val["__class__"] == "HdlOp": if val["fn"] == "CONCAT": - bit_vector_l = _eval_param(val["ops"][0], params) - bit_vector_r = _eval_param(val["ops"][1], params) + bit_vector_l = _eval_param(val["ops"][0], params, simpleeval_instance) + bit_vector_r = _eval_param(val["ops"][1], params, simpleeval_instance) bin_l = bin(int(bit_vector_l["value"], 16))[2:].zfill(bit_vector_l["width"]) bin_r = bin(int(bit_vector_r["value"], 16))[2:].zfill(bit_vector_r["width"]) return { @@ -55,16 +55,23 @@ def _eval_param(val, params: dict): } elif val["fn"] == "REPL_CONCAT": - repeat = _eval_param(val["ops"][0], params) - bit_vector = _eval_param(val["ops"][1], params) + repeat = _eval_param(val["ops"][0], params, simpleeval_instance) + bit_vector = _eval_param(val["ops"][1], params, simpleeval_instance) bin_val = bin(int(bit_vector["value"], 16))[2:].zfill(bit_vector["width"]) return {"value": hex(int(repeat * bin_val, 2)), "width": repeat * bit_vector["width"]} else: - return int(ex.evaluate(resolve_ops(val, params)).take(0)) + simpleeval_instance.names = params + return int( + simpleeval_instance.eval( + _eval_param( + resolve_ops(val, params, simpleeval_instance), params, simpleeval_instance + ).replace("$", "") + ) + ) -def resolve_ops(val, params: dict): +def resolve_ops(val, params: dict, simpleeval_instance: SimpleEval): """Get 'val' representation, that will be used in ip core yaml :param val: expression gathered from HdlConvertor data. @@ -106,35 +113,39 @@ def resolve_ops(val, params: dict): "GT": ">", "GE": ">=", "DOWNTO": ":", + "PART_SELECT_PRE": "-:", } if val["fn"] in bin_ops.keys(): op = bin_ops[val["fn"]] - left = resolve_ops(val["ops"][0], params) - right = resolve_ops(val["ops"][1], params) + left = resolve_ops(val["ops"][0], params, simpleeval_instance) + right = resolve_ops(val["ops"][1], params, simpleeval_instance) if left is None or right is None: return None else: - return "(" + str(left) + op + str(right) + ")" + return f"({left}{op}{right})" elif val["fn"] == "TERNARY": - cond = resolve_ops(val["ops"][0], params) - if_true = resolve_ops(val["ops"][1], params) - if_false = resolve_ops(val["ops"][2], params) + cond = resolve_ops(val["ops"][0], params, simpleeval_instance) + if_true = resolve_ops(val["ops"][1], params, simpleeval_instance) + if_false = resolve_ops(val["ops"][2], params, simpleeval_instance) if cond is None or if_true is None or if_false is None: return None else: - return "where(" + str(cond) + ", " + str(if_true) + ", " + str(if_false) + ")" + return f"({str(if_true)} if {str(cond)} else {str(if_false)})" - elif val["fn"] == "CONCAT" or val["fn"] == "REPL_CONCAT": + elif val["fn"] in ["CONCAT", "REPL_CONCAT"]: # TODO - try to find a better way to get parameters default values # than copying params and evaluating them before each (REPL_)CONCAT params_cp = params.copy() for name in params_cp.keys(): if isinstance(params_cp[name], str): - params_cp[name] = int(ex.evaluate(params_cp[name], params_cp).take(0)) + simpleeval_instance.names = params_cp + params_cp[name] = int( + simpleeval_instance.eval(params_cp[name].replace("$", "")) + ) - return _eval_param(val, params_cp) + return _eval_param(val, params_cp, simpleeval_instance) elif val["fn"] == "PARAMETRIZATION": if ( @@ -144,13 +155,16 @@ def resolve_ops(val, params: dict): ): # corner case - this happens in Verilog's output/input reg return "(0:0)" - return resolve_ops(val["ops"][1], params) + return resolve_ops(val["ops"][1], params, simpleeval_instance) elif val["fn"] == "INDEX": # this happens in VHDL's 'std_logic_vector({up_id DOWNTO low_id}) # drop `std_logic_vector` and process the insides of parentheses - return resolve_ops(val["ops"][1], params) + if val["ops"][0] == "std_logic_vector": + return resolve_ops(val["ops"][1], params, simpleeval_instance) + else: + return f"{val['ops'][0]}[{resolve_ops(val['ops'][1], params, simpleeval_instance)}]" elif val["fn"] == "CALL": - return f"{resolve_ops(val['ops'][0], params)}({','.join(resolve_ops(arg, params) for arg in val['ops'][1:])})" + return f"{resolve_ops(val['ops'][0], params, simpleeval_instance)}({','.join(resolve_ops(str(arg), params, simpleeval_instance) for arg in val['ops'][1:])})" else: warning(f'resolve_ops: unhandled HdlOp function: {val["fn"]}') diff --git a/topwrap/ipwrapper.py b/topwrap/ipwrapper.py index 7a5596fc..f6af7b3c 100644 --- a/topwrap/ipwrapper.py +++ b/topwrap/ipwrapper.py @@ -5,10 +5,10 @@ from pathlib import Path from typing import List -import numexpr as ex from amaranth import Instance, Module, Signal from amaranth.build import Platform from amaranth.hdl.ast import Cat, Const +from simpleeval import simple_eval from topwrap.ip_desc import IPCoreComplexParameter, IPCoreDescription @@ -45,8 +45,10 @@ def _evaluate_parameters(params: dict): if isinstance(param, IPCoreComplexParameter): params[name] = Const(param.value, shape=(param.width)) elif isinstance(param, str): - if ex.validate(param, params) is None: - params[name] = int(ex.evaluate(param, params).take(0)) + try: + params[name] = simple_eval(param, names=params) + except Exception as e: + error(f"evaluating expression {name} failed with the following message: {str(e)}") def _eval_bounds(bounds, params): @@ -55,7 +57,7 @@ def _eval_bounds(bounds, params): for i, item in enumerate(bounds): if isinstance(item, str): try: - result[i] = int(ex.evaluate(item, params).take(0)) + result[i] = int(simple_eval(item, names=params)) except TypeError: error( "Could not evaluate expression with parameter: " diff --git a/topwrap/kpm_dataflow_validator.py b/topwrap/kpm_dataflow_validator.py index 51607400..15b54551 100644 --- a/topwrap/kpm_dataflow_validator.py +++ b/topwrap/kpm_dataflow_validator.py @@ -4,8 +4,8 @@ import re from typing import NamedTuple, Union -import numexpr as ex from pipeline_manager_backend_communication.misc_structures import MessageType +from simpleeval import simple_eval from topwrap.design_to_kpm_dataflow_parser import KPMDataflowSubgraphnode @@ -65,7 +65,7 @@ def _check_parameters_values(dataflow_data, specification) -> CheckResult: if not re.match(r"\d+\'[hdob][\dabcdefABCDEF]+", param_val): try: - evaluated[param_name] = int(ex.evaluate(param_val, evaluated).take(0)) + evaluated[param_name] = simple_eval(param_val, names=evaluated) except (ValueError, KeyError, SyntaxError, OverflowError): invalid_params.append(f"{node['name']}:{param_name}") diff --git a/topwrap/verilog_parser.py b/topwrap/verilog_parser.py index 9030f01e..58a41261 100644 --- a/topwrap/verilog_parser.py +++ b/topwrap/verilog_parser.py @@ -7,6 +7,7 @@ from hdlConvertor import HdlConvertor from hdlConvertorAst.language import Language from hdlConvertorAst.to.json import ToJson +from simpleeval import SimpleEval from typing_extensions import override from .hdl_module import HDLModule, HDLParameter @@ -33,7 +34,7 @@ def module_name(self) -> str: def parameters(self) -> Dict[str, HDLParameter]: params = {} for item in self._data["dec"]["params"]: - param_val = resolve_ops(item["value"], params) + param_val = resolve_ops(item["value"], params, SimpleEval()) if param_val is not None: params[item["name"]["val"]] = param_val return params @@ -52,7 +53,7 @@ def ports(self) -> Set[PortDefinition]: if type_or_bounds == "wire" or type_or_bounds["__class__"] == "HdlTypeAuto": ubound, lbound = "0", "0" else: - resolved_ops = resolve_ops(type_or_bounds, self.parameters) + resolved_ops = resolve_ops(type_or_bounds, self.parameters, SimpleEval()) if resolved_ops is not None: ubound, lbound = resolved_ops[1:-1].split(":") else: diff --git a/topwrap/vhdl_parser.py b/topwrap/vhdl_parser.py index bf0b0efa..520a8cc1 100644 --- a/topwrap/vhdl_parser.py +++ b/topwrap/vhdl_parser.py @@ -7,6 +7,7 @@ from hdlConvertor import HdlConvertor from hdlConvertorAst.language import Language from hdlConvertorAst.to.json import ToJson +from simpleeval import SimpleEval from typing_extensions import override from .hdl_module import HDLModule, HDLParameter @@ -43,7 +44,7 @@ def module_name(self) -> str: def parameters(self) -> Dict[str, HDLParameter]: params = {} for item in self.__data["params"]: - param_val = resolve_ops(item["value"], params) + param_val = resolve_ops(item["value"], params, SimpleEval()) if param_val is not None: params[item["name"]["val"]] = param_val return params @@ -61,7 +62,7 @@ def ports(self) -> Set[PortDefinition]: if type_or_bounds == "std_logic" or type_or_bounds == "std_ulogic": ubound, lbound = "0", "0" else: - resolved_ops = resolve_ops(type_or_bounds, self.parameters) + resolved_ops = resolve_ops(type_or_bounds, self.parameters, SimpleEval()) if resolved_ops is not None: ubound, lbound = resolved_ops[1:-1].split(":") else: