Skip to content

Commit

Permalink
Removes dependency of tests on pyarrow
Browse files Browse the repository at this point in the history
  • Loading branch information
edgararuiz committed Oct 15, 2024
1 parent a223d8a commit b0d568e
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 40 deletions.
18 changes: 9 additions & 9 deletions python/tests/test_classify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import mall
import polars as pl
import pyarrow
import shutil
import os

Expand All @@ -13,17 +12,18 @@ def test_classify():
df = pl.DataFrame(dict(x=["one", "two", "three"]))
df.llm.use("test", "echo", _cache="_test_cache")
x = df.llm.classify("x", ["one", "two"])
assert (
x.select("classify").to_pandas().to_string()
== " classify\n0 one\n1 two\n2 None"
)
assert pull(x, "classify") == ["one", "two", None]


def test_classify_dict():
df = pl.DataFrame(dict(x=[1, 2, 3]))
df.llm.use("test", "echo", _cache="_test_cache")
x = df.llm.classify("x", {"one": 1, "two": 2})
assert (
x.select("classify").to_pandas().to_string()
== " classify\n0 1.0\n1 2.0\n2 NaN"
)
assert pull(x, "classify") == [1, 2, None]


def pull(df, col):
out = []
for i in df.select(col).to_dicts():
out.append(i.get(col))
return out
10 changes: 5 additions & 5 deletions python/tests/test_extract.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import pytest
import mall
import polars as pl
import pyarrow

import shutil
import os

if os._exists("_test_cache"):
shutil.rmtree("_test_cache", ignore_errors=True)


def test_extract_list():
df = pl.DataFrame(dict(x="x"))
df.llm.use("test", "content", _cache = "_test_cache")
df.llm.use("test", "content", _cache="_test_cache")
x = df.llm.extract("x", ["a", "b"])
assert (
x["extract"][0]
Expand All @@ -20,7 +20,7 @@ def test_extract_list():

def test_extract_dict():
df = pl.DataFrame(dict(x="x"))
df.llm.use("test", "content", _cache = "_test_cache")
df.llm.use("test", "content", _cache="_test_cache")
x = df.llm.extract("x", dict(a="one", b="two"))
assert (
x["extract"][0]
Expand All @@ -30,7 +30,7 @@ def test_extract_dict():

def test_extract_one():
df = pl.DataFrame(dict(x="x"))
df.llm.use("test", "content", _cache = "_test_cache")
df.llm.use("test", "content", _cache="_test_cache")
x = df.llm.extract("x", labels="a")
assert (
x["extract"][0]
Expand Down
24 changes: 10 additions & 14 deletions python/tests/test_sentiment.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest
import mall
import polars as pl
import pyarrow

import shutil
import os

Expand All @@ -15,10 +13,7 @@ def test_sentiment_simple():
reviews = data.reviews
reviews.llm.use("test", "echo", _cache="_test_cache")
x = reviews.llm.sentiment("review")
assert (
x.select("sentiment").to_pandas().to_string()
== " sentiment\n0 None\n1 None\n2 None"
)
assert pull(x, "sentiment") == [None, None, None]


def sim_sentiment():
Expand All @@ -30,19 +25,13 @@ def sim_sentiment():
def test_sentiment_valid():
x = sim_sentiment()
x = x.llm.sentiment("x")
assert (
x.select("sentiment").to_pandas().to_string()
== " sentiment\n0 positive\n1 negative\n2 neutral\n3 None"
)
assert pull(x, "sentiment") == ["positive", "negative", "neutral", None]


def test_sentiment_valid2():
x = sim_sentiment()
x = x.llm.sentiment("x", ["positive", "negative"])
assert (
x.select("sentiment").to_pandas().to_string()
== " sentiment\n0 positive\n1 negative\n2 None\n3 None"
)
assert pull(x, "sentiment") == ["positive", "negative", None, None]


def test_sentiment_prompt():
Expand All @@ -53,3 +42,10 @@ def test_sentiment_prompt():
x["sentiment"][0]
== "You are a helpful sentiment engine. Return only one of the following answers: positive, negative, neutral . No capitalization. No explanations. The answer is based on the following text:\n{}"
)


def pull(df, col):
out = []
for i in df.select(col).to_dicts():
out.append(i.get(col))
return out
1 change: 0 additions & 1 deletion python/tests/test_summarize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import mall
import polars as pl
import pyarrow
import shutil
import os

Expand Down
2 changes: 0 additions & 2 deletions python/tests/test_translate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest
import mall
import polars as pl
import pyarrow

import shutil
import os

Expand Down
18 changes: 9 additions & 9 deletions python/tests/test_verify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import mall
import polars as pl
import pyarrow
import shutil
import os

Expand All @@ -13,17 +12,18 @@ def test_verify():
df = pl.DataFrame(dict(x=[1, 1, 0, 2]))
df.llm.use("test", "echo", _cache="_test_cache")
x = df.llm.verify("x", "this is my question")
assert (
x.select("verify").to_pandas().to_string()
== " verify\n0 1.0\n1 1.0\n2 0.0\n3 NaN"
)
assert pull(x, "verify") == [1, 1, 0, None]


def test_verify_yn():
df = pl.DataFrame(dict(x=["y", "n", "y", "x"]))
df.llm.use("test", "echo", _cache="_test_cache")
x = df.llm.verify("x", "this is my question", ["y", "n"])
assert (
x.select("verify").to_pandas().to_string()
== " verify\n0 y\n1 n\n2 y\n3 None"
)
assert pull(x, "verify") == ["y", "n", "y", None]


def pull(df, col):
out = []
for i in df.select(col).to_dicts():
out.append(i.get(col))
return out

0 comments on commit b0d568e

Please sign in to comment.