[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Type hints
commit b8063b3b23af95e02b27848f6ab5c82edd644609
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date: Tue Mar 24 17:55:31 2020 -0700
Type hints
Now that our minimum requirement finally meets Python 3.5 we can provide type
hints for our IDE users.
https://docs.python.org/3/library/typing.html
This is just a best effort first pass. I don't have an IDE that uses these. No
doubt there will be mistakes that need adjustment.
---
stem/__init__.py | 30 ++--
stem/client/__init__.py | 37 ++---
stem/client/cell.py | 128 +++++++++--------
stem/client/datatype.py | 84 +++++------
stem/connection.py | 44 +++---
stem/control.py | 241 ++++++++++++++++----------------
stem/descriptor/__init__.py | 66 ++++-----
stem/descriptor/bandwidth_file.py | 26 ++--
stem/descriptor/certificate.py | 39 +++---
stem/descriptor/collector.py | 49 +++----
stem/descriptor/extrainfo_descriptor.py | 59 ++++----
stem/descriptor/hidden_service.py | 109 ++++++++-------
stem/descriptor/microdescriptor.py | 18 +--
stem/descriptor/networkstatus.py | 122 ++++++++--------
stem/descriptor/remote.py | 67 ++++-----
stem/descriptor/router_status_entry.py | 62 ++++----
stem/descriptor/server_descriptor.py | 74 +++++-----
stem/descriptor/tordnsel.py | 9 +-
stem/directory.py | 49 +++----
stem/exit_policy.py | 100 ++++++-------
stem/interpreter/__init__.py | 6 +-
stem/interpreter/arguments.py | 6 +-
stem/interpreter/autocomplete.py | 9 +-
stem/interpreter/commands.py | 22 +--
stem/interpreter/help.py | 8 +-
stem/manual.py | 50 +++----
stem/process.py | 8 +-
stem/response/__init__.py | 55 ++++----
stem/response/add_onion.py | 2 +-
stem/response/authchallenge.py | 2 +-
stem/response/events.py | 71 +++++-----
stem/response/getconf.py | 2 +-
stem/response/getinfo.py | 6 +-
stem/response/mapaddress.py | 2 +-
stem/response/protocolinfo.py | 5 +-
stem/socket.py | 73 +++++-----
stem/util/__init__.py | 10 +-
stem/util/conf.py | 47 ++++---
stem/util/connection.py | 39 +++---
stem/util/enum.py | 18 +--
stem/util/log.py | 34 ++---
stem/util/proc.py | 34 ++---
stem/util/str_tools.py | 32 +++--
stem/util/system.py | 71 +++++-----
stem/util/term.py | 10 +-
stem/util/test_tools.py | 68 ++++-----
stem/util/tor_tools.py | 25 ++--
stem/version.py | 22 +--
test/integ/control/controller.py | 8 +-
test/integ/response/protocolinfo.py | 8 +-
test/settings.cfg | 23 ++-
test/unit/response/protocolinfo.py | 2 +-
52 files changed, 1143 insertions(+), 1048 deletions(-)
diff --git a/stem/__init__.py b/stem/__init__.py
index 907156fe..c0efab19 100644
--- a/stem/__init__.py
+++ b/stem/__init__.py
@@ -507,6 +507,8 @@ import traceback
import stem.util
import stem.util.enum
+from typing import Any, Optional, Sequence
+
__version__ = '1.8.0-dev'
__author__ = 'Damian Johnson'
__contact__ = 'atagar@xxxxxxxxxxxxxx'
@@ -584,7 +586,7 @@ class Endpoint(object):
:var int port: port of the endpoint
"""
- def __init__(self, address, port):
+ def __init__(self, address: str, port: int) -> None:
if not stem.util.connection.is_valid_ipv4_address(address) and not stem.util.connection.is_valid_ipv6_address(address):
raise ValueError("'%s' isn't a valid IPv4 or IPv6 address" % address)
elif not stem.util.connection.is_valid_port(port):
@@ -593,13 +595,13 @@ class Endpoint(object):
self.address = address
self.port = int(port)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'address', 'port', cache = True)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, Endpoint) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@@ -610,11 +612,11 @@ class ORPort(Endpoint):
:var list link_protocols: link protocol version we're willing to establish
"""
- def __init__(self, address, port, link_protocols = None):
+ def __init__(self, address: str, port: int, link_protocols: Optional[Sequence[int]] = None) -> None:
super(ORPort, self).__init__(address, port)
self.link_protocols = link_protocols
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'link_protocols', parent = Endpoint, cache = True)
@@ -642,7 +644,7 @@ class OperationFailed(ControllerError):
message
"""
- def __init__(self, code = None, message = None):
+ def __init__(self, code: Optional[str] = None, message: Optional[str] = None) -> None:
super(ControllerError, self).__init__(message)
self.code = code
self.message = message
@@ -658,10 +660,10 @@ class CircuitExtensionFailed(UnsatisfiableRequest):
"""
An attempt to create or extend a circuit failed.
- :var stem.response.CircuitEvent circ: response notifying us of the failure
+ :var stem.response.events.CircuitEvent circ: response notifying us of the failure
"""
- def __init__(self, message, circ = None):
+ def __init__(self, message: str, circ: Optional['stem.response.events.CircuitEvent'] = None) -> None:
super(CircuitExtensionFailed, self).__init__(message = message)
self.circ = circ
@@ -674,7 +676,7 @@ class DescriptorUnavailable(UnsatisfiableRequest):
Subclassed under UnsatisfiableRequest rather than OperationFailed.
"""
- def __init__(self, message):
+ def __init__(self, message: str) -> None:
super(DescriptorUnavailable, self).__init__(message = message)
@@ -685,7 +687,7 @@ class Timeout(UnsatisfiableRequest):
.. versionadded:: 1.7.0
"""
- def __init__(self, message):
+ def __init__(self, message: str) -> None:
super(Timeout, self).__init__(message = message)
@@ -705,7 +707,7 @@ class InvalidArguments(InvalidRequest):
:var list arguments: a list of arguments which were invalid
"""
- def __init__(self, code = None, message = None, arguments = None):
+ def __init__(self, code: Optional[str] = None, message: Optional[str] = None, arguments: Optional[Sequence[str]] = None):
super(InvalidArguments, self).__init__(code, message)
self.arguments = arguments
@@ -736,7 +738,7 @@ class DownloadFailed(IOError):
:var str stacktrace_str: string representation of the stacktrace
"""
- def __init__(self, url, error, stacktrace, message = None):
+ def __init__(self, url: str, error: BaseException, stacktrace: Any, message: Optional[str] = None) -> None:
if message is None:
# The string representation of exceptions can reside in several places.
# urllib.URLError use a 'reason' attribute that in turn may referrence
@@ -773,7 +775,7 @@ class DownloadTimeout(DownloadFailed):
.. versionadded:: 1.8.0
"""
- def __init__(self, url, error, stacktrace, timeout):
+ def __init__(self, url: str, error: BaseException, stacktrace: Any, timeout: int):
message = 'Failed to download from %s: %0.1f second timeout reached' % (url, timeout)
super(DownloadTimeout, self).__init__(url, error, stacktrace, message)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 57cd3457..2972985d 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -33,6 +33,9 @@ import stem.client.cell
import stem.socket
import stem.util.connection
+from types import TracebackType
+from typing import Iterator, Optional, Tuple, Type
+
from stem.client.cell import (
CELL_TYPE_SIZE,
FIXED_PAYLOAD_LEN,
@@ -63,7 +66,7 @@ class Relay(object):
:var int link_protocol: link protocol version we established
"""
- def __init__(self, orport, link_protocol):
+ def __init__(self, orport: int, link_protocol: int) -> None:
self.link_protocol = LinkProtocol(link_protocol)
self._orport = orport
self._orport_buffer = b'' # unread bytes
@@ -71,7 +74,7 @@ class Relay(object):
self._circuits = {}
@staticmethod
- def connect(address, port, link_protocols = DEFAULT_LINK_PROTOCOLS):
+ def connect(address: str, port: int, link_protocols: Tuple[int] = DEFAULT_LINK_PROTOCOLS) -> None:
"""
Establishes a connection with the given ORPort.
@@ -144,7 +147,7 @@ class Relay(object):
return Relay(conn, link_protocol)
- def _recv(self, raw = False):
+ def _recv(self, raw: bool = False) -> None:
"""
Reads the next cell from our ORPort. If none is present this blocks
until one is available.
@@ -185,7 +188,7 @@ class Relay(object):
cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol)
return cell
- def _msg(self, cell):
+ def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']:
"""
Sends a cell on the ORPort and provides the response we receive in reply.
@@ -217,7 +220,7 @@ class Relay(object):
for received_cell in stem.client.cell.Cell.pop(response, self.link_protocol):
yield received_cell
- def is_alive(self):
+ def is_alive(self) -> bool:
"""
Checks if our socket is currently connected. This is a pass-through for our
socket's :func:`~stem.socket.BaseSocket.is_alive` method.
@@ -227,7 +230,7 @@ class Relay(object):
return self._orport.is_alive()
- def connection_time(self):
+ def connection_time(self) -> float:
"""
Provides the unix timestamp for when our socket was either connected or
disconnected. That is to say, the time we connected if we're currently
@@ -239,7 +242,7 @@ class Relay(object):
return self._orport.connection_time()
- def close(self):
+ def close(self) -> None:
"""
Closes our socket connection. This is a pass-through for our socket's
:func:`~stem.socket.BaseSocket.close` method.
@@ -248,7 +251,7 @@ class Relay(object):
with self._orport_lock:
return self._orport.close()
- def create_circuit(self):
+ def create_circuit(self) -> None:
"""
Establishes a new circuit.
"""
@@ -277,15 +280,15 @@ class Relay(object):
return circ
- def __iter__(self):
+ def __iter__(self) -> Iterator['stem.client.Circuit']:
with self._orport_lock:
for circ in self._circuits.values():
yield circ
- def __enter__(self):
+ def __enter__(self) -> 'stem.client.Relay':
return self
- def __exit__(self, exit_type, value, traceback):
+ def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
self.close()
@@ -304,7 +307,7 @@ class Circuit(object):
:raises: **ImportError** if the cryptography module is unavailable
"""
- def __init__(self, relay, circ_id, kdf):
+ def __init__(self, relay: 'stem.client.Relay', circ_id: int, kdf: 'stem.client.datatype.KDF') -> None:
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
@@ -320,7 +323,7 @@ class Circuit(object):
self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor()
self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor()
- def directory(self, request, stream_id = 0):
+ def directory(self, request: str, stream_id: int = 0) -> str:
"""
Request descriptors from the relay.
@@ -355,7 +358,7 @@ class Circuit(object):
else:
response.append(decrypted_cell)
- def _send(self, command, data = '', stream_id = 0):
+ def _send(self, command: 'stem.client.datatype.RelayCommand', data: bytes = b'', stream_id: int = 0) -> None:
"""
Sends a message over the circuit.
@@ -375,13 +378,13 @@ class Circuit(object):
self.forward_digest = forward_digest
self.forward_key = forward_key
- def close(self):
+ def close(self) -> None:
with self.relay._orport_lock:
self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
del self.relay._circuits[self.id]
- def __enter__(self):
+ def __enter__(self) -> 'stem.client.Circuit':
return self
- def __exit__(self, exit_type, value, traceback):
+ def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
self.close()
diff --git a/stem/client/cell.py b/stem/client/cell.py
index 83888556..ef445a64 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -49,6 +49,8 @@ from stem import UNDEFINED
from stem.client.datatype import HASH_LEN, ZERO, LinkProtocol, Address, Certificate, CloseReason, RelayCommand, Size, split
from stem.util import datetime_to_unix, str_tools
+from typing import Any, Sequence, Tuple, Type
+
FIXED_PAYLOAD_LEN = 509 # PAYLOAD_LEN, per tor-spec section 0.2
AUTH_CHALLENGE_SIZE = 32
@@ -96,17 +98,19 @@ class Cell(object):
VALUE = -1
IS_FIXED_SIZE = False
- def __init__(self, unused = b''):
+ def __init__(self, unused: bytes = b'') -> None:
super(Cell, self).__init__()
self.unused = unused
@staticmethod
- def by_name(name):
+ def by_name(name: str) -> Type['stem.client.cell.Cell']:
"""
Provides cell attributes by its name.
:param str name: cell command to fetch
+ :returns: cell class with this name
+
:raises: **ValueError** if cell type is invalid
"""
@@ -117,12 +121,14 @@ class Cell(object):
raise ValueError("'%s' isn't a valid cell type" % name)
@staticmethod
- def by_value(value):
+ def by_value(value: int) -> Type['stem.client.cell.Cell']:
"""
Provides cell attributes by its value.
:param int value: cell value to fetch
+ :returns: cell class with this numeric value
+
:raises: **ValueError** if cell type is invalid
"""
@@ -136,7 +142,7 @@ class Cell(object):
raise NotImplementedError('Packing not yet implemented for %s cells' % type(self).NAME)
@staticmethod
- def unpack(content, link_protocol):
+ def unpack(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.Cell':
"""
Unpacks all cells from a response.
@@ -155,7 +161,7 @@ class Cell(object):
yield cell
@staticmethod
- def pop(content, link_protocol):
+ def pop(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> Tuple['stem.client.cell.Cell', bytes]:
"""
Unpacks the first cell.
@@ -187,7 +193,7 @@ class Cell(object):
return cls._unpack(payload, circ_id, link_protocol), content
@classmethod
- def _pack(cls, link_protocol, payload, unused = b'', circ_id = None):
+ def _pack(cls: Type['stem.client.cell.Cell'], link_protocol: 'stem.client.datatype.LinkProtocol', payload: bytes, unused: bytes = b'', circ_id: int = None) -> bytes:
"""
Provides bytes that can be used on the wire for these cell attributes.
Format of a properly packed cell depends on if it's fixed or variable
@@ -241,13 +247,13 @@ class Cell(object):
return bytes(cell)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls: Type['stem.client.cell.Cell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.Cell':
"""
Subclass implementation for unpacking cell content.
:param bytes content: payload to decode
- :param stem.client.datatype.LinkProtocol link_protocol: link protocol version
:param int circ_id: circuit id cell is for
+ :param stem.client.datatype.LinkProtocol link_protocol: link protocol version
:returns: instance of this cell type
@@ -256,10 +262,10 @@ class Cell(object):
raise NotImplementedError('Unpacking not yet implemented for %s cells' % cls.NAME)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, Cell) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@@ -270,7 +276,7 @@ class CircuitCell(Cell):
:var int circ_id: circuit id
"""
- def __init__(self, circ_id, unused = b''):
+ def __init__(self, circ_id: int, unused: bytes = b'') -> None:
super(CircuitCell, self).__init__(unused)
self.circ_id = circ_id
@@ -286,7 +292,7 @@ class PaddingCell(Cell):
VALUE = 0
IS_FIXED_SIZE = True
- def __init__(self, payload = None):
+ def __init__(self, payload: bytes = None) -> None:
if not payload:
payload = os.urandom(FIXED_PAYLOAD_LEN)
elif len(payload) != FIXED_PAYLOAD_LEN:
@@ -295,14 +301,14 @@ class PaddingCell(Cell):
super(PaddingCell, self).__init__()
self.payload = payload
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
return PaddingCell._pack(link_protocol, self.payload)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.PaddingCell':
return PaddingCell(content)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'payload', cache = True)
@@ -311,7 +317,7 @@ class CreateCell(CircuitCell):
VALUE = 1
IS_FIXED_SIZE = True
- def __init__(self):
+ def __init__(self) -> None:
super(CreateCell, self).__init__() # TODO: implement
@@ -320,7 +326,7 @@ class CreatedCell(CircuitCell):
VALUE = 2
IS_FIXED_SIZE = True
- def __init__(self):
+ def __init__(self) -> None:
super(CreatedCell, self).__init__() # TODO: implement
@@ -346,7 +352,7 @@ class RelayCell(CircuitCell):
VALUE = 3
IS_FIXED_SIZE = True
- def __init__(self, circ_id, command, data, digest = 0, stream_id = 0, recognized = 0, unused = b''):
+ def __init__(self, circ_id: int, command, data: bytes, digest: int = 0, stream_id: int = 0, recognized: int = 0, unused: bytes = b'') -> None:
if 'hash' in str(type(digest)).lower():
# Unfortunately hashlib generates from a dynamic private class so
# isinstance() isn't such a great option. With python2/python3 the
@@ -375,7 +381,7 @@ class RelayCell(CircuitCell):
elif stream_id and self.command in STREAM_ID_DISALLOWED:
raise ValueError('%s relay cells concern the circuit itself and cannot have a stream id' % self.command)
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
payload = bytearray()
payload += Size.CHAR.pack(self.command_int)
payload += Size.SHORT.pack(self.recognized)
@@ -387,7 +393,7 @@ class RelayCell(CircuitCell):
return RelayCell._pack(link_protocol, bytes(payload), self.unused, self.circ_id)
@staticmethod
- def decrypt(link_protocol, content, key, digest):
+ def decrypt(link_protocol: 'stem.client.datatype.LinkProtocol', content: bytes, key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib.HASH') -> Tuple['stem.client.cell.RelayCell', 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib.HASH']:
"""
Decrypts content as a relay cell addressed to us. This provides back a
tuple of the form...
@@ -441,7 +447,7 @@ class RelayCell(CircuitCell):
return cell, new_key, new_digest
- def encrypt(self, link_protocol, key, digest):
+ def encrypt(self, link_protocol: 'stem.client.datatype.LinkProtocol', key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib.HASH') -> Tuple[bytes, 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib.HASH']:
"""
Encrypts our cell content to be sent with the given key. This provides back
a tuple of the form...
@@ -477,7 +483,7 @@ class RelayCell(CircuitCell):
return header + new_key.update(payload), new_key, new_digest
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.RelayCell':
command, content = Size.CHAR.pop(content)
recognized, content = Size.SHORT.pop(content) # 'recognized' field
stream_id, content = Size.SHORT.pop(content)
@@ -490,7 +496,7 @@ class RelayCell(CircuitCell):
return RelayCell(circ_id, command, data, digest, stream_id, recognized, unused)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'command_int', 'stream_id', 'digest', 'data', cache = True)
@@ -506,19 +512,19 @@ class DestroyCell(CircuitCell):
VALUE = 4
IS_FIXED_SIZE = True
- def __init__(self, circ_id, reason = CloseReason.NONE, unused = b''):
+ def __init__(self, circ_id: int, reason: 'stem.client.datatype.CloseReason' = CloseReason.NONE, unused: bytes = b'') -> None:
super(DestroyCell, self).__init__(circ_id, unused)
self.reason, self.reason_int = CloseReason.get(reason)
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
return DestroyCell._pack(link_protocol, Size.CHAR.pack(self.reason_int), self.unused, self.circ_id)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls: Type['stem.client.cell.DestroyCell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.DestroyCell':
reason, unused = Size.CHAR.pop(content)
return DestroyCell(circ_id, reason, unused)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'circ_id', 'reason_int', cache = True)
@@ -534,7 +540,7 @@ class CreateFastCell(CircuitCell):
VALUE = 5
IS_FIXED_SIZE = True
- def __init__(self, circ_id, key_material = None, unused = b''):
+ def __init__(self, circ_id: int, key_material: bytes = None, unused: bytes = b'') -> None:
if not key_material:
key_material = os.urandom(HASH_LEN)
elif len(key_material) != HASH_LEN:
@@ -543,11 +549,11 @@ class CreateFastCell(CircuitCell):
super(CreateFastCell, self).__init__(circ_id, unused)
self.key_material = key_material
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
return CreateFastCell._pack(link_protocol, self.key_material, self.unused, self.circ_id)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreateFastCell':
key_material, unused = split(content, HASH_LEN)
if len(key_material) != HASH_LEN:
@@ -555,7 +561,7 @@ class CreateFastCell(CircuitCell):
return CreateFastCell(circ_id, key_material, unused)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'circ_id', 'key_material', cache = True)
@@ -571,7 +577,7 @@ class CreatedFastCell(CircuitCell):
VALUE = 6
IS_FIXED_SIZE = True
- def __init__(self, circ_id, derivative_key, key_material = None, unused = b''):
+ def __init__(self, circ_id: int, derivative_key: bytes, key_material: bytes = None, unused: bytes = b'') -> None:
if not key_material:
key_material = os.urandom(HASH_LEN)
elif len(key_material) != HASH_LEN:
@@ -584,11 +590,11 @@ class CreatedFastCell(CircuitCell):
self.key_material = key_material
self.derivative_key = derivative_key
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
return CreatedFastCell._pack(link_protocol, self.key_material + self.derivative_key, self.unused, self.circ_id)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreateFastCell':
if len(content) < HASH_LEN * 2:
raise ValueError('Key material and derivatived key should be %i bytes, but was %i' % (HASH_LEN * 2, len(content)))
@@ -597,7 +603,7 @@ class CreatedFastCell(CircuitCell):
return CreatedFastCell(circ_id, derivative_key, key_material, content)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'circ_id', 'derivative_key', 'key_material', cache = True)
@@ -612,16 +618,16 @@ class VersionsCell(Cell):
VALUE = 7
IS_FIXED_SIZE = False
- def __init__(self, versions):
+ def __init__(self, versions: Sequence[int]) -> None:
super(VersionsCell, self).__init__()
self.versions = versions
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
payload = b''.join([Size.SHORT.pack(v) for v in self.versions])
return VersionsCell._pack(link_protocol, payload)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls: Type['stem.client.cell.VersionsCell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.VersionsCell':
link_protocols = []
while content:
@@ -630,7 +636,7 @@ class VersionsCell(Cell):
return VersionsCell(link_protocols)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'versions', cache = True)
@@ -647,13 +653,13 @@ class NetinfoCell(Cell):
VALUE = 8
IS_FIXED_SIZE = True
- def __init__(self, receiver_address, sender_addresses, timestamp = None, unused = b''):
+ def __init__(self, receiver_address: 'stem.client.datatype.Address', sender_addresses: Sequence['stem.client.datatype.Address'], timestamp: datetime.datetime = None, unused: bytes = b'') -> None:
super(NetinfoCell, self).__init__(unused)
self.timestamp = timestamp if timestamp else datetime.datetime.now()
self.receiver_address = receiver_address
self.sender_addresses = sender_addresses
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
payload = bytearray()
payload += Size.LONG.pack(int(datetime_to_unix(self.timestamp)))
payload += self.receiver_address.pack()
@@ -665,7 +671,7 @@ class NetinfoCell(Cell):
return NetinfoCell._pack(link_protocol, bytes(payload), self.unused)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.NetinfoCell':
timestamp, content = Size.LONG.pop(content)
receiver_address, content = Address.pop(content)
@@ -678,7 +684,7 @@ class NetinfoCell(Cell):
return NetinfoCell(receiver_address, sender_addresses, datetime.datetime.utcfromtimestamp(timestamp), unused = content)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'timestamp', 'receiver_address', 'sender_addresses', cache = True)
@@ -687,7 +693,7 @@ class RelayEarlyCell(CircuitCell):
VALUE = 9
IS_FIXED_SIZE = True
- def __init__(self):
+ def __init__(self) -> None:
super(RelayEarlyCell, self).__init__() # TODO: implement
@@ -696,7 +702,7 @@ class Create2Cell(CircuitCell):
VALUE = 10
IS_FIXED_SIZE = True
- def __init__(self):
+ def __init__(self) -> None:
super(Create2Cell, self).__init__() # TODO: implement
@@ -705,7 +711,7 @@ class Created2Cell(Cell):
VALUE = 11
IS_FIXED_SIZE = True
- def __init__(self):
+ def __init__(self) -> None:
super(Created2Cell, self).__init__() # TODO: implement
@@ -714,7 +720,7 @@ class PaddingNegotiateCell(Cell):
VALUE = 12
IS_FIXED_SIZE = True
- def __init__(self):
+ def __init__(self) -> None:
super(PaddingNegotiateCell, self).__init__() # TODO: implement
@@ -729,7 +735,7 @@ class VPaddingCell(Cell):
VALUE = 128
IS_FIXED_SIZE = False
- def __init__(self, size = None, payload = None):
+ def __init__(self, size: int = None, payload: bytes = None) -> None:
if size is None and payload is None:
raise ValueError('VPaddingCell constructor must specify payload or size')
elif size is not None and size < 0:
@@ -740,14 +746,14 @@ class VPaddingCell(Cell):
super(VPaddingCell, self).__init__()
self.payload = payload if payload is not None else os.urandom(size)
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
return VPaddingCell._pack(link_protocol, self.payload)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.VPaddingCell':
return VPaddingCell(payload = content)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'payload', cache = True)
@@ -762,15 +768,15 @@ class CertsCell(Cell):
VALUE = 129
IS_FIXED_SIZE = False
- def __init__(self, certs, unused = b''):
+ def __init__(self, certs: Sequence['stem.client.Certificate'], unused: bytes = b'') -> None:
super(CertsCell, self).__init__(unused)
self.certificates = certs
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
return CertsCell._pack(link_protocol, Size.CHAR.pack(len(self.certificates)) + b''.join([cert.pack() for cert in self.certificates]), self.unused)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CertsCell':
cert_count, content = Size.CHAR.pop(content)
certs = []
@@ -783,7 +789,7 @@ class CertsCell(Cell):
return CertsCell(certs, unused = content)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'certificates', cache = True)
@@ -800,7 +806,7 @@ class AuthChallengeCell(Cell):
VALUE = 130
IS_FIXED_SIZE = False
- def __init__(self, methods, challenge = None, unused = b''):
+ def __init__(self, methods: Sequence[int], challenge: bytes = None, unused: bytes = b'') -> None:
if not challenge:
challenge = os.urandom(AUTH_CHALLENGE_SIZE)
elif len(challenge) != AUTH_CHALLENGE_SIZE:
@@ -810,7 +816,7 @@ class AuthChallengeCell(Cell):
self.challenge = challenge
self.methods = methods
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
payload = bytearray()
payload += self.challenge
payload += Size.SHORT.pack(len(self.methods))
@@ -821,7 +827,7 @@ class AuthChallengeCell(Cell):
return AuthChallengeCell._pack(link_protocol, bytes(payload), self.unused)
@classmethod
- def _unpack(cls, content, circ_id, link_protocol):
+ def _unpack(cls: Type['stem.client.cell.AuthChallengeCell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.AuthChallengeCell':
min_size = AUTH_CHALLENGE_SIZE + Size.SHORT.size
if len(content) < min_size:
raise ValueError('AUTH_CHALLENGE payload should be at least %i bytes, but was %i' % (min_size, len(content)))
@@ -840,7 +846,7 @@ class AuthChallengeCell(Cell):
return AuthChallengeCell(methods, challenge, unused = content)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'challenge', 'methods', cache = True)
@@ -849,7 +855,7 @@ class AuthenticateCell(Cell):
VALUE = 131
IS_FIXED_SIZE = False
- def __init__(self):
+ def __init__(self) -> None:
super(AuthenticateCell, self).__init__() # TODO: implement
@@ -858,5 +864,5 @@ class AuthorizeCell(Cell):
VALUE = 132
IS_FIXED_SIZE = False
- def __init__(self):
+ def __init__(self) -> None:
super(AuthorizeCell, self).__init__() # TODO: implement
diff --git a/stem/client/datatype.py b/stem/client/datatype.py
index 4f7110e9..8d8ae7fb 100644
--- a/stem/client/datatype.py
+++ b/stem/client/datatype.py
@@ -144,6 +144,8 @@ import stem.util
import stem.util.connection
import stem.util.enum
+from typing import Any, Tuple, Type, Union
+
ZERO = b'\x00'
HASH_LEN = 20
KEY_LEN = 16
@@ -155,7 +157,7 @@ class _IntegerEnum(stem.util.enum.Enum):
**UNKNOWN** value for integer values that lack a mapping.
"""
- def __init__(self, *args):
+ def __init__(self, *args: Tuple[str, int]) -> None:
self._enum_to_int = {}
self._int_to_enum = {}
parent_args = []
@@ -176,7 +178,7 @@ class _IntegerEnum(stem.util.enum.Enum):
parent_args.append(('UNKNOWN', 'UNKNOWN'))
super(_IntegerEnum, self).__init__(*parent_args)
- def get(self, val):
+ def get(self, val: Union[int, str]) -> Tuple[str, int]:
"""
Provides the (enum, int_value) tuple for a given value.
"""
@@ -246,7 +248,7 @@ CloseReason = _IntegerEnum(
)
-def split(content, size):
+def split(content: bytes, size: int) -> Tuple[bytes, bytes]:
"""
Simple split of bytes into two substrings.
@@ -270,7 +272,7 @@ class LinkProtocol(int):
from a range that's determined by our link protocol.
"""
- def __new__(cls, version):
+ def __new__(cls: Type['stem.client.datatype.LinkProtocol'], version: int) -> 'stem.client.datatype.LinkProtocol':
if isinstance(version, LinkProtocol):
return version # already a LinkProtocol
@@ -284,14 +286,14 @@ class LinkProtocol(int):
return protocol
- def __hash__(self):
+ def __hash__(self) -> int:
# All LinkProtocol attributes can be derived from our version, so that's
# all we need in our hash. Offsetting by our type so we don't hash conflict
# with ints.
return self.version * hash(str(type(self)))
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, int):
return self.version == other
elif isinstance(other, LinkProtocol):
@@ -299,10 +301,10 @@ class LinkProtocol(int):
else:
return False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
- def __int__(self):
+ def __int__(self) -> int:
return self.version
@@ -311,7 +313,7 @@ class Field(object):
Packable and unpackable datatype.
"""
- def pack(self):
+ def pack(self) -> bytes:
"""
Encodes field into bytes.
@@ -323,7 +325,7 @@ class Field(object):
raise NotImplementedError('Not yet available')
@classmethod
- def unpack(cls, packed):
+ def unpack(cls, packed: bytes) -> 'stem.client.datatype.Field':
"""
Decodes bytes into a field of this type.
@@ -342,7 +344,7 @@ class Field(object):
return unpacked
@staticmethod
- def pop(packed):
+ def pop(packed: bytes) -> Tuple[Any, bytes]:
"""
Decodes bytes as this field type, providing it and the remainder.
@@ -355,10 +357,10 @@ class Field(object):
raise NotImplementedError('Not yet available')
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, Field) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@@ -378,15 +380,15 @@ class Size(Field):
==================== ===========
"""
- def __init__(self, name, size):
+ def __init__(self, name: str, size: int) -> None:
self.name = name
self.size = size
@staticmethod
- def pop(packed):
+ def pop(packed: bytes) -> Tuple[int, bytes]:
raise NotImplementedError("Use our constant's unpack() and pop() instead")
- def pack(self, content):
+ def pack(self, content: int) -> bytes:
try:
return content.to_bytes(self.size, 'big')
except:
@@ -397,18 +399,18 @@ class Size(Field):
else:
raise
- def unpack(self, packed):
+ def unpack(self, packed: bytes) -> int:
if self.size != len(packed):
raise ValueError('%s is the wrong size for a %s field' % (repr(packed), self.name))
return int.from_bytes(packed, 'big')
- def pop(self, packed):
+ def pop(self, packed: bytes) -> Tuple[int, bytes]:
to_unpack, remainder = split(packed, self.size)
return self.unpack(to_unpack), remainder
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'name', 'size', cache = True)
@@ -422,7 +424,7 @@ class Address(Field):
:var bytes value_bin: encoded address value
"""
- def __init__(self, value, addr_type = None):
+ def __init__(self, value: str, addr_type: Union[int, 'stem.client.datatype.AddrType'] = None) -> None:
if addr_type is None:
if stem.util.connection.is_valid_ipv4_address(value):
addr_type = AddrType.IPv4
@@ -461,7 +463,7 @@ class Address(Field):
self.value = None
self.value_bin = value
- def pack(self):
+ def pack(self) -> bytes:
cell = bytearray()
cell += Size.CHAR.pack(self.type_int)
cell += Size.CHAR.pack(len(self.value_bin))
@@ -469,7 +471,7 @@ class Address(Field):
return bytes(cell)
@staticmethod
- def pop(content):
+ def pop(content) -> Tuple['stem.client.datatype.Address', bytes]:
addr_type, content = Size.CHAR.pop(content)
addr_length, content = Size.CHAR.pop(content)
@@ -480,7 +482,7 @@ class Address(Field):
return Address(addr_value, addr_type), content
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'type_int', 'value_bin', cache = True)
@@ -493,11 +495,11 @@ class Certificate(Field):
:var bytes value: certificate value
"""
- def __init__(self, cert_type, value):
+ def __init__(self, cert_type: Union[int, 'stem.client.datatype.CertType'], value: bytes) -> None:
self.type, self.type_int = CertType.get(cert_type)
self.value = value
- def pack(self):
+ def pack(self) -> bytes:
cell = bytearray()
cell += Size.CHAR.pack(self.type_int)
cell += Size.SHORT.pack(len(self.value))
@@ -505,7 +507,7 @@ class Certificate(Field):
return bytes(cell)
@staticmethod
- def pop(content):
+ def pop(content: bytes) -> Tuple['stem.client.datatype.Certificate', bytes]:
cert_type, content = Size.CHAR.pop(content)
cert_size, content = Size.SHORT.pop(content)
@@ -515,7 +517,7 @@ class Certificate(Field):
cert_bytes, content = split(content, cert_size)
return Certificate(cert_type, cert_bytes), content
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'type_int', 'value')
@@ -532,12 +534,12 @@ class LinkSpecifier(Field):
:var bytes value: encoded link specification destination
"""
- def __init__(self, link_type, value):
+ def __init__(self, link_type: int, value: bytes) -> None:
self.type = link_type
self.value = value
@staticmethod
- def pop(packed):
+ def pop(packed: bytes) -> Tuple['stem.client.datatype.LinkSpecifier', bytes]:
# LSTYPE (Link specifier type) [1 byte]
# LSLEN (Link specifier length) [1 byte]
# LSPEC (Link specifier) [LSLEN bytes]
@@ -561,7 +563,7 @@ class LinkSpecifier(Field):
else:
return LinkSpecifier(link_type, value), packed # unrecognized type
- def pack(self):
+ def pack(self) -> bytes:
cell = bytearray()
cell += Size.CHAR.pack(self.type)
cell += Size.CHAR.pack(len(self.value))
@@ -579,14 +581,14 @@ class LinkByIPv4(LinkSpecifier):
:var int port: relay ORPort
"""
- def __init__(self, address, port):
+ def __init__(self, address: str, port: int) -> None:
super(LinkByIPv4, self).__init__(0, _pack_ipv4_address(address) + Size.SHORT.pack(port))
self.address = address
self.port = port
@staticmethod
- def unpack(value):
+ def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv4':
if len(value) != 6:
raise ValueError('IPv4 link specifiers should be six bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value)))
@@ -604,14 +606,14 @@ class LinkByIPv6(LinkSpecifier):
:var int port: relay ORPort
"""
- def __init__(self, address, port):
+ def __init__(self, address: str, port: int) -> None:
super(LinkByIPv6, self).__init__(1, _pack_ipv6_address(address) + Size.SHORT.pack(port))
self.address = address
self.port = port
@staticmethod
- def unpack(value):
+ def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv6':
if len(value) != 18:
raise ValueError('IPv6 link specifiers should be eighteen bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value)))
@@ -628,7 +630,7 @@ class LinkByFingerprint(LinkSpecifier):
:var str fingerprint: relay sha1 fingerprint
"""
- def __init__(self, value):
+ def __init__(self, value: bytes) -> None:
super(LinkByFingerprint, self).__init__(2, value)
if len(value) != 20:
@@ -646,7 +648,7 @@ class LinkByEd25519(LinkSpecifier):
:var str fingerprint: relay ed25519 fingerprint
"""
- def __init__(self, value):
+ def __init__(self, value: bytes) -> None:
super(LinkByEd25519, self).__init__(3, value)
if len(value) != 32:
@@ -668,7 +670,7 @@ class KDF(collections.namedtuple('KDF', ['key_hash', 'forward_digest', 'backward
"""
@staticmethod
- def from_value(key_material):
+ def from_value(key_material: bytes) -> 'stem.client.datatype.KDF':
# Derived key material, as per...
#
# K = H(K0 | [00]) | H(K0 | [01]) | H(K0 | [02]) | ...
@@ -689,19 +691,19 @@ class KDF(collections.namedtuple('KDF', ['key_hash', 'forward_digest', 'backward
return KDF(key_hash, forward_digest, backward_digest, forward_key, backward_key)
-def _pack_ipv4_address(address):
+def _pack_ipv4_address(address: str) -> bytes:
return b''.join([Size.CHAR.pack(int(v)) for v in address.split('.')])
-def _unpack_ipv4_address(value):
+def _unpack_ipv4_address(value: str) -> bytes:
return '.'.join([str(Size.CHAR.unpack(value[i:i + 1])) for i in range(4)])
-def _pack_ipv6_address(address):
+def _pack_ipv6_address(address: str) -> bytes:
return b''.join([Size.SHORT.pack(int(v, 16)) for v in address.split(':')])
-def _unpack_ipv6_address(value):
+def _unpack_ipv6_address(value: str) -> bytes:
return ':'.join(['%04x' % Size.SHORT.unpack(value[i * 2:(i + 1) * 2]) for i in range(8)])
diff --git a/stem/connection.py b/stem/connection.py
index e3032784..3d3eb3ee 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -135,6 +135,7 @@ import os
import stem.control
import stem.response
+import stem.response.protocolinfo
import stem.socket
import stem.util.connection
import stem.util.enum
@@ -142,6 +143,7 @@ import stem.util.str_tools
import stem.util.system
import stem.version
+from typing import Any, Optional, Sequence, Tuple, Type, Union
from stem.util import log
AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN')
@@ -209,7 +211,7 @@ COMMON_TOR_COMMANDS = (
)
-def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/tor/control', password = None, password_prompt = False, chroot_path = None, controller = stem.control.Controller):
+def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: type = stem.control.Controller) -> Union[stem.control.BaseController, stem.socket.ControlSocket]:
"""
Convenience function for quickly getting a control connection. This is very
handy for debugging or CLI setup, handling setup and prompting for a password
@@ -234,7 +236,7 @@ def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/
Use both port 9051 and 9151 by default.
:param tuple contol_port: address and port tuple, for instance **('127.0.0.1', 9051)**
- :param str path: path where the control socket is located
+ :param str control_socket: path where the control socket is located
:param str password: passphrase to authenticate to the socket
:param bool password_prompt: prompt for the controller password if it wasn't
supplied
@@ -295,7 +297,7 @@ def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/
return _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
-def _connect_auth(control_socket, password, password_prompt, chroot_path, controller):
+def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Union[Type[stem.control.BaseController], Type[stem.socket.ControlSocket]]) -> Union[stem.control.BaseController, stem.socket.ControlSocket]:
"""
Helper for the connect_* functions that authenticates the socket and
constructs the controller.
@@ -361,7 +363,7 @@ def _connect_auth(control_socket, password, password_prompt, chroot_path, contro
return None
-def authenticate(controller, password = None, chroot_path = None, protocolinfo_response = None):
+def authenticate(controller: Any, password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None:
"""
Authenticates to a control socket using the information provided by a
PROTOCOLINFO response. In practice this will often be all we need to
@@ -575,7 +577,7 @@ def authenticate(controller, password = None, chroot_path = None, protocolinfo_r
raise AssertionError('BUG: Authentication failed without providing a recognized exception: %s' % str(auth_exceptions))
-def authenticate_none(controller, suppress_ctl_errors = True):
+def authenticate_none(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], suppress_ctl_errors: bool = True) -> None:
"""
Authenticates to an open control socket. All control connections need to
authenticate before they can be used, even if tor hasn't been configured to
@@ -622,7 +624,7 @@ def authenticate_none(controller, suppress_ctl_errors = True):
raise OpenAuthRejected('Socket failed (%s)' % exc)
-def authenticate_password(controller, password, suppress_ctl_errors = True):
+def authenticate_password(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: str, suppress_ctl_errors: bool = True) -> None:
"""
Authenticates to a control socket that uses a password (via the
HashedControlPassword torrc option). Quotes in the password are escaped.
@@ -692,7 +694,7 @@ def authenticate_password(controller, password, suppress_ctl_errors = True):
raise PasswordAuthRejected('Socket failed (%s)' % exc)
-def authenticate_cookie(controller, cookie_path, suppress_ctl_errors = True):
+def authenticate_cookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None:
"""
Authenticates to a control socket that uses the contents of an authentication
cookie (generated via the CookieAuthentication torrc option). This does basic
@@ -782,7 +784,7 @@ def authenticate_cookie(controller, cookie_path, suppress_ctl_errors = True):
raise CookieAuthRejected('Socket failed (%s)' % exc, cookie_path, False)
-def authenticate_safecookie(controller, cookie_path, suppress_ctl_errors = True):
+def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None:
"""
Authenticates to a control socket using the safe cookie method, which is
enabled by setting the CookieAuthentication torrc option on Tor client's which
@@ -931,7 +933,7 @@ def authenticate_safecookie(controller, cookie_path, suppress_ctl_errors = True)
raise CookieAuthRejected(str(auth_response), cookie_path, True, auth_response)
-def get_protocolinfo(controller):
+def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket.ControlSocket]) -> stem.response.protocolinfo.ProtocolInfoResponse:
"""
Issues a PROTOCOLINFO query to a control socket, getting information about
the tor process running on it. If the socket is already closed then it is
@@ -971,7 +973,7 @@ def get_protocolinfo(controller):
return protocolinfo_response
-def _msg(controller, message):
+def _msg(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], message: str) -> stem.response.ControlMessage:
"""
Sends and receives a message with either a
:class:`~stem.socket.ControlSocket` or :class:`~stem.control.BaseController`.
@@ -984,7 +986,7 @@ def _msg(controller, message):
return controller.msg(message)
-def _connection_for_default_port(address):
+def _connection_for_default_port(address: str) -> stem.socket.ControlPort:
"""
Attempts to provide a controller connection for either port 9051 (default for
relays) or 9151 (default for Tor Browser). If both fail then this raises the
@@ -1006,7 +1008,7 @@ def _connection_for_default_port(address):
raise exc
-def _read_cookie(cookie_path, is_safecookie):
+def _read_cookie(cookie_path: str, is_safecookie: bool) -> str:
"""
Provides the contents of a given cookie file.
@@ -1014,6 +1016,8 @@ def _read_cookie(cookie_path, is_safecookie):
:param bool is_safecookie: **True** if this was for SAFECOOKIE
authentication, **False** if for COOKIE
+ :returns: **str** with the cookie file content
+
:raises:
* :class:`stem.connection.UnreadableCookieFile` if the cookie file is
unreadable
@@ -1048,7 +1052,7 @@ def _read_cookie(cookie_path, is_safecookie):
raise UnreadableCookieFile(exc_msg, cookie_path, is_safecookie)
-def _hmac_sha256(key, msg):
+def _hmac_sha256(key: str, msg: str) -> bytes:
"""
Generates a sha256 digest using the given key and message.
@@ -1065,11 +1069,11 @@ class AuthenticationFailure(Exception):
"""
Base error for authentication failures.
- :var stem.socket.ControlMessage auth_response: AUTHENTICATE response from the
+ :var stem.response.ControlMessage auth_response: AUTHENTICATE response from the
control socket, **None** if one wasn't received
"""
- def __init__(self, message, auth_response = None):
+ def __init__(self, message: str, auth_response: Optional[stem.response.ControlMessage] = None) -> None:
super(AuthenticationFailure, self).__init__(message)
self.auth_response = auth_response
@@ -1081,7 +1085,7 @@ class UnrecognizedAuthMethods(AuthenticationFailure):
:var list unknown_auth_methods: authentication methods that weren't recognized
"""
- def __init__(self, message, unknown_auth_methods):
+ def __init__(self, message: str, unknown_auth_methods: Sequence[str]) -> None:
super(UnrecognizedAuthMethods, self).__init__(message)
self.unknown_auth_methods = unknown_auth_methods
@@ -1125,7 +1129,7 @@ class CookieAuthFailed(AuthenticationFailure):
authentication attempt
"""
- def __init__(self, message, cookie_path, is_safecookie, auth_response = None):
+ def __init__(self, message: str, cookie_path: str, is_safecookie: bool, auth_response: Optional[stem.response.ControlMessage] = None) -> None:
super(CookieAuthFailed, self).__init__(message, auth_response)
self.is_safecookie = is_safecookie
self.cookie_path = cookie_path
@@ -1152,7 +1156,7 @@ class AuthChallengeFailed(CookieAuthFailed):
AUTHCHALLENGE command has failed.
"""
- def __init__(self, message, cookie_path):
+ def __init__(self, message: str, cookie_path: str) -> None:
super(AuthChallengeFailed, self).__init__(message, cookie_path, True)
@@ -1169,7 +1173,7 @@ class UnrecognizedAuthChallengeMethod(AuthChallengeFailed):
:var str authchallenge_method: AUTHCHALLENGE method that Tor couldn't recognize
"""
- def __init__(self, message, cookie_path, authchallenge_method):
+ def __init__(self, message: str, cookie_path: str, authchallenge_method: str) -> None:
super(UnrecognizedAuthChallengeMethod, self).__init__(message, cookie_path)
self.authchallenge_method = authchallenge_method
@@ -1201,7 +1205,7 @@ class NoAuthCookie(MissingAuthInfo):
authentication, **False** if for COOKIE
"""
- def __init__(self, message, is_safecookie):
+ def __init__(self, message: str, is_safecookie: bool) -> None:
super(NoAuthCookie, self).__init__(message)
self.is_safecookie = is_safecookie
diff --git a/stem/control.py b/stem/control.py
index 4016e762..ec4ba54e 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -255,7 +255,9 @@ import stem.descriptor.router_status_entry
import stem.descriptor.server_descriptor
import stem.exit_policy
import stem.response
+import stem.response.add_onion
import stem.response.events
+import stem.response.protocolinfo
import stem.socket
import stem.util
import stem.util.conf
@@ -268,6 +270,8 @@ import stem.version
from stem import UNDEFINED, CircStatus, Signal
from stem.util import log
+from types import TracebackType
+from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
# When closing the controller we attempt to finish processing enqueued events,
# but if it takes longer than this we terminate.
@@ -447,15 +451,15 @@ class CreateHiddenServiceOutput(collections.namedtuple('CreateHiddenServiceOutpu
"""
-def with_default(yields = False):
+def with_default(yields: bool = False) -> Callable:
"""
Provides a decorator to support having a default value. This should be
treated as private.
"""
- def decorator(func):
- def get_default(func, args, kwargs):
- arg_names = inspect.getargspec(func).args[1:] # drop 'self'
+ def decorator(func: Callable) -> Callable:
+ def get_default(func: Callable, args: Any, kwargs: Any) -> Any:
+ arg_names = inspect.getfullargspec(func).args[1:] # drop 'self'
default_position = arg_names.index('default') if 'default' in arg_names else None
if default_position is not None and default_position < len(args):
@@ -465,7 +469,7 @@ def with_default(yields = False):
if not yields:
@functools.wraps(func)
- def wrapped(self, *args, **kwargs):
+ def wrapped(self, *args: Any, **kwargs: Any) -> Any:
try:
return func(self, *args, **kwargs)
except:
@@ -477,7 +481,7 @@ def with_default(yields = False):
return default
else:
@functools.wraps(func)
- def wrapped(self, *args, **kwargs):
+ def wrapped(self, *args: Any, **kwargs: Any) -> Any:
try:
for val in func(self, *args, **kwargs):
yield val
@@ -496,7 +500,7 @@ def with_default(yields = False):
return decorator
-def event_description(event):
+def event_description(event: str) -> str:
"""
Provides a description for Tor events.
@@ -538,7 +542,7 @@ class BaseController(object):
socket as though it hasn't yet been authenticated.
"""
- def __init__(self, control_socket, is_authenticated = False):
+ def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None:
self._socket = control_socket
self._msg_lock = threading.RLock()
@@ -576,7 +580,7 @@ class BaseController(object):
if is_authenticated:
self._post_authentication()
- def msg(self, message):
+ def msg(self, message: str) -> stem.response.ControlMessage:
"""
Sends a message to our control socket and provides back its reply.
@@ -659,7 +663,7 @@ class BaseController(object):
self.close()
raise
- def is_alive(self):
+ def is_alive(self) -> bool:
"""
Checks if our socket is currently connected. This is a pass-through for our
socket's :func:`~stem.socket.BaseSocket.is_alive` method.
@@ -669,7 +673,7 @@ class BaseController(object):
return self._socket.is_alive()
- def is_localhost(self):
+ def is_localhost(self) -> bool:
"""
Returns if the connection is for the local system or not.
@@ -680,7 +684,7 @@ class BaseController(object):
return self._socket.is_localhost()
- def connection_time(self):
+ def connection_time(self) -> float:
"""
Provides the unix timestamp for when our socket was either connected or
disconnected. That is to say, the time we connected if we're currently
@@ -694,7 +698,7 @@ class BaseController(object):
return self._socket.connection_time()
- def is_authenticated(self):
+ def is_authenticated(self) -> bool:
"""
Checks if our socket is both connected and authenticated.
@@ -704,7 +708,7 @@ class BaseController(object):
return self._is_authenticated if self.is_alive() else False
- def connect(self):
+ def connect(self) -> None:
"""
Reconnects our control socket. This is a pass-through for our socket's
:func:`~stem.socket.ControlSocket.connect` method.
@@ -714,7 +718,7 @@ class BaseController(object):
self._socket.connect()
- def close(self):
+ def close(self) -> None:
"""
Closes our socket connection. This is a pass-through for our socket's
:func:`~stem.socket.BaseSocket.close` method.
@@ -733,7 +737,7 @@ class BaseController(object):
if t.is_alive() and threading.current_thread() != t:
t.join()
- def get_socket(self):
+ def get_socket(self) -> stem.socket.ControlSocket:
"""
Provides the socket used to speak with the tor process. Communicating with
the socket directly isn't advised since it may confuse this controller.
@@ -743,7 +747,7 @@ class BaseController(object):
return self._socket
- def get_latest_heartbeat(self):
+ def get_latest_heartbeat(self) -> float:
"""
Provides the unix timestamp for when we last heard from tor. This is zero
if we've never received a message.
@@ -753,7 +757,7 @@ class BaseController(object):
return self._last_heartbeat
- def add_status_listener(self, callback, spawn = True):
+ def add_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None], spawn: bool = True) -> None:
"""
Notifies a given function when the state of our socket changes. Functions
are expected to be of the form...
@@ -783,7 +787,7 @@ class BaseController(object):
with self._status_listeners_lock:
self._status_listeners.append((callback, spawn))
- def remove_status_listener(self, callback):
+ def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool:
"""
Stops listener from being notified of further events.
@@ -805,13 +809,13 @@ class BaseController(object):
self._status_listeners = new_listeners
return is_changed
- def __enter__(self):
+ def __enter__(self) -> 'stem.control.BaseController':
return self
- def __exit__(self, exit_type, value, traceback):
+ def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
self.close()
- def _handle_event(self, event_message):
+ def _handle_event(self, event_message: stem.response.ControlMessage) -> None:
"""
Callback to be overwritten by subclasses for event listening. This is
notified whenever we receive an event from the control socket.
@@ -822,13 +826,13 @@ class BaseController(object):
pass
- def _connect(self):
+ def _connect(self) -> None:
self._launch_threads()
self._notify_status_listeners(State.INIT)
self._socket_connect()
self._is_authenticated = False
- def _close(self):
+ def _close(self) -> None:
# Our is_alive() state is now false. Our reader thread should already be
# awake from recv() raising a closure exception. Wake up the event thread
# too so it can end.
@@ -846,12 +850,12 @@ class BaseController(object):
self._socket_close()
- def _post_authentication(self):
+ def _post_authentication(self) -> None:
# actions to be taken after we have a newly authenticated connection
self._is_authenticated = True
- def _notify_status_listeners(self, state):
+ def _notify_status_listeners(self, state: 'stem.control.State') -> None:
"""
Informs our status listeners that a state change occurred.
@@ -895,7 +899,7 @@ class BaseController(object):
else:
listener(self, state, change_timestamp)
- def _launch_threads(self):
+ def _launch_threads(self) -> None:
"""
Initializes daemon threads. Threads can't be reused so we need to recreate
them if we're restarted.
@@ -915,7 +919,7 @@ class BaseController(object):
self._event_thread.setDaemon(True)
self._event_thread.start()
- def _reader_loop(self):
+ def _reader_loop(self) -> None:
"""
Continually pulls from the control socket, directing the messages into
queues based on their type. Controller messages come in two varieties...
@@ -944,7 +948,7 @@ class BaseController(object):
self._reply_queue.put(exc)
- def _event_loop(self):
+ def _event_loop(self) -> None:
"""
Continually pulls messages from the _event_queue and sends them to our
handle_event callback. This is done via its own thread so subclasses with a
@@ -982,7 +986,7 @@ class Controller(BaseController):
"""
@staticmethod
- def from_port(address = '127.0.0.1', port = 'default'):
+ def from_port(address: str = '127.0.0.1', port: int = 'default') -> 'stem.control.Controller':
"""
Constructs a :class:`~stem.socket.ControlPort` based Controller.
@@ -1016,7 +1020,7 @@ class Controller(BaseController):
return Controller(control_port)
@staticmethod
- def from_socket_file(path = '/var/run/tor/control'):
+ def from_socket_file(path: str = '/var/run/tor/control') -> 'stem.control.Controller':
"""
Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
@@ -1030,7 +1034,7 @@ class Controller(BaseController):
control_socket = stem.socket.ControlSocketFile(path)
return Controller(control_socket)
- def __init__(self, control_socket, is_authenticated = False):
+ def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None:
self._is_caching_enabled = True
self._request_cache = {}
self._last_newnym = 0.0
@@ -1048,14 +1052,14 @@ class Controller(BaseController):
super(Controller, self).__init__(control_socket, is_authenticated)
- def _sighup_listener(event):
+ def _sighup_listener(event: stem.response.events.Event) -> None:
if event.signal == Signal.RELOAD:
self.clear_cache()
self._notify_status_listeners(State.RESET)
self.add_event_listener(_sighup_listener, EventType.SIGNAL)
- def _confchanged_listener(event):
+ def _confchanged_listener(event: stem.response.events.Event) -> None:
if self.is_caching_enabled():
to_cache_changed = dict((k.lower(), v) for k, v in event.changed.items())
to_cache_unset = dict((k.lower(), []) for k in event.unset) # [] represents None value in cache
@@ -1070,7 +1074,7 @@ class Controller(BaseController):
self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED)
- def _address_changed_listener(event):
+ def _address_changed_listener(event: stem.response.events.Event) -> None:
if event.action in ('EXTERNAL_ADDRESS', 'DNS_USELESS'):
self._set_cache({'exit_policy': None})
self._set_cache({'address': None}, 'getinfo')
@@ -1078,11 +1082,11 @@ class Controller(BaseController):
self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER)
- def close(self):
+ def close(self) -> None:
self.clear_cache()
super(Controller, self).close()
- def authenticate(self, *args, **kwargs):
+ def authenticate(self, *args: Any, **kwargs: Any) -> None:
"""
A convenience method to authenticate the controller. This is just a
pass-through to :func:`stem.connection.authenticate`.
@@ -1091,7 +1095,7 @@ class Controller(BaseController):
import stem.connection
stem.connection.authenticate(self, *args, **kwargs)
- def reconnect(self, *args, **kwargs):
+ def reconnect(self, *args: Any, **kwargs: Any) -> None:
"""
Reconnects and authenticates to our control socket.
@@ -1108,7 +1112,7 @@ class Controller(BaseController):
self.authenticate(*args, **kwargs)
@with_default()
- def get_info(self, params, default = UNDEFINED, get_bytes = False):
+ def get_info(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, get_bytes: bool = False) -> Union[str, Dict[str, str]]:
"""
get_info(params, default = UNDEFINED, get_bytes = False)
@@ -1232,7 +1236,7 @@ class Controller(BaseController):
raise
@with_default()
- def get_version(self, default = UNDEFINED):
+ def get_version(self, default: Any = UNDEFINED) -> stem.version.Version:
"""
get_version(default = UNDEFINED)
@@ -1261,7 +1265,7 @@ class Controller(BaseController):
return version
@with_default()
- def get_exit_policy(self, default = UNDEFINED):
+ def get_exit_policy(self, default: Any = UNDEFINED) -> stem.exit_policy.ExitPolicy:
"""
get_exit_policy(default = UNDEFINED)
@@ -1293,7 +1297,7 @@ class Controller(BaseController):
return policy
@with_default()
- def get_ports(self, listener_type, default = UNDEFINED):
+ def get_ports(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[int]:
"""
get_ports(listener_type, default = UNDEFINED)
@@ -1315,7 +1319,7 @@ class Controller(BaseController):
and no default was provided
"""
- def is_localhost(address):
+ def is_localhost(address: str) -> bool:
if stem.util.connection.is_valid_ipv4_address(address):
return address == '0.0.0.0' or address.startswith('127.')
elif stem.util.connection.is_valid_ipv6_address(address):
@@ -1330,7 +1334,7 @@ class Controller(BaseController):
return [port for (addr, port) in self.get_listeners(listener_type) if is_localhost(addr)]
@with_default()
- def get_listeners(self, listener_type, default = UNDEFINED):
+ def get_listeners(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[Tuple[str, int]]:
"""
get_listeners(listener_type, default = UNDEFINED)
@@ -1436,7 +1440,7 @@ class Controller(BaseController):
return listeners
@with_default()
- def get_accounting_stats(self, default = UNDEFINED):
+ def get_accounting_stats(self, default: Any = UNDEFINED) -> 'stem.control.AccountingStats':
"""
get_accounting_stats(default = UNDEFINED)
@@ -1480,7 +1484,7 @@ class Controller(BaseController):
)
@with_default()
- def get_protocolinfo(self, default = UNDEFINED):
+ def get_protocolinfo(self, default: Any = UNDEFINED) -> stem.response.protocolinfo.ProtocolInfoResponse:
"""
get_protocolinfo(default = UNDEFINED)
@@ -1503,7 +1507,7 @@ class Controller(BaseController):
return stem.connection.get_protocolinfo(self)
@with_default()
- def get_user(self, default = UNDEFINED):
+ def get_user(self, default: Any = UNDEFINED) -> str:
"""
get_user(default = UNDEFINED)
@@ -1538,7 +1542,7 @@ class Controller(BaseController):
raise ValueError("Unable to resolve tor's user" if self.is_localhost() else "Tor isn't running locally")
@with_default()
- def get_pid(self, default = UNDEFINED):
+ def get_pid(self, default: Any = UNDEFINED) -> int:
"""
get_pid(default = UNDEFINED)
@@ -1594,7 +1598,7 @@ class Controller(BaseController):
raise ValueError("Unable to resolve tor's pid" if self.is_localhost() else "Tor isn't running locally")
@with_default()
- def get_start_time(self, default = UNDEFINED):
+ def get_start_time(self, default: Any = UNDEFINED) -> float:
"""
get_start_time(default = UNDEFINED)
@@ -1644,7 +1648,7 @@ class Controller(BaseController):
raise ValueError("Unable to resolve when tor began" if self.is_localhost() else "Tor isn't running locally")
@with_default()
- def get_uptime(self, default = UNDEFINED):
+ def get_uptime(self, default: Any = UNDEFINED) -> float:
"""
get_uptime(default = UNDEFINED)
@@ -1662,7 +1666,7 @@ class Controller(BaseController):
return time.time() - self.get_start_time()
- def is_user_traffic_allowed(self):
+ def is_user_traffic_allowed(self) -> bool:
"""
Checks if we're likely to service direct user traffic. This essentially
boils down to...
@@ -1704,7 +1708,7 @@ class Controller(BaseController):
return UserTrafficAllowed(inbound_allowed, outbound_allowed)
@with_default()
- def get_microdescriptor(self, relay = None, default = UNDEFINED):
+ def get_microdescriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.microdescriptor.Microdescriptor:
"""
get_microdescriptor(relay = None, default = UNDEFINED)
@@ -1762,7 +1766,7 @@ class Controller(BaseController):
return stem.descriptor.microdescriptor.Microdescriptor(desc_content)
@with_default(yields = True)
- def get_microdescriptors(self, default = UNDEFINED):
+ def get_microdescriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]:
"""
get_microdescriptors(default = UNDEFINED)
@@ -1793,7 +1797,7 @@ class Controller(BaseController):
yield desc
@with_default()
- def get_server_descriptor(self, relay = None, default = UNDEFINED):
+ def get_server_descriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor:
"""
get_server_descriptor(relay = None, default = UNDEFINED)
@@ -1856,7 +1860,7 @@ class Controller(BaseController):
return stem.descriptor.server_descriptor.RelayDescriptor(desc_content)
@with_default(yields = True)
- def get_server_descriptors(self, default = UNDEFINED):
+ def get_server_descriptors(self, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor:
"""
get_server_descriptors(default = UNDEFINED)
@@ -1892,7 +1896,7 @@ class Controller(BaseController):
yield desc
@with_default()
- def get_network_status(self, relay = None, default = UNDEFINED):
+ def get_network_status(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.router_status_entry.RouterStatusEntryV3:
"""
get_network_status(relay = None, default = UNDEFINED)
@@ -1951,7 +1955,7 @@ class Controller(BaseController):
return stem.descriptor.router_status_entry.RouterStatusEntryV3(desc_content)
@with_default(yields = True)
- def get_network_statuses(self, default = UNDEFINED):
+ def get_network_statuses(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]:
"""
get_network_statuses(default = UNDEFINED)
@@ -1988,7 +1992,7 @@ class Controller(BaseController):
yield desc
@with_default()
- def get_hidden_service_descriptor(self, address, default = UNDEFINED, servers = None, await_result = True, timeout = None):
+ def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2:
"""
get_hidden_service_descriptor(address, default = UNDEFINED, servers = None, await_result = True)
@@ -2036,10 +2040,10 @@ class Controller(BaseController):
start_time = time.time()
if await_result:
- def hs_desc_listener(event):
+ def hs_desc_listener(event: stem.response.events.Event) -> None:
hs_desc_queue.put(event)
- def hs_desc_content_listener(event):
+ def hs_desc_content_listener(event: stem.response.events.Event) -> None:
hs_desc_content_queue.put(event)
self.add_event_listener(hs_desc_listener, EventType.HS_DESC)
@@ -2084,7 +2088,7 @@ class Controller(BaseController):
if hs_desc_content_listener:
self.remove_event_listener(hs_desc_content_listener)
- def get_conf(self, param, default = UNDEFINED, multiple = False):
+ def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]:
"""
get_conf(param, default = UNDEFINED, multiple = False)
@@ -2133,7 +2137,7 @@ class Controller(BaseController):
entries = self.get_conf_map(param, default, multiple)
return _case_insensitive_lookup(entries, param, default)
- def get_conf_map(self, params, default = UNDEFINED, multiple = True):
+ def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
"""
get_conf_map(params, default = UNDEFINED, multiple = True)
@@ -2251,7 +2255,7 @@ class Controller(BaseController):
else:
raise
- def _get_conf_dict_to_response(self, config_dict, default, multiple):
+ def _get_conf_dict_to_response(self, config_dict: Mapping[str, Sequence[str]], default: Any, multiple: bool) -> Dict[str, Union[str, Sequence[str]]]:
"""
Translates a dictionary of 'config key => [value1, value2...]' into the
return value of :func:`~stem.control.Controller.get_conf_map`, taking into
@@ -2273,7 +2277,7 @@ class Controller(BaseController):
return return_dict
@with_default()
- def is_set(self, param, default = UNDEFINED):
+ def is_set(self, param: str, default: Any = UNDEFINED) -> bool:
"""
is_set(param, default = UNDEFINED)
@@ -2293,7 +2297,7 @@ class Controller(BaseController):
return param in self._get_custom_options()
- def _get_custom_options(self):
+ def _get_custom_options(self) -> Dict[str, str]:
result = self._get_cache('get_custom_options')
if not result:
@@ -2320,7 +2324,7 @@ class Controller(BaseController):
return result
- def set_conf(self, param, value):
+ def set_conf(self, param: str, value: Union[str, Sequence[str]]) -> None:
"""
Changes the value of a tor configuration option. Our value can be any of
the following...
@@ -2342,7 +2346,7 @@ class Controller(BaseController):
self.set_options({param: value}, False)
- def reset_conf(self, *params):
+ def reset_conf(self, *params: str) -> None:
"""
Reverts one or more parameters to their default values.
@@ -2357,7 +2361,7 @@ class Controller(BaseController):
self.set_options(dict([(entry, None) for entry in params]), True)
- def set_options(self, params, reset = False):
+ def set_options(self, params: Union[Mapping[str, Union[str, Sequence[str]]], Sequence[Tuple[str, Union[str, Sequence[str]]]]], reset: bool = False) -> None:
"""
Changes multiple tor configuration options via either a SETCONF or
RESETCONF query. Both behave identically unless our value is None, in which
@@ -2439,7 +2443,7 @@ class Controller(BaseController):
raise stem.ProtocolError('Returned unexpected status code: %s' % response.code)
@with_default()
- def get_hidden_service_conf(self, default = UNDEFINED):
+ def get_hidden_service_conf(self, default: Any = UNDEFINED) -> Dict[str, Any]:
"""
get_hidden_service_conf(default = UNDEFINED)
@@ -2534,7 +2538,7 @@ class Controller(BaseController):
self._set_cache({'hidden_service_conf': service_dir_map})
return service_dir_map
- def set_hidden_service_conf(self, conf):
+ def set_hidden_service_conf(self, conf: Mapping[str, Any]) -> None:
"""
Update all the configured hidden services from a dictionary having
the same format as
@@ -2599,7 +2603,7 @@ class Controller(BaseController):
self.set_options(hidden_service_options)
- def create_hidden_service(self, path, port, target_address = None, target_port = None, auth_type = None, client_names = None):
+ def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.cotroller.CreateHiddenServiceOutput':
"""
Create a new hidden service. If the directory is already present, a
new port is added.
@@ -2717,7 +2721,7 @@ class Controller(BaseController):
config = conf,
)
- def remove_hidden_service(self, path, port = None):
+ def remove_hidden_service(self, path: str, port: Optional[int] = None) -> bool:
"""
Discontinues a given hidden service.
@@ -2759,7 +2763,7 @@ class Controller(BaseController):
return True
@with_default()
- def list_ephemeral_hidden_services(self, default = UNDEFINED, our_services = True, detached = False):
+ def list_ephemeral_hidden_services(self, default: Any = UNDEFINED, our_services: bool = True, detached: bool = False) -> Sequence[str]:
"""
list_ephemeral_hidden_services(default = UNDEFINED, our_services = True, detached = False)
@@ -2799,7 +2803,7 @@ class Controller(BaseController):
return [r for r in result if r] # drop any empty responses (GETINFO is blank if unset)
- def create_ephemeral_hidden_service(self, ports, key_type = 'NEW', key_content = 'BEST', discard_key = False, detached = False, await_publication = False, timeout = None, basic_auth = None, max_streams = None):
+ def create_ephemeral_hidden_service(self, ports: Union[int, Sequence[int], Mapping[int, str]], key_type: str = 'NEW', key_content: str = 'BEST', discard_key: bool = False, detached: bool = False, await_publication: bool = False, timeout: Optional[float] = None, basic_auth: Optional[Mapping[str, str]] = None, max_streams: Optional[int] = None) -> stem.response.add_onion.AddOnionResponse:
"""
Creates a new hidden service. Unlike
:func:`~stem.control.Controller.create_hidden_service` this style of
@@ -2905,7 +2909,7 @@ class Controller(BaseController):
start_time = time.time()
if await_publication:
- def hs_desc_listener(event):
+ def hs_desc_listener(event: stem.response.events.Event) -> None:
hs_desc_queue.put(event)
self.add_event_listener(hs_desc_listener, EventType.HS_DESC)
@@ -2983,7 +2987,7 @@ class Controller(BaseController):
return response
- def remove_ephemeral_hidden_service(self, service_id):
+ def remove_ephemeral_hidden_service(self, service_id: str) -> bool:
"""
Discontinues a given hidden service that was created with
:func:`~stem.control.Controller.create_ephemeral_hidden_service`.
@@ -3008,7 +3012,7 @@ class Controller(BaseController):
else:
raise stem.ProtocolError('DEL_ONION returned unexpected response code: %s' % response.code)
- def add_event_listener(self, listener, *events):
+ def add_event_listener(self, listener: Callable[[stem.response.events.Event], None], *events: 'stem.control.EventType') -> None:
"""
Directs further tor controller events to a given function. The function is
expected to take a single argument, which is a
@@ -3066,7 +3070,7 @@ class Controller(BaseController):
if failed_events:
raise stem.ProtocolError('SETEVENTS rejected %s' % ', '.join(failed_events))
- def remove_event_listener(self, listener):
+ def remove_event_listener(self, listener: Callable[[stem.response.events.Event], None]) -> None:
"""
Stops a listener from being notified of further tor events.
@@ -3092,7 +3096,7 @@ class Controller(BaseController):
if not response.is_ok():
raise stem.ProtocolError('SETEVENTS received unexpected response\n%s' % response)
- def _get_cache(self, param, namespace = None):
+ def _get_cache(self, param: str, namespace: Optional[str] = None) -> Any:
"""
Queries our request cache for the given key.
@@ -3109,7 +3113,7 @@ class Controller(BaseController):
cache_key = '%s.%s' % (namespace, param) if namespace else param
return self._request_cache.get(cache_key, None)
- def _get_cache_map(self, params, namespace = None):
+ def _get_cache_map(self, params: Sequence[str], namespace: Optional[str] = None) -> Dict[str, Any]:
"""
Queries our request cache for multiple entries.
@@ -3131,7 +3135,7 @@ class Controller(BaseController):
return cached_values
- def _set_cache(self, params, namespace = None):
+ def _set_cache(self, params: Mapping[str, Any], namespace: Optional[str] = None) -> None:
"""
Sets the given request cache entries. If the new cache value is **None**
then it is removed from our cache.
@@ -3173,7 +3177,7 @@ class Controller(BaseController):
else:
self._request_cache[cache_key] = value
- def _confchanged_cache_invalidation(self, params):
+ def _confchanged_cache_invalidation(self, params: Mapping[str, Any]) -> None:
"""
Drops dependent portions of the cache when configuration changes.
@@ -3197,7 +3201,7 @@ class Controller(BaseController):
self._set_cache({'exit_policy': None}) # numerous options can change our policy
- def is_caching_enabled(self):
+ def is_caching_enabled(self) -> bool:
"""
**True** if caching has been enabled, **False** otherwise.
@@ -3206,7 +3210,7 @@ class Controller(BaseController):
return self._is_caching_enabled
- def set_caching(self, enabled):
+ def set_caching(self, enabled: bool) -> None:
"""
Enables or disables caching of information retrieved from tor.
@@ -3218,7 +3222,7 @@ class Controller(BaseController):
if not self._is_caching_enabled:
self.clear_cache()
- def clear_cache(self):
+ def clear_cache(self) -> None:
"""
Drops any cached results.
"""
@@ -3227,7 +3231,7 @@ class Controller(BaseController):
self._request_cache = {}
self._last_newnym = 0.0
- def load_conf(self, configtext):
+ def load_conf(self, configtext: str) -> None:
"""
Sends the configuration text to Tor and loads it as if it has been read from
the torrc.
@@ -3247,7 +3251,7 @@ class Controller(BaseController):
elif not response.is_ok():
raise stem.ProtocolError('+LOADCONF Received unexpected response\n%s' % str(response))
- def save_conf(self, force = False):
+ def save_conf(self, force: bool = False) -> None:
"""
Saves the current configuration options into the active torrc file.
@@ -3273,7 +3277,7 @@ class Controller(BaseController):
else:
raise stem.ProtocolError('SAVECONF returned unexpected response code')
- def is_feature_enabled(self, feature):
+ def is_feature_enabled(self, feature: str) -> bool:
"""
Checks if a control connection feature is enabled. These features can be
enabled using :func:`~stem.control.Controller.enable_feature`.
@@ -3290,7 +3294,7 @@ class Controller(BaseController):
return feature in self._enabled_features
- def enable_feature(self, features):
+ def enable_feature(self, features: Union[str, Sequence[str]]) -> None:
"""
Enables features that are disabled by default to maintain backward
compatibility. Once enabled, a feature cannot be disabled and a new
@@ -3324,7 +3328,7 @@ class Controller(BaseController):
self._enabled_features += [entry.upper() for entry in features]
@with_default()
- def get_circuit(self, circuit_id, default = UNDEFINED):
+ def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent:
"""
get_circuit(circuit_id, default = UNDEFINED)
@@ -3349,7 +3353,7 @@ class Controller(BaseController):
raise ValueError("Tor currently does not have a circuit with the id of '%s'" % circuit_id)
@with_default()
- def get_circuits(self, default = UNDEFINED):
+ def get_circuits(self, default: Any = UNDEFINED) -> Sequence[stem.response.events.CircuitEvent]:
"""
get_circuits(default = UNDEFINED)
@@ -3372,7 +3376,7 @@ class Controller(BaseController):
return circuits
- def new_circuit(self, path = None, purpose = 'general', await_build = False, timeout = None):
+ def new_circuit(self, path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str:
"""
Requests a new circuit. If the path isn't provided, one is automatically
selected.
@@ -3380,7 +3384,7 @@ class Controller(BaseController):
.. versionchanged:: 1.7.0
Added the timeout argument.
- :param list,str path: one or more relays to make a circuit through
+ :param str,list path: one or more relays to make a circuit through
:param str purpose: 'general' or 'controller'
:param bool await_build: blocks until the circuit is built if **True**
:param float timeout: seconds to wait when **await_build** is **True**
@@ -3394,7 +3398,7 @@ class Controller(BaseController):
return self.extend_circuit('0', path, purpose, await_build, timeout)
- def extend_circuit(self, circuit_id = '0', path = None, purpose = 'general', await_build = False, timeout = None):
+ def extend_circuit(self, circuit_id: str = '0', path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str:
"""
Either requests the creation of a new circuit or extends an existing one.
@@ -3418,7 +3422,7 @@ class Controller(BaseController):
Added the timeout argument.
:param str circuit_id: id of a circuit to be extended
- :param list,str path: one or more relays to make a circuit through, this is
+ :param str,list path: one or more relays to make a circuit through, this is
required if the circuit id is non-zero
:param str purpose: 'general' or 'controller'
:param bool await_build: blocks until the circuit is built if **True**
@@ -3442,7 +3446,7 @@ class Controller(BaseController):
start_time = time.time()
if await_build:
- def circ_listener(event):
+ def circ_listener(event: stem.response.events.Event) -> None:
circ_queue.put(event)
self.add_event_listener(circ_listener, EventType.CIRC)
@@ -3489,7 +3493,7 @@ class Controller(BaseController):
if circ_listener:
self.remove_event_listener(circ_listener)
- def repurpose_circuit(self, circuit_id, purpose):
+ def repurpose_circuit(self, circuit_id: str, purpose: str) -> None:
"""
Changes a circuit's purpose. Currently, two purposes are recognized...
* general
@@ -3510,7 +3514,7 @@ class Controller(BaseController):
else:
raise stem.ProtocolError('SETCIRCUITPURPOSE returned unexpected response code: %s' % response.code)
- def close_circuit(self, circuit_id, flag = ''):
+ def close_circuit(self, circuit_id: str, flag: str = '') -> None:
"""
Closes the specified circuit.
@@ -3518,8 +3522,9 @@ class Controller(BaseController):
:param str flag: optional value to modify closing, the only flag available
is 'IfUnused' which will not close the circuit unless it is unused
- :raises: :class:`stem.InvalidArguments` if the circuit is unknown
- :raises: :class:`stem.InvalidRequest` if not enough information is provided
+ :raises:
+ * :class:`stem.InvalidArguments` if the circuit is unknown
+ * :class:`stem.InvalidRequest` if not enough information is provided
"""
response = self.msg('CLOSECIRCUIT %s %s' % (circuit_id, flag))
@@ -3534,7 +3539,7 @@ class Controller(BaseController):
raise stem.ProtocolError('CLOSECIRCUIT returned unexpected response code: %s' % response.code)
@with_default()
- def get_streams(self, default = UNDEFINED):
+ def get_streams(self, default: Any = UNDEFINED) -> Sequence[stem.response.events.StreamEvent]:
"""
get_streams(default = UNDEFINED)
@@ -3558,7 +3563,7 @@ class Controller(BaseController):
return streams
- def attach_stream(self, stream_id, circuit_id, exiting_hop = None):
+ def attach_stream(self, stream_id: str, circuit_id: str, exiting_hop: Optional[int] = None) -> None:
"""
Attaches a stream to a circuit.
@@ -3593,7 +3598,7 @@ class Controller(BaseController):
else:
raise stem.ProtocolError('ATTACHSTREAM returned unexpected response code: %s' % response.code)
- def close_stream(self, stream_id, reason = stem.RelayEndReason.MISC, flag = ''):
+ def close_stream(self, stream_id: str, reason: stem.RelayEndReason = stem.RelayEndReason.MISC, flag: str = '') -> None:
"""
Closes the specified stream.
@@ -3622,7 +3627,7 @@ class Controller(BaseController):
else:
raise stem.ProtocolError('CLOSESTREAM returned unexpected response code: %s' % response.code)
- def signal(self, signal):
+ def signal(self, signal: stem.Signal) -> None:
"""
Sends a signal to the Tor client.
@@ -3645,7 +3650,7 @@ class Controller(BaseController):
raise stem.ProtocolError('SIGNAL response contained unrecognized status code: %s' % response.code)
- def is_newnym_available(self):
+ def is_newnym_available(self) -> bool:
"""
Indicates if tor would currently accept a NEWNYM signal. This can only
account for signals sent via this controller.
@@ -3661,7 +3666,7 @@ class Controller(BaseController):
else:
return False
- def get_newnym_wait(self):
+ def get_newnym_wait(self) -> float:
"""
Provides the number of seconds until a NEWNYM signal would be respected.
This can only account for signals sent via this controller.
@@ -3675,7 +3680,7 @@ class Controller(BaseController):
return max(0.0, self._last_newnym + 10 - time.time())
@with_default()
- def get_effective_rate(self, default = UNDEFINED, burst = False):
+ def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int:
"""
get_effective_rate(default = UNDEFINED, burst = False)
@@ -3714,7 +3719,7 @@ class Controller(BaseController):
return value
- def map_address(self, mapping):
+ def map_address(self, mapping: Mapping[str, str]) -> Dict[str, str]:
"""
Map addresses to replacement addresses. Tor replaces subseqent connections
to the original addresses with the replacement addresses.
@@ -3726,11 +3731,11 @@ class Controller(BaseController):
:param dict mapping: mapping of original addresses to replacement addresses
+ :returns: **dict** with 'original -> replacement' address mappings
+
:raises:
* :class:`stem.InvalidRequest` if the addresses are malformed
* :class:`stem.OperationFailed` if Tor couldn't fulfill the request
-
- :returns: **dict** with 'original -> replacement' address mappings
"""
mapaddress_arg = ' '.join(['%s=%s' % (k, v) for (k, v) in list(mapping.items())])
@@ -3739,7 +3744,7 @@ class Controller(BaseController):
return response.entries
- def drop_guards(self):
+ def drop_guards(self) -> None:
"""
Drops our present guard nodes and picks a new set.
@@ -3750,7 +3755,7 @@ class Controller(BaseController):
self.msg('DROPGUARDS')
- def _post_authentication(self):
+ def _post_authentication(self) -> None:
super(Controller, self)._post_authentication()
# try to re-attach event listeners to the new instance
@@ -3788,7 +3793,7 @@ class Controller(BaseController):
else:
log.warn('We were unable assert ownership of tor through TAKEOWNERSHIP, despite being configured to be the owning process through __OwningControllerProcess. (%s)' % response)
- def _handle_event(self, event_message):
+ def _handle_event(self, event_message: str) -> None:
try:
stem.response.convert('EVENT', event_message)
event_type = event_message.type
@@ -3805,7 +3810,7 @@ class Controller(BaseController):
except Exception as exc:
log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event_message))
- def _attach_listeners(self):
+ def _attach_listeners(self) -> Tuple[Sequence[str], Sequence[str]]:
"""
Attempts to subscribe to the self._event_listeners events from tor. This is
a no-op if we're not currently authenticated.
@@ -3849,7 +3854,7 @@ class Controller(BaseController):
return (set_events, failed_events)
-def _parse_circ_path(path):
+def _parse_circ_path(path: str) -> Sequence[Tuple[str, str]]:
"""
Parses a circuit path as a list of **(fingerprint, nickname)** tuples. Tor
circuit paths are defined as being of the form...
@@ -3892,7 +3897,7 @@ def _parse_circ_path(path):
return []
-def _parse_circ_entry(entry):
+def _parse_circ_entry(entry: str) -> Tuple[str, str]:
"""
Parses a single relay's 'LongName' or 'ServerID'. See the
:func:`~stem.control._parse_circ_path` function for more information.
@@ -3930,7 +3935,7 @@ def _parse_circ_entry(entry):
@with_default()
-def _case_insensitive_lookup(entries, key, default = UNDEFINED):
+def _case_insensitive_lookup(entries: Union[Sequence[str], Mapping[str, Any]], key: str, default: Any = UNDEFINED) -> Any:
"""
Makes a case insensitive lookup within a list or dictionary, providing the
first matching entry that we come across.
@@ -3957,7 +3962,7 @@ def _case_insensitive_lookup(entries, key, default = UNDEFINED):
raise ValueError("key '%s' doesn't exist in dict: %s" % (key, entries))
-def _get_with_timeout(event_queue, timeout, start_time):
+def _get_with_timeout(event_queue: queue.Queue, timeout: float, start_time: float) -> Any:
"""
Pulls an item from a queue with a given timeout.
"""
diff --git a/stem/descriptor/__init__.py b/stem/descriptor/__init__.py
index ff273405..9c769749 100644
--- a/stem/descriptor/__init__.py
+++ b/stem/descriptor/__init__.py
@@ -120,6 +120,8 @@ import stem.util.enum
import stem.util.str_tools
import stem.util.system
+from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type
+
__all__ = [
'bandwidth_file',
'certificate',
@@ -192,7 +194,7 @@ class _Compression(object):
.. versionadded:: 1.8.0
"""
- def __init__(self, name, module, encoding, extension, decompression_func):
+ def __init__(self, name: str, module: Optional[str], encoding: str, extension: str, decompression_func: Callable[[Any, str], bytes]) -> None:
if module is None:
self._module = None
self.available = True
@@ -222,7 +224,7 @@ class _Compression(object):
self._module_name = module
self._decompression_func = decompression_func
- def decompress(self, content):
+ def decompress(self, content: bytes) -> bytes:
"""
Decompresses the given content via this method.
@@ -250,11 +252,11 @@ class _Compression(object):
except Exception as exc:
raise IOError('Failed to decompress as %s: %s' % (self, exc))
- def __str__(self):
+ def __str__(self) -> str:
return self._name
-def _zstd_decompress(module, content):
+def _zstd_decompress(module: Any, content: str) -> bytes:
output_buffer = io.BytesIO()
with module.ZstdDecompressor().write_to(output_buffer) as decompressor:
@@ -286,7 +288,7 @@ class TypeAnnotation(collections.namedtuple('TypeAnnotation', ['name', 'major_ve
:var int minor_version: minor version number
"""
- def __str__(self):
+ def __str__(self) -> str:
return '@type %s %s.%s' % (self.name, self.major_version, self.minor_version)
@@ -302,7 +304,7 @@ class SigningKey(collections.namedtuple('SigningKey', ['private', 'public', 'pub
"""
-def parse_file(descriptor_file, descriptor_type = None, validate = False, document_handler = DocumentHandler.ENTRIES, normalize_newlines = None, **kwargs):
+def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, normalize_newlines: Optional[bool] = None, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
"""
Simple function to read the descriptor contents from a file, providing an
iterator for its :class:`~stem.descriptor.__init__.Descriptor` contents.
@@ -405,7 +407,7 @@ def parse_file(descriptor_file, descriptor_type = None, validate = False, docume
descriptor_path = getattr(descriptor_file, 'name', None)
filename = '<undefined>' if descriptor_path is None else os.path.basename(descriptor_file.name)
- def parse(descriptor_file):
+ def parse(descriptor_file: BinaryIO) -> Iterator['stem.descriptor.Descriptor']:
if normalize_newlines:
descriptor_file = NewlineNormalizer(descriptor_file)
@@ -448,20 +450,20 @@ def parse_file(descriptor_file, descriptor_type = None, validate = False, docume
yield desc
-def _parse_file_for_path(descriptor_file, *args, **kwargs):
+def _parse_file_for_path(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
with open(descriptor_file, 'rb') as desc_file:
for desc in parse_file(desc_file, *args, **kwargs):
yield desc
-def _parse_file_for_tar_path(descriptor_file, *args, **kwargs):
+def _parse_file_for_tar_path(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
with tarfile.open(descriptor_file) as tar_file:
for desc in parse_file(tar_file, *args, **kwargs):
desc._set_path(os.path.abspath(descriptor_file))
yield desc
-def _parse_file_for_tarfile(descriptor_file, *args, **kwargs):
+def _parse_file_for_tarfile(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
for tar_entry in descriptor_file:
if tar_entry.isfile():
entry = descriptor_file.extractfile(tar_entry)
@@ -477,7 +479,7 @@ def _parse_file_for_tarfile(descriptor_file, *args, **kwargs):
entry.close()
-def _parse_metrics_file(descriptor_type, major_version, minor_version, descriptor_file, validate, document_handler, **kwargs):
+def _parse_metrics_file(descriptor_type: Type['stem.descriptor.Descriptor'], major_version: int, minor_version: int, descriptor_file: BinaryIO, validate: bool, document_handler: 'stem.descriptor.DocumentHandler', **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
# Parses descriptor files from metrics, yielding individual descriptors. This
# throws a TypeError if the descriptor_type or version isn't recognized.
@@ -547,7 +549,7 @@ def _parse_metrics_file(descriptor_type, major_version, minor_version, descripto
raise TypeError("Unrecognized metrics descriptor format. type: '%s', version: '%i.%i'" % (descriptor_type, major_version, minor_version))
-def _descriptor_content(attr = None, exclude = (), header_template = (), footer_template = ()):
+def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] = (), header_template: Sequence[str] = (), footer_template: Sequence[str] = ()) -> bytes:
"""
Constructs a minimal descriptor with the given attributes. The content we
provide back is of the form...
@@ -619,28 +621,28 @@ def _descriptor_content(attr = None, exclude = (), header_template = (), footer_
return stem.util.str_tools._to_bytes('\n'.join(header_content + remainder + footer_content))
-def _value(line, entries):
+def _value(line: str, entries: Dict[str, Sequence[str]]) -> str:
return entries[line][0][0]
-def _values(line, entries):
+def _values(line: str, entries: Dict[str, Sequence[str]]) -> Sequence[str]:
return [entry[0] for entry in entries[line]]
-def _parse_simple_line(keyword, attribute, func = None):
- def _parse(descriptor, entries):
+def _parse_simple_line(keyword: str, attribute: str, func: Callable[[str], str] = None) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value(keyword, entries)
setattr(descriptor, attribute, func(value) if func else value)
return _parse
-def _parse_if_present(keyword, attribute):
+def _parse_if_present(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
return lambda descriptor, entries: setattr(descriptor, attribute, keyword in entries)
-def _parse_bytes_line(keyword, attribute):
- def _parse(descriptor, entries):
+def _parse_bytes_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
line_match = re.search(stem.util.str_tools._to_bytes('^(opt )?%s(?:[%s]+(.*))?$' % (keyword, WHITESPACE)), descriptor.get_bytes(), re.MULTILINE)
result = None
@@ -653,8 +655,8 @@ def _parse_bytes_line(keyword, attribute):
return _parse
-def _parse_int_line(keyword, attribute, allow_negative = True):
- def _parse(descriptor, entries):
+def _parse_int_line(keyword: str, attribute: str, allow_negative: bool = True) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value(keyword, entries)
try:
@@ -670,10 +672,10 @@ def _parse_int_line(keyword, attribute, allow_negative = True):
return _parse
-def _parse_timestamp_line(keyword, attribute):
+def _parse_timestamp_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
# "<keyword>" YYYY-MM-DD HH:MM:SS
- def _parse(descriptor, entries):
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value(keyword, entries)
try:
@@ -684,10 +686,10 @@ def _parse_timestamp_line(keyword, attribute):
return _parse
-def _parse_forty_character_hex(keyword, attribute):
+def _parse_forty_character_hex(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
# format of fingerprints, sha1 digests, etc
- def _parse(descriptor, entries):
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value(keyword, entries)
if not stem.util.tor_tools.is_hex_digits(value, 40):
@@ -698,8 +700,8 @@ def _parse_forty_character_hex(keyword, attribute):
return _parse
-def _parse_protocol_line(keyword, attribute):
- def _parse(descriptor, entries):
+def _parse_protocol_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# parses 'protocol' entries like: Cons=1-2 Desc=1-2 DirCache=1 HSDir=1
value = _value(keyword, entries)
@@ -729,8 +731,8 @@ def _parse_protocol_line(keyword, attribute):
return _parse
-def _parse_key_block(keyword, attribute, expected_block_type, value_attribute = None):
- def _parse(descriptor, entries):
+def _parse_key_block(keyword: str, attribute: str, expected_block_type: str, value_attribute: Optional[str] = None) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value, block_type, block_contents = entries[keyword][0]
if not block_contents or block_type != expected_block_type:
@@ -744,7 +746,7 @@ def _parse_key_block(keyword, attribute, expected_block_type, value_attribute =
return _parse
-def _mappings_for(keyword, value, require_value = False, divider = ' '):
+def _mappings_for(keyword: str, value: str, require_value: bool = False, divider: str = ' ') -> Iterator[Tuple[str, str]]:
"""
Parses an attribute as a series of 'key=value' mappings. Unlike _parse_*
functions this is a helper, returning the attribute value rather than setting
@@ -777,7 +779,7 @@ def _mappings_for(keyword, value, require_value = False, divider = ' '):
yield k, v
-def _copy(default):
+def _copy(default: Any) -> Any:
if default is None or isinstance(default, (bool, stem.exit_policy.ExitPolicy)):
return default # immutable
elif default in EMPTY_COLLECTION:
@@ -786,7 +788,7 @@ def _copy(default):
return copy.copy(default)
-def _encode_digest(hash_value, encoding):
+def _encode_digest(hash_value: bytes, encoding: 'stem.descriptor.DigestEncoding') -> str:
"""
Encodes a hash value with the given HashEncoding.
"""
diff --git a/stem/descriptor/bandwidth_file.py b/stem/descriptor/bandwidth_file.py
index 3cf20595..49df3173 100644
--- a/stem/descriptor/bandwidth_file.py
+++ b/stem/descriptor/bandwidth_file.py
@@ -21,6 +21,8 @@ import time
import stem.util.str_tools
+from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Type
+
from stem.descriptor import (
_mappings_for,
Descriptor,
@@ -50,7 +52,7 @@ class RecentStats(object):
:var RelayFailures relay_failures: number of relays we failed to measure
"""
- def __init__(self):
+ def __init__(self) -> None:
self.consensus_count = None
self.prioritized_relays = None
self.prioritized_relay_lists = None
@@ -73,7 +75,7 @@ class RelayFailures(object):
by default)
"""
- def __init__(self):
+ def __init__(self) -> None:
self.no_measurement = None
self.insuffient_period = None
self.insufficient_measurements = None
@@ -83,22 +85,22 @@ class RelayFailures(object):
# Converts header attributes to a given type. Malformed fields should be
# ignored according to the spec.
-def _str(val):
+def _str(val: str) -> str:
return val # already a str
-def _int(val):
+def _int(val: str) -> int:
return int(val) if (val and val.isdigit()) else None
-def _date(val):
+def _date(val: str) -> datetime.datetime:
try:
return stem.util.str_tools._parse_iso_timestamp(val)
except ValueError:
return None # not an iso formatted date
-def _csv(val):
+def _csv(val: str) -> Sequence[str]:
return list(map(lambda v: v.strip(), val.split(','))) if val is not None else None
@@ -150,7 +152,7 @@ HEADER_DEFAULT = {
}
-def _parse_file(descriptor_file, validate = False, **kwargs):
+def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']:
"""
Iterates over the bandwidth authority metrics in a file.
@@ -169,7 +171,7 @@ def _parse_file(descriptor_file, validate = False, **kwargs):
yield BandwidthFile(descriptor_file.read(), validate, **kwargs)
-def _parse_header(descriptor, entries):
+def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
header = collections.OrderedDict()
content = io.BytesIO(descriptor.get_bytes())
@@ -214,7 +216,7 @@ def _parse_header(descriptor, entries):
raise ValueError("The 'version' header must be in the second position")
-def _parse_timestamp(descriptor, entries):
+def _parse_timestamp(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
first_line = io.BytesIO(descriptor.get_bytes()).readline().strip()
if first_line.isdigit():
@@ -223,7 +225,7 @@ def _parse_timestamp(descriptor, entries):
raise ValueError("First line should be a unix timestamp, but was '%s'" % first_line)
-def _parse_body(descriptor, entries):
+def _parse_body(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# In version 1.0.0 the body is everything after the first line. Otherwise
# it's everything after the header's divider.
@@ -301,7 +303,7 @@ class BandwidthFile(Descriptor):
ATTRIBUTES.update(dict([(k, (None, _parse_header)) for k in HEADER_ATTR.keys()]))
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.bandwidth_file.BandwidthFile'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
"""
Creates descriptor content with the given attributes. This descriptor type
differs somewhat from others and treats our attr/exclude attributes as
@@ -352,7 +354,7 @@ class BandwidthFile(Descriptor):
return b'\n'.join(lines)
- def __init__(self, raw_content, validate = False):
+ def __init__(self, raw_content: str, validate: bool = False) -> None:
super(BandwidthFile, self).__init__(raw_content, lazy_load = not validate)
if validate:
diff --git a/stem/descriptor/certificate.py b/stem/descriptor/certificate.py
index 0522b883..6956a60f 100644
--- a/stem/descriptor/certificate.py
+++ b/stem/descriptor/certificate.py
@@ -64,6 +64,7 @@ import stem.util.enum
import stem.util.str_tools
from stem.client.datatype import CertType, Field, Size, split
+from typing import Callable, Dict, Optional, Sequence, Tuple, Union
ED25519_KEY_LENGTH = 32
ED25519_HEADER_LENGTH = 40
@@ -88,7 +89,7 @@ class Ed25519Extension(Field):
:var bytes data: data the extension concerns
"""
- def __init__(self, ext_type, flag_val, data):
+ def __init__(self, ext_type: 'stem.descriptor.certificate.ExtensionType', flag_val: int, data: bytes) -> None:
self.type = ext_type
self.flags = []
self.flag_int = flag_val if flag_val else 0
@@ -104,7 +105,7 @@ class Ed25519Extension(Field):
if ext_type == ExtensionType.HAS_SIGNING_KEY and len(data) != 32:
raise ValueError('Ed25519 HAS_SIGNING_KEY extension must be 32 bytes, but was %i.' % len(data))
- def pack(self):
+ def pack(self) -> bytes:
encoded = bytearray()
encoded += Size.SHORT.pack(len(self.data))
encoded += Size.CHAR.pack(self.type)
@@ -113,7 +114,7 @@ class Ed25519Extension(Field):
return bytes(encoded)
@staticmethod
- def pop(content):
+ def pop(content: bytes) -> Tuple['stem.descriptor.certificate.Ed25519Extension', bytes]:
if len(content) < 4:
raise ValueError('Ed25519 extension is missing header fields')
@@ -127,7 +128,7 @@ class Ed25519Extension(Field):
return Ed25519Extension(ext_type, flags, data), content
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'type', 'flag_int', 'data', cache = True)
@@ -138,11 +139,11 @@ class Ed25519Certificate(object):
:var int version: certificate format version
"""
- def __init__(self, version):
+ def __init__(self, version: int) -> None:
self.version = version
@staticmethod
- def unpack(content):
+ def unpack(content: bytes) -> 'stem.descriptor.certificate.Ed25519Certificate':
"""
Parses a byte encoded ED25519 certificate.
@@ -162,7 +163,7 @@ class Ed25519Certificate(object):
raise ValueError('Ed25519 certificate is version %i. Parser presently only supports version 1.' % version)
@staticmethod
- def from_base64(content):
+ def from_base64(content: str) -> 'stem.descriptor.certificate.Ed25519Certificate':
"""
Parses a base64 encoded ED25519 certificate.
@@ -189,7 +190,7 @@ class Ed25519Certificate(object):
except (TypeError, binascii.Error) as exc:
raise ValueError("Ed25519 certificate wasn't propoerly base64 encoded (%s):\n%s" % (exc, content))
- def pack(self):
+ def pack(self) -> bytes:
"""
Encoded byte representation of our certificate.
@@ -198,7 +199,7 @@ class Ed25519Certificate(object):
raise NotImplementedError('Certificate encoding has not been implemented for %s' % type(self).__name__)
- def to_base64(self, pem = False):
+ def to_base64(self, pem: bool = False) -> str:
"""
Base64 encoded certificate data.
@@ -206,7 +207,7 @@ class Ed25519Certificate(object):
<https://en.wikipedia.org/wiki/Privacy-Enhanced_Mail>`_, for more
information see `RFC 7468 <https://tools.ietf.org/html/rfc7468>`_
- :returns: **unicode** for our encoded certificate representation
+ :returns: **str** for our encoded certificate representation
"""
encoded = b'\n'.join(stem.util.str_tools._split_by_length(base64.b64encode(self.pack()), 64))
@@ -217,7 +218,7 @@ class Ed25519Certificate(object):
return stem.util.str_tools._to_unicode(encoded)
@staticmethod
- def _from_descriptor(keyword, attribute):
+ def _from_descriptor(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
def _parse(descriptor, entries):
value, block_type, block_contents = entries[keyword][0]
@@ -228,7 +229,7 @@ class Ed25519Certificate(object):
return _parse
- def __str__(self):
+ def __str__(self) -> str:
return self.to_base64(pem = True)
@@ -252,7 +253,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
is unavailable
"""
- def __init__(self, cert_type = None, expiration = None, key_type = None, key = None, extensions = None, signature = None, signing_key = None):
+ def __init__(self, cert_type: Optional['stem.client.datatype.CertType'] = None, expiration: Optional[datetime.datetime] = None, key_type: Optional[int] = None, key: Optional[bytes] = None, extensions: Optional[Sequence['stem.descriptor.certificate.Ed25519Extension']] = None, signature: Optional[bytes] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> None:
super(Ed25519CertificateV1, self).__init__(1)
if cert_type is None:
@@ -284,7 +285,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
elif self.type == CertType.UNKNOWN:
raise ValueError('Ed25519 certificate type %i is unrecognized' % self.type_int)
- def pack(self):
+ def pack(self) -> bytes:
encoded = bytearray()
encoded += Size.CHAR.pack(self.version)
encoded += Size.CHAR.pack(self.type_int)
@@ -302,7 +303,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
return bytes(encoded)
@staticmethod
- def unpack(content):
+ def unpack(content: bytes) -> 'stem.descriptor.certificate.Ed25519CertificateV1':
if len(content) < ED25519_HEADER_LENGTH + ED25519_SIGNATURE_LENGTH:
raise ValueError('Ed25519 certificate was %i bytes, but should be at least %i' % (len(content), ED25519_HEADER_LENGTH + ED25519_SIGNATURE_LENGTH))
@@ -329,7 +330,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
return Ed25519CertificateV1(cert_type, datetime.datetime.utcfromtimestamp(expiration_hours * 3600), key_type, key, extensions, signature)
- def is_expired(self):
+ def is_expired(self) -> bool:
"""
Checks if this certificate is presently expired or not.
@@ -338,7 +339,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
return datetime.datetime.now() > self.expiration
- def signing_key(self):
+ def signing_key(self) -> bytes:
"""
Provides this certificate's signing key.
@@ -354,7 +355,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
return None
- def validate(self, descriptor):
+ def validate(self, descriptor: Union['stem.descriptor.server_descriptor.RelayDescriptor', 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3']) -> None:
"""
Validate our descriptor content matches its ed25519 signature. Supported
descriptor types include...
@@ -410,7 +411,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
raise ValueError('Descriptor Ed25519 certificate signature invalid (signature forged or corrupt)')
@staticmethod
- def _signed_content(descriptor):
+ def _signed_content(descriptor: Union['stem.descriptor.server_descriptor.RelayDescriptor', 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3']) -> bytes:
"""
Provides this descriptor's signing constant, appended with the portion of
the descriptor that's signed.
diff --git a/stem/descriptor/collector.py b/stem/descriptor/collector.py
index 7aeb298b..1f1b1e95 100644
--- a/stem/descriptor/collector.py
+++ b/stem/descriptor/collector.py
@@ -63,6 +63,7 @@ import stem.util.connection
import stem.util.str_tools
from stem.descriptor import Compression, DocumentHandler
+from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union
COLLECTOR_URL = 'https://collector.torproject.org/'
REFRESH_INDEX_RATE = 3600 # get new index if cached copy is an hour old
@@ -76,7 +77,7 @@ SEC_DATE = re.compile('(\\d{4}-\\d{2}-\\d{2}-\\d{2}-\\d{2}-\\d{2})')
FUTURE = datetime.datetime(9999, 1, 1)
-def get_instance():
+def get_instance() -> 'stem.descriptor.collector.CollecTor':
"""
Provides the singleton :class:`~stem.descriptor.collector.CollecTor`
used for this module's shorthand functions.
@@ -92,7 +93,7 @@ def get_instance():
return SINGLETON_COLLECTOR
-def get_server_descriptors(start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3):
+def get_server_descriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.server_descriptor.RelayDescriptor']:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_server_descriptors`
@@ -103,7 +104,7 @@ def get_server_descriptors(start = None, end = None, cache_to = None, bridge = F
yield desc
-def get_extrainfo_descriptors(start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3):
+def get_extrainfo_descriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor']:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_extrainfo_descriptors`
@@ -114,7 +115,7 @@ def get_extrainfo_descriptors(start = None, end = None, cache_to = None, bridge
yield desc
-def get_microdescriptors(start = None, end = None, cache_to = None, timeout = None, retries = 3):
+def get_microdescriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_microdescriptors`
@@ -125,7 +126,7 @@ def get_microdescriptors(start = None, end = None, cache_to = None, timeout = No
yield desc
-def get_consensus(start = None, end = None, cache_to = None, document_handler = DocumentHandler.ENTRIES, version = 3, microdescriptor = False, bridge = False, timeout = None, retries = 3):
+def get_consensus(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_consensus`
@@ -136,7 +137,7 @@ def get_consensus(start = None, end = None, cache_to = None, document_handler =
yield desc
-def get_key_certificates(start = None, end = None, cache_to = None, timeout = None, retries = 3):
+def get_key_certificates(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_key_certificates`
@@ -147,7 +148,7 @@ def get_key_certificates(start = None, end = None, cache_to = None, timeout = No
yield desc
-def get_bandwidth_files(start = None, end = None, cache_to = None, timeout = None, retries = 3):
+def get_bandwidth_files(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_bandwidth_files`
@@ -158,7 +159,7 @@ def get_bandwidth_files(start = None, end = None, cache_to = None, timeout = Non
yield desc
-def get_exit_lists(start = None, end = None, cache_to = None, timeout = None, retries = 3):
+def get_exit_lists(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_exit_lists`
@@ -187,7 +188,7 @@ class File(object):
:var datetime last_modified: when the file was last modified
"""
- def __init__(self, path, types, size, sha256, first_published, last_published, last_modified):
+ def __init__(self, path: str, types: Tuple[str], size: int, sha256: str, first_published: datetime.datetime, last_published: datetime.datetime, last_modified: datetime.datetime) -> None:
self.path = path
self.types = tuple(types) if types else ()
self.compression = File._guess_compression(path)
@@ -205,7 +206,7 @@ class File(object):
else:
self.start, self.end = File._guess_time_range(path)
- def read(self, directory = None, descriptor_type = None, start = None, end = None, document_handler = DocumentHandler.ENTRIES, timeout = None, retries = 3):
+ def read(self, directory: Optional[str] = None, descriptor_type: Optional[str] = None, start: datetime.datetime = None, end: datetime.datetime = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.Descriptor']:
"""
Provides descriptors from this archive. Descriptors are downloaded or read
from disk as follows...
@@ -289,7 +290,7 @@ class File(object):
yield desc
- def download(self, directory, decompress = True, timeout = None, retries = 3, overwrite = False):
+ def download(self, directory: str, decompress: bool = True, timeout: Optional[int] = None, retries: Optional[int] = 3, overwrite: bool = False) -> str:
"""
Downloads this file to the given location. If a file already exists this is
a no-op.
@@ -345,7 +346,7 @@ class File(object):
return path
@staticmethod
- def _guess_compression(path):
+ def _guess_compression(path) -> 'stem.descriptor.Compression':
"""
Determine file comprssion from CollecTor's filename.
"""
@@ -357,7 +358,7 @@ class File(object):
return Compression.PLAINTEXT
@staticmethod
- def _guess_time_range(path):
+ def _guess_time_range(path) -> Tuple[datetime.datetime, datetime.datetime]:
"""
Attemt to determine the (start, end) time range from CollecTor's filename.
This provides (None, None) if this cannot be determined.
@@ -398,7 +399,7 @@ class CollecTor(object):
:var float timeout: duration before we'll time out our request
"""
- def __init__(self, retries = 2, timeout = None):
+ def __init__(self, retries: Optional[int] = 2, timeout: Optional[int] = None) -> None:
self.retries = retries
self.timeout = timeout
@@ -406,7 +407,7 @@ class CollecTor(object):
self._cached_files = None
self._cached_index_at = 0
- def get_server_descriptors(self, start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3):
+ def get_server_descriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.server_descriptor.RelayDescriptor']:
"""
Provides server descriptors published during the given time range, sorted
oldest to newest.
@@ -433,7 +434,7 @@ class CollecTor(object):
for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries):
yield desc
- def get_extrainfo_descriptors(self, start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3):
+ def get_extrainfo_descriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor']:
"""
Provides extrainfo descriptors published during the given time range,
sorted oldest to newest.
@@ -460,7 +461,7 @@ class CollecTor(object):
for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries):
yield desc
- def get_microdescriptors(self, start = None, end = None, cache_to = None, timeout = None, retries = 3):
+ def get_microdescriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']:
"""
Provides microdescriptors estimated to be published during the given time
range, sorted oldest to newest. Unlike server/extrainfo descriptors,
@@ -494,7 +495,7 @@ class CollecTor(object):
for desc in f.read(cache_to, 'microdescriptor', start, end, timeout = timeout, retries = retries):
yield desc
- def get_consensus(self, start = None, end = None, cache_to = None, document_handler = DocumentHandler.ENTRIES, version = 3, microdescriptor = False, bridge = False, timeout = None, retries = 3):
+ def get_consensus(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']:
"""
Provides consensus router status entries published during the given time
range, sorted oldest to newest.
@@ -538,7 +539,7 @@ class CollecTor(object):
for desc in f.read(cache_to, desc_type, start, end, document_handler, timeout = timeout, retries = retries):
yield desc
- def get_key_certificates(self, start = None, end = None, cache_to = None, timeout = None, retries = 3):
+ def get_key_certificates(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']:
"""
Directory authority key certificates for the given time range,
sorted oldest to newest.
@@ -562,7 +563,7 @@ class CollecTor(object):
for desc in f.read(cache_to, 'dir-key-certificate-3', start, end, timeout = timeout, retries = retries):
yield desc
- def get_bandwidth_files(self, start = None, end = None, cache_to = None, timeout = None, retries = 3):
+ def get_bandwidth_files(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']:
"""
Bandwidth authority heuristics for the given time range, sorted oldest to
newest.
@@ -586,7 +587,7 @@ class CollecTor(object):
for desc in f.read(cache_to, 'bandwidth-file', start, end, timeout = timeout, retries = retries):
yield desc
- def get_exit_lists(self, start = None, end = None, cache_to = None, timeout = None, retries = 3):
+ def get_exit_lists(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']:
"""
`TorDNSEL exit lists <https://www.torproject.org/projects/tordnsel.html.en>`_
for the given time range, sorted oldest to newest.
@@ -610,7 +611,7 @@ class CollecTor(object):
for desc in f.read(cache_to, 'tordnsel', start, end, timeout = timeout, retries = retries):
yield desc
- def index(self, compression = 'best'):
+ def index(self, compression: Union[str, 'descriptor.Compression'] = 'best') -> Dict[str, Any]:
"""
Provides the archives available in CollecTor.
@@ -645,7 +646,7 @@ class CollecTor(object):
return self._cached_index
- def files(self, descriptor_type = None, start = None, end = None):
+ def files(self, descriptor_type: str = None, start: datetime.datetime = None, end: datetime.datetime = None) -> Sequence['stem.descriptor.collector.File']:
"""
Provides files CollecTor presently has, sorted oldest to newest.
@@ -680,7 +681,7 @@ class CollecTor(object):
return matches
@staticmethod
- def _files(val, path):
+ def _files(val: str, path: Sequence[str]) -> Sequence['stem.descriptor.collector.File']:
"""
Recursively provies files within the index.
diff --git a/stem/descriptor/extrainfo_descriptor.py b/stem/descriptor/extrainfo_descriptor.py
index d92bb770..6aca3c29 100644
--- a/stem/descriptor/extrainfo_descriptor.py
+++ b/stem/descriptor/extrainfo_descriptor.py
@@ -67,6 +67,7 @@ Extra-info descriptors are available from a few sources...
===================== ===========
"""
+import datetime
import functools
import hashlib
import re
@@ -75,6 +76,8 @@ import stem.util.connection
import stem.util.enum
import stem.util.str_tools
+from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
+
from stem.descriptor import (
PGP_BLOCK_END,
Descriptor,
@@ -163,7 +166,7 @@ _timestamp_re = re.compile('^(.*) \\(([0-9]+) s\\)( .*)?$')
_locale_re = re.compile('^[a-zA-Z0-9\\?]{2}$')
-def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs):
+def _parse_file(descriptor_file: BinaryIO, is_bridge = False, validate = False, **kwargs: Any) -> Iterator['stem.descriptor.extrainfo_descriptor.ExtraInfoDescriptor']:
"""
Iterates over the extra-info descriptors in a file.
@@ -204,7 +207,7 @@ def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs):
break # done parsing file
-def _parse_timestamp_and_interval(keyword, content):
+def _parse_timestamp_and_interval(keyword: str, content: str) -> Tuple[datetime.datetime, int, str]:
"""
Parses a 'YYYY-MM-DD HH:MM:SS (NSEC s) *' entry.
@@ -238,7 +241,7 @@ def _parse_timestamp_and_interval(keyword, content):
raise ValueError("%s line's timestamp wasn't parsable: %s" % (keyword, line))
-def _parse_extra_info_line(descriptor, entries):
+def _parse_extra_info_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "extra-info" Nickname Fingerprint
value = _value('extra-info', entries)
@@ -255,7 +258,7 @@ def _parse_extra_info_line(descriptor, entries):
descriptor.fingerprint = extra_info_comp[1]
-def _parse_transport_line(descriptor, entries):
+def _parse_transport_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "transport" transportname address:port [arglist]
# Everything after the transportname is scrubbed in published bridge
# descriptors, so we'll never see it in practice.
@@ -301,7 +304,7 @@ def _parse_transport_line(descriptor, entries):
descriptor.transport = transports
-def _parse_padding_counts_line(descriptor, entries):
+def _parse_padding_counts_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "padding-counts" YYYY-MM-DD HH:MM:SS (NSEC s) key=val key=val...
value = _value('padding-counts', entries)
@@ -316,7 +319,7 @@ def _parse_padding_counts_line(descriptor, entries):
setattr(descriptor, 'padding_counts', counts)
-def _parse_dirreq_line(keyword, recognized_counts_attr, unrecognized_counts_attr, descriptor, entries):
+def _parse_dirreq_line(keyword: str, recognized_counts_attr: str, unrecognized_counts_attr: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value(keyword, entries)
recognized_counts = {}
@@ -340,7 +343,7 @@ def _parse_dirreq_line(keyword, recognized_counts_attr, unrecognized_counts_attr
setattr(descriptor, unrecognized_counts_attr, unrecognized_counts)
-def _parse_dirreq_share_line(keyword, attribute, descriptor, entries):
+def _parse_dirreq_share_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value(keyword, entries)
if not value.endswith('%'):
@@ -353,7 +356,7 @@ def _parse_dirreq_share_line(keyword, attribute, descriptor, entries):
setattr(descriptor, attribute, float(value[:-1]) / 100)
-def _parse_cell_line(keyword, attribute, descriptor, entries):
+def _parse_cell_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "<keyword>" num,...,num
value = _value(keyword, entries)
@@ -375,7 +378,7 @@ def _parse_cell_line(keyword, attribute, descriptor, entries):
raise exc
-def _parse_timestamp_and_interval_line(keyword, end_attribute, interval_attribute, descriptor, entries):
+def _parse_timestamp_and_interval_line(keyword: str, end_attribute: str, interval_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "<keyword>" YYYY-MM-DD HH:MM:SS (NSEC s)
timestamp, interval, _ = _parse_timestamp_and_interval(keyword, _value(keyword, entries))
@@ -383,7 +386,7 @@ def _parse_timestamp_and_interval_line(keyword, end_attribute, interval_attribut
setattr(descriptor, interval_attribute, interval)
-def _parse_conn_bi_direct_line(descriptor, entries):
+def _parse_conn_bi_direct_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "conn-bi-direct" YYYY-MM-DD HH:MM:SS (NSEC s) BELOW,READ,WRITE,BOTH
value = _value('conn-bi-direct', entries)
@@ -401,7 +404,7 @@ def _parse_conn_bi_direct_line(descriptor, entries):
descriptor.conn_bi_direct_both = int(stats[3])
-def _parse_history_line(keyword, end_attribute, interval_attribute, values_attribute, descriptor, entries):
+def _parse_history_line(keyword: str, end_attribute: str, interval_attribute: str, values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "<keyword>" YYYY-MM-DD HH:MM:SS (NSEC s) NUM,NUM,NUM,NUM,NUM...
value = _value(keyword, entries)
@@ -419,7 +422,7 @@ def _parse_history_line(keyword, end_attribute, interval_attribute, values_attri
setattr(descriptor, values_attribute, history_values)
-def _parse_port_count_line(keyword, attribute, descriptor, entries):
+def _parse_port_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "<keyword>" port=N,port=N,...
value, port_mappings = _value(keyword, entries), {}
@@ -434,7 +437,7 @@ def _parse_port_count_line(keyword, attribute, descriptor, entries):
setattr(descriptor, attribute, port_mappings)
-def _parse_geoip_to_count_line(keyword, attribute, descriptor, entries):
+def _parse_geoip_to_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "<keyword>" CC=N,CC=N,...
#
# The maxmind geoip (https://www.maxmind.com/app/iso3166) has numeric
@@ -454,7 +457,7 @@ def _parse_geoip_to_count_line(keyword, attribute, descriptor, entries):
setattr(descriptor, attribute, locale_usage)
-def _parse_bridge_ip_versions_line(descriptor, entries):
+def _parse_bridge_ip_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value, ip_versions = _value('bridge-ip-versions', entries), {}
for protocol, count in _mappings_for('bridge-ip-versions', value, divider = ','):
@@ -466,7 +469,7 @@ def _parse_bridge_ip_versions_line(descriptor, entries):
descriptor.ip_versions = ip_versions
-def _parse_bridge_ip_transports_line(descriptor, entries):
+def _parse_bridge_ip_transports_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value, ip_transports = _value('bridge-ip-transports', entries), {}
for protocol, count in _mappings_for('bridge-ip-transports', value, divider = ','):
@@ -478,7 +481,7 @@ def _parse_bridge_ip_transports_line(descriptor, entries):
descriptor.ip_transports = ip_transports
-def _parse_hs_stats(keyword, stat_attribute, extra_attribute, descriptor, entries):
+def _parse_hs_stats(keyword: str, stat_attribute: str, extra_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "<keyword>" num key=val key=val...
value, stat, extra = _value(keyword, entries), None, {}
@@ -814,7 +817,7 @@ class ExtraInfoDescriptor(Descriptor):
'bridge-ip-transports': _parse_bridge_ip_transports_line,
}
- def __init__(self, raw_contents, validate = False):
+ def __init__(self, raw_contents: str, validate: bool = False) -> None:
"""
Extra-info descriptor constructor. By default this validates the
descriptor's content as it's parsed. This validation can be disabled to
@@ -851,7 +854,7 @@ class ExtraInfoDescriptor(Descriptor):
else:
self._entries = entries
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX):
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
"""
Digest of this descriptor's content. These are referenced by...
@@ -876,13 +879,13 @@ class ExtraInfoDescriptor(Descriptor):
raise NotImplementedError('Unsupported Operation: this should be implemented by the ExtraInfoDescriptor subclass')
- def _required_fields(self):
+ def _required_fields(self) -> Tuple[str]:
return REQUIRED_FIELDS
- def _first_keyword(self):
+ def _first_keyword(self) -> str:
return 'extra-info'
- def _last_keyword(self):
+ def _last_keyword(self) -> str:
return 'router-signature'
@@ -917,7 +920,7 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor):
})
@classmethod
- def content(cls, attr = None, exclude = (), sign = False, signing_key = None):
+ def content(cls: Type['stem.descriptor.extrainfo.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> str:
base_header = (
('extra-info', '%s %s' % (_random_nickname(), _random_fingerprint())),
('published', _random_date()),
@@ -938,11 +941,11 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor):
))
@classmethod
- def create(cls, attr = None, exclude = (), validate = True, sign = False, signing_key = None):
+ def create(cls: Type['stem.descriptor.extrainfo.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> 'stem.descriptor.extrainfo.RelayExtraInfoDescriptor':
return cls(cls.content(attr, exclude, sign, signing_key), validate = validate)
@functools.lru_cache()
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX):
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
if hash_type == DigestHash.SHA1:
# our digest is calculated from everything except our signature
@@ -986,7 +989,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor):
})
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.extrainfo.BridgeExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('extra-info', 'ec2bridgereaac65a3 %s' % _random_fingerprint()),
('published', _random_date()),
@@ -994,7 +997,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor):
('router-digest', _random_fingerprint()),
))
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX):
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX:
return self._digest
elif hash_type == DigestHash.SHA256 and encoding == DigestEncoding.BASE64:
@@ -1002,7 +1005,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor):
else:
raise NotImplementedError('Bridge extrainfo digests are only available as sha1/hex and sha256/base64, not %s/%s' % (hash_type, encoding))
- def _required_fields(self):
+ def _required_fields(self) -> Tuple[str]:
excluded_fields = [
'router-signature',
]
@@ -1013,5 +1016,5 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor):
return tuple(included_fields + [f for f in REQUIRED_FIELDS if f not in excluded_fields])
- def _last_keyword(self):
+ def _last_keyword(self) -> str:
return None
diff --git a/stem/descriptor/hidden_service.py b/stem/descriptor/hidden_service.py
index 75a78d2e..8d23838e 100644
--- a/stem/descriptor/hidden_service.py
+++ b/stem/descriptor/hidden_service.py
@@ -51,6 +51,7 @@ import stem.util.tor_tools
from stem.client.datatype import CertType
from stem.descriptor.certificate import ExtensionType, Ed25519Extension, Ed25519Certificate, Ed25519CertificateV1
+from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import (
PGP_BLOCK_END,
@@ -162,7 +163,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
"""
@staticmethod
- def parse(content):
+ def parse(content: str) -> 'stem.descriptor.hidden_service.IntroductionPointV3':
"""
Parses an introduction point from its descriptor content.
@@ -200,7 +201,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3(link_specifiers, onion_key, auth_key_cert, enc_key, enc_key_cert, legacy_key, legacy_key_cert)
@staticmethod
- def create_for_address(address, port, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None):
+ def create_for_address(address: str, port: int, expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3':
"""
Simplified constructor for a single address/port link specifier.
@@ -232,7 +233,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3.create_for_link_specifiers(link_specifiers, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None)
@staticmethod
- def create_for_link_specifiers(link_specifiers, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None):
+ def create_for_link_specifiers(link_specifiers: Sequence['stem.client.datatype.LinkSpecifier'], expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3':
"""
Simplified constructor. For more sophisticated use cases you can use this
as a template for how introduction points are properly created.
@@ -271,7 +272,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3(link_specifiers, onion_key, auth_key_cert, enc_key, enc_key_cert, None, None)
- def encode(self):
+ def encode(self) -> str:
"""
Descriptor representation of this introduction point.
@@ -299,7 +300,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return '\n'.join(lines)
- def onion_key(self):
+ def onion_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey':
"""
Provides our ntor introduction point public key.
@@ -312,7 +313,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.onion_key_raw, x25519 = True)
- def auth_key(self):
+ def auth_key(self) -> 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey':
"""
Provides our authentication certificate's public key.
@@ -325,7 +326,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.auth_key_cert.key, ed25519 = True)
- def enc_key(self):
+ def enc_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey':
"""
Provides our encryption key.
@@ -338,7 +339,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.enc_key_raw, x25519 = True)
- def legacy_key(self):
+ def legacy_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey':
"""
Provides our legacy introduction point public key.
@@ -352,7 +353,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.legacy_key_raw, x25519 = True)
@staticmethod
- def _key_as(value, x25519 = False, ed25519 = False):
+ def _key_as(value: str, x25519: bool = False, ed25519: bool = False) -> Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']:
if value is None or (not x25519 and not ed25519):
return value
@@ -375,7 +376,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return Ed25519PublicKey.from_public_bytes(value)
@staticmethod
- def _parse_link_specifiers(content):
+ def _parse_link_specifiers(content: str) -> 'stem.client.datatype.LinkSpecifier':
try:
content = base64.b64decode(content)
except Exception as exc:
@@ -393,16 +394,16 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return link_specifiers
- def __hash__(self):
+ def __hash__(self) -> int:
if not hasattr(self, '_hash'):
self._hash = hash(self.encode())
return self._hash
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, IntroductionPointV3) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@@ -417,22 +418,22 @@ class AuthorizedClient(object):
:var str cookie: base64 encoded authentication cookie
"""
- def __init__(self, id = None, iv = None, cookie = None):
+ def __init__(self, id: str = None, iv: str = None, cookie: str = None) -> None:
self.id = stem.util.str_tools._to_unicode(id if id else base64.b64encode(os.urandom(8)).rstrip(b'='))
self.iv = stem.util.str_tools._to_unicode(iv if iv else base64.b64encode(os.urandom(16)).rstrip(b'='))
self.cookie = stem.util.str_tools._to_unicode(cookie if cookie else base64.b64encode(os.urandom(16)).rstrip(b'='))
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'id', 'iv', 'cookie', cache = True)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, AuthorizedClient) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
-def _parse_file(descriptor_file, desc_type = None, validate = False, **kwargs):
+def _parse_file(descriptor_file: BinaryIO, desc_type: str = None, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.hidden_service.HiddenServiceDescriptor']:
"""
Iterates over the hidden service descriptors in a file.
@@ -442,7 +443,7 @@ def _parse_file(descriptor_file, desc_type = None, validate = False, **kwargs):
**True**, skips these checks otherwise
:param dict kwargs: additional arguments for the descriptor constructor
- :returns: iterator for :class:`~stem.descriptor.hidden_service.HiddenServiceDescriptorV2`
+ :returns: iterator for :class:`~stem.descriptor.hidden_service.HiddenServiceDescriptor`
instances in the file
:raises:
@@ -472,7 +473,7 @@ def _parse_file(descriptor_file, desc_type = None, validate = False, **kwargs):
break # done parsing file
-def _decrypt_layer(encrypted_block, constant, revision_counter, subcredential, blinded_key):
+def _decrypt_layer(encrypted_block: bytes, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> str:
if encrypted_block.startswith('-----BEGIN MESSAGE-----\n') and encrypted_block.endswith('\n-----END MESSAGE-----'):
encrypted_block = encrypted_block[24:-22]
@@ -499,7 +500,7 @@ def _decrypt_layer(encrypted_block, constant, revision_counter, subcredential, b
return stem.util.str_tools._to_unicode(plaintext)
-def _encrypt_layer(plaintext, constant, revision_counter, subcredential, blinded_key):
+def _encrypt_layer(plaintext: str, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes:
salt = os.urandom(16)
cipher, mac_for = _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt)
@@ -510,7 +511,7 @@ def _encrypt_layer(plaintext, constant, revision_counter, subcredential, blinded
return b'-----BEGIN MESSAGE-----\n%s\n-----END MESSAGE-----' % b'\n'.join(stem.util.str_tools._split_by_length(encoded, 64))
-def _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt):
+def _layer_cipher(constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes, salt: bytes) -> Tuple['cryptography.hazmat.primitives.ciphers.Cipher', Callable[[bytes], bytes]]:
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
@@ -530,7 +531,7 @@ def _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt):
return cipher, lambda ciphertext: hashlib.sha3_256(mac_prefix + ciphertext).digest()
-def _parse_protocol_versions_line(descriptor, entries):
+def _parse_protocol_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value('protocol-versions', entries)
try:
@@ -545,7 +546,7 @@ def _parse_protocol_versions_line(descriptor, entries):
descriptor.protocol_versions = versions
-def _parse_introduction_points_line(descriptor, entries):
+def _parse_introduction_points_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
_, block_type, block_contents = entries['introduction-points'][0]
if not block_contents or block_type != 'MESSAGE':
@@ -559,7 +560,7 @@ def _parse_introduction_points_line(descriptor, entries):
raise ValueError("'introduction-points' isn't base64 encoded content:\n%s" % block_contents)
-def _parse_v3_outer_clients(descriptor, entries):
+def _parse_v3_outer_clients(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "auth-client" client-id iv encrypted-cookie
clients = {}
@@ -575,7 +576,7 @@ def _parse_v3_outer_clients(descriptor, entries):
descriptor.clients = clients
-def _parse_v3_inner_formats(descriptor, entries):
+def _parse_v3_inner_formats(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value, formats = _value('create2-formats', entries), []
for entry in value.split(' '):
@@ -587,7 +588,7 @@ def _parse_v3_inner_formats(descriptor, entries):
descriptor.formats = formats
-def _parse_v3_introduction_points(descriptor, entries):
+def _parse_v3_introduction_points(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
if hasattr(descriptor, '_unparsed_introduction_points'):
introduction_points = []
remaining = descriptor._unparsed_introduction_points
@@ -687,7 +688,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
}
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('rendezvous-service-descriptor', 'y3olqqblqw2gbh6phimfuiroechjjafa'),
('version', '2'),
@@ -701,10 +702,10 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
))
@classmethod
- def create(cls, attr = None, exclude = (), validate = True):
+ def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = (), validate: bool = True) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV2':
return cls(cls.content(attr, exclude), validate = validate, skip_crypto_validation = True)
- def __init__(self, raw_contents, validate = False, skip_crypto_validation = False):
+ def __init__(self, raw_contents: str, validate: bool = False, skip_crypto_validation: bool = False) -> None:
super(HiddenServiceDescriptorV2, self).__init__(raw_contents, lazy_load = not validate)
entries = _descriptor_components(raw_contents, validate, non_ascii_fields = ('introduction-points'))
@@ -736,10 +737,12 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
self._entries = entries
@functools.lru_cache()
- def introduction_points(self, authentication_cookie = None):
+ def introduction_points(self, authentication_cookie: Optional[str] = None) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']:
"""
Provided this service's introduction points.
+ :param str authentication_cookie: base64 encoded authentication cookie
+
:returns: **list** of :class:`~stem.descriptor.hidden_service.IntroductionPointV2`
:raises:
@@ -774,7 +777,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
return HiddenServiceDescriptorV2._parse_introduction_points(content)
@staticmethod
- def _decrypt_basic_auth(content, authentication_cookie):
+ def _decrypt_basic_auth(content: bytes, authentication_cookie: str) -> bytes:
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
@@ -821,7 +824,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
return content # nope, unable to decrypt the content
@staticmethod
- def _decrypt_stealth_auth(content, authentication_cookie):
+ def _decrypt_stealth_auth(content: bytes, authentication_cookie: str) -> bytes:
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
@@ -836,7 +839,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
return decryptor.update(encrypted) + decryptor.finalize()
@staticmethod
- def _parse_introduction_points(content):
+ def _parse_introduction_points(content: bytes) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']:
"""
Provides the parsed list of IntroductionPointV2 for the unencrypted content.
"""
@@ -928,7 +931,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
}
@classmethod
- def content(cls, attr = None, exclude = (), sign = False, inner_layer = None, outer_layer = None, identity_key = None, signing_key = None, signing_cert = None, revision_counter = None, blinding_nonce = None):
+ def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> str:
"""
Hidden service v3 descriptors consist of three parts:
@@ -1023,10 +1026,10 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
return desc_content
@classmethod
- def create(cls, attr = None, exclude = (), validate = True, sign = False, inner_layer = None, outer_layer = None, identity_key = None, signing_key = None, signing_cert = None, revision_counter = None, blinding_nonce = None):
+ def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3':
return cls(cls.content(attr, exclude, sign, inner_layer, outer_layer, identity_key, signing_key, signing_cert, revision_counter, blinding_nonce), validate = validate)
- def __init__(self, raw_contents, validate = False):
+ def __init__(self, raw_contents: bytes, validate: bool = False) -> None:
super(HiddenServiceDescriptorV3, self).__init__(raw_contents, lazy_load = not validate)
self._inner_layer = None
@@ -1054,7 +1057,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
else:
self._entries = entries
- def decrypt(self, onion_address):
+ def decrypt(self, onion_address: str) -> 'stem.descriptor.hidden_service.InnerLayer':
"""
Decrypt this descriptor. Hidden serice descriptors contain two encryption
layers (:class:`~stem.descriptor.hidden_service.OuterLayer` and
@@ -1086,7 +1089,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
return self._inner_layer
@staticmethod
- def address_from_identity_key(key, suffix = True):
+ def address_from_identity_key(key: Union[bytes, 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'], suffix: bool = True) -> str:
"""
Converts a hidden service identity key into its address. This accepts all
key formats (private, public, or public bytes).
@@ -1094,7 +1097,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
:param Ed25519PublicKey,Ed25519PrivateKey,bytes key: hidden service identity key
:param bool suffix: includes the '.onion' suffix if true, excluded otherwise
- :returns: **unicode** hidden service address
+ :returns: **str** hidden service address
:raises: **ImportError** if key is a cryptographic type and ed25519 support
is unavailable
@@ -1109,7 +1112,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
return stem.util.str_tools._to_unicode(onion_address + b'.onion' if suffix else onion_address).lower()
@staticmethod
- def identity_key_from_address(onion_address):
+ def identity_key_from_address(onion_address: str) -> bool:
"""
Converts a hidden service address into its public identity key.
@@ -1146,7 +1149,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
return pubkey
@staticmethod
- def _subcredential(identity_key, blinded_key):
+ def _subcredential(identity_key: bytes, blinded_key: bytes) -> bytes:
# credential = H('credential' | public-identity-key)
# subcredential = H('subcredential' | credential | blinded-public-key)
@@ -1186,11 +1189,11 @@ class OuterLayer(Descriptor):
}
@staticmethod
- def _decrypt(encrypted, revision_counter, subcredential, blinded_key):
+ def _decrypt(encrypted: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.OuterLayer':
plaintext = _decrypt_layer(encrypted, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key)
return OuterLayer(plaintext)
- def _encrypt(self, revision_counter, subcredential, blinded_key):
+ def _encrypt(self, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes:
# Spec mandated padding: "Before encryption the plaintext is padded with
# NUL bytes to the nearest multiple of 10k bytes."
@@ -1201,7 +1204,7 @@ class OuterLayer(Descriptor):
return _encrypt_layer(content, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key)
@classmethod
- def content(cls, attr = None, exclude = (), validate = True, sign = False, inner_layer = None, revision_counter = None, authorized_clients = None, subcredential = None, blinded_key = None):
+ def content(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: Optional[int] = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> str:
try:
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
@@ -1235,10 +1238,10 @@ class OuterLayer(Descriptor):
))
@classmethod
- def create(cls, attr = None, exclude = (), validate = True, sign = False, inner_layer = None, revision_counter = None, authorized_clients = None, subcredential = None, blinded_key = None):
+ def create(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: int = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> 'stem.descriptor.hidden_service.OuterLayer':
return cls(cls.content(attr, exclude, validate, sign, inner_layer, revision_counter, authorized_clients, subcredential, blinded_key), validate = validate)
- def __init__(self, content, validate = False):
+ def __init__(self, content: bytes, validate: bool = False) -> None:
content = stem.util.str_tools._to_bytes(content).rstrip(b'\x00') # strip null byte padding
super(OuterLayer, self).__init__(content, lazy_load = not validate)
@@ -1282,7 +1285,7 @@ class InnerLayer(Descriptor):
}
@staticmethod
- def _decrypt(outer_layer, revision_counter, subcredential, blinded_key):
+ def _decrypt(outer_layer: 'stem.descriptor.hidden_service.OuterLayer', revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes:
plaintext = _decrypt_layer(outer_layer.encrypted, b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key)
return InnerLayer(plaintext, validate = True, outer_layer = outer_layer)
@@ -1292,7 +1295,7 @@ class InnerLayer(Descriptor):
return _encrypt_layer(self.get_bytes(), b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key)
@classmethod
- def content(cls, attr = None, exclude = (), introduction_points = None):
+ def content(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> str:
if introduction_points:
suffix = '\n' + '\n'.join(map(IntroductionPointV3.encode, introduction_points))
else:
@@ -1303,10 +1306,10 @@ class InnerLayer(Descriptor):
)) + stem.util.str_tools._to_bytes(suffix)
@classmethod
- def create(cls, attr = None, exclude = (), validate = True, introduction_points = None):
+ def create(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> 'stem.descriptor.hidden_service.InnerLayer':
return cls(cls.content(attr, exclude, introduction_points), validate = validate)
- def __init__(self, content, validate = False, outer_layer = None):
+ def __init__(self, content: bytes, validate: bool = False, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None) -> None:
super(InnerLayer, self).__init__(content, lazy_load = not validate)
self.outer = outer_layer
@@ -1331,7 +1334,7 @@ class InnerLayer(Descriptor):
self._entries = entries
-def _blinded_pubkey(identity_key, blinding_nonce):
+def _blinded_pubkey(identity_key: bytes, blinding_nonce: bytes) -> bytes:
from stem.util import ed25519
mult = 2 ** (ed25519.b - 2) + sum(2 ** i * ed25519.bit(blinding_nonce, i) for i in range(3, ed25519.b - 2))
@@ -1339,7 +1342,7 @@ def _blinded_pubkey(identity_key, blinding_nonce):
return ed25519.encodepoint(ed25519.scalarmult(P, mult))
-def _blinded_sign(msg, identity_key, blinded_key, blinding_nonce):
+def _blinded_sign(msg: bytes, identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes, blinding_nonce: bytes) -> bytes:
try:
from cryptography.hazmat.primitives import serialization
except ImportError:
diff --git a/stem/descriptor/microdescriptor.py b/stem/descriptor/microdescriptor.py
index c62a3d0d..c2c104ff 100644
--- a/stem/descriptor/microdescriptor.py
+++ b/stem/descriptor/microdescriptor.py
@@ -69,6 +69,8 @@ import hashlib
import stem.exit_policy
+from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Type, Union
+
from stem.descriptor import (
Descriptor,
DigestHash,
@@ -102,7 +104,7 @@ SINGLE_FIELDS = (
)
-def _parse_file(descriptor_file, validate = False, **kwargs):
+def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']:
"""
Iterates over the microdescriptors in a file.
@@ -159,7 +161,7 @@ def _parse_file(descriptor_file, validate = False, **kwargs):
break # done parsing descriptors
-def _parse_id_line(descriptor, entries):
+def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
identities = {}
for entry in _values('id', entries):
@@ -244,7 +246,7 @@ class Microdescriptor(Descriptor):
}
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.microdescriptor.Microdescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('onion-key', _random_crypto_blob('RSA PUBLIC KEY')),
))
@@ -260,7 +262,7 @@ class Microdescriptor(Descriptor):
else:
self._entries = entries
- def digest(self, hash_type = DigestHash.SHA256, encoding = DigestEncoding.BASE64):
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA256, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.BASE64) -> Union[str, 'hashlib.HASH']:
"""
Digest of this microdescriptor. These are referenced by...
@@ -285,7 +287,7 @@ class Microdescriptor(Descriptor):
raise NotImplementedError('Microdescriptor digests are only available in sha1 and sha256, not %s' % hash_type)
@functools.lru_cache()
- def get_annotations(self):
+ def get_annotations(self) -> Dict[str, str]:
"""
Provides content that appeared prior to the descriptor. If this comes from
the cached-microdescs then this commonly contains content like...
@@ -308,7 +310,7 @@ class Microdescriptor(Descriptor):
return annotation_dict
- def get_annotation_lines(self):
+ def get_annotation_lines(self) -> Sequence[str]:
"""
Provides the lines of content that appeared prior to the descriptor. This
is the same as the
@@ -320,7 +322,7 @@ class Microdescriptor(Descriptor):
return self._annotation_lines
- def _check_constraints(self, entries):
+ def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None:
"""
Does a basic check that the entries conform to this descriptor type's
constraints.
@@ -341,5 +343,5 @@ class Microdescriptor(Descriptor):
if 'onion-key' != list(entries.keys())[0]:
raise ValueError("Microdescriptor must start with a 'onion-key' entry")
- def _name(self, is_plural = False):
+ def _name(self, is_plural: bool = False) -> str:
return 'microdescriptors' if is_plural else 'microdescriptor'
diff --git a/stem/descriptor/networkstatus.py b/stem/descriptor/networkstatus.py
index 77c6d612..48940987 100644
--- a/stem/descriptor/networkstatus.py
+++ b/stem/descriptor/networkstatus.py
@@ -65,6 +65,8 @@ import stem.util.str_tools
import stem.util.tor_tools
import stem.version
+from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Type
+
from stem.descriptor import (
PGP_BLOCK_END,
Descriptor,
@@ -293,7 +295,7 @@ class DocumentDigest(collections.namedtuple('DocumentDigest', ['flavor', 'algori
"""
-def _parse_file(document_file, document_type = None, validate = False, is_microdescriptor = False, document_handler = DocumentHandler.ENTRIES, **kwargs):
+def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.descriptor.networkstatus.NetworkStatusDocument']] = None, validate: bool = False, is_microdescriptor: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, **kwargs: Any) -> 'stem.descriptor.networkstatus.NetworkStatusDocument':
"""
Parses a network status and iterates over the RouterStatusEntry in it. The
document that these instances reference have an empty 'routers' attribute to
@@ -372,7 +374,7 @@ def _parse_file(document_file, document_type = None, validate = False, is_microd
raise ValueError('Unrecognized document_handler: %s' % document_handler)
-def _parse_file_key_certs(certificate_file, validate = False):
+def _parse_file_key_certs(certificate_file: BinaryIO, validate: bool = False) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']:
"""
Parses a file containing one or more authority key certificates.
@@ -401,7 +403,7 @@ def _parse_file_key_certs(certificate_file, validate = False):
break # done parsing file
-def _parse_file_detached_sigs(detached_signature_file, validate = False):
+def _parse_file_detached_sigs(detached_signature_file: BinaryIO, validate: bool = False) -> Iterator['stem.descriptor.networkstatus.DetachedSignature']:
"""
Parses a file containing one or more detached signatures.
@@ -431,7 +433,7 @@ class NetworkStatusDocument(Descriptor):
Common parent for network status documents.
"""
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX):
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> None:
"""
Digest of this descriptor's content. These are referenced by...
@@ -458,8 +460,8 @@ class NetworkStatusDocument(Descriptor):
raise NotImplementedError('Network status document digests are only available in sha1 and sha256, not %s' % hash_type)
-def _parse_version_line(keyword, attribute, expected_version):
- def _parse(descriptor, entries):
+def _parse_version_line(keyword: str, attribute: str, expected_version: int) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value(keyword, entries)
if not value.isdigit():
@@ -473,7 +475,7 @@ def _parse_version_line(keyword, attribute, expected_version):
return _parse
-def _parse_dir_source_line(descriptor, entries):
+def _parse_dir_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value('dir-source', entries)
dir_source_comp = value.split()
@@ -493,7 +495,7 @@ def _parse_dir_source_line(descriptor, entries):
descriptor.dir_port = None if dir_source_comp[2] == '0' else int(dir_source_comp[2])
-def _parse_additional_digests(descriptor, entries):
+def _parse_additional_digests(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
digests = []
for val in _values('additional-digest', entries):
@@ -507,7 +509,7 @@ def _parse_additional_digests(descriptor, entries):
descriptor.additional_digests = digests
-def _parse_additional_signatures(descriptor, entries):
+def _parse_additional_signatures(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
signatures = []
for val, block_type, block_contents in entries['additional-signature']:
@@ -598,7 +600,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument):
}
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('network-status-version', '2'),
('dir-source', '%s %s 80' % (_random_ipv4_address(), _random_ipv4_address())),
@@ -610,7 +612,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument):
('directory-signature', 'moria2' + _random_crypto_blob('SIGNATURE')),
))
- def __init__(self, raw_content, validate = False):
+ def __init__(self, raw_content: bytes, validate: bool = False) -> None:
super(NetworkStatusDocumentV2, self).__init__(raw_content, lazy_load = not validate)
# Splitting the document from the routers. Unlike v3 documents we're not
@@ -646,7 +648,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument):
else:
self._entries = entries
- def _check_constraints(self, entries):
+ def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None:
required_fields = [field for (field, is_mandatory) in NETWORK_STATUS_V2_FIELDS if is_mandatory]
for keyword in required_fields:
if keyword not in entries:
@@ -662,7 +664,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument):
raise ValueError("Network status document (v2) are expected to start with a 'network-status-version' line:\n%s" % str(self))
-def _parse_header_network_status_version_line(descriptor, entries):
+def _parse_header_network_status_version_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "network-status-version" version
value = _value('network-status-version', entries)
@@ -683,7 +685,7 @@ def _parse_header_network_status_version_line(descriptor, entries):
raise ValueError("Expected a version 3 network status document, got version '%s' instead" % descriptor.version)
-def _parse_header_vote_status_line(descriptor, entries):
+def _parse_header_vote_status_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "vote-status" type
#
# The consensus-method and consensus-methods fields are optional since
@@ -700,7 +702,7 @@ def _parse_header_vote_status_line(descriptor, entries):
raise ValueError("A network status document's vote-status line can only be 'consensus' or 'vote', got '%s' instead" % value)
-def _parse_header_consensus_methods_line(descriptor, entries):
+def _parse_header_consensus_methods_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "consensus-methods" IntegerList
if descriptor._lazy_loading and descriptor.is_vote:
@@ -717,7 +719,7 @@ def _parse_header_consensus_methods_line(descriptor, entries):
descriptor.consensus_methods = consensus_methods
-def _parse_header_consensus_method_line(descriptor, entries):
+def _parse_header_consensus_method_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "consensus-method" Integer
if descriptor._lazy_loading and descriptor.is_consensus:
@@ -731,7 +733,7 @@ def _parse_header_consensus_method_line(descriptor, entries):
descriptor.consensus_method = int(value)
-def _parse_header_voting_delay_line(descriptor, entries):
+def _parse_header_voting_delay_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "voting-delay" VoteSeconds DistSeconds
value = _value('voting-delay', entries)
@@ -744,8 +746,8 @@ def _parse_header_voting_delay_line(descriptor, entries):
raise ValueError("A network status document's 'voting-delay' line must be a pair of integer values, but was '%s'" % value)
-def _parse_versions_line(keyword, attribute):
- def _parse(descriptor, entries):
+def _parse_versions_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value, entries = _value(keyword, entries), []
for entry in value.split(','):
@@ -759,7 +761,7 @@ def _parse_versions_line(keyword, attribute):
return _parse
-def _parse_header_flag_thresholds_line(descriptor, entries):
+def _parse_header_flag_thresholds_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "flag-thresholds" SP THRESHOLDS
value, thresholds = _value('flag-thresholds', entries).strip(), {}
@@ -782,7 +784,7 @@ def _parse_header_flag_thresholds_line(descriptor, entries):
descriptor.flag_thresholds = thresholds
-def _parse_header_parameters_line(descriptor, entries):
+def _parse_header_parameters_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "params" [Parameters]
# Parameter ::= Keyword '=' Int32
# Int32 ::= A decimal integer between -2147483648 and 2147483647.
@@ -798,7 +800,7 @@ def _parse_header_parameters_line(descriptor, entries):
descriptor._check_params_constraints()
-def _parse_directory_footer_line(descriptor, entries):
+def _parse_directory_footer_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# nothing to parse, simply checking that we don't have a value
value = _value('directory-footer', entries)
@@ -807,7 +809,7 @@ def _parse_directory_footer_line(descriptor, entries):
raise ValueError("A network status document's 'directory-footer' line shouldn't have any content, got 'directory-footer %s'" % value)
-def _parse_footer_directory_signature_line(descriptor, entries):
+def _parse_footer_directory_signature_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
signatures = []
for sig_value, block_type, block_contents in entries['directory-signature']:
@@ -828,7 +830,7 @@ def _parse_footer_directory_signature_line(descriptor, entries):
descriptor.signatures = signatures
-def _parse_package_line(descriptor, entries):
+def _parse_package_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
package_versions = []
for value, _, _ in entries['package']:
@@ -849,7 +851,7 @@ def _parse_package_line(descriptor, entries):
descriptor.packages = package_versions
-def _parsed_shared_rand_commit(descriptor, entries):
+def _parsed_shared_rand_commit(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "shared-rand-commit" Version AlgName Identity Commit [Reveal]
commitments = []
@@ -871,7 +873,7 @@ def _parsed_shared_rand_commit(descriptor, entries):
descriptor.shared_randomness_commitments = commitments
-def _parse_shared_rand_previous_value(descriptor, entries):
+def _parse_shared_rand_previous_value(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "shared-rand-previous-value" NumReveals Value
value = _value('shared-rand-previous-value', entries)
@@ -884,7 +886,7 @@ def _parse_shared_rand_previous_value(descriptor, entries):
raise ValueError("A network status document's 'shared-rand-previous-value' line must be a pair of values, the first an integer but was '%s'" % value)
-def _parse_shared_rand_current_value(descriptor, entries):
+def _parse_shared_rand_current_value(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "shared-rand-current-value" NumReveals Value
value = _value('shared-rand-current-value', entries)
@@ -897,7 +899,7 @@ def _parse_shared_rand_current_value(descriptor, entries):
raise ValueError("A network status document's 'shared-rand-current-value' line must be a pair of values, the first an integer but was '%s'" % value)
-def _parse_bandwidth_file_headers(descriptor, entries):
+def _parse_bandwidth_file_headers(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "bandwidth-file-headers" KeyValues
# KeyValues ::= "" | KeyValue | KeyValues SP KeyValue
# KeyValue ::= Keyword '=' Value
@@ -912,7 +914,7 @@ def _parse_bandwidth_file_headers(descriptor, entries):
descriptor.bandwidth_file_headers = results
-def _parse_bandwidth_file_digest(descriptor, entries):
+def _parse_bandwidth_file_digest(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "bandwidth-file-digest" 1*(SP algorithm "=" digest)
value = _value('bandwidth-file-digest', entries)
@@ -1096,7 +1098,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
}
@classmethod
- def content(cls, attr = None, exclude = (), authorities = None, routers = None):
+ def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> str:
attr = {} if attr is None else dict(attr)
is_vote = attr.get('vote-status') == 'vote'
@@ -1168,10 +1170,10 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return desc_content
@classmethod
- def create(cls, attr = None, exclude = (), validate = True, authorities = None, routers = None):
+ def create(cls: Type['stem.descriptor.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, authorities: Optional[Sequence['stem.directory.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> 'stem.descriptor.NetworkStatusDocumentV3':
return cls(cls.content(attr, exclude, authorities, routers), validate = validate)
- def __init__(self, raw_content, validate = False, default_params = True):
+ def __init__(self, raw_content: str, validate: bool = False, default_params: bool = True) -> None:
"""
Parse a v3 network status document.
@@ -1213,7 +1215,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
self.routers = dict((desc.fingerprint, desc) for desc in router_iter)
self._footer(document_file, validate)
- def type_annotation(self):
+ def type_annotation(self) -> 'stem.descriptor.TypeAnnotation':
if isinstance(self, BridgeNetworkStatusDocument):
return TypeAnnotation('bridge-network-status', 1, 0)
elif not self.is_microdescriptor:
@@ -1225,7 +1227,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return TypeAnnotation('network-status-microdesc-consensus-3', 1, 0)
- def is_valid(self):
+ def is_valid(self) -> bool:
"""
Checks if the current time is between this document's **valid_after** and
**valid_until** timestamps. To be valid means the information within this
@@ -1239,7 +1241,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return self.valid_after < datetime.datetime.utcnow() < self.valid_until
- def is_fresh(self):
+ def is_fresh(self) -> bool:
"""
Checks if the current time is between this document's **valid_after** and
**fresh_until** timestamps. To be fresh means this should be the latest
@@ -1253,7 +1255,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return self.valid_after < datetime.datetime.utcnow() < self.fresh_until
- def validate_signatures(self, key_certs):
+ def validate_signatures(self, key_certs: Sequence['stem.descriptor.networkstatus.KeyCertificates']) -> None:
"""
Validates we're properly signed by the signing certificates.
@@ -1287,7 +1289,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
if valid_digests < required_digests:
raise ValueError('Network Status Document has %i valid signatures out of %i total, needed %i' % (valid_digests, total_digests, required_digests))
- def get_unrecognized_lines(self):
+ def get_unrecognized_lines(self) -> Sequence[str]:
if self._lazy_loading:
self._parse(self._header_entries, False, parser_for_line = self._HEADER_PARSER_FOR_LINE)
self._parse(self._footer_entries, False, parser_for_line = self._FOOTER_PARSER_FOR_LINE)
@@ -1295,7 +1297,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return super(NetworkStatusDocumentV3, self).get_unrecognized_lines()
- def meets_consensus_method(self, method):
+ def meets_consensus_method(self, method: int) -> bool:
"""
Checks if we meet the given consensus-method. This works for both votes and
consensuses, checking our 'consensus-method' and 'consensus-methods'
@@ -1313,7 +1315,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
else:
return False # malformed document
- def _header(self, document_file, validate):
+ def _header(self, document_file: BinaryIO, validate: bool) -> None:
content = bytes.join(b'', _read_until_keywords((AUTH_START, ROUTERS_START, FOOTER_START), document_file))
entries = _descriptor_components(content, validate)
header_fields = [attr[0] for attr in HEADER_STATUS_DOCUMENT_FIELDS]
@@ -1347,7 +1349,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
self._header_entries = entries
self._entries.update(entries)
- def _footer(self, document_file, validate):
+ def _footer(self, document_file: BinaryIO, validate: bool) -> None:
entries = _descriptor_components(document_file.read(), validate)
footer_fields = [attr[0] for attr in FOOTER_STATUS_DOCUMENT_FIELDS]
@@ -1379,7 +1381,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
self._footer_entries = entries
self._entries.update(entries)
- def _check_params_constraints(self):
+ def _check_params_constraints(self) -> None:
"""
Checks that the params we know about are within their documented ranges.
"""
@@ -1398,7 +1400,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
raise ValueError("'%s' value on the params line must be in the range of %i - %i, was %i" % (key, minimum, maximum, value))
-def _check_for_missing_and_disallowed_fields(document, entries, fields):
+def _check_for_missing_and_disallowed_fields(document: 'stem.descriptor.networkstatus.NetworkStatusDocumentV3', entries: Mapping[str, str], fields: Sequence[str]) -> None:
"""
Checks that we have mandatory fields for our type, and that we don't have
any fields exclusive to the other (ie, no vote-only fields appear in a
@@ -1431,7 +1433,7 @@ def _check_for_missing_and_disallowed_fields(document, entries, fields):
raise ValueError("Network status document has fields that shouldn't appear in this document type or version: %s" % ', '.join(disallowed_fields))
-def _parse_int_mappings(keyword, value, validate):
+def _parse_int_mappings(keyword: str, value: str, validate: bool) -> Dict[str, int]:
# Parse a series of 'key=value' entries, checking the following:
# - values are integers
# - keys are sorted in lexical order
@@ -1461,7 +1463,7 @@ def _parse_int_mappings(keyword, value, validate):
return results
-def _parse_dirauth_source_line(descriptor, entries):
+def _parse_dirauth_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "dir-source" nickname identity address IP dirport orport
value = _value('dir-source', entries)
@@ -1580,7 +1582,7 @@ class DirectoryAuthority(Descriptor):
}
@classmethod
- def content(cls, attr = None, exclude = (), is_vote = False):
+ def content(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), is_vote: bool = False) -> str:
attr = {} if attr is None else dict(attr)
# include mandatory 'vote-digest' if a consensus
@@ -1599,10 +1601,10 @@ class DirectoryAuthority(Descriptor):
return content
@classmethod
- def create(cls, attr = None, exclude = (), validate = True, is_vote = False):
+ def create(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, is_vote: bool = False) -> 'stem.descriptor.networkstatus.DirectoryAuthority':
return cls(cls.content(attr, exclude, is_vote), validate = validate, is_vote = is_vote)
- def __init__(self, raw_content, validate = False, is_vote = False):
+ def __init__(self, raw_content: str, validate: bool = False, is_vote: bool = False) -> None:
"""
Parse a directory authority entry in a v3 network status document.
@@ -1677,7 +1679,7 @@ class DirectoryAuthority(Descriptor):
self._entries = entries
-def _parse_dir_address_line(descriptor, entries):
+def _parse_dir_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "dir-address" IPPort
value = _value('dir-address', entries)
@@ -1752,7 +1754,7 @@ class KeyCertificate(Descriptor):
}
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.networkstatus.KeyCertificate'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('dir-key-certificate-version', '3'),
('fingerprint', _random_fingerprint()),
@@ -1764,7 +1766,7 @@ class KeyCertificate(Descriptor):
('dir-key-certification', _random_crypto_blob('SIGNATURE')),
))
- def __init__(self, raw_content, validate = False):
+ def __init__(self, raw_content: str, validate: str = False) -> None:
super(KeyCertificate, self).__init__(raw_content, lazy_load = not validate)
entries = _descriptor_components(raw_content, validate)
@@ -1805,7 +1807,7 @@ class DocumentSignature(object):
:raises: **ValueError** if a validity check fails
"""
- def __init__(self, method, identity, key_digest, signature, flavor = None, validate = False):
+ def __init__(self, method: str, identity: str, key_digest: str, signature: str, flavor: Optional[str] = None, validate: bool = False) -> None:
# Checking that these attributes are valid. Technically the key
# digest isn't a fingerprint, but it has the same characteristics.
@@ -1822,7 +1824,7 @@ class DocumentSignature(object):
self.signature = signature
self.flavor = flavor
- def _compare(self, other, method):
+ def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool:
if not isinstance(other, DocumentSignature):
return False
@@ -1832,19 +1834,19 @@ class DocumentSignature(object):
return method(True, True) # we're equal
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(str(self).strip())
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s == o)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
- def __lt__(self, other):
+ def __lt__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s < o)
- def __le__(self, other):
+ def __le__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s <= o)
@@ -1898,7 +1900,7 @@ class DetachedSignature(Descriptor):
}
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.networkstatus.DetachedSignature'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('consensus-digest', '6D3CC0EFA408F228410A4A8145E1B0BB0670E442'),
('valid-after', _random_date()),
@@ -1906,7 +1908,7 @@ class DetachedSignature(Descriptor):
('valid-until', _random_date()),
))
- def __init__(self, raw_content, validate = False):
+ def __init__(self, raw_content: str, validate: bool = False) -> None:
super(DetachedSignature, self).__init__(raw_content, lazy_load = not validate)
entries = _descriptor_components(raw_content, validate)
@@ -1941,7 +1943,7 @@ class BridgeNetworkStatusDocument(NetworkStatusDocument):
TYPE_ANNOTATION_NAME = 'bridge-network-status'
- def __init__(self, raw_content, validate = False):
+ def __init__(self, raw_content: str, validate: bool = False) -> None:
super(BridgeNetworkStatusDocument, self).__init__(raw_content)
self.published = None
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index 24eb7b9b..f3c6d6bd 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -101,6 +101,7 @@ import stem.util.tor_tools
from stem.descriptor import Compression
from stem.util import log, str_tools
+from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union
# Tor has a limited number of descriptors we can fetch explicitly by their
# fingerprint or hashes due to a limit on the url length by squid proxies.
@@ -121,7 +122,7 @@ SINGLETON_DOWNLOADER = None
DIR_PORT_BLACKLIST = ('tor26', 'Serge')
-def get_instance():
+def get_instance() -> 'stem.descriptor.remote.DescriptorDownloader':
"""
Provides the singleton :class:`~stem.descriptor.remote.DescriptorDownloader`
used for this module's shorthand functions.
@@ -139,7 +140,7 @@ def get_instance():
return SINGLETON_DOWNLOADER
-def their_server_descriptor(**query_args):
+def their_server_descriptor(**query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the server descriptor of the relay we're downloading from.
@@ -154,7 +155,7 @@ def their_server_descriptor(**query_args):
return get_instance().their_server_descriptor(**query_args)
-def get_server_descriptors(fingerprints = None, **query_args):
+def get_server_descriptors(fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Shorthand for
:func:`~stem.descriptor.remote.DescriptorDownloader.get_server_descriptors`
@@ -166,7 +167,7 @@ def get_server_descriptors(fingerprints = None, **query_args):
return get_instance().get_server_descriptors(fingerprints, **query_args)
-def get_extrainfo_descriptors(fingerprints = None, **query_args):
+def get_extrainfo_descriptors(fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Shorthand for
:func:`~stem.descriptor.remote.DescriptorDownloader.get_extrainfo_descriptors`
@@ -178,7 +179,7 @@ def get_extrainfo_descriptors(fingerprints = None, **query_args):
return get_instance().get_extrainfo_descriptors(fingerprints, **query_args)
-def get_microdescriptors(hashes, **query_args):
+def get_microdescriptors(hashes: Optional[Union[str, Sequence[str]]], **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Shorthand for
:func:`~stem.descriptor.remote.DescriptorDownloader.get_microdescriptors`
@@ -190,7 +191,7 @@ def get_microdescriptors(hashes, **query_args):
return get_instance().get_microdescriptors(hashes, **query_args)
-def get_consensus(authority_v3ident = None, microdescriptor = False, **query_args):
+def get_consensus(authority_v3ident: Optional[str] = None, microdescriptor: bool = False, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Shorthand for
:func:`~stem.descriptor.remote.DescriptorDownloader.get_consensus`
@@ -202,7 +203,7 @@ def get_consensus(authority_v3ident = None, microdescriptor = False, **query_arg
return get_instance().get_consensus(authority_v3ident, microdescriptor, **query_args)
-def get_bandwidth_file(**query_args):
+def get_bandwidth_file(**query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Shorthand for
:func:`~stem.descriptor.remote.DescriptorDownloader.get_bandwidth_file`
@@ -214,7 +215,7 @@ def get_bandwidth_file(**query_args):
return get_instance().get_bandwidth_file(**query_args)
-def get_detached_signatures(**query_args):
+def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Shorthand for
:func:`~stem.descriptor.remote.DescriptorDownloader.get_detached_signatures`
@@ -370,7 +371,7 @@ class Query(object):
the same as running **query.run(True)** (default is **False**)
"""
- def __init__(self, resource, descriptor_type = None, endpoints = None, compression = (Compression.GZIP,), retries = 2, fall_back_to_authority = False, timeout = None, start = True, block = False, validate = False, document_handler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs):
+ def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence['stem.Endpoint']] = None, compression: Sequence['stem.descriptor.Compression'] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
if not resource.startswith('/'):
raise ValueError("Resources should start with a '/': %s" % resource)
@@ -433,7 +434,7 @@ class Query(object):
if block:
self.run(True)
- def start(self):
+ def start(self) -> None:
"""
Starts downloading the scriptors if we haven't started already.
"""
@@ -449,7 +450,7 @@ class Query(object):
self._downloader_thread.setDaemon(True)
self._downloader_thread.start()
- def run(self, suppress = False):
+ def run(self, suppress: bool = False) -> Sequence['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
@@ -469,7 +470,7 @@ class Query(object):
return list(self._run(suppress))
- def _run(self, suppress):
+ def _run(self, suppress: bool) -> Iterator['stem.descriptor.Descriptor']:
with self._downloader_thread_lock:
self.start()
self._downloader_thread.join()
@@ -505,11 +506,11 @@ class Query(object):
raise self.error
- def __iter__(self):
+ def __iter__(self) -> Iterator['stem.descriptor.Descriptor']:
for desc in self._run(True):
yield desc
- def _pick_endpoint(self, use_authority = False):
+ def _pick_endpoint(self, use_authority: bool = False) -> 'stem.Endpoint':
"""
Provides an endpoint to query. If we have multiple endpoints then one
is picked at random.
@@ -527,7 +528,7 @@ class Query(object):
else:
return random.choice(self.endpoints)
- def _download_descriptors(self, retries, timeout):
+ def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
try:
self.start_time = time.time()
endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
@@ -572,7 +573,7 @@ class DescriptorDownloader(object):
:class:`~stem.descriptor.remote.Query` constructor
"""
- def __init__(self, use_mirrors = False, **default_args):
+ def __init__(self, use_mirrors: bool = False, **default_args: Any) -> None:
self._default_args = default_args
self._endpoints = None
@@ -585,7 +586,7 @@ class DescriptorDownloader(object):
except Exception as exc:
log.debug('Unable to retrieve directory mirrors: %s' % exc)
- def use_directory_mirrors(self):
+ def use_directory_mirrors(self) -> 'stem.descriptor.networkstatus.NetworkStatusDocumentV3':
"""
Downloads the present consensus and configures ourselves to use directory
mirrors, in addition to authorities.
@@ -611,7 +612,7 @@ class DescriptorDownloader(object):
return consensus
- def their_server_descriptor(self, **query_args):
+ def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the server descriptor of the relay we're downloading from.
@@ -625,7 +626,7 @@ class DescriptorDownloader(object):
return self.query('/tor/server/authority', **query_args)
- def get_server_descriptors(self, fingerprints = None, **query_args):
+ def get_server_descriptors(self, fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the server descriptors with the given fingerprints. If no
fingerprints are provided then this returns all descriptors known
@@ -655,7 +656,7 @@ class DescriptorDownloader(object):
return self.query(resource, **query_args)
- def get_extrainfo_descriptors(self, fingerprints = None, **query_args):
+ def get_extrainfo_descriptors(self, fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the extrainfo descriptors with the given fingerprints. If no
fingerprints are provided then this returns all descriptors in the present
@@ -685,7 +686,7 @@ class DescriptorDownloader(object):
return self.query(resource, **query_args)
- def get_microdescriptors(self, hashes, **query_args):
+ def get_microdescriptors(self, hashes: Optional[Union[str, Sequence[str]]], **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the microdescriptors with the given hashes. To get these see the
**microdescriptor_digest** attribute of
@@ -731,7 +732,7 @@ class DescriptorDownloader(object):
return self.query('/tor/micro/d/%s' % '-'.join(hashes), **query_args)
- def get_consensus(self, authority_v3ident = None, microdescriptor = False, **query_args):
+ def get_consensus(self, authority_v3ident: Optional[str] = None, microdescriptor: bool = False, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the present router status entries.
@@ -775,7 +776,7 @@ class DescriptorDownloader(object):
return consensus_query
- def get_vote(self, authority, **query_args):
+ def get_vote(self, authority: 'stem.directory.Authority', **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the present vote for a given directory authority.
@@ -794,13 +795,13 @@ class DescriptorDownloader(object):
return self.query(resource, **query_args)
- def get_key_certificates(self, authority_v3idents = None, **query_args):
+ def get_key_certificates(self, authority_v3idents: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the key certificates for authorities with the given fingerprints.
If no fingerprints are provided then this returns all present key
certificates.
- :param str authority_v3idents: fingerprint or list of fingerprints of the
+ :param str,list authority_v3idents: fingerprint or list of fingerprints of the
authority keys, see `'v3ident' in tor's config.c
<https://gitweb.torproject.org/tor.git/tree/src/or/config.c#n819>`_
for the values.
@@ -827,7 +828,7 @@ class DescriptorDownloader(object):
return self.query(resource, **query_args)
- def get_bandwidth_file(self, **query_args):
+ def get_bandwidth_file(self, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the bandwidth authority heuristics used to make the next
consensus.
@@ -843,7 +844,7 @@ class DescriptorDownloader(object):
return self.query('/tor/status-vote/next/bandwidth', **query_args)
- def get_detached_signatures(self, **query_args):
+ def get_detached_signatures(self, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the detached signatures that will be used to make the next
consensus. Please note that **these are only available during minutes 55-60
@@ -896,7 +897,7 @@ class DescriptorDownloader(object):
return self.query('/tor/status-vote/next/consensus-signatures', **query_args)
- def query(self, resource, **query_args):
+ def query(self, resource: str, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Issues a request for the given resource.
@@ -923,7 +924,7 @@ class DescriptorDownloader(object):
return Query(resource, **args)
-def _download_from_orport(endpoint, compression, resource):
+def _download_from_orport(endpoint: 'stem.ORPort', compression: Sequence['stem.Compression'], resource: str) -> Tuple[bytes, Dict[str, str]]:
"""
Downloads descriptors from the given orport. Payload is just like an http
response (headers and all)...
@@ -981,7 +982,7 @@ def _download_from_orport(endpoint, compression, resource):
return _decompress(body_data, headers.get('Content-Encoding')), headers
-def _download_from_dirport(url, compression, timeout):
+def _download_from_dirport(url: str, compression: Sequence['stem.descriptor.Compression'], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
"""
Downloads descriptors from the given url.
@@ -1016,7 +1017,7 @@ def _download_from_dirport(url, compression, timeout):
return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
-def _decompress(data, encoding):
+def _decompress(data: bytes, encoding: str) -> bytes:
"""
Decompresses descriptor data.
@@ -1030,6 +1031,8 @@ def _decompress(data, encoding):
:param bytes data: data we received
:param str encoding: 'Content-Encoding' header of the response
+ :returns: **bytes** with the decompressed data
+
:raises:
* **ValueError** if encoding is unrecognized
* **ImportError** if missing the decompression module
@@ -1045,7 +1048,7 @@ def _decompress(data, encoding):
raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
-def _guess_descriptor_type(resource):
+def _guess_descriptor_type(resource: str) -> str:
# Attempts to determine the descriptor type based on the resource url. This
# raises a ValueError if the resource isn't recognized.
diff --git a/stem/descriptor/router_status_entry.py b/stem/descriptor/router_status_entry.py
index c2d8dd07..20822c82 100644
--- a/stem/descriptor/router_status_entry.py
+++ b/stem/descriptor/router_status_entry.py
@@ -27,6 +27,8 @@ import io
import stem.exit_policy
import stem.util.str_tools
+from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type
+
from stem.descriptor import (
KEYWORD_LINE,
Descriptor,
@@ -44,7 +46,7 @@ from stem.descriptor import (
_parse_pr_line = _parse_protocol_line('pr', 'protocols')
-def _parse_file(document_file, validate, entry_class, entry_keyword = 'r', start_position = None, end_position = None, section_end_keywords = (), extra_args = ()):
+def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], entry_keyword: str = 'r', start_position: int = None, end_position: int = None, section_end_keywords: Sequence[str] = (), extra_args: Sequence[str] = ()) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']:
"""
Reads a range of the document_file containing some number of entry_class
instances. We deliminate the entry_class entries by the keyword on their
@@ -111,7 +113,7 @@ def _parse_file(document_file, validate, entry_class, entry_keyword = 'r', start
break
-def _parse_r_line(descriptor, entries):
+def _parse_r_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# Parses a RouterStatusEntry's 'r' line. They're very nearly identical for
# all current entry types (v2, v3, and microdescriptor v3) with one little
# wrinkle: only the microdescriptor flavor excludes a 'digest' field.
@@ -163,7 +165,7 @@ def _parse_r_line(descriptor, entries):
raise ValueError("Publication time time wasn't parsable: r %s" % value)
-def _parse_a_line(descriptor, entries):
+def _parse_a_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "a" SP address ":" portlist
# example: a [2001:888:2133:0:82:94:251:204]:9001
@@ -186,7 +188,7 @@ def _parse_a_line(descriptor, entries):
descriptor.or_addresses = or_addresses
-def _parse_s_line(descriptor, entries):
+def _parse_s_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "s" Flags
# example: s Named Running Stable Valid
@@ -201,7 +203,7 @@ def _parse_s_line(descriptor, entries):
raise ValueError("%s had extra whitespace on its 's' line: s %s" % (descriptor._name(), value))
-def _parse_v_line(descriptor, entries):
+def _parse_v_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "v" version
# example: v Tor 0.2.2.35
#
@@ -219,7 +221,7 @@ def _parse_v_line(descriptor, entries):
raise ValueError('%s has a malformed tor version (%s): v %s' % (descriptor._name(), exc, value))
-def _parse_w_line(descriptor, entries):
+def _parse_w_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "w" "Bandwidth=" INT ["Measured=" INT] ["Unmeasured=1"]
# example: w Bandwidth=7980
@@ -266,7 +268,7 @@ def _parse_w_line(descriptor, entries):
descriptor.unrecognized_bandwidth_entries = unrecognized_bandwidth_entries
-def _parse_p_line(descriptor, entries):
+def _parse_p_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "p" ("accept" / "reject") PortList
#
# examples:
@@ -282,7 +284,7 @@ def _parse_p_line(descriptor, entries):
raise ValueError('%s exit policy is malformed (%s): p %s' % (descriptor._name(), exc, value))
-def _parse_id_line(descriptor, entries):
+def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "id" "ed25519" ed25519-identity
#
# examples:
@@ -305,7 +307,7 @@ def _parse_id_line(descriptor, entries):
raise ValueError("'id' lines should contain both the key type and digest: id %s" % value)
-def _parse_m_line(descriptor, entries):
+def _parse_m_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "m" methods 1*(algorithm "=" digest)
# example: m 8,9,10,11,12 sha256=g1vx9si329muxV3tquWIXXySNOIwRGMeAESKs/v4DWs
@@ -339,14 +341,14 @@ def _parse_m_line(descriptor, entries):
descriptor.microdescriptor_hashes = all_hashes
-def _parse_microdescriptor_m_line(descriptor, entries):
+def _parse_microdescriptor_m_line(descriptor: 'stem.descriptor.Descriptor', entries):
# "m" digest
# example: m aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70
descriptor.microdescriptor_digest = _value('m', entries)
-def _base64_to_hex(identity, check_if_fingerprint = True):
+def _base64_to_hex(identity: str, check_if_fingerprint: bool = True) -> str:
"""
Decodes a base64 value to hex. For example...
@@ -420,7 +422,7 @@ class RouterStatusEntry(Descriptor):
}
@classmethod
- def from_str(cls, content, **kwargs):
+ def from_str(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], content: str, **kwargs: Any) -> 'stem.descriptor.router_status_entry.RouterStatusEntry':
# Router status entries don't have their own @type annotation, so to make
# our subclass from_str() work we need to do the type inferencing ourself.
@@ -440,14 +442,14 @@ class RouterStatusEntry(Descriptor):
else:
raise ValueError("Descriptor.from_str() expected a single descriptor, but had %i instead. Please include 'multiple = True' if you want a list of results instead." % len(results))
- def __init__(self, content, validate = False, document = None):
+ def __init__(self, content: str, validate: bool = False, document: Optional['stem.descriptor.NetworkStatusDocument'] = None) -> None:
"""
Parse a router descriptor in a network status document.
:param str content: router descriptor content to be parsed
- :param NetworkStatusDocument document: document this descriptor came from
:param bool validate: checks the validity of the content if **True**, skips
these checks otherwise
+ :param NetworkStatusDocument document: document this descriptor came from
:raises: **ValueError** if the descriptor data is invalid
"""
@@ -472,21 +474,21 @@ class RouterStatusEntry(Descriptor):
else:
self._entries = entries
- def _name(self, is_plural = False):
+ def _name(self, is_plural: bool = False) -> str:
"""
Name for this descriptor type.
"""
return 'Router status entries' if is_plural else 'Router status entry'
- def _required_fields(self):
+ def _required_fields(self) -> Tuple[str]:
"""
Provides lines that must appear in the descriptor.
"""
return ()
- def _single_fields(self):
+ def _single_fields(self) -> Tuple[str]:
"""
Provides lines that can only appear in the descriptor once.
"""
@@ -512,18 +514,18 @@ class RouterStatusEntryV2(RouterStatusEntry):
})
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())),
))
- def _name(self, is_plural = False):
+ def _name(self, is_plural: bool = False) -> str:
return 'Router status entries (v2)' if is_plural else 'Router status entry (v2)'
- def _required_fields(self):
- return ('r')
+ def _required_fields(self) -> Tuple[str]:
+ return ('r',)
- def _single_fields(self):
+ def _single_fields(self) -> Tuple[str]:
return ('r', 's', 'v')
@@ -603,19 +605,19 @@ class RouterStatusEntryV3(RouterStatusEntry):
})
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())),
('s', 'Fast Named Running Stable Valid'),
))
- def _name(self, is_plural = False):
+ def _name(self, is_plural: bool = False) -> str:
return 'Router status entries (v3)' if is_plural else 'Router status entry (v3)'
- def _required_fields(self):
+ def _required_fields(self) -> Tuple[str]:
return ('r', 's')
- def _single_fields(self):
+ def _single_fields(self) -> Tuple[str]:
return ('r', 's', 'v', 'w', 'p', 'pr')
@@ -668,18 +670,18 @@ class RouterStatusEntryMicroV3(RouterStatusEntry):
})
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryMicroV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('r', '%s ARIJF2zbqirB9IwsW0mQznccWww %s %s 9001 9030' % (_random_nickname(), _random_date(), _random_ipv4_address())),
('m', 'aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70'),
('s', 'Fast Guard HSDir Named Running Stable V2Dir Valid'),
))
- def _name(self, is_plural = False):
+ def _name(self, is_plural: bool = False) -> str:
return 'Router status entries (micro v3)' if is_plural else 'Router status entry (micro v3)'
- def _required_fields(self):
+ def _required_fields(self) -> Tuple[str]:
return ('r', 's', 'm')
- def _single_fields(self):
+ def _single_fields(self) -> Tuple[str]:
return ('r', 's', 'v', 'w', 'm', 'pr')
diff --git a/stem/descriptor/server_descriptor.py b/stem/descriptor/server_descriptor.py
index 955b8429..11b44972 100644
--- a/stem/descriptor/server_descriptor.py
+++ b/stem/descriptor/server_descriptor.py
@@ -61,6 +61,7 @@ import stem.version
from stem.descriptor.certificate import Ed25519Certificate
from stem.descriptor.router_status_entry import RouterStatusEntryV3
+from typing import Any, BinaryIO, Dict, Iterator, Optional, Mapping, Sequence, Tuple, Type, Union
from stem.descriptor import (
PGP_BLOCK_END,
@@ -139,11 +140,11 @@ REJECT_ALL_POLICY = stem.exit_policy.ExitPolicy('reject *:*')
DEFAULT_BRIDGE_DISTRIBUTION = 'any'
-def _truncated_b64encode(content):
+def _truncated_b64encode(content: bytes) -> str:
return stem.util.str_tools._to_unicode(base64.b64encode(content).rstrip(b'='))
-def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs):
+def _parse_file(descriptor_file: BinaryIO, is_bridge: bool = False, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.server_descriptor.ServerDescriptor']:
"""
Iterates over the server descriptors in a file.
@@ -220,7 +221,7 @@ def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs):
break # done parsing descriptors
-def _parse_router_line(descriptor, entries):
+def _parse_router_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "router" nickname address ORPort SocksPort DirPort
value = _value('router', entries)
@@ -246,7 +247,7 @@ def _parse_router_line(descriptor, entries):
descriptor.dir_port = None if router_comp[4] == '0' else int(router_comp[4])
-def _parse_bandwidth_line(descriptor, entries):
+def _parse_bandwidth_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "bandwidth" bandwidth-avg bandwidth-burst bandwidth-observed
value = _value('bandwidth', entries)
@@ -266,7 +267,7 @@ def _parse_bandwidth_line(descriptor, entries):
descriptor.observed_bandwidth = int(bandwidth_comp[2])
-def _parse_platform_line(descriptor, entries):
+def _parse_platform_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "platform" string
_parse_bytes_line('platform', 'platform')(descriptor, entries)
@@ -292,7 +293,7 @@ def _parse_platform_line(descriptor, entries):
pass
-def _parse_fingerprint_line(descriptor, entries):
+def _parse_fingerprint_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# This is forty hex digits split into space separated groups of four.
# Checking that we match this pattern.
@@ -309,7 +310,7 @@ def _parse_fingerprint_line(descriptor, entries):
descriptor.fingerprint = fingerprint
-def _parse_extrainfo_digest_line(descriptor, entries):
+def _parse_extrainfo_digest_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value('extra-info-digest', entries)
digest_comp = value.split(' ')
@@ -320,7 +321,7 @@ def _parse_extrainfo_digest_line(descriptor, entries):
descriptor.extra_info_sha256_digest = digest_comp[1] if len(digest_comp) >= 2 else None
-def _parse_hibernating_line(descriptor, entries):
+def _parse_hibernating_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
# "hibernating" 0|1 (in practice only set if one)
value = _value('hibernating', entries)
@@ -331,7 +332,7 @@ def _parse_hibernating_line(descriptor, entries):
descriptor.hibernating = value == '1'
-def _parse_protocols_line(descriptor, entries):
+def _parse_protocols_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value('protocols', entries)
protocols_match = re.match('^Link (.*) Circuit (.*)$', value)
@@ -343,7 +344,7 @@ def _parse_protocols_line(descriptor, entries):
descriptor.circuit_protocols = circuit_versions.split(' ')
-def _parse_or_address_line(descriptor, entries):
+def _parse_or_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
all_values = _values('or-address', entries)
or_addresses = []
@@ -366,7 +367,7 @@ def _parse_or_address_line(descriptor, entries):
descriptor.or_addresses = or_addresses
-def _parse_history_line(keyword, history_end_attribute, history_interval_attribute, history_values_attribute, descriptor, entries):
+def _parse_history_line(keyword: str, history_end_attribute: str, history_interval_attribute: str, history_values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
value = _value(keyword, entries)
timestamp, interval, remainder = stem.descriptor.extrainfo_descriptor._parse_timestamp_and_interval(keyword, value)
@@ -383,7 +384,7 @@ def _parse_history_line(keyword, history_end_attribute, history_interval_attribu
setattr(descriptor, history_values_attribute, history_values)
-def _parse_exit_policy(descriptor, entries):
+def _parse_exit_policy(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
if hasattr(descriptor, '_unparsed_exit_policy'):
if descriptor._unparsed_exit_policy and stem.util.str_tools._to_unicode(descriptor._unparsed_exit_policy[0]) == 'reject *:*':
descriptor.exit_policy = REJECT_ALL_POLICY
@@ -576,7 +577,7 @@ class ServerDescriptor(Descriptor):
'eventdns': _parse_eventdns_line,
}
- def __init__(self, raw_contents, validate = False):
+ def __init__(self, raw_contents: str, validate: bool = False) -> None:
"""
Server descriptor constructor, created from an individual relay's
descriptor content (as provided by 'GETINFO desc/*', cached descriptors,
@@ -621,7 +622,7 @@ class ServerDescriptor(Descriptor):
else:
self._entries = entries
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX):
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
"""
Digest of this descriptor's content. These are referenced by...
@@ -641,7 +642,7 @@ class ServerDescriptor(Descriptor):
raise NotImplementedError('Unsupported Operation: this should be implemented by the ServerDescriptor subclass')
- def _check_constraints(self, entries):
+ def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None:
"""
Does a basic check that the entries conform to this descriptor type's
constraints.
@@ -679,16 +680,16 @@ class ServerDescriptor(Descriptor):
# Constraints that the descriptor must meet to be valid. These can be None if
# not applicable.
- def _required_fields(self):
+ def _required_fields(self) -> Tuple[str]:
return REQUIRED_FIELDS
- def _single_fields(self):
+ def _single_fields(self) -> Tuple[str]:
return REQUIRED_FIELDS + SINGLE_FIELDS
- def _first_keyword(self):
+ def _first_keyword(self) -> str:
return 'router'
- def _last_keyword(self):
+ def _last_keyword(self) -> str:
return 'router-signature'
@@ -753,7 +754,7 @@ class RelayDescriptor(ServerDescriptor):
'router-signature': _parse_router_signature_line,
})
- def __init__(self, raw_contents, validate = False, skip_crypto_validation = False):
+ def __init__(self, raw_contents: str, validate: bool = False, skip_crypto_validation: bool = False) -> None:
super(RelayDescriptor, self).__init__(raw_contents, validate)
if validate:
@@ -785,7 +786,7 @@ class RelayDescriptor(ServerDescriptor):
pass # cryptography module unavailable
@classmethod
- def content(cls, attr = None, exclude = (), sign = False, signing_key = None, exit_policy = None):
+ def content(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> str:
if attr is None:
attr = {}
@@ -827,15 +828,18 @@ class RelayDescriptor(ServerDescriptor):
))
@classmethod
- def create(cls, attr = None, exclude = (), validate = True, sign = False, signing_key = None, exit_policy = None):
+ def create(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> 'stem.descriptor.server_descriptor.RelayDescriptor':
return cls(cls.content(attr, exclude, sign, signing_key, exit_policy), validate = validate, skip_crypto_validation = not sign)
@functools.lru_cache()
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX):
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
"""
Provides the digest of our descriptor's content.
- :returns: the digest string encoded in uppercase hex
+ :param stem.descriptor.DigestHash hash_type: digest hashing algorithm
+ :param stem.descriptor.DigestEncoding encoding: digest encoding
+
+ :returns: **hashlib.HASH** or **str** based on our encoding argument
:raises: ValueError if the digest cannot be calculated
"""
@@ -849,7 +853,7 @@ class RelayDescriptor(ServerDescriptor):
else:
raise NotImplementedError('Server descriptor digests are only available in sha1 and sha256, not %s' % hash_type)
- def make_router_status_entry(self):
+ def make_router_status_entry(self) -> 'stem.descriptor.router_status_entry.RouterStatusEntryV3':
"""
Provides a RouterStatusEntryV3 for this descriptor content.
@@ -888,12 +892,12 @@ class RelayDescriptor(ServerDescriptor):
return RouterStatusEntryV3.create(attr)
@functools.lru_cache()
- def _onion_key_crosscert_digest(self):
+ def _onion_key_crosscert_digest(self) -> str:
"""
Provides the digest of the onion-key-crosscert data. This consists of the
RSA identity key sha1 and ed25519 identity key.
- :returns: **unicode** digest encoded in uppercase hex
+ :returns: **str** digest encoded in uppercase hex
:raises: ValueError if the digest cannot be calculated
"""
@@ -902,7 +906,7 @@ class RelayDescriptor(ServerDescriptor):
data = signing_key_digest + base64.b64decode(stem.util.str_tools._to_bytes(self.ed25519_master_key) + b'=')
return stem.util.str_tools._to_unicode(binascii.hexlify(data).upper())
- def _check_constraints(self, entries):
+ def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None:
super(RelayDescriptor, self)._check_constraints(entries)
if self.certificate:
@@ -941,7 +945,7 @@ class BridgeDescriptor(ServerDescriptor):
})
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls: Type['stem.descriptor.server_descriptor.BridgeDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
return _descriptor_content(attr, exclude, (
('router', '%s %s 9001 0 0' % (_random_nickname(), _random_ipv4_address())),
('router-digest', '006FD96BA35E7785A6A3B8B75FE2E2435A13BDB4'),
@@ -950,13 +954,13 @@ class BridgeDescriptor(ServerDescriptor):
('reject', '*:*'),
))
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX):
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX:
return self._digest
else:
raise NotImplementedError('Bridge server descriptor digests are only available as sha1/hex, not %s/%s' % (hash_type, encoding))
- def is_scrubbed(self):
+ def is_scrubbed(self) -> bool:
"""
Checks if we've been properly scrubbed in accordance with the `bridge
descriptor specification
@@ -969,7 +973,7 @@ class BridgeDescriptor(ServerDescriptor):
return self.get_scrubbing_issues() == []
@functools.lru_cache()
- def get_scrubbing_issues(self):
+ def get_scrubbing_issues(self) -> Sequence[str]:
"""
Provides issues with our scrubbing.
@@ -1003,7 +1007,7 @@ class BridgeDescriptor(ServerDescriptor):
return issues
- def _required_fields(self):
+ def _required_fields(self) -> Tuple[str]:
# bridge required fields are the same as a relay descriptor, minus items
# excluded according to the format page
@@ -1019,8 +1023,8 @@ class BridgeDescriptor(ServerDescriptor):
return tuple(included_fields + [f for f in REQUIRED_FIELDS if f not in excluded_fields])
- def _single_fields(self):
+ def _single_fields(self) -> str:
return self._required_fields() + SINGLE_FIELDS
- def _last_keyword(self):
+ def _last_keyword(self) -> str:
return None
diff --git a/stem/descriptor/tordnsel.py b/stem/descriptor/tordnsel.py
index d0f57b93..c36e343d 100644
--- a/stem/descriptor/tordnsel.py
+++ b/stem/descriptor/tordnsel.py
@@ -14,6 +14,8 @@ import stem.util.connection
import stem.util.str_tools
import stem.util.tor_tools
+from typing import Any, BinaryIO, Dict, Iterator, Sequence
+
from stem.descriptor import (
Descriptor,
_read_until_keywords,
@@ -21,7 +23,7 @@ from stem.descriptor import (
)
-def _parse_file(tordnsel_file, validate = False, **kwargs):
+def _parse_file(tordnsel_file: BinaryIO, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']:
"""
Iterates over a tordnsel file.
@@ -62,7 +64,7 @@ class TorDNSEL(Descriptor):
TYPE_ANNOTATION_NAME = 'tordnsel'
- def __init__(self, raw_contents, validate):
+ def __init__(self, raw_contents: str, validate: bool) -> None:
super(TorDNSEL, self).__init__(raw_contents)
raw_contents = stem.util.str_tools._to_unicode(raw_contents)
entries = _descriptor_components(raw_contents, validate)
@@ -74,8 +76,7 @@ class TorDNSEL(Descriptor):
self._parse(entries, validate)
- def _parse(self, entries, validate):
-
+ def _parse(self, entries: Dict[str, Sequence[str]], validate: bool) -> None:
for keyword, values in list(entries.items()):
value, block_type, block_content = values[0]
diff --git a/stem/directory.py b/stem/directory.py
index 67079c80..f96adfbb 100644
--- a/stem/directory.py
+++ b/stem/directory.py
@@ -49,6 +49,7 @@ import stem.util
import stem.util.conf
from stem.util import connection, str_tools, tor_tools
+from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Pattern, Sequence, Tuple
GITWEB_AUTHORITY_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/auth_dirs.inc'
GITWEB_FALLBACK_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/fallback_dirs.inc'
@@ -68,7 +69,7 @@ FALLBACK_EXTRAINFO = re.compile('/\\* extrainfo=([0-1]) \\*/')
FALLBACK_IPV6 = re.compile('" ipv6=\\[([\\da-f:]+)\\]:(\\d+)"')
-def _match_with(lines, regexes, required = None):
+def _match_with(lines: Sequence[str], regexes: Sequence[Pattern], required: Optional[bool] = None) -> Dict[Pattern, Tuple[str]]:
"""
Scans the given content against a series of regex matchers, providing back a
mapping of regexes to their capture groups. This maping is with the value if
@@ -101,7 +102,7 @@ def _match_with(lines, regexes, required = None):
return matches
-def _directory_entries(lines, pop_section_func, regexes, required = None):
+def _directory_entries(lines: Sequence[str], pop_section_func: Callable[[Sequence[str]], Sequence[str]], regexes: Sequence[Pattern], required: bool = None) -> Iterator[Dict[Pattern, Tuple[str]]]:
next_section = pop_section_func(lines)
while next_section:
@@ -133,7 +134,7 @@ class Directory(object):
ORPort, or **None** if it doesn't have one
"""
- def __init__(self, address, or_port, dir_port, fingerprint, nickname, orport_v6):
+ def __init__(self, address: str, or_port: int, dir_port: int, fingerprint: str, nickname: str, orport_v6: str) -> None:
identifier = '%s (%s)' % (fingerprint, nickname) if nickname else fingerprint
if not connection.is_valid_ipv4_address(address):
@@ -163,7 +164,7 @@ class Directory(object):
self.orport_v6 = (orport_v6[0], int(orport_v6[1])) if orport_v6 else None
@staticmethod
- def from_cache():
+ def from_cache() -> Dict[str, 'stem.directory.Directory']:
"""
Provides cached Tor directory information. This information is hardcoded
into Tor and occasionally changes, so the information provided by this
@@ -181,7 +182,7 @@ class Directory(object):
raise NotImplementedError('Unsupported Operation: this should be implemented by the Directory subclass')
@staticmethod
- def from_remote(timeout = 60):
+ def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Directory']:
"""
Reads and parses tor's directory data `from gitweb.torproject.org <https://gitweb.torproject.org/>`_.
Note that while convenient, this reliance on GitWeb means you should alway
@@ -209,13 +210,13 @@ class Directory(object):
raise NotImplementedError('Unsupported Operation: this should be implemented by the Directory subclass')
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'address', 'or_port', 'dir_port', 'fingerprint', 'nickname', 'orport_v6')
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, Directory) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@@ -231,7 +232,7 @@ class Authority(Directory):
:var str v3ident: identity key fingerprint used to sign votes and consensus
"""
- def __init__(self, address = None, or_port = None, dir_port = None, fingerprint = None, nickname = None, orport_v6 = None, v3ident = None):
+ def __init__(self, address: Optional[str] = None, or_port: Optional[int] = None, dir_port: Optional[int] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, orport_v6: Optional[int] = None, v3ident: Optional[str] = None) -> None:
super(Authority, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6)
if v3ident and not tor_tools.is_valid_fingerprint(v3ident):
@@ -241,11 +242,11 @@ class Authority(Directory):
self.v3ident = v3ident
@staticmethod
- def from_cache():
+ def from_cache() -> Dict[str, 'stem.directory.Authority']:
return dict(DIRECTORY_AUTHORITIES)
@staticmethod
- def from_remote(timeout = 60):
+ def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Authority']:
try:
lines = str_tools._to_unicode(urllib.request.urlopen(GITWEB_AUTHORITY_URL, timeout = timeout).read()).splitlines()
@@ -284,7 +285,7 @@ class Authority(Directory):
return results
@staticmethod
- def _pop_section(lines):
+ def _pop_section(lines: Sequence[str]) -> Sequence[str]:
"""
Provides the next authority entry.
"""
@@ -299,13 +300,13 @@ class Authority(Directory):
return section_lines
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'v3ident', parent = Directory, cache = True)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, Authority) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@@ -348,13 +349,13 @@ class Fallback(Directory):
:var collections.OrderedDict header: metadata about the fallback directory file this originated from
"""
- def __init__(self, address = None, or_port = None, dir_port = None, fingerprint = None, nickname = None, has_extrainfo = False, orport_v6 = None, header = None):
+ def __init__(self, address: Optional[str] = None, or_port: Optional[int] = None, dir_port: Optional[int] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, has_extrainfo: bool = False, orport_v6: Optional[int] = None, header: Optional[Mapping[str, str]] = None) -> None:
super(Fallback, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6)
self.has_extrainfo = has_extrainfo
self.header = collections.OrderedDict(header) if header else collections.OrderedDict()
@staticmethod
- def from_cache(path = FALLBACK_CACHE_PATH):
+ def from_cache(path: str = FALLBACK_CACHE_PATH) -> Dict[str, 'stem.directory.Fallback']:
conf = stem.util.conf.Config()
conf.load(path)
headers = collections.OrderedDict([(k.split('.', 1)[1], conf.get(k)) for k in conf.keys() if k.startswith('header.')])
@@ -393,7 +394,7 @@ class Fallback(Directory):
return results
@staticmethod
- def from_remote(timeout = 60):
+ def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Fallback']:
try:
lines = str_tools._to_unicode(urllib.request.urlopen(GITWEB_FALLBACK_URL, timeout = timeout).read()).splitlines()
@@ -450,7 +451,7 @@ class Fallback(Directory):
return results
@staticmethod
- def _pop_section(lines):
+ def _pop_section(lines: Sequence[str]) -> Sequence[str]:
"""
Provides lines up through the next divider. This excludes lines with just a
comma since they're an artifact of these being C strings.
@@ -470,7 +471,7 @@ class Fallback(Directory):
return section_lines
@staticmethod
- def _write(fallbacks, tor_commit, stem_commit, headers, path = FALLBACK_CACHE_PATH):
+ def _write(fallbacks: Dict[str, 'stem.directory.Fallback'], tor_commit: str, stem_commit: str, headers: Mapping[str, str], path: str = FALLBACK_CACHE_PATH) -> None:
"""
Persists fallback directories to a location in a way that can be read by
from_cache().
@@ -503,17 +504,17 @@ class Fallback(Directory):
conf.save(path)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'has_extrainfo', 'header', parent = Directory, cache = True)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, Fallback) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
-def _fallback_directory_differences(previous_directories, new_directories):
+def _fallback_directory_differences(previous_directories: Sequence['stem.directory.Dirctory'], new_directories: Sequence['stem.directory.Directory']) -> str:
"""
Provides a description of how fallback directories differ.
"""
diff --git a/stem/exit_policy.py b/stem/exit_policy.py
index ddcd7dfd..076611d2 100644
--- a/stem/exit_policy.py
+++ b/stem/exit_policy.py
@@ -71,6 +71,8 @@ import stem.util.connection
import stem.util.enum
import stem.util.str_tools
+from typing import Any, Iterator, Optional, Sequence, Union
+
AddressType = stem.util.enum.Enum(('WILDCARD', 'Wildcard'), ('IPv4', 'IPv4'), ('IPv6', 'IPv6'))
# Addresses aliased by the 'private' policy. From the tor man page...
@@ -89,7 +91,7 @@ PRIVATE_ADDRESSES = (
)
-def _flag_private_rules(rules):
+def _flag_private_rules(rules: Sequence['ExitPolicyRule']) -> None:
"""
Determine if part of our policy was expanded from the 'private' keyword. This
doesn't differentiate if this actually came from the 'private' keyword or a
@@ -139,7 +141,7 @@ def _flag_private_rules(rules):
last_rule._is_private = True
-def _flag_default_rules(rules):
+def _flag_default_rules(rules: Sequence['ExitPolicyRule']) -> None:
"""
Determine if part of our policy ends with the defaultly appended suffix.
"""
@@ -162,7 +164,7 @@ class ExitPolicy(object):
entries that make up this policy
"""
- def __init__(self, *rules):
+ def __init__(self, *rules: Union[str, 'stem.exit_policy.ExitPolicyRule']) -> None:
# sanity check the types
for rule in rules:
@@ -196,7 +198,7 @@ class ExitPolicy(object):
self._is_allowed_default = True
@functools.lru_cache()
- def can_exit_to(self, address = None, port = None, strict = False):
+ def can_exit_to(self, address: Optional[str] = None, port: Optional[int] = None, strict: bool = False) -> bool:
"""
Checks if this policy allows exiting to a given destination or not. If the
address or port is omitted then this will check if we're allowed to exit to
@@ -220,7 +222,7 @@ class ExitPolicy(object):
return self._is_allowed_default
@functools.lru_cache()
- def is_exiting_allowed(self):
+ def is_exiting_allowed(self) -> bool:
"""
Provides **True** if the policy allows exiting whatsoever, **False**
otherwise.
@@ -242,7 +244,7 @@ class ExitPolicy(object):
return self._is_allowed_default
@functools.lru_cache()
- def summary(self):
+ def summary(self) -> str:
"""
Provides a short description of our policy chain, similar to a
microdescriptor. This excludes entries that don't cover all IP
@@ -320,7 +322,7 @@ class ExitPolicy(object):
return (label_prefix + ', '.join(display_ranges)).strip()
- def has_private(self):
+ def has_private(self) -> bool:
"""
Checks if we have any rules expanded from the 'private' keyword. Tor
appends these by default to the start of the policy and includes a dynamic
@@ -338,7 +340,7 @@ class ExitPolicy(object):
return False
- def strip_private(self):
+ def strip_private(self) -> 'ExitPolicy':
"""
Provides a copy of this policy without 'private' policy entries.
@@ -349,7 +351,7 @@ class ExitPolicy(object):
return ExitPolicy(*[rule for rule in self._get_rules() if not rule.is_private()])
- def has_default(self):
+ def has_default(self) -> bool:
"""
Checks if we have the default policy suffix.
@@ -364,7 +366,7 @@ class ExitPolicy(object):
return False
- def strip_default(self):
+ def strip_default(self) -> 'ExitPolicy':
"""
Provides a copy of this policy without the default policy suffix.
@@ -375,7 +377,7 @@ class ExitPolicy(object):
return ExitPolicy(*[rule for rule in self._get_rules() if not rule.is_default()])
- def _get_rules(self):
+ def _get_rules(self) -> Sequence['stem.exit_policy.ExitPolicyRule']:
# Local reference to our input_rules so this can be lock free. Otherwise
# another thread might unset our input_rules while processing them.
@@ -437,18 +439,18 @@ class ExitPolicy(object):
return self._rules
- def __len__(self):
+ def __len__(self) -> int:
return len(self._get_rules())
- def __iter__(self):
+ def __iter__(self) -> Iterator['stem.exit_policy.ExitPolicyRule']:
for rule in self._get_rules():
yield rule
@functools.lru_cache()
- def __str__(self):
+ def __str__(self) -> str:
return ', '.join([str(rule) for rule in self._get_rules()])
- def __hash__(self):
+ def __hash__(self) -> int:
if self._hash is None:
my_hash = 0
@@ -460,10 +462,10 @@ class ExitPolicy(object):
return self._hash
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, ExitPolicy) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@@ -495,7 +497,7 @@ class MicroExitPolicy(ExitPolicy):
:param str policy: policy string that describes this policy
"""
- def __init__(self, policy):
+ def __init__(self, policy: str) -> None:
# Microdescriptor policies are of the form...
#
# MicrodescriptrPolicy ::= ("accept" / "reject") SP PortList NL
@@ -537,16 +539,16 @@ class MicroExitPolicy(ExitPolicy):
super(MicroExitPolicy, self).__init__(*rules)
self._is_allowed_default = not self.is_accept
- def __str__(self):
+ def __str__(self) -> str:
return self._policy
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(str(self))
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, MicroExitPolicy) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@@ -580,7 +582,7 @@ class ExitPolicyRule(object):
:raises: **ValueError** if input isn't a valid tor exit policy rule
"""
- def __init__(self, rule):
+ def __init__(self, rule: str) -> None:
# policy ::= "accept[6]" exitpattern | "reject[6]" exitpattern
# exitpattern ::= addrspec ":" portspec
@@ -634,7 +636,7 @@ class ExitPolicyRule(object):
self._is_private = False
self._is_default_suffix = False
- def is_address_wildcard(self):
+ def is_address_wildcard(self) -> bool:
"""
**True** if we'll match against **any** address, **False** otherwise.
@@ -646,7 +648,7 @@ class ExitPolicyRule(object):
return self._address_type == _address_type_to_int(AddressType.WILDCARD)
- def is_port_wildcard(self):
+ def is_port_wildcard(self) -> bool:
"""
**True** if we'll match against any port, **False** otherwise.
@@ -655,7 +657,7 @@ class ExitPolicyRule(object):
return self.min_port in (0, 1) and self.max_port == 65535
- def is_match(self, address = None, port = None, strict = False):
+ def is_match(self, address: Optional[str] = None, port: Optional[int] = None, strict: bool = False) -> bool:
"""
**True** if we match against the given destination, **False** otherwise. If
the address or port is omitted then this will check if we're allowed to
@@ -726,7 +728,7 @@ class ExitPolicyRule(object):
else:
return True
- def get_address_type(self):
+ def get_address_type(self) -> AddressType:
"""
Provides the :data:`~stem.exit_policy.AddressType` for our policy.
@@ -735,7 +737,7 @@ class ExitPolicyRule(object):
return _int_to_address_type(self._address_type)
- def get_mask(self, cache = True):
+ def get_mask(self, cache: bool = True) -> str:
"""
Provides the address represented by our mask. This is **None** if our
address type is a wildcard.
@@ -765,7 +767,7 @@ class ExitPolicyRule(object):
return self._mask
- def get_masked_bits(self):
+ def get_masked_bits(self) -> int:
"""
Provides the number of bits our subnet mask represents. This is **None** if
our mask can't have a bit representation.
@@ -775,7 +777,7 @@ class ExitPolicyRule(object):
return self._masked_bits
- def is_private(self):
+ def is_private(self) -> bool:
"""
Checks if this rule was expanded from the 'private' policy keyword.
@@ -786,7 +788,7 @@ class ExitPolicyRule(object):
return self._is_private
- def is_default(self):
+ def is_default(self) -> bool:
"""
Checks if this rule belongs to the default exit policy suffix.
@@ -798,7 +800,7 @@ class ExitPolicyRule(object):
return self._is_default_suffix
@functools.lru_cache()
- def __str__(self):
+ def __str__(self) -> str:
"""
Provides the string representation of our policy. This does not
necessarily match the rule that we were constructed from (due to things
@@ -842,18 +844,18 @@ class ExitPolicyRule(object):
return label
@functools.lru_cache()
- def _get_mask_bin(self):
+ def _get_mask_bin(self) -> int:
# provides an integer representation of our mask
return int(stem.util.connection._address_to_binary(self.get_mask(False)), 2)
@functools.lru_cache()
- def _get_address_bin(self):
+ def _get_address_bin(self) -> int:
# provides an integer representation of our address
return stem.util.connection.address_to_int(self.address) & self._get_mask_bin()
- def _apply_addrspec(self, rule, addrspec, is_ipv6_only):
+ def _apply_addrspec(self, rule: str, addrspec: str, is_ipv6_only: bool) -> None:
# Parses the addrspec...
# addrspec ::= "*" | ip4spec | ip6spec
@@ -924,7 +926,7 @@ class ExitPolicyRule(object):
else:
raise ValueError("'%s' isn't a wildcard, IPv4, or IPv6 address: %s" % (addrspec, rule))
- def _apply_portspec(self, rule, portspec):
+ def _apply_portspec(self, rule: str, portspec: str) -> None:
# Parses the portspec...
# portspec ::= "*" | port | port "-" port
# port ::= an integer between 1 and 65535, inclusive.
@@ -955,24 +957,24 @@ class ExitPolicyRule(object):
else:
raise ValueError("Port value isn't a wildcard, integer, or range: %s" % rule)
- def __hash__(self):
+ def __hash__(self) -> int:
if self._hash is None:
self._hash = stem.util._hash_attr(self, 'is_accept', 'address', 'min_port', 'max_port') * 1024 + hash(self.get_mask(False))
return self._hash
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, ExitPolicyRule) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
-def _address_type_to_int(address_type):
+def _address_type_to_int(address_type: AddressType) -> int:
return AddressType.index_of(address_type)
-def _int_to_address_type(address_type_int):
+def _int_to_address_type(address_type_int: int) -> AddressType:
return list(AddressType)[address_type_int]
@@ -981,32 +983,32 @@ class MicroExitPolicyRule(ExitPolicyRule):
Lighter weight ExitPolicyRule derivative for microdescriptors.
"""
- def __init__(self, is_accept, min_port, max_port):
+ def __init__(self, is_accept: bool, min_port: int, max_port: int) -> None:
self.is_accept = is_accept
self.address = None # wildcard address
self.min_port = min_port
self.max_port = max_port
self._skip_rule = False
- def is_address_wildcard(self):
+ def is_address_wildcard(self) -> bool:
return True
- def get_address_type(self):
+ def get_address_type(self) -> AddressType:
return AddressType.WILDCARD
- def get_mask(self, cache = True):
+ def get_mask(self, cache = True) -> str:
return None
- def get_masked_bits(self):
+ def get_masked_bits(self) -> int:
return None
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'is_accept', 'min_port', 'max_port', cache = True)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, MicroExitPolicyRule) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index 07a5f573..2a3cff18 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -38,11 +38,11 @@ uses_settings = stem.util.conf.uses_settings('stem_interpreter', settings_path)
@uses_settings
-def msg(message, config, **attr):
+def msg(message: str, config: 'stem.util.conf.Config', **attr: str) -> str:
return config.get(message).format(**attr)
-def main():
+def main() -> None:
try:
import readline
except ImportError:
@@ -135,7 +135,7 @@ def main():
controller.msg(args.run_cmd)
try:
- raw_input()
+ input()
except (KeyboardInterrupt, stem.SocketClosed):
pass
else:
diff --git a/stem/interpreter/arguments.py b/stem/interpreter/arguments.py
index 00c8891d..8ac1c2c1 100644
--- a/stem/interpreter/arguments.py
+++ b/stem/interpreter/arguments.py
@@ -12,6 +12,8 @@ import os
import stem.interpreter
import stem.util.connection
+from typing import NamedTuple, Sequence
+
DEFAULT_ARGS = {
'control_address': '127.0.0.1',
'control_port': 'default',
@@ -29,7 +31,7 @@ OPT = 'i:s:h'
OPT_EXPANDED = ['interface=', 'socket=', 'tor=', 'run=', 'no-color', 'help']
-def parse(argv):
+def parse(argv: Sequence[str]) -> NamedTuple:
"""
Parses our arguments, providing a named tuple with their values.
@@ -90,7 +92,7 @@ def parse(argv):
return Args(**args)
-def get_help():
+def get_help() -> str:
"""
Provides our --help usage information.
diff --git a/stem/interpreter/autocomplete.py b/stem/interpreter/autocomplete.py
index 9f5f2659..671085a7 100644
--- a/stem/interpreter/autocomplete.py
+++ b/stem/interpreter/autocomplete.py
@@ -8,10 +8,11 @@ Tab completion for our interpreter prompt.
import functools
from stem.interpreter import uses_settings
+from typing import Optional, Sequence
@uses_settings
-def _get_commands(controller, config):
+def _get_commands(controller: 'stem.control.Controller', config: 'stem.util.conf.Config') -> Sequence[str]:
"""
Provides commands recognized by tor.
"""
@@ -76,11 +77,11 @@ def _get_commands(controller, config):
class Autocompleter(object):
- def __init__(self, controller):
+ def __init__(self, controller: 'stem.control.Controller') -> None:
self._commands = _get_commands(controller)
@functools.lru_cache()
- def matches(self, text):
+ def matches(self, text: str) -> Sequence[str]:
"""
Provides autocompletion matches for the given text.
@@ -92,7 +93,7 @@ class Autocompleter(object):
lowercase_text = text.lower()
return [cmd for cmd in self._commands if cmd.lower().startswith(lowercase_text)]
- def complete(self, text, state):
+ def complete(self, text: str, state: int) -> Optional[str]:
"""
Provides case insensetive autocompletion options, acting as a functor for
the readlines set_completer function.
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index 6e61fdda..1d610dac 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -21,11 +21,12 @@ import stem.util.tor_tools
from stem.interpreter import STANDARD_OUTPUT, BOLD_OUTPUT, ERROR_OUTPUT, uses_settings, msg
from stem.util.term import format
+from typing import BinaryIO, Iterator, Sequence, Tuple
MAX_EVENTS = 100
-def _get_fingerprint(arg, controller):
+def _get_fingerprint(arg: str, controller: 'stem.control.Controller') -> str:
"""
Resolves user input into a relay fingerprint. This accepts...
@@ -90,7 +91,7 @@ def _get_fingerprint(arg, controller):
@contextlib.contextmanager
-def redirect(stdout, stderr):
+def redirect(stdout: BinaryIO, stderr: BinaryIO) -> Iterator[None]:
original = sys.stdout, sys.stderr
sys.stdout, sys.stderr = stdout, stderr
@@ -106,7 +107,7 @@ class ControlInterpreter(code.InteractiveConsole):
for special irc style subcommands.
"""
- def __init__(self, controller):
+ def __init__(self, controller: 'stem.control.Controller') -> None:
self._received_events = []
code.InteractiveConsole.__init__(self, {
@@ -129,7 +130,7 @@ class ControlInterpreter(code.InteractiveConsole):
handle_event_real = self._controller._handle_event
- def handle_event_wrapper(event_message):
+ def handle_event_wrapper(event_message: 'stem.response.events.Event') -> None:
handle_event_real(event_message)
self._received_events.insert(0, event_message)
@@ -138,7 +139,7 @@ class ControlInterpreter(code.InteractiveConsole):
self._controller._handle_event = handle_event_wrapper
- def get_events(self, *event_types):
+ def get_events(self, *event_types: 'stem.control.EventType') -> Sequence['stem.response.events.Event']:
events = list(self._received_events)
event_types = list(map(str.upper, event_types)) # make filtering case insensitive
@@ -147,7 +148,7 @@ class ControlInterpreter(code.InteractiveConsole):
return events
- def do_help(self, arg):
+ def do_help(self, arg: str) -> str:
"""
Performs the '/help' operation, giving usage information for the given
argument or a general summary if there wasn't one.
@@ -155,7 +156,7 @@ class ControlInterpreter(code.InteractiveConsole):
return stem.interpreter.help.response(self._controller, arg)
- def do_events(self, arg):
+ def do_events(self, arg: str) -> str:
"""
Performs the '/events' operation, dumping the events that we've received
belonging to the given types. If no types are specified then this provides
@@ -173,7 +174,7 @@ class ControlInterpreter(code.InteractiveConsole):
return '\n'.join([format(str(e), *STANDARD_OUTPUT) for e in self.get_events(*event_types)])
- def do_info(self, arg):
+ def do_info(self, arg: str) -> str:
"""
Performs the '/info' operation, looking up a relay by fingerprint, IP
address, or nickname and printing its descriptor and consensus entries in a
@@ -271,7 +272,7 @@ class ControlInterpreter(code.InteractiveConsole):
return '\n'.join(lines)
- def do_python(self, arg):
+ def do_python(self, arg: str) -> str:
"""
Performs the '/python' operation, toggling if we accept python commands or
not.
@@ -295,12 +296,11 @@ class ControlInterpreter(code.InteractiveConsole):
return format(response, *STANDARD_OUTPUT)
@uses_settings
- def run_command(self, command, config, print_response = False):
+ def run_command(self, command: str, config: 'stem.util.conf.Config', print_response: bool = False) -> Sequence[Tuple[str, int]]:
"""
Runs the given command. Requests starting with a '/' are special commands
to the interpreter, and anything else is sent to the control port.
- :param stem.control.Controller controller: tor control connection
:param str command: command to be processed
:param bool print_response: prints the response to stdout if true
diff --git a/stem/interpreter/help.py b/stem/interpreter/help.py
index 1f242a8e..81c76d34 100644
--- a/stem/interpreter/help.py
+++ b/stem/interpreter/help.py
@@ -18,7 +18,7 @@ from stem.interpreter import (
from stem.util.term import format
-def response(controller, arg):
+def response(controller: 'stem.control.Controller', arg: str) -> str:
"""
Provides our /help response.
@@ -33,7 +33,7 @@ def response(controller, arg):
return _response(controller, _normalize(arg))
-def _normalize(arg):
+def _normalize(arg) -> str:
arg = arg.upper()
# If there's multiple arguments then just take the first. This is
@@ -52,7 +52,7 @@ def _normalize(arg):
@functools.lru_cache()
@uses_settings
-def _response(controller, arg, config):
+def _response(controller: 'stem.control.Controller', arg: str, config: 'stem.util.conf.Config') -> str:
if not arg:
return _general_help()
@@ -126,7 +126,7 @@ def _response(controller, arg, config):
return output.rstrip()
-def _general_help():
+def _general_help() -> str:
lines = []
for line in msg('help.general').splitlines():
diff --git a/stem/manual.py b/stem/manual.py
index 367b6d7e..e28e0e6f 100644
--- a/stem/manual.py
+++ b/stem/manual.py
@@ -63,6 +63,8 @@ import stem.util.enum
import stem.util.log
import stem.util.system
+from typing import Any, Dict, Mapping, Optional, Sequence, TextIO, Tuple, Union
+
Category = stem.util.enum.Enum('GENERAL', 'CLIENT', 'RELAY', 'DIRECTORY', 'AUTHORITY', 'HIDDEN_SERVICE', 'DENIAL_OF_SERVICE', 'TESTING', 'UNKNOWN')
GITWEB_MANUAL_URL = 'https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt'
CACHE_PATH = os.path.join(os.path.dirname(__file__), 'cached_manual.sqlite')
@@ -103,13 +105,13 @@ class SchemaMismatch(IOError):
:var tuple supported_schemas: schemas library supports
"""
- def __init__(self, message, database_schema, library_schema):
+ def __init__(self, message: str, database_schema: int, supported_schemas: Tuple[int]) -> None:
super(SchemaMismatch, self).__init__(message)
self.database_schema = database_schema
- self.library_schema = library_schema
+ self.supported_schemas = supported_schemas
-def query(query, *param):
+def query(query: str, *param: str) -> 'sqlite3.Cursor':
"""
Performs the given query on our sqlite manual cache. This database should
be treated as being read-only. File permissions generally enforce this, and
@@ -162,25 +164,25 @@ class ConfigOption(object):
:var str description: longer manual description with details
"""
- def __init__(self, name, category = Category.UNKNOWN, usage = '', summary = '', description = ''):
+ def __init__(self, name: str, category: 'stem.manual.Category' = Category.UNKNOWN, usage: str = '', summary: str = '', description: str = '') -> None:
self.name = name
self.category = category
self.usage = usage
self.summary = summary
self.description = description
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'name', 'category', 'usage', 'summary', 'description', cache = True)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, ConfigOption) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@functools.lru_cache()
-def _config(lowercase = True):
+def _config(lowercase: bool = True) -> Dict[str, Union[Sequence[str], str]]:
"""
Provides a dictionary for our settings.cfg. This has a couple categories...
@@ -204,7 +206,7 @@ def _config(lowercase = True):
return {}
-def _manual_differences(previous_manual, new_manual):
+def _manual_differences(previous_manual: 'stem.manual.Manual', new_manual: 'stem.manual.Manual') -> str:
"""
Provides a description of how two manuals differ.
"""
@@ -249,7 +251,7 @@ def _manual_differences(previous_manual, new_manual):
return '\n'.join(lines)
-def is_important(option):
+def is_important(option: str) -> bool:
"""
Indicates if a configuration option of particularly common importance or not.
@@ -262,7 +264,7 @@ def is_important(option):
return option.lower() in _config()['manual.important']
-def download_man_page(path = None, file_handle = None, url = GITWEB_MANUAL_URL, timeout = 20):
+def download_man_page(path: Optional[str] = None, file_handle: Optional[TextIO] = None, url: str = GITWEB_MANUAL_URL, timeout: int = 20) -> None:
"""
Downloads tor's latest man page from `gitweb.torproject.org
<https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt>`_. This method is
@@ -347,7 +349,7 @@ class Manual(object):
:var str stem_commit: stem commit to cache this manual information
"""
- def __init__(self, name, synopsis, description, commandline_options, signals, files, config_options):
+ def __init__(self, name: str, synopsis: str, description: str, commandline_options: Mapping[str, str], signals: Mapping[str, str], files: Mapping[str, str], config_options: Mapping[str, str]) -> None:
self.name = name
self.synopsis = synopsis
self.description = description
@@ -360,7 +362,7 @@ class Manual(object):
self.schema = None
@staticmethod
- def from_cache(path = None):
+ def from_cache(path: Optional[str] = None) -> 'stem.manual.Manual':
"""
Provides manual information cached with Stem. Unlike
:func:`~stem.manual.Manual.from_man` and
@@ -424,7 +426,7 @@ class Manual(object):
return manual
@staticmethod
- def from_man(man_path = 'tor'):
+ def from_man(man_path: str = 'tor') -> 'stem.manual.Manual':
"""
Reads and parses a given man page.
@@ -467,7 +469,7 @@ class Manual(object):
)
@staticmethod
- def from_remote(timeout = 60):
+ def from_remote(timeout: int = 60) -> 'stem.manual.Manual':
"""
Reads and parses the latest tor man page `from gitweb.torproject.org
<https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt>`_. Note that
@@ -500,7 +502,7 @@ class Manual(object):
download_man_page(file_handle = tmp, timeout = timeout)
return Manual.from_man(tmp.name)
- def save(self, path):
+ def save(self, path: str) -> None:
"""
Persists the manual content to a given location.
@@ -549,17 +551,17 @@ class Manual(object):
os.rename(tmp_path, path)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'name', 'synopsis', 'description', 'commandline_options', 'signals', 'files', 'config_options', cache = True)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, Manual) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
-def _get_categories(content):
+def _get_categories(content: str) -> Dict[str, str]:
"""
The man page is headers followed by an indented section. First pass gets
the mapping of category titles to their lines.
@@ -605,7 +607,7 @@ def _get_categories(content):
return categories
-def _get_indented_descriptions(lines):
+def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, Sequence[str]]:
"""
Parses the commandline argument and signal sections. These are options
followed by an indented description. For example...
@@ -635,7 +637,7 @@ def _get_indented_descriptions(lines):
return dict([(arg, ' '.join(desc_lines)) for arg, desc_lines in options.items() if desc_lines])
-def _add_config_options(config_options, category, lines):
+def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption'], category: str, lines: Sequence[str]) -> None:
"""
Parses a section of tor configuration options. These have usage information,
followed by an indented description. For instance...
@@ -653,7 +655,7 @@ def _add_config_options(config_options, category, lines):
since that platform lacks getrlimit(). (Default: 1000)
"""
- def add_option(title, description):
+ def add_option(title: str, description: str) -> None:
if 'PER INSTANCE OPTIONS' in title:
return # skip, unfortunately amid the options
@@ -697,7 +699,7 @@ def _add_config_options(config_options, category, lines):
add_option(last_title, description)
-def _join_lines(lines):
+def _join_lines(lines: Sequence[str]) -> str:
"""
Simple join, except we want empty lines to still provide a newline.
"""
diff --git a/stem/process.py b/stem/process.py
index a1d805ec..bfab4967 100644
--- a/stem/process.py
+++ b/stem/process.py
@@ -29,11 +29,13 @@ import stem.util.str_tools
import stem.util.system
import stem.version
+from typing import Any, Callable, Mapping, Optional, Sequence, Union
+
NO_TORRC = '<no torrc>'
DEFAULT_INIT_TIMEOUT = 90
-def launch_tor(tor_cmd = 'tor', args = None, torrc_path = None, completion_percent = 100, init_msg_handler = None, timeout = DEFAULT_INIT_TIMEOUT, take_ownership = False, close_output = True, stdin = None):
+def launch_tor(tor_cmd: str = 'tor', args: Optional[Sequence[str]] = None, torrc_path: Optional[str] = None, completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True, stdin: Optional[str] = None) -> subprocess.Popen:
"""
Initializes a tor process. This blocks until initialization completes or we
error out.
@@ -131,7 +133,7 @@ def launch_tor(tor_cmd = 'tor', args = None, torrc_path = None, completion_perce
tor_process.stdin.close()
if timeout:
- def timeout_handler(signum, frame):
+ def timeout_handler(signum: int, frame: Any) -> None:
raise OSError('reached a %i second timeout without success' % timeout)
signal.signal(signal.SIGALRM, timeout_handler)
@@ -197,7 +199,7 @@ def launch_tor(tor_cmd = 'tor', args = None, torrc_path = None, completion_perce
pass
-def launch_tor_with_config(config, tor_cmd = 'tor', completion_percent = 100, init_msg_handler = None, timeout = DEFAULT_INIT_TIMEOUT, take_ownership = False, close_output = True):
+def launch_tor_with_config(config: Mapping[str, Union[str, Sequence[str]]], tor_cmd: str = 'tor', completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True) -> subprocess.Popen:
"""
Initializes a tor process, like :func:`~stem.process.launch_tor`, but with a
customized configuration. This writes a temporary torrc to disk, launches
diff --git a/stem/response/__init__.py b/stem/response/__init__.py
index 2fbb9c48..4b1f9533 100644
--- a/stem/response/__init__.py
+++ b/stem/response/__init__.py
@@ -38,6 +38,8 @@ import stem.socket
import stem.util
import stem.util.str_tools
+from typing import Any, Iterator, Optional, Sequence, Tuple, Union
+
__all__ = [
'add_onion',
'events',
@@ -54,7 +56,7 @@ __all__ = [
KEY_ARG = re.compile('^(\\S+)=')
-def convert(response_type, message, **kwargs):
+def convert(response_type: str, message: 'stem.response.ControlMessage', **kwargs: Any) -> None:
"""
Converts a :class:`~stem.response.ControlMessage` into a particular kind of
tor response. This does an in-place conversion of the message from being a
@@ -140,7 +142,7 @@ class ControlMessage(object):
"""
@staticmethod
- def from_str(content, msg_type = None, normalize = False, **kwargs):
+ def from_str(content: str, msg_type: Optional[str] = None, normalize: bool = False, **kwargs: Any) -> 'stem.response.ControlMessage':
"""
Provides a ControlMessage for the given content.
@@ -171,7 +173,7 @@ class ControlMessage(object):
return msg
- def __init__(self, parsed_content, raw_content, arrived_at = None):
+ def __init__(self, parsed_content: Sequence[Tuple[str, str, bytes]], raw_content: bytes, arrived_at: Optional[int] = None) -> None:
if not parsed_content:
raise ValueError("ControlMessages can't be empty")
@@ -182,7 +184,7 @@ class ControlMessage(object):
self._str = None
self._hash = stem.util._hash_attr(self, '_raw_content')
- def is_ok(self):
+ def is_ok(self) -> bool:
"""
Checks if any of our lines have a 250 response.
@@ -195,7 +197,7 @@ class ControlMessage(object):
return False
- def content(self, get_bytes = False):
+ def content(self, get_bytes: bool = False) -> Sequence[Tuple[str, str, bytes]]:
"""
Provides the parsed message content. These are entries of the form...
@@ -234,7 +236,7 @@ class ControlMessage(object):
else:
return list(self._parsed_content)
- def raw_content(self, get_bytes = False):
+ def raw_content(self, get_bytes: bytes = False) -> Union[str, bytes]:
"""
Provides the unparsed content read from the control socket.
@@ -251,7 +253,7 @@ class ControlMessage(object):
else:
return self._raw_content
- def __str__(self):
+ def __str__(self) -> str:
"""
Content of the message, stripped of status code and divider protocol
formatting.
@@ -262,7 +264,7 @@ class ControlMessage(object):
return self._str
- def __iter__(self):
+ def __iter__(self) -> Iterator['stem.response.ControlLine']:
"""
Provides :class:`~stem.response.ControlLine` instances for the content of
the message. This is stripped of status codes and dividers, for instance...
@@ -290,14 +292,14 @@ class ControlMessage(object):
yield ControlLine(content)
- def __len__(self):
+ def __len__(self) -> int:
"""
:returns: number of ControlLines
"""
return len(self._parsed_content)
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> 'stem.response.ControlLine':
"""
:returns: :class:`~stem.response.ControlLine` at the index
"""
@@ -307,13 +309,13 @@ class ControlMessage(object):
return ControlLine(content)
- def __hash__(self):
+ def __hash__(self) -> int:
return self._hash
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return hash(self) == hash(other) if isinstance(other, ControlMessage) else False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
@@ -327,14 +329,14 @@ class ControlLine(str):
immutable). All methods are thread safe.
"""
- def __new__(self, value):
+ def __new__(self, value: str) -> 'stem.response.ControlLine':
return str.__new__(self, value)
- def __init__(self, value):
+ def __init__(self, value: str) -> None:
self._remainder = value
self._remainder_lock = threading.RLock()
- def remainder(self):
+ def remainder(self) -> str:
"""
Provides our unparsed content. This is an empty string after we've popped
all entries.
@@ -344,7 +346,7 @@ class ControlLine(str):
return self._remainder
- def is_empty(self):
+ def is_empty(self) -> bool:
"""
Checks if we have further content to pop or not.
@@ -353,7 +355,7 @@ class ControlLine(str):
return self._remainder == ''
- def is_next_quoted(self, escaped = False):
+ def is_next_quoted(self, escaped: bool = False) -> bool:
"""
Checks if our next entry is a quoted value or not.
@@ -365,7 +367,7 @@ class ControlLine(str):
start_quote, end_quote = _get_quote_indices(self._remainder, escaped)
return start_quote == 0 and end_quote != -1
- def is_next_mapping(self, key = None, quoted = False, escaped = False):
+ def is_next_mapping(self, key: Optional[str] = None, quoted: bool = False, escaped: bool = False) -> bool:
"""
Checks if our next entry is a KEY=VALUE mapping or not.
@@ -393,7 +395,7 @@ class ControlLine(str):
else:
return False # doesn't start with a key
- def peek_key(self):
+ def peek_key(self) -> str:
"""
Provides the key of the next entry, providing **None** if it isn't a
key/value mapping.
@@ -409,7 +411,7 @@ class ControlLine(str):
else:
return None
- def pop(self, quoted = False, escaped = False):
+ def pop(self, quoted: bool = False, escaped: bool = False) -> str:
"""
Parses the next space separated entry, removing it and the space from our
remaining content. Examples...
@@ -443,7 +445,7 @@ class ControlLine(str):
self._remainder = remainder
return next_entry
- def pop_mapping(self, quoted = False, escaped = False, get_bytes = False):
+ def pop_mapping(self, quoted: bool = False, escaped: bool = False, get_bytes: bool = False) -> Tuple[str, str]:
"""
Parses the next space separated entry as a KEY=VALUE mapping, removing it
and the space from our remaining content.
@@ -480,13 +482,14 @@ class ControlLine(str):
return (key, next_entry)
-def _parse_entry(line, quoted, escaped, get_bytes):
+def _parse_entry(line: str, quoted: bool, escaped: bool, get_bytes: bool) -> Tuple[Union[str, bytes], str]:
"""
Parses the next entry from the given space separated content.
:param str line: content to be parsed
:param bool quoted: parses the next entry as a quoted value, removing the quotes
:param bool escaped: unescapes the string
+ :param bool get_bytes: provides **bytes** for the entry rather than a **str**
:returns: **tuple** of the form (entry, remainder)
@@ -540,7 +543,7 @@ def _parse_entry(line, quoted, escaped, get_bytes):
return (next_entry, remainder.lstrip())
-def _get_quote_indices(line, escaped):
+def _get_quote_indices(line: str, escaped: bool) -> Tuple[int, int]:
"""
Provides the indices of the next two quotes in the given content.
@@ -576,7 +579,7 @@ class SingleLineResponse(ControlMessage):
:var str message: content of the line
"""
- def is_ok(self, strict = False):
+ def is_ok(self, strict: bool = False) -> bool:
"""
Checks if the response code is "250". If strict is **True** then this
checks if the response is "250 OK"
@@ -593,7 +596,7 @@ class SingleLineResponse(ControlMessage):
return self.content()[0][0] == '250'
- def _parse_message(self):
+ def _parse_message(self) -> None:
content = self.content()
if len(content) > 1:
diff --git a/stem/response/add_onion.py b/stem/response/add_onion.py
index 64d58282..3f52f9f2 100644
--- a/stem/response/add_onion.py
+++ b/stem/response/add_onion.py
@@ -15,7 +15,7 @@ class AddOnionResponse(stem.response.ControlMessage):
:var dict client_auth: newly generated client credentials the service accepts
"""
- def _parse_message(self):
+ def _parse_message(self) -> None:
# Example:
# 250-ServiceID=gfzprpioee3hoppz
# 250-PrivateKey=RSA1024:MIICXgIBAAKBgQDZvYVxv...
diff --git a/stem/response/authchallenge.py b/stem/response/authchallenge.py
index d9cc5491..80a1c0f5 100644
--- a/stem/response/authchallenge.py
+++ b/stem/response/authchallenge.py
@@ -17,7 +17,7 @@ class AuthChallengeResponse(stem.response.ControlMessage):
:var str server_nonce: server nonce provided by tor
"""
- def _parse_message(self):
+ def _parse_message(self) -> None:
# Example:
# 250 AUTHCHALLENGE SERVERHASH=680A73C9836C4F557314EA1C4EDE54C285DB9DC89C83627401AEF9D7D27A95D5 SERVERNONCE=F8EA4B1F2C8B40EF1AF68860171605B910E3BBCABADF6FC3DB1FA064F4690E85
diff --git a/stem/response/events.py b/stem/response/events.py
index fdd17a25..0e112373 100644
--- a/stem/response/events.py
+++ b/stem/response/events.py
@@ -12,6 +12,7 @@ import stem.util
import stem.version
from stem.util import connection, log, str_tools, tor_tools
+from typing import Any, Dict, Sequence
# Matches keyword=value arguments. This can't be a simple "(.*)=(.*)" pattern
# because some positional arguments, like circuit paths, can have an equal
@@ -40,7 +41,7 @@ class Event(stem.response.ControlMessage):
_SKIP_PARSING = False # skip parsing contents into our positional_args and keyword_args
_VERSION_ADDED = stem.version.Version('0.1.1.1-alpha') # minimum version with control-spec V1 event support
- def _parse_message(self):
+ def _parse_message(self) -> None:
if not str(self).strip():
raise stem.ProtocolError('Received a blank tor event. Events must at the very least have a type.')
@@ -58,10 +59,10 @@ class Event(stem.response.ControlMessage):
self._parse()
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'arrived_at', parent = stem.response.ControlMessage, cache = True)
- def _parse_standard_attr(self):
+ def _parse_standard_attr(self) -> None:
"""
Most events are of the form...
650 *( positional_args ) *( key "=" value )
@@ -122,7 +123,7 @@ class Event(stem.response.ControlMessage):
for controller_attr_name, attr_name in self._KEYWORD_ARGS.items():
setattr(self, attr_name, self.keyword_args.get(controller_attr_name))
- def _iso_timestamp(self, timestamp):
+ def _iso_timestamp(self, timestamp: str) -> 'datetime.datetime':
"""
Parses an iso timestamp (ISOTime2Frac in the control-spec).
@@ -142,10 +143,10 @@ class Event(stem.response.ControlMessage):
raise stem.ProtocolError('Unable to parse timestamp (%s): %s' % (exc, self))
# method overwritten by our subclasses for special handling that they do
- def _parse(self):
+ def _parse(self) -> None:
pass
- def _log_if_unrecognized(self, attr, attr_enum):
+ def _log_if_unrecognized(self, attr: str, attr_enum: 'stem.util.enum.Enum') -> None:
"""
Checks if an attribute exists in a given enumeration, logging a message if
it isn't. Attributes can either be for a string or collection of strings
@@ -196,7 +197,7 @@ class AddrMapEvent(Event):
}
_OPTIONALLY_QUOTED = ('expiry')
- def _parse(self):
+ def _parse(self) -> None:
if self.destination == '<error>':
self.destination = None
@@ -234,7 +235,7 @@ class BandwidthEvent(Event):
_POSITIONAL_ARGS = ('read', 'written')
- def _parse(self):
+ def _parse(self) -> None:
if not self.read:
raise stem.ProtocolError('BW event is missing its read value')
elif not self.written:
@@ -277,7 +278,7 @@ class BuildTimeoutSetEvent(Event):
}
_VERSION_ADDED = stem.version.Version('0.2.2.7-alpha')
- def _parse(self):
+ def _parse(self) -> None:
# convert our integer and float parameters
for param in ('total_times', 'timeout', 'xm', 'close_timeout'):
@@ -346,7 +347,7 @@ class CircuitEvent(Event):
'SOCKS_PASSWORD': 'socks_password',
}
- def _parse(self):
+ def _parse(self) -> None:
self.path = tuple(stem.control._parse_circ_path(self.path))
self.created = self._iso_timestamp(self.created)
@@ -363,7 +364,7 @@ class CircuitEvent(Event):
self._log_if_unrecognized('reason', stem.CircClosureReason)
self._log_if_unrecognized('remote_reason', stem.CircClosureReason)
- def _compare(self, other, method):
+ def _compare(self, other: Any, method: Any) -> bool:
# sorting circuit events by their identifier
if not isinstance(other, CircuitEvent):
@@ -374,10 +375,10 @@ class CircuitEvent(Event):
return method(my_id, their_id) if my_id != their_id else method(hash(self), hash(other))
- def __gt__(self, other):
+ def __gt__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s > o)
- def __ge__(self, other):
+ def __ge__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s >= o)
@@ -414,7 +415,7 @@ class CircMinorEvent(Event):
}
_VERSION_ADDED = stem.version.Version('0.2.3.11-alpha')
- def _parse(self):
+ def _parse(self) -> None:
self.path = tuple(stem.control._parse_circ_path(self.path))
self.created = self._iso_timestamp(self.created)
@@ -450,7 +451,7 @@ class ClientsSeenEvent(Event):
}
_VERSION_ADDED = stem.version.Version('0.2.1.10-alpha')
- def _parse(self):
+ def _parse(self) -> None:
if self.start_time is not None:
self.start_time = stem.util.str_tools._parse_timestamp(self.start_time)
@@ -509,7 +510,7 @@ class ConfChangedEvent(Event):
_SKIP_PARSING = True
_VERSION_ADDED = stem.version.Version('0.2.3.3-alpha')
- def _parse(self):
+ def _parse(self) -> None:
self.changed = {}
self.unset = []
@@ -563,7 +564,7 @@ class GuardEvent(Event):
_VERSION_ADDED = stem.version.Version('0.1.2.5-alpha')
_POSITIONAL_ARGS = ('guard_type', 'endpoint', 'status')
- def _parse(self):
+ def _parse(self) -> None:
self.endpoint_fingerprint = None
self.endpoint_nickname = None
@@ -610,7 +611,7 @@ class HSDescEvent(Event):
_POSITIONAL_ARGS = ('action', 'address', 'authentication', 'directory', 'descriptor_id')
_KEYWORD_ARGS = {'REASON': 'reason', 'REPLICA': 'replica', 'HSDIR_INDEX': 'index'}
- def _parse(self):
+ def _parse(self) -> None:
self.directory_fingerprint = None
self.directory_nickname = None
@@ -650,7 +651,7 @@ class HSDescContentEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.7.1-alpha')
_POSITIONAL_ARGS = ('address', 'descriptor_id', 'directory')
- def _parse(self):
+ def _parse(self) -> None:
if self.address == 'UNKNOWN':
self.address = None
@@ -686,7 +687,7 @@ class LogEvent(Event):
_SKIP_PARSING = True
- def _parse(self):
+ def _parse(self) -> None:
self.runlevel = self.type
self._log_if_unrecognized('runlevel', stem.Runlevel)
@@ -709,7 +710,7 @@ class NetworkStatusEvent(Event):
_SKIP_PARSING = True
_VERSION_ADDED = stem.version.Version('0.1.2.3-alpha')
- def _parse(self):
+ def _parse(self) -> None:
content = str(self).lstrip('NS\n').rstrip('\nOK')
self.descriptors = list(stem.descriptor.router_status_entry._parse_file(
@@ -753,11 +754,11 @@ class NewConsensusEvent(Event):
_SKIP_PARSING = True
_VERSION_ADDED = stem.version.Version('0.2.1.13-alpha')
- def _parse(self):
+ def _parse(self) -> None:
self.consensus_content = str(self).lstrip('NEWCONSENSUS\n').rstrip('\nOK')
self._parsed = None
- def entries(self):
+ def entries(self) -> Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']:
"""
Relay router status entries residing within this consensus.
@@ -791,7 +792,7 @@ class NewDescEvent(Event):
new descriptors
"""
- def _parse(self):
+ def _parse(self) -> None:
self.relays = tuple([stem.control._parse_circ_entry(entry) for entry in str(self).split()[1:]])
@@ -832,7 +833,7 @@ class ORConnEvent(Event):
'ID': 'id',
}
- def _parse(self):
+ def _parse(self) -> None:
self.endpoint_fingerprint = None
self.endpoint_nickname = None
self.endpoint_address = None
@@ -886,7 +887,7 @@ class SignalEvent(Event):
_POSITIONAL_ARGS = ('signal',)
_VERSION_ADDED = stem.version.Version('0.2.3.1-alpha')
- def _parse(self):
+ def _parse(self) -> None:
# log if we recieved an unrecognized signal
expected_signals = (
stem.Signal.RELOAD,
@@ -918,7 +919,7 @@ class StatusEvent(Event):
_POSITIONAL_ARGS = ('runlevel', 'action')
_VERSION_ADDED = stem.version.Version('0.1.2.3-alpha')
- def _parse(self):
+ def _parse(self) -> None:
if self.type == 'STATUS_GENERAL':
self.status_type = stem.StatusType.GENERAL
elif self.type == 'STATUS_CLIENT':
@@ -970,7 +971,7 @@ class StreamEvent(Event):
'PURPOSE': 'purpose',
}
- def _parse(self):
+ def _parse(self) -> None:
if self.target is None:
raise stem.ProtocolError("STREAM event didn't have a target: %s" % self)
else:
@@ -1029,7 +1030,7 @@ class StreamBwEvent(Event):
_POSITIONAL_ARGS = ('id', 'written', 'read', 'time')
_VERSION_ADDED = stem.version.Version('0.1.2.8-beta')
- def _parse(self):
+ def _parse(self) -> None:
if not tor_tools.is_valid_stream_id(self.id):
raise stem.ProtocolError("Stream IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self))
elif not self.written:
@@ -1062,7 +1063,7 @@ class TransportLaunchedEvent(Event):
_POSITIONAL_ARGS = ('type', 'name', 'address', 'port')
_VERSION_ADDED = stem.version.Version('0.2.5.0-alpha')
- def _parse(self):
+ def _parse(self) -> None:
if self.type not in ('server', 'client'):
raise stem.ProtocolError("Transport type should either be 'server' or 'client': %s" % self)
@@ -1104,7 +1105,7 @@ class ConnectionBandwidthEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
- def _parse(self):
+ def _parse(self) -> None:
if not self.id:
raise stem.ProtocolError('CONN_BW event is missing its id')
elif not self.conn_type:
@@ -1163,7 +1164,7 @@ class CircuitBandwidthEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
- def _parse(self):
+ def _parse(self) -> None:
if not self.id:
raise stem.ProtocolError('CIRC_BW event is missing its id')
elif not self.read:
@@ -1233,7 +1234,7 @@ class CellStatsEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
- def _parse(self):
+ def _parse(self) -> None:
if self.id and not tor_tools.is_valid_circuit_id(self.id):
raise stem.ProtocolError("Circuit IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self))
elif self.inbound_queue and not tor_tools.is_valid_circuit_id(self.inbound_queue):
@@ -1279,7 +1280,7 @@ class TokenBucketEmptyEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
- def _parse(self):
+ def _parse(self) -> None:
if self.id and not tor_tools.is_valid_connection_id(self.id):
raise stem.ProtocolError("Connection IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self))
elif not self.read.isdigit():
@@ -1296,7 +1297,7 @@ class TokenBucketEmptyEvent(Event):
self._log_if_unrecognized('bucket', stem.TokenBucket)
-def _parse_cell_type_mapping(mapping):
+def _parse_cell_type_mapping(mapping: str) -> Dict[str, int]:
"""
Parses a mapping of the form...
diff --git a/stem/response/getconf.py b/stem/response/getconf.py
index 6de49b1f..7ba972ae 100644
--- a/stem/response/getconf.py
+++ b/stem/response/getconf.py
@@ -16,7 +16,7 @@ class GetConfResponse(stem.response.ControlMessage):
values (**list** of **str**)
"""
- def _parse_message(self):
+ def _parse_message(self) -> None:
# Example:
# 250-CookieAuthentication=0
# 250-ControlPort=9100
diff --git a/stem/response/getinfo.py b/stem/response/getinfo.py
index 27442ffd..7aebd70a 100644
--- a/stem/response/getinfo.py
+++ b/stem/response/getinfo.py
@@ -4,6 +4,8 @@
import stem.response
import stem.socket
+from typing import Sequence
+
class GetInfoResponse(stem.response.ControlMessage):
"""
@@ -12,7 +14,7 @@ class GetInfoResponse(stem.response.ControlMessage):
:var dict entries: mapping between the queried options and their bytes values
"""
- def _parse_message(self):
+ def _parse_message(self) -> None:
# Example:
# 250-version=0.2.3.11-alpha-dev (git-ef0bc7f8f26a917c)
# 250+config-text=
@@ -66,7 +68,7 @@ class GetInfoResponse(stem.response.ControlMessage):
self.entries[key] = value
- def _assert_matches(self, params):
+ def _assert_matches(self, params: Sequence[str]) -> None:
"""
Checks if we match a given set of parameters, and raise a ProtocolError if not.
diff --git a/stem/response/mapaddress.py b/stem/response/mapaddress.py
index 73ed84f1..92ce16d2 100644
--- a/stem/response/mapaddress.py
+++ b/stem/response/mapaddress.py
@@ -17,7 +17,7 @@ class MapAddressResponse(stem.response.ControlMessage):
* :class:`stem.InvalidRequest` if the addresses provided were invalid
"""
- def _parse_message(self):
+ def _parse_message(self) -> None:
# Example:
# 250-127.192.10.10=torproject.org
# 250 1.2.3.4=tor.freehaven.net
diff --git a/stem/response/protocolinfo.py b/stem/response/protocolinfo.py
index 459fef5b..330b165e 100644
--- a/stem/response/protocolinfo.py
+++ b/stem/response/protocolinfo.py
@@ -8,7 +8,6 @@ import stem.socket
import stem.version
import stem.util.str_tools
-from stem.connection import AuthMethod
from stem.util import log
@@ -26,13 +25,15 @@ class ProtocolInfoResponse(stem.response.ControlMessage):
:var str cookie_path: path of tor's authentication cookie
"""
- def _parse_message(self):
+ def _parse_message(self) -> None:
# Example:
# 250-PROTOCOLINFO 1
# 250-AUTH METHODS=COOKIE COOKIEFILE="/home/atagar/.tor/control_auth_cookie"
# 250-VERSION Tor="0.2.1.30"
# 250 OK
+ from stem.connection import AuthMethod
+
self.protocol_version = None
self.tor_version = None
self.auth_methods = ()
diff --git a/stem/socket.py b/stem/socket.py
index db110973..179ae16e 100644
--- a/stem/socket.py
+++ b/stem/socket.py
@@ -62,8 +62,7 @@ Tor...
|- is_localhost - returns if the socket is for the local system or not
|- connection_time - timestamp when socket last connected or disconnected
|- connect - connects a new socket
- |- close - shuts down the socket
- +- __enter__ / __exit__ - manages socket connection
+ +- close - shuts down the socket
send_message - Writes a message to a control socket.
recv_message - Reads a ControlMessage from a control socket.
@@ -80,6 +79,8 @@ import stem.response
import stem.util.str_tools
from stem.util import log
+from types import TracebackType
+from typing import BinaryIO, Callable, Optional, Type
MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]')
ERROR_MSG = 'Error while receiving a control message (%s): %s'
@@ -94,7 +95,7 @@ class BaseSocket(object):
Thread safe socket, providing common socket functionality.
"""
- def __init__(self):
+ def __init__(self) -> None:
self._socket, self._socket_file = None, None
self._is_alive = False
self._connection_time = 0.0 # time when we last connected or disconnected
@@ -106,7 +107,7 @@ class BaseSocket(object):
self._send_lock = threading.RLock()
self._recv_lock = threading.RLock()
- def is_alive(self):
+ def is_alive(self) -> bool:
"""
Checks if the socket is known to be closed. We won't be aware if it is
until we either use it or have explicitily shut it down.
@@ -125,7 +126,7 @@ class BaseSocket(object):
return self._is_alive
- def is_localhost(self):
+ def is_localhost(self) -> bool:
"""
Returns if the connection is for the local system or not.
@@ -135,7 +136,7 @@ class BaseSocket(object):
return False
- def connection_time(self):
+ def connection_time(self) -> float:
"""
Provides the unix timestamp for when our socket was either connected or
disconnected. That is to say, the time we connected if we're currently
@@ -149,7 +150,7 @@ class BaseSocket(object):
return self._connection_time
- def connect(self):
+ def connect(self) -> None:
"""
Connects to a new socket, closing our previous one if we're already
attached.
@@ -181,7 +182,7 @@ class BaseSocket(object):
except stem.SocketError:
self._connect() # single retry
- def close(self):
+ def close(self) -> None:
"""
Shuts down the socket. If it's already closed then this is a no-op.
"""
@@ -217,7 +218,7 @@ class BaseSocket(object):
if is_change:
self._close()
- def _send(self, message, handler):
+ def _send(self, message: str, handler: Callable[[socket.socket, BinaryIO, str], None]) -> None:
"""
Send message in a thread safe manner. Handler is expected to be of the form...
@@ -241,7 +242,7 @@ class BaseSocket(object):
raise
- def _recv(self, handler):
+ def _recv(self, handler: Callable[[socket.socket, BinaryIO], None]) -> bytes:
"""
Receives a message in a thread safe manner. Handler is expected to be of the form...
@@ -283,7 +284,7 @@ class BaseSocket(object):
raise
- def _get_send_lock(self):
+ def _get_send_lock(self) -> threading.RLock:
"""
The send lock is useful to classes that interact with us at a deep level
because it's used to lock :func:`stem.socket.ControlSocket.connect` /
@@ -296,27 +297,27 @@ class BaseSocket(object):
return self._send_lock
- def __enter__(self):
+ def __enter__(self) -> 'stem.socket.BaseSocket':
return self
- def __exit__(self, exit_type, value, traceback):
+ def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
self.close()
- def _connect(self):
+ def _connect(self) -> None:
"""
Connection callback that can be overwritten by subclasses and wrappers.
"""
pass
- def _close(self):
+ def _close(self) -> None:
"""
Disconnection callback that can be overwritten by subclasses and wrappers.
"""
pass
- def _make_socket(self):
+ def _make_socket(self) -> socket.socket:
"""
Constructs and connects new socket. This is implemented by subclasses.
@@ -342,7 +343,7 @@ class RelaySocket(BaseSocket):
:var int port: ORPort our socket connects to
"""
- def __init__(self, address = '127.0.0.1', port = 9050, connect = True):
+ def __init__(self, address: str = '127.0.0.1', port: int = 9050, connect: bool = True) -> None:
"""
RelaySocket constructor.
@@ -361,7 +362,7 @@ class RelaySocket(BaseSocket):
if connect:
self.connect()
- def send(self, message):
+ def send(self, message: str) -> None:
"""
Sends a message to the relay's ORPort.
@@ -374,7 +375,7 @@ class RelaySocket(BaseSocket):
self._send(message, lambda s, sf, msg: _write_to_socket(sf, msg))
- def recv(self, timeout = None):
+ def recv(self, timeout: Optional[float] = None) -> bytes:
"""
Receives a message from the relay.
@@ -388,7 +389,7 @@ class RelaySocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket closes before we receive a complete message
"""
- def wrapped_recv(s, sf):
+ def wrapped_recv(s: socket.socket, sf: BinaryIO) -> bytes:
if timeout is None:
return s.recv()
else:
@@ -404,10 +405,10 @@ class RelaySocket(BaseSocket):
return self._recv(wrapped_recv)
- def is_localhost(self):
+ def is_localhost(self) -> bool:
return self.address == '127.0.0.1'
- def _make_socket(self):
+ def _make_socket(self) -> socket.socket:
try:
relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
relay_socket.connect((self.address, self.port))
@@ -426,10 +427,10 @@ class ControlSocket(BaseSocket):
which are expected to implement the **_make_socket()** method.
"""
- def __init__(self):
+ def __init__(self) -> None:
super(ControlSocket, self).__init__()
- def send(self, message):
+ def send(self, message: str) -> None:
"""
Formats and sends a message to the control socket. For more information see
the :func:`~stem.socket.send_message` function.
@@ -443,7 +444,7 @@ class ControlSocket(BaseSocket):
self._send(message, lambda s, sf, msg: send_message(sf, msg))
- def recv(self):
+ def recv(self) -> stem.response.ControlMessage:
"""
Receives a message from the control socket, blocking until we've received
one. For more information see the :func:`~stem.socket.recv_message` function.
@@ -467,7 +468,7 @@ class ControlPort(ControlSocket):
:var int port: ControlPort our socket connects to
"""
- def __init__(self, address = '127.0.0.1', port = 9051, connect = True):
+ def __init__(self, address: str = '127.0.0.1', port: int = 9051, connect: bool = True) -> None:
"""
ControlPort constructor.
@@ -486,10 +487,10 @@ class ControlPort(ControlSocket):
if connect:
self.connect()
- def is_localhost(self):
+ def is_localhost(self) -> bool:
return self.address == '127.0.0.1'
- def _make_socket(self):
+ def _make_socket(self) -> socket.socket:
try:
control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
control_socket.connect((self.address, self.port))
@@ -506,7 +507,7 @@ class ControlSocketFile(ControlSocket):
:var str path: filesystem path of the socket we connect to
"""
- def __init__(self, path = '/var/run/tor/control', connect = True):
+ def __init__(self, path: str = '/var/run/tor/control', connect: bool = True) -> None:
"""
ControlSocketFile constructor.
@@ -523,10 +524,10 @@ class ControlSocketFile(ControlSocket):
if connect:
self.connect()
- def is_localhost(self):
+ def is_localhost(self) -> bool:
return True
- def _make_socket(self):
+ def _make_socket(self) -> socket.socket:
try:
control_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
control_socket.connect(self.path)
@@ -535,7 +536,7 @@ class ControlSocketFile(ControlSocket):
raise stem.SocketError(exc)
-def send_message(control_file, message, raw = False):
+def send_message(control_file: BinaryIO, message: str, raw: bool = False) -> None:
"""
Sends a message to the control socket, adding the expected formatting for
single verses multi-line messages. Neither message type should contain an
@@ -578,7 +579,7 @@ def send_message(control_file, message, raw = False):
log.trace('Sent to tor:%s%s' % (msg_div, log_message))
-def _write_to_socket(socket_file, message):
+def _write_to_socket(socket_file: BinaryIO, message: str) -> None:
try:
socket_file.write(stem.util.str_tools._to_bytes(message))
socket_file.flush()
@@ -601,7 +602,7 @@ def _write_to_socket(socket_file, message):
raise stem.SocketClosed('file has been closed')
-def recv_message(control_file, arrived_at = None):
+def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) -> stem.response.ControlMessage:
"""
Pulls from a control socket until we either have a complete message or
encounter a problem.
@@ -721,7 +722,7 @@ def recv_message(control_file, arrived_at = None):
raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line)))
-def send_formatting(message):
+def send_formatting(message: str) -> None:
"""
Performs the formatting expected from sent control messages. For more
information see the :func:`~stem.socket.send_message` function.
@@ -750,7 +751,7 @@ def send_formatting(message):
return message + '\r\n'
-def _log_trace(response):
+def _log_trace(response: bytes) -> None:
if not log.is_tracing():
return
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index e4e08174..050f6c91 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -7,6 +7,8 @@ Utility functions used by the stem library.
import datetime
+from typing import Any, Union
+
__all__ = [
'conf',
'connection',
@@ -43,7 +45,7 @@ __all__ = [
HASH_TYPES = True
-def _hash_value(val):
+def _hash_value(val: Any) -> int:
if not HASH_TYPES:
my_hash = 0
else:
@@ -64,7 +66,7 @@ def _hash_value(val):
return my_hash
-def datetime_to_unix(timestamp):
+def datetime_to_unix(timestamp: 'datetime.datetime') -> float:
"""
Converts a utc datetime object to a unix timestamp.
@@ -78,7 +80,7 @@ def datetime_to_unix(timestamp):
return (timestamp - datetime.datetime(1970, 1, 1)).total_seconds()
-def _pubkey_bytes(key):
+def _pubkey_bytes(key: Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']) -> bytes:
"""
Normalizes X25509 and ED25519 keys into their public key bytes.
"""
@@ -107,7 +109,7 @@ def _pubkey_bytes(key):
raise ValueError('Key must be a string or cryptographic public/private key (was %s)' % type(key).__name__)
-def _hash_attr(obj, *attributes, **kwargs):
+def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
"""
Provide a hash value for the given set of attributes.
diff --git a/stem/util/conf.py b/stem/util/conf.py
index a06f1fd7..37d1c5f4 100644
--- a/stem/util/conf.py
+++ b/stem/util/conf.py
@@ -163,16 +163,17 @@ import os
import threading
from stem.util import log
+from typing import Any, Callable, Mapping, Optional, Sequence, Union
CONFS = {} # mapping of identifier to singleton instances of configs
class _SyncListener(object):
- def __init__(self, config_dict, interceptor):
+ def __init__(self, config_dict: Mapping[str, Any], interceptor: Callable[[str, Any], Any]) -> None:
self.config_dict = config_dict
self.interceptor = interceptor
- def update(self, config, key):
+ def update(self, config: 'stem.util.conf.Config', key: str) -> None:
if key in self.config_dict:
new_value = config.get(key, self.config_dict[key])
@@ -188,7 +189,7 @@ class _SyncListener(object):
self.config_dict[key] = new_value
-def config_dict(handle, conf_mappings, handler = None):
+def config_dict(handle: str, conf_mappings: Mapping[str, Any], handler: Optional[Callable[[str, Any], Any]] = None) -> Mapping[str, Any]:
"""
Makes a dictionary that stays synchronized with a configuration.
@@ -214,6 +215,8 @@ def config_dict(handle, conf_mappings, handler = None):
:param str handle: unique identifier for a config instance
:param dict conf_mappings: config key/value mappings used as our defaults
:param functor handler: function referred to prior to assigning values
+
+ :returns: mapping of attributes to their current configuration value
"""
selected_config = get_config(handle)
@@ -221,7 +224,7 @@ def config_dict(handle, conf_mappings, handler = None):
return conf_mappings
-def get_config(handle):
+def get_config(handle: str) -> 'stem.util.conf.Config':
"""
Singleton constructor for configuration file instances. If a configuration
already exists for the handle then it's returned. Otherwise a fresh instance
@@ -236,7 +239,7 @@ def get_config(handle):
return CONFS[handle]
-def uses_settings(handle, path, lazy_load = True):
+def uses_settings(handle: str, path: str, lazy_load: bool = True) -> Callable:
"""
Provides a function that can be used as a decorator for other functions that
require settings to be loaded. Functions with this decorator will be provided
@@ -272,13 +275,13 @@ def uses_settings(handle, path, lazy_load = True):
config.load(path)
config._settings_loaded = True
- def decorator(func):
- def wrapped(*args, **kwargs):
+ def decorator(func: Callable) -> Callable:
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
if lazy_load and not config._settings_loaded:
config.load(path)
config._settings_loaded = True
- if 'config' in inspect.getargspec(func).args:
+ if 'config' in inspect.getfullargspec(func).args:
return func(*args, config = config, **kwargs)
else:
return func(*args, **kwargs)
@@ -288,7 +291,7 @@ def uses_settings(handle, path, lazy_load = True):
return decorator
-def parse_enum(key, value, enumeration):
+def parse_enum(key: str, value: str, enumeration: 'stem.util.enum.Enum') -> Any:
"""
Provides the enumeration value for a given key. This is a case insensitive
lookup and raises an exception if the enum key doesn't exist.
@@ -305,7 +308,7 @@ def parse_enum(key, value, enumeration):
return parse_enum_csv(key, value, enumeration, 1)[0]
-def parse_enum_csv(key, value, enumeration, count = None):
+def parse_enum_csv(key: str, value: str, enumeration: 'stem.util.enum.Enum', count: Optional[Union[int, Sequence[int]]] = None) -> Sequence[Any]:
"""
Parses a given value as being a comma separated listing of enumeration keys,
returning the corresponding enumeration values. This is intended to be a
@@ -445,7 +448,7 @@ class Config(object):
Class can now be used as a dictionary.
"""
- def __init__(self):
+ def __init__(self) -> None:
self._path = None # location we last loaded from or saved to
self._contents = collections.OrderedDict() # configuration key/value pairs
self._listeners = [] # functors to be notified of config changes
@@ -459,7 +462,7 @@ class Config(object):
# flag to support lazy loading in uses_settings()
self._settings_loaded = False
- def load(self, path = None, commenting = True):
+ def load(self, path: Optional[str] = None, commenting: bool = True) -> None:
"""
Reads in the contents of the given path, adding its configuration values
to our current contents. If the path is a directory then this loads each
@@ -534,7 +537,7 @@ class Config(object):
else:
self.set(line, '', False) # default to a key => '' mapping
- def save(self, path = None):
+ def save(self, path: Optional[str] = None) -> None:
"""
Saves configuration contents to disk. If a path is provided then it
replaces the configuration location that we track.
@@ -564,7 +567,7 @@ class Config(object):
output_file.write('%s %s\n' % (entry_key, entry_value))
- def clear(self):
+ def clear(self) -> None:
"""
Drops the configuration contents and reverts back to a blank, unloaded
state.
@@ -574,7 +577,7 @@ class Config(object):
self._contents.clear()
self._requested_keys = set()
- def add_listener(self, listener, backfill = True):
+ def add_listener(self, listener: Callable[[str, Any], Any], backfill: bool = True) -> None:
"""
Registers the function to be notified of configuration updates. Listeners
are expected to be functors which accept (config, key).
@@ -590,14 +593,14 @@ class Config(object):
for key in self.keys():
listener(self, key)
- def clear_listeners(self):
+ def clear_listeners(self) -> None:
"""
Removes all attached listeners.
"""
self._listeners = []
- def keys(self):
+ def keys(self) -> Sequence[str]:
"""
Provides all keys in the currently loaded configuration.
@@ -606,7 +609,7 @@ class Config(object):
return list(self._contents.keys())
- def unused_keys(self):
+ def unused_keys(self) -> Sequence[str]:
"""
Provides the configuration keys that have never been provided to a caller
via :func:`~stem.util.conf.config_dict` or the
@@ -618,7 +621,7 @@ class Config(object):
return set(self.keys()).difference(self._requested_keys)
- def set(self, key, value, overwrite = True):
+ def set(self, key: str, value: Union[str, Sequence[str]], overwrite: bool = True) -> None:
"""
Appends the given key/value configuration mapping, behaving the same as if
we'd loaded this from a configuration file.
@@ -657,7 +660,7 @@ class Config(object):
else:
raise ValueError("Config.set() only accepts str (bytes or unicode), list, or tuple. Provided value was a '%s'" % type(value))
- def get(self, key, default = None):
+ def get(self, key: str, default: Optional[Any] = None) -> Any:
"""
Fetches the given configuration, using the key and default value to
determine the type it should be. Recognized inferences are:
@@ -737,7 +740,7 @@ class Config(object):
return val
- def get_value(self, key, default = None, multiple = False):
+ def get_value(self, key: str, default: Optional[Any] = None, multiple: bool = False) -> Union[str, Sequence[str]]:
"""
This provides the current value associated with a given key.
@@ -763,6 +766,6 @@ class Config(object):
log.log_once(message_id, log.TRACE, "config entry '%s' not found, defaulting to '%s'" % (key, default))
return default
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> Any:
with self._contents_lock:
return self._contents[key]
diff --git a/stem/util/connection.py b/stem/util/connection.py
index eaeafec4..2f815a46 100644
--- a/stem/util/connection.py
+++ b/stem/util/connection.py
@@ -65,6 +65,7 @@ import stem.util.proc
import stem.util.system
from stem.util import conf, enum, log, str_tools
+from typing import Optional, Sequence, Union
# Connection resolution is risky to log about since it's highly likely to
# contain sensitive information. That said, it's also difficult to get right in
@@ -157,7 +158,7 @@ class Connection(collections.namedtuple('Connection', ['local_address', 'local_p
"""
-def download(url, timeout = None, retries = None):
+def download(url: str, timeout: Optional[int] = None, retries: Optional[int] = None) -> bytes:
"""
Download from the given url.
@@ -198,7 +199,7 @@ def download(url, timeout = None, retries = None):
raise stem.DownloadFailed(url, exc, stacktrace)
-def get_connections(resolver = None, process_pid = None, process_name = None):
+def get_connections(resolver: Optional['stem.util.connection.Resolver'] = None, process_pid: Optional[int] = None, process_name: Optional[str] = None) -> Sequence['stem.util.connection.Connection']:
"""
Retrieves a list of the current connections for a given process. This
provides a list of :class:`~stem.util.connection.Connection`. Note that
@@ -239,7 +240,7 @@ def get_connections(resolver = None, process_pid = None, process_name = None):
if not process_pid and not process_name:
raise ValueError('You must provide a pid or process name to provide connections for')
- def _log(msg):
+ def _log(msg: str) -> None:
if LOG_CONNECTION_RESOLUTION:
log.debug(msg)
@@ -288,7 +289,7 @@ def get_connections(resolver = None, process_pid = None, process_name = None):
connections = []
resolver_regex = re.compile(resolver_regex_str)
- def _parse_address_str(addr_type, addr_str, line):
+ def _parse_address_str(addr_type: str, addr_str: str, line: str) -> str:
addr, port = addr_str.rsplit(':', 1)
if not is_valid_ipv4_address(addr) and not is_valid_ipv6_address(addr, allow_brackets = True):
@@ -334,7 +335,7 @@ def get_connections(resolver = None, process_pid = None, process_name = None):
return connections
-def system_resolvers(system = None):
+def system_resolvers(system: Optional[str] = None) -> Sequence['stem.util.connection.Resolver']:
"""
Provides the types of connection resolvers likely to be available on this platform.
@@ -383,7 +384,7 @@ def system_resolvers(system = None):
return resolvers
-def port_usage(port):
+def port_usage(port: int) -> Optional[str]:
"""
Provides the common use of a given port. For example, 'HTTP' for port 80 or
'SSH' for 22.
@@ -429,7 +430,7 @@ def port_usage(port):
return PORT_USES.get(port)
-def is_valid_ipv4_address(address):
+def is_valid_ipv4_address(address: str) -> bool:
"""
Checks if a string is a valid IPv4 address.
@@ -458,7 +459,7 @@ def is_valid_ipv4_address(address):
return True
-def is_valid_ipv6_address(address, allow_brackets = False):
+def is_valid_ipv6_address(address: str, allow_brackets: bool = False) -> bool:
"""
Checks if a string is a valid IPv6 address.
@@ -513,7 +514,7 @@ def is_valid_ipv6_address(address, allow_brackets = False):
return True
-def is_valid_port(entry, allow_zero = False):
+def is_valid_port(entry: Union[str, int, Sequence[str], Sequence[int]], allow_zero: bool = False) -> bool:
"""
Checks if a string or int is a valid port number.
@@ -545,7 +546,7 @@ def is_valid_port(entry, allow_zero = False):
return False
-def is_private_address(address):
+def is_private_address(address: str) -> bool:
"""
Checks if the IPv4 address is in a range belonging to the local network or
loopback. These include:
@@ -581,7 +582,7 @@ def is_private_address(address):
return False
-def address_to_int(address):
+def address_to_int(address: str) -> int:
"""
Provides an integer representation of a IPv4 or IPv6 address that can be used
for sorting.
@@ -599,7 +600,7 @@ def address_to_int(address):
return int(_address_to_binary(address), 2)
-def expand_ipv6_address(address):
+def expand_ipv6_address(address: str) -> str:
"""
Expands abbreviated IPv6 addresses to their full colon separated hex format.
For instance...
@@ -660,7 +661,7 @@ def expand_ipv6_address(address):
return address
-def get_mask_ipv4(bits):
+def get_mask_ipv4(bits: int) -> str:
"""
Provides the IPv4 mask for a given number of bits, in the dotted-quad format.
@@ -686,7 +687,7 @@ def get_mask_ipv4(bits):
return '.'.join([str(int(octet, 2)) for octet in octets])
-def get_mask_ipv6(bits):
+def get_mask_ipv6(bits: int) -> str:
"""
Provides the IPv6 mask for a given number of bits, in the hex colon-delimited
format.
@@ -713,7 +714,7 @@ def get_mask_ipv6(bits):
return ':'.join(['%04x' % int(group, 2) for group in groupings]).upper()
-def _get_masked_bits(mask):
+def _get_masked_bits(mask: str) -> int:
"""
Provides the number of bits that an IPv4 subnet mask represents. Note that
not all masks can be represented by a bit count.
@@ -738,13 +739,15 @@ def _get_masked_bits(mask):
raise ValueError('Unable to convert mask to a bit count: %s' % mask)
-def _get_binary(value, bits):
+def _get_binary(value: int, bits: int) -> str:
"""
Provides the given value as a binary string, padded with zeros to the given
number of bits.
:param int value: value to be converted
:param int bits: number of bits to pad to
+
+ :returns: **str** of this binary value
"""
# http://www.daniweb.com/code/snippet216539.html
@@ -754,10 +757,12 @@ def _get_binary(value, bits):
# TODO: In stem 2.x we should consider unifying this with
# stem.client.datatype's _unpack_ipv4_address() and _unpack_ipv6_address().
-def _address_to_binary(address):
+def _address_to_binary(address: str) -> str:
"""
Provides the binary value for an IPv4 or IPv6 address.
+ :param str address: address to convert
+
:returns: **str** with the binary representation of this address
:raises: **ValueError** if address is neither an IPv4 nor IPv6 address
diff --git a/stem/util/enum.py b/stem/util/enum.py
index 56bf119d..b70d29f4 100644
--- a/stem/util/enum.py
+++ b/stem/util/enum.py
@@ -40,8 +40,10 @@ constructed as simple type listings...
+- __iter__ - iterator over our enum keys
"""
+from typing import Iterator, Sequence
-def UppercaseEnum(*args):
+
+def UppercaseEnum(*args: str) -> 'stem.util.enum.Enum':
"""
Provides an :class:`~stem.util.enum.Enum` instance where the values are
identical to the keys. Since the keys are uppercase by convention this means
@@ -67,7 +69,7 @@ class Enum(object):
Basic enumeration.
"""
- def __init__(self, *args):
+ def __init__(self, *args: str) -> None:
from stem.util.str_tools import _to_camel_case
# ordered listings of our keys and values
@@ -88,7 +90,7 @@ class Enum(object):
self._keys = tuple(keys)
self._values = tuple(values)
- def keys(self):
+ def keys(self) -> Sequence[str]:
"""
Provides an ordered listing of the enumeration keys in this set.
@@ -97,7 +99,7 @@ class Enum(object):
return list(self._keys)
- def index_of(self, value):
+ def index_of(self, value: str) -> int:
"""
Provides the index of the given value in the collection.
@@ -110,7 +112,7 @@ class Enum(object):
return self._values.index(value)
- def next(self, value):
+ def next(self, value: str) -> str:
"""
Provides the next enumeration after the given value.
@@ -127,7 +129,7 @@ class Enum(object):
next_index = (self._values.index(value) + 1) % len(self._values)
return self._values[next_index]
- def previous(self, value):
+ def previous(self, value: str) -> str:
"""
Provides the previous enumeration before the given value.
@@ -144,7 +146,7 @@ class Enum(object):
prev_index = (self._values.index(value) - 1) % len(self._values)
return self._values[prev_index]
- def __getitem__(self, item):
+ def __getitem__(self, item: str) -> str:
"""
Provides the values for the given key.
@@ -161,7 +163,7 @@ class Enum(object):
keys = ', '.join(self.keys())
raise ValueError("'%s' isn't among our enumeration keys, which includes: %s" % (item, keys))
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
"""
Provides an ordered listing of the enums in this set.
"""
diff --git a/stem/util/log.py b/stem/util/log.py
index 94d055ff..940469a3 100644
--- a/stem/util/log.py
+++ b/stem/util/log.py
@@ -92,10 +92,10 @@ DEDUPLICATION_MESSAGE_IDS = set()
class _NullHandler(logging.Handler):
- def __init__(self):
+ def __init__(self) -> None:
logging.Handler.__init__(self, level = logging.FATAL + 5) # disable logging
- def emit(self, record):
+ def emit(self, record: logging.LogRecord) -> None:
pass
@@ -103,7 +103,7 @@ if not LOGGER.handlers:
LOGGER.addHandler(_NullHandler())
-def get_logger():
+def get_logger() -> logging.Logger:
"""
Provides the stem logger.
@@ -113,7 +113,7 @@ def get_logger():
return LOGGER
-def logging_level(runlevel):
+def logging_level(runlevel: 'stem.util.log.Runlevel') -> int:
"""
Translates a runlevel into the value expected by the logging module.
@@ -126,7 +126,7 @@ def logging_level(runlevel):
return logging.FATAL + 5
-def is_tracing():
+def is_tracing() -> bool:
"""
Checks if we're logging at the trace runlevel.
@@ -142,7 +142,7 @@ def is_tracing():
return False
-def escape(message):
+def escape(message: str) -> str:
"""
Escapes specific sequences for logging (newlines, tabs, carriage returns). If
the input is **bytes** then this converts it to **unicode** under python 3.x.
@@ -160,7 +160,7 @@ def escape(message):
return message
-def log(runlevel, message):
+def log(runlevel: 'stem.util.log.Runlevel', message: str) -> None:
"""
Logs a message at the given runlevel.
@@ -172,7 +172,7 @@ def log(runlevel, message):
LOGGER.log(LOG_VALUES[runlevel], message)
-def log_once(message_id, runlevel, message):
+def log_once(message_id: str, runlevel: 'stem.util.log.Runlevel', message: str) -> None:
"""
Logs a message at the given runlevel. If a message with this ID has already
been logged then this is a no-op.
@@ -193,43 +193,43 @@ def log_once(message_id, runlevel, message):
# shorter aliases for logging at a runlevel
-def trace(message):
+def trace(message: str) -> None:
log(Runlevel.TRACE, message)
-def debug(message):
+def debug(message: str) -> None:
log(Runlevel.DEBUG, message)
-def info(message):
+def info(message: str) -> None:
log(Runlevel.INFO, message)
-def notice(message):
+def notice(message: str) -> None:
log(Runlevel.NOTICE, message)
-def warn(message):
+def warn(message: str) -> None:
log(Runlevel.WARN, message)
-def error(message):
+def error(message: str) -> None:
log(Runlevel.ERROR, message)
class _StdoutLogger(logging.Handler):
- def __init__(self, runlevel):
+ def __init__(self, runlevel: 'stem.util.log.Runlevel') -> None:
logging.Handler.__init__(self, level = logging_level(runlevel))
self.formatter = logging.Formatter(
fmt = '%(asctime)s [%(levelname)s] %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S')
- def emit(self, record):
+ def emit(self, record: logging.LogRecord) -> None:
print(self.formatter.format(record))
-def log_to_stdout(runlevel):
+def log_to_stdout(runlevel: 'stem.util.log.Runlevel') -> None:
"""
Logs further events to stdout.
diff --git a/stem/util/proc.py b/stem/util/proc.py
index 3589af13..10f2ae60 100644
--- a/stem/util/proc.py
+++ b/stem/util/proc.py
@@ -56,6 +56,7 @@ import stem.util.enum
import stem.util.str_tools
from stem.util import log
+from typing import Any, Mapping, Optional, Sequence, Set, Tuple, Type
try:
# unavailable on windows (#19823)
@@ -80,7 +81,7 @@ Stat = stem.util.enum.Enum(
@functools.lru_cache()
-def is_available():
+def is_available() -> bool:
"""
Checks if proc information is available on this platform.
@@ -101,7 +102,7 @@ def is_available():
@functools.lru_cache()
-def system_start_time():
+def system_start_time() -> float:
"""
Provides the unix time (seconds since epoch) when the system started.
@@ -124,7 +125,7 @@ def system_start_time():
@functools.lru_cache()
-def physical_memory():
+def physical_memory() -> int:
"""
Provides the total physical memory on the system in bytes.
@@ -146,7 +147,7 @@ def physical_memory():
raise exc
-def cwd(pid):
+def cwd(pid: int) -> str:
"""
Provides the current working directory for the given process.
@@ -174,7 +175,7 @@ def cwd(pid):
return cwd
-def uid(pid):
+def uid(pid: int) -> int:
"""
Provides the user ID the given process is running under.
@@ -199,7 +200,7 @@ def uid(pid):
raise exc
-def memory_usage(pid):
+def memory_usage(pid: int) -> Tuple[int, int]:
"""
Provides the memory usage in bytes for the given process.
@@ -232,7 +233,7 @@ def memory_usage(pid):
raise exc
-def stats(pid, *stat_types):
+def stats(pid: int, *stat_types: 'stem.util.proc.Stat') -> Sequence[Any]:
"""
Provides process specific information. See the :data:`~stem.util.proc.Stat`
enum for valid options.
@@ -270,6 +271,7 @@ def stats(pid, *stat_types):
raise exc
results = []
+
for stat_type in stat_types:
if stat_type == Stat.COMMAND:
if pid == 0:
@@ -300,7 +302,7 @@ def stats(pid, *stat_types):
return tuple(results)
-def file_descriptors_used(pid):
+def file_descriptors_used(pid: int) -> int:
"""
Provides the number of file descriptors currently being used by a process.
@@ -327,7 +329,7 @@ def file_descriptors_used(pid):
raise IOError('Unable to check number of file descriptors used: %s' % exc)
-def connections(pid = None, user = None):
+def connections(pid: Optional[int] = None, user: Optional[str] = None) -> Sequence['stem.util.connection.Connection']:
"""
Queries connections from the proc contents. This matches netstat, lsof, and
friends but is much faster. If no **pid** or **user** are provided this
@@ -412,7 +414,7 @@ def connections(pid = None, user = None):
raise
-def _inodes_for_sockets(pid):
+def _inodes_for_sockets(pid: int) -> Set[bytes]:
"""
Provides inodes in use by a process for its sockets.
@@ -450,7 +452,7 @@ def _inodes_for_sockets(pid):
return inodes
-def _unpack_addr(addr):
+def _unpack_addr(addr: str) -> str:
"""
Translates an address entry in the /proc/net/* contents to a human readable
form (`reference <http://linuxdevcenter.com/pub/a/linux/2000/11/16/LinuxAdmin.html>`_,
@@ -494,7 +496,7 @@ def _unpack_addr(addr):
return ENCODED_ADDR[addr]
-def _is_float(*value):
+def _is_float(*value: Any) -> bool:
try:
for v in value:
float(v)
@@ -504,11 +506,11 @@ def _is_float(*value):
return False
-def _get_line(file_path, line_prefix, parameter):
+def _get_line(file_path: str, line_prefix: str, parameter: str) -> str:
return _get_lines(file_path, (line_prefix, ), parameter)[line_prefix]
-def _get_lines(file_path, line_prefixes, parameter):
+def _get_lines(file_path: str, line_prefixes: Sequence[str], parameter: str) -> Mapping[str, str]:
"""
Fetches lines with the given prefixes from a file. This only provides back
the first instance of each prefix.
@@ -552,7 +554,7 @@ def _get_lines(file_path, line_prefixes, parameter):
raise
-def _log_runtime(parameter, proc_location, start_time):
+def _log_runtime(parameter: str, proc_location: str, start_time: int) -> None:
"""
Logs a message indicating a successful proc query.
@@ -565,7 +567,7 @@ def _log_runtime(parameter, proc_location, start_time):
log.debug('proc call (%s): %s (runtime: %0.4f)' % (parameter, proc_location, runtime))
-def _log_failure(parameter, exc):
+def _log_failure(parameter: str, exc: Type[Exception]) -> None:
"""
Logs a message indicating that the proc query failed.
diff --git a/stem/util/str_tools.py b/stem/util/str_tools.py
index c1285626..c606906a 100644
--- a/stem/util/str_tools.py
+++ b/stem/util/str_tools.py
@@ -26,6 +26,8 @@ import sys
import stem.util
import stem.util.enum
+from typing import Sequence, Tuple, Union
+
# label conversion tuples of the form...
# (bits / bytes / seconds, short label, long label)
@@ -57,7 +59,7 @@ TIME_UNITS = (
_timestamp_re = re.compile(r'(\d{4})-(\d{2})-(\d{2}) (\d{2}):(\d{2}):(\d{2})')
-def _to_bytes(msg):
+def _to_bytes(msg: Union[str, bytes]) -> bytes:
"""
Provides the ASCII bytes for the given string. This is purely to provide
python 3 compatability, normalizing the unicode/ASCII change in the version
@@ -76,7 +78,7 @@ def _to_bytes(msg):
return msg
-def _to_unicode(msg):
+def _to_unicode(msg: Union[str, bytes]) -> str:
"""
Provides the unicode string for the given ASCII bytes. This is purely to
provide python 3 compatability, normalizing the unicode/ASCII change in the
@@ -93,7 +95,7 @@ def _to_unicode(msg):
return msg
-def _decode_b64(msg):
+def _decode_b64(msg: Union[str, bytes]) -> str:
"""
Base64 decode, without padding concerns.
"""
@@ -104,7 +106,7 @@ def _decode_b64(msg):
return base64.b64decode(msg + padding_chr * missing_padding)
-def _to_int(msg):
+def _to_int(msg: Union[str, bytes]) -> int:
"""
Serializes a string to a number.
@@ -120,7 +122,7 @@ def _to_int(msg):
return sum([pow(256, (len(msg) - i - 1)) * ord(c) for (i, c) in enumerate(msg)])
-def _to_camel_case(label, divider = '_', joiner = ' '):
+def _to_camel_case(label: str, divider: str = '_', joiner: str = ' ') -> str:
"""
Converts the given string to camel case, ie:
@@ -148,7 +150,7 @@ def _to_camel_case(label, divider = '_', joiner = ' '):
return joiner.join(words)
-def _split_by_length(msg, size):
+def _split_by_length(msg: str, size: int) -> Sequence[str]:
"""
Splits a string into a list of strings up to the given size.
@@ -172,7 +174,7 @@ def _split_by_length(msg, size):
Ending = stem.util.enum.Enum('ELLIPSE', 'HYPHEN')
-def crop(msg, size, min_word_length = 4, min_crop = 0, ending = Ending.ELLIPSE, get_remainder = False):
+def crop(msg: str, size: int, min_word_length: int = 4, min_crop: int = 0, ending: 'stem.util.str_tools.Ending' = Ending.ELLIPSE, get_remainder: bool = False) -> str:
"""
Shortens a string to a given length.
@@ -286,7 +288,7 @@ def crop(msg, size, min_word_length = 4, min_crop = 0, ending = Ending.ELLIPSE,
return (return_msg, remainder) if get_remainder else return_msg
-def size_label(byte_count, decimal = 0, is_long = False, is_bytes = True, round = False):
+def size_label(byte_count: int, decimal: int = 0, is_long: bool = False, is_bytes: bool = True, round: bool = False) -> str:
"""
Converts a number of bytes into a human readable label in its most
significant units. For instance, 7500 bytes would return "7 KB". If the
@@ -323,7 +325,7 @@ def size_label(byte_count, decimal = 0, is_long = False, is_bytes = True, round
return _get_label(SIZE_UNITS_BITS, byte_count, decimal, is_long, round)
-def time_label(seconds, decimal = 0, is_long = False):
+def time_label(seconds: int, decimal: int = 0, is_long: bool = False) -> str:
"""
Converts seconds into a time label truncated to its most significant units.
For instance, 7500 seconds would return "2h". Units go up through days.
@@ -354,7 +356,7 @@ def time_label(seconds, decimal = 0, is_long = False):
return _get_label(TIME_UNITS, seconds, decimal, is_long)
-def time_labels(seconds, is_long = False):
+def time_labels(seconds: int, is_long: bool = False) -> Sequence[str]:
"""
Provides a list of label conversions for each time unit, starting with its
most significant units on down. Any counts that evaluate to zero are omitted.
@@ -384,7 +386,7 @@ def time_labels(seconds, is_long = False):
return time_labels
-def short_time_label(seconds):
+def short_time_label(seconds: int) -> str:
"""
Provides a time in the following format:
[[dd-]hh:]mm:ss
@@ -424,7 +426,7 @@ def short_time_label(seconds):
return label
-def parse_short_time_label(label):
+def parse_short_time_label(label: str) -> int:
"""
Provides the number of seconds corresponding to the formatting used for the
cputime and etime fields of ps:
@@ -469,7 +471,7 @@ def parse_short_time_label(label):
raise ValueError('Non-numeric value in time entry: %s' % label)
-def _parse_timestamp(entry):
+def _parse_timestamp(entry: str) -> 'datetime.datetime':
"""
Parses the date and time that in format like like...
@@ -495,7 +497,7 @@ def _parse_timestamp(entry):
return datetime.datetime(time[0], time[1], time[2], time[3], time[4], time[5])
-def _parse_iso_timestamp(entry):
+def _parse_iso_timestamp(entry: str) -> 'datetime.datetime':
"""
Parses the ISO 8601 standard that provides for timestamps like...
@@ -533,7 +535,7 @@ def _parse_iso_timestamp(entry):
return timestamp + datetime.timedelta(microseconds = int(microseconds))
-def _get_label(units, count, decimal, is_long, round = False):
+def _get_label(units: Tuple[int, str, str], count: int, decimal: int, is_long: bool, round: bool = False) -> str:
"""
Provides label corresponding to units of the highest significance in the
provided set. This rounds down (ie, integer truncation after visible units).
diff --git a/stem/util/system.py b/stem/util/system.py
index b3dee151..8a61b2b9 100644
--- a/stem/util/system.py
+++ b/stem/util/system.py
@@ -82,6 +82,7 @@ import stem.util.str_tools
from stem import UNDEFINED
from stem.util import log
+from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, TextIO, Union
State = stem.util.enum.UppercaseEnum(
'PENDING',
@@ -189,7 +190,7 @@ class CallError(OSError):
:var str stderr: stderr of the process
"""
- def __init__(self, msg, command, exit_status, runtime, stdout, stderr):
+ def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: str, stderr: str) -> None:
self.msg = msg
self.command = command
self.exit_status = exit_status
@@ -197,7 +198,7 @@ class CallError(OSError):
self.stdout = stdout
self.stderr = stderr
- def __str__(self):
+ def __str__(self) -> str:
return self.msg
@@ -210,7 +211,7 @@ class CallTimeoutError(CallError):
:var float timeout: time we waited
"""
- def __init__(self, msg, command, exit_status, runtime, stdout, stderr, timeout):
+ def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: str, stderr: str, timeout: float) -> None:
super(CallTimeoutError, self).__init__(msg, command, exit_status, runtime, stdout, stderr)
self.timeout = timeout
@@ -231,7 +232,7 @@ class DaemonTask(object):
:var exception error: exception raised by subprocess if it failed
"""
- def __init__(self, runner, args = None, priority = 15, start = False):
+ def __init__(self, runner: Callable, args: Optional[Sequence[Any]] = None, priority: int = 15, start: bool = False) -> None:
self.runner = runner
self.args = args
self.priority = priority
@@ -247,7 +248,7 @@ class DaemonTask(object):
if start:
self.run()
- def run(self):
+ def run(self) -> None:
"""
Invokes the task if it hasn't already been started. If it has this is a
no-op.
@@ -259,7 +260,7 @@ class DaemonTask(object):
self._process.start()
self.status = State.RUNNING
- def join(self):
+ def join(self) -> Any:
"""
Provides the result of the daemon task. If still running this blocks until
the task is completed.
@@ -292,7 +293,7 @@ class DaemonTask(object):
raise RuntimeError('BUG: unexpected status from daemon task, %s' % self.status)
@staticmethod
- def _run_wrapper(conn, priority, runner, args):
+ def _run_wrapper(conn: 'multiprocessing.connection.Connection', priority: int, runner: Callable, args: Sequence[Any]) -> None:
start_time = time.time()
os.nice(priority)
@@ -305,7 +306,7 @@ class DaemonTask(object):
conn.close()
-def is_windows():
+def is_windows() -> bool:
"""
Checks if we are running on Windows.
@@ -315,7 +316,7 @@ def is_windows():
return platform.system() == 'Windows'
-def is_mac():
+def is_mac() -> bool:
"""
Checks if we are running on Mac OSX.
@@ -325,7 +326,7 @@ def is_mac():
return platform.system() == 'Darwin'
-def is_gentoo():
+def is_gentoo() -> bool:
"""
Checks if we're running on Gentoo.
@@ -335,7 +336,7 @@ def is_gentoo():
return os.path.exists('/etc/gentoo-release')
-def is_slackware():
+def is_slackware() -> bool:
"""
Checks if we are running on a Slackware system.
@@ -345,7 +346,7 @@ def is_slackware():
return os.path.exists('/etc/slackware-version')
-def is_bsd():
+def is_bsd() -> bool:
"""
Checks if we are within the BSD family of operating systems. This currently
recognizes Macs, FreeBSD, and OpenBSD but may be expanded later.
@@ -356,7 +357,7 @@ def is_bsd():
return platform.system() in ('Darwin', 'FreeBSD', 'OpenBSD', 'NetBSD')
-def is_available(command, cached=True):
+def is_available(command: str, cached: bool = True) -> bool:
"""
Checks the current PATH to see if a command is available or not. If more
than one command is present (for instance "ls -a | grep foo") then this
@@ -399,7 +400,7 @@ def is_available(command, cached=True):
return cmd_exists
-def is_running(command):
+def is_running(command: Union[str, int, Sequence[str]]) -> bool:
"""
Checks for if a process with a given name or pid is running.
@@ -461,7 +462,7 @@ def is_running(command):
return None
-def size_of(obj, exclude = None):
+def size_of(obj: Any, exclude: Optional[Sequence[int]] = None) -> int:
"""
Provides the `approximate memory usage of an object
<https://code.activestate.com/recipes/577504/>`_. This can recurse tuples,
@@ -504,7 +505,7 @@ def size_of(obj, exclude = None):
return size
-def name_by_pid(pid):
+def name_by_pid(pid: int) -> Optional[str]:
"""
Attempts to determine the name a given process is running under (not
including arguments). This uses...
@@ -547,7 +548,7 @@ def name_by_pid(pid):
return process_name
-def pid_by_name(process_name, multiple = False):
+def pid_by_name(process_name: str, multiple: bool = False) -> Union[int, Sequence[int]]:
"""
Attempts to determine the process id for a running process, using...
@@ -718,7 +719,7 @@ def pid_by_name(process_name, multiple = False):
return [] if multiple else None
-def pid_by_port(port):
+def pid_by_port(port: int) -> Optional[int]:
"""
Attempts to determine the process id for a process with the given port,
using...
@@ -838,7 +839,7 @@ def pid_by_port(port):
return None # all queries failed
-def pid_by_open_file(path):
+def pid_by_open_file(path: str) -> Optional[int]:
"""
Attempts to determine the process id for a process with the given open file,
using...
@@ -876,7 +877,7 @@ def pid_by_open_file(path):
return None # all queries failed
-def pids_by_user(user):
+def pids_by_user(user: str) -> Optional[Sequence[int]]:
"""
Provides processes owned by a given user.
@@ -908,7 +909,7 @@ def pids_by_user(user):
return None
-def cwd(pid):
+def cwd(pid: int) -> Optional[str]:
"""
Provides the working directory of the given process.
@@ -977,7 +978,7 @@ def cwd(pid):
return None # all queries failed
-def user(pid):
+def user(pid: int) -> Optional[str]:
"""
Provides the user a process is running under.
@@ -1010,7 +1011,7 @@ def user(pid):
return None
-def start_time(pid):
+def start_time(pid: str) -> Optional[float]:
"""
Provides the unix timestamp when the given process started.
@@ -1041,7 +1042,7 @@ def start_time(pid):
return None
-def tail(target, lines = None):
+def tail(target: Union[str, TextIO], lines: Optional[int] = None) -> Iterator[str]:
"""
Provides lines of a file starting with the end. For instance,
'tail -n 50 /tmp/my_log' could be done with...
@@ -1094,7 +1095,7 @@ def tail(target, lines = None):
block_number -= 1
-def bsd_jail_id(pid):
+def bsd_jail_id(pid: int) -> int:
"""
Gets the jail id for a process. These seem to only exist for FreeBSD (this
style for jails does not exist on Linux, OSX, or OpenBSD).
@@ -1129,7 +1130,7 @@ def bsd_jail_id(pid):
return 0
-def bsd_jail_path(jid):
+def bsd_jail_path(jid: int) -> Optional[str]:
"""
Provides the path of the given FreeBSD jail.
@@ -1151,7 +1152,7 @@ def bsd_jail_path(jid):
return None
-def is_tarfile(path):
+def is_tarfile(path: str) -> bool:
"""
Returns if the path belongs to a tarfile or not.
@@ -1177,7 +1178,7 @@ def is_tarfile(path):
return mimetypes.guess_type(path)[0] == 'application/x-tar'
-def expand_path(path, cwd = None):
+def expand_path(path: str, cwd: Optional[str] = None) -> str:
"""
Provides an absolute path, expanding tildes with the user's home and
appending a current working directory if the path was relative.
@@ -1222,7 +1223,7 @@ def expand_path(path, cwd = None):
return relative_path
-def files_with_suffix(base_path, suffix):
+def files_with_suffix(base_path: str, suffix: str) -> Iterator[str]:
"""
Iterates over files in a given directory, providing filenames with a certain
suffix.
@@ -1245,7 +1246,7 @@ def files_with_suffix(base_path, suffix):
yield os.path.join(root, filename)
-def call(command, default = UNDEFINED, ignore_exit_status = False, timeout = None, cwd = None, env = None):
+def call(command: Union[str, Sequence[str]], default: Any = UNDEFINED, ignore_exit_status: bool = False, timeout: Optional[float] = None, cwd: Optional[str] = None, env: Optional[Mapping[str, str]] = None) -> Sequence[str]:
"""
call(command, default = UNDEFINED, ignore_exit_status = False)
@@ -1346,7 +1347,7 @@ def call(command, default = UNDEFINED, ignore_exit_status = False, timeout = Non
SYSTEM_CALL_TIME += time.time() - start_time
-def get_process_name():
+def get_process_name() -> str:
"""
Provides the present name of our process.
@@ -1398,7 +1399,7 @@ def get_process_name():
return _PROCESS_NAME
-def set_process_name(process_name):
+def set_process_name(process_name: str) -> None:
"""
Renames our current process from "python <args>" to a custom name. This is
best-effort, not necessarily working on all platforms.
@@ -1432,7 +1433,7 @@ def set_process_name(process_name):
_set_proc_title(process_name)
-def _set_argv(process_name):
+def _set_argv(process_name: str) -> None:
"""
Overwrites our argv in a similar fashion to how it's done in C with:
strcpy(argv[0], 'new_name');
@@ -1462,7 +1463,7 @@ def _set_argv(process_name):
_PROCESS_NAME = process_name
-def _set_prctl_name(process_name):
+def _set_prctl_name(process_name: str) -> None:
"""
Sets the prctl name, which is used by top and killall. This appears to be
Linux specific and has the max of 15 characters.
@@ -1477,7 +1478,7 @@ def _set_prctl_name(process_name):
libc.prctl(PR_SET_NAME, ctypes.byref(name_buffer), 0, 0, 0)
-def _set_proc_title(process_name):
+def _set_proc_title(process_name: str) -> None:
"""
BSD specific calls (should be compataible with both FreeBSD and OpenBSD:
http://fxr.watson.org/fxr/source/gen/setproctitle.c?v=FREEBSD-LIBC
diff --git a/stem/util/term.py b/stem/util/term.py
index 06391441..acc52cad 100644
--- a/stem/util/term.py
+++ b/stem/util/term.py
@@ -50,6 +50,8 @@ Utilities for working with the terminal.
import stem.util.enum
import stem.util.str_tools
+from typing import Optional, Union
+
TERM_COLORS = ('BLACK', 'RED', 'GREEN', 'YELLOW', 'BLUE', 'MAGENTA', 'CYAN', 'WHITE')
# DISABLE_COLOR_SUPPORT is *not* being vended to Stem users. This is likely to
@@ -70,7 +72,7 @@ CSI = '\x1B[%sm'
RESET = CSI % '0'
-def encoding(*attrs):
+def encoding(*attrs: Union['stem.util.terminal.Color', 'stem.util.terminal.BgColor', 'stem.util.terminal.Attr']) -> Optional[str]:
"""
Provides the ANSI escape sequence for these terminal color or attributes.
@@ -81,7 +83,7 @@ def encoding(*attrs):
provide an ecoding for
:returns: **str** of the ANSI escape sequence, **None** no attributes are
- recognized
+ unrecognized
"""
term_encodings = []
@@ -99,7 +101,7 @@ def encoding(*attrs):
return CSI % ';'.join(term_encodings)
-def format(msg, *attr):
+def format(msg: str, *attr: Union['stem.util.terminal.Color', 'stem.util.terminal.BgColor', 'stem.util.terminal.Attr']) -> str:
"""
Simple terminal text formatting using `ANSI escape sequences
<https://en.wikipedia.org/wiki/ANSI_escape_code#CSI_codes>`_.
@@ -118,7 +120,7 @@ def format(msg, *attr):
:data:`~stem.util.term.BgColor`, or :data:`~stem.util.term.Attr` enums
and are case insensitive (so strings like 'red' are fine)
- :returns: **unicode** wrapped with ANSI escape encodings, starting with the given
+ :returns: **str** wrapped with ANSI escape encodings, starting with the given
attributes and ending with a reset
"""
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py
index 71165214..d5d0f842 100644
--- a/stem/util/test_tools.py
+++ b/stem/util/test_tools.py
@@ -42,6 +42,8 @@ import stem.util.conf
import stem.util.enum
import stem.util.system
+from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Type
+
CONFIG = stem.util.conf.config_dict('test', {
'pycodestyle.ignore': [],
'pyflakes.ignore': [],
@@ -55,7 +57,7 @@ AsyncStatus = stem.util.enum.UppercaseEnum('PENDING', 'RUNNING', 'FINISHED')
AsyncResult = collections.namedtuple('AsyncResult', 'type msg')
-def assert_equal(expected, actual, msg = None):
+def assert_equal(expected: Any, actual: Any, msg: Optional[str] = None) -> None:
"""
Function form of a TestCase's assertEqual.
@@ -72,7 +74,7 @@ def assert_equal(expected, actual, msg = None):
raise AssertionError("Expected '%s' but was '%s'" % (expected, actual) if msg is None else msg)
-def assert_in(expected, actual, msg = None):
+def assert_in(expected: Any, actual: Any, msg: Optional[str] = None) -> None:
"""
Asserts that a given value is within this content.
@@ -89,7 +91,7 @@ def assert_in(expected, actual, msg = None):
raise AssertionError("Expected '%s' to be within '%s'" % (expected, actual) if msg is None else msg)
-def skip(msg):
+def skip(msg: str) -> None:
"""
Function form of a TestCase's skipTest.
@@ -100,10 +102,12 @@ def skip(msg):
:raises: **unittest.case.SkipTest** for this reason
"""
+ # TODO: remove now that python 2.x is unsupported?
+
raise unittest.case.SkipTest(msg)
-def asynchronous(func):
+def asynchronous(func: Callable) -> Callable:
test = stem.util.test_tools.AsyncTest(func)
ASYNC_TESTS[test.name] = test
return test.method
@@ -131,7 +135,7 @@ class AsyncTest(object):
.. versionadded:: 1.6.0
"""
- def __init__(self, runner, args = None, threaded = False):
+ def __init__(self, runner: Callable, args: Optional[Any] = None, threaded: bool = False) -> None:
self.name = '%s.%s' % (runner.__module__, runner.__name__)
self._runner = runner
@@ -147,8 +151,8 @@ class AsyncTest(object):
self._result = None
self._status = AsyncStatus.PENDING
- def run(self, *runner_args, **kwargs):
- def _wrapper(conn, runner, args):
+ def run(self, *runner_args: Any, **kwargs: Any) -> None:
+ def _wrapper(conn: 'multiprocessing.connection.Connection', runner: Callable, args: Any) -> None:
os.nice(12)
try:
@@ -187,14 +191,14 @@ class AsyncTest(object):
self._process.start()
self._status = AsyncStatus.RUNNING
- def pid(self):
+ def pid(self) -> int:
with self._process_lock:
return self._process.pid if (self._process and not self._threaded) else None
- def join(self):
+ def join(self) -> None:
self.result(None)
- def result(self, test):
+ def result(self, test: 'unittest.TestCase') -> None:
with self._process_lock:
if self._status == AsyncStatus.PENDING:
self.run()
@@ -231,18 +235,18 @@ class TimedTestRunner(unittest.TextTestRunner):
.. versionadded:: 1.6.0
"""
- def run(self, test):
+ def run(self, test: 'unittest.TestCase') -> None:
for t in test._tests:
original_type = type(t)
class _TestWrapper(original_type):
- def run(self, result = None):
+ def run(self, result: Optional[Any] = None) -> Any:
start_time = time.time()
result = super(type(self), self).run(result)
TEST_RUNTIMES[self.id()] = time.time() - start_time
return result
- def assertRaisesWith(self, exc_type, exc_msg, func, *args, **kwargs):
+ def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, func: Callable, *args: Any, **kwargs: Any) -> None:
"""
Asserts the given invokation raises the expected excepiton. This is
similar to unittest's assertRaises and assertRaisesRegexp, but checks
@@ -255,10 +259,10 @@ class TimedTestRunner(unittest.TextTestRunner):
return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), func, *args, **kwargs)
- def id(self):
+ def id(self) -> str:
return '%s.%s.%s' % (original_type.__module__, original_type.__name__, self._testMethodName)
- def __str__(self):
+ def __str__(self) -> str:
return '%s (%s.%s)' % (self._testMethodName, original_type.__module__, original_type.__name__)
t.__class__ = _TestWrapper
@@ -266,7 +270,7 @@ class TimedTestRunner(unittest.TextTestRunner):
return super(TimedTestRunner, self).run(test)
-def test_runtimes():
+def test_runtimes() -> Mapping[str, float]:
"""
Provides the runtimes of tests executed through TimedTestRunners.
@@ -279,7 +283,7 @@ def test_runtimes():
return dict(TEST_RUNTIMES)
-def clean_orphaned_pyc(paths):
+def clean_orphaned_pyc(paths: Sequence[str]) -> Sequence[str]:
"""
Deletes any file with a \\*.pyc extention without a corresponding \\*.py. This
helps to address a common gotcha when deleting python files...
@@ -324,7 +328,7 @@ def clean_orphaned_pyc(paths):
return orphaned_pyc
-def is_pyflakes_available():
+def is_pyflakes_available() -> bool:
"""
Checks if pyflakes is availalbe.
@@ -334,7 +338,7 @@ def is_pyflakes_available():
return _module_exists('pyflakes.api') and _module_exists('pyflakes.reporter')
-def is_pycodestyle_available():
+def is_pycodestyle_available() -> bool:
"""
Checks if pycodestyle is availalbe.
@@ -349,7 +353,7 @@ def is_pycodestyle_available():
return hasattr(pycodestyle, 'BaseReport')
-def stylistic_issues(paths, check_newlines = False, check_exception_keyword = False, prefer_single_quotes = False):
+def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_exception_keyword: bool = False, prefer_single_quotes: bool = False) -> Mapping[str, 'stem.util.test_tools.Issue']:
"""
Checks for stylistic issues that are an issue according to the parts of PEP8
we conform to. You can suppress pycodestyle issues by making a 'test'
@@ -425,7 +429,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa
else:
ignore_rules.append(rule)
- def is_ignored(path, rule, code):
+ def is_ignored(path: str, rule: str, code: str) -> bool:
for ignored_path, ignored_rule, ignored_code in ignore_for_file:
if path.endswith(ignored_path) and ignored_rule == rule and code.strip().startswith(ignored_code):
return True
@@ -440,7 +444,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa
import pycodestyle
class StyleReport(pycodestyle.BaseReport):
- def init_file(self, filename, lines, expected, line_offset):
+ def init_file(self, filename: str, lines: Sequence[str], expected: Tuple[str], line_offset: int) -> None:
super(StyleReport, self).init_file(filename, lines, expected, line_offset)
if not check_newlines and not check_exception_keyword and not prefer_single_quotes:
@@ -473,7 +477,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa
issues.setdefault(filename, []).append(Issue(index + 1, 'use single rather than double quotes', line))
- def error(self, line_number, offset, text, check):
+ def error(self, line_number: int, offset: int, text: str, check: str) -> None:
code = super(StyleReport, self).error(line_number, offset, text, check)
if code:
@@ -488,7 +492,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa
return issues
-def pyflakes_issues(paths):
+def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issue']:
"""
Performs static checks via pyflakes. False positives can be ignored via
'pyflakes.ignore' entries in our 'test' config. For instance...
@@ -521,23 +525,23 @@ def pyflakes_issues(paths):
import pyflakes.reporter
class Reporter(pyflakes.reporter.Reporter):
- def __init__(self):
+ def __init__(self) -> None:
self._ignored_issues = {}
for line in CONFIG['pyflakes.ignore']:
path, issue = line.split('=>')
self._ignored_issues.setdefault(path.strip(), []).append(issue.strip())
- def unexpectedError(self, filename, msg):
+ def unexpectedError(self, filename: str, msg: str) -> None:
self._register_issue(filename, None, msg, None)
- def syntaxError(self, filename, msg, lineno, offset, text):
+ def syntaxError(self, filename: str, msg: str, lineno: int, offset: int, text: str) -> None:
self._register_issue(filename, lineno, msg, text)
- def flake(self, msg):
+ def flake(self, msg: str) -> None:
self._register_issue(msg.filename, msg.lineno, msg.message % msg.message_args, None)
- def _is_ignored(self, path, issue):
+ def _is_ignored(self, path: str, issue: str) -> bool:
# Paths in pyflakes_ignore are relative, so we need to check to see if our
# path ends with any of them.
@@ -556,7 +560,7 @@ def pyflakes_issues(paths):
return False
- def _register_issue(self, path, line_number, issue, line):
+ def _register_issue(self, path: str, line_number: int, issue: str, line: int) -> None:
if not self._is_ignored(path, issue):
if path and line_number and not line:
line = linecache.getline(path, line_number).strip()
@@ -571,7 +575,7 @@ def pyflakes_issues(paths):
return issues
-def _module_exists(module_name):
+def _module_exists(module_name: str) -> bool:
"""
Checks if a module exists.
@@ -587,7 +591,7 @@ def _module_exists(module_name):
return False
-def _python_files(paths):
+def _python_files(paths: Sequence[str]) -> Iterator[str]:
for path in paths:
for file_path in stem.util.system.files_with_suffix(path, '.py'):
skip = False
diff --git a/stem/util/tor_tools.py b/stem/util/tor_tools.py
index 8987635e..2398b7bc 100644
--- a/stem/util/tor_tools.py
+++ b/stem/util/tor_tools.py
@@ -23,6 +23,8 @@ import re
import stem.util.str_tools
+from typing import Optional, Sequence, Union
+
# The control-spec defines the following as...
#
# Fingerprint = "$" 40*HEXDIG
@@ -45,7 +47,7 @@ HS_V2_ADDRESS_PATTERN = re.compile('^[a-z2-7]{16}$')
HS_V3_ADDRESS_PATTERN = re.compile('^[a-z2-7]{56}$')
-def is_valid_fingerprint(entry, check_prefix = False):
+def is_valid_fingerprint(entry: str, check_prefix: bool = False) -> bool:
"""
Checks if a string is a properly formatted relay fingerprint. This checks for
a '$' prefix if check_prefix is true, otherwise this only validates the hex
@@ -72,11 +74,11 @@ def is_valid_fingerprint(entry, check_prefix = False):
return False
-def is_valid_nickname(entry):
+def is_valid_nickname(entry: str) -> bool:
"""
Checks if a string is a valid format for being a nickname.
- :param str entry: string to be checked
+ :param str entry: string to check
:returns: **True** if the string could be a nickname, **False** otherwise
"""
@@ -90,10 +92,12 @@ def is_valid_nickname(entry):
return False
-def is_valid_circuit_id(entry):
+def is_valid_circuit_id(entry: str) -> bool:
"""
Checks if a string is a valid format for being a circuit identifier.
+ :param str entry: string to check
+
:returns: **True** if the string could be a circuit id, **False** otherwise
"""
@@ -106,29 +110,33 @@ def is_valid_circuit_id(entry):
return False
-def is_valid_stream_id(entry):
+def is_valid_stream_id(entry: str) -> bool:
"""
Checks if a string is a valid format for being a stream identifier.
Currently, this is just an alias to :func:`~stem.util.tor_tools.is_valid_circuit_id`.
+ :param str entry: string to check
+
:returns: **True** if the string could be a stream id, **False** otherwise
"""
return is_valid_circuit_id(entry)
-def is_valid_connection_id(entry):
+def is_valid_connection_id(entry: str) -> bool:
"""
Checks if a string is a valid format for being a connection identifier.
Currently, this is just an alias to :func:`~stem.util.tor_tools.is_valid_circuit_id`.
+ :param str entry: string to check
+
:returns: **True** if the string could be a connection id, **False** otherwise
"""
return is_valid_circuit_id(entry)
-def is_valid_hidden_service_address(entry, version = None):
+def is_valid_hidden_service_address(entry: str, version: Optional[Union[int, Sequence[int]]] = None) -> bool:
"""
Checks if a string is a valid format for being a hidden service address (not
including the '.onion' suffix).
@@ -137,6 +145,7 @@ def is_valid_hidden_service_address(entry, version = None):
Added the **version** argument, and responds with **True** if a version 3
hidden service address rather than just version 2 addresses.
+ :param str entry: string to check
:param int,list version: versions to check for, if unspecified either v2 or v3
hidden service address will provide **True**
@@ -166,7 +175,7 @@ def is_valid_hidden_service_address(entry, version = None):
return False
-def is_hex_digits(entry, count):
+def is_hex_digits(entry: str, count: int) -> bool:
"""
Checks if a string is the given number of hex digits. Digits represented by
letters are case insensitive.
diff --git a/stem/version.py b/stem/version.py
index 181aec8a..8ec35293 100644
--- a/stem/version.py
+++ b/stem/version.py
@@ -42,13 +42,15 @@ import stem.util
import stem.util.enum
import stem.util.system
+from typing import Any, Callable
+
# cache for the get_system_tor_version function
VERSION_CACHE = {}
VERSION_PATTERN = re.compile(r'^([0-9]+)\.([0-9]+)\.([0-9]+)(\.[0-9]+)?(-\S*)?(( \(\S*\))*)$')
-def get_system_tor_version(tor_cmd = 'tor'):
+def get_system_tor_version(tor_cmd: str = 'tor') -> 'stem.version.Version':
"""
Queries tor for its version. This is os dependent, only working on linux,
osx, and bsd.
@@ -96,7 +98,7 @@ def get_system_tor_version(tor_cmd = 'tor'):
@functools.lru_cache()
-def _get_version(version_str):
+def _get_version(version_str: str) -> 'stem.version.Version':
return Version(version_str)
@@ -125,7 +127,7 @@ class Version(object):
:raises: **ValueError** if input isn't a valid tor version
"""
- def __init__(self, version_str):
+ def __init__(self, version_str: str) -> None:
self.version_str = version_str
version_parts = VERSION_PATTERN.match(version_str)
@@ -157,14 +159,14 @@ class Version(object):
else:
raise ValueError("'%s' isn't a properly formatted tor version" % version_str)
- def __str__(self):
+ def __str__(self) -> str:
"""
Provides the string used to construct the version.
"""
return self.version_str
- def _compare(self, other, method):
+ def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> Callable[[Any, Any], bool]:
"""
Compares version ordering according to the spec.
"""
@@ -195,23 +197,23 @@ class Version(object):
return method(my_status, other_status)
- def __hash__(self):
+ def __hash__(self) -> int:
return stem.util._hash_attr(self, 'major', 'minor', 'micro', 'patch', 'status', cache = True)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s == o)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
- def __gt__(self, other):
+ def __gt__(self, other: Any) -> bool:
"""
Checks if this version meets the requirements for a given feature.
"""
return self._compare(other, lambda s, o: s > o)
- def __ge__(self, other):
+ def __ge__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s >= o)
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index 8b8b3205..732ae50a 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -339,14 +339,14 @@ class TestController(unittest.TestCase):
auth_methods = []
if test.runner.Torrc.COOKIE in tor_options:
- auth_methods.append(stem.response.protocolinfo.AuthMethod.COOKIE)
- auth_methods.append(stem.response.protocolinfo.AuthMethod.SAFECOOKIE)
+ auth_methods.append(stem.connection.AuthMethod.COOKIE)
+ auth_methods.append(stem.connection.AuthMethod.SAFECOOKIE)
if test.runner.Torrc.PASSWORD in tor_options:
- auth_methods.append(stem.response.protocolinfo.AuthMethod.PASSWORD)
+ auth_methods.append(stem.connection.AuthMethod.PASSWORD)
if not auth_methods:
- auth_methods.append(stem.response.protocolinfo.AuthMethod.NONE)
+ auth_methods.append(stem.connection.AuthMethod.NONE)
self.assertEqual(tuple(auth_methods), protocolinfo.auth_methods)
diff --git a/test/integ/response/protocolinfo.py b/test/integ/response/protocolinfo.py
index 2fb060db..3a9ee0be 100644
--- a/test/integ/response/protocolinfo.py
+++ b/test/integ/response/protocolinfo.py
@@ -125,8 +125,8 @@ class TestProtocolInfo(unittest.TestCase):
auth_methods, auth_cookie_path = [], None
if test.runner.Torrc.COOKIE in tor_options:
- auth_methods.append(stem.response.protocolinfo.AuthMethod.COOKIE)
- auth_methods.append(stem.response.protocolinfo.AuthMethod.SAFECOOKIE)
+ auth_methods.append(stem.connection.AuthMethod.COOKIE)
+ auth_methods.append(stem.connection.AuthMethod.SAFECOOKIE)
chroot_path = runner.get_chroot()
auth_cookie_path = runner.get_auth_cookie_path()
@@ -135,10 +135,10 @@ class TestProtocolInfo(unittest.TestCase):
auth_cookie_path = auth_cookie_path[len(chroot_path):]
if test.runner.Torrc.PASSWORD in tor_options:
- auth_methods.append(stem.response.protocolinfo.AuthMethod.PASSWORD)
+ auth_methods.append(stem.connection.AuthMethod.PASSWORD)
if not auth_methods:
- auth_methods.append(stem.response.protocolinfo.AuthMethod.NONE)
+ auth_methods.append(stem.connection.AuthMethod.NONE)
self.assertEqual((), protocolinfo_response.unknown_auth_methods)
self.assertEqual(tuple(auth_methods), protocolinfo_response.auth_methods)
diff --git a/test/settings.cfg b/test/settings.cfg
index 38a37ef9..8c6423bb 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -192,20 +192,16 @@ pycodestyle.ignore test/unit/util/connection.py => W291: _tor tor 158
# False positives from pyflakes. These are mappings between the path and the
# issue.
-pyflakes.ignore run_tests.py => 'unittest' imported but unused
-pyflakes.ignore stem/control.py => undefined name 'controller'
-pyflakes.ignore stem/manual.py => undefined name 'unichr'
+pyflakes.ignore stem/manual.py => undefined name 'sqlite3'
+pyflakes.ignore stem/client/cell.py => undefined name 'cryptography'
+pyflakes.ignore stem/client/cell.py => undefined name 'hashlib'
pyflakes.ignore stem/client/datatype.py => redefinition of unused 'pop' from *
-pyflakes.ignore stem/descriptor/hidden_service_descriptor.py => 'stem.descriptor.hidden_service.*' imported but unused
-pyflakes.ignore stem/descriptor/hidden_service_descriptor.py => 'from stem.descriptor.hidden_service import *' used; unable to detect undefined names
-pyflakes.ignore stem/interpreter/__init__.py => undefined name 'raw_input'
-pyflakes.ignore stem/response/events.py => undefined name 'long'
-pyflakes.ignore stem/util/__init__.py => undefined name 'long'
-pyflakes.ignore stem/util/__init__.py => undefined name 'unicode'
-pyflakes.ignore stem/util/conf.py => undefined name 'unicode'
-pyflakes.ignore stem/util/test_tools.py => 'pyflakes' imported but unused
-pyflakes.ignore stem/util/test_tools.py => 'pycodestyle' imported but unused
-pyflakes.ignore test/__init__.py => undefined name 'test'
+pyflakes.ignore stem/descriptor/hidden_service.py => undefined name 'cryptography'
+pyflakes.ignore stem/interpreter/autocomplete.py => undefined name 'stem'
+pyflakes.ignore stem/interpreter/help.py => undefined name 'stem'
+pyflakes.ignore stem/response/events.py => undefined name 'datetime'
+pyflakes.ignore stem/util/conf.py => undefined name 'stem'
+pyflakes.ignore stem/util/enum.py => undefined name 'stem'
pyflakes.ignore test/require.py => 'cryptography.utils.int_from_bytes' imported but unused
pyflakes.ignore test/require.py => 'cryptography.utils.int_to_bytes' imported but unused
pyflakes.ignore test/require.py => 'cryptography.hazmat.backends.default_backend' imported but unused
@@ -216,7 +212,6 @@ pyflakes.ignore test/require.py => 'cryptography.hazmat.primitives.serialization
pyflakes.ignore test/require.py => 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey' imported but unused
pyflakes.ignore test/unit/response/events.py => 'from stem import *' used; unable to detect undefined names
pyflakes.ignore test/unit/response/events.py => *may be undefined, or defined from star imports: stem
-pyflakes.ignore stem/util/str_tools.py => undefined name 'unicode'
pyflakes.ignore test/integ/interpreter.py => 'readline' imported but unused
# Test modules we want to run. Modules are roughly ordered by the dependencies
diff --git a/test/unit/response/protocolinfo.py b/test/unit/response/protocolinfo.py
index dd8d2160..a71746c9 100644
--- a/test/unit/response/protocolinfo.py
+++ b/test/unit/response/protocolinfo.py
@@ -13,8 +13,8 @@ import stem.version
from unittest.mock import Mock, patch
+from stem.connection import AuthMethod
from stem.response import ControlMessage
-from stem.response.protocolinfo import AuthMethod
NO_AUTH = """250-PROTOCOLINFO 1
250-AUTH METHODS=NULL
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits