@@ -731,6 +731,36 @@ def _validate_extract_fields(fields: dict):
731731 )
732732
733733
734+ async def dict_generator_async (
735+ * args ,
736+ fn ,
737+ fill_with ,
738+ fields ,
739+ ** kwargs ,
740+ ):
741+ dict_generated = await fn (* args , ** kwargs )
742+ if fill_with is not None :
743+ for field in fields :
744+ if field not in dict_generated :
745+ dict_generated [field ] = fill_with
746+ return dict_generated
747+
748+
749+ async def dict_generator (
750+ * args ,
751+ fn ,
752+ fill_with ,
753+ fields ,
754+ ** kwargs ,
755+ ):
756+ dict_generated = fn (* args , ** kwargs )
757+ if fill_with is not None :
758+ for field in fields :
759+ if field not in dict_generated :
760+ dict_generated [field ] = fill_with
761+ return dict_generated
762+
763+
734764class extract_fields (base .SingleNodeNodeTransformer ):
735765 """Extracts fields from a dictionary of output."""
736766
@@ -804,29 +834,35 @@ def transform_node(
804834 """
805835 fn = node_ .callable
806836 base_doc = node_ .documentation
807-
837+ dict_generator_fn = (
838+ functools .partial (dict_generator , fn = fn , fill_with = self .fill_with , fields = self .fields )
839+ if not (inspect .iscoroutinefunction (fn ))
840+ else functools .partial (
841+ dict_generator_async , fn = fn , fill_with = self .fill_with , fields = self .fields
842+ )
843+ )
808844 # if fn is async
809- if inspect .iscoroutinefunction (fn ):
810-
811- async def dict_generator (* args , ** kwargs ):
812- dict_generated = await fn (* args , ** kwargs )
813- if self .fill_with is not None :
814- for field in self .fields :
815- if field not in dict_generated :
816- dict_generated [field ] = self .fill_with
817- return dict_generated
818-
819- else :
820-
821- def dict_generator (* args , ** kwargs ):
822- dict_generated = fn (* args , ** kwargs )
823- if self .fill_with is not None :
824- for field in self .fields :
825- if field not in dict_generated :
826- dict_generated [field ] = self .fill_with
827- return dict_generated
828-
829- output_nodes = [node_ .copy_with (callabl = dict_generator )]
845+ # if inspect.iscoroutinefunction(fn):
846+ #
847+ # async def dict_generator(*args, **kwargs):
848+ # dict_generated = await fn(*args, **kwargs)
849+ # if self.fill_with is not None:
850+ # for field in self.fields:
851+ # if field not in dict_generated:
852+ # dict_generated[field] = self.fill_with
853+ # return dict_generated
854+ #
855+ # else:
856+ #
857+ # def dict_generator(*args, **kwargs):
858+ # dict_generated = fn(*args, **kwargs)
859+ # if self.fill_with is not None:
860+ # for field in self.fields:
861+ # if field not in dict_generated:
862+ # dict_generated[field] = self.fill_with
863+ # return dict_generated
864+
865+ output_nodes = [node_ .copy_with (callabl = dict_generator_fn )]
830866
831867 for field , field_type in self .fields .items ():
832868 doc_string = base_doc # default doc string of base function.
0 commit comments