[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Prepare to creating and wrapping one more asynchronous class
commit 79e8c1b47c63dfd49b62ec47c1b7902f51b06a83
Author: Illia Volochii <illia.volochii@xxxxxxxxx>
Date: Thu May 14 00:06:49 2020 +0300
Prepare to creating and wrapping one more asynchronous class
---
stem/connection.py | 2 +-
stem/control.py | 108 ++++++++-------------------------------
stem/interpreter/__init__.py | 2 +-
stem/interpreter/commands.py | 5 +-
stem/util/__init__.py | 81 +++++++++++++++++++++++++++++
test/integ/control/controller.py | 2 +-
test/runner.py | 2 +-
test/unit/control/controller.py | 4 +-
8 files changed, 110 insertions(+), 96 deletions(-)
diff --git a/stem/connection.py b/stem/connection.py
index c44fddb1..8f57f3b3 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -257,7 +257,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
if controller is None or not issubclass(controller, stem.control.Controller):
raise ValueError('Controller should be a stem.control.BaseController subclass.')
- async_controller_thread = stem.control._AsyncControllerThread()
+ async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
async_controller_thread.start()
connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
diff --git a/stem/control.py b/stem/control.py
index 6de671b6..1488621a 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -553,29 +553,6 @@ def event_description(event: str) -> str:
return EVENT_DESCRIPTIONS.get(event.lower())
-class _MsgLock:
- __slots__ = ('_r_lock', '_async_lock')
-
- def __init__(self):
- self._r_lock = threading.RLock()
- self._async_lock = asyncio.Lock()
-
- async def acquire(self):
- await self._async_lock.acquire()
- self._r_lock.acquire()
-
- def release(self):
- self._r_lock.release()
- self._async_lock.release()
-
- async def __aenter__(self):
- await self.acquire()
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- self.release()
-
-
class _BaseControllerSocketMixin:
def is_alive(self) -> bool:
"""
@@ -644,7 +621,7 @@ class BaseController(_BaseControllerSocketMixin):
self._asyncio_loop = asyncio.get_event_loop()
- self._msg_lock = _MsgLock()
+ self._msg_lock = stem.util.CombinedReentrantAndAsyncioLock()
self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread)
self._status_listeners_lock = threading.RLock()
@@ -3901,22 +3878,7 @@ class AsyncController(_ControllerClassMethodMixin, BaseController):
return (set_events, failed_events)
-class _AsyncControllerThread(threading.Thread):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, *kwargs)
- self.loop = asyncio.new_event_loop()
- self.setDaemon(True)
-
- def run(self):
- self.loop.run_forever()
-
- def join(self, timeout = None):
- self.loop.call_soon_threadsafe(self.loop.stop)
- super().join(timeout)
- self.loop.close()
-
-
-class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
+class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
@classmethod
def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller':
instance = super().from_port(address, port)
@@ -3932,48 +3894,19 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None:
def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None):
if started_async_controller_thread:
- self._async_controller_thread = started_async_controller_thread
+ self._thread_for_wrapped_class = started_async_controller_thread
else:
- self._async_controller_thread = _AsyncControllerThread()
- self._async_controller_thread.start()
- self._asyncio_loop = self._async_controller_thread.loop
-
- self._async_controller = self._init_async_controller(control_socket, is_authenticated)
- self._socket = self._async_controller._socket
-
- def _init_async_controller(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool) -> 'stem.control.AsyncController':
- # The asynchronous controller should be initialized in the thread where its
- # methods will be executed.
- if self._async_controller_thread != threading.current_thread():
- async def init_async_controller() -> 'stem.control.AsyncController':
- return AsyncController(control_socket, is_authenticated)
-
- return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result()
-
- return AsyncController(control_socket, is_authenticated)
-
- def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
- return asyncio.run_coroutine_threadsafe(
- getattr(self._async_controller, method_name)(*args, **kwargs),
- self._asyncio_loop,
- ).result()
-
- def _execute_async_generator_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
- async def convert_async_generator(generator):
- return iter([d async for d in generator])
+ self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
+ self._thread_for_wrapped_class.start()
- return asyncio.run_coroutine_threadsafe(
- convert_async_generator(
- getattr(self._async_controller, method_name)(*args, **kwargs),
- ),
- self._asyncio_loop,
- ).result()
+ self._wrapped_instance = self._init_async_class(AsyncController, control_socket, is_authenticated)
+ self._socket = self._wrapped_instance._socket
def msg(self, message: str) -> stem.response.ControlMessage:
return self._execute_async_method('msg', message)
def is_authenticated(self) -> bool:
- return self._async_controller.is_authenticated()
+ return self._wrapped_instance.is_authenticated()
def connect(self) -> None:
self._execute_async_method('connect')
@@ -3985,13 +3918,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
self._execute_async_method('close')
def get_latest_heartbeat(self) -> float:
- return self._async_controller.get_latest_heartbeat()
+ return self._wrapped_instance.get_latest_heartbeat()
def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None:
- self._async_controller.add_status_listener(callback, spawn)
+ self._wrapped_instance.add_status_listener(callback, spawn)
def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool:
- self._async_controller.remove_status_listener(callback)
+ self._wrapped_instance.remove_status_listener(callback)
def authenticate(self, *args: Any, **kwargs: Any) -> None:
self._execute_async_method('authenticate', *args, **kwargs)
@@ -4099,13 +4032,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
self._execute_async_method('remove_event_listener', listener)
def is_caching_enabled(self) -> bool:
- return self._async_controller.is_caching_enabled()
+ return self._wrapped_instance.is_caching_enabled()
def set_caching(self, enabled: bool) -> None:
- self._async_controller.set_caching(enabled)
+ self._wrapped_instance.set_caching(enabled)
def clear_cache(self) -> None:
- self._async_controller.clear_cache()
+ self._wrapped_instance.clear_cache()
def load_conf(self, configtext: str) -> None:
self._execute_async_method('load_conf', configtext)
@@ -4114,10 +4047,10 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
return self._execute_async_method('save_conf', force)
def is_feature_enabled(self, feature: str) -> bool:
- return self._async_controller.is_feature_enabled(feature)
+ return self._wrapped_instance.is_feature_enabled(feature)
def enable_feature(self, features: Union[str, Sequence[str]]) -> None:
- self._async_controller.enable_feature(features)
+ self._wrapped_instance.enable_feature(features)
def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent:
return self._execute_async_method('get_circuit', circuit_id, default)
@@ -4150,10 +4083,10 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
self._execute_async_method('signal', signal)
def is_newnym_available(self) -> bool:
- return self._async_controller.is_newnym_available()
+ return self._wrapped_instance.is_newnym_available()
def get_newnym_wait(self) -> float:
- return self._async_controller.get_newnym_wait()
+ return self._wrapped_instance.get_newnym_wait()
def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int:
return self._execute_async_method('get_effective_rate', default, burst)
@@ -4165,8 +4098,9 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
self._execute_async_method('drop_guards')
def __del__(self) -> None:
- if self._asyncio_loop.is_running():
- self._asyncio_loop.call_soon_threadsafe(self._asyncio_loop.stop)
+ loop = self._thread_for_wrapped_class.loop
+ if loop.is_running():
+ loop.call_soon_threadsafe(loop.stop)
def __enter__(self) -> 'stem.control.Controller':
return self
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index 07353d44..ae064a0a 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -127,7 +127,7 @@ def main() -> None:
async def handle_event(event_message):
print(format(str(event_message), *STANDARD_OUTPUT))
- controller._async_controller._handle_event = handle_event
+ controller._wrapped_instance._handle_event = handle_event
if sys.stdout.isatty():
events = args.run_cmd.upper().split(' ', 1)[1]
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index 0d262ab5..edbcca70 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -128,7 +128,7 @@ class ControlInterpreter(code.InteractiveConsole):
# Intercept events our controller hears about at a pretty low level since
# the user will likely be requesting them by direct 'SETEVENTS' calls.
- handle_event_real = self._controller._async_controller._handle_event
+ handle_event_real = self._controller._wrapped_instance._handle_event
async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
await handle_event_real(event_message)
@@ -139,8 +139,7 @@ class ControlInterpreter(code.InteractiveConsole):
# type check disabled due to https://github.com/python/mypy/issues/708
- self._controller._async_controller._handle_event = handle_event_wrapper
- self._controller._handle_event = handle_event_wrapper # type: ignore
+ self._controller._wrapped_instance._handle_event = handle_event_wrapper
def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]:
events = list(self._received_events)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index e4fa3ca8..a230cfbd 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -5,7 +5,9 @@
Utility functions used by the stem library.
"""
+import asyncio
import datetime
+import threading
from typing import Any, Union
@@ -139,3 +141,82 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
setattr(obj, '_cached_hash', my_hash)
return my_hash
+
+
+class CombinedReentrantAndAsyncioLock:
+ """
+ Lock that combines thread-safe reentrant and not thread-safe asyncio locks.
+ """
+
+ __slots__ = ('_r_lock', '_async_lock')
+
+ def __init__(self):
+ self._r_lock = threading.RLock()
+ self._async_lock = asyncio.Lock()
+
+ async def acquire(self):
+ await self._async_lock.acquire()
+ self._r_lock.acquire()
+
+ def release(self):
+ self._r_lock.release()
+ self._async_lock.release()
+
+ async def __aenter__(self):
+ await self.acquire()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ self.release()
+
+
+class ThreadForWrappedAsyncClass(threading.Thread):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, *kwargs)
+ self.loop = asyncio.new_event_loop()
+ self.setDaemon(True)
+
+ def run(self):
+ self.loop.run_forever()
+
+ def join(self, timeout=None):
+ self.loop.call_soon_threadsafe(self.loop.stop)
+ super().join(timeout)
+ self.loop.close()
+
+
+class AsyncClassWrapper:
+ _thread_for_wrapped_class: ThreadForWrappedAsyncClass
+ _wrapped_instance: type
+
+ def _init_async_class(self, async_class, *args, **kwargs):
+ thread = self._thread_for_wrapped_class
+ # The asynchronous class should be initialized in the thread where
+ # its methods will be executed.
+ if thread != threading.current_thread():
+ async def init():
+ return async_class(*args, **kwargs)
+
+ return asyncio.run_coroutine_threadsafe(init(), thread.loop).result()
+
+ return async_class(*args, **kwargs)
+
+ def _call_async_method_soon(self, method_name, *args, **kwargs):
+ return asyncio.run_coroutine_threadsafe(
+ getattr(self._wrapped_instance, method_name)(*args, **kwargs),
+ self._thread_for_wrapped_class.loop,
+ )
+
+ def _execute_async_method(self, method_name, *args, **kwargs):
+ return self._call_async_method_soon(method_name, *args, **kwargs).result()
+
+ def _execute_async_generator_method(self, method_name, *args, **kwargs):
+ async def convert_async_generator(generator):
+ return iter([d async for d in generator])
+
+ return asyncio.run_coroutine_threadsafe(
+ convert_async_generator(
+ getattr(self._wrapped_instance, method_name)(*args, **kwargs),
+ ),
+ self._thread_for_wrapped_class.loop,
+ ).result()
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index 73e71fa4..b1772f34 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -113,7 +113,7 @@ class TestController(unittest.TestCase):
state_controller, state_type, state_timestamp = received_events[0]
- self.assertEqual(controller._async_controller, state_controller)
+ self.assertEqual(controller._wrapped_instance, state_controller)
self.assertEqual(State.RESET, state_type)
self.assertTrue(state_timestamp > before and state_timestamp < after)
diff --git a/test/runner.py b/test/runner.py
index 4a38e824..189a2d7b 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -488,7 +488,7 @@ class Runner(object):
:raises: :class: `test.runner.TorInaccessable` if tor can't be connected to
"""
- async_controller_thread = stem.control._AsyncControllerThread()
+ async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
async_controller_thread.start()
try:
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index e8ef4787..a11aba45 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -44,7 +44,7 @@ class TestControl(unittest.TestCase):
with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))):
self.controller = Controller(socket)
- self.async_controller = self.controller._async_controller
+ self.async_controller = self.controller._wrapped_instance
self.circ_listener = Mock()
self.controller.add_event_listener(self.circ_listener, EventType.CIRC)
@@ -748,7 +748,7 @@ class TestControl(unittest.TestCase):
with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
is_alive_mock.return_value = True
- loop = self.controller._asyncio_loop
+ loop = self.controller._thread_for_wrapped_class.loop
asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
try:
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits