[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[tor-commits] [stem/master] Make the `with_default` decorator compatible with asynchronous functions



commit d865c718242146c9ef14901482ebe940be100bb2
Author: Illia Volochii <illia.volochii@xxxxxxxxx>
Date:   Wed Apr 15 21:49:02 2020 +0300

    Make the `with_default` decorator compatible with asynchronous functions
---
 stem/control.py | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/stem/control.py b/stem/control.py
index 0bf1b35c..d5e0d0ba 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -458,6 +458,12 @@ def with_default(yields: bool = False) -> Callable:
   """
 
   def decorator(func: Callable) -> Callable:
+    is_coroutine_func = asyncio.iscoroutinefunction(func)
+    def coroutine_if_needed(func: Callable) -> Callable:
+      if is_coroutine_func:
+        return asyncio.coroutine(func)
+      return func
+
     def get_default(func: Callable, args: Any, kwargs: Any) -> Any:
       arg_names = inspect.getfullargspec(func).args[1:]  # drop 'self'
       default_position = arg_names.index('default') if 'default' in arg_names else None
@@ -469,9 +475,13 @@ def with_default(yields: bool = False) -> Callable:
 
     if not yields:
       @functools.wraps(func)
+      @coroutine_if_needed
       def wrapped(self, *args: Any, **kwargs: Any) -> Any:
         try:
-          return func(self, *args, **kwargs)
+          result = func(self, *args, **kwargs)
+          if is_coroutine_func:
+            result = yield from result
+          return result
         except:
           default = get_default(func, args, kwargs)
 
@@ -481,9 +491,13 @@ def with_default(yields: bool = False) -> Callable:
             return default
     else:
       @functools.wraps(func)
+      @coroutine_if_needed
       def wrapped(self, *args: Any, **kwargs: Any) -> Any:
         try:
-          for val in func(self, *args, **kwargs):
+          result = func(self, *args, **kwargs)
+          if is_coroutine_func:
+            result = yield from result
+          for val in result:
             yield val
         except:
           default = get_default(func, args, kwargs)



_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits