Skip to content

Commit

Permalink
Update numpy rewrite for linear_operator_circulant.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592864112
  • Loading branch information
ColCarroll authored and tensorflower-gardener committed Dec 21, 2023
1 parent fc47de9 commit ec23c1f
Showing 1 changed file with 78 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,32 @@ def _linop_inverse(self) -> "LinearOperatorCirculant":
is_square=True,
input_output_dtype=self.dtype)

def _linop_matmul(
self,
left_operator: "LinearOperatorCirculant",
right_operator: linear_operator.LinearOperator,
) -> linear_operator.LinearOperator:
if not isinstance(
right_operator, LinearOperatorCirculant
) or not isinstance(left_operator, type(right_operator)):
return super()._linop_matmul(left_operator, right_operator)

return LinearOperatorCirculant(
spectrum=left_operator.spectrum * right_operator.spectrum,
is_non_singular=property_hint_util.combined_non_singular_hint(
left_operator, right_operator
),
is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint(
left_operator, right_operator
),
is_positive_definite=(
property_hint_util.combined_commuting_positive_definite_hint(
left_operator, right_operator
)
),
is_square=True,
)

def _linop_solve(
self,
left_operator: "LinearOperatorCirculant",
Expand Down Expand Up @@ -1271,6 +1297,32 @@ def _linop_inverse(self) -> "LinearOperatorCirculant2D":
is_square=True,
input_output_dtype=self.dtype)

def _linop_matmul(
self,
left_operator: "LinearOperatorCirculant2D",
right_operator: linear_operator.LinearOperator,
) -> linear_operator.LinearOperator:
if not isinstance(
right_operator, LinearOperatorCirculant2D
) or not isinstance(left_operator, type(right_operator)):
return super()._linop_matmul(left_operator, right_operator)

return LinearOperatorCirculant2D(
spectrum=left_operator.spectrum * right_operator.spectrum,
is_non_singular=property_hint_util.combined_non_singular_hint(
left_operator, right_operator
),
is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint(
left_operator, right_operator
),
is_positive_definite=(
property_hint_util.combined_commuting_positive_definite_hint(
left_operator, right_operator
)
),
is_square=True,
)

def _linop_solve(
self,
left_operator: "LinearOperatorCirculant2D",
Expand Down Expand Up @@ -1473,6 +1525,32 @@ def _linop_inverse(self) -> "LinearOperatorCirculant3D":
is_square=True,
input_output_dtype=self.dtype)

def _linop_matmul(
self,
left_operator: "LinearOperatorCirculant3D",
right_operator: linear_operator.LinearOperator,
) -> linear_operator.LinearOperator:
if not isinstance(
right_operator, LinearOperatorCirculant3D
) or not isinstance(left_operator, type(right_operator)):
return super()._linop_matmul(left_operator, right_operator)

return LinearOperatorCirculant3D(
spectrum=left_operator.spectrum * right_operator.spectrum,
is_non_singular=property_hint_util.combined_non_singular_hint(
left_operator, right_operator
),
is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint(
left_operator, right_operator
),
is_positive_definite=(
property_hint_util.combined_commuting_positive_definite_hint(
left_operator, right_operator
)
),
is_square=True,
)

def _linop_solve(
self,
left_operator: "LinearOperatorCirculant3D",
Expand Down

0 comments on commit ec23c1f

Please sign in to comment.