Skip to content

Commit

Permalink
Reading tables with a dask-cudf DataFrame (#224)
Browse files Browse the repository at this point in the history
* add gpu param

* hive and code coverage

* Update pandaslike.py
  • Loading branch information
sarahyurick authored Aug 25, 2021
1 parent 4dab949 commit ece7ec7
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 10 deletions.
2 changes: 2 additions & 0 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def create_table(
format: str = None,
persist: bool = True,
schema_name: str = None,
gpu: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -199,6 +200,7 @@ def create_table(
table_name=table_name,
format=format,
persist=persist,
gpu=gpu,
**kwargs,
)
self.schema[schema_name].tables[table_name.lower()] = dc
Expand Down
12 changes: 9 additions & 3 deletions dask_sql/input_utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def to_dc(
table_name: str,
format: str = None,
persist: bool = True,
gpu: bool = False,
**kwargs,
) -> DataContainer:
"""
Expand All @@ -49,7 +50,7 @@ def to_dc(
maybe persist them to cluster memory before.
"""
filled_get_dask_dataframe = lambda *args: cls._get_dask_dataframe(
*args, table_name=table_name, format=format, **kwargs,
*args, table_name=table_name, format=format, gpu=gpu, **kwargs,
)

if isinstance(input_item, list):
Expand All @@ -64,7 +65,12 @@ def to_dc(

@classmethod
def _get_dask_dataframe(
cls, input_item: InputType, table_name: str, format: str = None, **kwargs,
cls,
input_item: InputType,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):
plugin_list = cls.get_plugins()

Expand All @@ -73,7 +79,7 @@ def _get_dask_dataframe(
input_item, table_name=table_name, format=format, **kwargs
):
return plugin.to_dc(
input_item, table_name=table_name, format=format, **kwargs
input_item, table_name=table_name, format=format, gpu=gpu, **kwargs
)

raise ValueError(f"Do not understand the input type {type(input_item)}")
12 changes: 11 additions & 1 deletion dask_sql/input_utils/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,17 @@ def is_correct_input(

return is_sqlalchemy_hive or is_hive_cursor or format == "hive"

def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
def to_dc(
self,
input_item: Any,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):
if gpu: # pragma: no cover
raise Exception("Hive does not support gpu")

table_name = kwargs.pop("hive_table_name", table_name)
schema = kwargs.pop("hive_schema_name", "default")

Expand Down
14 changes: 12 additions & 2 deletions dask_sql/input_utils/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,21 @@ def is_correct_input(
isinstance(input_item, intake.catalog.Catalog) or format == "intake"
)

def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
def to_dc(
self,
input_item: Any,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):
table_name = kwargs.pop("intake_table_name", table_name)
catalog_kwargs = kwargs.pop("catalog_kwargs", {})

if isinstance(input_item, str):
input_item = intake.open_catalog(input_item, **catalog_kwargs)

return input_item[table_name].to_dask(**kwargs)
if gpu: # pragma: no cover
raise Exception("Intake does not support gpu")
else:
return input_item[table_name].to_dask(**kwargs)
16 changes: 14 additions & 2 deletions dask_sql/input_utils/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ def is_correct_input(
):
return isinstance(input_item, str)

def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
def to_dc(
self,
input_item: Any,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):

if format == "memory":
client = default_client()
Expand All @@ -27,7 +34,12 @@ def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
format = extension.lstrip(".")

try:
read_function = getattr(dd, f"read_{format}")
if gpu: # pragma: no cover
import dask_cudf

read_function = getattr(dask_cudf, f"read_{format}")
else:
read_function = getattr(dd, f"read_{format}")
except AttributeError:
raise AttributeError(f"Can not read files of format {format}")

Expand Down
23 changes: 21 additions & 2 deletions dask_sql/input_utils/pandaslike.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ def is_correct_input(
is_cudf_type = cudf and isinstance(input_item, cudf.DataFrame)
return is_cudf_type or isinstance(input_item, pd.DataFrame) or format == "dask"

def to_dc(self, input_item, table_name: str, format: str = None, **kwargs):
def to_dc(
self,
input_item,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):
npartitions = kwargs.pop("npartitions", 1)
return dd.from_pandas(input_item, npartitions=npartitions, **kwargs)
if gpu: # pragma: no cover
import dask_cudf

if isinstance(input_item, pd.DataFrame):
return dask_cudf.from_cudf(
cudf.from_pandas(input_item), npartitions=npartitions, **kwargs,
)
else:
return dask_cudf.from_cudf(
input_item, npartitions=npartitions, **kwargs,
)
else:
return dd.from_pandas(input_item, npartitions=npartitions, **kwargs)
2 changes: 2 additions & 0 deletions dask_sql/physical/rel/custom/create_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ def convert(
except KeyError:
raise AttributeError("Parameters must include a 'location' parameter.")

gpu = kwargs.pop("gpu", False)
context.create_table(
table_name,
location,
format=format,
persist=persist,
schema_name=schema_name,
gpu=gpu,
**kwargs,
)

0 comments on commit ece7ec7

Please sign in to comment.