Skip to content

Commit

Permalink
Merge pull request #26 from pnnl/upgrade-numpy-scipy
Browse files Browse the repository at this point in the history
Implement numpy-scipy frontend for comet
  • Loading branch information
pthomadakis authored Sep 19, 2023
2 parents 5af35aa + 6188842 commit fe7276d
Show file tree
Hide file tree
Showing 108 changed files with 5,297 additions and 1,175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,18 @@ class Tensor_Decl_Builder:

tensor_decl_wrapper_text = jinja2.Template(
("" * indentation_size)
+ ' = "ta.dense_tensor_decl"{{dims_tuple}}'
+ '{format = "Dense"} : '
+ ' = "ta.{{decl}}_tensor_decl"{{dims_tuple}}'
+ '{format = {{format}}} : '
+"{{ranges_tuple}} -> "
+ "{{inputtype}}"
+ "\n" ,
undefined=jinja2.StrictUndefined,
)

def __init__(self, decl_vars:list, inputtype: str)->None:
def __init__(self, decl_vars:list, inputtype: str, format: str)->None:
self.inputtype = inputtype
self.decl_vars = decl_vars
self.format = format


def build_tensor(self):
Expand All @@ -64,30 +65,39 @@ def build_tensor(self):
for i in range(len(self.decl_vars)-1):
dims_tuple += self.decl_vars[i] + ","
ranges_tuple += "!ta.range,"

dims_tuple += self.decl_vars[-1] + ")"
ranges_tuple += "!ta.range)"
format = self.format
if self.format == "CSR" or self.format == "COO" or self.format == "CSC":
format = '"{}" , temporal_tensor = false'.format(self.format)
else:
format = '"{}"'.format(self.format)

return self.tensor_decl_wrapper_text.render(
dims_tuple = dims_tuple,
ranges_tuple = ranges_tuple,
inputtype = self.inputtype

format = format,
inputtype = self.inputtype,
decl = "dense" if self.format == "Dense" else "sparse"
)



class TC_and_TrPose_Builder:
class ArithOp_Builder:
indentation_size = 4
beta_val = 0.0

tc_decl_wrapper_text = jinja2.Template(
("" * indentation_size)
+ ' = "ta.tc"{{operators}}'
+ '{__alpha__ = 1.000000e+00 : f64, '
+ ' = "ta.mul"{{operators}}'
+ '{MaskType = "none", ' #[TODO] MaskType should not be static
+ '__alpha__ = 1.000000e+00 : f64, '
+"__beta__ = {{beta}} : f64,"
+ 'formats = ["Dense", "Dense", "Dense"],'
+'indexing_maps = {{indexing_maps}}, semiring = "plusxy_times"} : '
+ 'formats = [{{formats}}],'
+'indexing_maps = {{indexing_maps}}, '
+'operand_segment_sizes = array<i32:1, 1, {{lhs_dims}}, 0>, ' #[TODO] operand_segment_sizes should not be static
+'semiring = "plusxy_times"} : '
+"{{types_range_str}}"
+"-> {{outputtype}}"
+ "\n" ,
Expand All @@ -99,7 +109,7 @@ class TC_and_TrPose_Builder:
+ ' = "ta.elews_mul"{{operators}}'
+ '{__alpha__ = 1.000000e+00 : f64, '
+"__beta__ = {{beta}} : f64,"
+ 'formats = ["Dense", "Dense", "Dense"],'
+ 'formats = [{{formats}}],'
+'indexing_maps = {{indexing_maps}}, semiring = "noop_times"} : '
+"{{types_range_str}}"
+"-> {{outputtype}}"
Expand All @@ -112,16 +122,46 @@ class TC_and_TrPose_Builder:
+ ' = "ta.transpose"{{operators}}'
+ '{__alpha__ = 1.000000e+00 : f64, '
+"__beta__ = {{beta}} : f64,"
+ 'formats = ["Dense", "Dense", "Dense"],'
+ 'formats = [{{formats}}],'
+'indexing_maps = {{indexing_maps}},semiring = "plusxy_times"} : '
+"{{types_range_str}}"
+"-> {{outputtype}}"
+ "\n" ,
undefined=jinja2.StrictUndefined,
)

tensor_add_wrapper_text = jinja2.Template(
("" * indentation_size)
+' = "ta.add"{{operators}}'
+' {'
+' Masktype = "none",'
+' formats = [{{formats}}],'
+' indexing_maps = {{indexing_maps}},'
+' semiring = "noop_plusxy"'
+' }'
+' : {{Tensor_types_tuple}}'
+"-> {{outputtype}}"
+ "\n" ,
undefined=jinja2.StrictUndefined,
)

