diff --git a/src/decorator.py b/src/decorator.py index 2479b6f..35e5b18 100644 --- a/src/decorator.py +++ b/src/decorator.py @@ -37,6 +37,7 @@ import inspect import operator import itertools +import functools from contextlib import _GeneratorContextManager from inspect import getfullargspec, iscoroutinefunction, isgeneratorfunction @@ -71,7 +72,7 @@ def __init__(self, func=None, name=None, signature=None, self.name = '_lambda_' self.doc = func.__doc__ self.module = func.__module__ - if inspect.isroutine(func): + if inspect.isroutine(func) or isinstance(func, functools.partial): argspec = getfullargspec(func) self.annotations = getattr(func, '__annotations__', {}) for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', @@ -214,6 +215,8 @@ def decorate(func, caller, extras=(), kwsyntax=False): does. By default kwsyntax is False and the the arguments are untouched. """ sig = inspect.signature(func) + if isinstance(func, functools.partial): + func = functools.update_wrapper(func, func.func) if iscoroutinefunction(caller): async def fun(*args, **kw): if not kwsyntax: @@ -230,6 +233,7 @@ def fun(*args, **kw): if not kwsyntax: args, kw = fix(args, kw, sig) return caller(func, *(extras + args), **kw) + fun.__name__ = func.__name__ fun.__doc__ = func.__doc__ fun.__wrapped__ = func diff --git a/src/tests/test.py b/src/tests/test.py index be9f851..6172be4 100644 --- a/src/tests/test.py +++ b/src/tests/test.py @@ -3,6 +3,7 @@ import unittest import decimal import inspect +import functools from asyncio import get_event_loop from collections import defaultdict, ChainMap, abc as c from decorator import dispatch_on, contextmanager, decorator @@ -509,5 +510,20 @@ def __len__(self): h(u) +@decorator +def partial_before_after(func, *args, **kwargs): + return "" + func(*args, **kwargs) + "" + + +class PartialTestCase(unittest.TestCase): + def test_before_after(self): + def origin_func(x, y): + return x + y + _func = functools.partial(origin_func, "x") + partial_func = partial_before_after(_func) + out = partial_func("y") + self.assertEqual(out, 'xy') + + if __name__ == '__main__': unittest.main()