Skip to content

Commit fe61844

Browse files
committed
Fixes serialize for extract fields
1 parent 1057a40 commit fe61844

File tree

3 files changed

+71
-26
lines changed

3 files changed

+71
-26
lines changed

examples/parallelism/file_processing/run.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
import warnings
2+
3+
from hamilton.lifecycle import SlowDownYouMoveTooFast
4+
5+
warnings.simplefilter(action="ignore", category=FutureWarning)
6+
17
import logging
28

39
import aggregate_data
@@ -9,9 +15,9 @@
915

1016
from hamilton import driver, log_setup
1117
from hamilton.execution import executors
12-
from hamilton.plugins import h_dask, h_ray
18+
from hamilton.plugins import h_dask, h_ray, h_rich
1319

14-
log_setup.setup_logging(logging.INFO)
20+
log_setup.setup_logging(logging.FATAL)
1521

1622

1723
@click.command()
@@ -39,6 +45,10 @@ def main(mode: str):
3945
driver.Builder()
4046
.enable_dynamic_execution(allow_experimental_mode=True)
4147
.with_remote_executor(remote_executor) # We only need to specify remote exeecutor
48+
.with_adapters(
49+
h_rich.RichProgressBar(),
50+
SlowDownYouMoveTooFast(sleep_time_mean=0.3, sleep_time_std=0.1),
51+
)
4252
# The local executor just runs it synchronously
4353
.with_modules(aggregate_data, list_data, process_data)
4454
.build()

hamilton/function_modifiers/expanders.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,36 @@ def _validate_extract_fields(fields: dict):
731731
)
732732

733733

734+
async def dict_generator_async(
735+
*args,
736+
fn,
737+
fill_with,
738+
fields,
739+
**kwargs,
740+
):
741+
dict_generated = await fn(*args, **kwargs)
742+
if fill_with is not None:
743+
for field in fields:
744+
if field not in dict_generated:
745+
dict_generated[field] = fill_with
746+
return dict_generated
747+
748+
749+
async def dict_generator(
750+
*args,
751+
fn,
752+
fill_with,
753+
fields,
754+
**kwargs,
755+
):
756+
dict_generated = fn(*args, **kwargs)
757+
if fill_with is not None:
758+
for field in fields:
759+
if field not in dict_generated:
760+
dict_generated[field] = fill_with
761+
return dict_generated
762+
763+
734764
class extract_fields(base.SingleNodeNodeTransformer):
735765
"""Extracts fields from a dictionary of output."""
736766

@@ -804,29 +834,35 @@ def transform_node(
804834
"""
805835
fn = node_.callable
806836
base_doc = node_.documentation
807-
837+
dict_generator_fn = (
838+
functools.partial(dict_generator, fn=fn, fill_with=self.fill_with, fields=self.fields)
839+
if not (inspect.iscoroutinefunction(fn))
840+
else functools.partial(
841+
dict_generator_async, fn=fn, fill_with=self.fill_with, fields=self.fields
842+
)
843+
)
808844
# if fn is async
809-
if inspect.iscoroutinefunction(fn):
810-
811-
async def dict_generator(*args, **kwargs):
812-
dict_generated = await fn(*args, **kwargs)
813-
if self.fill_with is not None:
814-
for field in self.fields:
815-
if field not in dict_generated:
816-
dict_generated[field] = self.fill_with
817-
return dict_generated
818-
819-
else:
820-
821-
def dict_generator(*args, **kwargs):
822-
dict_generated = fn(*args, **kwargs)
823-
if self.fill_with is not None:
824-
for field in self.fields:
825-
if field not in dict_generated:
826-
dict_generated[field] = self.fill_with
827-
return dict_generated
828-
829-
output_nodes = [node_.copy_with(callabl=dict_generator)]
845+
# if inspect.iscoroutinefunction(fn):
846+
#
847+
# async def dict_generator(*args, **kwargs):
848+
# dict_generated = await fn(*args, **kwargs)
849+
# if self.fill_with is not None:
850+
# for field in self.fields:
851+
# if field not in dict_generated:
852+
# dict_generated[field] = self.fill_with
853+
# return dict_generated
854+
#
855+
# else:
856+
#
857+
# def dict_generator(*args, **kwargs):
858+
# dict_generated = fn(*args, **kwargs)
859+
# if self.fill_with is not None:
860+
# for field in self.fields:
861+
# if field not in dict_generated:
862+
# dict_generated[field] = self.fill_with
863+
# return dict_generated
864+
865+
output_nodes = [node_.copy_with(callabl=dict_generator_fn)]
830866

831867
for field, field_type in self.fields.items():
832868
doc_string = base_doc # default doc string of base function.

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ docs = [
5757
"diskcache",
5858
# required for all the plugins
5959
"dlt",
60-
# furo -- install from main for now until the next release is out:
61-
"furo @ git+https://github.com/pradyunsg/furo@main",
60+
"furo",
6261
"gitpython", # Required for parsing git info for generation of data-adapter docs
6362
"grpcio-status",
6463
"lightgbm",

0 commit comments

Comments
 (0)