Skip to content

Commit

Permalink
Add basic tests that smoked out a couple issues :)
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Jul 9, 2024
1 parent 8676c9b commit 768b44d
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/nx-cugraph/nx_cugraph/convert_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def from_pandas_edgelist(
# Optimistically try to use cupy, but fall back to numpy if necessary
src_array = src_series.to_cupy()
dst_array = dst_series.to_cupy()
except (AttributeError, ValueError, NotImplementedError):
except (AttributeError, TypeError, ValueError, NotImplementedError):
src_array = src_series.to_numpy()
dst_array = dst_series.to_numpy()
try:
Expand Down Expand Up @@ -77,9 +77,13 @@ def from_pandas_edgelist(
src_indices = cp.asarray(np_or_cp.searchsorted(nodes, src_array), index_dtype)
dst_indices = cp.asarray(np_or_cp.searchsorted(nodes, dst_array), index_dtype)
else:
if not is_src_copied:
if is_src_copied:
src_indices = src_array
else:
src_indices = cp.array(src_array)
if not is_dst_copied:
if is_dst_copied:
dst_indices = dst_array
else:
dst_indices = cp.array(dst_array)

if not graph_class.is_directed():
Expand Down
82 changes: 82 additions & 0 deletions python/nx-cugraph/nx_cugraph/tests/test_convert_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pandas as pd
import pytest

import nx_cugraph as nxcg
from nx_cugraph.utils import _cp_iscopied_asarray

try:
import cudf
except ModuleNotFoundError:
cudf = None


data = [
{"source": [0, 1], "target": [1, 2]}, # nodes are 0, 1, 2
{"source": [0, 1], "target": [1, 3]}, # nodes are 0, 1, 3 (need renumbered!)
{"source": ["a", "b"], "target": ["b", "c"]}, # nodes are 'a', 'b', 'c'
]


@pytest.mark.skipif("not cudf")
@pytest.mark.parametrize("data", data)
def test_from_cudf_edgelist(data):
df = cudf.DataFrame(data)
nxcg.from_pandas_edgelist(df) # Basic smoke test
source = df["source"]
if source.dtype == int:
is_copied, src_array = _cp_iscopied_asarray(source)
assert is_copied is False
is_copied, src_array = _cp_iscopied_asarray(source.to_cupy())
assert is_copied is False
is_copied, src_array = _cp_iscopied_asarray(source, orig_object=source)
assert is_copied is False
is_copied, src_array = _cp_iscopied_asarray(
source.to_cupy(), orig_object=source
)
assert is_copied is False
# to numpy
is_copied, src_array = _cp_iscopied_asarray(source.to_numpy())
assert is_copied is True
is_copied, src_array = _cp_iscopied_asarray(
source.to_numpy(), orig_object=source
)
assert is_copied is True
else:
with pytest.raises(TypeError):
_cp_iscopied_asarray(source)
with pytest.raises(TypeError):
_cp_iscopied_asarray(source.to_cupy())
with pytest.raises(ValueError, match="Unsupported dtype"):
_cp_iscopied_asarray(source.to_numpy())
with pytest.raises(ValueError, match="Unsupported dtype"):
_cp_iscopied_asarray(source.to_numpy(), orig_object=source)


@pytest.mark.parametrize("data", data)
def test_from_pandas_edgelist(data):
df = pd.DataFrame(data)
nxcg.from_pandas_edgelist(df) # Basic smoke test
source = df["source"]
if source.dtype == int:
is_copied, src_array = _cp_iscopied_asarray(source)
assert is_copied is True
is_copied, src_array = _cp_iscopied_asarray(source, orig_object=source)
assert is_copied is True
is_copied, src_array = _cp_iscopied_asarray(source.to_numpy())
assert is_copied is True
is_copied, src_array = _cp_iscopied_asarray(
source.to_numpy(), orig_object=source
)
assert is_copied is True

0 comments on commit 768b44d

Please sign in to comment.