Skip to content

Commit

Permalink
【complex op】 No.34 add complex support for dot (#56349)
Browse files Browse the repository at this point in the history
* update

* fix codestyle

* update

* update
  • Loading branch information
huangjiyi authored Aug 17, 2023
1 parent 488071a commit 920c66e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 92 deletions.
26 changes: 22 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,8 @@ def dot(x, y, name=None):
is the batch dimension, which means that the vectors of multiple batches are dotted.
Parameters:
x(Tensor): 1-D or 2-D ``Tensor``. Its dtype should be ``float32``, ``float64``, ``int32``, ``int64``
y(Tensor): 1-D or 2-D ``Tensor``. Its dtype soulde be ``float32``, ``float64``, ``int32``, ``int64``
x(Tensor): 1-D or 2-D ``Tensor``. Its dtype should be ``float32``, ``float64``, ``int32``, ``int64``, ``complex64``, ``complex128``
y(Tensor): 1-D or 2-D ``Tensor``. Its dtype soulde be ``float32``, ``float64``, ``int32``, ``int64``, ``complex64``, ``complex128``
name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name`
Returns:
Expand Down Expand Up @@ -1117,13 +1117,31 @@ def dot(x, y, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
[
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
op_type,
)
check_variable_and_dtype(
y,
'y',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
[
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
op_type,
)

Expand Down
110 changes: 22 additions & 88 deletions test/legacy_test/test_dot_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,102 +182,36 @@ def test_dygraph(self):
)


class TestComplexDotOp(OpTest):
def setUp(self):
self.op_type = "dot"
self.python_api = paddle.dot
self.init_base_dtype()
self.init_input_output()

self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.outputs = {'Out': self.out}

def init_base_dtype(self):
self.dtype = np.float64
class TestComplex64DotOp(DotOp):
def init_dtype(self):
self.dtype = np.complex64

def init_input_output(self):
self.x = np.random.random(100).astype(
self.dtype
) + 1j * np.random.random(100).astype(self.dtype)
self.y = np.random.random(100).astype(
self.dtype
) + 1j * np.random.random(100).astype(self.dtype)
self.out = np.dot(self.x, self.y)

def test_check_output(self):
self.check_output()

def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
)

def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
)

def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
)


class TestComplexDotOp2D(OpTest):
def setUp(self):
self.op_type = "dot"
self.python_api = paddle.dot
self.init_base_dtype()
self.init_input_output()

self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.outputs = {'Out': self.out}
shape = 100
self.x = (
np.random.random(shape) + 1j * np.random.random(shape)
).astype(self.dtype)
self.y = (
np.random.random(shape) + 1j * np.random.random(shape)
).astype(self.dtype)
self.out = np.dot(self.x, self.y).astype(self.dtype)

def init_base_dtype(self):
self.dtype = np.float64

class TestComplex64DotOp2D(TestComplex64DotOp):
def init_input_output(self):
self.x = np.random.random((2, 100)).astype(
self.dtype
) + 1j * np.random.random((2, 100)).astype(self.dtype)
self.y = np.random.random((2, 100)).astype(
self.dtype
) + 1j * np.random.random((2, 100)).astype(self.dtype)
shape = (2, 100)
self.x = (
np.random.random(shape) + 1j * np.random.random(shape)
).astype(self.dtype)
self.y = (
np.random.random(shape) + 1j * np.random.random(shape)
).astype(self.dtype)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1)

def test_check_output(self):
self.check_output()

def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
)

def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
)

def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
)
class TestComplex128DotOp(TestComplex64DotOp):
def init_dtype(self):
self.dtype = np.complex128


@unittest.skipIf(
Expand Down

0 comments on commit 920c66e

Please sign in to comment.