Skip to content

Commit

Permalink
Change numexpr to simpleeval to allow for more complex param evaluation
Browse files Browse the repository at this point in the history
Simpleeval allows for defining custom functinos like clog2 which is
common in desings.
Constants in numexpr have limited number of bits due to underlying
numpy

Co-authored-by: Krzysztof Obłonczek <[email protected]>
Internal-tag: [#64101]
Signed-off-by: bbrzyski <[email protected]>
  • Loading branch information
bbrzyski and koblonczek committed Aug 22, 2024
1 parent 69567b7 commit 9cd1837
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"click",
"amaranth==0.4.0.*",
"marshmallow_dataclass",
"numexpr",
"simpleeval",
"typing_extensions",
"marshmallow-dataclass[enum]",
"typing_extensions",
Expand Down
1 change: 0 additions & 1 deletion tests/tests_kpm/test_kpm_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ def find_node_with_interface_id(design_graphs, iface_id, conn_id):
node_names.append(find_node_with_interface_id(design_graphs, conn["to"], conn["id"]))
node_names.append(find_node_with_interface_id(design_graphs, conn["from"], conn["id"]))
node_occurrence_dict = {item: node_names.count(item) for item in node_names}
breakpoint()
conn_dict = {
"Constant": 2,
"External Input": 2,
Expand Down
58 changes: 36 additions & 22 deletions topwrap/hdl_parsers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from logging import warning

import numexpr as ex
from simpleeval import SimpleEval

from topwrap.amaranth_helpers import DIR_IN, DIR_INOUT, DIR_OUT

Expand All @@ -26,15 +26,15 @@ class PortDefinition:
direction: PortDirection


def _eval_param(val, params: dict):
def _eval_param(val, params: dict, simpleeval_instance: SimpleEval):
"""Function used to calculate parameter value.
It is used for evaluating CONCAT and REPL_CONCAT in resolve_ops()"""
if isinstance(val, int):
return val
if isinstance(val, dict) and val.keys() == {"value", "width"}:
return val
if isinstance(val, str):
return _eval_param(params[val], params)
return _eval_param(params[val], params, simpleeval_instance)

elif val["__class__"] == "HdlValueInt":
value = int(val["val"], val["base"])
Expand All @@ -45,8 +45,8 @@ def _eval_param(val, params: dict):

elif val["__class__"] == "HdlOp":
if val["fn"] == "CONCAT":
bit_vector_l = _eval_param(val["ops"][0], params)
bit_vector_r = _eval_param(val["ops"][1], params)
bit_vector_l = _eval_param(val["ops"][0], params, simpleeval_instance)
bit_vector_r = _eval_param(val["ops"][1], params, simpleeval_instance)
bin_l = bin(int(bit_vector_l["value"], 16))[2:].zfill(bit_vector_l["width"])
bin_r = bin(int(bit_vector_r["value"], 16))[2:].zfill(bit_vector_r["width"])
return {
Expand All @@ -55,16 +55,23 @@ def _eval_param(val, params: dict):
}

elif val["fn"] == "REPL_CONCAT":
repeat = _eval_param(val["ops"][0], params)
bit_vector = _eval_param(val["ops"][1], params)
repeat = _eval_param(val["ops"][0], params, simpleeval_instance)
bit_vector = _eval_param(val["ops"][1], params, simpleeval_instance)
bin_val = bin(int(bit_vector["value"], 16))[2:].zfill(bit_vector["width"])
return {"value": hex(int(repeat * bin_val, 2)), "width": repeat * bit_vector["width"]}

else:
return int(ex.evaluate(resolve_ops(val, params)).take(0))
simpleeval_instance.names = params
return int(
simpleeval_instance.eval(
_eval_param(
resolve_ops(val, params, simpleeval_instance), params, simpleeval_instance
).replace("$", "")
)
)


def resolve_ops(val, params: dict):
def resolve_ops(val, params: dict, simpleeval_instance: SimpleEval):
"""Get 'val' representation, that will be used in ip core yaml
:param val: expression gathered from HdlConvertor data.
Expand Down Expand Up @@ -106,35 +113,39 @@ def resolve_ops(val, params: dict):
"GT": ">",
"GE": ">=",
"DOWNTO": ":",
"PART_SELECT_PRE": "-:",
}

if val["fn"] in bin_ops.keys():
op = bin_ops[val["fn"]]
left = resolve_ops(val["ops"][0], params)
right = resolve_ops(val["ops"][1], params)
left = resolve_ops(val["ops"][0], params, simpleeval_instance)
right = resolve_ops(val["ops"][1], params, simpleeval_instance)
if left is None or right is None:
return None
else:
return "(" + str(left) + op + str(right) + ")"
return f"({left}{op}{right})"

elif val["fn"] == "TERNARY":
cond = resolve_ops(val["ops"][0], params)
if_true = resolve_ops(val["ops"][1], params)
if_false = resolve_ops(val["ops"][2], params)
cond = resolve_ops(val["ops"][0], params, simpleeval_instance)
if_true = resolve_ops(val["ops"][1], params, simpleeval_instance)
if_false = resolve_ops(val["ops"][2], params, simpleeval_instance)
if cond is None or if_true is None or if_false is None:
return None
else:
return "where(" + str(cond) + ", " + str(if_true) + ", " + str(if_false) + ")"
return f"({str(if_true)} if {str(cond)} else {str(if_false)})"

elif val["fn"] == "CONCAT" or val["fn"] == "REPL_CONCAT":
elif val["fn"] in ["CONCAT", "REPL_CONCAT"]:
# TODO - try to find a better way to get parameters default values
# than copying params and evaluating them before each (REPL_)CONCAT
params_cp = params.copy()
for name in params_cp.keys():
if isinstance(params_cp[name], str):
params_cp[name] = int(ex.evaluate(params_cp[name], params_cp).take(0))
simpleeval_instance.names = params_cp
params_cp[name] = int(
simpleeval_instance.eval(params_cp[name].replace("$", ""))
)

return _eval_param(val, params_cp)
return _eval_param(val, params_cp, simpleeval_instance)

elif val["fn"] == "PARAMETRIZATION":
if (
Expand All @@ -144,13 +155,16 @@ def resolve_ops(val, params: dict):
):
# corner case - this happens in Verilog's output/input reg
return "(0:0)"
return resolve_ops(val["ops"][1], params)
return resolve_ops(val["ops"][1], params, simpleeval_instance)

elif val["fn"] == "INDEX":
# this happens in VHDL's 'std_logic_vector({up_id DOWNTO low_id})
# drop `std_logic_vector` and process the insides of parentheses
return resolve_ops(val["ops"][1], params)
if val["ops"][0] == "std_logic_vector":
return resolve_ops(val["ops"][1], params, simpleeval_instance)
else:
return f"{val['ops'][0]}[{resolve_ops(val['ops'][1], params, simpleeval_instance)}]"
elif val["fn"] == "CALL":
return f"{resolve_ops(val['ops'][0], params)}({','.join(resolve_ops(arg, params) for arg in val['ops'][1:])})"
return f"{resolve_ops(val['ops'][0], params, simpleeval_instance)}({','.join(resolve_ops(str(arg), params, simpleeval_instance) for arg in val['ops'][1:])})"
else:
warning(f'resolve_ops: unhandled HdlOp function: {val["fn"]}')
10 changes: 6 additions & 4 deletions topwrap/ipwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from pathlib import Path
from typing import List

import numexpr as ex
from amaranth import Instance, Module, Signal
from amaranth.build import Platform
from amaranth.hdl.ast import Cat, Const
from simpleeval import simple_eval

from topwrap.ip_desc import IPCoreComplexParameter, IPCoreDescription

Expand Down Expand Up @@ -45,8 +45,10 @@ def _evaluate_parameters(params: dict):
if isinstance(param, IPCoreComplexParameter):
params[name] = Const(param.value, shape=(param.width))
elif isinstance(param, str):
if ex.validate(param, params) is None:
params[name] = int(ex.evaluate(param, params).take(0))
try:
params[name] = simple_eval(param, names=params)
except Exception as e:
error(f"evaluating expression {name} failed with the following message: {str(e)}")


def _eval_bounds(bounds, params):
Expand All @@ -55,7 +57,7 @@ def _eval_bounds(bounds, params):
for i, item in enumerate(bounds):
if isinstance(item, str):
try:
result[i] = int(ex.evaluate(item, params).take(0))
result[i] = int(simple_eval(item, names=params))
except TypeError:
error(
"Could not evaluate expression with parameter: "
Expand Down
4 changes: 2 additions & 2 deletions topwrap/kpm_dataflow_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import re
from typing import NamedTuple, Union

import numexpr as ex
from pipeline_manager_backend_communication.misc_structures import MessageType
from simpleeval import simple_eval

from topwrap.design_to_kpm_dataflow_parser import KPMDataflowSubgraphnode

Expand Down Expand Up @@ -65,7 +65,7 @@ def _check_parameters_values(dataflow_data, specification) -> CheckResult:

if not re.match(r"\d+\'[hdob][\dabcdefABCDEF]+", param_val):
try:
evaluated[param_name] = int(ex.evaluate(param_val, evaluated).take(0))
evaluated[param_name] = simple_eval(param_val, names=evaluated)
except (ValueError, KeyError, SyntaxError, OverflowError):
invalid_params.append(f"{node['name']}:{param_name}")

Expand Down
5 changes: 3 additions & 2 deletions topwrap/verilog_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hdlConvertor import HdlConvertor
from hdlConvertorAst.language import Language
from hdlConvertorAst.to.json import ToJson
from simpleeval import SimpleEval
from typing_extensions import override

from .hdl_module import HDLModule, HDLParameter
Expand All @@ -33,7 +34,7 @@ def module_name(self) -> str:
def parameters(self) -> Dict[str, HDLParameter]:
params = {}
for item in self._data["dec"]["params"]:
param_val = resolve_ops(item["value"], params)
param_val = resolve_ops(item["value"], params, SimpleEval())
if param_val is not None:
params[item["name"]["val"]] = param_val
return params
Expand All @@ -52,7 +53,7 @@ def ports(self) -> Set[PortDefinition]:
if type_or_bounds == "wire" or type_or_bounds["__class__"] == "HdlTypeAuto":
ubound, lbound = "0", "0"
else:
resolved_ops = resolve_ops(type_or_bounds, self.parameters)
resolved_ops = resolve_ops(type_or_bounds, self.parameters, SimpleEval())
if resolved_ops is not None:
ubound, lbound = resolved_ops[1:-1].split(":")
else:
Expand Down
5 changes: 3 additions & 2 deletions topwrap/vhdl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hdlConvertor import HdlConvertor
from hdlConvertorAst.language import Language
from hdlConvertorAst.to.json import ToJson
from simpleeval import SimpleEval
from typing_extensions import override

from .hdl_module import HDLModule, HDLParameter
Expand Down Expand Up @@ -43,7 +44,7 @@ def module_name(self) -> str:
def parameters(self) -> Dict[str, HDLParameter]:
params = {}
for item in self.__data["params"]:
param_val = resolve_ops(item["value"], params)
param_val = resolve_ops(item["value"], params, SimpleEval())
if param_val is not None:
params[item["name"]["val"]] = param_val
return params
Expand All @@ -61,7 +62,7 @@ def ports(self) -> Set[PortDefinition]:
if type_or_bounds == "std_logic" or type_or_bounds == "std_ulogic":
ubound, lbound = "0", "0"
else:
resolved_ops = resolve_ops(type_or_bounds, self.parameters)
resolved_ops = resolve_ops(type_or_bounds, self.parameters, SimpleEval())
if resolved_ops is not None:
ubound, lbound = resolved_ops[1:-1].split(":")
else:
Expand Down

0 comments on commit 9cd1837

Please sign in to comment.