diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs index 5db783f5e1..da06f04887 100644 --- a/src/daft-connect/src/op/execute/write.rs +++ b/src/daft-connect/src/op/execute/write.rs @@ -55,9 +55,7 @@ impl Session { bail!("Source is required"); }; - if source != "parquet" { - bail!("Unsupported source: {source}; only parquet is supported"); - } + let file_format: FileFormat = source.parse()?; let Ok(mode) = SaveMode::try_from(mode) else { bail!("Invalid save mode: {mode}"); @@ -115,7 +113,7 @@ impl Session { let plan = translator.to_logical_plan(input).await?; let plan = plan - .table_write(&path, FileFormat::Parquet, None, None, None) + .table_write(&path, file_format, None, None, None) .wrap_err("Failed to create table write plan")?; let optimized_plan = plan.optimize()?; diff --git a/src/daft-connect/src/translation/logical_plan/read/data_source.rs b/src/daft-connect/src/translation/logical_plan/read/data_source.rs index 90164dd0bd..863b5e8f1d 100644 --- a/src/daft-connect/src/translation/logical_plan/read/data_source.rs +++ b/src/daft-connect/src/translation/logical_plan/read/data_source.rs @@ -1,5 +1,5 @@ use daft_logical_plan::LogicalPlanBuilder; -use daft_scan::builder::ParquetScanBuilder; +use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder}; use eyre::{bail, ensure, WrapErr}; use tracing::warn; @@ -18,10 +18,6 @@ pub async fn data_source( bail!("Format is required"); }; - if format != "parquet" { - bail!("Unsupported format: {format}; only parquet is supported"); - } - ensure!(!paths.is_empty(), "Paths are required"); if let Some(schema) = schema { @@ -36,10 +32,23 @@ pub async fn data_source( warn!("Ignoring predicates: {predicates:?}; not yet implemented"); } - let builder = ParquetScanBuilder::new(paths) - .finish() - .await - .wrap_err("Failed to create parquet scan builder")?; + let plan = match &*format { + "parquet" => ParquetScanBuilder::new(paths) + .finish() + .await + .wrap_err("Failed to create parquet scan builder")?, + "csv" => CsvScanBuilder::new(paths) + .finish() + .await + .wrap_err("Failed to create csv scan builder")?, + "json" => { + // todo(completeness): implement json reading + bail!("json reading is not yet implemented"); + } + other => { + bail!("Unsupported format: {other}; only parquet and csv are supported"); + } + }; - Ok(builder) + Ok(plan) } diff --git a/tests/connect/test_csv.py b/tests/connect/test_csv.py new file mode 100644 index 0000000000..7e957dd394 --- /dev/null +++ b/tests/connect/test_csv.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import os + +import pytest + + +def test_write_csv_basic(spark_session, tmp_path): + df = spark_session.range(10) + csv_dir = os.path.join(tmp_path, "csv") + df.write.csv(csv_dir) + + csv_files = [f for f in os.listdir(csv_dir) if f.endswith(".csv")] + assert len(csv_files) > 0, "Expected at least one CSV file to be written" + + df_read = spark_session.read.csv(str(csv_dir)) + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read" + + +def test_write_csv_with_header(spark_session, tmp_path): + df = spark_session.range(10) + csv_dir = os.path.join(tmp_path, "csv") + df.write.option("header", True).csv(csv_dir) + + df_read = spark_session.read.option("header", True).csv(str(csv_dir)) + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"]) + + +def test_write_csv_with_delimiter(spark_session, tmp_path): + df = spark_session.range(10) + csv_dir = os.path.join(tmp_path, "csv") + df.write.option("sep", "|").csv(csv_dir) + + df_read = spark_session.read.option("sep", "|").csv(str(csv_dir)) + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"]) + + +def test_write_csv_with_quote(spark_session, tmp_path): + df = spark_session.createDataFrame([("a,b",), ("c'd",)], ["text"]) + csv_dir = os.path.join(tmp_path, "csv") + df.write.option("quote", "'").csv(csv_dir) + + df_read = spark_session.read.option("quote", "'").csv(str(csv_dir)) + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["text"].equals(df_read_pandas["text"]) + + +def test_write_csv_with_escape(spark_session, tmp_path): + df = spark_session.createDataFrame([("a'b",), ("c'd",)], ["text"]) + csv_dir = os.path.join(tmp_path, "csv") + df.write.option("escape", "\\").csv(csv_dir) + + df_read = spark_session.read.option("escape", "\\").csv(str(csv_dir)) + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["text"].equals(df_read_pandas["text"]) + + +@pytest.mark.skip( + reason="https://github.com/Eventual-Inc/Daft/issues/3609: CSV null value handling not yet implemented" +) +def test_write_csv_with_null_value(spark_session, tmp_path): + df = spark_session.createDataFrame([(1, None), (2, "test")], ["id", "value"]) + csv_dir = os.path.join(tmp_path, "csv") + df.write.option("nullValue", "NULL").csv(csv_dir) + + df_read = spark_session.read.option("nullValue", "NULL").csv(str(csv_dir)) + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["value"].isna().equals(df_read_pandas["value"].isna()) + + +def test_write_csv_with_compression(spark_session, tmp_path): + df = spark_session.range(10) + csv_dir = os.path.join(tmp_path, "csv") + df.write.option("compression", "gzip").csv(csv_dir) + + df_read = spark_session.read.csv(str(csv_dir)) + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"])