[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Use Synchronous for Controller
commit dc93ee7257b9a0eed1a4632dec3c7b16a65e9782
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date: Thu Jun 25 16:13:54 2020 -0700
Use Synchronous for Controller
Finally migrating our Controller class from Illia's AsyncClassWrapper to our
Synchronous mixin.
Benefits are...
* Class no longer requires a synchronous and asynchronous copy.
* Controller can be implemented as a fully asynchronous class, while still
functioning in synchronous contexts.
Downside is...
* Python type checkers (like mypy) only recognice our Controller as an
asynchronous class, producing false positives for synchronous users.
---
run_tests.py | 4 +
stem/connection.py | 9 +-
stem/control.py | 579 +++++-----------------
stem/descriptor/remote.py | 8 +-
stem/interpreter/__init__.py | 2 +-
stem/interpreter/commands.py | 4 +-
stem/util/__init__.py | 43 +-
stem/util/test_tools.py | 3 +
test/integ/connection/authentication.py | 13 +-
test/integ/control/controller.py | 818 ++++++++++++++++++--------------
test/runner.py | 64 ++-
test/settings.cfg | 8 +
test/unit/control/controller.py | 119 +++--
test/unit/descriptor/remote.py | 2 +-
14 files changed, 715 insertions(+), 961 deletions(-)
diff --git a/run_tests.py b/run_tests.py
index c738f9b8..5218008f 100755
--- a/run_tests.py
+++ b/run_tests.py
@@ -259,6 +259,10 @@ def main():
# 2.7 or later because before that test results didn't have a 'skipped'
# attribute.
+ # TODO: handling of earlier python versions is no longer necessary here
+ # TODO: this invokes all asynchronous tests, even if we have a --test or
+ # --exclude-test argument
+
skipped_tests = 0
if args.run_integ:
diff --git a/stem/connection.py b/stem/connection.py
index 86d32d7f..8495da2a 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -89,7 +89,7 @@ fine-grained control over the authentication process. For instance...
::
connect - Simple method for getting authenticated control connection for synchronous usage.
- async_connect - Simple method for getting authenticated control connection for asynchronous usage.
+ async_connect - Simple method for getting authenticated control connection for asynchronous usage.
authenticate - Main method for authenticating to a control socket
authenticate_none - Authenticates to an open control socket
@@ -292,7 +292,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
raise
-async def connect_async(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type[stem.control.BaseController] = stem.control.AsyncController) -> Any:
+async def connect_async(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type[stem.control.BaseController] = stem.control.Controller) -> Any:
"""
Convenience function for quickly getting a control connection for
asynchronous usage. This is very handy for debugging or CLI setup, handling
@@ -364,6 +364,7 @@ async def _connect_async(control_port: Tuple[str, Union[str, int]], control_sock
control_connection = _connection_for_default_port(address)
else:
control_connection = stem.socket.ControlPort(address, int(port))
+
await control_connection.connect()
except stem.SocketError as exc:
error_msg = CONNECT_MESSAGES['unable_to_use_port'].format(address = address, port = port, error = exc)
@@ -405,9 +406,7 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str
if controller is None:
return control_socket
- elif issubclass(controller, stem.control.BaseController) or issubclass(controller, stem.control.Controller):
- # TODO: Controller no longer extends BaseController (we'll probably change that)
-
+ else:
return controller(control_socket, is_authenticated = True)
except IncorrectSocketType:
if isinstance(control_socket, stem.socket.ControlPort):
diff --git a/stem/control.py b/stem/control.py
index 47ddaa35..7b90eed0 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -269,9 +269,9 @@ import stem.util.tor_tools
import stem.version
from stem import UNDEFINED, CircStatus, Signal
-from stem.util import log
+from stem.util import Synchronous, log
from types import TracebackType
-from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
# When closing the controller we attempt to finish processing enqueued events,
# but if it takes longer than this we terminate.
@@ -553,56 +553,7 @@ def event_description(event: str) -> str:
return EVENT_DESCRIPTIONS.get(event.lower())
-class _BaseControllerSocketMixin:
- _socket: stem.socket.ControlSocket
-
- def is_alive(self) -> bool:
- """
- Checks if our socket is currently connected. This is a pass-through for our
- socket's :func:`~stem.socket.BaseSocket.is_alive` method.
-
- :returns: **bool** that's **True** if our socket is connected and **False** otherwise
- """
-
- return self._socket.is_alive()
-
- def is_localhost(self) -> bool:
- """
- Returns if the connection is for the local system or not.
-
- .. versionadded:: 1.3.0
-
- :returns: **bool** that's **True** if the connection is for the local host and **False** otherwise
- """
-
- return self._socket.is_localhost()
-
- def connection_time(self) -> float:
- """
- Provides the unix timestamp for when our socket was either connected or
- disconnected. That is to say, the time we connected if we're currently
- connected and the time we disconnected if we're not connected.
-
- .. versionadded:: 1.3.0
-
- :returns: **float** for when we last connected or disconnected, zero if
- we've never connected
- """
-
- return self._socket.connection_time()
-
- def get_socket(self) -> stem.socket.ControlSocket:
- """
- Provides the socket used to speak with the tor process. Communicating with
- the socket directly isn't advised since it may confuse this controller.
-
- :returns: :class:`~stem.socket.ControlSocket` we're communicating with
- """
-
- return self._socket
-
-
-class BaseController(_BaseControllerSocketMixin):
+class BaseController(Synchronous):
"""
Controller for the tor process. This is a minimal base class for other
controllers, providing basic process communication and event listing. Don't
@@ -619,21 +570,13 @@ class BaseController(_BaseControllerSocketMixin):
"""
def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None:
- self._socket = control_socket
+ super(BaseController, self).__init__()
- self._asyncio_loop = asyncio.get_event_loop()
-
- self._msg_lock = asyncio.Lock()
+ self._socket = control_socket
self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread)
self._status_listeners_lock = threading.RLock()
- # queues where incoming messages are directed
- self._reply_queue = asyncio.Queue() # type: asyncio.Queue[Union[stem.response.ControlMessage, stem.ControllerError]]
- self._event_queue = asyncio.Queue() # type: asyncio.Queue[stem.response.ControlMessage]
-
- self._event_notice = asyncio.Event()
-
# saves our socket's prior _connect() and _close() methods so they can be
# called along with ours
@@ -650,11 +593,22 @@ class BaseController(_BaseControllerSocketMixin):
self._reader_loop_task = None # type: Optional[asyncio.Task]
self._event_loop_task = None # type: Optional[asyncio.Task]
+
if self._socket.is_alive():
self._create_loop_tasks()
if is_authenticated:
- self._asyncio_loop.create_task(self._post_authentication())
+ self._loop.create_task(self._post_authentication())
+
+ def __ainit__(self) -> None:
+ self._msg_lock = asyncio.Lock()
+
+ # queues where incoming messages are directed
+
+ self._reply_queue = asyncio.Queue() # type: asyncio.Queue[Union[stem.response.ControlMessage, stem.ControllerError]]
+ self._event_queue = asyncio.Queue() # type: asyncio.Queue[stem.response.ControlMessage]
+
+ self._event_notice = asyncio.Event()
async def msg(self, message: str) -> stem.response.ControlMessage:
"""
@@ -736,8 +690,43 @@ class BaseController(_BaseControllerSocketMixin):
# provide an assurance to the caller that when we raise a SocketClosed
# exception we are shut down afterward for realz.
- await self.close()
- raise
+ await self.close()
+ raise
+
+ def is_alive(self) -> bool:
+ """
+ Checks if our socket is currently connected. This is a pass-through for our
+ socket's :func:`~stem.socket.BaseSocket.is_alive` method.
+
+ :returns: **bool** that's **True** if our socket is connected and **False** otherwise
+ """
+
+ return self._socket.is_alive()
+
+ def is_localhost(self) -> bool:
+ """
+ Returns if the connection is for the local system or not.
+
+ .. versionadded:: 1.3.0
+
+ :returns: **bool** that's **True** if the connection is for the local host and **False** otherwise
+ """
+
+ return self._socket.is_localhost()
+
+ def connection_time(self) -> float:
+ """
+ Provides the unix timestamp for when our socket was either connected or
+ disconnected. That is to say, the time we connected if we're currently
+ connected and the time we disconnected if we're not connected.
+
+ .. versionadded:: 1.3.0
+
+ :returns: **float** for when we last connected or disconnected, zero if
+ we've never connected
+ """
+
+ return self._socket.connection_time()
def is_authenticated(self) -> bool:
"""
@@ -778,6 +767,18 @@ class BaseController(_BaseControllerSocketMixin):
if t.is_alive() and threading.current_thread() != t:
t.join()
+ self.stop()
+
+ def get_socket(self) -> stem.socket.ControlSocket:
+ """
+ Provides the socket used to speak with the tor process. Communicating with
+ the socket directly isn't advised since it may confuse this controller.
+
+ :returns: :class:`~stem.socket.ControlSocket` we're communicating with
+ """
+
+ return self._socket
+
def get_latest_heartbeat(self) -> float:
"""
Provides the unix timestamp for when we last heard from tor. This is zero
@@ -858,7 +859,7 @@ class BaseController(_BaseControllerSocketMixin):
async def _connect(self) -> None:
self._create_loop_tasks()
- await self._notify_status_listeners(State.INIT, acquire_send_lock=False)
+ await self._notify_status_listeners(State.INIT, acquire_send_lock = False)
await self._socket_connect()
self._is_authenticated = False
@@ -874,13 +875,14 @@ class BaseController(_BaseControllerSocketMixin):
self._reader_loop_task = None
event_loop_task = self._event_loop_task
self._event_loop_task = None
+
if reader_loop_task and self.is_alive():
await reader_loop_task
+
if event_loop_task:
await event_loop_task
- await self._notify_status_listeners(State.CLOSED, acquire_send_lock=False)
-
+ await self._notify_status_listeners(State.CLOSED, acquire_send_lock = False)
await self._socket_close()
async def _post_authentication(self) -> None:
@@ -899,6 +901,7 @@ class BaseController(_BaseControllerSocketMixin):
# need to have it to ensure it doesn't change beneath us.
send_lock = self._socket._get_send_lock()
+
try:
if acquire_send_lock:
await send_lock.acquire()
@@ -944,8 +947,8 @@ class BaseController(_BaseControllerSocketMixin):
them if we're restarted.
"""
- self._reader_loop_task = self._asyncio_loop.create_task(self._reader_loop())
- self._event_loop_task = self._asyncio_loop.create_task(self._event_loop())
+ self._reader_loop_task = self._loop.create_task(self._reader_loop())
+ self._event_loop_task = self._loop.create_task(self._event_loop())
async def _reader_loop(self) -> None:
"""
@@ -1011,21 +1014,24 @@ class BaseController(_BaseControllerSocketMixin):
self._event_notice.clear()
-class AsyncController(BaseController):
+class Controller(BaseController):
"""
Connection with Tor's control socket. This is built on top of the
BaseController and provides a more user friendly API for library users.
"""
- @classmethod
- def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'AsyncController':
+ @staticmethod
+ def from_port(address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller':
"""
- Constructs a :class:`~stem.socket.ControlPort` based AsyncController.
+ Constructs a :class:`~stem.socket.ControlPort` based Controller.
If the **port** is **'default'** then this checks on both 9051 (default
for relays) and 9151 (default for the Tor Browser). This default may change
in the future.
+ .. versionchanged:: 1.5.0
+ Use both port 9051 and 9151 by default.
+
:param address: ip address of the controller
:param port: port number of the controller
@@ -1034,13 +1040,31 @@ class AsyncController(BaseController):
:raises: :class:`stem.SocketError` if we're unable to establish a connection
"""
- control_socket = _init_control_port(address, port)
- return cls(control_socket)
+ import stem.connection
+
+ if not stem.util.connection.is_valid_ipv4_address(address):
+ raise ValueError('Invalid IP address: %s' % address)
+ elif port != 'default' and not stem.util.connection.is_valid_port(port):
+ raise ValueError('Invalid port: %s' % port)
+
+ if port == 'default':
+ control_port = stem.connection._connection_for_default_port(address)
+ else:
+ control_port = stem.socket.ControlPort(address, int(port))
+
+ controller = Controller(control_port)
- @classmethod
- def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'AsyncController':
+ try:
+ controller.connect()
+ return controller
+ except:
+ controller.stop()
+ raise
+
+ @staticmethod
+ def from_socket_file(path: str = '/var/run/tor/control') -> 'stem.control.Controller':
"""
- Constructs a :class:`~stem.socket.ControlSocketFile` based AsyncController.
+ Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
:param path: path where the control socket is located
@@ -1049,8 +1073,15 @@ class AsyncController(BaseController):
:raises: :class:`stem.SocketError` if we're unable to establish a connection
"""
- control_socket = _init_control_socket_file(path)
- return cls(control_socket)
+ control_socket = stem.socket.ControlSocketFile(path)
+ controller = Controller(control_socket)
+
+ try:
+ controller.connect()
+ return controller
+ except:
+ controller.stop()
+ raise
def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None:
self._is_caching_enabled = True
@@ -1062,13 +1093,12 @@ class AsyncController(BaseController):
# mapping of event types to their listeners
self._event_listeners = {} # type: Dict[stem.control.EventType, List[Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]]]
- self._event_listeners_lock = asyncio.Lock()
self._enabled_features = [] # type: List[str]
self._last_address_exc = None # type: Optional[BaseException]
self._last_fingerprint_exc = None # type: Optional[BaseException]
- super(AsyncController, self).__init__(control_socket, is_authenticated)
+ super(Controller, self).__init__(control_socket, is_authenticated)
async def _sighup_listener(event: stem.response.events.SignalEvent) -> None:
if event.signal == Signal.RELOAD:
@@ -1101,11 +1131,16 @@ class AsyncController(BaseController):
self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER),
)
- self._asyncio_loop.create_task(_add_event_listeners())
+ self._loop.create_task(_add_event_listeners())
+
+ def __ainit__(self):
+ super(Controller, self).__ainit__()
+
+ self._event_listeners_lock = asyncio.Lock()
async def close(self) -> None:
self.clear_cache()
- await super(AsyncController, self).close()
+ await super(Controller, self).close()
async def authenticate(self, *args: Any, **kwargs: Any) -> None:
"""
@@ -1186,7 +1221,7 @@ class AsyncController(BaseController):
raise stem.ProtocolError('Tor geoip database is unavailable')
elif param == 'address' and self._last_address_exc:
raise self._last_address_exc # we already know we can't resolve an address
- elif param == 'fingerprint' and self._last_fingerprint_exc and self.get_conf('ORPort', None) is None:
+ elif param == 'fingerprint' and self._last_fingerprint_exc and await self.get_conf('ORPort', None) is None:
raise self._last_fingerprint_exc # we already know we're not a relay
# check for cached results
@@ -2082,7 +2117,6 @@ class AsyncController(BaseController):
request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
response = stem.response._convert_to_single_line(await self.msg(request))
- stem.response.convert('SINGLELINE', response)
if not response.is_ok():
raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code)
@@ -3778,7 +3812,7 @@ class AsyncController(BaseController):
await self.msg('DROPGUARDS')
async def _post_authentication(self) -> None:
- await super(AsyncController, self)._post_authentication()
+ await super(Controller, self)._post_authentication()
# try to re-attach event listeners to the new instance
@@ -3834,9 +3868,10 @@ class AsyncController(BaseController):
if listener_type == event_type:
for listener in event_listeners:
try:
- potential_coroutine = listener(event)
- if asyncio.iscoroutine(potential_coroutine):
- await potential_coroutine
+ listener_call = listener(event)
+
+ if asyncio.iscoroutine(listener_call):
+ await listener_call
except Exception as exc:
log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event))
@@ -3883,346 +3918,6 @@ class AsyncController(BaseController):
return (set_events, failed_events)
-def _set_doc_from_async_controller(func: Callable) -> Callable:
- func.__doc__ = getattr(AsyncController, func.__name__).__doc__
- return func
-
-
-class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
- """
- Connection with Tor's control socket. This wraps
- :class:`~stem.control.AsyncController` to provide a synchronous
- interface and for backwards compatibility.
- """
-
- @classmethod
- def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'Controller':
- """
- Constructs a :class:`~stem.socket.ControlPort` based Controller.
-
- If the **port** is **'default'** then this checks on both 9051 (default
- for relays) and 9151 (default for the Tor Browser). This default may change
- in the future.
-
- .. versionchanged:: 1.5.0
- Use both port 9051 and 9151 by default.
-
- :param address: ip address of the controller
- :param port: port number of the controller
-
- :returns: :class:`~stem.control.Controller` attached to the given port
-
- :raises: :class:`stem.SocketError` if we're unable to establish a connection
- """
-
- control_socket = _init_control_port(address, port)
- controller = cls(control_socket)
- controller.connect()
- return controller
-
- @classmethod
- def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'Controller':
- """
- Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
-
- :param str path: path where the control socket is located
-
- :returns: :class:`~stem.control.Controller` attached to the given socket file
-
- :raises: :class:`stem.SocketError` if we're unable to establish a connection
- """
-
- control_socket = _init_control_socket_file(path)
- controller = cls(control_socket)
- controller.connect()
- return controller
-
- def __init__(
- self,
- control_socket: stem.socket.ControlSocket,
- is_authenticated: bool = False,
- ) -> None:
- # if within an asyncio context use its loop, otherwise spawn our own
-
- try:
- self._loop = asyncio.get_running_loop()
- self._loop_thread = threading.current_thread()
- except RuntimeError:
- self._loop = asyncio.new_event_loop()
- self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
- self._loop_thread.setDaemon(True)
- self._loop_thread.start()
-
- self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated) # type: ignore
- self._socket = self._wrapped_instance._socket
-
- @_set_doc_from_async_controller
- def msg(self, message: str) -> stem.response.ControlMessage:
- return self._execute_async_method('msg', message)
-
- @_set_doc_from_async_controller
- def is_authenticated(self) -> bool:
- return self._wrapped_instance.is_authenticated()
-
- @_set_doc_from_async_controller
- def connect(self) -> None:
- self._execute_async_method('connect')
-
- @_set_doc_from_async_controller
- def reconnect(self, *args: Any, **kwargs: Any) -> None:
- self._execute_async_method('reconnect', *args, **kwargs)
-
- @_set_doc_from_async_controller
- def close(self) -> None:
- self._execute_async_method('close')
-
- @_set_doc_from_async_controller
- def get_latest_heartbeat(self) -> float:
- return self._wrapped_instance.get_latest_heartbeat()
-
- @_set_doc_from_async_controller
- def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None:
- self._wrapped_instance.add_status_listener(callback, spawn)
-
- @_set_doc_from_async_controller
- def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool:
- return self._wrapped_instance.remove_status_listener(callback)
-
- @_set_doc_from_async_controller
- def authenticate(self, *args: Any, **kwargs: Any) -> None:
- self._execute_async_method('authenticate', *args, **kwargs)
-
- @_set_doc_from_async_controller
- def get_info(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, get_bytes: bool = False) -> Union[str, Dict[str, str]]:
- return self._execute_async_method('get_info', params, default, get_bytes)
-
- @_set_doc_from_async_controller
- def get_version(self, default: Any = UNDEFINED) -> stem.version.Version:
- return self._execute_async_method('get_version', default)
-
- @_set_doc_from_async_controller
- def get_exit_policy(self, default: Any = UNDEFINED) -> stem.exit_policy.ExitPolicy:
- return self._execute_async_method('get_exit_policy', default)
-
- @_set_doc_from_async_controller
- def get_ports(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[int]:
- return self._execute_async_method('get_ports', listener_type, default)
-
- @_set_doc_from_async_controller
- def get_listeners(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[Tuple[str, int]]:
- return self._execute_async_method('get_listeners', listener_type, default)
-
- @_set_doc_from_async_controller
- def get_accounting_stats(self, default: Any = UNDEFINED) -> 'stem.control.AccountingStats':
- return self._execute_async_method('get_accounting_stats', default)
-
- @_set_doc_from_async_controller
- def get_protocolinfo(self, default: Any = UNDEFINED) -> stem.response.protocolinfo.ProtocolInfoResponse:
- return self._execute_async_method('get_protocolinfo', default)
-
- @_set_doc_from_async_controller
- def get_user(self, default: Any = UNDEFINED) -> str:
- return self._execute_async_method('get_user', default)
-
- @_set_doc_from_async_controller
- def get_pid(self, default: Any = UNDEFINED) -> int:
- return self._execute_async_method('get_pid', default)
-
- @_set_doc_from_async_controller
- def get_start_time(self, default: Any = UNDEFINED) -> float:
- return self._execute_async_method('get_start_time', default)
-
- @_set_doc_from_async_controller
- def get_uptime(self, default: Any = UNDEFINED) -> float:
- return self._execute_async_method('get_uptime', default)
-
- @_set_doc_from_async_controller
- def is_user_traffic_allowed(self) -> 'stem.control.UserTrafficAllowed':
- return self._execute_async_method('is_user_traffic_allowed')
-
- @_set_doc_from_async_controller
- def get_microdescriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.microdescriptor.Microdescriptor:
- return self._execute_async_method('get_microdescriptor', relay, default)
-
- @_set_doc_from_async_controller
- def get_microdescriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]:
- return self._execute_async_generator_method('get_microdescriptors', default)
-
- @_set_doc_from_async_controller
- def get_server_descriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor:
- return self._execute_async_method('get_server_descriptor', relay, default)
-
- @_set_doc_from_async_controller
- def get_server_descriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]:
- return self._execute_async_generator_method('get_server_descriptors', default)
-
- @_set_doc_from_async_controller
- def get_network_status(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.router_status_entry.RouterStatusEntryV3:
- return self._execute_async_method('get_network_status', relay, default)
-
- @_set_doc_from_async_controller
- def get_network_statuses(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]:
- return self._execute_async_generator_method('get_network_statuses', default)
-
- @_set_doc_from_async_controller
- def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2:
- return self._execute_async_method('get_hidden_service_descriptor', address, default, servers, await_result, timeout)
-
- @_set_doc_from_async_controller
- def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]:
- return self._execute_async_method('get_conf', param, default, multiple)
-
- @_set_doc_from_async_controller
- def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
- return self._execute_async_method('get_conf_map', params, default, multiple)
-
- @_set_doc_from_async_controller
- def is_set(self, param: str, default: Any = UNDEFINED) -> bool:
- return self._execute_async_method('is_set', param, default)
-
- @_set_doc_from_async_controller
- def set_conf(self, param: str, value: Union[str, Sequence[str]]) -> None:
- self._execute_async_method('set_conf', param, value)
-
- @_set_doc_from_async_controller
- def reset_conf(self, *params: str) -> None:
- self._execute_async_method('reset_conf', *params)
-
- @_set_doc_from_async_controller
- def set_options(self, params: Union[Mapping[str, Union[str, Sequence[str]]], Sequence[Tuple[str, Union[str, Sequence[str]]]]], reset: bool = False) -> None:
- self._execute_async_method('set_options', params, reset)
-
- @_set_doc_from_async_controller
- def get_hidden_service_conf(self, default: Any = UNDEFINED) -> Dict[str, Any]:
- return self._execute_async_method('get_hidden_service_conf', default)
-
- @_set_doc_from_async_controller
- def set_hidden_service_conf(self, conf: Mapping[str, Any]) -> None:
- self._execute_async_method('set_hidden_service_conf', conf)
-
- @_set_doc_from_async_controller
- def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.control.CreateHiddenServiceOutput':
- return self._execute_async_method('create_hidden_service', path, port, target_address, target_port, auth_type, client_names)
-
- @_set_doc_from_async_controller
- def remove_hidden_service(self, path: str, port: Optional[int] = None) -> bool:
- return self._execute_async_method('remove_hidden_service', path, port)
-
- @_set_doc_from_async_controller
- def list_ephemeral_hidden_services(self, default: Any = UNDEFINED, our_services: bool = True, detached: bool = False) -> Sequence[str]:
- return self._execute_async_method('list_ephemeral_hidden_services', default, our_services, detached)
-
- @_set_doc_from_async_controller
- def create_ephemeral_hidden_service(self, ports: Union[int, Sequence[int], Mapping[int, str]], key_type: str = 'NEW', key_content: str = 'BEST', discard_key: bool = False, detached: bool = False, await_publication: bool = False, timeout: Optional[float] = None, basic_auth: Optional[Mapping[str, str]] = None, max_streams: Optional[int] = None) -> stem.response.add_onion.AddOnionResponse:
- return self._execute_async_method('create_ephemeral_hidden_service', ports, key_type, key_content, discard_key, detached, await_publication, timeout, basic_auth, max_streams)
-
- @_set_doc_from_async_controller
- def remove_ephemeral_hidden_service(self, service_id: str) -> bool:
- return self._execute_async_method('remove_ephemeral_hidden_service', service_id)
-
- @_set_doc_from_async_controller
- def add_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]], *events: 'stem.control.EventType') -> None:
- self._execute_async_method('add_event_listener', listener, *events)
-
- @_set_doc_from_async_controller
- def remove_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]) -> None:
- self._execute_async_method('remove_event_listener', listener)
-
- @_set_doc_from_async_controller
- def is_caching_enabled(self) -> bool:
- return self._wrapped_instance.is_caching_enabled()
-
- @_set_doc_from_async_controller
- def set_caching(self, enabled: bool) -> None:
- self._wrapped_instance.set_caching(enabled)
-
- @_set_doc_from_async_controller
- def clear_cache(self) -> None:
- self._wrapped_instance.clear_cache()
-
- @_set_doc_from_async_controller
- def load_conf(self, configtext: str) -> None:
- self._execute_async_method('load_conf', configtext)
-
- @_set_doc_from_async_controller
- def save_conf(self, force: bool = False) -> None:
- return self._execute_async_method('save_conf', force)
-
- @_set_doc_from_async_controller
- def is_feature_enabled(self, feature: str) -> bool:
- return self._wrapped_instance.is_feature_enabled(feature)
-
- @_set_doc_from_async_controller
- def enable_feature(self, features: Union[str, Sequence[str]]) -> None:
- self._wrapped_instance.enable_feature(features)
-
- @_set_doc_from_async_controller
- def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent:
- return self._execute_async_method('get_circuit', circuit_id, default)
-
- @_set_doc_from_async_controller
- def get_circuits(self, default: Any = UNDEFINED) -> List[stem.response.events.CircuitEvent]:
- return self._execute_async_method('get_circuits', default)
-
- @_set_doc_from_async_controller
- def new_circuit(self, path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str:
- return self._execute_async_method('new_circuit', path, purpose, await_build, timeout)
-
- @_set_doc_from_async_controller
- def extend_circuit(self, circuit_id: str = '0', path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str:
- return self._execute_async_method('extend_circuit', circuit_id, path, purpose, await_build, timeout)
-
- @_set_doc_from_async_controller
- def repurpose_circuit(self, circuit_id: str, purpose: str) -> None:
- self._execute_async_method('repurpose_circuit', circuit_id, purpose)
-
- @_set_doc_from_async_controller
- def close_circuit(self, circuit_id: str, flag: str = '') -> None:
- self._execute_async_method('close_circuit', circuit_id, flag)
-
- @_set_doc_from_async_controller
- def get_streams(self, default: Any = UNDEFINED) -> List[stem.response.events.StreamEvent]:
- return self._execute_async_method('get_streams', default)
-
- @_set_doc_from_async_controller
- def attach_stream(self, stream_id: str, circuit_id: str, exiting_hop: Optional[int] = None) -> None:
- self._execute_async_method('attach_stream', stream_id, circuit_id, exiting_hop)
-
- @_set_doc_from_async_controller
- def close_stream(self, stream_id: str, reason: stem.RelayEndReason = stem.RelayEndReason.MISC, flag: str = '') -> None:
- self._execute_async_method('close_stream', stream_id, reason, flag)
-
- @_set_doc_from_async_controller
- def signal(self, signal: stem.Signal) -> None:
- self._execute_async_method('signal', signal)
-
- @_set_doc_from_async_controller
- def is_newnym_available(self) -> bool:
- return self._wrapped_instance.is_newnym_available()
-
- @_set_doc_from_async_controller
- def get_newnym_wait(self) -> float:
- return self._wrapped_instance.get_newnym_wait()
-
- @_set_doc_from_async_controller
- def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int:
- return self._execute_async_method('get_effective_rate', default, burst)
-
- @_set_doc_from_async_controller
- def map_address(self, mapping: Mapping[str, str]) -> Dict[str, str]:
- return self._execute_async_method('map_address', mapping)
-
- @_set_doc_from_async_controller
- def drop_guards(self) -> None:
- self._execute_async_method('drop_guards')
-
- def __enter__(self) -> 'stem.control.Controller':
- return self
-
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
- self.close()
-
-
def _parse_circ_path(path: str) -> Sequence[Tuple[str, str]]:
"""
Parses a circuit path as a list of **(fingerprint, nickname)** tuples. Tor
@@ -4342,26 +4037,6 @@ async def _get_with_timeout(event_queue: asyncio.Queue, timeout: Optional[float]
time_left = None
try:
- return await asyncio.wait_for(event_queue.get(), timeout=time_left)
+ return await asyncio.wait_for(event_queue.get(), timeout = time_left)
except asyncio.TimeoutError:
raise stem.Timeout('Reached our %0.1f second timeout' % timeout)
-
-
-def _init_control_port(address: str, port: Union[int, str]) -> stem.socket.ControlPort:
- import stem.connection
-
- if not stem.util.connection.is_valid_ipv4_address(address):
- raise ValueError('Invalid IP address: %s' % address)
- elif port != 'default' and not stem.util.connection.is_valid_port(port):
- raise ValueError('Invalid port: %s' % port)
-
- if port == 'default':
- control_port = stem.connection._connection_for_default_port(address)
- else:
- control_port = stem.socket.ControlPort(address, int(port))
-
- return control_port
-
-
-def _init_control_socket_file(path: str) -> stem.socket.ControlSocketFile:
- return stem.socket.ControlSocketFile(path)
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index 942d81e9..3428f0d2 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -445,13 +445,13 @@ class Query(Synchronous):
loop = asyncio.get_running_loop()
self._downloader_task = loop.create_task(self._download_descriptors(self.retries, self.timeout))
- async def run(self, suppress: bool = False, close: bool = True) -> List['stem.descriptor.Descriptor']:
+ async def run(self, suppress: bool = False, stop: bool = True) -> List['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
:param suppress: avoids raising exceptions if **True**
- :param close: terminates the resources backing this query if **True**,
+ :param stop: terminates the resources backing this query if **True**,
further method calls will raise a RuntimeError
:returns: list for the requested :class:`~stem.descriptor.__init__.Descriptor` instances
@@ -465,14 +465,14 @@ class Query(Synchronous):
* :class:`~stem.DownloadFailed` if our request fails
"""
- # TODO: We should replace our 'close' argument with a new API design prior
+ # TODO: We should replace our 'stop' argument with a new API design prior
# to release. Self-destructing this object by default for synchronous users
# is quite a step backward, but is acceptable as we iterate on this.
try:
return [desc async for desc in self._run(suppress)]
finally:
- if close:
+ if stop:
self._loop.call_soon_threadsafe(self._loop.stop)
async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index 370b9aa6..872e7f6f 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -127,7 +127,7 @@ def main() -> None:
async def handle_event(event_message: stem.response.ControlMessage) -> None:
print(format(str(event_message), *STANDARD_OUTPUT))
- controller._wrapped_instance._handle_event = handle_event # type: ignore
+ controller._handle_event = handle_event # type: ignore
if sys.stdout.isatty():
events = args.run_cmd.upper().split(' ', 1)[1]
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index 99f1219d..b04fcc85 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -128,7 +128,7 @@ class ControlInterpreter(code.InteractiveConsole):
# Intercept events our controller hears about at a pretty low level since
# the user will likely be requesting them by direct 'SETEVENTS' calls.
- handle_event_real = self._controller._wrapped_instance._handle_event
+ handle_event_real = self._controller._handle_event
async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
await handle_event_real(event_message)
@@ -139,7 +139,7 @@ class ControlInterpreter(code.InteractiveConsole):
# type check disabled due to https://github.com/python/mypy/issues/708
- self._controller._wrapped_instance._handle_event = handle_event_wrapper # type: ignore
+ self._controller._handle_event = handle_event_wrapper # type: ignore
def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]:
events = list(self._received_events)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index d780a0de..de946fd9 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -13,7 +13,6 @@ import threading
import typing
import unittest.mock
-from concurrent.futures import Future
from types import TracebackType
from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
@@ -211,6 +210,7 @@ class Synchronous(object):
self._no_op = Synchronous.is_asyncio_context()
if self._no_op:
+ self._loop = asyncio.get_running_loop()
self.__ainit__() # this is already an asyncio context
else:
# Run coroutines through our loop. This calls methods by name rather than
@@ -361,44 +361,3 @@ class Synchronous(object):
def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
return self._run_async_method('__aexit__', exit_type, value, traceback)
-
-
-class AsyncClassWrapper:
- _loop: asyncio.AbstractEventLoop
- _loop_thread: threading.Thread
- _wrapped_instance: type
-
- def _init_async_class(self, async_class: Type, *args: Any, **kwargs: Any) -> Any:
- # The asynchronous class should be initialized in the thread where
- # its methods will be executed.
- if self._loop_thread != threading.current_thread():
- async def init():
- return async_class(*args, **kwargs)
-
- return asyncio.run_coroutine_threadsafe(init(), self._loop).result()
-
- return async_class(*args, **kwargs)
-
- def _call_async_method_soon(self, method_name: str, *args: Any, **kwargs: Any) -> Future:
- return asyncio.run_coroutine_threadsafe(
- getattr(self._wrapped_instance, method_name)(*args, **kwargs),
- self._loop,
- )
-
- def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
- return self._call_async_method_soon(method_name, *args, **kwargs).result()
-
- def _execute_async_generator_method(self, method_name: str, *args: Any, **kwargs: Any) -> Iterator:
- async def convert_async_generator(generator: AsyncIterator) -> Iterator:
- return iter([d async for d in generator])
-
- return asyncio.run_coroutine_threadsafe(
- convert_async_generator(
- getattr(self._wrapped_instance, method_name)(*args, **kwargs),
- ),
- self._loop,
- ).result()
-
- def __del__(self) -> None:
- self._loop.call_soon_threadsafe(self._loop.stop)
- self._loop_thread.join()
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py
index 67133195..ac9f8b88 100644
--- a/stem/util/test_tools.py
+++ b/stem/util/test_tools.py
@@ -696,11 +696,14 @@ def async_test(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.new_event_loop()
+
try:
result = loop.run_until_complete(func(*args, **kwargs))
finally:
loop.close()
+
return result
+
return wrapper
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index 042d3939..bbca5e43 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -3,7 +3,6 @@ Integration tests for authenticating to the control socket via
stem.connection.authenticate* functions.
"""
-import asyncio
import os
import unittest
@@ -121,11 +120,8 @@ class TestAuthenticate(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller:
- asyncio.run_coroutine_threadsafe(
- stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()),
- controller._loop,
- ).result()
+ async with await runner.get_tor_controller(False) as controller:
+ await stem.connection.authenticate(controller, test.runner.CONTROL_PASSWORD, runner.get_chroot())
await test.runner.exercise_controller(self, controller)
@test.require.controller
@@ -276,7 +272,8 @@ class TestAuthenticate(unittest.TestCase):
await self._check_auth(auth_type, auth_value)
@test.require.controller
- def test_wrong_password_with_controller(self):
+ @async_test
+ async def test_wrong_password_with_controller(self):
"""
We ran into a race condition where providing the wrong password to the
Controller caused inconsistent responses. Checking for that...
@@ -290,7 +287,7 @@ class TestAuthenticate(unittest.TestCase):
self.skipTest('(requires only password auth)')
for i in range(10):
- with runner.get_tor_controller(False) as controller:
+ async with await runner.get_tor_controller(False) as controller:
with self.assertRaises(stem.connection.IncorrectPassword):
controller.authenticate('wrong_password')
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index 7853d407..33b62ea7 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -7,7 +7,6 @@ import os
import shutil
import socket
import tempfile
-import threading
import time
import unittest
@@ -38,24 +37,25 @@ TEST_ROUTER_STATUS_ENTRY = None
class TestController(unittest.TestCase):
@test.require.only_run_once
@test.require.controller
- def test_missing_capabilities(self):
+ @async_test
+ async def test_missing_capabilities(self):
"""
Check to see if tor supports any events, signals, or features that we
don't.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- for event in controller.get_info('events/names').split():
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ for event in (await controller.get_info('events/names')).split():
if event not in EventType:
test.register_new_capability('Event', event)
- for signal in controller.get_info('signal/names').split():
+ for signal in (await controller.get_info('signal/names')).split():
if signal not in Signal:
test.register_new_capability('Signal', signal)
# new features should simply be added to enable_feature()'s docs
- for feature in controller.get_info('features/names').split():
+ for feature in (await controller.get_info('features/names')).split():
if feature not in ('EXTENDED_EVENTS', 'VERBOSE_NAMES'):
test.register_new_capability('Feature', feature)
@@ -88,7 +88,7 @@ class TestController(unittest.TestCase):
Checks that a notificiation listener is... well, notified of SIGHUPs.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
received_events = []
def status_listener(my_controller, state, timestamp):
@@ -97,7 +97,7 @@ class TestController(unittest.TestCase):
controller.add_status_listener(status_listener)
before = time.time()
- controller.signal(Signal.HUP)
+ await controller.signal(Signal.HUP)
# I really hate adding a sleep here, but signal() is non-blocking.
while len(received_events) == 0:
@@ -112,20 +112,21 @@ class TestController(unittest.TestCase):
state_controller, state_type, state_timestamp = received_events[0]
- self.assertEqual(controller._wrapped_instance, state_controller)
+ self.assertEqual(controller, state_controller)
self.assertEqual(State.RESET, state_type)
self.assertTrue(state_timestamp > before and state_timestamp < after)
- controller.reset_conf('__OwningControllerProcess')
+ await controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- def test_event_handling(self):
+ @async_test
+ async def test_event_handling(self):
"""
Add a couple listeners for various events and make sure that they receive
them. Then remove the listeners.
"""
- event_notice1, event_notice2 = threading.Event(), threading.Event()
+ event_notice1, event_notice2 = asyncio.Event(), asyncio.Event()
event_buffer1, event_buffer2 = [], []
def listener1(event):
@@ -138,30 +139,30 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- controller.add_event_listener(listener1, EventType.CONF_CHANGED)
- controller.add_event_listener(listener2, EventType.CONF_CHANGED, EventType.DEBUG)
+ async with await runner.get_tor_controller() as controller:
+ await controller.add_event_listener(listener1, EventType.CONF_CHANGED)
+ await controller.add_event_listener(listener2, EventType.CONF_CHANGED, EventType.DEBUG)
# The NodeFamily is a harmless option we can toggle
- controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
+ await controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
# Wait for the event. Assert that we get it within 10 seconds
- event_notice1.wait(10)
+ await asyncio.wait_for(event_notice1.wait(), timeout = 10)
self.assertEqual(len(event_buffer1), 1)
event_notice1.clear()
- event_notice2.wait(10)
+ await asyncio.wait_for(event_notice2.wait(), timeout = 10)
self.assertTrue(len(event_buffer2) >= 1)
event_notice2.clear()
# Checking that a listener's no longer called after being removed.
- controller.remove_event_listener(listener2)
+ await controller.remove_event_listener(listener2)
buffer2_size = len(event_buffer2)
- controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401')
- event_notice1.wait(10)
+ await controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401')
+ await asyncio.wait_for(event_notice1.wait(), timeout = 10)
self.assertEqual(len(event_buffer1), 2)
event_notice1.clear()
@@ -174,16 +175,17 @@ class TestController(unittest.TestCase):
self.assertTrue(isinstance(event, stem.response.events.ConfChangedEvent))
- controller.reset_conf('NodeFamily')
+ await controller.reset_conf('NodeFamily')
@test.require.controller
- def test_reattaching_listeners(self):
+ @async_test
+ async def test_reattaching_listeners(self):
"""
Checks that event listeners are re-attached when a controller disconnects
then reconnects to tor.
"""
- event_notice = threading.Event()
+ event_notice = asyncio.Event()
event_buffer = []
def listener(event):
@@ -192,79 +194,85 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- controller.add_event_listener(listener, EventType.CONF_CHANGED)
+ async with await runner.get_tor_controller() as controller:
+ await controller.add_event_listener(listener, EventType.CONF_CHANGED)
# trigger an event
- controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
- event_notice.wait(4)
+ await controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
+ await asyncio.wait_for(event_notice.wait(), timeout = 4)
self.assertTrue(len(event_buffer) >= 1)
# disconnect, then reconnect and check that we get events again
- controller.close()
+ await controller.close()
event_notice.clear()
event_buffer = []
- controller.connect()
- controller.authenticate(password = test.runner.CONTROL_PASSWORD)
+ await controller.connect()
+ await controller.authenticate(password = test.runner.CONTROL_PASSWORD)
self.assertTrue(len(event_buffer) == 0)
- controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401')
+ await controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401')
- event_notice.wait(4)
+ await asyncio.wait_for(event_notice.wait(), timeout = 4)
self.assertTrue(len(event_buffer) >= 1)
- controller.reset_conf('NodeFamily')
+ await controller.reset_conf('NodeFamily')
@test.require.controller
- def test_getinfo(self):
+ @async_test
+ async def test_getinfo(self):
"""
Exercises GETINFO with valid and invalid queries.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# successful single query
torrc_path = runner.get_torrc_path()
- self.assertEqual(torrc_path, controller.get_info('config-file'))
- self.assertEqual(torrc_path, controller.get_info('config-file', 'ho hum'))
+ self.assertEqual(torrc_path, await controller.get_info('config-file'))
+ self.assertEqual(torrc_path, await controller.get_info('config-file', 'ho hum'))
expected = {'config-file': torrc_path}
- self.assertEqual(expected, controller.get_info(['config-file']))
- self.assertEqual(expected, controller.get_info(['config-file'], 'ho hum'))
+ self.assertEqual(expected, await controller.get_info(['config-file']))
+ self.assertEqual(expected, await controller.get_info(['config-file'], 'ho hum'))
# successful batch query, we don't know the values so just checking for
# the keys
getinfo_params = set(['version', 'config-file', 'config/names'])
- self.assertEqual(getinfo_params, set(controller.get_info(['version', 'config-file', 'config/names']).keys()))
+ self.assertEqual(getinfo_params, set((await controller.get_info(['version', 'config-file', 'config/names'])).keys()))
# non-existant option
- self.assertRaises(stem.ControllerError, controller.get_info, 'blarg')
- self.assertEqual('ho hum', controller.get_info('blarg', 'ho hum'))
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_info('blarg')
+
+ self.assertEqual('ho hum', await controller.get_info('blarg', 'ho hum'))
# empty input
- self.assertRaises(stem.ControllerError, controller.get_info, '')
- self.assertEqual('ho hum', controller.get_info('', 'ho hum'))
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_info('')
+
+ self.assertEqual('ho hum', await controller.get_info('', 'ho hum'))
- self.assertEqual({}, controller.get_info([]))
- self.assertEqual({}, controller.get_info([], {}))
+ self.assertEqual({}, await controller.get_info([]))
+ self.assertEqual({}, await controller.get_info([], {}))
@test.require.controller
- def test_getinfo_freshrelaydescs(self):
+ @async_test
+ async def test_getinfo_freshrelaydescs(self):
"""
Exercises 'GETINFO status/fresh-relay-descs'.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- response = controller.get_info('status/fresh-relay-descs')
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ response = await controller.get_info('status/fresh-relay-descs')
div = response.find('\nextra-info ')
- nickname = controller.get_conf('Nickname')
+ nickname = await controller.get_conf('Nickname')
if div == -1:
self.fail('GETINFO response should have both a server and extrainfo descriptor:\n%s' % response)
@@ -274,44 +282,47 @@ class TestController(unittest.TestCase):
self.assertEqual(nickname, server_desc.nickname)
self.assertEqual(nickname, extrainfo_desc.nickname)
- self.assertEqual(controller.get_info('address'), server_desc.address)
+ self.assertEqual(await controller.get_info('address'), server_desc.address)
self.assertEqual(test.runner.ORPORT, server_desc.or_port)
@test.require.controller
@test.require.online
- def test_getinfo_dir_status(self):
+ @async_test
+ async def test_getinfo_dir_status(self):
"""
Exercise 'GETINFO dir/status-vote/*'.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- consensus = controller.get_info('dir/status-vote/current/consensus')
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ consensus = await controller.get_info('dir/status-vote/current/consensus')
self.assertTrue('moria1' in consensus, 'moria1 not found in the consensus')
if test.tor_version() >= stem.version.Version('0.4.3.1-alpha'):
- microdescs = controller.get_info('dir/status-vote/current/consensus-microdesc')
+ microdescs = await controller.get_info('dir/status-vote/current/consensus-microdesc')
self.assertTrue('moria1' in microdescs, 'moria1 not found in the microdescriptor consensus')
@test.require.controller
- def test_get_version(self):
+ @async_test
+ async def test_get_version(self):
"""
Test that the convenient method get_version() works.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- version = controller.get_version()
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ version = await controller.get_version()
self.assertTrue(isinstance(version, stem.version.Version))
self.assertEqual(version, test.tor_version())
@test.require.controller
- def test_get_exit_policy(self):
+ @async_test
+ async def test_get_exit_policy(self):
"""
Sanity test for get_exit_policy(). Our 'ExitRelay 0' torrc entry causes us
to have a simple reject-all policy.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- self.assertEqual(ExitPolicy('reject *:*'), controller.get_exit_policy())
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ self.assertEqual(ExitPolicy('reject *:*'), await controller.get_exit_policy())
@test.require.controller
@async_test
@@ -322,20 +333,21 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller:
- controller.authenticate(test.runner.CONTROL_PASSWORD)
+ async with await runner.get_tor_controller(False) as controller:
+ await controller.authenticate(test.runner.CONTROL_PASSWORD)
await test.runner.exercise_controller(self, controller)
@test.require.controller
- def test_protocolinfo(self):
+ @async_test
+ async def test_protocolinfo(self):
"""
Test that the convenient method protocolinfo() works.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller:
- protocolinfo = controller.get_protocolinfo()
+ async with await runner.get_tor_controller(False) as controller:
+ protocolinfo = await controller.get_protocolinfo()
self.assertTrue(isinstance(protocolinfo, stem.response.protocolinfo.ProtocolInfoResponse))
# Doing a sanity test on the ProtocolInfoResponse instance returned.
@@ -355,14 +367,15 @@ class TestController(unittest.TestCase):
self.assertEqual(tuple(auth_methods), protocolinfo.auth_methods)
@test.require.controller
- def test_getconf(self):
+ @async_test
+ async def test_getconf(self):
"""
Exercises GETCONF with valid and invalid queries.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
control_socket = controller.get_socket()
if isinstance(control_socket, stem.socket.ControlPort):
@@ -373,79 +386,89 @@ class TestController(unittest.TestCase):
config_key = 'ControlSocket'
# successful single query
- self.assertEqual(connection_value, controller.get_conf(config_key))
- self.assertEqual(connection_value, controller.get_conf(config_key, 'la-di-dah'))
+ self.assertEqual(connection_value, await controller.get_conf(config_key))
+ self.assertEqual(connection_value, await controller.get_conf(config_key, 'la-di-dah'))
# succeessful batch query
expected = {config_key: [connection_value]}
- self.assertEqual(expected, controller.get_conf_map([config_key]))
- self.assertEqual(expected, controller.get_conf_map([config_key], 'la-di-dah'))
+ self.assertEqual(expected, await controller.get_conf_map([config_key]))
+ self.assertEqual(expected, await controller.get_conf_map([config_key], 'la-di-dah'))
request_params = ['ControlPORT', 'dirport', 'datadirectory']
- reply_params = controller.get_conf_map(request_params, multiple=False).keys()
+ reply_params = (await controller.get_conf_map(request_params, multiple=False)).keys()
self.assertEqual(set(request_params), set(reply_params))
# queries an option that is unset
- self.assertEqual(None, controller.get_conf('HTTPSProxy'))
- self.assertEqual('la-di-dah', controller.get_conf('HTTPSProxy', 'la-di-dah'))
- self.assertEqual([], controller.get_conf('HTTPSProxy', [], multiple = True))
+ self.assertEqual(None, await controller.get_conf('HTTPSProxy'))
+ self.assertEqual('la-di-dah', await controller.get_conf('HTTPSProxy', 'la-di-dah'))
+ self.assertEqual([], await controller.get_conf('HTTPSProxy', [], multiple = True))
# non-existant option(s)
- self.assertRaises(stem.InvalidArguments, controller.get_conf, 'blarg')
- self.assertEqual('la-di-dah', controller.get_conf('blarg', 'la-di-dah'))
- self.assertRaises(stem.InvalidArguments, controller.get_conf_map, 'blarg')
- self.assertEqual({'blarg': 'la-di-dah'}, controller.get_conf_map('blarg', 'la-di-dah'))
- self.assertRaises(stem.InvalidRequest, controller.get_conf_map, ['blarg', 'huadf'], multiple = True)
- self.assertEqual({'erfusdj': 'la-di-dah', 'afiafj': 'la-di-dah'}, controller.get_conf_map(['erfusdj', 'afiafj'], 'la-di-dah', multiple = True))
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.get_conf('blarg')
+
+ self.assertEqual('la-di-dah', await controller.get_conf('blarg', 'la-di-dah'))
+
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.get_conf_map('blarg')
+
+ self.assertEqual({'blarg': 'la-di-dah'}, await controller.get_conf_map('blarg', 'la-di-dah'))
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.get_conf_map(['blarg', 'huadf'], multiple = True)
+
+ self.assertEqual({'erfusdj': 'la-di-dah', 'afiafj': 'la-di-dah'}, await controller.get_conf_map(['erfusdj', 'afiafj'], 'la-di-dah', multiple = True))
# multivalue configuration keys
nodefamilies = [('abc', 'xyz', 'pqrs'), ('mno', 'tuv', 'wxyz')]
- controller.msg('SETCONF %s' % ' '.join(['nodefamily="' + ','.join(x) + '"' for x in nodefamilies]))
- self.assertEqual([','.join(n) for n in nodefamilies], controller.get_conf('nodefamily', multiple = True))
- controller.msg('RESETCONF NodeFamily')
+ await controller.msg('SETCONF %s' % ' '.join(['nodefamily="' + ','.join(x) + '"' for x in nodefamilies]))
+ self.assertEqual([','.join(n) for n in nodefamilies], await controller.get_conf('nodefamily', multiple = True))
+ await controller.msg('RESETCONF NodeFamily')
# empty input
- self.assertEqual(None, controller.get_conf(''))
- self.assertEqual({}, controller.get_conf_map([]))
- self.assertEqual({}, controller.get_conf_map(['']))
- self.assertEqual(None, controller.get_conf(' '))
- self.assertEqual({}, controller.get_conf_map([' ', ' ']))
+ self.assertEqual(None, await controller.get_conf(''))
+ self.assertEqual({}, await controller.get_conf_map([]))
+ self.assertEqual({}, await controller.get_conf_map(['']))
+ self.assertEqual(None, await controller.get_conf(' '))
+ self.assertEqual({}, await controller.get_conf_map([' ', ' ']))
- self.assertEqual('la-di-dah', controller.get_conf('', 'la-di-dah'))
- self.assertEqual({}, controller.get_conf_map('', 'la-di-dah'))
- self.assertEqual({}, controller.get_conf_map([], 'la-di-dah'))
+ self.assertEqual('la-di-dah', await controller.get_conf('', 'la-di-dah'))
+ self.assertEqual({}, await controller.get_conf_map('', 'la-di-dah'))
+ self.assertEqual({}, await controller.get_conf_map([], 'la-di-dah'))
@test.require.controller
- def test_is_set(self):
+ @async_test
+ async def test_is_set(self):
"""
Exercises our is_set() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- custom_options = controller._execute_async_method('_get_custom_options')
+ async with await runner.get_tor_controller() as controller:
+ custom_options = await controller._get_custom_options()
self.assertTrue('ControlPort' in custom_options or 'ControlSocket' in custom_options)
self.assertEqual('1', custom_options['DownloadExtraInfo'])
self.assertEqual('1112', custom_options['SocksPort'])
- self.assertTrue(controller.is_set('DownloadExtraInfo'))
- self.assertTrue(controller.is_set('SocksPort'))
- self.assertFalse(controller.is_set('CellStatistics'))
- self.assertFalse(controller.is_set('ConnLimit'))
+ self.assertTrue(await controller.is_set('DownloadExtraInfo'))
+ self.assertTrue(await controller.is_set('SocksPort'))
+ self.assertFalse(await controller.is_set('CellStatistics'))
+ self.assertFalse(await controller.is_set('ConnLimit'))
# check we update when setting and resetting values
- controller.set_conf('ConnLimit', '1005')
- self.assertTrue(controller.is_set('ConnLimit'))
+ await controller.set_conf('ConnLimit', '1005')
+ self.assertTrue(await controller.is_set('ConnLimit'))
- controller.reset_conf('ConnLimit')
- self.assertFalse(controller.is_set('ConnLimit'))
+ await controller.reset_conf('ConnLimit')
+ self.assertFalse(await controller.is_set('ConnLimit'))
@test.require.controller
- def test_hidden_services_conf(self):
+ @async_test
+ async def test_hidden_services_conf(self):
"""
Exercises the hidden service family of methods (get_hidden_service_conf,
set_hidden_service_conf, create_hidden_service, and remove_hidden_service).
@@ -459,16 +482,16 @@ class TestController(unittest.TestCase):
service3_path = os.path.join(test_dir, 'test_hidden_service3')
service4_path = os.path.join(test_dir, 'test_hidden_service4')
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
try:
# initially we shouldn't be running any hidden services
- self.assertEqual({}, controller.get_hidden_service_conf())
+ self.assertEqual({}, await controller.get_hidden_service_conf())
# try setting a blank config, shouldn't have any impact
- controller.set_hidden_service_conf({})
- self.assertEqual({}, controller.get_hidden_service_conf())
+ await controller.set_hidden_service_conf({})
+ self.assertEqual({}, await controller.get_hidden_service_conf())
# create a hidden service
@@ -491,58 +514,58 @@ class TestController(unittest.TestCase):
},
}
- controller.set_hidden_service_conf(initialconf)
- self.assertEqual(initialconf, controller.get_hidden_service_conf())
+ await controller.set_hidden_service_conf(initialconf)
+ self.assertEqual(initialconf, await controller.get_hidden_service_conf())
# add already existing services, with/without explicit target
- self.assertEqual(None, controller.create_hidden_service(service1_path, 8020))
- self.assertEqual(None, controller.create_hidden_service(service1_path, 8021, target_port = 8021))
- self.assertEqual(initialconf, controller.get_hidden_service_conf())
+ self.assertEqual(None, await controller.create_hidden_service(service1_path, 8020))
+ self.assertEqual(None, await controller.create_hidden_service(service1_path, 8021, target_port = 8021))
+ self.assertEqual(initialconf, await controller.get_hidden_service_conf())
# add a new service, with/without explicit target
hs_path = os.path.join(os.getcwd(), service3_path)
- hs_address1 = controller.create_hidden_service(hs_path, 8888).hostname
- hs_address2 = controller.create_hidden_service(hs_path, 8989, target_port = 8021).hostname
+ hs_address1 = (await controller.create_hidden_service(hs_path, 8888)).hostname
+ hs_address2 = (await controller.create_hidden_service(hs_path, 8989, target_port = 8021)).hostname
self.assertEqual(hs_address1, hs_address2)
self.assertTrue(hs_address1.endswith('.onion'))
- conf = controller.get_hidden_service_conf()
+ conf = await controller.get_hidden_service_conf()
self.assertEqual(3, len(conf))
self.assertEqual(2, len(conf[hs_path]['HiddenServicePort']))
# remove a hidden service, the service dir should still be there
- controller.remove_hidden_service(hs_path, 8888)
- self.assertEqual(3, len(controller.get_hidden_service_conf()))
+ await controller.remove_hidden_service(hs_path, 8888)
+ self.assertEqual(3, len(await controller.get_hidden_service_conf()))
# remove a service completely, it should now be gone
- controller.remove_hidden_service(hs_path, 8989)
- self.assertEqual(2, len(controller.get_hidden_service_conf()))
+ await controller.remove_hidden_service(hs_path, 8989)
+ self.assertEqual(2, len(await controller.get_hidden_service_conf()))
# add a new service, this time with client authentication
hs_path = os.path.join(os.getcwd(), service4_path)
- hs_attributes = controller.create_hidden_service(hs_path, 8888, auth_type = 'basic', client_names = ['c1', 'c2'])
+ hs_attributes = await controller.create_hidden_service(hs_path, 8888, auth_type = 'basic', client_names = ['c1', 'c2'])
self.assertEqual(2, len(hs_attributes.hostname.splitlines()))
self.assertEqual(2, len(hs_attributes.hostname_for_client))
self.assertTrue(hs_attributes.hostname_for_client['c1'].endswith('.onion'))
self.assertTrue(hs_attributes.hostname_for_client['c2'].endswith('.onion'))
- conf = controller.get_hidden_service_conf()
+ conf = await controller.get_hidden_service_conf()
self.assertEqual(3, len(conf))
self.assertEqual(1, len(conf[hs_path]['HiddenServicePort']))
# remove a hidden service
- controller.remove_hidden_service(hs_path, 8888)
- self.assertEqual(2, len(controller.get_hidden_service_conf()))
+ await controller.remove_hidden_service(hs_path, 8888)
+ self.assertEqual(2, len(await controller.get_hidden_service_conf()))
finally:
- controller.set_hidden_service_conf({}) # drop hidden services created during the test
+ await controller.set_hidden_service_conf({}) # drop hidden services created during the test
# clean up the hidden service directories created as part of this test
@@ -553,47 +576,50 @@ class TestController(unittest.TestCase):
pass
@test.require.controller
- def test_without_ephemeral_hidden_services(self):
+ @async_test
+ async def test_without_ephemeral_hidden_services(self):
"""
Exercises ephemeral hidden service methods when none are present.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- self.assertEqual([], controller.list_ephemeral_hidden_services())
- self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
- self.assertEqual(False, controller.remove_ephemeral_hidden_service('gfzprpioee3hoppz'))
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
+ self.assertEqual([], await controller.list_ephemeral_hidden_services(detached = True))
+ self.assertEqual(False, await controller.remove_ephemeral_hidden_service('gfzprpioee3hoppz'))
@test.require.controller
- def test_with_invalid_ephemeral_hidden_service_port(self):
- with test.runner.get_runner().get_tor_controller() as controller:
+ @async_test
+ async def test_with_invalid_ephemeral_hidden_service_port(self):
+ async with await test.runner.get_runner().get_tor_controller() as controller:
for ports in (4567890, [4567, 4567890], {4567: '-:4567'}):
- exc_msg = "ADD_ONION response didn't have an OK status: Invalid VIRTPORT/TARGET"
- self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, ports)
+ with self.assertRaisesWith(stem.ProtocolError, "ADD_ONION response didn't have an OK status: Invalid VIRTPORT/TARGET"):
+ await controller.create_ephemeral_hidden_service(ports)
@test.require.controller
- def test_ephemeral_hidden_services_v2(self):
+ @async_test
+ async def test_ephemeral_hidden_services_v2(self):
"""
Exercises creating v2 ephemeral hidden services.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024')
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ async with await runner.get_tor_controller() as controller:
+ response = await controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024')
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
self.assertEqual('RSA1024', response.private_key_type)
self.assertEqual({}, response.client_auth)
# drop the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services())
+ self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
# recreate the service with the same private key
- recreate_response = controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key)
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ recreate_response = await controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key)
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertEqual(response.service_id, recreate_response.service_id)
# the response only includes the private key when making a new one
@@ -603,41 +629,42 @@ class TestController(unittest.TestCase):
# create a service where we never see the private key
- response = controller.create_ephemeral_hidden_service(4568, key_content = 'RSA1024', discard_key = True)
- self.assertTrue(response.service_id in controller.list_ephemeral_hidden_services())
+ response = await controller.create_ephemeral_hidden_service(4568, key_content = 'RSA1024', discard_key = True)
+ self.assertTrue(response.service_id in await controller.list_ephemeral_hidden_services())
self.assertEqual(None, response.private_key)
self.assertEqual(None, response.private_key_type)
# other controllers shouldn't be able to see these hidden services
- with runner.get_tor_controller() as second_controller:
- self.assertEqual(2, len(controller.list_ephemeral_hidden_services()))
- self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services()))
+ async with await runner.get_tor_controller() as second_controller:
+ self.assertEqual(2, len(await controller.list_ephemeral_hidden_services()))
+ self.assertEqual(0, len(await second_controller.list_ephemeral_hidden_services()))
@test.require.controller
- def test_ephemeral_hidden_services_v3(self):
+ @async_test
+ async def test_ephemeral_hidden_services_v3(self):
"""
Exercises creating v3 ephemeral hidden services.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- response = controller.create_ephemeral_hidden_service(4567, key_content = 'ED25519-V3')
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ async with await runner.get_tor_controller() as controller:
+ response = await controller.create_ephemeral_hidden_service(4567, key_content = 'ED25519-V3')
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
self.assertEqual('ED25519-V3', response.private_key_type)
self.assertEqual({}, response.client_auth)
# drop the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services())
+ self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
# recreate the service with the same private key
- recreate_response = controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key)
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ recreate_response = await controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key)
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertEqual(response.service_id, recreate_response.service_id)
# the response only includes the private key when making a new one
@@ -647,38 +674,40 @@ class TestController(unittest.TestCase):
# create a service where we never see the private key
- response = controller.create_ephemeral_hidden_service(4568, key_content = 'ED25519-V3', discard_key = True)
- self.assertTrue(response.service_id in controller.list_ephemeral_hidden_services())
+ response = await controller.create_ephemeral_hidden_service(4568, key_content = 'ED25519-V3', discard_key = True)
+ self.assertTrue(response.service_id in await controller.list_ephemeral_hidden_services())
self.assertEqual(None, response.private_key)
self.assertEqual(None, response.private_key_type)
# other controllers shouldn't be able to see these hidden services
- with runner.get_tor_controller() as second_controller:
- self.assertEqual(2, len(controller.list_ephemeral_hidden_services()))
- self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services()))
+ async with await runner.get_tor_controller() as second_controller:
+ self.assertEqual(2, len(await controller.list_ephemeral_hidden_services()))
+ self.assertEqual(0, len(await second_controller.list_ephemeral_hidden_services()))
@test.require.controller
- def test_with_ephemeral_hidden_services_basic_auth(self):
+ @async_test
+ async def test_with_ephemeral_hidden_services_basic_auth(self):
"""
Exercises creating ephemeral hidden services that uses basic authentication.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024', basic_auth = {'alice': 'nKwfvVPmTNr2k2pG0pzV4g', 'bob': None})
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ async with await runner.get_tor_controller() as controller:
+ response = await controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024', basic_auth = {'alice': 'nKwfvVPmTNr2k2pG0pzV4g', 'bob': None})
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
self.assertEqual(['bob'], list(response.client_auth.keys())) # newly created credentials were only created for bob
# drop the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services())
+ self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
@test.require.controller
- def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self):
+ @async_test
+ async def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self):
"""
Exercises creating ephemeral hidden services when attempting to use basic
auth but not including any credentials.
@@ -686,12 +715,13 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- exc_msg = "ADD_ONION response didn't have an OK status: No auth clients specified"
- self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, 4567, basic_auth = {})
+ async with await runner.get_tor_controller() as controller:
+ with self.assertRaisesWith(stem.ProtocolError, "ADD_ONION response didn't have an OK status: No auth clients specified"):
+ await controller.create_ephemeral_hidden_service(4567, basic_auth = {})
@test.require.controller
- def test_with_detached_ephemeral_hidden_services(self):
+ @async_test
+ async def test_with_detached_ephemeral_hidden_services(self):
"""
Exercises creating detached ephemeral hidden services and methods when
they're present.
@@ -699,34 +729,35 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- response = controller.create_ephemeral_hidden_service(4567, detached = True)
- self.assertEqual([], controller.list_ephemeral_hidden_services())
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True))
+ async with await runner.get_tor_controller() as controller:
+ response = await controller.create_ephemeral_hidden_service(4567, detached = True)
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services(detached = True))
# drop and recreate the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
- controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key, detached = True)
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True))
+ self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services(detached = True))
+ await controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key, detached = True)
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services(detached = True))
# other controllers should be able to see this service, and drop it
- with runner.get_tor_controller() as second_controller:
- self.assertEqual([response.service_id], second_controller.list_ephemeral_hidden_services(detached = True))
- self.assertEqual(True, second_controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
+ async with await runner.get_tor_controller() as second_controller:
+ self.assertEqual([response.service_id], await second_controller.list_ephemeral_hidden_services(detached = True))
+ self.assertEqual(True, await second_controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services(detached = True))
# recreate the service and confirms that it outlives this controller
- response = second_controller.create_ephemeral_hidden_service(4567, detached = True)
+ response = await second_controller.create_ephemeral_hidden_service(4567, detached = True)
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True))
- controller.remove_ephemeral_hidden_service(response.service_id)
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services(detached = True))
+ await controller.remove_ephemeral_hidden_service(response.service_id)
@test.require.controller
- def test_rejecting_unanonymous_hidden_services_creation(self):
+ @async_test
+ async def test_rejecting_unanonymous_hidden_services_creation(self):
"""
Attempt to create a non-anonymous hidden service despite not setting
HiddenServiceSingleHopMode and HiddenServiceNonAnonymousMode.
@@ -734,11 +765,12 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- self.assertEqual('Tor is in anonymous hidden service mode', str(controller.msg('ADD_ONION NEW:BEST Flags=NonAnonymous Port=4567')))
+ async with await runner.get_tor_controller() as controller:
+ self.assertEqual('Tor is in anonymous hidden service mode', str(await controller.msg('ADD_ONION NEW:BEST Flags=NonAnonymous Port=4567')))
@test.require.controller
- def test_set_conf(self):
+ @async_test
+ async def test_set_conf(self):
"""
Exercises set_conf(), reset_conf(), and set_options() methods with valid
and invalid requests.
@@ -748,42 +780,42 @@ class TestController(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir:
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
try:
# successfully set a single option
- connlimit = int(controller.get_conf('ConnLimit'))
- controller.set_conf('connlimit', str(connlimit - 1))
- self.assertEqual(connlimit - 1, int(controller.get_conf('ConnLimit')))
+ connlimit = int(await controller.get_conf('ConnLimit'))
+ await controller.set_conf('connlimit', str(connlimit - 1))
+ self.assertEqual(connlimit - 1, int(await controller.get_conf('ConnLimit')))
# successfully set a single list option
exit_policy = ['accept *:7777', 'reject *:*']
- controller.set_conf('ExitPolicy', exit_policy)
- self.assertEqual(exit_policy, controller.get_conf('ExitPolicy', multiple = True))
+ await controller.set_conf('ExitPolicy', exit_policy)
+ self.assertEqual(exit_policy, await controller.get_conf('ExitPolicy', multiple = True))
# fail to set a single option
try:
- controller.set_conf('invalidkeyboo', 'abcde')
+ await controller.set_conf('invalidkeyboo', 'abcde')
self.fail()
except stem.InvalidArguments as exc:
self.assertEqual(['invalidkeyboo'], exc.arguments)
# resets configuration parameters
- controller.reset_conf('ConnLimit', 'ExitPolicy')
- self.assertEqual(connlimit, int(controller.get_conf('ConnLimit')))
- self.assertEqual(None, controller.get_conf('ExitPolicy'))
+ await controller.reset_conf('ConnLimit', 'ExitPolicy')
+ self.assertEqual(connlimit, int(await controller.get_conf('ConnLimit')))
+ self.assertEqual(None, await controller.get_conf('ExitPolicy'))
# successfully sets multiple config options
- controller.set_options({
+ await controller.set_options({
'connlimit': str(connlimit - 2),
'contactinfo': 'stem@testing',
})
- self.assertEqual(connlimit - 2, int(controller.get_conf('ConnLimit')))
- self.assertEqual('stem@testing', controller.get_conf('contactinfo'))
+ self.assertEqual(connlimit - 2, int(await controller.get_conf('ConnLimit')))
+ self.assertEqual('stem@testing', await controller.get_conf('contactinfo'))
# fail to set multiple config options
try:
- controller.set_options({
+ await controller.set_options({
'contactinfo': 'stem@testing',
'bombay': 'vadapav',
})
@@ -792,17 +824,17 @@ class TestController(unittest.TestCase):
self.assertEqual(['bombay'], exc.arguments)
# context-sensitive keys (the only retched things for which order matters)
- controller.set_options((
+ await controller.set_options((
('HiddenServiceDir', tmpdir),
('HiddenServicePort', '17234 127.0.0.1:17235'),
))
- self.assertEqual(tmpdir, controller.get_conf('HiddenServiceDir'))
- self.assertEqual('17234 127.0.0.1:17235', controller.get_conf('HiddenServicePort'))
+ self.assertEqual(tmpdir, await controller.get_conf('HiddenServiceDir'))
+ self.assertEqual('17234 127.0.0.1:17235', await controller.get_conf('HiddenServicePort'))
finally:
# reverts configuration changes
- controller.set_options((
+ await controller.set_options((
('ExitPolicy', 'reject *:*'),
('ConnLimit', None),
('ContactInfo', None),
@@ -811,47 +843,53 @@ class TestController(unittest.TestCase):
), reset = True)
@test.require.controller
- def test_set_conf_for_usebridges(self):
+ @async_test
+ async def test_set_conf_for_usebridges(self):
"""
Ensure we can set UseBridges=1 and also set a Bridge. This is a tor
regression check (:trac:`31945`).
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- orport = controller.get_conf('ORPort')
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ orport = await controller.get_conf('ORPort')
try:
- controller.set_conf('ORPort', '0') # ensure we're not a relay so UseBridges is usabe
- controller.set_options([('UseBridges', '1'), ('Bridge', '127.0.0.1:9999')])
- self.assertEqual('127.0.0.1:9999', controller.get_conf('Bridge'))
+ await controller.set_conf('ORPort', '0') # ensure we're not a relay so UseBridges is usabe
+ await controller.set_options([('UseBridges', '1'), ('Bridge', '127.0.0.1:9999')])
+ self.assertEqual('127.0.0.1:9999', await controller.get_conf('Bridge'))
finally:
# reverts configuration changes
- controller.set_options((
+ await controller.set_options((
('ORPort', orport),
('UseBridges', None),
('Bridge', None),
), reset = True)
@test.require.controller
- def test_set_conf_when_immutable(self):
+ @async_test
+ async def test_set_conf_when_immutable(self):
"""
Issue a SETCONF for tor options that cannot be changed while running.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap cannot be changed while tor's running", controller.set_conf, 'DisableAllSwap', '1')
- self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap, User cannot be changed while tor's running", controller.set_options, {'User': 'atagar', 'DisableAllSwap': '1'})
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ with self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap cannot be changed while tor's running"):
+ await controller.set_conf('DisableAllSwap', '1')
+
+ with self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap, User cannot be changed while tor's running"):
+ await controller.set_options({'User': 'atagar', 'DisableAllSwap': '1'})
@test.require.controller
- def test_loadconf(self):
+ @async_test
+ async def test_loadconf(self):
"""
Exercises Controller.load_conf with valid and invalid requests.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
oldconf = runner.get_torrc_contents()
try:
@@ -863,98 +901,105 @@ class TestController(unittest.TestCase):
# ("/home/atagar/Desktop/stem/test/data"->"/home/atagar/.tor") is not
# allowed.
- self.assertRaises(stem.InvalidRequest, controller.load_conf, 'ContactInfo confloaded')
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.load_conf('ContactInfo confloaded')
try:
- controller.load_conf('Blahblah blah')
+ await controller.load_conf('Blahblah blah')
self.fail()
except stem.InvalidArguments as exc:
self.assertEqual(['Blahblah'], exc.arguments)
# valid config
- controller.load_conf(runner.get_torrc_contents() + '\nContactInfo confloaded\n')
- self.assertEqual('confloaded', controller.get_conf('ContactInfo'))
+ await controller.load_conf(runner.get_torrc_contents() + '\nContactInfo confloaded\n')
+ self.assertEqual('confloaded', await controller.get_conf('ContactInfo'))
finally:
# reload original valid config
- controller.load_conf(oldconf)
- controller.reset_conf('__OwningControllerProcess')
+ await controller.load_conf(oldconf)
+ await controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- def test_saveconf(self):
+ @async_test
+ async def test_saveconf(self):
runner = test.runner.get_runner()
# only testing for success, since we need to run out of disk space to test
# for failure
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
oldconf = runner.get_torrc_contents()
try:
- controller.set_conf('ContactInfo', 'confsaved')
- controller.save_conf()
+ await controller.set_conf('ContactInfo', 'confsaved')
+ await controller.save_conf()
with open(runner.get_torrc_path()) as torrcfile:
self.assertTrue('\nContactInfo confsaved\n' in torrcfile.read())
finally:
- controller.load_conf(oldconf)
- controller.save_conf()
- controller.reset_conf('__OwningControllerProcess')
+ await controller.load_conf(oldconf)
+ await controller.save_conf()
+ await controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- def test_get_ports(self):
+ @async_test
+ async def test_get_ports(self):
"""
Test Controller.get_ports against a running tor instance.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- self.assertEqual([test.runner.ORPORT], controller.get_ports(Listener.OR))
- self.assertEqual([], controller.get_ports(Listener.DIR))
- self.assertEqual([test.runner.SOCKS_PORT], controller.get_ports(Listener.SOCKS))
- self.assertEqual([], controller.get_ports(Listener.TRANS))
- self.assertEqual([], controller.get_ports(Listener.NATD))
- self.assertEqual([], controller.get_ports(Listener.DNS))
+ async with await runner.get_tor_controller() as controller:
+ self.assertEqual([test.runner.ORPORT], await controller.get_ports(Listener.OR))
+ self.assertEqual([], await controller.get_ports(Listener.DIR))
+ self.assertEqual([test.runner.SOCKS_PORT], await controller.get_ports(Listener.SOCKS))
+ self.assertEqual([], await controller.get_ports(Listener.TRANS))
+ self.assertEqual([], await controller.get_ports(Listener.NATD))
+ self.assertEqual([], await controller.get_ports(Listener.DNS))
if test.runner.Torrc.PORT in runner.get_options():
- self.assertEqual([test.runner.CONTROL_PORT], controller.get_ports(Listener.CONTROL))
+ self.assertEqual([test.runner.CONTROL_PORT], await controller.get_ports(Listener.CONTROL))
else:
- self.assertEqual([], controller.get_ports(Listener.CONTROL))
+ self.assertEqual([], await controller.get_ports(Listener.CONTROL))
@test.require.controller
- def test_get_listeners(self):
+ @async_test
+ async def test_get_listeners(self):
"""
Test Controller.get_listeners against a running tor instance.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- self.assertEqual([('0.0.0.0', test.runner.ORPORT)], controller.get_listeners(Listener.OR))
- self.assertEqual([], controller.get_listeners(Listener.DIR))
- self.assertEqual([('127.0.0.1', test.runner.SOCKS_PORT)], controller.get_listeners(Listener.SOCKS))
- self.assertEqual([], controller.get_listeners(Listener.TRANS))
- self.assertEqual([], controller.get_listeners(Listener.NATD))
- self.assertEqual([], controller.get_listeners(Listener.DNS))
+ async with await runner.get_tor_controller() as controller:
+ self.assertEqual([('0.0.0.0', test.runner.ORPORT)], await controller.get_listeners(Listener.OR))
+ self.assertEqual([], await controller.get_listeners(Listener.DIR))
+ self.assertEqual([('127.0.0.1', test.runner.SOCKS_PORT)], await controller.get_listeners(Listener.SOCKS))
+ self.assertEqual([], await controller.get_listeners(Listener.TRANS))
+ self.assertEqual([], await controller.get_listeners(Listener.NATD))
+ self.assertEqual([], await controller.get_listeners(Listener.DNS))
if test.runner.Torrc.PORT in runner.get_options():
- self.assertEqual([('127.0.0.1', test.runner.CONTROL_PORT)], controller.get_listeners(Listener.CONTROL))
+ self.assertEqual([('127.0.0.1', test.runner.CONTROL_PORT)], await controller.get_listeners(Listener.CONTROL))
else:
- self.assertEqual([], controller.get_listeners(Listener.CONTROL))
+ self.assertEqual([], await controller.get_listeners(Listener.CONTROL))
@test.require.controller
@test.require.online
@test.require.version(stem.version.Version('0.1.2.2-alpha'))
- def test_enable_feature(self):
+ @async_test
+ async def test_enable_feature(self):
"""
Test Controller.enable_feature with valid and invalid inputs.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
self.assertTrue(controller.is_feature_enabled('VERBOSE_NAMES'))
- self.assertRaises(stem.InvalidArguments, controller.enable_feature, ['NOT', 'A', 'FEATURE'])
+
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.enable_feature(['NOT', 'A', 'FEATURE'])
try:
controller.enable_feature(['NOT', 'A', 'FEATURE'])
@@ -964,58 +1009,70 @@ class TestController(unittest.TestCase):
self.fail()
@test.require.controller
- def test_signal(self):
+ @async_test
+ async def test_signal(self):
"""
Test controller.signal with valid and invalid signals.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
# valid signal
- controller.signal('CLEARDNSCACHE')
+ await controller.signal('CLEARDNSCACHE')
# invalid signals
- self.assertRaises(stem.InvalidArguments, controller.signal, 'FOOBAR')
+
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.signal('FOOBAR')
@test.require.controller
- def test_newnym_availability(self):
+ @async_test
+ async def test_newnym_availability(self):
"""
Test the is_newnym_available and get_newnym_wait methods.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
self.assertEqual(True, controller.is_newnym_available())
self.assertEqual(0.0, controller.get_newnym_wait())
- controller.signal(stem.Signal.NEWNYM)
+ await controller.signal(stem.Signal.NEWNYM)
self.assertEqual(False, controller.is_newnym_available())
self.assertTrue(controller.get_newnym_wait() > 9.0)
@test.require.controller
@test.require.online
- def test_extendcircuit(self):
- with test.runner.get_runner().get_tor_controller() as controller:
+ @async_test
+ async def test_extendcircuit(self):
+ async with await test.runner.get_runner().get_tor_controller() as controller:
circuit_id = controller.extend_circuit('0')
# check if our circuit was created
+
self.assertNotEqual(None, controller.get_circuit(circuit_id, None))
circuit_id = controller.new_circuit()
self.assertNotEqual(None, controller.get_circuit(circuit_id, None))
- self.assertRaises(stem.InvalidRequest, controller.extend_circuit, 'foo')
- self.assertRaises(stem.InvalidRequest, controller.extend_circuit, '0', 'thisroutershouldntexistbecausestemexists!@##$%#')
- self.assertRaises(stem.InvalidRequest, controller.extend_circuit, '0', 'thisroutershouldntexistbecausestemexists!@##$%#', 'foo')
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.extend_circuit('foo')
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.extend_circuit('0', 'thisroutershouldntexistbecausestemexists!@##$%#')
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.extend_circuit('0', 'thisroutershouldntexistbecausestemexists!@##$%#', 'foo')
@test.require.controller
@test.require.online
- def test_repurpose_circuit(self):
+ @async_test
+ async def test_repurpose_circuit(self):
"""
Tests Controller.repurpose_circuit with valid and invalid input.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
circ_id = controller.new_circuit()
controller.repurpose_circuit(circ_id, 'CONTROLLER')
circuit = controller.get_circuit(circ_id)
@@ -1025,38 +1082,47 @@ class TestController(unittest.TestCase):
circuit = controller.get_circuit(circ_id)
self.assertTrue(circuit.purpose == 'GENERAL')
- self.assertRaises(stem.InvalidRequest, controller.repurpose_circuit, 'f934h9f3h4', 'fooo')
- self.assertRaises(stem.InvalidRequest, controller.repurpose_circuit, '4', 'fooo')
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.repurpose_circuit('f934h9f3h4', 'fooo')
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.repurpose_circuit('4', 'fooo')
@test.require.controller
@test.require.online
- def test_close_circuit(self):
+ @async_test
+ async def test_close_circuit(self):
"""
Tests Controller.close_circuit with valid and invalid input.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
circuit_id = controller.new_circuit()
controller.close_circuit(circuit_id)
- circuit_output = controller.get_info('circuit-status')
+ circuit_output = await controller.get_info('circuit-status')
circ = [x.split()[0] for x in circuit_output.splitlines()]
self.assertFalse(circuit_id in circ)
circuit_id = controller.new_circuit()
controller.close_circuit(circuit_id, 'IfUnused')
- circuit_output = controller.get_info('circuit-status')
+ circuit_output = await controller.get_info('circuit-status')
circ = [x.split()[0] for x in circuit_output.splitlines()]
self.assertFalse(circuit_id in circ)
circuit_id = controller.new_circuit()
- self.assertRaises(stem.InvalidArguments, controller.close_circuit, circuit_id + '1024')
- self.assertRaises(stem.InvalidRequest, controller.close_circuit, '')
+
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.close_circuit(circuit_id + '1024')
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.close_circuit('')
@test.require.controller
@test.require.online
- def test_get_streams(self):
+ @async_test
+ async def test_get_streams(self):
"""
Tests Controller.get_streams().
"""
@@ -1065,9 +1131,11 @@ class TestController(unittest.TestCase):
port = 443
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+
+ async with await runner.get_tor_controller() as controller:
# we only need one proxy port, so take the first
- socks_listener = controller.get_listeners(Listener.SOCKS)[0]
+
+ socks_listener = (await controller.get_listeners(Listener.SOCKS))[0]
with test.network.Socks(socks_listener) as s:
s.settimeout(30)
@@ -1081,17 +1149,18 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_close_stream(self):
+ @async_test
+ async def test_close_stream(self):
"""
Tests Controller.close_stream with valid and invalid input.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# use the first socks listener
- socks_listener = controller.get_listeners(Listener.SOCKS)[0]
+ socks_listener = (await controller.get_listeners(Listener.SOCKS))[0]
with test.network.Socks(socks_listener) as s:
s.settimeout(30)
@@ -1116,16 +1185,18 @@ class TestController(unittest.TestCase):
# unknown stream
- self.assertRaises(stem.InvalidArguments, controller.close_stream, 'blarg')
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.close_stream('blarg')
@test.require.controller
@test.require.online
- def test_mapaddress(self):
+ @async_test
+ async def test_mapaddress(self):
self.skipTest('(https://trac.torproject.org/projects/tor/ticket/25611)')
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- controller.map_address({'1.2.1.2': 'ifconfig.me'})
+ async with await runner.get_tor_controller() as controller:
+ await controller.map_address({'1.2.1.2': 'ifconfig.me'})
s = None
response = None
@@ -1136,7 +1207,7 @@ class TestController(unittest.TestCase):
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(30)
- s.connect(('127.0.0.1', int(controller.get_conf('SocksPort'))))
+ s.connect(('127.0.0.1', int(await controller.get_conf('SocksPort'))))
test.network.negotiate_socks(s, '1.2.1.2', 80)
s.sendall(stem.util.str_tools._to_bytes(test.network.IP_REQUEST)) # make the http request for the ip address
response = s.recv(1000)
@@ -1158,14 +1229,15 @@ class TestController(unittest.TestCase):
self.assertTrue(stem.util.connection.is_valid_ipv4_address(stem.util.str_tools._to_unicode(ip_addr)), "'%s' isn't an address" % ip_addr)
@test.require.controller
- def test_mapaddress_offline(self):
+ @async_test
+ async def test_mapaddress_offline(self):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# try mapping one element, ensuring results are as expected
map1 = {'1.2.1.2': 'ifconfig.me'}
- x = controller.map_address(map1)
+ x = await controller.map_address(map1)
self.assertEqual(x, map1)
# try mapping two elements, ensuring results are as expected
@@ -1173,17 +1245,18 @@ class TestController(unittest.TestCase):
map2 = {'1.2.3.4': 'foobar.example.com',
'1.2.3.5': 'barfuzz.example.com'}
- x = controller.map_address(map2)
+ x = await controller.map_address(map2)
self.assertEqual(x, map2)
# try mapping zero elements
- self.assertRaises(stem.InvalidRequest, controller.map_address, {})
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.map_address({})
# try a virtual mapping to IPv4, the default virtualaddressrange is 127.192.0.0/10
map3 = {'0.0.0.0': 'quux'}
- x = controller.map_address(map3)
+ x = await controller.map_address(map3)
self.assertEquals(len(x), 1)
addr1, target = list(x.items())[0]
@@ -1193,15 +1266,15 @@ class TestController(unittest.TestCase):
# try a virtual mapping to IPv6, the default IPv6 virtualaddressrange is FE80::/10
map4 = {'::': 'quibble'}
- x = controller.map_address(map4)
+ x = await controller.map_address(map4)
self.assertEquals(len(x), 1)
addr2, target = list(x.items())[0]
self.assertTrue(addr2.startswith('[fe'), '%s did not start with [fe.' % addr2)
self.assertEquals(target, 'quibble')
- def address_mappings(addr_type):
- response = controller.get_info(['address-mappings/%s' % addr_type])
+ async def address_mappings(addr_type):
+ response = await controller.get_info(['address-mappings/%s' % addr_type])
result = {}
for line in response['address-mappings/%s' % addr_type].splitlines():
@@ -1218,7 +1291,7 @@ class TestController(unittest.TestCase):
'1.2.3.5': 'barfuzz.example.com',
addr1: 'quux',
addr2: 'quibble',
- }, address_mappings('control'))
+ }, await address_mappings('control'))
# ask for a list of all the address mappings
@@ -1228,29 +1301,40 @@ class TestController(unittest.TestCase):
'1.2.3.5': 'barfuzz.example.com',
addr1: 'quux',
addr2: 'quibble',
- }, address_mappings('all'))
+ }, await address_mappings('all'))
# Now ask for a list of only the mappings configured with the
# configuration. Ours should not be there.
- self.assertEquals({}, address_mappings('config'))
+ self.assertEquals({}, await address_mappings('config'))
@test.require.controller
@test.require.online
- def test_get_microdescriptor(self):
+ @async_test
+ async def test_get_microdescriptor(self):
"""
Basic checks for get_microdescriptor().
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
# we should balk at invalid content
- self.assertRaises(ValueError, controller.get_microdescriptor, '')
- self.assertRaises(ValueError, controller.get_microdescriptor, 5)
- self.assertRaises(ValueError, controller.get_microdescriptor, 'z' * 30)
+
+ with self.assertRaises(ValueError):
+ await controller.get_microdescriptor('')
+
+ with self.assertRaises(ValueError):
+ await controller.get_microdescriptor(5)
+
+ with self.assertRaises(ValueError):
+ await controller.get_microdescriptor('z' * 30)
# try with a relay that doesn't exist
- self.assertRaises(stem.ControllerError, controller.get_microdescriptor, 'blargg')
- self.assertRaises(stem.ControllerError, controller.get_microdescriptor, '5' * 40)
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_microdescriptor('blargg')
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_microdescriptor('5' * 40)
test_relay = self._get_router_status_entry(controller)
@@ -1261,7 +1345,8 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_microdescriptors(self):
+ @async_test
+ async def test_get_microdescriptors(self):
"""
Fetches a few descriptors via the get_microdescriptors() method.
"""
@@ -1271,7 +1356,7 @@ class TestController(unittest.TestCase):
if not os.path.exists(runner.get_test_dir('cached-microdescs')):
self.skipTest('(no cached microdescriptors)')
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_microdescriptors():
@@ -1283,22 +1368,33 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_server_descriptor(self):
+ @async_test
+ async def test_get_server_descriptor(self):
"""
Basic checks for get_server_descriptor().
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# we should balk at invalid content
- self.assertRaises(ValueError, controller.get_server_descriptor, '')
- self.assertRaises(ValueError, controller.get_server_descriptor, 5)
- self.assertRaises(ValueError, controller.get_server_descriptor, 'z' * 30)
+
+ with self.assertRaises(ValueError):
+ await controller.get_server_descriptor('')
+
+ with self.assertRaises(ValueError):
+ await controller.get_server_descriptor(5)
+
+ with self.assertRaises(ValueError):
+ await controller.get_server_descriptor('z' * 30)
# try with a relay that doesn't exist
- self.assertRaises(stem.ControllerError, controller.get_server_descriptor, 'blargg')
- self.assertRaises(stem.ControllerError, controller.get_server_descriptor, '5' * 40)
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_server_descriptor('blargg')
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_server_descriptor('5' * 40)
test_relay = self._get_router_status_entry(controller)
@@ -1309,14 +1405,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_server_descriptors(self):
+ @async_test
+ async def test_get_server_descriptors(self):
"""
Fetches a few descriptors via the get_server_descriptors() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_server_descriptors():
@@ -1334,20 +1431,31 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_network_status(self):
+ @async_test
+ async def test_get_network_status(self):
"""
Basic checks for get_network_status().
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
# we should balk at invalid content
- self.assertRaises(ValueError, controller.get_network_status, '')
- self.assertRaises(ValueError, controller.get_network_status, 5)
- self.assertRaises(ValueError, controller.get_network_status, 'z' * 30)
+
+ with self.assertRaises(ValueError):
+ await controller.get_network_status('')
+
+ with self.assertRaises(ValueError):
+ await controller.get_network_status(5)
+
+ with self.assertRaises(ValueError):
+ await controller.get_network_status('z' * 30)
# try with a relay that doesn't exist
- self.assertRaises(stem.ControllerError, controller.get_network_status, 'blargg')
- self.assertRaises(stem.ControllerError, controller.get_network_status, '5' * 40)
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_network_status('blargg')
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_network_status('5' * 40)
test_relay = self._get_router_status_entry(controller)
@@ -1358,14 +1466,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_network_statuses(self):
+ @async_test
+ async def test_get_network_statuses(self):
"""
Fetches a few descriptors via the get_network_statuses() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_network_statuses():
@@ -1381,14 +1490,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_hidden_service_descriptor(self):
+ @async_test
+ async def test_get_hidden_service_descriptor(self):
"""
Fetches a few descriptors via the get_hidden_service_descriptor() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# fetch the descriptor for DuckDuckGo
desc = controller.get_hidden_service_descriptor('3g2upl4pq6kufc4m.onion')
@@ -1396,8 +1506,8 @@ class TestController(unittest.TestCase):
# try to fetch something that doesn't exist
- exc_msg = 'No running hidden service at m4cfuk6qp4lpu2g3.onion'
- self.assertRaisesWith(stem.DescriptorUnavailable, exc_msg, controller.get_hidden_service_descriptor, 'm4cfuk6qp4lpu2g3')
+ with self.assertRaisesWith(stem.DescriptorUnavailable, 'No running hidden service at m4cfuk6qp4lpu2g3.onion'):
+ await controller.get_hidden_service_descriptor('m4cfuk6qp4lpu2g3')
# ... but shouldn't fail if we have a default argument or aren't awaiting the descriptor
@@ -1406,7 +1516,8 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_attachstream(self):
+ @async_test
+ async def test_attachstream(self):
host = socket.gethostbyname('www.torproject.org')
port = 80
@@ -1416,15 +1527,16 @@ class TestController(unittest.TestCase):
if stream.status == 'NEW' and circuit_id:
controller.attach_stream(stream.id, circuit_id)
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
# try 10 times to build a circuit we can connect through
+
for i in range(10):
- controller.add_event_listener(handle_streamcreated, stem.control.EventType.STREAM)
- controller.set_conf('__LeaveStreamsUnattached', '1')
+ await controller.add_event_listener(handle_streamcreated, stem.control.EventType.STREAM)
+ await controller.set_conf('__LeaveStreamsUnattached', '1')
try:
circuit_id = controller.new_circuit(await_build = True)
- socks_listener = controller.get_listeners(Listener.SOCKS)[0]
+ socks_listener = (await controller.get_listeners(Listener.SOCKS))[0]
with test.network.Socks(socks_listener) as s:
s.settimeout(30)
@@ -1435,7 +1547,7 @@ class TestController(unittest.TestCase):
continue
finally:
controller.remove_event_listener(handle_streamcreated)
- controller.reset_conf('__LeaveStreamsUnattached')
+ await controller.reset_conf('__LeaveStreamsUnattached')
our_stream = [stream for stream in streams if stream.target_address == host][0]
@@ -1446,38 +1558,40 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_circuits(self):
+ @async_test
+ async def test_get_circuits(self):
"""
Fetches circuits via the get_circuits() method.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
new_circ = controller.new_circuit()
circuits = controller.get_circuits()
self.assertTrue(new_circ in [circ.id for circ in circuits])
@test.require.controller
- def test_transition_to_relay(self):
+ @async_test
+ async def test_transition_to_relay(self):
"""
Transitions Tor to turn into a relay, then back to a client. This helps to
catch transition issues such as the one cited in :trac:`14901`.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
try:
- controller.reset_conf('OrPort', 'DisableNetwork')
- self.assertEqual(None, controller.get_conf('OrPort'))
+ await controller.reset_conf('OrPort', 'DisableNetwork')
+ self.assertEqual(None, await controller.get_conf('OrPort'))
# DisableNetwork ensures no port is actually opened
- controller.set_options({'OrPort': '9090', 'DisableNetwork': '1'})
+ await controller.set_options({'OrPort': '9090', 'DisableNetwork': '1'})
# TODO once tor 0.2.7.x exists, test that we can generate a descriptor on demand.
- self.assertEqual('9090', controller.get_conf('OrPort'))
- controller.reset_conf('OrPort', 'DisableNetwork')
- self.assertEqual(None, controller.get_conf('OrPort'))
+ self.assertEqual('9090', await controller.get_conf('OrPort'))
+ await controller.reset_conf('OrPort', 'DisableNetwork')
+ self.assertEqual(None, await controller.get_conf('OrPort'))
finally:
- controller.set_conf('OrPort', str(test.runner.ORPORT))
+ await controller.set_conf('OrPort', str(test.runner.ORPORT))
def _get_router_status_entry(self, controller):
"""
diff --git a/test/runner.py b/test/runner.py
index b132b8f5..a02fb769 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -88,7 +88,7 @@ class TorInaccessable(Exception):
async def exercise_controller(test_case, controller):
- """with await test.runner.get_runner().get_tor_socket
+ """
Checks that we can now use the socket by issuing a 'GETINFO config-file'
query. Controller can be either a :class:`stem.socket.ControlSocket` or
:class:`stem.control.BaseController`.
@@ -102,11 +102,10 @@ async def exercise_controller(test_case, controller):
if isinstance(controller, stem.socket.ControlSocket):
await controller.send('GETINFO config-file')
+
config_file_response = await controller.recv()
else:
- config_file_response = controller.msg('GETINFO config-file')
- if asyncio.iscoroutine(config_file_response):
- config_file_response = await config_file_response
+ config_file_response = await controller.msg('GETINFO config-file')
test_case.assertEqual('config-file=%s\nOK' % torrc_path, str(config_file_response))
@@ -261,9 +260,19 @@ class Runner(object):
stem.socket.recv_message = _chroot_recv_message
if self.is_accessible():
- self._owner_controller = stem.control.Controller(self._get_unconnected_socket(), False)
- self._owner_controller.connect()
- self._authenticate_controller(self._owner_controller)
+ # TODO: refactor so owner controller is less convoluted
+
+ loop = asyncio.new_event_loop()
+
+ self._owner_controller_thread = threading.Thread(
+ name = 'owning_controller',
+ target = loop.run_forever,
+ daemon = True,
+ )
+
+ self._owner_controller_thread.start()
+
+ self._owner_controller = asyncio.run_coroutine_threadsafe(self.get_tor_controller(True), loop).result()
if test.Target.RELATIVE in self.attribute_targets:
os.chdir(original_cwd) # revert our cwd back to normal
@@ -279,7 +288,9 @@ class Runner(object):
println('Shutting down tor... ', STATUS, NO_NL)
if self._owner_controller:
- self._owner_controller.close()
+ asyncio.run_coroutine_threadsafe(self._owner_controller.close(), self._owner_controller._loop).result()
+ self._owner_controller._loop.call_soon_threadsafe(self._owner_controller._loop.stop)
+ self._owner_controller_thread.join()
self._owner_controller = None
if self._tor_process:
@@ -445,16 +456,6 @@ class Runner(object):
tor_process = self._get('_tor_process')
return tor_process.pid
- def _get_unconnected_socket(self):
- if Torrc.PORT in self._custom_opts:
- control_socket = stem.socket.ControlPort(port = CONTROL_PORT)
- elif Torrc.SOCKET in self._custom_opts:
- control_socket = stem.socket.ControlSocketFile(CONTROL_SOCKET_PATH)
- else:
- raise TorInaccessable('Unable to connect to tor')
-
- return control_socket
-
async def get_tor_socket(self, authenticate = True):
"""
Provides a socket connected to our tor test instance.
@@ -466,7 +467,13 @@ class Runner(object):
:raises: :class:`test.runner.TorInaccessable` if tor can't be connected to
"""
- control_socket = self._get_unconnected_socket()
+ if Torrc.PORT in self._custom_opts:
+ control_socket = stem.socket.ControlPort(port = CONTROL_PORT)
+ elif Torrc.SOCKET in self._custom_opts:
+ control_socket = stem.socket.ControlSocketFile(CONTROL_SOCKET_PATH)
+ else:
+ raise TorInaccessable('Unable to connect to tor')
+
await control_socket.connect()
if authenticate:
@@ -474,10 +481,7 @@ class Runner(object):
return control_socket
- def _authenticate_controller(self, controller):
- controller.authenticate(password=CONTROL_PASSWORD, chroot_path=self.get_chroot())
-
- def get_tor_controller(self, authenticate = True):
+ async def get_tor_controller(self, authenticate = True):
"""
Provides a controller connected to our tor test instance.
@@ -488,19 +492,11 @@ class Runner(object):
:raises: :class: `test.runner.TorInaccessable` if tor can't be connected to
"""
- loop = asyncio.new_event_loop()
- loop_thread = threading.Thread(target = loop.run_forever, name = 'get_tor_controller')
- loop_thread.setDaemon(True)
- loop_thread.start()
-
- async def wrapped_get_controller():
- control_socket = await self.get_tor_socket(False)
- return stem.control.Controller(control_socket)
-
- controller = asyncio.run_coroutine_threadsafe(wrapped_get_controller(), loop).result()
+ control_socket = await self.get_tor_socket(False)
+ controller = stem.control.Controller(control_socket)
if authenticate:
- self._authenticate_controller(controller)
+ await controller.authenticate(password = CONTROL_PASSWORD, chroot_path = self.get_chroot())
return controller
diff --git a/test/settings.cfg b/test/settings.cfg
index 70bdd069..ef543a18 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -235,6 +235,14 @@ mypy.ignore stem/descriptor/remote.py => Return type "Coroutine[Any, Any, None]"
mypy.ignore * => "Descriptor" has no attribute "*
+# Metaprogramming false positive for our close method.
+
+mypy.ignore stem/control.py => Return type "Coroutine[Any, Any, None]" of "close" *
+
+# Interpreter uses a synchronous controller, which can cause false positives.
+
+mypy.ignore stem/interpreter/commands.py => "Coroutine[Any, Any, ControlMessage]" has no attribute "*
+
# Test modules we want to run. Modules are roughly ordered by the dependencies
# so the lowest level tests come first. This is because a problem in say,
# controller message parsing, will cause all higher level tests to fail too.
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index 84fcdfed..6c33da6b 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -21,11 +21,7 @@ from stem import ControllerError, DescriptorUnavailable, InvalidArguments, Inval
from stem.control import MALFORMED_EVENTS, _parse_circ_path, Listener, Controller, EventType
from stem.response import ControlMessage
from stem.exit_policy import ExitPolicy
-from stem.util.test_tools import (
- async_test,
- coro_func_raising_exc,
- coro_func_returning_value,
-)
+from stem.util.test_tools import coro_func_raising_exc, coro_func_returning_value
NS_DESC = 'r %s %s u5lTXJKGsLKufRLnSyVqT7TdGYw 2012-12-30 22:02:49 77.223.43.54 9001 0\ns Fast Named Running Stable Valid\nw Bandwidth=75'
TEST_TIMESTAMP = 12345
@@ -44,7 +40,6 @@ class TestControl(unittest.TestCase):
with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))):
self.controller = Controller(socket)
- self.async_controller = self.controller._wrapped_instance
self.circ_listener = Mock()
self.controller.add_event_listener(self.circ_listener, EventType.CIRC)
@@ -69,24 +64,23 @@ class TestControl(unittest.TestCase):
for event in stem.control.EventType:
self.assertTrue(stem.control.event_description(event) is not None)
- @patch('stem.control.AsyncController.msg')
+ @patch('stem.control.Controller.msg')
def test_get_info(self, msg_mock):
message = ControlMessage.from_str('250-hello=hi right back!\r\n250 OK\r\n', 'GETINFO')
msg_mock.side_effect = coro_func_returning_value(message)
self.assertEqual('hi right back!', self.controller.get_info('hello'))
- @patch('stem.control.AsyncController.msg')
- @async_test
- async def test_get_info_address_caching(self, msg_mock):
+ @patch('stem.control.Controller.msg')
+ def test_get_info_address_caching(self, msg_mock):
def set_message(*args):
message = ControlMessage.from_str(*args)
msg_mock.side_effect = coro_func_returning_value(message)
set_message('551 Address unknown\r\n')
- self.assertEqual(None, self.async_controller._last_address_exc)
+ self.assertEqual(None, self.controller._last_address_exc)
self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address')
- self.assertEqual('Address unknown', str(self.async_controller._last_address_exc))
+ self.assertEqual('Address unknown', str(self.controller._last_address_exc))
self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back
@@ -98,26 +92,26 @@ class TestControl(unittest.TestCase):
set_message('250-address=17.2.89.80\r\n250 OK\r\n', 'GETINFO')
self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address')
- await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n'))
+ self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n'))
self.assertEqual('17.2.89.80', self.controller.get_info('address'))
# invalidates the cache, transitioning from one address to another
set_message('250-address=80.89.2.17\r\n250 OK\r\n', 'GETINFO')
self.assertEqual('17.2.89.80', self.controller.get_info('address'))
- await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n'))
+ self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n'))
self.assertEqual('80.89.2.17', self.controller.get_info('address'))
- @patch('stem.control.AsyncController.msg')
- @patch('stem.control.AsyncController.get_conf')
+ @patch('stem.control.Controller.msg')
+ @patch('stem.control.Controller.get_conf')
def test_get_info_without_fingerprint(self, get_conf_mock, msg_mock):
message = ControlMessage.from_str('551 Not running in server mode\r\n')
msg_mock.side_effect = coro_func_returning_value(message)
- get_conf_mock.return_value = None
+ get_conf_mock.side_effect = coro_func_returning_value(None)
- self.assertEqual(None, self.async_controller._last_fingerprint_exc)
+ self.assertEqual(None, self.controller._last_fingerprint_exc)
self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint')
- self.assertEqual('Not running in server mode', str(self.async_controller._last_fingerprint_exc))
+ self.assertEqual('Not running in server mode', str(self.controller._last_fingerprint_exc))
self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back
@@ -127,11 +121,11 @@ class TestControl(unittest.TestCase):
# ... but if we become a relay we'll call it again
- get_conf_mock.return_value = '443'
+ get_conf_mock.side_effect = coro_func_returning_value('443')
self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint')
self.assertEqual(2, msg_mock.call_count)
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_version(self, get_info_mock):
"""
Exercises the get_version() method.
@@ -155,7 +149,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(version_2_1_object, self.controller.get_version())
# Turn off caching.
- self.async_controller._is_caching_enabled = False
+ self.controller._is_caching_enabled = False
# Return a version without caching, so it will be the new version.
self.assertEqual(version_2_2_object, self.controller.get_version())
@@ -184,13 +178,13 @@ class TestControl(unittest.TestCase):
# Turn caching back on before we leave.
self.controller._is_caching_enabled = True
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_exit_policy(self, get_info_mock):
"""
Exercises the get_exit_policy() method.
"""
- async def get_info_mock_side_effect(param, default = None):
+ async def get_info_mock_side_effect(self, param, default = None):
return {
'exit-policy/full': 'reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*',
}[param]
@@ -213,8 +207,8 @@ class TestControl(unittest.TestCase):
self.assertEqual(str(expected), str(self.controller.get_exit_policy()))
- @patch('stem.control.AsyncController.get_info')
- @patch('stem.control.AsyncController.get_conf')
+ @patch('stem.control.Controller.get_info')
+ @patch('stem.control.Controller.get_conf')
def test_get_ports(self, get_conf_mock, get_info_mock):
"""
Exercises the get_ports() and get_listeners() methods.
@@ -225,7 +219,7 @@ class TestControl(unittest.TestCase):
get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
- async def get_conf_mock_side_effect(param, *args, **kwargs):
+ async def get_conf_mock_side_effect(self, param, *args, **kwargs):
return {
'ControlPort': '9050',
'ControlListenAddress': ['127.0.0.1'],
@@ -239,7 +233,7 @@ class TestControl(unittest.TestCase):
# non-local addresss
- async def get_conf_mock_side_effect(param, *args, **kwargs):
+ async def get_conf_mock_side_effect(self, param, *args, **kwargs):
return {
'ControlPort': '9050',
'ControlListenAddress': ['27.4.4.1'],
@@ -290,14 +284,14 @@ class TestControl(unittest.TestCase):
self.assertEqual([], self.controller.get_listeners(Listener.CONTROL))
self.assertEqual([], self.controller.get_ports(Listener.CONTROL))
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
@patch('time.time', Mock(return_value = 1410723598.276578))
def test_get_accounting_stats(self, get_info_mock):
"""
Exercises the get_accounting_stats() method.
"""
- async def get_info_mock_side_effect(param, **kwargs):
+ async def get_info_mock_side_effect(self, param, **kwargs):
return {
'accounting/enabled': '1',
'accounting/hibernating': 'awake',
@@ -358,6 +352,7 @@ class TestControl(unittest.TestCase):
self.assertRaises(ProtocolError, self.controller.get_protocolinfo)
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
def test_get_user_remote(self):
"""
Exercise the get_user() method for a non-local socket.
@@ -367,7 +362,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(123, self.controller.get_user(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('atagar')))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value('atagar')))
def test_get_user_by_getinfo(self):
"""
Exercise the get_user() resolution via its getinfo option.
@@ -376,7 +371,8 @@ class TestControl(unittest.TestCase):
self.assertEqual('atagar', self.controller.get_user())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.util.system.pid_by_name', Mock(return_value = 432))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_pid', Mock(side_effect = coro_func_returning_value(432)))
@patch('stem.util.system.user', Mock(return_value = 'atagar'))
def test_get_user_by_system(self):
"""
@@ -386,6 +382,7 @@ class TestControl(unittest.TestCase):
self.assertEqual('atagar', self.controller.get_user())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
def test_get_pid_remote(self):
"""
Exercise the get_pid() method for a non-local socket.
@@ -395,7 +392,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(123, self.controller.get_pid(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('321')))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value('321')))
def test_get_pid_by_getinfo(self):
"""
Exercise the get_pid() resolution via its getinfo option.
@@ -404,7 +401,8 @@ class TestControl(unittest.TestCase):
self.assertEqual(321, self.controller.get_pid())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.AsyncController.get_conf')
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_conf')
@patch('stem.control.open', create = True)
def test_get_pid_by_pid_file(self, open_mock, get_conf_mock):
"""
@@ -418,6 +416,8 @@ class TestControl(unittest.TestCase):
open_mock.assert_called_once_with('/tmp/pid_file')
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_conf', Mock(side_effect = coro_func_returning_value(None)))
@patch('stem.util.system.pid_by_name', Mock(return_value = 432))
def test_get_pid_by_name(self):
"""
@@ -426,9 +426,9 @@ class TestControl(unittest.TestCase):
self.assertEqual(432, self.controller.get_pid())
- @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
+ @patch('stem.control.Controller.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
@patch('time.time', Mock(return_value = 1000.0))
def test_get_uptime_by_getinfo(self, getinfo_mock):
"""
@@ -443,8 +443,9 @@ class TestControl(unittest.TestCase):
self.assertRaisesWith(ValueError, "'GETINFO uptime' did not provide a valid numeric response: abc", self.controller.get_uptime)
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.1.0.14'))))
- @patch('stem.control.AsyncController.get_pid', Mock(side_effect = coro_func_returning_value('12')))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.1.0.14'))))
+ @patch('stem.control.Controller.get_pid', Mock(side_effect = coro_func_returning_value('12')))
@patch('stem.util.system.start_time', Mock(return_value = 5000.0))
@patch('time.time', Mock(return_value = 5200.0))
def test_get_uptime_by_process(self):
@@ -454,7 +455,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(200.0, self.controller.get_uptime())
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_network_status_for_ourselves(self, get_info_mock):
"""
Exercises the get_network_status() method for getting our own relay.
@@ -472,7 +473,7 @@ class TestControl(unittest.TestCase):
desc = NS_DESC % ('moria1', '/96bKo4soysolMgKn5Hex2nyFSY')
- async def get_info_mock_side_effect(param, **kwargs):
+ async def get_info_mock_side_effect(self, param, **kwargs):
return {
'fingerprint': '9695DFC35FFEB861329B9F1AB04C46397020CE31',
'ns/id/9695DFC35FFEB861329B9F1AB04C46397020CE31': desc,
@@ -482,7 +483,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(stem.descriptor.router_status_entry.RouterStatusEntryV3(desc), self.controller.get_network_status())
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_network_status_when_unavailable(self, get_info_mock):
"""
Exercises the get_network_status() method.
@@ -494,7 +495,7 @@ class TestControl(unittest.TestCase):
exc_msg = "Tor was unable to provide the descriptor for '5AC9C5AA75BA1F18D8459B326B4B8111A856D290'"
self.assertRaisesWith(DescriptorUnavailable, exc_msg, self.controller.get_network_status, '5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_network_status(self, get_info_mock):
"""
Exercises the get_network_status() method.
@@ -540,16 +541,14 @@ class TestControl(unittest.TestCase):
self.assertRaises(InvalidArguments, self.controller.get_network_status, nickname)
- @patch('stem.control.AsyncController.is_authenticated', Mock(return_value = True))
- @patch('stem.control.AsyncController._attach_listeners')
- @patch('stem.control.AsyncController.get_version')
- def test_add_event_listener(self, get_version_mock, attach_listeners_mock):
+ @patch('stem.control.Controller.is_authenticated', Mock(return_value = True))
+ @patch('stem.control.Controller._attach_listeners', Mock(side_effect = coro_func_returning_value(([], []))))
+ @patch('stem.control.Controller.get_version')
+ def test_add_event_listener(self, get_version_mock):
"""
Exercises the add_event_listener and remove_event_listener methods.
"""
- attach_listeners_mock.side_effect = coro_func_returning_value(([], []))
-
def set_version(version_str):
version = stem.version.Version(version_str)
get_version_mock.side_effect = coro_func_returning_value(version)
@@ -621,10 +620,10 @@ class TestControl(unittest.TestCase):
self._emit_event(BW_EVENT)
self.bw_listener.assert_called_once_with(BW_EVENT)
- @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
- @patch('stem.control.AsyncController.msg', Mock(side_effect = coro_func_returning_value(ControlMessage.from_str('250 OK\r\n'))))
- @patch('stem.control.AsyncController.add_event_listener', Mock(side_effect = coro_func_returning_value(None)))
- @patch('stem.control.AsyncController.remove_event_listener', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
+ @patch('stem.control.Controller.msg', Mock(side_effect = coro_func_returning_value(ControlMessage.from_str('250 OK\r\n'))))
+ @patch('stem.control.Controller.add_event_listener', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.remove_event_listener', Mock(side_effect = coro_func_returning_value(None)))
def test_timeout(self):
"""
Methods that have an 'await' argument also have an optional timeout. Check
@@ -648,7 +647,7 @@ class TestControl(unittest.TestCase):
response = ''.join(['%s\r\n' % ' '.join(entry) for entry in valid_streams])
get_info_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.AsyncController.get_info', get_info_mock):
+ with patch('stem.control.Controller.get_info', get_info_mock):
streams = self.controller.get_streams()
self.assertEqual(len(valid_streams), len(streams))
@@ -669,7 +668,7 @@ class TestControl(unittest.TestCase):
response = stem.response.ControlMessage.from_str('555 Connection is not managed by controller.\r\n')
msg_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.AsyncController.msg', msg_mock):
+ with patch('stem.control.Controller.msg', msg_mock):
self.assertRaises(UnsatisfiableRequest, self.controller.attach_stream, 'stream_id', 'circ_id')
def test_parse_circ_path(self):
@@ -712,7 +711,7 @@ class TestControl(unittest.TestCase):
for test_input in malformed_inputs:
self.assertRaises(ProtocolError, _parse_circ_path, test_input)
- @patch('stem.control.AsyncController.get_conf')
+ @patch('stem.control.Controller.get_conf')
def test_get_effective_rate(self, get_conf_mock):
"""
Exercise the get_effective_rate() method.
@@ -720,7 +719,7 @@ class TestControl(unittest.TestCase):
# check default if nothing was set
- async def get_conf_mock_side_effect(param, *args, **kwargs):
+ async def get_conf_mock_side_effect(self, param, *args, **kwargs):
return {
'BandwidthRate': '1073741824',
'BandwidthBurst': '1073741824',
@@ -749,19 +748,19 @@ class TestControl(unittest.TestCase):
# with its work is to join on the thread.
with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
- with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
+ with patch('stem.control.Controller.is_alive') as is_alive_mock:
is_alive_mock.return_value = True
loop = self.controller._loop
- asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
+ asyncio.run_coroutine_threadsafe(Controller._event_loop(self.controller), loop)
try:
# Converting an event back into an uncast ControlMessage, then feeding it
# into our controller's event queue.
uncast_event = ControlMessage.from_str(event.raw_content())
- event_queue = self.async_controller._event_queue
+ event_queue = self.controller._event_queue
asyncio.run_coroutine_threadsafe(event_queue.put(uncast_event), loop).result()
asyncio.run_coroutine_threadsafe(event_queue.join(), loop).result() # block until the event is consumed
finally:
is_alive_mock.return_value = False
- asyncio.run_coroutine_threadsafe(self.async_controller._close(), loop).result()
+ self.controller._close()
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index bb6f554c..3facd6a5 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -135,7 +135,7 @@ class TestDescriptorDownloader(unittest.TestCase):
def test_reply_header_data(self):
query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
self.assertEqual(None, query.reply_headers) # initially we don't have a reply
- query.run(close = False)
+ query.run(stop = False)
self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('Date'))
self.assertEqual('application/octet-stream', query.reply_headers.get('Content-Type'))
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits