From 618884211e560b0d98f7f0dbc7f24f4224543162 Mon Sep 17 00:00:00 2001 From: Polykarpos Thomadakis Date: Fri, 15 Sep 2023 23:39:53 -0700 Subject: [PATCH] In numpy-scipy, added support for A.multiply(B) as a match for scipy's method #22 --- frontends/numpy-scipy/comet.py | 19 ++++++++++++++++++- .../ops/test_eltwise_mult_CSRxDense_oCSR.py | 4 ++-- .../ops/test_eltwise_mult_CSRxDense_oDense.py | 4 ++-- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/frontends/numpy-scipy/comet.py b/frontends/numpy-scipy/comet.py index 180ade47..c8ad0927 100644 --- a/frontends/numpy-scipy/comet.py +++ b/frontends/numpy-scipy/comet.py @@ -358,7 +358,24 @@ def visit_Method_Call(self, node: Call, obj): self.tsemantics[out_id] = {'shape': [1,], 'format': DENSE, 'labels': []} self.ops.append(("s", [obj], out_id)) self.declarations.append(('d', 'v', 'l', out_id)) - + elif node.func.attr == "multiply": + op1 = NewVisitor.visit(self, node.args[0]) + op1_sems = self.tsemantics[op1] + if 'labels' not in op1_sems: + op1_sems['labels'] = op_semantics['labels'] + if self.tsemantics[obj]['format'] != DENSE: + op_semantics = self.tsemantics[obj] + self.tsemantics[op1]['labels'] = op_semantics['labels'] + else: + op_semantics = self.tsemantics[op1] + if self.tsemantics[op1]['format'] != DENSE: + self.tsemantics[obj]['labels'] = op_semantics['labels'] + s = 'a' + indices = "".join(chr(ord(s)+i) for i in range(len(op_semantics['labels']))) + self.ops.append(("*", [obj, op1], indices+','+indices+'->'+indices, self.tcurr, None)) + format = self.sp_elw_mult_conversions[op_semantics['format']][op1_sems['format']] + self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'labels': op_semantics['labels'], 'format': format} + self.declarations.append(('d', 'T', 'l', self.tcurr)) self.tcurr +=1 return out_id diff --git a/frontends/numpy-scipy/integration_tests/ops/test_eltwise_mult_CSRxDense_oCSR.py b/frontends/numpy-scipy/integration_tests/ops/test_eltwise_mult_CSRxDense_oCSR.py index f365e592..2272eb71 100644 --- a/frontends/numpy-scipy/integration_tests/ops/test_eltwise_mult_CSRxDense_oCSR.py +++ b/frontends/numpy-scipy/integration_tests/ops/test_eltwise_mult_CSRxDense_oCSR.py @@ -4,13 +4,13 @@ import comet def run_numpy(A,B): - C = A * B + C = A.multiply( B) return C @comet.compile(flags=None) def run_comet_with_jit(A,B): - C = A * B + C = A.multiply( B) return C diff --git a/frontends/numpy-scipy/integration_tests/ops/test_eltwise_mult_CSRxDense_oDense.py b/frontends/numpy-scipy/integration_tests/ops/test_eltwise_mult_CSRxDense_oDense.py index bdee5429..2a53feac 100644 --- a/frontends/numpy-scipy/integration_tests/ops/test_eltwise_mult_CSRxDense_oDense.py +++ b/frontends/numpy-scipy/integration_tests/ops/test_eltwise_mult_CSRxDense_oDense.py @@ -4,13 +4,13 @@ import comet def run_numpy(A,B): - C = A * B + C = A.multiply( B) return C @comet.compile(flags=None) def run_comet_with_jit(A,B): - C = A * B + C = A.multiply( B) return C