[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Synchronous class mockability
commit ef1e41ebce0aa1bb5fde9064410402bee9887451
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date: Wed Jul 8 17:12:39 2020 -0700
Synchronous class mockability
The meta-programming behind our Synchronous class doesn't play well with test
mocks. Handling this, and testing the permutations I can think of.
---
stem/util/__init__.py | 49 +++++++++++++++++++++++++++-----------
test/unit/util/synchronous.py | 55 +++++++++++++++++++++++++++++++++++++++++--
2 files changed, 88 insertions(+), 16 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 54f90376..15840f56 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,9 +10,11 @@ import datetime
import functools
import inspect
import threading
+import unittest.mock
+
from concurrent.futures import Future
-from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
+from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
__all__ = [
'conf',
@@ -213,19 +215,11 @@ class Synchronous(object):
# call any coroutines through this loop
- def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
- if Synchronous.is_asyncio_context():
- return func(*args, **kwargs)
-
- with self._loop_thread_lock:
- if not self._loop_thread.is_alive():
- raise RuntimeError('%s has been stopped' % type(self).__name__)
-
- return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
-
- for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
- if inspect.iscoroutinefunction(func):
- setattr(self, method_name, functools.partial(call_async, func))
+ for name, func in inspect.getmembers(self):
+ if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
+ setattr(self, name, functools.partial(self._call_async_method, name))
+ elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
+ setattr(self, name, functools.partial(self._call_async_method, name))
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
@@ -312,6 +306,33 @@ class Synchronous(object):
except RuntimeError:
return False
+ def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
+ """
+ Run this async method from either a synchronous or asynchronous context.
+
+ :param method_name: name of the method to invoke
+ :param args: positional arguments
+ :param kwargs: keyword arguments
+
+ :returns: method's return value
+
+ :raises: **AttributeError** if this method doesn't exist
+ """
+
+ # Retrieving methods by name (rather than keeping a reference) so runtime
+ # replacements like test mocks work.
+
+ func = getattr(type(self), method_name)
+
+ if Synchronous.is_asyncio_context():
+ return func(self, *args, **kwargs)
+
+ with self._loop_thread_lock:
+ if self._loop_thread and not self._loop_thread.is_alive():
+ raise RuntimeError('%s has been closed' % type(self).__name__)
+
+ return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
+
def __iter__(self) -> Iterator:
async def convert_generator(generator: AsyncIterator) -> Iterator:
return iter([d async for d in generator])
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index dd27c3c6..5b38a7b5 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -6,9 +6,10 @@ import asyncio
import io
import unittest
-from unittest.mock import patch
+from unittest.mock import patch, Mock
from stem.util import Synchronous
+from stem.util.test_tools import coro_func_returning_value
EXAMPLE_OUTPUT = """\
hello from a synchronous context
@@ -20,6 +21,8 @@ class Example(Synchronous):
async def hello(self):
return 'hello'
+ def sync_hello(self):
+ return 'hello'
class TestSynchronous(unittest.TestCase):
@patch('sys.stdout', new_callable = io.StringIO)
@@ -45,7 +48,7 @@ class TestSynchronous(unittest.TestCase):
def test_ainit(self):
"""
- Check that our constructor runs __ainit__ if present.
+ Check that our constructor runs __ainit__ when present.
"""
class AinitDemo(Synchronous):
@@ -96,3 +99,51 @@ class TestSynchronous(unittest.TestCase):
instance.start()
self.assertEqual('hello', instance.hello())
instance.stop()
+
+ def test_asynchronous_mockability(self):
+ """
+ Check that method mocks are respected.
+ """
+
+ # mock prior to construction
+
+ with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
+ instance = Example()
+ self.assertEqual('mocked hello', instance.hello())
+
+ self.assertEqual('hello', instance.hello()) # mock should now be reverted
+ instance.stop()
+
+ # mock after construction
+
+ instance = Example()
+
+ with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
+ self.assertEqual('mocked hello', instance.hello())
+
+ self.assertEqual('hello', instance.hello())
+ instance.stop()
+
+ def test_synchronous_mockability(self):
+ """
+ Ensure we do not disrupt non-asynchronous method mocks.
+ """
+
+ # mock prior to construction
+
+ with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
+ instance = Example()
+ self.assertEqual('mocked hello', instance.sync_hello())
+
+ self.assertEqual('hello', instance.sync_hello()) # mock should now be reverted
+ instance.stop()
+
+ # mock after construction
+
+ instance = Example()
+
+ with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
+ self.assertEqual('mocked hello', instance.sync_hello())
+
+ self.assertEqual('hello', instance.sync_hello())
+ instance.stop()
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits