Skip to content

Commit

Permalink
support more int type check
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerwwww committed Apr 11, 2024
1 parent c5f8797 commit 15f3b46
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
16 changes: 8 additions & 8 deletions pygmtools/classic_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,10 @@ def sm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
if _check_shape(K, 2, backend):
K = _unsqueeze(K, 0, backend)
non_batched_input = True
if type(n1) is int and n1max is None:
if isinstance(n1, (int, np.integer)) and n1max is None:
n1max = n1
n1 = None
if type(n2) is int and n2max is None:
if isinstance(n2, (int, np.integer)) and n2max is None:
n2max = n2
n2 = None
elif _check_shape(K, 3, backend):
Expand Down Expand Up @@ -727,10 +727,10 @@ def rrwm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
if _check_shape(K, 2, backend):
K = _unsqueeze(K, 0, backend)
non_batched_input = True
if type(n1) is int and n1max is None:
if isinstance(n1, (int, np.integer)) and n1max is None:
n1max = n1
n1 = None
if type(n2) is int and n2max is None:
if isinstance(n2, (int, np.integer)) and n2max is None:
n2max = n2
n2 = None
elif _check_shape(K, 3, backend):
Expand Down Expand Up @@ -1034,10 +1034,10 @@ def ipfp(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
if _check_shape(K, 2, backend):
K = _unsqueeze(K, 0, backend)
non_batched_input = True
if type(n1) is int and n1max is None:
if isinstance(n1, (int, np.integer)) and n1max is None:
n1max = n1
n1 = None
if type(n2) is int and n2max is None:
if isinstance(n2, (int, np.integer)) and n2max is None:
n2max = n2
n2 = None
elif _check_shape(K, 3, backend):
Expand Down Expand Up @@ -1165,10 +1165,10 @@ def astar(K, n1=None, n2=None, n1max=None, n2max=None, beam_width=0, backend=Non
if _check_shape(K, 2, backend):
K = _unsqueeze(K, 0, backend)
non_batched_input = True
if type(n1) is int and n1max is None:
if isinstance(n1, (int, np.integer)) and n1max is None:
n1max = n1
n1 = None
if type(n2) is int and n2max is None:
if isinstance(n2, (int, np.integer)) and n2max is None:
n2max = n2
n2 = None
elif _check_shape(K, 3, backend):
Expand Down
8 changes: 4 additions & 4 deletions pygmtools/linear_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,8 @@ def sinkhorn(s, n1=None, n2=None, unmatch1=None, unmatch2=None,
_check_data_type(s, 's', backend)
if _check_shape(s, 2, backend):
s = _unsqueeze(s, 0, backend)
if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend)
if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend)
if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend)
if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend)
non_batched_input = True
elif _check_shape(s, 3, backend):
non_batched_input = False
Expand Down Expand Up @@ -1298,8 +1298,8 @@ def hungarian(s, n1=None, n2=None, unmatch1=None, unmatch2=None,
_check_data_type(s, backend)
if _check_shape(s, 2, backend):
s = _unsqueeze(s, 0, backend)
if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend)
if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend)
if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend)
if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend)
non_batched_input = True
elif _check_shape(s, 3, backend):
non_batched_input = False
Expand Down
20 changes: 10 additions & 10 deletions pygmtools/neural_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def pca_gm(feat1, feat2, A1, A2, n1=None, n2=None,

if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]):
feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)]
if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend)
if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend)
if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend)
if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend)
non_batched_input = True
elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]):
non_batched_input = False
Expand Down Expand Up @@ -592,8 +592,8 @@ def ipca_gm(feat1, feat2, A1, A2, n1=None, n2=None,

if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]):
feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)]
if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend)
if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend)
if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend)
if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend)
non_batched_input = True
elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]):
non_batched_input = False
Expand Down Expand Up @@ -913,8 +913,8 @@ def cie(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1=None, n2=None
and all([_check_shape(_, 3, backend) for _ in (feat_edge1, feat_edge2)]):
feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2 =\
[_unsqueeze(_, 0, backend) for _ in (feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2)]
if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend)
if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend)
if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend)
if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend)
non_batched_input = True
elif all([_check_shape(_, 3, backend) for _ in (feat_node1, feat_node2, A1, A2)]) \
and all([_check_shape(_, 4, backend) for _ in (feat_edge1, feat_edge2)]):
Expand Down Expand Up @@ -1239,10 +1239,10 @@ def ngm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
if _check_shape(K, 2, backend):
K = _unsqueeze(K, 0, backend)
non_batched_input = True
if type(n1) is int and n1max is None:
if isinstance(n1, (int, np.integer)) and n1max is None:
n1max = n1
n1 = None
if type(n2) is int and n2max is None:
if isinstance(n2, (int, np.integer)) and n2max is None:
n2max = n2
n2 = None
elif _check_shape(K, 3, backend):
Expand Down Expand Up @@ -1444,8 +1444,8 @@ def genn_astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, filters_1=6

if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]):
feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)]
if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend)
if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend)
if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend)
if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend)
non_batched_input = True
elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]):
non_batched_input = False
Expand Down

0 comments on commit 15f3b46

Please sign in to comment.