[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Fix the `with_default` decorator for synchronous functions
commit b3f59acd750f10b11af76c52575b56208aad80dd
Author: Illia Volochii <illia.volochii@xxxxxxxxx>
Date: Fri Apr 17 18:41:12 2020 +0300
Fix the `with_default` decorator for synchronous functions
---
stem/control.py | 95 +++++++++++++++++++++++++++++++++------------------------
1 file changed, 55 insertions(+), 40 deletions(-)
diff --git a/stem/control.py b/stem/control.py
index 0e03edfb..cc9ef964 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -458,12 +458,6 @@ def with_default(yields: bool = False) -> Callable:
"""
def decorator(func: Callable) -> Callable:
- is_coroutine_func = asyncio.iscoroutinefunction(func)
- def coroutine_if_needed(func: Callable) -> Callable:
- if is_coroutine_func:
- return asyncio.coroutine(func)
- return func
-
def get_default(func: Callable, args: Any, kwargs: Any) -> Any:
arg_names = inspect.getfullargspec(func).args[1:] # drop 'self'
default_position = arg_names.index('default') if 'default' in arg_names else None
@@ -473,41 +467,62 @@ def with_default(yields: bool = False) -> Callable:
else:
return kwargs.get('default', UNDEFINED)
- if not yields:
- @functools.wraps(func)
- @coroutine_if_needed
- def wrapped(self, *args: Any, **kwargs: Any) -> Any:
- try:
- result = func(self, *args, **kwargs)
- if is_coroutine_func:
- result = yield from result
- return result
- except:
- default = get_default(func, args, kwargs)
-
- if default == UNDEFINED:
- raise
- else:
- return default
+ if asyncio.iscoroutinefunction(func):
+ if not yields:
+ @functools.wraps(func)
+ async def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+ try:
+ return await func(self, *args, **kwargs)
+ except:
+ default = get_default(func, args, kwargs)
+
+ if default == UNDEFINED:
+ raise
+ else:
+ return default
+ else:
+ @functools.wraps(func)
+ async def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+ try:
+ for val in await func(self, *args, **kwargs):
+ yield val
+ except:
+ default = get_default(func, args, kwargs)
+
+ if default == UNDEFINED:
+ raise
+ else:
+ if default is not None:
+ for val in default:
+ yield val
else:
- @functools.wraps(func)
- @coroutine_if_needed
- def wrapped(self, *args: Any, **kwargs: Any) -> Any:
- try:
- result = func(self, *args, **kwargs)
- if is_coroutine_func:
- result = yield from result
- for val in result:
- yield val
- except:
- default = get_default(func, args, kwargs)
-
- if default == UNDEFINED:
- raise
- else:
- if default is not None:
- for val in default:
- yield val
+ if not yields:
+ @functools.wraps(func)
+ def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+ try:
+ return func(self, *args, **kwargs)
+ except:
+ default = get_default(func, args, kwargs)
+
+ if default == UNDEFINED:
+ raise
+ else:
+ return default
+ else:
+ @functools.wraps(func)
+ def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+ try:
+ for val in func(self, *args, **kwargs):
+ yield val
+ except:
+ default = get_default(func, args, kwargs)
+
+ if default == UNDEFINED:
+ raise
+ else:
+ if default is not None:
+ for val in default:
+ yield val
return wrapped
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits