Skip to content

Commit 87810f1

Browse files
committed
Adds create_temp_table to allow using dataframes in sessions
1 parent 8706dcf commit 87810f1

File tree

11 files changed

+237
-68
lines changed

11 files changed

+237
-68
lines changed

daft/catalog/__init__.py

+75-9
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,16 @@
4242

4343
from abc import ABC, abstractmethod
4444
from collections.abc import Sequence
45-
from daft.daft import catalog as native_catalog
45+
from daft.daft import PyTableSource, catalog as native_catalog
4646
from daft.daft import PyIdentifier, PyTable
4747
from daft.logical.builder import LogicalPlanBuilder
4848

4949
from daft.dataframe import DataFrame
5050

5151
from typing import TYPE_CHECKING
5252

53+
from daft.logical.schema import Schema
54+
5355
if TYPE_CHECKING:
5456
from daft.dataframe.dataframe import ColumnInputType
5557

@@ -286,9 +288,11 @@ class Table(ABC):
286288
"""Interface for python table implementations."""
287289

288290
@staticmethod
289-
def from_df(dataframe: DataFrame) -> Table:
291+
def from_df(name: str, dataframe: DataFrame) -> Table:
290292
"""Returns a read-only table backed by the DataFrame."""
291-
return PyTable.from_builder(dataframe._builder._builder)
293+
from daft.catalog.__memory import MemoryTable
294+
295+
return MemoryTable(name, dataframe)
292296

293297
@staticmethod
294298
def from_iceberg(obj: object) -> Table:
@@ -314,10 +318,24 @@ def from_unity(obj: object) -> Table:
314318
@staticmethod
315319
def _from_obj(obj: object) -> Table:
316320
"""Returns a Daft Table from a supported object type or raises an error."""
317-
if isinstance(obj, DataFrame):
318-
return Table.from_df(obj)
319321
raise ValueError(f"Unsupported table type: {type(obj)}")
320322

323+
# TODO catalog APIs part 3
324+
# @property
325+
# @abstractmethod
326+
# def name(self) -> str:
327+
# """Returns the table name."""
328+
329+
# TODO catalog APIs part 3
330+
# @property
331+
# @abstractmethod
332+
# def inner(self) -> object | None:
333+
# """Returns the inner table object if this is an adapter."""
334+
335+
@abstractmethod
336+
def read(self) -> DataFrame:
337+
"""Returns a DataFrame from this table."""
338+
321339
# TODO deprecated catalog APIs #3819
322340
def to_dataframe(self) -> DataFrame:
323341
"""DEPRECATED: Please use `read` instead; version 0.5.0!"""
@@ -327,14 +345,62 @@ def to_dataframe(self) -> DataFrame:
327345
)
328346
return self.read()
329347

330-
@abstractmethod
331-
def read(self) -> DataFrame:
332-
"""Returns a DataFrame from this table."""
333-
334348
def select(self, *columns: ColumnInputType) -> DataFrame:
335349
"""Returns a DataFrame from this table with the selected columns."""
336350
return self.read().select(*columns)
337351

338352
def show(self, n: int = 8) -> None:
339353
"""Shows the first n rows from this table."""
340354
return self.read().show(n)
355+
356+
357+
class TableSource:
358+
"""A TableSource is used to create a new table; this could be a Schema or DataFrame."""
359+
360+
_source: PyTableSource
361+
362+
def __init__(self) -> None:
363+
raise ValueError("We do not support creating a TableSource via __init__")
364+
365+
@staticmethod
366+
def from_df(df: DataFrame) -> TableSource:
367+
s = TableSource.__new__(TableSource)
368+
s._source = PyTableSource.from_builder(df._builder._builder)
369+
return s
370+
371+
@staticmethod
372+
def _from_obj(obj: object = None) -> TableSource:
373+
# TODO for future sources, consider https://github.com/Eventual-Inc/Daft/pull/2864
374+
if obj is None:
375+
return TableSource._from_none()
376+
elif isinstance(obj, DataFrame):
377+
return TableSource.from_df(obj)
378+
elif isinstance(obj, str):
379+
return TableSource._from_path(obj)
380+
elif isinstance(obj, Schema):
381+
return TableSource._from_schema(obj)
382+
else:
383+
raise Exception(f"Unknown table source: {obj}")
384+
385+
@staticmethod
386+
def _from_none() -> TableSource:
387+
# for creating temp mutable tables, but we don't have those yet
388+
# s = TableSource.__new__(TableSource)
389+
# s._source = PyTableSource.empty()
390+
# return s
391+
# todo temp workaround just use an empty schema
392+
return TableSource._from_schema(Schema._from_fields([]))
393+
394+
@staticmethod
395+
def _from_schema(schema: Schema) -> TableSource:
396+
# we don't have mutable temp tables, so just make an empty view
397+
# s = TableSource.__new__(TableSource)
398+
# s._source = PyTableSource.from_schema(schema._schema)
399+
# return s
400+
# todo temp workaround until create_table is wired
401+
return TableSource.from_df(DataFrame._from_pylist([]))
402+
403+
@staticmethod
404+
def _from_path(path: str) -> TableSource:
405+
# for supporting daft.create_table("t", "/path/to/data") <-> CREATE TABLE t AS '/path/to/my.data'
406+
raise NotImplementedError("creating a table source from a path is not yet supported.")

daft/catalog/__memory.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ def get_table(self, name: str | Identifier) -> Table:
4141
class MemoryTable(Table):
4242
"""An in-memory table holds a reference to an existing dataframe."""
4343

44+
_name: str
4445
_inner: DataFrame
4546

46-
def __init__(self, inner: DataFrame):
47+
def __init__(self, name: str, inner: DataFrame):
48+
self._name = name
4749
self._inner = inner
4850

4951
###

daft/daft/__init__.pyi

+5-2
Original file line numberDiff line numberDiff line change
@@ -1867,11 +1867,13 @@ class PyIdentifier:
18671867
def __repr__(self) -> str: ...
18681868

18691869
class PyTable(Table):
1870+
def read(self):
1871+
LogicalPlanBuilder
1872+
1873+
class PyTableSource:
18701874
@staticmethod
18711875
def from_builder(builder: LogicalPlanBuilder):
18721876
PyTable
1873-
def read(self):
1874-
LogicalPlanBuilder
18751877

18761878
###
18771879
# daft-session
@@ -1885,6 +1887,7 @@ class PySession:
18851887
def attach_table(self, table: Table, alias: str): ...
18861888
def detach_catalog(self, alias: str): ...
18871889
def detach_table(self, alias: str): ...
1890+
def create_temp_table(self, name: str, source: PyTableSource, replace: bool): ...
18881891
def current_catalog(self) -> Catalog: ...
18891892
def get_catalog(self, name: str) -> Catalog: ...
18901893
def get_table(self, name: PyIdentifier) -> Table: ...

daft/session.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING
44

5-
from daft.catalog import Catalog, Identifier, Table
5+
from daft.catalog import Catalog, Identifier, Table, TableSource
66
from daft.daft import PySession
77

88
if TYPE_CHECKING:
@@ -15,19 +15,12 @@ class Session:
1515
_session: PySession
1616

1717
def __init__(self):
18-
raise NotImplementedError("We do not support creating a Session via __init__ ")
18+
self._session = PySession.empty()
1919

2020
###
2121
# factory methods
2222
###
2323

24-
@staticmethod
25-
def empty() -> Session:
26-
"""Creates an empty session."""
27-
s = Session.__new__(Session)
28-
s._session = PySession.empty()
29-
return s
30-
3124
@staticmethod
3225
def _from_pysession(session: PySession) -> Session:
3326
"""Creates a session from a rust session wrapper."""
@@ -39,7 +32,7 @@ def _from_pysession(session: PySession) -> Session:
3932
def _from_env() -> Session:
4033
"""Creates a session from the environment's configuration."""
4134
# todo session builders, raise if DAFT_SESSION=0
42-
return Session.empty()
35+
return Session()
4336

4437
###
4538
# attach & detach
@@ -63,6 +56,15 @@ def detach_table(self, alias: str):
6356
"""Detaches the table from this session."""
6457
return self._session.detach_table(alias)
6558

59+
###
60+
# create_*
61+
###
62+
63+
def create_temp_table(self, name: str, source: object | TableSource = None) -> Table:
64+
"""Creates a temp table scoped to this session's lifetime."""
65+
s = source if isinstance(source, TableSource) else TableSource._from_obj(source)
66+
return self._session.create_temp_table(name, s._source, replace=True)
67+
6668
###
6769
# session state
6870
###

src/daft-catalog/src/python.rs

+34-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use std::sync::Arc;
22

3-
use daft_core::prelude::SchemaRef;
3+
use daft_core::{prelude::SchemaRef, python::PySchema};
44
use daft_logical_plan::{LogicalPlanRef, PyLogicalPlanBuilder};
55
use pyo3::{exceptions::PyIndexError, intern, prelude::*};
66

77
use crate::{
8-
error::Result, global_catalog, Catalog, CatalogRef, Identifier, Table, TableRef, View,
8+
error::Result, global_catalog, Catalog, CatalogRef, Identifier, Table, TableRef, TableSource,
9+
View,
910
};
1011

1112
/// Read a table from the specified `DaftMetaCatalog`.
@@ -277,10 +278,41 @@ impl Table for PyTableWrapper {
277278
}
278279
}
279280

281+
/// PyTableSource wraps either a schema or dataframe.
282+
#[pyclass]
283+
pub struct PyTableSource(TableSource);
284+
285+
impl From<TableSource> for PyTableSource {
286+
fn from(source: TableSource) -> Self {
287+
Self(source)
288+
}
289+
}
290+
291+
/// PyTableSource -> TableSource
292+
impl AsRef<TableSource> for PyTableSource {
293+
fn as_ref(&self) -> &TableSource {
294+
&self.0
295+
}
296+
}
297+
298+
#[pymethods]
299+
impl PyTableSource {
300+
#[staticmethod]
301+
pub fn from_schema(schema: PySchema) -> PyTableSource {
302+
Self(TableSource::Schema(schema.schema))
303+
}
304+
305+
#[staticmethod]
306+
pub fn from_builder(view: &PyLogicalPlanBuilder) -> PyTableSource {
307+
Self(TableSource::View(view.builder.build()))
308+
}
309+
}
310+
280311
pub fn register_modules<'py>(parent: &Bound<'py, PyModule>) -> PyResult<Bound<'py, PyModule>> {
281312
parent.add_class::<PyCatalog>()?;
282313
parent.add_class::<PyIdentifier>()?;
283314
parent.add_class::<PyTable>()?;
315+
parent.add_class::<PyTableSource>()?;
284316
// TODO deprecated catalog APIs #3819
285317
let module = PyModule::new(parent.py(), "catalog")?;
286318
module.add_wrapped(wrap_pyfunction!(py_read_table))?;

src/daft-connect/src/execute.rs

+9-12
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{future::ready, rc::Rc, sync::Arc};
22

33
use common_error::DaftResult;
44
use common_file_formats::FileFormat;
5-
use daft_catalog::View;
5+
use daft_catalog::TableSource;
66
use daft_context::get_context;
77
use daft_dsl::LiteralValue;
88
use daft_logical_plan::LogicalPlanBuilder;
@@ -234,19 +234,16 @@ impl ConnectSession {
234234
)
235235
})?;
236236

237-
{
238-
// TODO session should handle the pre-existence error
239-
if !replace && self.session().has_table(&name.clone().into()) {
240-
return Err(Status::internal("Dataframe view already exists"));
241-
}
242-
}
243-
244237
let session = self.session_mut();
245-
let view = View::from(input.build()).arced();
238+
let source = TableSource::from(input);
246239

247-
session.attach_table(view, name).map_err(|e| {
248-
Status::internal(textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"))
249-
})?;
240+
session
241+
.create_temp_table(name, &source, replace)
242+
.map_err(|e| {
243+
Status::internal(
244+
textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"),
245+
)
246+
})?;
250247

251248
let response = rb.result_complete_response();
252249
let stream = stream::once(ready(Ok(response)));

src/daft-session/src/python.rs

+18-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use daft_catalog::python::{PyCatalogWrapper, PyIdentifier, PyTableWrapper};
1+
use daft_catalog::python::{
2+
PyCatalogWrapper, PyIdentifier, PyTable, PyTableSource, PyTableWrapper,
3+
};
24
use pyo3::prelude::*;
35

46
use crate::Session;
@@ -19,8 +21,8 @@ impl PySession {
1921
.attach_catalog(PyCatalogWrapper::wrap(catalog), alias)?)
2022
}
2123

22-
pub fn current_catalog(&self, py: Python<'_>) -> PyResult<PyObject> {
23-
self.0.current_catalog()?.to_py(py)
24+
pub fn attach_table(&self, table: PyObject, alias: String) -> PyResult<()> {
25+
Ok(self.0.attach_table(PyTableWrapper::wrap(table), alias)?)
2426
}
2527

2628
pub fn detach_catalog(&self, alias: &str) -> PyResult<()> {
@@ -31,8 +33,19 @@ impl PySession {
3133
Ok(self.0.detach_table(alias)?)
3234
}
3335

34-
pub fn attach_table(&self, table: PyObject, alias: String) -> PyResult<()> {
35-
Ok(self.0.attach_table(PyTableWrapper::wrap(table), alias)?)
36+
pub fn create_temp_table(
37+
&self,
38+
name: String,
39+
source: &PyTableSource,
40+
replace: bool,
41+
) -> PyResult<PyTable> {
42+
let table = self.0.create_temp_table(name, source.as_ref(), replace)?;
43+
let table = PyTable::new(table);
44+
Ok(table)
45+
}
46+
47+
pub fn current_catalog(&self, py: Python<'_>) -> PyResult<PyObject> {
48+
self.0.current_catalog()?.to_py(py)
3649
}
3750

3851
pub fn get_catalog(&self, py: Python<'_>, name: &str) -> PyResult<PyObject> {

src/daft-session/src/session.rs

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
22

3-
use daft_catalog::{Bindings, CatalogRef, Identifier, TableRef};
3+
use daft_catalog::{Bindings, CatalogRef, Identifier, TableRef, TableSource, View};
44
use uuid::Uuid;
55

66
use crate::{
@@ -72,6 +72,32 @@ impl Session {
7272
Ok(())
7373
}
7474

75+
/// Creates a temp table scoped to this session from an existing view.
76+
///
77+
/// TODO feat: consider making a CreateTableSource object for more complicated options.
78+
///
79+
/// ```
80+
/// CREATE [OR REPLACE] TEMP TABLE [IF NOT EXISTS] <name> <source>;
81+
/// ```
82+
pub fn create_temp_table(
83+
&self,
84+
name: impl Into<String>,
85+
source: &TableSource,
86+
replace: bool,
87+
) -> Result<TableRef> {
88+
let name = name.into();
89+
if !replace && self.state().tables.exists(&name) {
90+
obj_already_exists_err!("Temporary table", &name.into())
91+
}
92+
// we don't have mutable temporary tables, only immutable views over dataframes.
93+
let table = match source {
94+
TableSource::Schema(_) => unsupported_err!("temporary table with schema"),
95+
TableSource::View(plan) => View::from(plan.clone()).arced(),
96+
};
97+
self.state_mut().tables.insert(name, table.clone());
98+
Ok(table)
99+
}
100+
75101
/// Returns the session's current catalog.
76102
pub fn current_catalog(&self) -> Result<CatalogRef> {
77103
self.get_catalog(&self.state().options.curr_catalog)

0 commit comments

Comments
 (0)