From 768b44d57b66e1e6a3cb2779439c8915d1682ee5 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Tue, 9 Jul 2024 08:10:55 -0700 Subject: [PATCH] Add basic tests that smoked out a couple issues :) --- .../nx-cugraph/nx_cugraph/convert_matrix.py | 10 ++- .../nx_cugraph/tests/test_convert_matrix.py | 82 +++++++++++++++++++ 2 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 python/nx-cugraph/nx_cugraph/tests/test_convert_matrix.py diff --git a/python/nx-cugraph/nx_cugraph/convert_matrix.py b/python/nx-cugraph/nx_cugraph/convert_matrix.py index f5fd34a8d37..e00463fbc7b 100644 --- a/python/nx-cugraph/nx_cugraph/convert_matrix.py +++ b/python/nx-cugraph/nx_cugraph/convert_matrix.py @@ -42,7 +42,7 @@ def from_pandas_edgelist( # Optimistically try to use cupy, but fall back to numpy if necessary src_array = src_series.to_cupy() dst_array = dst_series.to_cupy() - except (AttributeError, ValueError, NotImplementedError): + except (AttributeError, TypeError, ValueError, NotImplementedError): src_array = src_series.to_numpy() dst_array = dst_series.to_numpy() try: @@ -77,9 +77,13 @@ def from_pandas_edgelist( src_indices = cp.asarray(np_or_cp.searchsorted(nodes, src_array), index_dtype) dst_indices = cp.asarray(np_or_cp.searchsorted(nodes, dst_array), index_dtype) else: - if not is_src_copied: + if is_src_copied: + src_indices = src_array + else: src_indices = cp.array(src_array) - if not is_dst_copied: + if is_dst_copied: + dst_indices = dst_array + else: dst_indices = cp.array(dst_array) if not graph_class.is_directed(): diff --git a/python/nx-cugraph/nx_cugraph/tests/test_convert_matrix.py b/python/nx-cugraph/nx_cugraph/tests/test_convert_matrix.py new file mode 100644 index 00000000000..ee6aca9e0ea --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/tests/test_convert_matrix.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd +import pytest + +import nx_cugraph as nxcg +from nx_cugraph.utils import _cp_iscopied_asarray + +try: + import cudf +except ModuleNotFoundError: + cudf = None + + +data = [ + {"source": [0, 1], "target": [1, 2]}, # nodes are 0, 1, 2 + {"source": [0, 1], "target": [1, 3]}, # nodes are 0, 1, 3 (need renumbered!) + {"source": ["a", "b"], "target": ["b", "c"]}, # nodes are 'a', 'b', 'c' +] + + +@pytest.mark.skipif("not cudf") +@pytest.mark.parametrize("data", data) +def test_from_cudf_edgelist(data): + df = cudf.DataFrame(data) + nxcg.from_pandas_edgelist(df) # Basic smoke test + source = df["source"] + if source.dtype == int: + is_copied, src_array = _cp_iscopied_asarray(source) + assert is_copied is False + is_copied, src_array = _cp_iscopied_asarray(source.to_cupy()) + assert is_copied is False + is_copied, src_array = _cp_iscopied_asarray(source, orig_object=source) + assert is_copied is False + is_copied, src_array = _cp_iscopied_asarray( + source.to_cupy(), orig_object=source + ) + assert is_copied is False + # to numpy + is_copied, src_array = _cp_iscopied_asarray(source.to_numpy()) + assert is_copied is True + is_copied, src_array = _cp_iscopied_asarray( + source.to_numpy(), orig_object=source + ) + assert is_copied is True + else: + with pytest.raises(TypeError): + _cp_iscopied_asarray(source) + with pytest.raises(TypeError): + _cp_iscopied_asarray(source.to_cupy()) + with pytest.raises(ValueError, match="Unsupported dtype"): + _cp_iscopied_asarray(source.to_numpy()) + with pytest.raises(ValueError, match="Unsupported dtype"): + _cp_iscopied_asarray(source.to_numpy(), orig_object=source) + + +@pytest.mark.parametrize("data", data) +def test_from_pandas_edgelist(data): + df = pd.DataFrame(data) + nxcg.from_pandas_edgelist(df) # Basic smoke test + source = df["source"] + if source.dtype == int: + is_copied, src_array = _cp_iscopied_asarray(source) + assert is_copied is True + is_copied, src_array = _cp_iscopied_asarray(source, orig_object=source) + assert is_copied is True + is_copied, src_array = _cp_iscopied_asarray(source.to_numpy()) + assert is_copied is True + is_copied, src_array = _cp_iscopied_asarray( + source.to_numpy(), orig_object=source + ) + assert is_copied is True