Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions tests/tpch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def dataset_path(local, scale):
}
local_paths = {
1: "./tpch-data/scale-1/",
10: "./tpch-data/scale10/",
100: "./tpch-data/scale100/",
10: "./tpch-data/scale-10/",
100: "./tpch-data/scale-100/",
}

if local:
Expand Down Expand Up @@ -167,7 +167,7 @@ def cluster_spec(scale):
if scale == 10:
return {
"worker_vm_types": ["m6i.large"],
"n_workers": 16,
"n_workers": 8,
**everywhere,
}
elif scale == 100:
Expand All @@ -178,7 +178,7 @@ def cluster_spec(scale):
}
elif scale == 1000:
return {
"worker_vm_types": ["m6i.large"],
"worker_vm_types": ["m6i.xlarge"],
"n_workers": 32,
**everywhere,
}
Expand All @@ -203,8 +203,9 @@ def cluster(
make_chart,
):
if local:
with LocalCluster() as cluster:
yield cluster
with dask.config.set({"distributed.scheduler.worker-saturation": 4}):
with LocalCluster() as cluster:
yield cluster
else:
kwargs = dict(
name=f"tpch-{module}-{scale}-{name}",
Expand Down Expand Up @@ -317,15 +318,15 @@ def fs(local):
def machine_spec(scale):
if scale == 10:
return {
"vm_type": "m6i.8xlarge",
"vm_type": "m6i.4xlarge",
}
elif scale == 100:
return {
"vm_type": "m6i.8xlarge",
}
elif scale == 1000:
return {
"vm_type": "m6i.16xlarge",
"vm_type": "m6i.32xlarge",
}
elif scale == 10000:
return {
Expand Down Expand Up @@ -399,7 +400,8 @@ def make_chart(request, name, tmp_path_factory, local, scale):

with lock:
generate(
outfile=os.path.join("charts", f"{local}-{scale}-query-{name}.json"),
outfile=os.path.join("charts", f"{local}-{scale}-{name}.json"),
name=name,
scale=scale,
local=local,
)
19 changes: 14 additions & 5 deletions tests/tpch/generate-data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import multiprocessing
import pathlib
import tempfile
import warnings
Expand Down Expand Up @@ -49,14 +50,22 @@ def generate(
# workload is best with 1vCPU and ~3-4GiB memory
worker_vm_types=["m7a.medium", "m3.medium"],
worker_options={"nthreads": 1},
spot_policy="spot_with_fallback",
region=REGION,
) as cluster:
cluster.adapt(minimum=1, maximum=350)
cluster.adapt(minimum=1, maximum=500)
with cluster.get_client() as client:
jobs = client.map(_tpch_data_gen, range(0, scale), **kwargs)
client.gather(jobs)
else:
_tpch_data_gen(step=None, **kwargs)
with dask.distributed.LocalCluster(
threads_per_worker=1,
memory_limit=dask.distributed.system.MEMORY_LIMIT,
n_workers=multiprocessing.cpu_count() // 2,
) as cluster:
with cluster.get_client() as client:
jobs = client.map(_tpch_data_gen, range(0, scale), **kwargs)
client.gather(jobs)


def retry(f):
Expand Down Expand Up @@ -116,8 +125,8 @@ def _tpch_data_gen(
con.sql(
f"""
SET memory_limit='{psutil.virtual_memory().available // 2**30 }G';
SET preserve_insertion_order=false;
SET threads TO 1;
SET preserve_insertion_order=false;
SET enable_progress_bar=false;
"""
)
Expand Down Expand Up @@ -166,8 +175,8 @@ def _tpch_data_gen(
(format parquet, per_thread_output true, filename_pattern "{table}_{{uuid}}", overwrite_or_ignore)
"""
)
print(f"Finished exporting table {table}!")
print("Finished exporting all data!")
print(f"Finished exporting table {table}")
print("Finished exporting all data")


def rows_approx_mb(con, table_name, partition_size: str):
Expand Down
6 changes: 4 additions & 2 deletions tests/tpch/generate_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd


def generate(outfile="chart.json", name=None, scale=None):
def generate(outfile="chart.json", name=None, scale=None, local=None):
df = pd.read_sql_table(table_name="test_run", con="sqlite:///benchmark.db")

df = df[
Expand Down Expand Up @@ -47,7 +47,9 @@ def recent(df):
),
tooltip=["library", "duration"],
)
.properties(title=f"TPC-H -- scale:{df.scale.iloc[0]} name:{df.name.iloc[0]}")
.properties(
title=f"TPC-H: {local} scale {df.scale.iloc[0]} -- {df.name.iloc[0]}"
)
.configure_title(
fontSize=20,
)
Expand Down
46 changes: 23 additions & 23 deletions tests/tpch/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def test_query_1(client, dataset_path, fs):
VAR1 = datetime(1998, 9, 2)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem")

lineitem_filtered = lineitem_ds[lineitem_ds.l_shipdate <= VAR1]
lineitem_filtered["sum_qty"] = lineitem_filtered.l_quantity
Expand Down Expand Up @@ -45,11 +45,11 @@ def test_query_2(client, dataset_path, fs):
var2 = "BRASS"
var3 = "EUROPE"

region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs)
nation_filtered = dd.read_parquet(dataset_path + "nation", filesystem=fs)
supplier_filtered = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
part_filtered = dd.read_parquet(dataset_path + "part", filesystem=fs)
partsupp_filtered = dd.read_parquet(dataset_path + "partsupp", filesystem=fs)
region_ds = dd.read_parquet(dataset_path + "region")
nation_filtered = dd.read_parquet(dataset_path + "nation")
supplier_filtered = dd.read_parquet(dataset_path + "supplier")
part_filtered = dd.read_parquet(dataset_path + "part")
partsupp_filtered = dd.read_parquet(dataset_path + "partsupp")

region_filtered = region_ds[(region_ds["r_name"] == var3)]
r_n_merged = nation_filtered.merge(
Expand Down Expand Up @@ -112,9 +112,9 @@ def test_query_3(client, dataset_path, fs):
var1 = datetime.strptime("1995-03-15", "%Y-%m-%d")
var2 = "BUILDING"

lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
cutomer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem")
orders_ds = dd.read_parquet(dataset_path + "orders")
cutomer_ds = dd.read_parquet(dataset_path + "customer")

lsel = lineitem_ds.l_shipdate > var1
osel = orders_ds.o_orderdate < var1
Expand All @@ -137,8 +137,8 @@ def test_query_4(client, dataset_path, fs):
date1 = datetime.strptime("1993-10-01", "%Y-%m-%d")
date2 = datetime.strptime("1993-07-01", "%Y-%m-%d")

line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem")
orders_ds = dd.read_parquet(dataset_path + "orders")

lsel = line_item_ds.l_commitdate < line_item_ds.l_receiptdate
osel = (orders_ds.o_orderdate < date1) & (orders_ds.o_orderdate >= date2)
Expand All @@ -160,12 +160,12 @@ def test_query_5(client, dataset_path, fs):
date1 = datetime.strptime("1994-01-01", "%Y-%m-%d")
date2 = datetime.strptime("1995-01-01", "%Y-%m-%d")

region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs)
nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
region_ds = dd.read_parquet(dataset_path + "region")
nation_ds = dd.read_parquet(dataset_path + "nation")
customer_ds = dd.read_parquet(dataset_path + "customer")
line_item_ds = dd.read_parquet(dataset_path + "lineitem")
orders_ds = dd.read_parquet(dataset_path + "orders")
supplier_ds = dd.read_parquet(dataset_path + "supplier")

rsel = region_ds.r_name == "ASIA"
osel = (orders_ds.o_orderdate >= date1) & (orders_ds.o_orderdate < date2)
Expand All @@ -190,7 +190,7 @@ def test_query_6(client, dataset_path, fs):
date2 = datetime.strptime("1995-01-01", "%Y-%m-%d")
var3 = 24

line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem")

sel = (
(line_item_ds.l_shipdate >= date1)
Expand All @@ -208,11 +208,11 @@ def test_query_7(client, dataset_path, fs):
var1 = datetime.strptime("1995-01-01", "%Y-%m-%d")
var2 = datetime.strptime("1997-01-01", "%Y-%m-%d")

nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
nation_ds = dd.read_parquet(dataset_path + "nation")
customer_ds = dd.read_parquet(dataset_path + "customer")
line_item_ds = dd.read_parquet(dataset_path + "lineitem")
orders_ds = dd.read_parquet(dataset_path + "orders")
supplier_ds = dd.read_parquet(dataset_path + "supplier")

lineitem_filtered = line_item_ds[
(line_item_ds["l_shipdate"] >= var1) & (line_item_ds["l_shipdate"] < var2)
Expand Down
8 changes: 4 additions & 4 deletions tests/tpch/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@


def read_data(filename):
pyarrow_dataset = dataset(filename, format="parquet")
return pl.scan_pyarrow_dataset(pyarrow_dataset)

if filename.startswith("s3://"):
pyarrow_dataset = dataset(filename, format="parquet")
return pl.scan_pyarrow_dataset(pyarrow_dataset)
import boto3

session = boto3.session.Session()
credentials = session.get_credentials()
return pl.scan_parquet(
filename,
filename + "/*",
storage_options={
"aws_access_key_id": credentials.access_key,
"aws_secret_access_key": credentials.secret_key,
"region": "us-east-2",
"session_token": credentials.token,
},
)
else:
Expand Down