[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Make Synchronous class resumable
commit 9f71ce9b21d8c710025a440e69920006e37eeb88
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date: Tue Jul 7 18:38:06 2020 -0700
Make Synchronous class resumable
Our Controller needs to start and stop with its connect/close methods, so we
need for this to be resumable.
---
stem/util/__init__.py | 50 +++++++++++++++++++++++++++---------------
test/settings.cfg | 2 ++
test/unit/descriptor/remote.py | 20 ++++++++---------
test/unit/util/synchronous.py | 29 +++++++++++++++++-------
4 files changed, 65 insertions(+), 36 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index cddce755..54f90376 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -12,7 +12,7 @@ import inspect
import threading
from concurrent.futures import Future
-from typing import Any, AsyncIterator, Callable, Iterator, Type, Union
+from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
__all__ = [
'conf',
@@ -162,12 +162,12 @@ class Synchronous(object):
def sync_demo():
instance = Example()
print('%s from a synchronous context' % instance.hello())
- instance.close()
+ instance.stop()
async def async_demo():
instance = Example()
print('%s from an asynchronous context' % await instance.hello())
- instance.close()
+ instance.stop()
sync_demo()
asyncio.run(async_demo())
@@ -194,35 +194,34 @@ class Synchronous(object):
# asyncio.get_running_loop(), and construct objects that
# require it (like asyncio.Queue and asyncio.Lock).
- Users are responsible for calling :func:`~stem.util.Synchronous.close` when
+ Users are responsible for calling :func:`~stem.util.Synchronous.stop` when
finished to clean up underlying resources.
"""
def __init__(self) -> None:
+ self._loop_thread = None # type: Optional[threading.Thread]
+ self._loop_thread_lock = threading.RLock()
+
if Synchronous.is_asyncio_context():
self._loop = asyncio.get_running_loop()
- self._loop_thread = None
self.__ainit__()
else:
self._loop = asyncio.new_event_loop()
- self._loop_thread = threading.Thread(
- name = '%s asyncio' % type(self).__name__,
- target = self._loop.run_forever,
- daemon = True,
- )
- self._loop_thread.start()
+ Synchronous.start(self)
# call any coroutines through this loop
def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
if Synchronous.is_asyncio_context():
return func(*args, **kwargs)
- elif not self._loop_thread.is_alive():
- raise RuntimeError('%s has been closed' % type(self).__name__)
- return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
+ with self._loop_thread_lock:
+ if not self._loop_thread.is_alive():
+ raise RuntimeError('%s has been stopped' % type(self).__name__)
+
+ return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
if inspect.iscoroutinefunction(func):
@@ -273,16 +272,31 @@ class Synchronous(object):
pass
- def close(self) -> None:
+ def start(self) -> None:
+ """
+ Initiate resources to make this object callable from synchronous contexts.
+ """
+
+ with self._loop_thread_lock:
+ self._loop_thread = threading.Thread(
+ name = '%s asyncio' % type(self).__name__,
+ target = self._loop.run_forever,
+ daemon = True,
+ )
+
+ self._loop_thread.start()
+
+ def stop(self) -> None:
"""
Terminate resources that permits this from being callable from synchronous
contexts. Once called any further synchronous invocations will fail with a
**RuntimeError**.
"""
- if self._loop_thread and self._loop_thread.is_alive():
- self._loop.call_soon_threadsafe(self._loop.stop)
- self._loop_thread.join()
+ with self._loop_thread_lock:
+ if self._loop_thread and self._loop_thread.is_alive():
+ self._loop.call_soon_threadsafe(self._loop.stop)
+ self._loop_thread.join()
@staticmethod
def is_asyncio_context() -> bool:
diff --git a/test/settings.cfg b/test/settings.cfg
index fcef5ec1..70bdd069 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -229,6 +229,8 @@ mypy.ignore * => "_IntegerEnum" has no attribute *
mypy.ignore * => See https://mypy.readthedocs.io/en/latest/common_issues.html*
mypy.ignore * => *is not valid as a type*
+mypy.ignore stem/descriptor/remote.py => Return type "Coroutine[Any, Any, None]" of "start" *
+
# Metaprogramming prevents mypy from determining descriptor attributes.
mypy.ignore * => "Descriptor" has no attribute "*
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 1fd2aaf9..bb6f554c 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -100,7 +100,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
- reply.close()
+ reply.stop()
def test_response_header_code(self):
"""
@@ -150,13 +150,13 @@ class TestDescriptorDownloader(unittest.TestCase):
descriptors = list(query)
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- query.close()
+ query.stop()
def test_gzip_url_override(self):
query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
self.assertEqual(TEST_RESOURCE, query.resource)
- query.close()
+ query.stop()
@mock_download(read_resource('compressed_identity'), encoding = 'identity')
def test_compression_plaintext(self):
@@ -172,7 +172,7 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.close()
+ query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -191,7 +191,7 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.close()
+ query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -212,7 +212,7 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.close()
+ query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -233,7 +233,7 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.close()
+ query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -258,7 +258,7 @@ class TestDescriptorDownloader(unittest.TestCase):
queries.append(downloader.get_detached_signatures())
for query in queries:
- query.close()
+ query.stop()
@mock_download(b'some malformed stuff')
def test_malformed_content(self):
@@ -285,7 +285,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaises(ValueError, query.run)
- query.close()
+ query.stop()
def test_query_with_invalid_endpoints(self):
invalid_endpoints = {
@@ -316,4 +316,4 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(list(query)))
self.assertEqual(1, len(list(query)))
- query.close()
+ query.stop()
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index 22271ffd..dd27c3c6 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -31,12 +31,12 @@ class TestSynchronous(unittest.TestCase):
def sync_demo():
instance = Example()
print('%s from a synchronous context' % instance.hello())
- instance.close()
+ instance.stop()
async def async_demo():
instance = Example()
print('%s from an asynchronous context' % await instance.hello())
- instance.close()
+ instance.stop()
sync_demo()
asyncio.run(async_demo())
@@ -66,20 +66,33 @@ class TestSynchronous(unittest.TestCase):
sync_demo()
asyncio.run(async_demo())
- def test_after_close(self):
+ def test_after_stop(self):
"""
- Check that closed instances raise a RuntimeError to synchronous callers.
+ Check that stopped instances raise a RuntimeError to synchronous callers.
"""
- # close a used instance
+ # stop a used instance
instance = Example()
self.assertEqual('hello', instance.hello())
- instance.close()
+ instance.stop()
self.assertRaises(RuntimeError, instance.hello)
- # close an unused instance
+ # stop an unused instance
instance = Example()
- instance.close()
+ instance.stop()
self.assertRaises(RuntimeError, instance.hello)
+
+ def test_resuming(self):
+ """
+ Resume a previously stopped instance.
+ """
+
+ instance = Example()
+ self.assertEqual('hello', instance.hello())
+ instance.stop()
+ self.assertRaises(RuntimeError, instance.hello)
+ instance.start()
+ self.assertEqual('hello', instance.hello())
+ instance.stop()
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits