[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[tor-commits] [stem/master] Synchronous context management



commit 8b539f2facdd86d07e76b5bf5daa379bf0d3d2ba
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date:   Thu Jul 9 17:29:58 2020 -0700

    Synchronous context management
    
    Make our class handle 'with' statements, and tidy up both its implementation
    and tests.
---
 stem/util/__init__.py         |  79 +++++++++------
 test/unit/util/synchronous.py | 223 +++++++++++++++++++++++++++++-------------
 2 files changed, 205 insertions(+), 97 deletions(-)

diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 15840f56..c147a5a4 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,10 +10,11 @@ import datetime
 import functools
 import inspect
 import threading
+import typing
 import unittest.mock
 
 from concurrent.futures import Future
-
+from types import TracebackType
 from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
 
 __all__ = [
@@ -201,26 +202,31 @@ class Synchronous(object):
   """
 
   def __init__(self) -> None:
+    self._loop = None  # type: Optional[asyncio.AbstractEventLoop]
     self._loop_thread = None  # type: Optional[threading.Thread]
     self._loop_thread_lock = threading.RLock()
 
-    if Synchronous.is_asyncio_context():
-      self._loop = asyncio.get_running_loop()
+    # this class is a no-op when created from an asyncio context
 
-      self.__ainit__()
-    else:
-      self._loop = asyncio.new_event_loop()
+    self._no_op = Synchronous.is_asyncio_context()
 
+    if not self._no_op:
+      self._loop = asyncio.new_event_loop()
       Synchronous.start(self)
 
-      # call any coroutines through this loop
+      # call any coroutines through our loop
 
       for name, func in inspect.getmembers(self):
-        if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
-          setattr(self, name, functools.partial(self._call_async_method, name))
+        if name in ('__aiter__', '__aenter__', '__aexit__'):
+          pass  # async object methods with synchronous counterparts
+        elif isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
+          setattr(self, name, functools.partial(self._run_async_method, name))
         elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
-          setattr(self, name, functools.partial(self._call_async_method, name))
+          setattr(self, name, functools.partial(self._run_async_method, name))
 
+    if self._no_op:
+      self.__ainit__()  # this is already an asyncio context
+    else:
       asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
 
   def __ainit__(self):
@@ -272,13 +278,14 @@ class Synchronous(object):
     """
 
     with self._loop_thread_lock:
-      self._loop_thread = threading.Thread(
-        name = '%s asyncio' % type(self).__name__,
-        target = self._loop.run_forever,
-        daemon = True,
-      )
+      if not self._no_op and self._loop_thread is None:
+        self._loop_thread = threading.Thread(
+          name = '%s asyncio' % type(self).__name__,
+          target = self._loop.run_forever,
+          daemon = True,
+        )
 
-      self._loop_thread.start()
+        self._loop_thread.start()
 
   def stop(self) -> None:
     """
@@ -288,9 +295,13 @@ class Synchronous(object):
     """
 
     with self._loop_thread_lock:
-      if self._loop_thread and self._loop_thread.is_alive():
+      if not self._no_op and self._loop_thread is not None:
         self._loop.call_soon_threadsafe(self._loop.stop)
-        self._loop_thread.join()
+
+        if threading.current_thread() != self._loop_thread:
+          self._loop_thread.join()
+
+        self._loop_thread = None
 
   @staticmethod
   def is_asyncio_context() -> bool:
@@ -306,7 +317,7 @@ class Synchronous(object):
     except RuntimeError:
       return False
 
-  def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
+  def _run_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
     """
     Run this async method from either a synchronous or asynchronous context.
 
@@ -324,25 +335,33 @@ class Synchronous(object):
 
     func = getattr(type(self), method_name)
 
-    if Synchronous.is_asyncio_context():
+    if self._no_op or Synchronous.is_asyncio_context():
       return func(self, *args, **kwargs)
 
     with self._loop_thread_lock:
-      if self._loop_thread and not self._loop_thread.is_alive():
-        raise RuntimeError('%s has been closed' % type(self).__name__)
+      if self._loop_thread is None:
+        raise RuntimeError('%s has been stopped' % type(self).__name__)
+      elif not func:
+        raise TypeError("'%s' does not have a %s method" % (type(self).__name__, method_name))
+
+      # convert iterator if indicated by this method's name or type hint
+
+      if method_name == '__aiter__' or (inspect.ismethod(func) and typing.get_type_hints(func).get('return') == AsyncIterator):
+        async def convert_generator(generator: AsyncIterator) -> Iterator:
+          return iter([d async for d in generator])
 
-      return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
+        return asyncio.run_coroutine_threadsafe(convert_generator(func(self, *args, **kwargs)), self._loop).result()
+      else:
+        return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
 
   def __iter__(self) -> Iterator:
-    async def convert_generator(generator: AsyncIterator) -> Iterator:
-      return iter([d async for d in generator])
+    return self._run_async_method('__aiter__')
 
-    iter_func = getattr(self, '__aiter__', None)
+  def __enter__(self):
+    return self._run_async_method('__aenter__')
 
-    if iter_func:
-      return asyncio.run_coroutine_threadsafe(convert_generator(iter_func()), self._loop).result()
-    else:
-      raise TypeError("'%s' object is not iterable" % type(self).__name__)
+  def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
+    return self._run_async_method('__aexit__', exit_type, value, traceback)
 
 
 class AsyncClassWrapper:
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index 5b38a7b5..d4429f3e 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -17,12 +17,34 @@ hello from an asynchronous context
 """
 
 
-class Example(Synchronous):
-  async def hello(self):
-    return 'hello'
+class Demo(Synchronous):
+  def __init__(self):
+    super(Demo, self).__init__()
+
+    self.called_enter = False
+    self.called_exit = False
+
+  def __ainit__(self):
+    self.ainit_loop = asyncio.get_running_loop()
+
+  async def async_method(self):
+    return 'async call'
+
+  def sync_method(self):
+    return 'sync call'
+
+  async def __aiter__(self):
+    for i in range(3):
+      yield i
+
+  async def __aenter__(self):
+    self.called_enter = True
+    return self
+
+  async def __aexit__(self, exit_type, value, traceback):
+    self.called_exit = True
+    return
 
-  def sync_hello(self):
-    return 'hello'
 
 class TestSynchronous(unittest.TestCase):
   @patch('sys.stdout', new_callable = io.StringIO)
@@ -31,6 +53,10 @@ class TestSynchronous(unittest.TestCase):
     Run the example from our pydoc.
     """
 
+    class Example(Synchronous):
+      async def hello(self):
+        return 'hello'
+
     def sync_demo():
       instance = Example()
       print('%s from a synchronous context' % instance.hello())
@@ -48,102 +74,165 @@ class TestSynchronous(unittest.TestCase):
 
   def test_ainit(self):
     """
-    Check that our constructor runs __ainit__ when present.
+    Check that construction runs __ainit__ with a loop when present.
     """
 
-    class AinitDemo(Synchronous):
-      def __init__(self):
-        super(AinitDemo, self).__init__()
-
-      def __ainit__(self):
-        self.ainit_loop = asyncio.get_running_loop()
-
-    def sync_demo():
-      instance = AinitDemo()
-      self.assertTrue(hasattr(instance, 'ainit_loop'))
+    def sync_test():
+      instance = Demo()
+      self.assertTrue(isinstance(instance.ainit_loop, asyncio.AbstractEventLoop))
+      instance.stop()
 
-    async def async_demo():
-      instance = AinitDemo()
-      self.assertTrue(hasattr(instance, 'ainit_loop'))
+    async def async_test():
+      instance = Demo()
+      self.assertTrue(isinstance(instance.ainit_loop, asyncio.AbstractEventLoop))
+      instance.stop()
 
-    sync_demo()
-    asyncio.run(async_demo())
+    sync_test()
+    asyncio.run(async_test())
 
-  def test_after_stop(self):
+  def test_stop(self):
     """
-    Check that stopped instances raise a RuntimeError to synchronous callers.
+    Synchronous callers should receive a RuntimeError when stopped.
     """
 
-    # stop a used instance
+    def sync_test():
+      instance = Demo()
+      self.assertEqual('async call', instance.async_method())
+      instance.stop()
+
+      self.assertRaises(RuntimeError, instance.async_method)
+
+      # synchronous methods still work
 
-    instance = Example()
-    self.assertEqual('hello', instance.hello())
-    instance.stop()
-    self.assertRaises(RuntimeError, instance.hello)
+      self.assertEqual('sync call', instance.sync_method())
 
-    # stop an unused instance
+    async def async_test():
+      instance = Demo()
+      self.assertEqual('async call', await instance.async_method())
+      instance.stop()
+
+      # stop has no affect on async users
+
+      self.assertEqual('async call', await instance.async_method())
 
-    instance = Example()
-    instance.stop()
-    self.assertRaises(RuntimeError, instance.hello)
+    sync_test()
+    asyncio.run(async_test())
 
   def test_resuming(self):
     """
     Resume a previously stopped instance.
     """
 
-    instance = Example()
-    self.assertEqual('hello', instance.hello())
-    instance.stop()
-    self.assertRaises(RuntimeError, instance.hello)
-    instance.start()
-    self.assertEqual('hello', instance.hello())
-    instance.stop()
+    def sync_test():
+      instance = Demo()
+      self.assertEqual('async call', instance.async_method())
+      instance.stop()
+
+      self.assertRaises(RuntimeError, instance.async_method)
 
-  def test_asynchronous_mockability(self):
+      instance.start()
+      self.assertEqual('async call', instance.async_method())
+      instance.stop()
+
+    async def async_test():
+      instance = Demo()
+      self.assertEqual('async call', await instance.async_method())
+      instance.stop()
+
+      # start has no affect on async users
+
+      instance.start()
+      self.assertEqual('async call', await instance.async_method())
+      instance.stop()
+
+    sync_test()
+    asyncio.run(async_test())
+
+  def test_iteration(self):
     """
-    Check that method mocks are respected.
+    Check that we can iterate in both contexts.
     """
 
-    # mock prior to construction
+    def sync_test():
+      instance = Demo()
+      result = []
 
-    with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
-      instance = Example()
-      self.assertEqual('mocked hello', instance.hello())
+      for val in instance:
+        result.append(val)
 
-    self.assertEqual('hello', instance.hello())  # mock should now be reverted
-    instance.stop()
+      self.assertEqual([0, 1, 2], result)
+      instance.stop()
 
-    # mock after construction
+    async def async_test():
+      instance = Demo()
+      result = []
 
-    instance = Example()
+      async for val in instance:
+        result.append(val)
 
-    with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
-      self.assertEqual('mocked hello', instance.hello())
+      self.assertEqual([0, 1, 2], result)
+      instance.stop()
 
-    self.assertEqual('hello', instance.hello())
-    instance.stop()
+    sync_test()
+    asyncio.run(async_test())
 
-  def test_synchronous_mockability(self):
+  def test_context_management(self):
     """
-    Ensure we do not disrupt non-asynchronous method mocks.
+    Exercise context management via 'with' statements.
     """
 
-    # mock prior to construction
+    def sync_test():
+      instance = Demo()
 
-    with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
-      instance = Example()
-      self.assertEqual('mocked hello', instance.sync_hello())
+      self.assertFalse(instance.called_enter)
+      self.assertFalse(instance.called_exit)
+
+      with instance:
+        self.assertTrue(instance.called_enter)
+        self.assertFalse(instance.called_exit)
+
+      self.assertTrue(instance.called_enter)
+      self.assertTrue(instance.called_exit)
+
+    async def async_test():
+      instance = Demo()
+
+      self.assertFalse(instance.called_enter)
+      self.assertFalse(instance.called_exit)
+
+      async with instance:
+        self.assertTrue(instance.called_enter)
+        self.assertFalse(instance.called_exit)
+
+      self.assertTrue(instance.called_enter)
+      self.assertTrue(instance.called_exit)
+
+    sync_test()
+    asyncio.run(async_test())
+
+  def test_mockability(self):
+    """
+    Check that method mocks are respected for both previously constructed
+    instances and those made after the mock.
+    """
+
+    pre_constructed = Demo()
+
+    with patch('test.unit.util.synchronous.Demo.async_method', Mock(side_effect = coro_func_returning_value('mocked call'))):
+      post_constructed = Demo()
+
+      self.assertEqual('mocked call', pre_constructed.async_method())
+      self.assertEqual('mocked call', post_constructed.async_method())
 
-    self.assertEqual('hello', instance.sync_hello())  # mock should now be reverted
-    instance.stop()
+    self.assertEqual('async call', pre_constructed.async_method())
+    self.assertEqual('async call', post_constructed.async_method())
 
-    # mock after construction
+    # synchronous methods are unaffected
 
-    instance = Example()
+    with patch('test.unit.util.synchronous.Demo.sync_method', Mock(return_value = 'mocked call')):
+      self.assertEqual('mocked call', pre_constructed.sync_method())
 
-    with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
-      self.assertEqual('mocked hello', instance.sync_hello())
+    self.assertEqual('sync call', pre_constructed.sync_method())
 
-    self.assertEqual('hello', instance.sync_hello())
-    instance.stop()
+    pre_constructed.stop()
+    post_constructed.stop()



_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits