Skip to content

Commit

Permalink
Merge branch 'main' into docs/more_info
Browse files Browse the repository at this point in the history
  • Loading branch information
ArinaDanilina authored Jan 3, 2025
2 parents 4c04ff5 + 1e8ffdf commit 8c31536
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_stages:
minimum_pre_commit_version: 3.0.0
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.0
hooks:
- id: mypy
additional_dependencies: [numpy>=1.25.0]
Expand Down Expand Up @@ -42,7 +42,7 @@ repos:
- id: check-yaml
- id: check-toml
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
Expand All @@ -63,7 +63,7 @@ repos:
- id: doc8
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.2
rev: v0.8.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
2 changes: 1 addition & 1 deletion src/moscot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
try:
from numpy.typing import DTypeLike, NDArray

ArrayLike = NDArray[np.float64]
ArrayLike = NDArray[np.floating]
except (ImportError, TypeError):
ArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc]
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _create_graph_geometry(
problem_shape = x.shape if problem_shape is None else problem_shape
return _instantiate_geodesic_cost(
arr=arr,
problem_shape=problem_shape,
problem_shape=problem_shape, # type: ignore[arg-type]
t=t,
is_linear_term=is_linear_term,
epsilon=epsilon,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def transport_matrix(self) -> ArrayLike: # noqa: D102
return self._transport_matrix

@property
def shape(self) -> tuple[int, int]: # noqa: D102
def shape(self) -> tuple[int, ...]: # noqa: D102
return self.transport_matrix.shape

def to( # noqa: D102
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/utils/tagged_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def from_adata(
return cls(data_src=data, tag=tag, cost=cost_fn)

@property
def shape(self) -> Tuple[int, int]:
def shape(self) -> Tuple[int, ...]:
"""Shape of the cost matrix."""
if self.tag == Tag.POINT_CLOUD:
x, y = self.data_src, (self.data_src if self.data_tgt is None else self.data_tgt)
Expand Down

0 comments on commit 8c31536

Please sign in to comment.