diff --git a/python/tests/test_classify.py b/python/tests/test_classify.py index 01a41be..f31fae1 100644 --- a/python/tests/test_classify.py +++ b/python/tests/test_classify.py @@ -1,7 +1,6 @@ import pytest import mall import polars as pl -import pyarrow import shutil import os @@ -13,17 +12,18 @@ def test_classify(): df = pl.DataFrame(dict(x=["one", "two", "three"])) df.llm.use("test", "echo", _cache="_test_cache") x = df.llm.classify("x", ["one", "two"]) - assert ( - x.select("classify").to_pandas().to_string() - == " classify\n0 one\n1 two\n2 None" - ) + assert pull(x, "classify") == ["one", "two", None] def test_classify_dict(): df = pl.DataFrame(dict(x=[1, 2, 3])) df.llm.use("test", "echo", _cache="_test_cache") x = df.llm.classify("x", {"one": 1, "two": 2}) - assert ( - x.select("classify").to_pandas().to_string() - == " classify\n0 1.0\n1 2.0\n2 NaN" - ) + assert pull(x, "classify") == [1, 2, None] + + +def pull(df, col): + out = [] + for i in df.select(col).to_dicts(): + out.append(i.get(col)) + return out diff --git a/python/tests/test_extract.py b/python/tests/test_extract.py index a320896..5c58424 100644 --- a/python/tests/test_extract.py +++ b/python/tests/test_extract.py @@ -1,16 +1,16 @@ import pytest import mall import polars as pl -import pyarrow - import shutil import os + if os._exists("_test_cache"): shutil.rmtree("_test_cache", ignore_errors=True) + def test_extract_list(): df = pl.DataFrame(dict(x="x")) - df.llm.use("test", "content", _cache = "_test_cache") + df.llm.use("test", "content", _cache="_test_cache") x = df.llm.extract("x", ["a", "b"]) assert ( x["extract"][0] @@ -20,7 +20,7 @@ def test_extract_list(): def test_extract_dict(): df = pl.DataFrame(dict(x="x")) - df.llm.use("test", "content", _cache = "_test_cache") + df.llm.use("test", "content", _cache="_test_cache") x = df.llm.extract("x", dict(a="one", b="two")) assert ( x["extract"][0] @@ -30,7 +30,7 @@ def test_extract_dict(): def test_extract_one(): df = pl.DataFrame(dict(x="x")) - df.llm.use("test", "content", _cache = "_test_cache") + df.llm.use("test", "content", _cache="_test_cache") x = df.llm.extract("x", labels="a") assert ( x["extract"][0] diff --git a/python/tests/test_sentiment.py b/python/tests/test_sentiment.py index 2e3f711..a21699b 100644 --- a/python/tests/test_sentiment.py +++ b/python/tests/test_sentiment.py @@ -1,8 +1,6 @@ import pytest import mall import polars as pl -import pyarrow - import shutil import os @@ -15,10 +13,7 @@ def test_sentiment_simple(): reviews = data.reviews reviews.llm.use("test", "echo", _cache="_test_cache") x = reviews.llm.sentiment("review") - assert ( - x.select("sentiment").to_pandas().to_string() - == " sentiment\n0 None\n1 None\n2 None" - ) + assert pull(x, "sentiment") == [None, None, None] def sim_sentiment(): @@ -30,19 +25,13 @@ def sim_sentiment(): def test_sentiment_valid(): x = sim_sentiment() x = x.llm.sentiment("x") - assert ( - x.select("sentiment").to_pandas().to_string() - == " sentiment\n0 positive\n1 negative\n2 neutral\n3 None" - ) + assert pull(x, "sentiment") == ["positive", "negative", "neutral", None] def test_sentiment_valid2(): x = sim_sentiment() x = x.llm.sentiment("x", ["positive", "negative"]) - assert ( - x.select("sentiment").to_pandas().to_string() - == " sentiment\n0 positive\n1 negative\n2 None\n3 None" - ) + assert pull(x, "sentiment") == ["positive", "negative", None, None] def test_sentiment_prompt(): @@ -53,3 +42,10 @@ def test_sentiment_prompt(): x["sentiment"][0] == "You are a helpful sentiment engine. Return only one of the following answers: positive, negative, neutral . No capitalization. No explanations. The answer is based on the following text:\n{}" ) + + +def pull(df, col): + out = [] + for i in df.select(col).to_dicts(): + out.append(i.get(col)) + return out diff --git a/python/tests/test_summarize.py b/python/tests/test_summarize.py index e2182d4..6d28578 100644 --- a/python/tests/test_summarize.py +++ b/python/tests/test_summarize.py @@ -1,7 +1,6 @@ import pytest import mall import polars as pl -import pyarrow import shutil import os diff --git a/python/tests/test_translate.py b/python/tests/test_translate.py index 5118d88..1230688 100644 --- a/python/tests/test_translate.py +++ b/python/tests/test_translate.py @@ -1,8 +1,6 @@ import pytest import mall import polars as pl -import pyarrow - import shutil import os diff --git a/python/tests/test_verify.py b/python/tests/test_verify.py index 58421e7..e520ab9 100644 --- a/python/tests/test_verify.py +++ b/python/tests/test_verify.py @@ -1,7 +1,6 @@ import pytest import mall import polars as pl -import pyarrow import shutil import os @@ -13,17 +12,18 @@ def test_verify(): df = pl.DataFrame(dict(x=[1, 1, 0, 2])) df.llm.use("test", "echo", _cache="_test_cache") x = df.llm.verify("x", "this is my question") - assert ( - x.select("verify").to_pandas().to_string() - == " verify\n0 1.0\n1 1.0\n2 0.0\n3 NaN" - ) + assert pull(x, "verify") == [1, 1, 0, None] def test_verify_yn(): df = pl.DataFrame(dict(x=["y", "n", "y", "x"])) df.llm.use("test", "echo", _cache="_test_cache") x = df.llm.verify("x", "this is my question", ["y", "n"]) - assert ( - x.select("verify").to_pandas().to_string() - == " verify\n0 y\n1 n\n2 y\n3 None" - ) + assert pull(x, "verify") == ["y", "n", "y", None] + + +def pull(df, col): + out = [] + for i in df.select(col).to_dicts(): + out.append(i.get(col)) + return out