diff --git a/pygmtools/classic_solvers.py b/pygmtools/classic_solvers.py index 7109d4d..e5e5f8f 100644 --- a/pygmtools/classic_solvers.py +++ b/pygmtools/classic_solvers.py @@ -371,10 +371,10 @@ def sm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None, if _check_shape(K, 2, backend): K = _unsqueeze(K, 0, backend) non_batched_input = True - if type(n1) is int and n1max is None: + if isinstance(n1, (int, np.integer)) and n1max is None: n1max = n1 n1 = None - if type(n2) is int and n2max is None: + if isinstance(n2, (int, np.integer)) and n2max is None: n2max = n2 n2 = None elif _check_shape(K, 3, backend): @@ -727,10 +727,10 @@ def rrwm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None, if _check_shape(K, 2, backend): K = _unsqueeze(K, 0, backend) non_batched_input = True - if type(n1) is int and n1max is None: + if isinstance(n1, (int, np.integer)) and n1max is None: n1max = n1 n1 = None - if type(n2) is int and n2max is None: + if isinstance(n2, (int, np.integer)) and n2max is None: n2max = n2 n2 = None elif _check_shape(K, 3, backend): @@ -1034,10 +1034,10 @@ def ipfp(K, n1=None, n2=None, n1max=None, n2max=None, x0=None, if _check_shape(K, 2, backend): K = _unsqueeze(K, 0, backend) non_batched_input = True - if type(n1) is int and n1max is None: + if isinstance(n1, (int, np.integer)) and n1max is None: n1max = n1 n1 = None - if type(n2) is int and n2max is None: + if isinstance(n2, (int, np.integer)) and n2max is None: n2max = n2 n2 = None elif _check_shape(K, 3, backend): @@ -1165,10 +1165,10 @@ def astar(K, n1=None, n2=None, n1max=None, n2max=None, beam_width=0, backend=Non if _check_shape(K, 2, backend): K = _unsqueeze(K, 0, backend) non_batched_input = True - if type(n1) is int and n1max is None: + if isinstance(n1, (int, np.integer)) and n1max is None: n1max = n1 n1 = None - if type(n2) is int and n2max is None: + if isinstance(n2, (int, np.integer)) and n2max is None: n2max = n2 n2 = None elif _check_shape(K, 3, backend): diff --git a/pygmtools/linear_solvers.py b/pygmtools/linear_solvers.py index 1186d8a..c84de68 100644 --- a/pygmtools/linear_solvers.py +++ b/pygmtools/linear_solvers.py @@ -729,8 +729,8 @@ def sinkhorn(s, n1=None, n2=None, unmatch1=None, unmatch2=None, _check_data_type(s, 's', backend) if _check_shape(s, 2, backend): s = _unsqueeze(s, 0, backend) - if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend) - if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend) + if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend) + if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend) non_batched_input = True elif _check_shape(s, 3, backend): non_batched_input = False @@ -1298,8 +1298,8 @@ def hungarian(s, n1=None, n2=None, unmatch1=None, unmatch2=None, _check_data_type(s, backend) if _check_shape(s, 2, backend): s = _unsqueeze(s, 0, backend) - if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend) - if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend) + if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend) + if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend) non_batched_input = True elif _check_shape(s, 3, backend): non_batched_input = False diff --git a/pygmtools/neural_solvers.py b/pygmtools/neural_solvers.py index f7d7e2f..80629a4 100644 --- a/pygmtools/neural_solvers.py +++ b/pygmtools/neural_solvers.py @@ -285,8 +285,8 @@ def pca_gm(feat1, feat2, A1, A2, n1=None, n2=None, if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]): feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)] - if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend) - if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend) + if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend) + if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend) non_batched_input = True elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]): non_batched_input = False @@ -592,8 +592,8 @@ def ipca_gm(feat1, feat2, A1, A2, n1=None, n2=None, if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]): feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)] - if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend) - if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend) + if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend) + if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend) non_batched_input = True elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]): non_batched_input = False @@ -913,8 +913,8 @@ def cie(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1=None, n2=None and all([_check_shape(_, 3, backend) for _ in (feat_edge1, feat_edge2)]): feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2 =\ [_unsqueeze(_, 0, backend) for _ in (feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2)] - if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend) - if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend) + if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend) + if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend) non_batched_input = True elif all([_check_shape(_, 3, backend) for _ in (feat_node1, feat_node2, A1, A2)]) \ and all([_check_shape(_, 4, backend) for _ in (feat_edge1, feat_edge2)]): @@ -1239,10 +1239,10 @@ def ngm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None, if _check_shape(K, 2, backend): K = _unsqueeze(K, 0, backend) non_batched_input = True - if type(n1) is int and n1max is None: + if isinstance(n1, (int, np.integer)) and n1max is None: n1max = n1 n1 = None - if type(n2) is int and n2max is None: + if isinstance(n2, (int, np.integer)) and n2max is None: n2max = n2 n2 = None elif _check_shape(K, 3, backend): @@ -1444,8 +1444,8 @@ def genn_astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, filters_1=6 if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]): feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)] - if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend) - if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend) + if isinstance(n1, (int, np.integer)): n1 = from_numpy(np.array([n1]), backend=backend) + if isinstance(n2, (int, np.integer)): n2 = from_numpy(np.array([n2]), backend=backend) non_batched_input = True elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]): non_batched_input = False