Skip to content

Commit

Permalink
chore(weave): Better names for internal methods (#2982)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong authored Nov 21, 2024
1 parent a0baa70 commit ed510a1
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions weave/trace/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class ProcessedInputs:
OnFinishHandlerType = Callable[["Call", Any, Optional[BaseException]], None]


def value_is_sentinel(param: Any) -> bool:
def _value_is_sentinel(param: Any) -> bool:
return (
param.default is None
or param.default is OPENAI_NOT_GIVEN
Expand All @@ -123,7 +123,7 @@ def _apply_fn_defaults_to_inputs(
sig = inspect.signature(fn)
for param_name, param in sig.parameters.items():
if param_name not in inputs:
if param.default != inspect.Parameter.empty and not value_is_sentinel(
if param.default != inspect.Parameter.empty and not _value_is_sentinel(
param
):
inputs[param_name] = param.default
Expand Down Expand Up @@ -224,7 +224,7 @@ def _is_unbound_method(func: Callable) -> bool:
return bool(is_method)


def default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedInputs:
def _default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedInputs:
try:
sig = inspect.signature(func)
inputs = sig.bind(*args, **kwargs).arguments
Expand All @@ -249,7 +249,7 @@ def _create_call(
if func._on_input_handler is not None:
pargs = func._on_input_handler(func, args, kwargs)
if not pargs:
pargs = default_on_input_handler(func, args, kwargs)
pargs = _default_on_input_handler(func, args, kwargs)
inputs_with_defaults = pargs.inputs

# This should probably be configurable, but for now we redact the api_key
Expand All @@ -273,9 +273,9 @@ def _create_call(
)


def _execute_call(
def _execute_op(
__op: Op,
call: Any,
__call: Call,
*args: Any,
__should_raise: bool = True,
**kwargs: Any,
Expand All @@ -290,17 +290,17 @@ def finish(output: Any = None, exception: BaseException | None = None) -> None:
raise ValueError("Should not call finish more than once")

client.finish_call(
call,
__call,
output,
exception,
op=__op,
)
if not call_context.get_current_call():
print_call_link(call)
print_call_link(__call)

def on_output(output: Any) -> Any:
if handler := getattr(__op, "_on_output_handler", None):
return handler(output, finish, call.inputs)
return handler(output, finish, __call.inputs)
finish(output)
return output

Expand All @@ -319,15 +319,15 @@ def process(res: Any) -> tuple[Any, Call]:
# Is there a better place for this? We want to ensure that even
# if the final output fails to be captured, we still pop the call
# so we don't put future calls under the old call.
call_context.pop_call(call.id)
call_context.pop_call(__call.id)

return res, call
return res, __call

def handle_exception(e: Exception) -> tuple[Any, Call]:
finish(exception=e)
if __should_raise:
raise
return None, call
return None, __call

if inspect.iscoroutinefunction(func):

Expand All @@ -348,7 +348,7 @@ async def _call_async() -> tuple[Any, Call]:
else:
return process(res)

return None, call
return None, __call


def call(
Expand Down Expand Up @@ -376,11 +376,19 @@ def add(a: int, b: int) -> int:
"""
if inspect.iscoroutinefunction(op.resolve_fn):
return _do_call_async(
op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs
op,
*args,
__weave=__weave,
__should_raise=__should_raise,
**kwargs,
)
else:
return _do_call(
op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs
op,
*args,
__weave=__weave,
__should_raise=__should_raise,
**kwargs,
)


Expand Down Expand Up @@ -411,7 +419,7 @@ def _do_call(
if op._on_input_handler is not None:
pargs = op._on_input_handler(op, args, kwargs)
if not pargs:
pargs = default_on_input_handler(op, args, kwargs)
pargs = _default_on_input_handler(op, args, kwargs)

if settings.should_disable_weave():
res = func(*pargs.args, **pargs.kwargs)
Expand All @@ -435,7 +443,7 @@ def _do_call(
)
res = func(*pargs.args, **pargs.kwargs)
else:
execute_result = _execute_call(
execute_result = _execute_op(
op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs
)
if inspect.iscoroutine(execute_result):
Expand Down Expand Up @@ -478,7 +486,7 @@ async def _do_call_async(
)
res = await func(*args, **kwargs)
else:
execute_result = _execute_call(
execute_result = _execute_op(
op, call, *args, __should_raise=__should_raise, **kwargs
)
if not inspect.iscoroutine(execute_result):
Expand Down

0 comments on commit ed510a1

Please sign in to comment.