tensor_sub_wrapper_text = jinja2.Template(
("" * indentation_size)
+ ' = "ta.subtract"{{operators}}'
+' {'
+' Masktype = "none",'
+' formats = [{{formats}}],'
+' indexing_maps = {{indexing_maps}},'
+' semiring = "noop_minus"'
+' }'
+' : {{Tensor_types_tuple}}'
+"-> {{outputtype}}"
+ "\n" ,
undefined=jinja2.StrictUndefined,
)

def __init__(self, input_tensors:list, dimslbls_to_map:list, input_array_dims_lbls:list,
target_dims_lbls:list,tensor_types:list,tc_indices:list,opr_type:str,op:str) -> None:
target_dims_lbls:list,tensor_types:list,tc_indices:list,opr_type:str,op:str, formats:list) -> None:
self.operators = input_tensors
self.dimslbls_to_map = dimslbls_to_map
self.input_array_dims_lbls = input_array_dims_lbls
Expand All @@ -130,8 +170,9 @@ def __init__(self, input_tensors:list, dimslbls_to_map:list, input_array_dims_lb
self.tc_indices = tc_indices
self.opr_type = opr_type
self.op = op
self.formats = formats

def build_tc(self):
def build(self):

out_tensor_type = self.tensor_types.pop()
types_range_str = "("
Expand All @@ -146,44 +187,66 @@ def build_tc(self):
types_range_str += "!ta.range)"

self.operators = str(tuple(self.operators)).replace("'", "")

indexing_maps = TC_and_TrPose_Builder.create_affine_mapping(self.dimslbls_to_map,self.input_array_dims_lbls,self.target_dims_lbls)

self.beta_val = TC_and_TrPose_Builder.get_beta_val(self.op)
indexing_maps = ArithOp_Builder.create_affine_mapping(self.dimslbls_to_map,self.input_array_dims_lbls,self.target_dims_lbls)
self.beta_val = ArithOp_Builder.get_beta_val(self.op)

if(self.opr_type == "contraction"):
return self.tc_decl_wrapper_text.render(
operators = self.operators,
indexing_maps = indexing_maps,
types_range_str = types_range_str,
outputtype = out_tensor_type,
beta = self.beta_val
beta = self.beta_val,
formats = '"{}", "{}", "{}"'.format(*self.formats),
lhs_dims = len(self.target_dims_lbls)
)
elif(self.opr_type == "elewise_mult"):
return self.elewisemult_wrapper_text.render(
operators = self.operators,
indexing_maps = indexing_maps,
types_range_str = types_range_str,
outputtype = out_tensor_type,
formats = '"{}", "{}", "{}"'.format(*self.formats),
beta = self.beta_val
)
else:
# Transpose
elif(self.opr_type == "transpose"):
return self.tranpose_wrapper_text.render(
operators = self.operators,
indexing_maps = indexing_maps,
types_range_str = types_range_str,
outputtype = out_tensor_type,
beta = self.beta_val
beta = self.beta_val,
formats = '"{}", "{}"'.format(*self.formats)
)

# Add
elif(self.op == '+'):
return self.tensor_add_wrapper_text.render(
operators = self.operators,
Tensor_types_tuple = types_range_str,
outputtype = out_tensor_type,
formats = '"{}", "{}", "{}"'.format(*self.formats),
indexing_maps = indexing_maps
)
# Subtract
elif(self.op == '-'):
return self.tensor_sub_wrapper_text.render(
operators = self.operators,
Tensor_types_tuple = types_range_str,
outputtype = out_tensor_type,
formats = '"{}", "{}", "{}"'.format(*self.formats),
indexing_maps = indexing_maps
)

def get_beta_val(op):
if(op == '='):
beta_val = '0.000000e+00'
elif(op == '+='):
beta_val = '1.000000e+00'
elif(op == '-='):
beta_val = '-1.000000e+00'

elif(op == '+' or op == '-'):
beta_val = '0.000000e+00'
return beta_val


Expand All @@ -201,7 +264,7 @@ def create_affine_mapping(dims_to_map:list, input_array_dims:list, target_dims:l

if(len(d_map) == 1):
d_map = d_map.pop()
mapping_String += "affine_map<" + str(tuple(input_map.values())).replace("'", "") + "-> ({})".format(d_map.replace("'", "")) + ">,"
mapping_String += "affine_map<(" + ",".join(list(input_map.values())).replace("'", "") + ")-> ({})".format(d_map.replace("'", "")) + ">,"
else:
mapping_String += "affine_map<" + str(tuple(input_map.values())).replace("'", "") + "->" + str(tuple(d_map)).replace("'", "") + ">,"

Expand All @@ -212,7 +275,7 @@ def create_affine_mapping(dims_to_map:list, input_array_dims:list, target_dims:l

if(len(d_map) == 1):
d_map = d_map.pop()
mapping_String += "affine_map<{} -> ({})".format((str(tuple(input_map.values())).replace("'","")), d_map) + ">]"
mapping_String += "affine_map<({}) -> ({})".format(",".join(list(input_map.values())).replace("'", ""), d_map) + ">]"
else:
mapping_String += "affine_map<{} -> {}".format((str(tuple(input_map.values())).replace("'","")), (str(tuple(d_map)).replace("'",""))) + ">]"

Expand Down Expand Up @@ -257,72 +320,48 @@ def build(self):
return self.ele_wise_fill_wrapper_text.render(
lbtensor_op_var = self.lbtensor_op_var,
const_op_var = self.const_op_var,
assigned_val = self.assigned_val.tolist(),
assigned_val = self.assigned_val.tolist(), #[TODO]
tensor_type = self.tensor_type,
dims_tensor_tuple = dims_tensor_tuple,
ranges_tuple = ranges_tuple
)

class Tensor_arithOp_builder:


class TensorSumBuilder:
indentation_size = 4

tensor_add_wrapper_text = jinja2.Template(
tensor_sum_wrapper_text = jinja2.Template(
("" * indentation_size)
+ ' = "ta.add"{{operators}}'
+' : {{Tensor_types_tuple}}'
+"-> {{outputtype}}"
+ ' = "ta.reduce"({{operators}})'
+' : ({{Tensor_types_tuple}})'
+"-> {{output_types}}"
+ "\n" ,
undefined=jinja2.StrictUndefined,
)

tensor_sub_wrapper_text = jinja2.Template(
("" * indentation_size)
+ ' = "ta.subtract"{{operators}}'
+' : {{Tensor_types_tuple}}'
+"-> {{outputtype}}"
+ "\n" ,
undefined=jinja2.StrictUndefined,
)

def __init__(self,tensor_operators:list,input_types:list,outputtype:str,op:str) -> None:
self.tensor_operators = tensor_operators
def __init__(self, operators, input_types, output_types):
self.operators = operators
self.input_types = input_types
self.outputtype = outputtype
self.op = op
pass
self.output_types = output_types

def build(self):
operators_tuple = str(tuple(self.tensor_operators)).replace("'", "")
input_types_tuple = str(tuple(self.input_types)).replace("'", "")

if(self.op == '+'):
return self.tensor_add_wrapper_text.render(
operators = operators_tuple,
Tensor_types_tuple = input_types_tuple,
outputtype = self.outputtype
)
return self.tensor_sum_wrapper_text.render(
operators = ",".join(self.operators).replace("'",""),
Tensor_types_tuple = ",".join(self.input_types),
output_types = ",".join(self.output_types)
)

elif(self.op == '-'):
return self.tensor_sub_wrapper_text.render(
operators = operators_tuple,
Tensor_types_tuple = input_types_tuple,
outputtype = self.outputtype
)


class MLIRFunctionBuilder:
_ops = {}

default_indentation_size = 4
indentation_delta_size = 2
module_wrapper_text = jinja2.Template(
"{{ aliases }}module {\n {{ body }}\n}\n",
"{{ aliases }}module {\n {{ body }}\nfunc.func private @quick_sort(memref<*xindex>, index)\n}\n",
undefined=jinja2.StrictUndefined,
)
function_wrapper_text = jinja2.Template(
("" * default_indentation_size)
+ "func {% if private_func %}private {% endif %}@{{func_name}}({{signature}}) -> {{return_type}} {"
+ "func.func {% if private_func %}private {% endif %}@{{func_name}}({{signature}}) -> {{return_type}} {"
+ "\n"
+ "{{statements}}"
+ "\n"
Expand All @@ -334,7 +373,7 @@ class MLIRFunctionBuilder:
def __init__(
self,
func_name: str,
#input_types: Sequence[Union[str, Type]],
input_types,
return_types: Sequence[Union[str, types_mlir.Type]],
aliases: types_mlir.AliasMap = None,
) -> None:
Expand All @@ -356,7 +395,7 @@ def __init__(
return_types = [types_mlir.Type.find(rt, aliases) for rt in return_types]

self.func_name = func_name
#self.inputs = inputs
self.inputs = input_types
self.return_types = return_types

self.var_name_counter = itertools.count()
Expand Down Expand Up @@ -421,13 +460,12 @@ def get_mlir(self, make_private=True, include_func_defs=True) -> str:
return_type = ", ".join(str(rt) for rt in self.return_types)
if len(self.return_types) != 1:
return_type = f"({return_type})"

#signature = ", ".join(f"{var}: {var.type}" for var in self.inputs)
signature = ", ".join(f"{ var[0].replace('%','%arg_')}: {var[1].replace('tensor', 'memref')}" if 'tensor' in var[1] else f"{ var[0].replace('%', '%arg_') }: memref<1xf64>" if var[1] =="f64" else f"{ var[0]}: {var[1]}" for var in self.inputs)

return needed_function_definitions + self.function_wrapper_text.render(
private_func=make_private,
func_name=self.func_name,
signature="",
signature=signature,
return_type=return_type,
statements=joined_statements,
)
Expand Down
Loading

0 comments on commit fe7276d

Please sign in to comment.