[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[tor-commits] [stem/master] Fix the `with_default` decorator for synchronous functions



commit b3f59acd750f10b11af76c52575b56208aad80dd
Author: Illia Volochii <illia.volochii@xxxxxxxxx>
Date:   Fri Apr 17 18:41:12 2020 +0300

    Fix the `with_default` decorator for synchronous functions
---
 stem/control.py | 95 +++++++++++++++++++++++++++++++++------------------------
 1 file changed, 55 insertions(+), 40 deletions(-)

diff --git a/stem/control.py b/stem/control.py
index 0e03edfb..cc9ef964 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -458,12 +458,6 @@ def with_default(yields: bool = False) -> Callable:
   """
 
   def decorator(func: Callable) -> Callable:
-    is_coroutine_func = asyncio.iscoroutinefunction(func)
-    def coroutine_if_needed(func: Callable) -> Callable:
-      if is_coroutine_func:
-        return asyncio.coroutine(func)
-      return func
-
     def get_default(func: Callable, args: Any, kwargs: Any) -> Any:
       arg_names = inspect.getfullargspec(func).args[1:]  # drop 'self'
       default_position = arg_names.index('default') if 'default' in arg_names else None
@@ -473,41 +467,62 @@ def with_default(yields: bool = False) -> Callable:
       else:
         return kwargs.get('default', UNDEFINED)
 
-    if not yields:
-      @functools.wraps(func)
-      @coroutine_if_needed
-      def wrapped(self, *args: Any, **kwargs: Any) -> Any:
-        try:
-          result = func(self, *args, **kwargs)
-          if is_coroutine_func:
-            result = yield from result
-          return result
-        except:
-          default = get_default(func, args, kwargs)
-
-          if default == UNDEFINED:
-            raise
-          else:
-            return default
+    if asyncio.iscoroutinefunction(func):
+      if not yields:
+        @functools.wraps(func)
+        async def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+          try:
+            return await func(self, *args, **kwargs)
+          except:
+            default = get_default(func, args, kwargs)
+
+            if default == UNDEFINED:
+              raise
+            else:
+              return default
+      else:
+        @functools.wraps(func)
+        async def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+          try:
+            for val in await func(self, *args, **kwargs):
+              yield val
+          except:
+            default = get_default(func, args, kwargs)
+
+            if default == UNDEFINED:
+              raise
+            else:
+              if default is not None:
+                for val in default:
+                  yield val
     else:
-      @functools.wraps(func)
-      @coroutine_if_needed
-      def wrapped(self, *args: Any, **kwargs: Any) -> Any:
-        try:
-          result = func(self, *args, **kwargs)
-          if is_coroutine_func:
-            result = yield from result
-          for val in result:
-            yield val
-        except:
-          default = get_default(func, args, kwargs)
-
-          if default == UNDEFINED:
-            raise
-          else:
-            if default is not None:
-              for val in default:
-                yield val
+      if not yields:
+        @functools.wraps(func)
+        def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+          try:
+            return func(self, *args, **kwargs)
+          except:
+            default = get_default(func, args, kwargs)
+
+            if default == UNDEFINED:
+              raise
+            else:
+              return default
+      else:
+        @functools.wraps(func)
+        def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+          try:
+            for val in func(self, *args, **kwargs):
+              yield val
+          except:
+            default = get_default(func, args, kwargs)
+
+            if default == UNDEFINED:
+              raise
+            else:
+              if default is not None:
+                for val in default:
+                  yield val
 
     return wrapped
 



_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits