[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Fix mypy issues
commit 076f89dfb4fd4a156ae32fde8c78c531385162e8
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date: Tue Apr 14 18:43:20 2020 -0700
Fix mypy issues
Correcting the issues brought up by mypy. Some bug fixes and reconfiguring.
Significant changes include...
* Refactored commmandline parsing (stem/interpreter/arguments.py and
test/arguments.py) to be a typing.NamedTuple derived class.
* Temporarily ignoring all mypy warnings related to our enum class.
Python 3.x added its own enum that I'd like to swap us over to but
that will be a separate project.
---
run_tests.py | 6 +-
stem/__init__.py | 12 +-
stem/client/__init__.py | 46 ++++---
stem/client/cell.py | 52 +++----
stem/client/datatype.py | 91 +++++++------
stem/connection.py | 41 +++---
stem/control.py | 228 +++++++++++++++----------------
stem/descriptor/__init__.py | 231 ++++++++++++++++---------------
stem/descriptor/bandwidth_file.py | 34 +++--
stem/descriptor/certificate.py | 20 +--
stem/descriptor/collector.py | 82 +++++------
stem/descriptor/extrainfo_descriptor.py | 60 ++++----
stem/descriptor/hidden_service.py | 118 ++++++++--------
stem/descriptor/microdescriptor.py | 20 +--
stem/descriptor/networkstatus.py | 136 ++++++++++---------
stem/descriptor/remote.py | 48 +++----
stem/descriptor/router_status_entry.py | 61 +++++----
stem/descriptor/server_descriptor.py | 69 +++++-----
stem/descriptor/tordnsel.py | 30 ++--
stem/directory.py | 32 ++---
stem/exit_policy.py | 60 ++++----
stem/interpreter/__init__.py | 12 +-
stem/interpreter/arguments.py | 177 ++++++++++++------------
stem/interpreter/autocomplete.py | 11 +-
stem/interpreter/commands.py | 28 ++--
stem/interpreter/help.py | 13 +-
stem/manual.py | 37 ++---
stem/process.py | 6 +-
stem/response/__init__.py | 104 ++++++++++----
stem/response/events.py | 234 ++++++++++++++++++++++++++++----
stem/response/getconf.py | 4 +-
stem/response/getinfo.py | 12 +-
stem/response/protocolinfo.py | 7 +-
stem/socket.py | 55 +++++---
stem/util/__init__.py | 6 +-
stem/util/conf.py | 26 ++--
stem/util/connection.py | 41 +++---
stem/util/enum.py | 31 +++--
stem/util/log.py | 3 +-
stem/util/proc.py | 12 +-
stem/util/str_tools.py | 32 +++--
stem/util/system.py | 48 ++++---
stem/util/term.py | 12 +-
stem/util/test_tools.py | 85 ++++++------
stem/version.py | 11 +-
test/arguments.py | 233 ++++++++++++++++---------------
test/mypy.ini | 6 +
test/settings.cfg | 19 +++
test/task.py | 2 +-
test/unit/client/address.py | 2 +-
test/unit/control/controller.py | 6 +-
test/unit/descriptor/bandwidth_file.py | 3 +-
test/unit/interpreter/arguments.py | 32 ++---
test/unit/util/proc.py | 23 ++--
54 files changed, 1608 insertions(+), 1202 deletions(-)
diff --git a/run_tests.py b/run_tests.py
index 2ea07dab..fd46211f 100755
--- a/run_tests.py
+++ b/run_tests.py
@@ -194,7 +194,7 @@ def main():
test_config.load(os.environ['STEM_TEST_CONFIG'])
try:
- args = test.arguments.parse(sys.argv[1:])
+ args = test.arguments.Arguments.parse(sys.argv[1:])
test.task.TOR_VERSION.args = (args.tor_path,)
test.output.SUPPRESS_STDOUT = args.quiet
except ValueError as exc:
@@ -202,7 +202,7 @@ def main():
sys.exit(1)
if args.print_help:
- println(test.arguments.get_help())
+ println(test.arguments.Arguments.get_help())
sys.exit()
elif not args.run_unit and not args.run_integ:
println('Nothing to run (for usage provide --help)\n')
@@ -383,7 +383,7 @@ def _print_static_issues(static_check_issues):
if static_check_issues:
println('STATIC CHECKS', STATUS)
- for file_path in static_check_issues:
+ for file_path in sorted(static_check_issues):
println('* %s' % file_path, STATUS)
# Make a dict of line numbers to its issues. This is so we can both sort
diff --git a/stem/__init__.py b/stem/__init__.py
index c0efab19..ce8d70a9 100644
--- a/stem/__init__.py
+++ b/stem/__init__.py
@@ -567,7 +567,7 @@ __all__ = [
]
# Constant that we use by default for our User-Agent when downloading descriptors
-stem.USER_AGENT = 'Stem/%s' % __version__
+USER_AGENT = 'Stem/%s' % __version__
# Constant to indicate an undefined argument default. Usually we'd use None for
# this, but users will commonly provide None as the argument so need something
@@ -612,7 +612,7 @@ class ORPort(Endpoint):
:var list link_protocols: link protocol version we're willing to establish
"""
- def __init__(self, address: str, port: int, link_protocols: Optional[Sequence[int]] = None) -> None:
+ def __init__(self, address: str, port: int, link_protocols: Optional[Sequence['stem.client.datatype.LinkProtocol']] = None) -> None: # type: ignore
super(ORPort, self).__init__(address, port)
self.link_protocols = link_protocols
@@ -644,6 +644,8 @@ class OperationFailed(ControllerError):
message
"""
+ # TODO: should the code be an int instead?
+
def __init__(self, code: Optional[str] = None, message: Optional[str] = None) -> None:
super(ControllerError, self).__init__(message)
self.code = code
@@ -663,7 +665,7 @@ class CircuitExtensionFailed(UnsatisfiableRequest):
:var stem.response.events.CircuitEvent circ: response notifying us of the failure
"""
- def __init__(self, message: str, circ: Optional['stem.response.events.CircuitEvent'] = None) -> None:
+ def __init__(self, message: str, circ: Optional['stem.response.events.CircuitEvent'] = None) -> None: # type: ignore
super(CircuitExtensionFailed, self).__init__(message = message)
self.circ = circ
@@ -775,7 +777,7 @@ class DownloadTimeout(DownloadFailed):
.. versionadded:: 1.8.0
"""
- def __init__(self, url: str, error: BaseException, stacktrace: Any, timeout: int):
+ def __init__(self, url: str, error: BaseException, stacktrace: Any, timeout: float):
message = 'Failed to download from %s: %0.1f second timeout reached' % (url, timeout)
super(DownloadTimeout, self).__init__(url, error, stacktrace, message)
@@ -919,7 +921,7 @@ StreamStatus = stem.util.enum.UppercaseEnum(
)
# StreamClosureReason is a superset of RelayEndReason
-StreamClosureReason = stem.util.enum.UppercaseEnum(*(RelayEndReason.keys() + [
+StreamClosureReason = stem.util.enum.UppercaseEnum(*(RelayEndReason.keys() + [ # type: ignore
'END',
'PRIVATE_ADDR',
]))
diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 2972985d..8726bdbf 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -34,11 +34,12 @@ import stem.socket
import stem.util.connection
from types import TracebackType
-from typing import Iterator, Optional, Tuple, Type
+from typing import Dict, Iterator, List, Optional, Sequence, Type, Union
from stem.client.cell import (
CELL_TYPE_SIZE,
FIXED_PAYLOAD_LEN,
+ PAYLOAD_LEN_SIZE,
Cell,
)
@@ -66,15 +67,15 @@ class Relay(object):
:var int link_protocol: link protocol version we established
"""
- def __init__(self, orport: int, link_protocol: int) -> None:
+ def __init__(self, orport: stem.socket.RelaySocket, link_protocol: int) -> None:
self.link_protocol = LinkProtocol(link_protocol)
self._orport = orport
self._orport_buffer = b'' # unread bytes
self._orport_lock = threading.RLock()
- self._circuits = {}
+ self._circuits = {} # type: Dict[int, stem.client.Circuit]
@staticmethod
- def connect(address: str, port: int, link_protocols: Tuple[int] = DEFAULT_LINK_PROTOCOLS) -> None:
+ def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore
"""
Establishes a connection with the given ORPort.
@@ -121,7 +122,7 @@ class Relay(object):
# first VERSIONS cell, always have CIRCID_LEN == 2 for backward
# compatibility.
- conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2))
+ conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) # type: ignore
response = conn.recv()
# Link negotiation ends right away if we lack a common protocol
@@ -131,12 +132,12 @@ class Relay(object):
conn.close()
raise stem.SocketError('Unable to establish a common link protocol with %s:%i' % (address, port))
- versions_reply = stem.client.cell.Cell.pop(response, 2)[0]
+ versions_reply = stem.client.cell.Cell.pop(response, 2)[0] # type: stem.client.cell.VersionsCell # type: ignore
common_protocols = set(link_protocols).intersection(versions_reply.versions)
if not common_protocols:
conn.close()
- raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(link_protocols), address, port, ', '.join(versions_reply.versions)))
+ raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(map(str, link_protocols)), address, port, ', '.join(map(str, versions_reply.versions))))
# Establishing connections requires sending a NETINFO, but including our
# address is optional. We can revisit including it when we have a usecase
@@ -147,7 +148,10 @@ class Relay(object):
return Relay(conn, link_protocol)
- def _recv(self, raw: bool = False) -> None:
+ def _recv_bytes(self) -> bytes:
+ return self._recv(True) # type: ignore
+
+ def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell':
"""
Reads the next cell from our ORPort. If none is present this blocks
until one is available.
@@ -172,18 +176,18 @@ class Relay(object):
else:
# variable length, our next field is the payload size
- while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN.size):
+ while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN):
self._orport_buffer += self._orport.recv() # read until we know the cell size
- payload_len = FIXED_PAYLOAD_LEN.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0]
- cell_size = circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN.size + payload_len
+ payload_len = PAYLOAD_LEN_SIZE.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0]
+ cell_size = circ_id_size + CELL_TYPE_SIZE.size + payload_len
while len(self._orport_buffer) < cell_size:
self._orport_buffer += self._orport.recv() # read until we have the full cell
if raw:
content, self._orport_buffer = split(self._orport_buffer, cell_size)
- return content
+ return content # type: ignore
else:
cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol)
return cell
@@ -213,12 +217,12 @@ class Relay(object):
:returns: **generator** with the cells received in reply
"""
+ # TODO: why is this an iterator?
+
self._orport.recv(timeout = 0) # discard unread data
self._orport.send(cell.pack(self.link_protocol))
response = self._orport.recv(timeout = 1)
-
- for received_cell in stem.client.cell.Cell.pop(response, self.link_protocol):
- yield received_cell
+ yield stem.client.cell.Cell.pop(response, self.link_protocol)[0]
def is_alive(self) -> bool:
"""
@@ -251,7 +255,7 @@ class Relay(object):
with self._orport_lock:
return self._orport.close()
- def create_circuit(self) -> None:
+ def create_circuit(self) -> 'stem.client.Circuit':
"""
Establishes a new circuit.
"""
@@ -314,7 +318,7 @@ class Circuit(object):
except ImportError:
raise ImportError('Circuit construction requires the cryptography module')
- ctr = modes.CTR(ZERO * (algorithms.AES.block_size // 8))
+ ctr = modes.CTR(ZERO * (algorithms.AES.block_size // 8)) # type: ignore
self.relay = relay
self.id = circ_id
@@ -323,7 +327,7 @@ class Circuit(object):
self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor()
self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor()
- def directory(self, request: str, stream_id: int = 0) -> str:
+ def directory(self, request: str, stream_id: int = 0) -> bytes:
"""
Request descriptors from the relay.
@@ -337,13 +341,13 @@ class Circuit(object):
self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id)
self._send(RelayCommand.DATA, request, stream_id = stream_id)
- response = []
+ response = [] # type: List[stem.client.cell.RelayCell]
while True:
# Decrypt relay cells received in response. Our digest/key only
# updates when handled successfully.
- encrypted_cell = self.relay._recv(raw = True)
+ encrypted_cell = self.relay._recv_bytes()
decrypted_cell, backward_key, backward_digest = stem.client.cell.RelayCell.decrypt(self.relay.link_protocol, encrypted_cell, self.backward_key, self.backward_digest)
@@ -358,7 +362,7 @@ class Circuit(object):
else:
response.append(decrypted_cell)
- def _send(self, command: 'stem.client.datatype.RelayCommand', data: bytes = b'', stream_id: int = 0) -> None:
+ def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None:
"""
Sends a message over the circuit.
diff --git a/stem/client/cell.py b/stem/client/cell.py
index ef445a64..c88ba716 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -49,12 +49,12 @@ from stem import UNDEFINED
from stem.client.datatype import HASH_LEN, ZERO, LinkProtocol, Address, Certificate, CloseReason, RelayCommand, Size, split
from stem.util import datetime_to_unix, str_tools
-from typing import Any, Sequence, Tuple, Type
+from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type, Union
FIXED_PAYLOAD_LEN = 509 # PAYLOAD_LEN, per tor-spec section 0.2
AUTH_CHALLENGE_SIZE = 32
-CELL_TYPE_SIZE = Size.CHAR
+CELL_TYPE_SIZE = Size.CHAR # type: stem.client.datatype.Size
PAYLOAD_LEN_SIZE = Size.SHORT
RELAY_DIGEST_SIZE = Size.LONG
@@ -138,11 +138,11 @@ class Cell(object):
raise ValueError("'%s' isn't a valid cell value" % value)
- def pack(self, link_protocol):
+ def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes:
raise NotImplementedError('Packing not yet implemented for %s cells' % type(self).NAME)
@staticmethod
- def unpack(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.Cell':
+ def unpack(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> Iterator['stem.client.cell.Cell']:
"""
Unpacks all cells from a response.
@@ -193,7 +193,7 @@ class Cell(object):
return cls._unpack(payload, circ_id, link_protocol), content
@classmethod
- def _pack(cls: Type['stem.client.cell.Cell'], link_protocol: 'stem.client.datatype.LinkProtocol', payload: bytes, unused: bytes = b'', circ_id: int = None) -> bytes:
+ def _pack(cls: Type['stem.client.cell.Cell'], link_protocol: 'stem.client.datatype.LinkProtocol', payload: bytes, unused: bytes = b'', circ_id: Optional[int] = None) -> bytes:
"""
Provides bytes that can be used on the wire for these cell attributes.
Format of a properly packed cell depends on if it's fixed or variable
@@ -292,7 +292,7 @@ class PaddingCell(Cell):
VALUE = 0
IS_FIXED_SIZE = True
- def __init__(self, payload: bytes = None) -> None:
+ def __init__(self, payload: Optional[bytes] = None) -> None:
if not payload:
payload = os.urandom(FIXED_PAYLOAD_LEN)
elif len(payload) != FIXED_PAYLOAD_LEN:
@@ -317,8 +317,8 @@ class CreateCell(CircuitCell):
VALUE = 1
IS_FIXED_SIZE = True
- def __init__(self) -> None:
- super(CreateCell, self).__init__() # TODO: implement
+ def __init__(self, circ_id: int, unused: bytes = b'') -> None:
+ super(CreateCell, self).__init__(circ_id, unused) # TODO: implement
class CreatedCell(CircuitCell):
@@ -326,8 +326,8 @@ class CreatedCell(CircuitCell):
VALUE = 2
IS_FIXED_SIZE = True
- def __init__(self) -> None:
- super(CreatedCell, self).__init__() # TODO: implement
+ def __init__(self, circ_id: int, unused: bytes = b'') -> None:
+ super(CreatedCell, self).__init__(circ_id, unused) # TODO: implement
class RelayCell(CircuitCell):
@@ -352,13 +352,13 @@ class RelayCell(CircuitCell):
VALUE = 3
IS_FIXED_SIZE = True
- def __init__(self, circ_id: int, command, data: bytes, digest: int = 0, stream_id: int = 0, recognized: int = 0, unused: bytes = b'') -> None:
+ def __init__(self, circ_id: int, command, data: Union[bytes, str], digest: Union[int, bytes, str, 'hashlib._HASH'] = 0, stream_id: int = 0, recognized: int = 0, unused: bytes = b'') -> None: # type: ignore
if 'hash' in str(type(digest)).lower():
# Unfortunately hashlib generates from a dynamic private class so
# isinstance() isn't such a great option. With python2/python3 the
# name is 'hashlib.HASH' whereas PyPy calls it just 'HASH' or 'Hash'.
- digest_packed = digest.digest()[:RELAY_DIGEST_SIZE.size]
+ digest_packed = digest.digest()[:RELAY_DIGEST_SIZE.size] # type: ignore
digest = RELAY_DIGEST_SIZE.unpack(digest_packed)
elif isinstance(digest, (bytes, str)):
digest_packed = digest[:RELAY_DIGEST_SIZE.size]
@@ -393,7 +393,7 @@ class RelayCell(CircuitCell):
return RelayCell._pack(link_protocol, bytes(payload), self.unused, self.circ_id)
@staticmethod
- def decrypt(link_protocol: 'stem.client.datatype.LinkProtocol', content: bytes, key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib.HASH') -> Tuple['stem.client.cell.RelayCell', 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib.HASH']:
+ def decrypt(link_protocol: 'stem.client.datatype.LinkProtocol', content: bytes, key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib._HASH') -> Tuple['stem.client.cell.RelayCell', 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib._HASH']: # type: ignore
"""
Decrypts content as a relay cell addressed to us. This provides back a
tuple of the form...
@@ -447,7 +447,7 @@ class RelayCell(CircuitCell):
return cell, new_key, new_digest
- def encrypt(self, link_protocol: 'stem.client.datatype.LinkProtocol', key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib.HASH') -> Tuple[bytes, 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib.HASH']:
+ def encrypt(self, link_protocol: 'stem.client.datatype.LinkProtocol', key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib._HASH') -> Tuple[bytes, 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib._HASH']: # type: ignore
"""
Encrypts our cell content to be sent with the given key. This provides back
a tuple of the form...
@@ -540,7 +540,7 @@ class CreateFastCell(CircuitCell):
VALUE = 5
IS_FIXED_SIZE = True
- def __init__(self, circ_id: int, key_material: bytes = None, unused: bytes = b'') -> None:
+ def __init__(self, circ_id: int, key_material: Optional[bytes] = None, unused: bytes = b'') -> None:
if not key_material:
key_material = os.urandom(HASH_LEN)
elif len(key_material) != HASH_LEN:
@@ -577,7 +577,7 @@ class CreatedFastCell(CircuitCell):
VALUE = 6
IS_FIXED_SIZE = True
- def __init__(self, circ_id: int, derivative_key: bytes, key_material: bytes = None, unused: bytes = b'') -> None:
+ def __init__(self, circ_id: int, derivative_key: bytes, key_material: Optional[bytes] = None, unused: bytes = b'') -> None:
if not key_material:
key_material = os.urandom(HASH_LEN)
elif len(key_material) != HASH_LEN:
@@ -594,7 +594,7 @@ class CreatedFastCell(CircuitCell):
return CreatedFastCell._pack(link_protocol, self.key_material + self.derivative_key, self.unused, self.circ_id)
@classmethod
- def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreateFastCell':
+ def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreatedFastCell':
if len(content) < HASH_LEN * 2:
raise ValueError('Key material and derivatived key should be %i bytes, but was %i' % (HASH_LEN * 2, len(content)))
@@ -653,7 +653,7 @@ class NetinfoCell(Cell):
VALUE = 8
IS_FIXED_SIZE = True
- def __init__(self, receiver_address: 'stem.client.datatype.Address', sender_addresses: Sequence['stem.client.datatype.Address'], timestamp: datetime.datetime = None, unused: bytes = b'') -> None:
+ def __init__(self, receiver_address: 'stem.client.datatype.Address', sender_addresses: Sequence['stem.client.datatype.Address'], timestamp: Optional[datetime.datetime] = None, unused: bytes = b'') -> None:
super(NetinfoCell, self).__init__(unused)
self.timestamp = timestamp if timestamp else datetime.datetime.now()
self.receiver_address = receiver_address
@@ -693,8 +693,8 @@ class RelayEarlyCell(CircuitCell):
VALUE = 9
IS_FIXED_SIZE = True
- def __init__(self) -> None:
- super(RelayEarlyCell, self).__init__() # TODO: implement
+ def __init__(self, circ_id: int, unused: bytes = b'') -> None:
+ super(RelayEarlyCell, self).__init__(circ_id, unused) # TODO: implement
class Create2Cell(CircuitCell):
@@ -702,8 +702,8 @@ class Create2Cell(CircuitCell):
VALUE = 10
IS_FIXED_SIZE = True
- def __init__(self) -> None:
- super(Create2Cell, self).__init__() # TODO: implement
+ def __init__(self, circ_id: int, unused: bytes = b'') -> None:
+ super(Create2Cell, self).__init__(circ_id, unused) # TODO: implement
class Created2Cell(Cell):
@@ -735,7 +735,7 @@ class VPaddingCell(Cell):
VALUE = 128
IS_FIXED_SIZE = False
- def __init__(self, size: int = None, payload: bytes = None) -> None:
+ def __init__(self, size: Optional[int] = None, payload: Optional[bytes] = None) -> None:
if size is None and payload is None:
raise ValueError('VPaddingCell constructor must specify payload or size')
elif size is not None and size < 0:
@@ -768,7 +768,7 @@ class CertsCell(Cell):
VALUE = 129
IS_FIXED_SIZE = False
- def __init__(self, certs: Sequence['stem.client.Certificate'], unused: bytes = b'') -> None:
+ def __init__(self, certs: Sequence['stem.client.datatype.Certificate'], unused: bytes = b'') -> None:
super(CertsCell, self).__init__(unused)
self.certificates = certs
@@ -778,7 +778,7 @@ class CertsCell(Cell):
@classmethod
def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CertsCell':
cert_count, content = Size.CHAR.pop(content)
- certs = []
+ certs = [] # type: List[stem.client.datatype.Certificate]
for i in range(cert_count):
if not content:
@@ -806,7 +806,7 @@ class AuthChallengeCell(Cell):
VALUE = 130
IS_FIXED_SIZE = False
- def __init__(self, methods: Sequence[int], challenge: bytes = None, unused: bytes = b'') -> None:
+ def __init__(self, methods: Sequence[int], challenge: Optional[bytes] = None, unused: bytes = b'') -> None:
if not challenge:
challenge = os.urandom(AUTH_CHALLENGE_SIZE)
elif len(challenge) != AUTH_CHALLENGE_SIZE:
diff --git a/stem/client/datatype.py b/stem/client/datatype.py
index 8d8ae7fb..acc9ec34 100644
--- a/stem/client/datatype.py
+++ b/stem/client/datatype.py
@@ -144,7 +144,7 @@ import stem.util
import stem.util.connection
import stem.util.enum
-from typing import Any, Tuple, Type, Union
+from typing import Any, Optional, Tuple, Union
ZERO = b'\x00'
HASH_LEN = 20
@@ -157,17 +157,17 @@ class _IntegerEnum(stem.util.enum.Enum):
**UNKNOWN** value for integer values that lack a mapping.
"""
- def __init__(self, *args: Tuple[str, int]) -> None:
+ def __init__(self, *args: Union[Tuple[str, int], Tuple[str, str, int]]) -> None:
self._enum_to_int = {}
self._int_to_enum = {}
parent_args = []
for entry in args:
if len(entry) == 2:
- enum, int_val = entry
+ enum, int_val = entry # type: ignore
str_val = enum
elif len(entry) == 3:
- enum, str_val, int_val = entry
+ enum, str_val, int_val = entry # type: ignore
else:
raise ValueError('IntegerEnums can only be constructed with two or three value tuples: %s' % repr(entry))
@@ -272,19 +272,16 @@ class LinkProtocol(int):
from a range that's determined by our link protocol.
"""
- def __new__(cls: Type['stem.client.datatype.LinkProtocol'], version: int) -> 'stem.client.datatype.LinkProtocol':
- if isinstance(version, LinkProtocol):
- return version # already a LinkProtocol
+ def __new__(self, version: int) -> 'stem.client.datatype.LinkProtocol':
+ return int.__new__(self, version) # type: ignore
- protocol = int.__new__(cls, version)
- protocol.version = version
- protocol.circ_id_size = Size.LONG if version > 3 else Size.SHORT
- protocol.first_circ_id = 0x80000000 if version > 3 else 0x01
+ def __init__(self, version: int) -> None:
+ self.version = version
+ self.circ_id_size = Size.LONG if version > 3 else Size.SHORT
+ self.first_circ_id = 0x80000000 if version > 3 else 0x01
- cell_header_size = protocol.circ_id_size.size + 1 # circuit id (2 or 4 bytes) + command (1 byte)
- protocol.fixed_cell_length = cell_header_size + stem.client.cell.FIXED_PAYLOAD_LEN
-
- return protocol
+ cell_header_size = self.circ_id_size.size + 1 # circuit id (2 or 4 bytes) + command (1 byte)
+ self.fixed_cell_length = cell_header_size + stem.client.cell.FIXED_PAYLOAD_LEN
def __hash__(self) -> int:
# All LinkProtocol attributes can be derived from our version, so that's
@@ -380,6 +377,11 @@ class Size(Field):
==================== ===========
"""
+ CHAR = None # type: Optional[stem.client.datatype.Size]
+ SHORT = None # type: Optional[stem.client.datatype.Size]
+ LONG = None # type: Optional[stem.client.datatype.Size]
+ LONG_LONG = None # type: Optional[stem.client.datatype.Size]
+
def __init__(self, name: str, size: int) -> None:
self.name = name
self.size = size
@@ -388,7 +390,7 @@ class Size(Field):
def pop(packed: bytes) -> Tuple[int, bytes]:
raise NotImplementedError("Use our constant's unpack() and pop() instead")
- def pack(self, content: int) -> bytes:
+ def pack(self, content: int) -> bytes: # type: ignore
try:
return content.to_bytes(self.size, 'big')
except:
@@ -399,13 +401,13 @@ class Size(Field):
else:
raise
- def unpack(self, packed: bytes) -> int:
+ def unpack(self, packed: bytes) -> int: # type: ignore
if self.size != len(packed):
raise ValueError('%s is the wrong size for a %s field' % (repr(packed), self.name))
return int.from_bytes(packed, 'big')
- def pop(self, packed: bytes) -> Tuple[int, bytes]:
+ def pop(self, packed: bytes) -> Tuple[int, bytes]: # type: ignore
to_unpack, remainder = split(packed, self.size)
return self.unpack(to_unpack), remainder
@@ -420,48 +422,53 @@ class Address(Field):
:var stem.client.AddrType type: address type
:var int type_int: integer value of the address type
- :var unicode value: address value
+ :var str value: address value
:var bytes value_bin: encoded address value
"""
- def __init__(self, value: str, addr_type: Union[int, 'stem.client.datatype.AddrType'] = None) -> None:
+ def __init__(self, value: Union[bytes, str], addr_type: Union[int, 'stem.client.datatype.AddrType'] = None) -> None:
if addr_type is None:
- if stem.util.connection.is_valid_ipv4_address(value):
+ if stem.util.connection.is_valid_ipv4_address(value): # type: ignore
addr_type = AddrType.IPv4
- elif stem.util.connection.is_valid_ipv6_address(value):
+ elif stem.util.connection.is_valid_ipv6_address(value): # type: ignore
addr_type = AddrType.IPv6
else:
- raise ValueError("'%s' isn't an IPv4 or IPv6 address" % value)
+ raise ValueError("'%s' isn't an IPv4 or IPv6 address" % stem.util.str_tools._to_unicode(value))
+
+ value_bytes = stem.util.str_tools._to_bytes(value)
+
+ self.value = None # type: Optional[str]
+ self.value_bin = None # type: Optional[bytes]
self.type, self.type_int = AddrType.get(addr_type)
if self.type == AddrType.IPv4:
- if stem.util.connection.is_valid_ipv4_address(value):
- self.value = value
- self.value_bin = b''.join([Size.CHAR.pack(int(v)) for v in value.split('.')])
+ if stem.util.connection.is_valid_ipv4_address(value_bytes): # type: ignore
+ self.value = stem.util.str_tools._to_unicode(value_bytes)
+ self.value_bin = b''.join([Size.CHAR.pack(int(v)) for v in value_bytes.split(b'.')])
else:
- if len(value) != 4:
+ if len(value_bytes) != 4:
raise ValueError('Packed IPv4 addresses should be four bytes, but was: %s' % repr(value))
- self.value = _unpack_ipv4_address(value)
- self.value_bin = value
+ self.value = _unpack_ipv4_address(value_bytes)
+ self.value_bin = value_bytes
elif self.type == AddrType.IPv6:
- if stem.util.connection.is_valid_ipv6_address(value):
- self.value = stem.util.connection.expand_ipv6_address(value).lower()
+ if stem.util.connection.is_valid_ipv6_address(value_bytes): # type: ignore
+ self.value = stem.util.connection.expand_ipv6_address(value_bytes).lower() # type: ignore
self.value_bin = b''.join([Size.SHORT.pack(int(v, 16)) for v in self.value.split(':')])
else:
- if len(value) != 16:
+ if len(value_bytes) != 16:
raise ValueError('Packed IPv6 addresses should be sixteen bytes, but was: %s' % repr(value))
- self.value = _unpack_ipv6_address(value)
- self.value_bin = value
+ self.value = _unpack_ipv6_address(value_bytes)
+ self.value_bin = value_bytes
else:
# The spec doesn't really tell us what form to expect errors to be. For
# now just leaving the value unset so we can fill it in later when we
# know what would be most useful.
self.value = None
- self.value_bin = value
+ self.value_bin = value_bytes
def pack(self) -> bytes:
cell = bytearray()
@@ -471,7 +478,7 @@ class Address(Field):
return bytes(cell)
@staticmethod
- def pop(content) -> Tuple['stem.client.datatype.Address', bytes]:
+ def pop(content: bytes) -> Tuple['stem.client.datatype.Address', bytes]:
addr_type, content = Size.CHAR.pop(content)
addr_length, content = Size.CHAR.pop(content)
@@ -590,7 +597,7 @@ class LinkByIPv4(LinkSpecifier):
@staticmethod
def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv4':
if len(value) != 6:
- raise ValueError('IPv4 link specifiers should be six bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value)))
+ raise ValueError('IPv4 link specifiers should be six bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value))))
addr, port = split(value, 4)
return LinkByIPv4(_unpack_ipv4_address(addr), Size.SHORT.unpack(port))
@@ -615,7 +622,7 @@ class LinkByIPv6(LinkSpecifier):
@staticmethod
def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv6':
if len(value) != 18:
- raise ValueError('IPv6 link specifiers should be eighteen bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value)))
+ raise ValueError('IPv6 link specifiers should be eighteen bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value))))
addr, port = split(value, 16)
return LinkByIPv6(_unpack_ipv6_address(addr), Size.SHORT.unpack(port))
@@ -634,7 +641,7 @@ class LinkByFingerprint(LinkSpecifier):
super(LinkByFingerprint, self).__init__(2, value)
if len(value) != 20:
- raise ValueError('Fingerprint link specifiers should be twenty bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value)))
+ raise ValueError('Fingerprint link specifiers should be twenty bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value))))
self.fingerprint = stem.util.str_tools._to_unicode(value)
@@ -652,7 +659,7 @@ class LinkByEd25519(LinkSpecifier):
super(LinkByEd25519, self).__init__(3, value)
if len(value) != 32:
- raise ValueError('Fingerprint link specifiers should be thirty two bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value)))
+ raise ValueError('Fingerprint link specifiers should be thirty two bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value))))
self.fingerprint = stem.util.str_tools._to_unicode(value)
@@ -695,7 +702,7 @@ def _pack_ipv4_address(address: str) -> bytes:
return b''.join([Size.CHAR.pack(int(v)) for v in address.split('.')])
-def _unpack_ipv4_address(value: str) -> bytes:
+def _unpack_ipv4_address(value: bytes) -> str:
return '.'.join([str(Size.CHAR.unpack(value[i:i + 1])) for i in range(4)])
@@ -703,7 +710,7 @@ def _pack_ipv6_address(address: str) -> bytes:
return b''.join([Size.SHORT.pack(int(v, 16)) for v in address.split(':')])
-def _unpack_ipv6_address(value: str) -> bytes:
+def _unpack_ipv6_address(value: bytes) -> str:
return ':'.join(['%04x' % Size.SHORT.unpack(value[i * 2:(i + 1) * 2]) for i in range(8)])
diff --git a/stem/connection.py b/stem/connection.py
index 3d3eb3ee..ff950a0c 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -143,7 +143,7 @@ import stem.util.str_tools
import stem.util.system
import stem.version
-from typing import Any, Optional, Sequence, Tuple, Type, Union
+from typing import Any, List, Optional, Sequence, Tuple, Type, Union
from stem.util import log
AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN')
@@ -211,7 +211,7 @@ COMMON_TOR_COMMANDS = (
)
-def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: type = stem.control.Controller) -> Union[stem.control.BaseController, stem.socket.ControlSocket]:
+def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type = stem.control.Controller) -> Any:
"""
Convenience function for quickly getting a control connection. This is very
handy for debugging or CLI setup, handling setup and prompting for a password
@@ -250,6 +250,8 @@ def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_so
**control_port** and **control_socket** are **None**
"""
+ # TODO: change this function's API so we can provide a concrete type
+
if control_port is None and control_socket is None:
raise ValueError('Neither a control port nor control socket were provided. Nothing to connect to.')
elif control_port:
@@ -260,7 +262,8 @@ def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_so
elif control_port[1] != 'default' and not stem.util.connection.is_valid_port(control_port[1]):
raise ValueError("'%s' isn't a valid port" % control_port[1])
- control_connection, error_msg = None, ''
+ control_connection = None # type: Optional[stem.socket.ControlSocket]
+ error_msg = ''
if control_socket:
if os.path.exists(control_socket):
@@ -297,7 +300,7 @@ def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_so
return _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
-def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Union[Type[stem.control.BaseController], Type[stem.socket.ControlSocket]]) -> Union[stem.control.BaseController, stem.socket.ControlSocket]:
+def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any:
"""
Helper for the connect_* functions that authenticates the socket and
constructs the controller.
@@ -363,7 +366,7 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
return None
-def authenticate(controller: Any, password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None:
+def authenticate(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None:
"""
Authenticates to a control socket using the information provided by a
PROTOCOLINFO response. In practice this will often be all we need to
@@ -481,7 +484,7 @@ def authenticate(controller: Any, password: Optional[str] = None, chroot_path: O
raise AuthenticationFailure('socket connection failed (%s)' % exc)
auth_methods = list(protocolinfo_response.auth_methods)
- auth_exceptions = []
+ auth_exceptions = [] # type: List[stem.connection.AuthenticationFailure]
if len(auth_methods) == 0:
raise NoAuthMethods('our PROTOCOLINFO response did not have any methods for authenticating')
@@ -846,10 +849,11 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.
cookie_data = _read_cookie(cookie_path, True)
client_nonce = os.urandom(32)
+ authchallenge_response = None # type: stem.response.authchallenge.AuthChallengeResponse
try:
client_nonce_hex = stem.util.str_tools._to_unicode(binascii.b2a_hex(client_nonce))
- authchallenge_response = _msg(controller, 'AUTHCHALLENGE SAFECOOKIE %s' % client_nonce_hex)
+ authchallenge_response = _msg(controller, 'AUTHCHALLENGE SAFECOOKIE %s' % client_nonce_hex) # type: ignore
if not authchallenge_response.is_ok():
try:
@@ -862,13 +866,18 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.
if 'Authentication required.' in authchallenge_response_str:
raise AuthChallengeUnsupported("SAFECOOKIE authentication isn't supported", cookie_path)
elif 'AUTHCHALLENGE only supports' in authchallenge_response_str:
- raise UnrecognizedAuthChallengeMethod(authchallenge_response_str, cookie_path)
+ # TODO: This code path has been broken for years. Do we still need it?
+ # If so, what should authchallenge_method be?
+
+ authchallenge_method = None
+
+ raise UnrecognizedAuthChallengeMethod(authchallenge_response_str, cookie_path, authchallenge_method)
elif 'Invalid base16 client nonce' in authchallenge_response_str:
raise InvalidClientNonce(authchallenge_response_str, cookie_path)
elif 'Cookie authentication is disabled' in authchallenge_response_str:
raise CookieAuthRejected(authchallenge_response_str, cookie_path, True)
else:
- raise AuthChallengeFailed(authchallenge_response, cookie_path)
+ raise AuthChallengeFailed(authchallenge_response_str, cookie_path)
except stem.ControllerError as exc:
try:
controller.connect()
@@ -878,7 +887,7 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.
if not suppress_ctl_errors:
raise
else:
- raise AuthChallengeFailed('Socket failed (%s)' % exc, cookie_path, True)
+ raise AuthChallengeFailed('Socket failed (%s)' % exc, cookie_path)
try:
stem.response.convert('AUTHCHALLENGE', authchallenge_response)
@@ -970,7 +979,7 @@ def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket.
raise stem.SocketError(exc)
stem.response.convert('PROTOCOLINFO', protocolinfo_response)
- return protocolinfo_response
+ return protocolinfo_response # type: ignore
def _msg(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], message: str) -> stem.response.ControlMessage:
@@ -1008,7 +1017,7 @@ def _connection_for_default_port(address: str) -> stem.socket.ControlPort:
raise exc
-def _read_cookie(cookie_path: str, is_safecookie: bool) -> str:
+def _read_cookie(cookie_path: str, is_safecookie: bool) -> bytes:
"""
Provides the contents of a given cookie file.
@@ -1016,7 +1025,7 @@ def _read_cookie(cookie_path: str, is_safecookie: bool) -> str:
:param bool is_safecookie: **True** if this was for SAFECOOKIE
authentication, **False** if for COOKIE
- :returns: **str** with the cookie file content
+ :returns: **bytes** with the cookie file content
:raises:
* :class:`stem.connection.UnreadableCookieFile` if the cookie file is
@@ -1052,12 +1061,12 @@ def _read_cookie(cookie_path: str, is_safecookie: bool) -> str:
raise UnreadableCookieFile(exc_msg, cookie_path, is_safecookie)
-def _hmac_sha256(key: str, msg: str) -> bytes:
+def _hmac_sha256(key: bytes, msg: bytes) -> bytes:
"""
Generates a sha256 digest using the given key and message.
- :param str key: starting key for the hash
- :param str msg: message to be hashed
+ :param bytes key: starting key for the hash
+ :param bytes msg: message to be hashed
:returns: sha256 digest of msg as bytes, hashed using the given key
"""
diff --git a/stem/control.py b/stem/control.py
index ec4ba54e..626b2b3e 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -271,7 +271,7 @@ import stem.version
from stem import UNDEFINED, CircStatus, Signal
from stem.util import log
from types import TracebackType
-from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
# When closing the controller we attempt to finish processing enqueued events,
# but if it takes longer than this we terminate.
@@ -404,7 +404,7 @@ SERVER_DESCRIPTORS_UNSUPPORTED = "Tor is currently not configured to retrieve \
server descriptors. As of Tor version 0.2.3.25 it downloads microdescriptors \
instead unless you set 'UseMicrodescriptors 0' in your torrc."
-EVENT_DESCRIPTIONS = None
+EVENT_DESCRIPTIONS = None # type: Dict[str, str]
class AccountingStats(collections.namedtuple('AccountingStats', ['retrieved', 'status', 'interval_end', 'time_until_reset', 'read_bytes', 'read_bytes_left', 'read_limit', 'written_bytes', 'write_bytes_left', 'write_limit'])):
@@ -518,7 +518,7 @@ def event_description(event: str) -> str:
try:
config.load(config_path)
- EVENT_DESCRIPTIONS = dict([(key.lower()[18:], config.get_value(key)) for key in config.keys() if key.startswith('event.description.')])
+ EVENT_DESCRIPTIONS = dict([(key.lower()[18:], config.get_value(key)) for key in config.keys() if key.startswith('event.description.')]) # type: ignore
except Exception as exc:
log.warn("BUG: stem failed to load its internal manual information from '%s': %s" % (config_path, exc))
return None
@@ -546,19 +546,19 @@ class BaseController(object):
self._socket = control_socket
self._msg_lock = threading.RLock()
- self._status_listeners = [] # tuples of the form (callback, spawn_thread)
+ self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread)
self._status_listeners_lock = threading.RLock()
# queues where incoming messages are directed
- self._reply_queue = queue.Queue()
- self._event_queue = queue.Queue()
+ self._reply_queue = queue.Queue() # type: queue.Queue[Union[stem.response.ControlMessage, stem.ControllerError]]
+ self._event_queue = queue.Queue() # type: queue.Queue[stem.response.ControlMessage]
# thread to continually pull from the control socket
- self._reader_thread = None
+ self._reader_thread = None # type: Optional[threading.Thread]
# thread to pull from the _event_queue and call handle_event
self._event_notice = threading.Event()
- self._event_thread = None
+ self._event_thread = None # type: Optional[threading.Thread]
# saves our socket's prior _connect() and _close() methods so they can be
# called along with ours
@@ -566,13 +566,13 @@ class BaseController(object):
self._socket_connect = self._socket._connect
self._socket_close = self._socket._close
- self._socket._connect = self._connect
- self._socket._close = self._close
+ self._socket._connect = self._connect # type: ignore
+ self._socket._close = self._close # type: ignore
self._last_heartbeat = 0.0 # timestamp for when we last heard from tor
self._is_authenticated = False
- self._state_change_threads = [] # threads we've spawned to notify of state changes
+ self._state_change_threads = [] # type: List[threading.Thread] # threads we've spawned to notify of state changes
if self._socket.is_alive():
self._launch_threads()
@@ -757,7 +757,7 @@ class BaseController(object):
return self._last_heartbeat
- def add_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None], spawn: bool = True) -> None:
+ def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None:
"""
Notifies a given function when the state of our socket changes. Functions
are expected to be of the form...
@@ -986,7 +986,7 @@ class Controller(BaseController):
"""
@staticmethod
- def from_port(address: str = '127.0.0.1', port: int = 'default') -> 'stem.control.Controller':
+ def from_port(address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller':
"""
Constructs a :class:`~stem.socket.ControlPort` based Controller.
@@ -1015,7 +1015,7 @@ class Controller(BaseController):
if port == 'default':
control_port = stem.connection._connection_for_default_port(address)
else:
- control_port = stem.socket.ControlPort(address, port)
+ control_port = stem.socket.ControlPort(address, int(port))
return Controller(control_port)
@@ -1036,33 +1036,33 @@ class Controller(BaseController):
def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None:
self._is_caching_enabled = True
- self._request_cache = {}
+ self._request_cache = {} # type: Dict[str, Any]
self._last_newnym = 0.0
self._cache_lock = threading.RLock()
# mapping of event types to their listeners
- self._event_listeners = {}
+ self._event_listeners = {} # type: Dict[stem.control.EventType, List[Callable[[stem.response.events.Event], None]]]
self._event_listeners_lock = threading.RLock()
- self._enabled_features = []
+ self._enabled_features = [] # type: List[str]
- self._last_address_exc = None
- self._last_fingerprint_exc = None
+ self._last_address_exc = None # type: Optional[BaseException]
+ self._last_fingerprint_exc = None # type: Optional[BaseException]
super(Controller, self).__init__(control_socket, is_authenticated)
- def _sighup_listener(event: stem.response.events.Event) -> None:
+ def _sighup_listener(event: stem.response.events.SignalEvent) -> None:
if event.signal == Signal.RELOAD:
self.clear_cache()
self._notify_status_listeners(State.RESET)
- self.add_event_listener(_sighup_listener, EventType.SIGNAL)
+ self.add_event_listener(_sighup_listener, EventType.SIGNAL) # type: ignore
- def _confchanged_listener(event: stem.response.events.Event) -> None:
+ def _confchanged_listener(event: stem.response.events.ConfChangedEvent) -> None:
if self.is_caching_enabled():
to_cache_changed = dict((k.lower(), v) for k, v in event.changed.items())
- to_cache_unset = dict((k.lower(), []) for k in event.unset) # [] represents None value in cache
+ to_cache_unset = dict((k.lower(), []) for k in event.unset) # type: Dict[str, List[str]] # [] represents None value in cache
to_cache = {}
to_cache.update(to_cache_changed)
@@ -1072,15 +1072,15 @@ class Controller(BaseController):
self._confchanged_cache_invalidation(to_cache)
- self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED)
+ self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED) # type: ignore
- def _address_changed_listener(event: stem.response.events.Event) -> None:
+ def _address_changed_listener(event: stem.response.events.StatusEvent) -> None:
if event.action in ('EXTERNAL_ADDRESS', 'DNS_USELESS'):
self._set_cache({'exit_policy': None})
self._set_cache({'address': None}, 'getinfo')
self._last_address_exc = None
- self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER)
+ self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER) # type: ignore
def close(self) -> None:
self.clear_cache()
@@ -1152,15 +1152,15 @@ class Controller(BaseController):
if isinstance(params, (bytes, str)):
is_multiple = False
- params = set([params])
+ param_set = set([params])
else:
if not params:
return {}
is_multiple = True
- params = set(params)
+ param_set = set(params)
- for param in params:
+ for param in param_set:
if param.startswith('ip-to-country/') and param != 'ip-to-country/0.0.0.0' and self.get_info('ip-to-country/ipv4-available', '0') != '1':
raise stem.ProtocolError('Tor geoip database is unavailable')
elif param == 'address' and self._last_address_exc:
@@ -1170,16 +1170,16 @@ class Controller(BaseController):
# check for cached results
- from_cache = [param.lower() for param in params]
+ from_cache = [param.lower() for param in param_set]
cached_results = self._get_cache_map(from_cache, 'getinfo')
for key in cached_results:
- user_expected_key = _case_insensitive_lookup(params, key)
+ user_expected_key = _case_insensitive_lookup(param_set, key)
reply[user_expected_key] = cached_results[key]
- params.remove(user_expected_key)
+ param_set.remove(user_expected_key)
# if everything was cached then short circuit making the query
- if not params:
+ if not param_set:
if LOG_CACHE_FETCHES:
log.trace('GETINFO %s (cache fetch)' % ' '.join(reply.keys()))
@@ -1189,14 +1189,13 @@ class Controller(BaseController):
return list(reply.values())[0]
try:
- response = self.msg('GETINFO %s' % ' '.join(params))
- stem.response.convert('GETINFO', response)
- response._assert_matches(params)
+ response = stem.response._convert_to_getinfo(self.msg('GETINFO %s' % ' '.join(param_set)))
+ response._assert_matches(param_set)
# usually we want unicode values under python 3.x
if not get_bytes:
- response.entries = dict((k, stem.util.str_tools._to_unicode(v)) for (k, v) in response.entries.items())
+ response.entries = dict((k, stem.util.str_tools._to_unicode(v)) for (k, v) in response.entries.items()) # type: ignore
reply.update(response.entries)
@@ -1213,26 +1212,26 @@ class Controller(BaseController):
self._set_cache(to_cache, 'getinfo')
- if 'address' in params:
+ if 'address' in param_set:
self._last_address_exc = None
- if 'fingerprint' in params:
+ if 'fingerprint' in param_set:
self._last_fingerprint_exc = None
- log.debug('GETINFO %s (runtime: %0.4f)' % (' '.join(params), time.time() - start_time))
+ log.debug('GETINFO %s (runtime: %0.4f)' % (' '.join(param_set), time.time() - start_time))
if is_multiple:
return reply
else:
return list(reply.values())[0]
except stem.ControllerError as exc:
- if 'address' in params:
+ if 'address' in param_set:
self._last_address_exc = exc
- if 'fingerprint' in params:
+ if 'fingerprint' in param_set:
self._last_fingerprint_exc = exc
- log.debug('GETINFO %s (failed: %s)' % (' '.join(params), exc))
+ log.debug('GETINFO %s (failed: %s)' % (' '.join(param_set), exc))
raise
@with_default()
@@ -1363,7 +1362,7 @@ class Controller(BaseController):
if listeners is None:
proxy_addrs = []
- query = 'net/listeners/%s' % listener_type.lower()
+ query = 'net/listeners/%s' % str(listener_type).lower()
try:
for listener in self.get_info(query).split():
@@ -1413,7 +1412,7 @@ class Controller(BaseController):
Listener.CONTROL: 'ControlListenAddress',
}[listener_type]
- port_value = self.get_conf(port_option).split()[0]
+ port_value = self._get_conf_single(port_option).split()[0]
for listener in self.get_conf(listener_option, multiple = True):
if ':' in listener:
@@ -1571,7 +1570,7 @@ class Controller(BaseController):
pid = int(getinfo_pid)
if not pid and self.is_localhost():
- pid_file_path = self.get_conf('PidFile', None)
+ pid_file_path = self._get_conf_single('PidFile', None)
if pid_file_path is not None:
with open(pid_file_path) as pid_file:
@@ -1666,7 +1665,7 @@ class Controller(BaseController):
return time.time() - self.get_start_time()
- def is_user_traffic_allowed(self) -> bool:
+ def is_user_traffic_allowed(self) -> 'stem.control.UserTrafficAllowed':
"""
Checks if we're likely to service direct user traffic. This essentially
boils down to...
@@ -1687,7 +1686,7 @@ class Controller(BaseController):
.. versionadded:: 1.5.0
- :returns: :class:`~stem.cotroller.UserTrafficAllowed` with **inbound** and
+ :returns: :class:`~stem.control.UserTrafficAllowed` with **inbound** and
**outbound** boolean attributes to indicate if we're likely servicing
direct user traffic
"""
@@ -1860,7 +1859,7 @@ class Controller(BaseController):
return stem.descriptor.server_descriptor.RelayDescriptor(desc_content)
@with_default(yields = True)
- def get_server_descriptors(self, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor:
+ def get_server_descriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]:
"""
get_server_descriptors(default = UNDEFINED)
@@ -1893,7 +1892,7 @@ class Controller(BaseController):
raise stem.DescriptorUnavailable('Descriptor information is unavailable, tor might still be downloading it')
for desc in stem.descriptor.server_descriptor._parse_file(io.BytesIO(desc_content)):
- yield desc
+ yield desc # type: ignore
@with_default()
def get_network_status(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.router_status_entry.RouterStatusEntryV3:
@@ -1989,7 +1988,7 @@ class Controller(BaseController):
)
for desc in desc_iterator:
- yield desc
+ yield desc # type: ignore
@with_default()
def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2:
@@ -2035,8 +2034,12 @@ class Controller(BaseController):
if not stem.util.tor_tools.is_valid_hidden_service_address(address):
raise ValueError("'%s.onion' isn't a valid hidden service address" % address)
- hs_desc_queue, hs_desc_listener = queue.Queue(), None
- hs_desc_content_queue, hs_desc_content_listener = queue.Queue(), None
+ hs_desc_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event]
+ hs_desc_listener = None
+
+ hs_desc_content_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event]
+ hs_desc_content_listener = None
+
start_time = time.time()
if await_result:
@@ -2055,8 +2058,7 @@ class Controller(BaseController):
if servers:
request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
- response = self.msg(request)
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg(request))
if not response.is_ok():
raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code)
@@ -2137,6 +2139,14 @@ class Controller(BaseController):
entries = self.get_conf_map(param, default, multiple)
return _case_insensitive_lookup(entries, param, default)
+ # TODO: temporary aliases until we have better type support in our API
+
+ def _get_conf_single(self, param: str, default: Any = UNDEFINED) -> str:
+ return self.get_conf(param, default) # type: ignore
+
+ def _get_conf_multiple(self, param: str, default: Any = UNDEFINED) -> List[str]:
+ return self.get_conf(param, default, multiple = True) # type: ignore
+
def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
"""
get_conf_map(params, default = UNDEFINED, multiple = True)
@@ -2218,8 +2228,7 @@ class Controller(BaseController):
return self._get_conf_dict_to_response(reply, default, multiple)
try:
- response = self.msg('GETCONF %s' % ' '.join(lookup_params))
- stem.response.convert('GETCONF', response)
+ response = stem.response._convert_to_getconf(self.msg('GETCONF %s' % ' '.join(lookup_params)))
reply.update(response.entries)
if self.is_caching_enabled():
@@ -2414,8 +2423,7 @@ class Controller(BaseController):
raise ValueError('Cannot set %s to %s since the value was a %s but we only accept strings' % (param, value, type(value).__name__))
query = ' '.join(query_comp)
- response = self.msg(query)
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg(query))
if response.is_ok():
log.debug('%s (runtime: %0.4f)' % (query, time.time() - start_time))
@@ -2489,15 +2497,14 @@ class Controller(BaseController):
start_time = time.time()
try:
- response = self.msg('GETCONF HiddenServiceOptions')
- stem.response.convert('GETCONF', response)
+ response = stem.response._convert_to_getconf(self.msg('GETCONF HiddenServiceOptions'))
log.debug('GETCONF HiddenServiceOptions (runtime: %0.4f)' %
(time.time() - start_time))
except stem.ControllerError as exc:
log.debug('GETCONF HiddenServiceOptions (failed: %s)' % exc)
raise
- service_dir_map = collections.OrderedDict()
+ service_dir_map = collections.OrderedDict() # type: collections.OrderedDict[str, Any]
directory = None
for status_code, divider, content in response.content():
@@ -2603,7 +2610,7 @@ class Controller(BaseController):
self.set_options(hidden_service_options)
- def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.cotroller.CreateHiddenServiceOutput':
+ def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.control.CreateHiddenServiceOutput':
"""
Create a new hidden service. If the directory is already present, a
new port is added.
@@ -2629,7 +2636,7 @@ class Controller(BaseController):
:param str auth_type: authentication type: basic, stealth or None to disable auth
:param list client_names: client names (1-16 characters "A-Za-z0-9+-_")
- :returns: :class:`~stem.cotroller.CreateHiddenServiceOutput` if we create
+ :returns: :class:`~stem.control.CreateHiddenServiceOutput` if we create
or update a hidden service, **None** otherwise
:raises: :class:`stem.ControllerError` if the call fails
@@ -2905,7 +2912,8 @@ class Controller(BaseController):
* :class:`stem.Timeout` if **timeout** was reached
"""
- hs_desc_queue, hs_desc_listener = queue.Queue(), None
+ hs_desc_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event]
+ hs_desc_listener = None
start_time = time.time()
if await_publication:
@@ -2957,8 +2965,7 @@ class Controller(BaseController):
else:
request += ' ClientAuth=%s' % client_name
- response = self.msg(request)
- stem.response.convert('ADD_ONION', response)
+ response = stem.response._convert_to_add_onion(stem.response._convert_to_add_onion(self.msg(request)))
if await_publication:
# We should receive five UPLOAD events, followed by up to another five
@@ -3002,8 +3009,7 @@ class Controller(BaseController):
:raises: :class:`stem.ControllerError` if the call fails
"""
- response = self.msg('DEL_ONION %s' % service_id)
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('DEL_ONION %s' % service_id))
if response.is_ok():
return True
@@ -3056,7 +3062,7 @@ class Controller(BaseController):
event_type = stem.response.events.EVENT_TYPE_TO_CLASS.get(event_type)
if event_type and (self.get_version() < event_type._VERSION_ADDED):
- raise stem.InvalidRequest(552, '%s event requires Tor version %s or later' % (event_type, event_type._VERSION_ADDED))
+ raise stem.InvalidRequest('552', '%s event requires Tor version %s or later' % (event_type, event_type._VERSION_ADDED))
for event_type in events:
self._event_listeners.setdefault(event_type, []).append(listener)
@@ -3135,7 +3141,7 @@ class Controller(BaseController):
return cached_values
- def _set_cache(self, params: Mapping[str, Any], namespace: Optional[str] = None) -> None:
+ def _set_cache(self, params: Dict[str, Any], namespace: Optional[str] = None) -> None:
"""
Sets the given request cache entries. If the new cache value is **None**
then it is removed from our cache.
@@ -3241,8 +3247,7 @@ class Controller(BaseController):
:raises: :class:`stem.ControllerError` if the call fails
"""
- response = self.msg('LOADCONF\n%s' % configtext)
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('LOADCONF\n%s' % configtext))
if response.code in ('552', '553'):
if response.code == '552' and response.message.startswith('Invalid config file: Failed to parse/validate config: Unknown option'):
@@ -3267,11 +3272,10 @@ class Controller(BaseController):
the configuration file
"""
- response = self.msg('SAVECONF FORCE' if force else 'SAVECONF')
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('SAVECONF FORCE' if force else 'SAVECONF'))
if response.is_ok():
- return True
+ pass
elif response.code == '551':
raise stem.OperationFailed(response.code, response.message)
else:
@@ -3311,8 +3315,7 @@ class Controller(BaseController):
if isinstance(features, (bytes, str)):
features = [features]
- response = self.msg('USEFEATURE %s' % ' '.join(features))
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('USEFEATURE %s' % ' '.join(features)))
if not response.is_ok():
if response.code == '552':
@@ -3353,7 +3356,7 @@ class Controller(BaseController):
raise ValueError("Tor currently does not have a circuit with the id of '%s'" % circuit_id)
@with_default()
- def get_circuits(self, default: Any = UNDEFINED) -> Sequence[stem.response.events.CircuitEvent]:
+ def get_circuits(self, default: Any = UNDEFINED) -> List[stem.response.events.CircuitEvent]:
"""
get_circuits(default = UNDEFINED)
@@ -3366,13 +3369,12 @@ class Controller(BaseController):
:raises: :class:`stem.ControllerError` if the call fails and no default was provided
"""
- circuits = []
+ circuits = [] # type: List[stem.response.events.CircuitEvent]
response = self.get_info('circuit-status')
for circ in response.splitlines():
- circ_message = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ)))
- stem.response.convert('EVENT', circ_message)
- circuits.append(circ_message)
+ circ_message = stem.response._convert_to_event(stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ))))
+ circuits.append(circ_message) # type: ignore
return circuits
@@ -3442,7 +3444,8 @@ class Controller(BaseController):
# to build. This is icky, but we can't reliably do this via polling since
# we then can't get the failure if it can't be created.
- circ_queue, circ_listener = queue.Queue(), None
+ circ_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event]
+ circ_listener = None
start_time = time.time()
if await_build:
@@ -3463,8 +3466,7 @@ class Controller(BaseController):
if purpose:
args.append('purpose=%s' % purpose)
- response = self.msg('EXTENDCIRCUIT %s' % ' '.join(args))
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('EXTENDCIRCUIT %s' % ' '.join(args)))
if response.code in ('512', '552'):
raise stem.InvalidRequest(response.code, response.message)
@@ -3505,8 +3507,7 @@ class Controller(BaseController):
:raises: :class:`stem.InvalidArguments` if the circuit doesn't exist or if the purpose was invalid
"""
- response = self.msg('SETCIRCUITPURPOSE %s purpose=%s' % (circuit_id, purpose))
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('SETCIRCUITPURPOSE %s purpose=%s' % (circuit_id, purpose)))
if not response.is_ok():
if response.code == '552':
@@ -3527,8 +3528,7 @@ class Controller(BaseController):
* :class:`stem.InvalidRequest` if not enough information is provided
"""
- response = self.msg('CLOSECIRCUIT %s %s' % (circuit_id, flag))
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('CLOSECIRCUIT %s %s' % (circuit_id, flag)))
if not response.is_ok():
if response.code in ('512', '552'):
@@ -3539,7 +3539,7 @@ class Controller(BaseController):
raise stem.ProtocolError('CLOSECIRCUIT returned unexpected response code: %s' % response.code)
@with_default()
- def get_streams(self, default: Any = UNDEFINED) -> Sequence[stem.response.events.StreamEvent]:
+ def get_streams(self, default: Any = UNDEFINED) -> List[stem.response.events.StreamEvent]:
"""
get_streams(default = UNDEFINED)
@@ -3553,13 +3553,12 @@ class Controller(BaseController):
provided
"""
- streams = []
+ streams = [] # type: List[stem.response.events.StreamEvent]
response = self.get_info('stream-status')
for stream in response.splitlines():
- message = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream)))
- stem.response.convert('EVENT', message)
- streams.append(message)
+ message = stem.response._convert_to_event(stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream))))
+ streams.append(message) # type: ignore
return streams
@@ -3585,8 +3584,7 @@ class Controller(BaseController):
if exiting_hop:
query += ' HOP=%s' % exiting_hop
- response = self.msg(query)
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg(query))
if not response.is_ok():
if response.code == '552':
@@ -3614,8 +3612,7 @@ class Controller(BaseController):
# there's a single value offset between RelayEndReason.index_of() and the
# value that tor expects since tor's value starts with the index of one
- response = self.msg('CLOSESTREAM %s %s %s' % (stream_id, stem.RelayEndReason.index_of(reason) + 1, flag))
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('CLOSESTREAM %s %s %s' % (stream_id, stem.RelayEndReason.index_of(reason) + 1, flag)))
if not response.is_ok():
if response.code in ('512', '552'):
@@ -3638,8 +3635,7 @@ class Controller(BaseController):
* :class:`stem.InvalidArguments` if signal provided wasn't recognized
"""
- response = self.msg('SIGNAL %s' % signal)
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('SIGNAL %s' % signal))
if response.is_ok():
if signal == stem.Signal.NEWNYM:
@@ -3703,14 +3699,14 @@ class Controller(BaseController):
"""
if not burst:
- attributes = ('BandwidthRate', 'RelayBandwidthRate', 'MaxAdvertisedBandwidth')
+ attributes = ['BandwidthRate', 'RelayBandwidthRate', 'MaxAdvertisedBandwidth']
else:
- attributes = ('BandwidthBurst', 'RelayBandwidthBurst')
+ attributes = ['BandwidthBurst', 'RelayBandwidthBurst']
value = None
for attr in attributes:
- attr_value = int(self.get_conf(attr))
+ attr_value = int(self._get_conf_single(attr))
if attr_value == 0 and attr.startswith('Relay'):
continue # RelayBandwidthRate and RelayBandwidthBurst default to zero
@@ -3740,9 +3736,7 @@ class Controller(BaseController):
mapaddress_arg = ' '.join(['%s=%s' % (k, v) for (k, v) in list(mapping.items())])
response = self.msg('MAPADDRESS %s' % mapaddress_arg)
- stem.response.convert('MAPADDRESS', response)
-
- return response.entries
+ return stem.response._convert_to_mapaddress(response).entries
def drop_guards(self) -> None:
"""
@@ -3779,8 +3773,7 @@ class Controller(BaseController):
owning_pid = self.get_conf('__OwningControllerProcess', None)
if owning_pid == str(os.getpid()) and self.is_localhost():
- response = self.msg('TAKEOWNERSHIP')
- stem.response.convert('SINGLELINE', response)
+ response = stem.response._convert_to_single_line(self.msg('TAKEOWNERSHIP'))
if response.is_ok():
# Now that tor is tracking our ownership of the process via the control
@@ -3793,11 +3786,18 @@ class Controller(BaseController):
else:
log.warn('We were unable assert ownership of tor through TAKEOWNERSHIP, despite being configured to be the owning process through __OwningControllerProcess. (%s)' % response)
- def _handle_event(self, event_message: str) -> None:
+ def _handle_event(self, event_message: stem.response.ControlMessage) -> None:
+ event = None # type: Optional[stem.response.events.Event]
+
try:
- stem.response.convert('EVENT', event_message)
- event_type = event_message.type
+ event = stem.response._convert_to_event(event_message)
+ event_type = event.type
except stem.ProtocolError as exc:
+ # TODO: We should change this so malformed events convert to the base
+ # Event class, so we don't provide raw ControlMessages to listeners.
+
+ event = event_message # type: ignore
+
log.error('Tor sent a malformed event (%s): %s' % (exc, event_message))
event_type = MALFORMED_EVENTS
@@ -3806,9 +3806,9 @@ class Controller(BaseController):
if listener_type == event_type:
for listener in event_listeners:
try:
- listener(event_message)
+ listener(event)
except Exception as exc:
- log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event_message))
+ log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event))
def _attach_listeners(self) -> Tuple[Sequence[str], Sequence[str]]:
"""
diff --git a/stem/descriptor/__init__.py b/stem/descriptor/__init__.py
index 9c769749..477e15e9 100644
--- a/stem/descriptor/__init__.py
+++ b/stem/descriptor/__init__.py
@@ -108,6 +108,7 @@ import base64
import codecs
import collections
import copy
+import hashlib
import io
import os
import random
@@ -120,7 +121,7 @@ import stem.util.enum
import stem.util.str_tools
import stem.util.system
-from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type
+from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
__all__ = [
'bandwidth_file',
@@ -152,7 +153,7 @@ KEYWORD_LINE = re.compile('^([%s]+)(?:[%s]+(.*))?$' % (KEYWORD_CHAR, WHITESPACE)
SPECIFIC_KEYWORD_LINE = '^(%%s)(?:[%s]+(.*))?$' % WHITESPACE
PGP_BLOCK_START = re.compile('^-----BEGIN ([%s%s]+)-----$' % (KEYWORD_CHAR, WHITESPACE))
PGP_BLOCK_END = '-----END %s-----'
-EMPTY_COLLECTION = ([], {}, set())
+EMPTY_COLLECTION = ([], {}, set()) # type: ignore
DIGEST_TYPE_INFO = b'\x00\x01'
DIGEST_PADDING = b'\xFF'
@@ -164,6 +165,8 @@ skFtXhOHHqTRN4GPPrZsAIUOQGzQtGb66IQgT4tO/pj+P6QmSCCdTfhvGfgTCsC+
WPi4Fl2qryzTb3QO5r5x7T8OsG2IBUET1bLQzmtbC560SYR49IvVAgMBAAE=
"""
+ENTRY_TYPE = Dict[str, List[Tuple[str, str, str]]]
+
DigestHash = stem.util.enum.UppercaseEnum(
'SHA1',
'SHA256',
@@ -194,7 +197,7 @@ class _Compression(object):
.. versionadded:: 1.8.0
"""
- def __init__(self, name: str, module: Optional[str], encoding: str, extension: str, decompression_func: Callable[[Any, str], bytes]) -> None:
+ def __init__(self, name: str, module: Optional[str], encoding: str, extension: str, decompression_func: Callable[[Any, bytes], bytes]) -> None:
if module is None:
self._module = None
self.available = True
@@ -256,7 +259,7 @@ class _Compression(object):
return self._name
-def _zstd_decompress(module: Any, content: str) -> bytes:
+def _zstd_decompress(module: Any, content: bytes) -> bytes:
output_buffer = io.BytesIO()
with module.ZstdDecompressor().write_to(output_buffer) as decompressor:
@@ -304,7 +307,7 @@ class SigningKey(collections.namedtuple('SigningKey', ['private', 'public', 'pub
"""
-def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, normalize_newlines: Optional[bool] = None, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
+def parse_file(descriptor_file: Union[str, BinaryIO, tarfile.TarFile, IO[bytes]], descriptor_type: str = None, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, normalize_newlines: Optional[bool] = None, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
"""
Simple function to read the descriptor contents from a file, providing an
iterator for its :class:`~stem.descriptor.__init__.Descriptor` contents.
@@ -372,7 +375,7 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate:
# Delegate to a helper if this is a path or tarfile.
- handler = None
+ handler = None # type: Callable
if isinstance(descriptor_file, (bytes, str)):
if stem.util.system.is_tarfile(descriptor_file):
@@ -388,7 +391,7 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate:
return
- if not descriptor_file.seekable():
+ if not descriptor_file.seekable(): # type: ignore
raise IOError(UNSEEKABLE_MSG)
# The tor descriptor specifications do not provide a reliable method for
@@ -397,19 +400,19 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate:
# by an annotation on their first line...
# https://trac.torproject.org/5651
- initial_position = descriptor_file.tell()
- first_line = stem.util.str_tools._to_unicode(descriptor_file.readline().strip())
+ initial_position = descriptor_file.tell() # type: ignore
+ first_line = stem.util.str_tools._to_unicode(descriptor_file.readline().strip()) # type: ignore
metrics_header_match = re.match('^@type (\\S+) (\\d+).(\\d+)$', first_line)
if not metrics_header_match:
- descriptor_file.seek(initial_position)
+ descriptor_file.seek(initial_position) # type: ignore
descriptor_path = getattr(descriptor_file, 'name', None)
- filename = '<undefined>' if descriptor_path is None else os.path.basename(descriptor_file.name)
+ filename = '<undefined>' if descriptor_path is None else os.path.basename(descriptor_file.name) # type: str # type: ignore
def parse(descriptor_file: BinaryIO) -> Iterator['stem.descriptor.Descriptor']:
if normalize_newlines:
- descriptor_file = NewlineNormalizer(descriptor_file)
+ descriptor_file = NewlineNormalizer(descriptor_file) # type: ignore
if descriptor_type is not None:
descriptor_type_match = re.match('^(\\S+) (\\d+).(\\d+)$', descriptor_type)
@@ -428,7 +431,7 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate:
# Cached descriptor handling. These contain multiple descriptors per file.
if normalize_newlines is None and stem.util.system.is_windows():
- descriptor_file = NewlineNormalizer(descriptor_file)
+ descriptor_file = NewlineNormalizer(descriptor_file) # type: ignore
if filename == 'cached-descriptors' or filename == 'cached-descriptors.new':
return stem.descriptor.server_descriptor._parse_file(descriptor_file, validate = validate, **kwargs)
@@ -441,29 +444,29 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate:
elif filename == 'cached-microdesc-consensus':
return stem.descriptor.networkstatus._parse_file(descriptor_file, is_microdescriptor = True, validate = validate, document_handler = document_handler, **kwargs)
else:
- raise TypeError("Unable to determine the descriptor's type. filename: '%s', first line: '%s'" % (filename, first_line))
+ raise TypeError("Unable to determine the descriptor's type. filename: '%s', first line: '%s'" % (filename, stem.util.str_tools._to_unicode(first_line)))
- for desc in parse(descriptor_file):
+ for desc in parse(descriptor_file): # type: ignore
if descriptor_path is not None:
desc._set_path(os.path.abspath(descriptor_path))
yield desc
-def _parse_file_for_path(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
+def _parse_file_for_path(descriptor_file: str, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
with open(descriptor_file, 'rb') as desc_file:
for desc in parse_file(desc_file, *args, **kwargs):
yield desc
-def _parse_file_for_tar_path(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
+def _parse_file_for_tar_path(descriptor_file: str, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
with tarfile.open(descriptor_file) as tar_file:
for desc in parse_file(tar_file, *args, **kwargs):
desc._set_path(os.path.abspath(descriptor_file))
yield desc
-def _parse_file_for_tarfile(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
+def _parse_file_for_tarfile(descriptor_file: tarfile.TarFile, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
for tar_entry in descriptor_file:
if tar_entry.isfile():
entry = descriptor_file.extractfile(tar_entry)
@@ -479,10 +482,14 @@ def _parse_file_for_tarfile(descriptor_file: BinaryIO, *args: Any, **kwargs: Any
entry.close()
-def _parse_metrics_file(descriptor_type: Type['stem.descriptor.Descriptor'], major_version: int, minor_version: int, descriptor_file: BinaryIO, validate: bool, document_handler: 'stem.descriptor.DocumentHandler', **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
+def _parse_metrics_file(descriptor_type: str, major_version: int, minor_version: int, descriptor_file: BinaryIO, validate: bool, document_handler: 'stem.descriptor.DocumentHandler', **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']:
# Parses descriptor files from metrics, yielding individual descriptors. This
# throws a TypeError if the descriptor_type or version isn't recognized.
+ desc = None # type: Optional[Any]
+ desc_type = None # type: Optional[Type[stem.descriptor.Descriptor]]
+ document_type = None # type: Optional[Type]
+
if descriptor_type == stem.descriptor.server_descriptor.RelayDescriptor.TYPE_ANNOTATION_NAME and major_version == 1:
for desc in stem.descriptor.server_descriptor._parse_file(descriptor_file, is_bridge = False, validate = validate, **kwargs):
yield desc
@@ -507,7 +514,7 @@ def _parse_metrics_file(descriptor_type: Type['stem.descriptor.Descriptor'], maj
for desc in stem.descriptor.networkstatus._parse_file(descriptor_file, document_type, validate = validate, document_handler = document_handler, **kwargs):
yield desc
elif descriptor_type == stem.descriptor.networkstatus.KeyCertificate.TYPE_ANNOTATION_NAME and major_version == 1:
- for desc in stem.descriptor.networkstatus._parse_file_key_certs(descriptor_file, validate = validate, **kwargs):
+ for desc in stem.descriptor.networkstatus._parse_file_key_certs(descriptor_file, validate = validate):
yield desc
elif descriptor_type in ('network-status-consensus-3', 'network-status-vote-3') and major_version == 1:
document_type = stem.descriptor.networkstatus.NetworkStatusDocumentV3
@@ -549,7 +556,7 @@ def _parse_metrics_file(descriptor_type: Type['stem.descriptor.Descriptor'], maj
raise TypeError("Unrecognized metrics descriptor format. type: '%s', version: '%i.%i'" % (descriptor_type, major_version, minor_version))
-def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] = (), header_template: Sequence[str] = (), footer_template: Sequence[str] = ()) -> bytes:
+def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] = (), header_template: Sequence[Tuple[str, Optional[str]]] = (), footer_template: Sequence[Tuple[str, Optional[str]]] = ()) -> bytes:
"""
Constructs a minimal descriptor with the given attributes. The content we
provide back is of the form...
@@ -586,8 +593,9 @@ def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] =
:returns: bytes with the requested descriptor content
"""
- header_content, footer_content = [], []
- attr = {} if attr is None else collections.OrderedDict(attr) # shallow copy since we're destructive
+ header_content = [] # type: List[str]
+ footer_content = [] # type: List[str]
+ attr = {} if attr is None else collections.OrderedDict(attr) # type: Dict[str, str] # shallow copy since we're destructive
for content, template in ((header_content, header_template),
(footer_content, footer_template)):
@@ -621,28 +629,28 @@ def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] =
return stem.util.str_tools._to_bytes('\n'.join(header_content + remainder + footer_content))
-def _value(line: str, entries: Dict[str, Sequence[str]]) -> str:
+def _value(line: str, entries: ENTRY_TYPE) -> str:
return entries[line][0][0]
-def _values(line: str, entries: Dict[str, Sequence[str]]) -> Sequence[str]:
+def _values(line: str, entries: ENTRY_TYPE) -> Sequence[str]:
return [entry[0] for entry in entries[line]]
-def _parse_simple_line(keyword: str, attribute: str, func: Callable[[str], str] = None) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_simple_line(keyword: str, attribute: str, func: Optional[Callable[[str], Any]] = None) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value(keyword, entries)
setattr(descriptor, attribute, func(value) if func else value)
return _parse
-def _parse_if_present(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+def _parse_if_present(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
return lambda descriptor, entries: setattr(descriptor, attribute, keyword in entries)
-def _parse_bytes_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_bytes_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
line_match = re.search(stem.util.str_tools._to_bytes('^(opt )?%s(?:[%s]+(.*))?$' % (keyword, WHITESPACE)), descriptor.get_bytes(), re.MULTILINE)
result = None
@@ -655,8 +663,8 @@ def _parse_bytes_line(keyword: str, attribute: str) -> Callable[['stem.descripto
return _parse
-def _parse_int_line(keyword: str, attribute: str, allow_negative: bool = True) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_int_line(keyword: str, attribute: str, allow_negative: bool = True) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value(keyword, entries)
try:
@@ -672,10 +680,10 @@ def _parse_int_line(keyword: str, attribute: str, allow_negative: bool = True) -
return _parse
-def _parse_timestamp_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+def _parse_timestamp_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
# "<keyword>" YYYY-MM-DD HH:MM:SS
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value(keyword, entries)
try:
@@ -686,10 +694,10 @@ def _parse_timestamp_line(keyword: str, attribute: str) -> Callable[['stem.descr
return _parse
-def _parse_forty_character_hex(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+def _parse_forty_character_hex(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
# format of fingerprints, sha1 digests, etc
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value(keyword, entries)
if not stem.util.tor_tools.is_hex_digits(value, 40):
@@ -700,15 +708,15 @@ def _parse_forty_character_hex(keyword: str, attribute: str) -> Callable[['stem.
return _parse
-def _parse_protocol_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_protocol_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# parses 'protocol' entries like: Cons=1-2 Desc=1-2 DirCache=1 HSDir=1
value = _value(keyword, entries)
protocols = collections.OrderedDict()
for k, v in _mappings_for(keyword, value):
- versions = []
+ versions = [] # type: List[int]
if not v:
continue
@@ -731,8 +739,8 @@ def _parse_protocol_line(keyword: str, attribute: str) -> Callable[['stem.descri
return _parse
-def _parse_key_block(keyword: str, attribute: str, expected_block_type: str, value_attribute: Optional[str] = None) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_key_block(keyword: str, attribute: str, expected_block_type: str, value_attribute: Optional[str] = None) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value, block_type, block_contents = entries[keyword][0]
if not block_contents or block_type != expected_block_type:
@@ -788,7 +796,7 @@ def _copy(default: Any) -> Any:
return copy.copy(default)
-def _encode_digest(hash_value: bytes, encoding: 'stem.descriptor.DigestEncoding') -> str:
+def _encode_digest(hash_value: 'hashlib._HASH', encoding: 'stem.descriptor.DigestEncoding') -> Union[str, 'hashlib._HASH']: # type: ignore
"""
Encodes a hash value with the given HashEncoding.
"""
@@ -810,21 +818,21 @@ class Descriptor(object):
Common parent for all types of descriptors.
"""
- ATTRIBUTES = {} # mapping of 'attribute' => (default_value, parsing_function)
- PARSER_FOR_LINE = {} # line keyword to its associated parsing function
- TYPE_ANNOTATION_NAME = None
+ ATTRIBUTES = {} # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]] # mapping of 'attribute' => (default_value, parsing_function)
+ PARSER_FOR_LINE = {} # type: Dict[str, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]] # line keyword to its associated parsing function
+ TYPE_ANNOTATION_NAME = None # type: Optional[str]
- def __init__(self, contents, lazy_load = False):
- self._path = None
- self._archive_path = None
+ def __init__(self, contents: bytes, lazy_load: bool = False) -> None:
+ self._path = None # type: Optional[str]
+ self._archive_path = None # type: Optional[str]
self._raw_contents = contents
self._lazy_loading = lazy_load
- self._entries = {}
- self._hash = None
- self._unrecognized_lines = []
+ self._entries = {} # type: ENTRY_TYPE
+ self._hash = None # type: Optional[int]
+ self._unrecognized_lines = [] # type: List[str]
@classmethod
- def from_str(cls, content, **kwargs):
+ def from_str(cls, content: str, **kwargs: Any) -> Union['stem.descriptor.Descriptor', List['stem.descriptor.Descriptor']]:
"""
Provides a :class:`~stem.descriptor.__init__.Descriptor` for the given content.
@@ -873,7 +881,7 @@ class Descriptor(object):
raise ValueError("Descriptor.from_str() expected a single descriptor, but had %i instead. Please include 'multiple = True' if you want a list of results instead." % len(results))
@classmethod
- def content(cls, attr = None, exclude = ()):
+ def content(cls, attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
"""
Creates descriptor content with the given attributes. Mandatory fields are
filled with dummy information unless data is supplied. This doesn't yet
@@ -885,7 +893,7 @@ class Descriptor(object):
:param list exclude: mandatory keywords to exclude from the descriptor, this
results in an invalid descriptor
- :returns: **str** with the content of a descriptor
+ :returns: **bytes** with the content of a descriptor
:raises:
* **ImportError** if cryptography is unavailable and sign is True
@@ -895,7 +903,7 @@ class Descriptor(object):
raise NotImplementedError("The create and content methods haven't been implemented for %s" % cls.__name__)
@classmethod
- def create(cls, attr = None, exclude = (), validate = True):
+ def create(cls, attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True) -> 'stem.descriptor.Descriptor':
"""
Creates a descriptor with the given attributes. Mandatory fields are filled
with dummy information unless data is supplied. This doesn't yet create a
@@ -917,9 +925,9 @@ class Descriptor(object):
* **NotImplementedError** if not implemented for this descriptor type
"""
- return cls(cls.content(attr, exclude), validate = validate)
+ return cls(cls.content(attr, exclude), validate = validate) # type: ignore
- def type_annotation(self):
+ def type_annotation(self) -> 'stem.descriptor.TypeAnnotation':
"""
Provides the `Tor metrics annotation
<https://metrics.torproject.org/collector.html#relay-descriptors>`_ of this
@@ -941,7 +949,7 @@ class Descriptor(object):
else:
raise NotImplementedError('%s does not have a @type annotation' % type(self).__name__)
- def get_path(self):
+ def get_path(self) -> str:
"""
Provides the absolute path that we loaded this descriptor from.
@@ -950,7 +958,7 @@ class Descriptor(object):
return self._path
- def get_archive_path(self):
+ def get_archive_path(self) -> str:
"""
If this descriptor came from an archive then provides its path within the
archive. This is only set if the descriptor was read by
@@ -962,7 +970,7 @@ class Descriptor(object):
return self._archive_path
- def get_bytes(self):
+ def get_bytes(self) -> bytes:
"""
Provides the ASCII **bytes** of the descriptor. This only differs from
**str()** if you're running python 3.x, in which case **str()** provides a
@@ -973,7 +981,7 @@ class Descriptor(object):
return stem.util.str_tools._to_bytes(self._raw_contents)
- def get_unrecognized_lines(self):
+ def get_unrecognized_lines(self) -> List[str]:
"""
Provides a list of lines that were either ignored or had data that we did
not know how to process. This is most common due to new descriptor fields
@@ -989,7 +997,7 @@ class Descriptor(object):
return list(self._unrecognized_lines)
- def _parse(self, entries, validate, parser_for_line = None):
+ def _parse(self, entries: ENTRY_TYPE, validate: bool, parser_for_line: Optional[Dict[str, Callable]] = None) -> None:
"""
Parses a series of 'keyword => (value, pgp block)' mappings and applies
them as attributes.
@@ -1020,16 +1028,16 @@ class Descriptor(object):
if validate:
raise
- def _set_path(self, path):
+ def _set_path(self, path: str) -> None:
self._path = path
- def _set_archive_path(self, path):
+ def _set_archive_path(self, path: str) -> None:
self._archive_path = path
- def _name(self, is_plural = False):
+ def _name(self, is_plural: bool = False) -> str:
return str(type(self))
- def _digest_for_signature(self, signing_key, signature):
+ def _digest_for_signature(self, signing_key: str, signature: str) -> str:
"""
Provides the signed digest we should have given this key and signature.
@@ -1091,13 +1099,15 @@ class Descriptor(object):
digest_hex = codecs.encode(decrypted_bytes[seperator_index + 1:], 'hex_codec')
return stem.util.str_tools._to_unicode(digest_hex.upper())
- def _content_range(self, start = None, end = None):
+ def _content_range(self, start: Optional[Union[str, bytes]] = None, end: Optional[Union[str, bytes]] = None) -> bytes:
"""
Provides the descriptor content inclusively between two substrings.
:param bytes start: start of the content range to get
:param bytes end: end of the content range to get
+ :returns: **bytes** within the given range
+
:raises: ValueError if either the start or end substring are not within our content
"""
@@ -1108,24 +1118,24 @@ class Descriptor(object):
start_index = content.find(stem.util.str_tools._to_bytes(start))
if start_index == -1:
- raise ValueError("'%s' is not present within our descriptor content" % start)
+ raise ValueError("'%s' is not present within our descriptor content" % stem.util.str_tools._to_unicode(start))
if end is not None:
end_index = content.find(stem.util.str_tools._to_bytes(end), start_index)
if end_index == -1:
- raise ValueError("'%s' is not present within our descriptor content" % end)
+ raise ValueError("'%s' is not present within our descriptor content" % stem.util.str_tools._to_unicode(end))
end_index += len(end) # make the ending index inclusive
return content[start_index:end_index]
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
# We can't use standard hasattr() since it calls this function, recursing.
# Doing so works since it stops recursing after several dozen iterations
# (not sure why), but horrible in terms of performance.
- def has_attr(attr):
+ def has_attr(attr: str) -> bool:
try:
super(Descriptor, self).__getattribute__(attr)
return True
@@ -1156,31 +1166,31 @@ class Descriptor(object):
return super(Descriptor, self).__getattribute__(name)
- def __str__(self):
+ def __str__(self) -> str:
return stem.util.str_tools._to_unicode(self._raw_contents)
- def _compare(self, other, method):
+ def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool:
if type(self) != type(other):
return False
return method(str(self).strip(), str(other).strip())
- def __hash__(self):
+ def __hash__(self) -> int:
if self._hash is None:
self._hash = hash(str(self).strip())
return self._hash
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s == o)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
- def __lt__(self, other):
+ def __lt__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s < o)
- def __le__(self, other):
+ def __le__(self, other: Any) -> bool:
return self._compare(other, lambda s, o: s <= o)
@@ -1189,27 +1199,31 @@ class NewlineNormalizer(object):
File wrapper that normalizes CRLF line endings.
"""
- def __init__(self, wrapped_file):
+ def __init__(self, wrapped_file: BinaryIO) -> None:
self._wrapped_file = wrapped_file
self.name = getattr(wrapped_file, 'name', None)
- def read(self, *args):
+ def read(self, *args: Any) -> bytes:
return self._wrapped_file.read(*args).replace(b'\r\n', b'\n')
- def readline(self, *args):
+ def readline(self, *args: Any) -> bytes:
return self._wrapped_file.readline(*args).replace(b'\r\n', b'\n')
- def readlines(self, *args):
+ def readlines(self, *args: Any) -> List[bytes]:
return [line.rstrip(b'\r') for line in self._wrapped_file.readlines(*args)]
- def seek(self, *args):
+ def seek(self, *args: Any) -> int:
return self._wrapped_file.seek(*args)
- def tell(self, *args):
+ def tell(self, *args: Any) -> int:
return self._wrapped_file.tell(*args)
-def _read_until_keywords(keywords, descriptor_file, inclusive = False, ignore_first = False, skip = False, end_position = None, include_ending_keyword = False):
+def _read_until_keywords(keywords: Union[str, Sequence[str]], descriptor_file: BinaryIO, inclusive: bool = False, ignore_first: bool = False, skip: bool = False, end_position: Optional[int] = None) -> List[bytes]:
+ return _read_until_keywords_with_ending_keyword(keywords, descriptor_file, inclusive, ignore_first, skip, end_position, include_ending_keyword = False) # type: ignore
+
+
+def _read_until_keywords_with_ending_keyword(keywords: Union[str, Sequence[str]], descriptor_file: BinaryIO, inclusive: bool = False, ignore_first: bool = False, skip: bool = False, end_position: Optional[int] = None, include_ending_keyword: bool = False) -> Tuple[List[bytes], str]:
"""
Reads from the descriptor file until we get to one of the given keywords or reach the
end of the file.
@@ -1228,7 +1242,7 @@ def _read_until_keywords(keywords, descriptor_file, inclusive = False, ignore_fi
**True**
"""
- content = None if skip else []
+ content = None if skip else [] # type: Optional[List[bytes]]
ending_keyword = None
if isinstance(keywords, (bytes, str)):
@@ -1270,10 +1284,10 @@ def _read_until_keywords(keywords, descriptor_file, inclusive = False, ignore_fi
if include_ending_keyword:
return (content, ending_keyword)
else:
- return content
+ return content # type: ignore
-def _bytes_for_block(content):
+def _bytes_for_block(content: str) -> bytes:
"""
Provides the base64 decoded content of a pgp-style block.
@@ -1291,7 +1305,7 @@ def _bytes_for_block(content):
return base64.b64decode(stem.util.str_tools._to_bytes(content))
-def _get_pseudo_pgp_block(remaining_contents):
+def _get_pseudo_pgp_block(remaining_contents: List[str]) -> Tuple[str, str]:
"""
Checks if given contents begins with a pseudo-Open-PGP-style block and, if
so, pops it off and provides it back to the caller.
@@ -1311,7 +1325,7 @@ def _get_pseudo_pgp_block(remaining_contents):
if block_match:
block_type = block_match.groups()[0]
- block_lines = []
+ block_lines = [] # type: List[str]
end_line = PGP_BLOCK_END % block_type
while True:
@@ -1327,7 +1341,7 @@ def _get_pseudo_pgp_block(remaining_contents):
return None
-def create_signing_key(private_key = None):
+def create_signing_key(private_key: Optional['cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey'] = None) -> 'stem.descriptor.SigningKey': # type: ignore
"""
Serializes a signing key if we have one. Otherwise this creates a new signing
key we can use to create descriptors.
@@ -1363,11 +1377,11 @@ def create_signing_key(private_key = None):
#
# https://github.com/pyca/cryptography/issues/3713
- def no_op(*args, **kwargs):
+ def no_op(*args: Any, **kwargs: Any) -> int:
return 1
- private_key._backend._lib.EVP_PKEY_CTX_set_signature_md = no_op
- private_key._backend.openssl_assert = no_op
+ private_key._backend._lib.EVP_PKEY_CTX_set_signature_md = no_op # type: ignore
+ private_key._backend.openssl_assert = no_op # type: ignore
public_key = private_key.public_key()
public_digest = b'\n' + public_key.public_bytes(
@@ -1378,7 +1392,7 @@ def create_signing_key(private_key = None):
return SigningKey(private_key, public_key, public_digest)
-def _append_router_signature(content, private_key):
+def _append_router_signature(content: bytes, private_key: 'cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey') -> bytes: # type: ignore
"""
Appends a router signature to a server or extrainfo descriptor.
@@ -1399,23 +1413,23 @@ def _append_router_signature(content, private_key):
return content + b'\n'.join([b'-----BEGIN SIGNATURE-----'] + stem.util.str_tools._split_by_length(signature, 64) + [b'-----END SIGNATURE-----\n'])
-def _random_nickname():
+def _random_nickname() -> str:
return ('Unnamed%i' % random.randint(0, 100000000000000))[:19]
-def _random_fingerprint():
+def _random_fingerprint() -> str:
return ('%040x' % random.randrange(16 ** 40)).upper()
-def _random_ipv4_address():
+def _random_ipv4_address() -> str:
return '%i.%i.%i.%i' % (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
-def _random_date():
+def _random_date() -> str:
return '%i-%02i-%02i %02i:%02i:%02i' % (random.randint(2000, 2015), random.randint(1, 12), random.randint(1, 20), random.randint(0, 23), random.randint(0, 59), random.randint(0, 59))
-def _random_crypto_blob(block_type = None):
+def _random_crypto_blob(block_type: Optional[str] = None) -> str:
"""
Provides a random string that can be used for crypto blocks.
"""
@@ -1429,7 +1443,11 @@ def _random_crypto_blob(block_type = None):
return crypto_blob
-def _descriptor_components(raw_contents, validate, extra_keywords = (), non_ascii_fields = ()):
+def _descriptor_components(raw_contents: bytes, validate: bool, non_ascii_fields: Sequence[str] = ()) -> ENTRY_TYPE:
+ return _descriptor_components_with_extra(raw_contents, validate, (), non_ascii_fields) # type: ignore
+
+
+def _descriptor_components_with_extra(raw_contents: bytes, validate: bool, extra_keywords: Sequence[str] = (), non_ascii_fields: Sequence[str] = ()) -> Tuple[ENTRY_TYPE, List[str]]:
"""
Initial breakup of the server descriptor contents to make parsing easier.
@@ -1443,7 +1461,7 @@ def _descriptor_components(raw_contents, validate, extra_keywords = (), non_asci
entries because this influences the resulting exit policy, but for everything
else in server descriptors the order does not matter.
- :param str raw_contents: descriptor content provided by the relay
+ :param bytes raw_contents: descriptor content provided by the relay
:param bool validate: checks the validity of the descriptor's content if
True, skips these checks otherwise
:param list extra_keywords: entity keywords to put into a separate listing
@@ -1456,12 +1474,9 @@ def _descriptor_components(raw_contents, validate, extra_keywords = (), non_asci
value tuple, the second being a list of those entries.
"""
- if isinstance(raw_contents, bytes):
- raw_contents = stem.util.str_tools._to_unicode(raw_contents)
-
- entries = collections.OrderedDict()
+ entries = collections.OrderedDict() # type: ENTRY_TYPE
extra_entries = [] # entries with a keyword in extra_keywords
- remaining_lines = raw_contents.split('\n')
+ remaining_lines = stem.util.str_tools._to_unicode(raw_contents).split('\n')
while remaining_lines:
line = remaining_lines.pop(0)
@@ -1525,7 +1540,7 @@ def _descriptor_components(raw_contents, validate, extra_keywords = (), non_asci
if extra_keywords:
return entries, extra_entries
else:
- return entries
+ return entries # type: ignore
# importing at the end to avoid circular dependencies on our Descriptor class
diff --git a/stem/descriptor/bandwidth_file.py b/stem/descriptor/bandwidth_file.py
index 49df3173..f1f0b1e2 100644
--- a/stem/descriptor/bandwidth_file.py
+++ b/stem/descriptor/bandwidth_file.py
@@ -21,9 +21,10 @@ import time
import stem.util.str_tools
-from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Type
+from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type
from stem.descriptor import (
+ ENTRY_TYPE,
_mappings_for,
Descriptor,
)
@@ -168,11 +169,14 @@ def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any
* **IOError** if the file can't be read
"""
- yield BandwidthFile(descriptor_file.read(), validate, **kwargs)
+ if kwargs:
+ raise ValueError('BUG: keyword arguments unused by bandwidth files')
+ yield BandwidthFile(descriptor_file.read(), validate)
-def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
- header = collections.OrderedDict()
+
+def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
+ header = collections.OrderedDict() # type: collections.OrderedDict[str, str]
content = io.BytesIO(descriptor.get_bytes())
content.readline() # skip the first line, which should be the timestamp
@@ -197,7 +201,7 @@ def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S
if key == 'version':
version_index = index
else:
- raise ValueError("Header expected to be key=value pairs, but had '%s'" % line)
+ raise ValueError("Header expected to be key=value pairs, but had '%s'" % stem.util.str_tools._to_unicode(line))
index += 1
@@ -216,16 +220,16 @@ def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S
raise ValueError("The 'version' header must be in the second position")
-def _parse_timestamp(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_timestamp(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
first_line = io.BytesIO(descriptor.get_bytes()).readline().strip()
if first_line.isdigit():
descriptor.timestamp = datetime.datetime.utcfromtimestamp(int(first_line))
else:
- raise ValueError("First line should be a unix timestamp, but was '%s'" % first_line)
+ raise ValueError("First line should be a unix timestamp, but was '%s'" % stem.util.str_tools._to_unicode(first_line))
-def _parse_body(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_body(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# In version 1.0.0 the body is everything after the first line. Otherwise
# it's everything after the header's divider.
@@ -239,13 +243,13 @@ def _parse_body(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Seq
measurements = {}
- for line in content.readlines():
- line = stem.util.str_tools._to_unicode(line.strip())
+ for line_bytes in content.readlines():
+ line = stem.util.str_tools._to_unicode(line_bytes.strip())
attr = dict(_mappings_for('measurement', line))
fingerprint = attr.get('node_id', '').lstrip('$') # bwauths prefix fingerprints with '$'
if not fingerprint:
- raise ValueError("Every meaurement must include 'node_id': %s" % line)
+ raise ValueError("Every meaurement must include 'node_id': %s" % stem.util.str_tools._to_unicode(line))
elif fingerprint in measurements:
raise ValueError('Relay %s is listed multiple times. It should only be present once.' % fingerprint)
@@ -298,12 +302,12 @@ class BandwidthFile(Descriptor):
'timestamp': (None, _parse_timestamp),
'header': ({}, _parse_header),
'measurements': ({}, _parse_body),
- }
+ } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
ATTRIBUTES.update(dict([(k, (None, _parse_header)) for k in HEADER_ATTR.keys()]))
@classmethod
- def content(cls: Type['stem.descriptor.bandwidth_file.BandwidthFile'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.bandwidth_file.BandwidthFile'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
"""
Creates descriptor content with the given attributes. This descriptor type
differs somewhat from others and treats our attr/exclude attributes as
@@ -328,7 +332,7 @@ class BandwidthFile(Descriptor):
header = collections.OrderedDict(attr) if attr is not None else collections.OrderedDict()
timestamp = header.pop('timestamp', str(int(time.time())))
- content = header.pop('content', [])
+ content = header.pop('content', []) # type: List[str] # type: ignore
version = header.get('version', HEADER_DEFAULT.get('version'))
lines = []
@@ -354,7 +358,7 @@ class BandwidthFile(Descriptor):
return b'\n'.join(lines)
- def __init__(self, raw_content: str, validate: bool = False) -> None:
+ def __init__(self, raw_content: bytes, validate: bool = False) -> None:
super(BandwidthFile, self).__init__(raw_content, lazy_load = not validate)
if validate:
diff --git a/stem/descriptor/certificate.py b/stem/descriptor/certificate.py
index 6956a60f..bc09be2d 100644
--- a/stem/descriptor/certificate.py
+++ b/stem/descriptor/certificate.py
@@ -64,7 +64,8 @@ import stem.util.enum
import stem.util.str_tools
from stem.client.datatype import CertType, Field, Size, split
-from typing import Callable, Dict, Optional, Sequence, Tuple, Union
+from stem.descriptor import ENTRY_TYPE
+from typing import Callable, List, Optional, Sequence, Tuple, Union
ED25519_KEY_LENGTH = 32
ED25519_HEADER_LENGTH = 40
@@ -218,7 +219,7 @@ class Ed25519Certificate(object):
return stem.util.str_tools._to_unicode(encoded)
@staticmethod
- def _from_descriptor(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
+ def _from_descriptor(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
def _parse(descriptor, entries):
value, block_type, block_contents = entries[keyword][0]
@@ -253,7 +254,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
is unavailable
"""
- def __init__(self, cert_type: Optional['stem.client.datatype.CertType'] = None, expiration: Optional[datetime.datetime] = None, key_type: Optional[int] = None, key: Optional[bytes] = None, extensions: Optional[Sequence['stem.descriptor.certificate.Ed25519Extension']] = None, signature: Optional[bytes] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> None:
+ def __init__(self, cert_type: Optional['stem.client.datatype.CertType'] = None, expiration: Optional[datetime.datetime] = None, key_type: Optional[int] = None, key: Optional[bytes] = None, extensions: Optional[Sequence['stem.descriptor.certificate.Ed25519Extension']] = None, signature: Optional[bytes] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> None: # type: ignore
super(Ed25519CertificateV1, self).__init__(1)
if cert_type is None:
@@ -261,12 +262,15 @@ class Ed25519CertificateV1(Ed25519Certificate):
elif key is None:
raise ValueError('Certificate key is required')
+ self.type = None # type: Optional[stem.client.datatype.CertType]
+ self.type_int = None # type: Optional[int]
+
self.type, self.type_int = CertType.get(cert_type)
- self.expiration = expiration if expiration else datetime.datetime.utcnow() + datetime.timedelta(hours = DEFAULT_EXPIRATION_HOURS)
- self.key_type = key_type if key_type else 1
- self.key = stem.util._pubkey_bytes(key)
- self.extensions = extensions if extensions else []
- self.signature = signature
+ self.expiration = expiration if expiration else datetime.datetime.utcnow() + datetime.timedelta(hours = DEFAULT_EXPIRATION_HOURS) # type: datetime.datetime
+ self.key_type = key_type if key_type else 1 # type: int
+ self.key = stem.util._pubkey_bytes(key) # type: bytes
+ self.extensions = list(extensions) if extensions else [] # type: List[stem.descriptor.certificate.Ed25519Extension]
+ self.signature = signature # type: Optional[bytes]
if signing_key:
calculated_sig = signing_key.sign(self.pack())
diff --git a/stem/descriptor/collector.py b/stem/descriptor/collector.py
index 1f1b1e95..9749dadb 100644
--- a/stem/descriptor/collector.py
+++ b/stem/descriptor/collector.py
@@ -63,7 +63,7 @@ import stem.util.connection
import stem.util.str_tools
from stem.descriptor import Compression, DocumentHandler
-from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
COLLECTOR_URL = 'https://collector.torproject.org/'
REFRESH_INDEX_RATE = 3600 # get new index if cached copy is an hour old
@@ -93,7 +93,7 @@ def get_instance() -> 'stem.descriptor.collector.CollecTor':
return SINGLETON_COLLECTOR
-def get_server_descriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.server_descriptor.RelayDescriptor']:
+def get_server_descriptors(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_server_descriptors`
@@ -104,7 +104,7 @@ def get_server_descriptors(start: datetime.datetime = None, end: datetime.dateti
yield desc
-def get_extrainfo_descriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor']:
+def get_extrainfo_descriptors(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor]:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_extrainfo_descriptors`
@@ -115,7 +115,7 @@ def get_extrainfo_descriptors(start: datetime.datetime = None, end: datetime.dat
yield desc
-def get_microdescriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']:
+def get_microdescriptors(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_microdescriptors`
@@ -126,7 +126,7 @@ def get_microdescriptors(start: datetime.datetime = None, end: datetime.datetime
yield desc
-def get_consensus(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']:
+def get_consensus(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, document_handler: stem.descriptor.DocumentHandler = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntry]:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_consensus`
@@ -137,7 +137,7 @@ def get_consensus(start: datetime.datetime = None, end: datetime.datetime = None
yield desc
-def get_key_certificates(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']:
+def get_key_certificates(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.networkstatus.KeyCertificate]:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_key_certificates`
@@ -148,7 +148,7 @@ def get_key_certificates(start: datetime.datetime = None, end: datetime.datetime
yield desc
-def get_bandwidth_files(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']:
+def get_bandwidth_files(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.bandwidth_file.BandwidthFile]:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_bandwidth_files`
@@ -159,7 +159,7 @@ def get_bandwidth_files(start: datetime.datetime = None, end: datetime.datetime
yield desc
-def get_exit_lists(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']:
+def get_exit_lists(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.tordnsel.TorDNSEL]:
"""
Shorthand for
:func:`~stem.descriptor.collector.CollecTor.get_exit_lists`
@@ -188,14 +188,14 @@ class File(object):
:var datetime last_modified: when the file was last modified
"""
- def __init__(self, path: str, types: Tuple[str], size: int, sha256: str, first_published: datetime.datetime, last_published: datetime.datetime, last_modified: datetime.datetime) -> None:
+ def __init__(self, path: str, types: Tuple[str], size: int, sha256: str, first_published: str, last_published: str, last_modified: str) -> None:
self.path = path
self.types = tuple(types) if types else ()
self.compression = File._guess_compression(path)
self.size = size
self.sha256 = sha256
self.last_modified = datetime.datetime.strptime(last_modified, '%Y-%m-%d %H:%M')
- self._downloaded_to = None # location we last downloaded to
+ self._downloaded_to = None # type: Optional[str] # location we last downloaded to
# Most descriptor types have publication time fields, but microdescriptors
# don't because these files lack timestamps to parse.
@@ -206,7 +206,7 @@ class File(object):
else:
self.start, self.end = File._guess_time_range(path)
- def read(self, directory: Optional[str] = None, descriptor_type: Optional[str] = None, start: datetime.datetime = None, end: datetime.datetime = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.Descriptor']:
+ def read(self, directory: Optional[str] = None, descriptor_type: Optional[str] = None, start: datetime.datetime = None, end: datetime.datetime = None, document_handler: stem.descriptor.DocumentHandler = DocumentHandler.ENTRIES, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.Descriptor]:
"""
Provides descriptors from this archive. Descriptors are downloaded or read
from disk as follows...
@@ -325,8 +325,8 @@ class File(object):
# check if this file already exists with the correct checksum
if os.path.exists(path):
- with open(path) as prior_file:
- expected_hash = binascii.hexlify(base64.b64decode(self.sha256))
+ with open(path, 'b') as prior_file:
+ expected_hash = binascii.hexlify(base64.b64decode(self.sha256)).decode('utf-8')
actual_hash = hashlib.sha256(prior_file.read()).hexdigest()
if expected_hash == actual_hash:
@@ -346,7 +346,7 @@ class File(object):
return path
@staticmethod
- def _guess_compression(path) -> 'stem.descriptor.Compression':
+ def _guess_compression(path: str) -> stem.descriptor._Compression:
"""
Determine file comprssion from CollecTor's filename.
"""
@@ -358,7 +358,7 @@ class File(object):
return Compression.PLAINTEXT
@staticmethod
- def _guess_time_range(path) -> Tuple[datetime.datetime, datetime.datetime]:
+ def _guess_time_range(path: str) -> Tuple[datetime.datetime, datetime.datetime]:
"""
Attemt to determine the (start, end) time range from CollecTor's filename.
This provides (None, None) if this cannot be determined.
@@ -404,10 +404,10 @@ class CollecTor(object):
self.timeout = timeout
self._cached_index = None
- self._cached_files = None
- self._cached_index_at = 0
+ self._cached_files = None # type: Optional[List[File]]
+ self._cached_index_at = 0.0
- def get_server_descriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.server_descriptor.RelayDescriptor']:
+ def get_server_descriptors(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]:
"""
Provides server descriptors published during the given time range, sorted
oldest to newest.
@@ -432,9 +432,9 @@ class CollecTor(object):
for f in self.files(desc_type, start, end):
for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries):
- yield desc
+ yield desc # type: ignore
- def get_extrainfo_descriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor']:
+ def get_extrainfo_descriptors(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor]:
"""
Provides extrainfo descriptors published during the given time range,
sorted oldest to newest.
@@ -459,9 +459,9 @@ class CollecTor(object):
for f in self.files(desc_type, start, end):
for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries):
- yield desc
+ yield desc # type: ignore
- def get_microdescriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']:
+ def get_microdescriptors(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]:
"""
Provides microdescriptors estimated to be published during the given time
range, sorted oldest to newest. Unlike server/extrainfo descriptors,
@@ -493,9 +493,9 @@ class CollecTor(object):
for f in self.files('microdescriptor', start, end):
for desc in f.read(cache_to, 'microdescriptor', start, end, timeout = timeout, retries = retries):
- yield desc
+ yield desc # type: ignore
- def get_consensus(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']:
+ def get_consensus(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, document_handler: stem.descriptor.DocumentHandler = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntry]:
"""
Provides consensus router status entries published during the given time
range, sorted oldest to newest.
@@ -537,9 +537,9 @@ class CollecTor(object):
for f in self.files(desc_type, start, end):
for desc in f.read(cache_to, desc_type, start, end, document_handler, timeout = timeout, retries = retries):
- yield desc
+ yield desc # type: ignore
- def get_key_certificates(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']:
+ def get_key_certificates(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.networkstatus.KeyCertificate]:
"""
Directory authority key certificates for the given time range,
sorted oldest to newest.
@@ -561,9 +561,9 @@ class CollecTor(object):
for f in self.files('dir-key-certificate-3', start, end):
for desc in f.read(cache_to, 'dir-key-certificate-3', start, end, timeout = timeout, retries = retries):
- yield desc
+ yield desc # type: ignore
- def get_bandwidth_files(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']:
+ def get_bandwidth_files(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.bandwidth_file.BandwidthFile]:
"""
Bandwidth authority heuristics for the given time range, sorted oldest to
newest.
@@ -585,9 +585,9 @@ class CollecTor(object):
for f in self.files('bandwidth-file', start, end):
for desc in f.read(cache_to, 'bandwidth-file', start, end, timeout = timeout, retries = retries):
- yield desc
+ yield desc # type: ignore
- def get_exit_lists(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']:
+ def get_exit_lists(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.tordnsel.TorDNSEL]:
"""
`TorDNSEL exit lists <https://www.torproject.org/projects/tordnsel.html.en>`_
for the given time range, sorted oldest to newest.
@@ -609,9 +609,9 @@ class CollecTor(object):
for f in self.files('tordnsel', start, end):
for desc in f.read(cache_to, 'tordnsel', start, end, timeout = timeout, retries = retries):
- yield desc
+ yield desc # type: ignore
- def index(self, compression: Union[str, 'descriptor.Compression'] = 'best') -> Dict[str, Any]:
+ def index(self, compression: Union[str, stem.descriptor._Compression] = 'best') -> Dict[str, Any]:
"""
Provides the archives available in CollecTor.
@@ -632,21 +632,25 @@ class CollecTor(object):
if compression == 'best':
for option in (Compression.LZMA, Compression.BZ2, Compression.GZIP, Compression.PLAINTEXT):
if option.available:
- compression = option
+ compression_enum = option
break
elif compression is None:
- compression = Compression.PLAINTEXT
+ compression_enum = Compression.PLAINTEXT
+ elif isinstance(compression, stem.descriptor._Compression):
+ compression_enum = compression
+ else:
+ raise ValueError('compression must be a descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__))
- extension = compression.extension if compression != Compression.PLAINTEXT else ''
+ extension = compression_enum.extension if compression_enum != Compression.PLAINTEXT else ''
url = COLLECTOR_URL + 'index/index.json' + extension
- response = compression.decompress(stem.util.connection.download(url, self.timeout, self.retries))
+ response = compression_enum.decompress(stem.util.connection.download(url, self.timeout, self.retries))
self._cached_index = json.loads(stem.util.str_tools._to_unicode(response))
self._cached_index_at = time.time()
return self._cached_index
- def files(self, descriptor_type: str = None, start: datetime.datetime = None, end: datetime.datetime = None) -> Sequence['stem.descriptor.collector.File']:
+ def files(self, descriptor_type: Optional[str] = None, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None) -> List['stem.descriptor.collector.File']:
"""
Provides files CollecTor presently has, sorted oldest to newest.
@@ -681,7 +685,7 @@ class CollecTor(object):
return matches
@staticmethod
- def _files(val: str, path: Sequence[str]) -> Sequence['stem.descriptor.collector.File']:
+ def _files(val: Dict[str, Any], path: List[str]) -> List['stem.descriptor.collector.File']:
"""
Recursively provies files within the index.
@@ -698,7 +702,7 @@ class CollecTor(object):
for k, v in val.items():
if k == 'files':
- for attr in v:
+ for attr in v: # Dict[str, str]
file_path = '/'.join(path + [attr.get('path')])
files.append(File(file_path, attr.get('types'), attr.get('size'), attr.get('sha256'), attr.get('first_published'), attr.get('last_published'), attr.get('last_modified')))
elif k == 'directories':
diff --git a/stem/descriptor/extrainfo_descriptor.py b/stem/descriptor/extrainfo_descriptor.py
index 6aca3c29..cd9467d1 100644
--- a/stem/descriptor/extrainfo_descriptor.py
+++ b/stem/descriptor/extrainfo_descriptor.py
@@ -76,9 +76,10 @@ import stem.util.connection
import stem.util.enum
import stem.util.str_tools
-from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import (
+ ENTRY_TYPE,
PGP_BLOCK_END,
Descriptor,
DigestHash,
@@ -184,6 +185,9 @@ def _parse_file(descriptor_file: BinaryIO, is_bridge = False, validate = False,
* **IOError** if the file can't be read
"""
+ if kwargs:
+ raise ValueError('BUG: keyword arguments unused by extrainfo descriptors')
+
while True:
if not is_bridge:
extrainfo_content = _read_until_keywords('router-signature', descriptor_file)
@@ -200,9 +204,9 @@ def _parse_file(descriptor_file: BinaryIO, is_bridge = False, validate = False,
extrainfo_content = extrainfo_content[1:]
if is_bridge:
- yield BridgeExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate, **kwargs)
+ yield BridgeExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate)
else:
- yield RelayExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate, **kwargs)
+ yield RelayExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate)
else:
break # done parsing file
@@ -241,7 +245,7 @@ def _parse_timestamp_and_interval(keyword: str, content: str) -> Tuple[datetime.
raise ValueError("%s line's timestamp wasn't parsable: %s" % (keyword, line))
-def _parse_extra_info_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_extra_info_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "extra-info" Nickname Fingerprint
value = _value('extra-info', entries)
@@ -258,7 +262,7 @@ def _parse_extra_info_line(descriptor: 'stem.descriptor.Descriptor', entries: Di
descriptor.fingerprint = extra_info_comp[1]
-def _parse_transport_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_transport_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "transport" transportname address:port [arglist]
# Everything after the transportname is scrubbed in published bridge
# descriptors, so we'll never see it in practice.
@@ -304,7 +308,7 @@ def _parse_transport_line(descriptor: 'stem.descriptor.Descriptor', entries: Dic
descriptor.transport = transports
-def _parse_padding_counts_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_padding_counts_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "padding-counts" YYYY-MM-DD HH:MM:SS (NSEC s) key=val key=val...
value = _value('padding-counts', entries)
@@ -319,7 +323,7 @@ def _parse_padding_counts_line(descriptor: 'stem.descriptor.Descriptor', entries
setattr(descriptor, 'padding_counts', counts)
-def _parse_dirreq_line(keyword: str, recognized_counts_attr: str, unrecognized_counts_attr: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_dirreq_line(keyword: str, recognized_counts_attr: str, unrecognized_counts_attr: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value(keyword, entries)
recognized_counts = {}
@@ -343,7 +347,7 @@ def _parse_dirreq_line(keyword: str, recognized_counts_attr: str, unrecognized_c
setattr(descriptor, unrecognized_counts_attr, unrecognized_counts)
-def _parse_dirreq_share_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_dirreq_share_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value(keyword, entries)
if not value.endswith('%'):
@@ -356,7 +360,7 @@ def _parse_dirreq_share_line(keyword: str, attribute: str, descriptor: 'stem.des
setattr(descriptor, attribute, float(value[:-1]) / 100)
-def _parse_cell_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_cell_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "<keyword>" num,...,num
value = _value(keyword, entries)
@@ -378,7 +382,7 @@ def _parse_cell_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.
raise exc
-def _parse_timestamp_and_interval_line(keyword: str, end_attribute: str, interval_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_timestamp_and_interval_line(keyword: str, end_attribute: str, interval_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "<keyword>" YYYY-MM-DD HH:MM:SS (NSEC s)
timestamp, interval, _ = _parse_timestamp_and_interval(keyword, _value(keyword, entries))
@@ -386,7 +390,7 @@ def _parse_timestamp_and_interval_line(keyword: str, end_attribute: str, interva
setattr(descriptor, interval_attribute, interval)
-def _parse_conn_bi_direct_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_conn_bi_direct_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "conn-bi-direct" YYYY-MM-DD HH:MM:SS (NSEC s) BELOW,READ,WRITE,BOTH
value = _value('conn-bi-direct', entries)
@@ -404,7 +408,7 @@ def _parse_conn_bi_direct_line(descriptor: 'stem.descriptor.Descriptor', entries
descriptor.conn_bi_direct_both = int(stats[3])
-def _parse_history_line(keyword: str, end_attribute: str, interval_attribute: str, values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_history_line(keyword: str, end_attribute: str, interval_attribute: str, values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "<keyword>" YYYY-MM-DD HH:MM:SS (NSEC s) NUM,NUM,NUM,NUM,NUM...
value = _value(keyword, entries)
@@ -422,7 +426,7 @@ def _parse_history_line(keyword: str, end_attribute: str, interval_attribute: st
setattr(descriptor, values_attribute, history_values)
-def _parse_port_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_port_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "<keyword>" port=N,port=N,...
value, port_mappings = _value(keyword, entries), {}
@@ -431,13 +435,13 @@ def _parse_port_count_line(keyword: str, attribute: str, descriptor: 'stem.descr
if (port != 'other' and not stem.util.connection.is_valid_port(port)) or not stat.isdigit():
raise ValueError('Entries in %s line should only be PORT=N entries: %s %s' % (keyword, keyword, value))
- port = int(port) if port.isdigit() else port
+ port = int(port) if port.isdigit() else port # type: ignore # this can be an int or 'other'
port_mappings[port] = int(stat)
setattr(descriptor, attribute, port_mappings)
-def _parse_geoip_to_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_geoip_to_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "<keyword>" CC=N,CC=N,...
#
# The maxmind geoip (https://www.maxmind.com/app/iso3166) has numeric
@@ -457,7 +461,7 @@ def _parse_geoip_to_count_line(keyword: str, attribute: str, descriptor: 'stem.d
setattr(descriptor, attribute, locale_usage)
-def _parse_bridge_ip_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_bridge_ip_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value, ip_versions = _value('bridge-ip-versions', entries), {}
for protocol, count in _mappings_for('bridge-ip-versions', value, divider = ','):
@@ -469,7 +473,7 @@ def _parse_bridge_ip_versions_line(descriptor: 'stem.descriptor.Descriptor', ent
descriptor.ip_versions = ip_versions
-def _parse_bridge_ip_transports_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_bridge_ip_transports_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value, ip_transports = _value('bridge-ip-transports', entries), {}
for protocol, count in _mappings_for('bridge-ip-transports', value, divider = ','):
@@ -481,7 +485,7 @@ def _parse_bridge_ip_transports_line(descriptor: 'stem.descriptor.Descriptor', e
descriptor.ip_transports = ip_transports
-def _parse_hs_stats(keyword: str, stat_attribute: str, extra_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_hs_stats(keyword: str, stat_attribute: str, extra_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "<keyword>" num key=val key=val...
value, stat, extra = _value(keyword, entries), None, {}
@@ -768,7 +772,7 @@ class ExtraInfoDescriptor(Descriptor):
'ip_versions': (None, _parse_bridge_ip_versions_line),
'ip_transports': (None, _parse_bridge_ip_transports_line),
- }
+ } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = {
'extra-info': _parse_extra_info_line,
@@ -817,7 +821,7 @@ class ExtraInfoDescriptor(Descriptor):
'bridge-ip-transports': _parse_bridge_ip_transports_line,
}
- def __init__(self, raw_contents: str, validate: bool = False) -> None:
+ def __init__(self, raw_contents: bytes, validate: bool = False) -> None:
"""
Extra-info descriptor constructor. By default this validates the
descriptor's content as it's parsed. This validation can be disabled to
@@ -854,7 +858,7 @@ class ExtraInfoDescriptor(Descriptor):
else:
self._entries = entries
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore
"""
Digest of this descriptor's content. These are referenced by...
@@ -879,7 +883,7 @@ class ExtraInfoDescriptor(Descriptor):
raise NotImplementedError('Unsupported Operation: this should be implemented by the ExtraInfoDescriptor subclass')
- def _required_fields(self) -> Tuple[str]:
+ def _required_fields(self) -> Tuple[str, ...]:
return REQUIRED_FIELDS
def _first_keyword(self) -> str:
@@ -920,7 +924,7 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor):
})
@classmethod
- def content(cls: Type['stem.descriptor.extrainfo.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> str:
+ def content(cls: Type['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> bytes:
base_header = (
('extra-info', '%s %s' % (_random_nickname(), _random_fingerprint())),
('published', _random_date()),
@@ -941,11 +945,11 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor):
))
@classmethod
- def create(cls: Type['stem.descriptor.extrainfo.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> 'stem.descriptor.extrainfo.RelayExtraInfoDescriptor':
+ def create(cls: Type['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> 'stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor':
return cls(cls.content(attr, exclude, sign, signing_key), validate = validate)
@functools.lru_cache()
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore
if hash_type == DigestHash.SHA1:
# our digest is calculated from everything except our signature
@@ -989,7 +993,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor):
})
@classmethod
- def content(cls: Type['stem.descriptor.extrainfo.BridgeExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.extrainfo_descriptor.BridgeExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('extra-info', 'ec2bridgereaac65a3 %s' % _random_fingerprint()),
('published', _random_date()),
@@ -997,7 +1001,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor):
('router-digest', _random_fingerprint()),
))
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore
if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX:
return self._digest
elif hash_type == DigestHash.SHA256 and encoding == DigestEncoding.BASE64:
@@ -1005,7 +1009,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor):
else:
raise NotImplementedError('Bridge extrainfo digests are only available as sha1/hex and sha256/base64, not %s/%s' % (hash_type, encoding))
- def _required_fields(self) -> Tuple[str]:
+ def _required_fields(self) -> Tuple[str, ...]:
excluded_fields = [
'router-signature',
]
diff --git a/stem/descriptor/hidden_service.py b/stem/descriptor/hidden_service.py
index 8d23838e..2eb7d02f 100644
--- a/stem/descriptor/hidden_service.py
+++ b/stem/descriptor/hidden_service.py
@@ -51,9 +51,10 @@ import stem.util.tor_tools
from stem.client.datatype import CertType
from stem.descriptor.certificate import ExtensionType, Ed25519Extension, Ed25519Certificate, Ed25519CertificateV1
-from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import (
+ ENTRY_TYPE,
PGP_BLOCK_END,
Descriptor,
_descriptor_content,
@@ -104,7 +105,7 @@ INTRODUCTION_POINTS_ATTR = {
'onion_key': None,
'service_key': None,
'intro_authentication': [],
-}
+} # type: Dict[str, Any]
# introduction-point fields that can only appear once
@@ -133,7 +134,7 @@ class DecryptionFailure(Exception):
"""
-class IntroductionPointV2(collections.namedtuple('IntroductionPointV2', INTRODUCTION_POINTS_ATTR.keys())):
+class IntroductionPointV2(collections.namedtuple('IntroductionPointV2', INTRODUCTION_POINTS_ATTR.keys())): # type: ignore
"""
Introduction point for a v2 hidden service.
@@ -163,7 +164,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
"""
@staticmethod
- def parse(content: str) -> 'stem.descriptor.hidden_service.IntroductionPointV3':
+ def parse(content: bytes) -> 'stem.descriptor.hidden_service.IntroductionPointV3':
"""
Parses an introduction point from its descriptor content.
@@ -175,7 +176,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
"""
entry = _descriptor_components(content, False)
- link_specifiers = IntroductionPointV3._parse_link_specifiers(_value('introduction-point', entry))
+ link_specifiers = IntroductionPointV3._parse_link_specifiers(stem.util.str_tools._to_bytes(_value('introduction-point', entry)))
onion_key_line = _value('onion-key', entry)
onion_key = onion_key_line[5:] if onion_key_line.startswith('ntor ') else None
@@ -201,7 +202,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3(link_specifiers, onion_key, auth_key_cert, enc_key, enc_key_cert, legacy_key, legacy_key_cert)
@staticmethod
- def create_for_address(address: str, port: int, expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3':
+ def create_for_address(address: str, port: int, expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': # type: ignore
"""
Simplified constructor for a single address/port link specifier.
@@ -223,6 +224,8 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
if not stem.util.connection.is_valid_port(port):
raise ValueError("'%s' is an invalid port" % port)
+ link_specifiers = None # type: Optional[List[stem.client.datatype.LinkSpecifier]]
+
if stem.util.connection.is_valid_ipv4_address(address):
link_specifiers = [stem.client.datatype.LinkByIPv4(address, port)]
elif stem.util.connection.is_valid_ipv6_address(address):
@@ -233,7 +236,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3.create_for_link_specifiers(link_specifiers, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None)
@staticmethod
- def create_for_link_specifiers(link_specifiers: Sequence['stem.client.datatype.LinkSpecifier'], expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3':
+ def create_for_link_specifiers(link_specifiers: Sequence['stem.client.datatype.LinkSpecifier'], expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': # type: ignore
"""
Simplified constructor. For more sophisticated use cases you can use this
as a template for how introduction points are properly created.
@@ -300,7 +303,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return '\n'.join(lines)
- def onion_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey':
+ def onion_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': # type: ignore
"""
Provides our ntor introduction point public key.
@@ -313,7 +316,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.onion_key_raw, x25519 = True)
- def auth_key(self) -> 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey':
+ def auth_key(self) -> 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey': # type: ignore
"""
Provides our authentication certificate's public key.
@@ -326,7 +329,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.auth_key_cert.key, ed25519 = True)
- def enc_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey':
+ def enc_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': # type: ignore
"""
Provides our encryption key.
@@ -339,7 +342,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.enc_key_raw, x25519 = True)
- def legacy_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey':
+ def legacy_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': # type: ignore
"""
Provides our legacy introduction point public key.
@@ -353,7 +356,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.legacy_key_raw, x25519 = True)
@staticmethod
- def _key_as(value: str, x25519: bool = False, ed25519: bool = False) -> Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']:
+ def _key_as(value: bytes, x25519: bool = False, ed25519: bool = False) -> Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']: # type: ignore
if value is None or (not x25519 and not ed25519):
return value
@@ -376,11 +379,11 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return Ed25519PublicKey.from_public_bytes(value)
@staticmethod
- def _parse_link_specifiers(content: str) -> 'stem.client.datatype.LinkSpecifier':
+ def _parse_link_specifiers(content: bytes) -> List['stem.client.datatype.LinkSpecifier']:
try:
content = base64.b64decode(content)
except Exception as exc:
- raise ValueError('Unable to base64 decode introduction point (%s): %s' % (exc, content))
+ raise ValueError('Unable to base64 decode introduction point (%s): %s' % (exc, stem.util.str_tools._to_unicode(content)))
link_specifiers = []
count, content = stem.client.datatype.Size.CHAR.pop(content)
@@ -390,7 +393,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
link_specifiers.append(link_specifier)
if content:
- raise ValueError('Introduction point had excessive data (%s)' % content)
+ raise ValueError('Introduction point had excessive data (%s)' % stem.util.str_tools._to_unicode(content))
return link_specifiers
@@ -418,7 +421,7 @@ class AuthorizedClient(object):
:var str cookie: base64 encoded authentication cookie
"""
- def __init__(self, id: str = None, iv: str = None, cookie: str = None) -> None:
+ def __init__(self, id: Optional[str] = None, iv: Optional[str] = None, cookie: Optional[str] = None) -> None:
self.id = stem.util.str_tools._to_unicode(id if id else base64.b64encode(os.urandom(8)).rstrip(b'='))
self.iv = stem.util.str_tools._to_unicode(iv if iv else base64.b64encode(os.urandom(16)).rstrip(b'='))
self.cookie = stem.util.str_tools._to_unicode(cookie if cookie else base64.b64encode(os.urandom(16)).rstrip(b'='))
@@ -433,7 +436,7 @@ class AuthorizedClient(object):
return not self == other
-def _parse_file(descriptor_file: BinaryIO, desc_type: str = None, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.hidden_service.HiddenServiceDescriptor']:
+def _parse_file(descriptor_file: BinaryIO, desc_type: Optional[Type['stem.descriptor.hidden_service.HiddenServiceDescriptor']] = None, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.hidden_service.HiddenServiceDescriptor']:
"""
Iterates over the hidden service descriptors in a file.
@@ -468,12 +471,12 @@ def _parse_file(descriptor_file: BinaryIO, desc_type: str = None, validate: bool
if descriptor_content[0].startswith(b'@type'):
descriptor_content = descriptor_content[1:]
- yield desc_type(bytes.join(b'', descriptor_content), validate, **kwargs)
+ yield desc_type(bytes.join(b'', descriptor_content), validate, **kwargs) # type: ignore
else:
break # done parsing file
-def _decrypt_layer(encrypted_block: bytes, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> str:
+def _decrypt_layer(encrypted_block: str, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> str:
if encrypted_block.startswith('-----BEGIN MESSAGE-----\n') and encrypted_block.endswith('\n-----END MESSAGE-----'):
encrypted_block = encrypted_block[24:-22]
@@ -492,7 +495,7 @@ def _decrypt_layer(encrypted_block: bytes, constant: bytes, revision_counter: in
cipher, mac_for = _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt)
if expected_mac != mac_for(ciphertext):
- raise ValueError('Malformed mac (expected %s, but was %s)' % (expected_mac, mac_for(ciphertext)))
+ raise ValueError('Malformed mac (expected %s, but was %s)' % (stem.util.str_tools._to_unicode(expected_mac), stem.util.str_tools._to_unicode(mac_for(ciphertext))))
decryptor = cipher.decryptor()
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
@@ -500,7 +503,7 @@ def _decrypt_layer(encrypted_block: bytes, constant: bytes, revision_counter: in
return stem.util.str_tools._to_unicode(plaintext)
-def _encrypt_layer(plaintext: str, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes:
+def _encrypt_layer(plaintext: bytes, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes:
salt = os.urandom(16)
cipher, mac_for = _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt)
@@ -511,7 +514,7 @@ def _encrypt_layer(plaintext: str, constant: bytes, revision_counter: int, subcr
return b'-----BEGIN MESSAGE-----\n%s\n-----END MESSAGE-----' % b'\n'.join(stem.util.str_tools._split_by_length(encoded, 64))
-def _layer_cipher(constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes, salt: bytes) -> Tuple['cryptography.hazmat.primitives.ciphers.Cipher', Callable[[bytes], bytes]]:
+def _layer_cipher(constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes, salt: bytes) -> Tuple['cryptography.hazmat.primitives.ciphers.Cipher', Callable[[bytes], bytes]]: # type: ignore
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
@@ -531,7 +534,7 @@ def _layer_cipher(constant: bytes, revision_counter: int, subcredential: bytes,
return cipher, lambda ciphertext: hashlib.sha3_256(mac_prefix + ciphertext).digest()
-def _parse_protocol_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_protocol_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value('protocol-versions', entries)
try:
@@ -546,7 +549,7 @@ def _parse_protocol_versions_line(descriptor: 'stem.descriptor.Descriptor', entr
descriptor.protocol_versions = versions
-def _parse_introduction_points_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_introduction_points_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
_, block_type, block_contents = entries['introduction-points'][0]
if not block_contents or block_type != 'MESSAGE':
@@ -560,7 +563,7 @@ def _parse_introduction_points_line(descriptor: 'stem.descriptor.Descriptor', en
raise ValueError("'introduction-points' isn't base64 encoded content:\n%s" % block_contents)
-def _parse_v3_outer_clients(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_v3_outer_clients(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "auth-client" client-id iv encrypted-cookie
clients = {}
@@ -576,7 +579,7 @@ def _parse_v3_outer_clients(descriptor: 'stem.descriptor.Descriptor', entries: D
descriptor.clients = clients
-def _parse_v3_inner_formats(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_v3_inner_formats(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value, formats = _value('create2-formats', entries), []
for entry in value.split(' '):
@@ -588,7 +591,7 @@ def _parse_v3_inner_formats(descriptor: 'stem.descriptor.Descriptor', entries: D
descriptor.formats = formats
-def _parse_v3_introduction_points(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_v3_introduction_points(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
if hasattr(descriptor, '_unparsed_introduction_points'):
introduction_points = []
remaining = descriptor._unparsed_introduction_points
@@ -674,7 +677,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
'introduction_points_encoded': (None, _parse_introduction_points_line),
'introduction_points_content': (None, _parse_introduction_points_line),
'signature': (None, _parse_v2_signature_line),
- }
+ } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = {
'rendezvous-service-descriptor': _parse_rendezvous_service_descriptor_line,
@@ -688,7 +691,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
}
@classmethod
- def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('rendezvous-service-descriptor', 'y3olqqblqw2gbh6phimfuiroechjjafa'),
('version', '2'),
@@ -705,7 +708,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = (), validate: bool = True) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV2':
return cls(cls.content(attr, exclude), validate = validate, skip_crypto_validation = True)
- def __init__(self, raw_contents: str, validate: bool = False, skip_crypto_validation: bool = False) -> None:
+ def __init__(self, raw_contents: bytes, validate: bool = False, skip_crypto_validation: bool = False) -> None:
super(HiddenServiceDescriptorV2, self).__init__(raw_contents, lazy_load = not validate)
entries = _descriptor_components(raw_contents, validate, non_ascii_fields = ('introduction-points'))
@@ -737,11 +740,11 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
self._entries = entries
@functools.lru_cache()
- def introduction_points(self, authentication_cookie: Optional[str] = None) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']:
+ def introduction_points(self, authentication_cookie: Optional[bytes] = None) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']:
"""
Provided this service's introduction points.
- :param str authentication_cookie: base64 encoded authentication cookie
+ :param bytes authentication_cookie: base64 encoded authentication cookie
:returns: **list** of :class:`~stem.descriptor.hidden_service.IntroductionPointV2`
@@ -777,7 +780,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
return HiddenServiceDescriptorV2._parse_introduction_points(content)
@staticmethod
- def _decrypt_basic_auth(content: bytes, authentication_cookie: str) -> bytes:
+ def _decrypt_basic_auth(content: bytes, authentication_cookie: bytes) -> bytes:
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
@@ -787,7 +790,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
try:
client_blocks = int(binascii.hexlify(content[1:2]), 16)
except ValueError:
- raise DecryptionFailure("When using basic auth the content should start with a number of blocks but wasn't a hex digit: %s" % binascii.hexlify(content[1:2]))
+ raise DecryptionFailure("When using basic auth the content should start with a number of blocks but wasn't a hex digit: %s" % binascii.hexlify(content[1:2]).decode('utf-8'))
# parse the client id and encrypted session keys
@@ -824,7 +827,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
return content # nope, unable to decrypt the content
@staticmethod
- def _decrypt_stealth_auth(content: bytes, authentication_cookie: str) -> bytes:
+ def _decrypt_stealth_auth(content: bytes, authentication_cookie: bytes) -> bytes:
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
@@ -888,7 +891,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor):
auth_type, auth_data = auth_value.split(' ')[:2]
auth_entries.append((auth_type, auth_data))
- introduction_points.append(IntroductionPointV2(**attr))
+ introduction_points.append(IntroductionPointV2(**attr)) # type: ignore
return introduction_points
@@ -931,7 +934,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
}
@classmethod
- def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> str:
+ def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.certificate.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> bytes: # type: ignore
"""
Hidden service v3 descriptors consist of three parts:
@@ -992,7 +995,12 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
blinded_key = _blinded_pubkey(identity_key, blinding_nonce) if blinding_nonce else b'a' * 32
subcredential = HiddenServiceDescriptorV3._subcredential(identity_key, blinded_key)
- custom_sig = attr.pop('signature') if (attr and 'signature' in attr) else None
+
+ if attr and 'signature' in attr:
+ custom_sig = attr['signature']
+ attr = dict(filter(lambda entry: entry[0] != 'signature', attr.items()))
+ else:
+ custom_sig = None
if not outer_layer:
outer_layer = OuterLayer.create(
@@ -1014,7 +1022,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
('descriptor-lifetime', '180'),
('descriptor-signing-key-cert', '\n' + signing_cert.to_base64(pem = True)),
('revision-counter', str(revision_counter)),
- ('superencrypted', b'\n' + outer_layer._encrypt(revision_counter, subcredential, blinded_key)),
+ ('superencrypted', stem.util.str_tools._to_unicode(b'\n' + outer_layer._encrypt(revision_counter, subcredential, blinded_key))),
), ()) + b'\n'
if custom_sig:
@@ -1026,13 +1034,13 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
return desc_content
@classmethod
- def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3':
+ def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.certificate.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3': # type: ignore
return cls(cls.content(attr, exclude, sign, inner_layer, outer_layer, identity_key, signing_key, signing_cert, revision_counter, blinding_nonce), validate = validate)
def __init__(self, raw_contents: bytes, validate: bool = False) -> None:
super(HiddenServiceDescriptorV3, self).__init__(raw_contents, lazy_load = not validate)
- self._inner_layer = None
+ self._inner_layer = None # type: Optional[stem.descriptor.hidden_service.InnerLayer]
entries = _descriptor_components(raw_contents, validate)
if validate:
@@ -1089,7 +1097,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
return self._inner_layer
@staticmethod
- def address_from_identity_key(key: Union[bytes, 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'], suffix: bool = True) -> str:
+ def address_from_identity_key(key: Union[bytes, 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'], suffix: bool = True) -> str: # type: ignore
"""
Converts a hidden service identity key into its address. This accepts all
key formats (private, public, or public bytes).
@@ -1112,7 +1120,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
return stem.util.str_tools._to_unicode(onion_address + b'.onion' if suffix else onion_address).lower()
@staticmethod
- def identity_key_from_address(onion_address: str) -> bool:
+ def identity_key_from_address(onion_address: str) -> bytes:
"""
Converts a hidden service address into its public identity key.
@@ -1149,7 +1157,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
return pubkey
@staticmethod
- def _subcredential(identity_key: bytes, blinded_key: bytes) -> bytes:
+ def _subcredential(identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes) -> bytes: # type: ignore
# credential = H('credential' | public-identity-key)
# subcredential = H('subcredential' | credential | blinded-public-key)
@@ -1179,7 +1187,7 @@ class OuterLayer(Descriptor):
'ephemeral_key': (None, _parse_v3_outer_ephemeral_key),
'clients': ({}, _parse_v3_outer_clients),
'encrypted': (None, _parse_v3_outer_encrypted),
- }
+ } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = {
'desc-auth-type': _parse_v3_outer_auth_type,
@@ -1189,9 +1197,9 @@ class OuterLayer(Descriptor):
}
@staticmethod
- def _decrypt(encrypted: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.OuterLayer':
+ def _decrypt(encrypted: str, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.OuterLayer':
plaintext = _decrypt_layer(encrypted, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key)
- return OuterLayer(plaintext)
+ return OuterLayer(stem.util.str_tools._to_bytes(plaintext))
def _encrypt(self, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes:
# Spec mandated padding: "Before encryption the plaintext is padded with
@@ -1204,7 +1212,7 @@ class OuterLayer(Descriptor):
return _encrypt_layer(content, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key)
@classmethod
- def content(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: Optional[int] = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> str:
+ def content(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: Optional[int] = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> bytes:
try:
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
@@ -1230,11 +1238,11 @@ class OuterLayer(Descriptor):
return _descriptor_content(attr, exclude, [
('desc-auth-type', 'x25519'),
- ('desc-auth-ephemeral-key', base64.b64encode(stem.util._pubkey_bytes(X25519PrivateKey.generate()))),
+ ('desc-auth-ephemeral-key', stem.util.str_tools._to_unicode(base64.b64encode(stem.util._pubkey_bytes(X25519PrivateKey.generate())))),
] + [
('auth-client', '%s %s %s' % (c.id, c.iv, c.cookie)) for c in authorized_clients
], (
- ('encrypted', b'\n' + inner_layer._encrypt(revision_counter, subcredential, blinded_key)),
+ ('encrypted', stem.util.str_tools._to_unicode(b'\n' + inner_layer._encrypt(revision_counter, subcredential, blinded_key))),
))
@classmethod
@@ -1285,17 +1293,17 @@ class InnerLayer(Descriptor):
}
@staticmethod
- def _decrypt(outer_layer: 'stem.descriptor.hidden_service.OuterLayer', revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes:
+ def _decrypt(outer_layer: 'stem.descriptor.hidden_service.OuterLayer', revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.InnerLayer':
plaintext = _decrypt_layer(outer_layer.encrypted, b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key)
- return InnerLayer(plaintext, validate = True, outer_layer = outer_layer)
+ return InnerLayer(stem.util.str_tools._to_bytes(plaintext), validate = True, outer_layer = outer_layer)
- def _encrypt(self, revision_counter, subcredential, blinded_key):
+ def _encrypt(self, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes:
# encrypt back into an outer layer's 'encrypted' field
return _encrypt_layer(self.get_bytes(), b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key)
@classmethod
- def content(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> str:
+ def content(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> bytes:
if introduction_points:
suffix = '\n' + '\n'.join(map(IntroductionPointV3.encode, introduction_points))
else:
@@ -1342,7 +1350,7 @@ def _blinded_pubkey(identity_key: bytes, blinding_nonce: bytes) -> bytes:
return ed25519.encodepoint(ed25519.scalarmult(P, mult))
-def _blinded_sign(msg: bytes, identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes, blinding_nonce: bytes) -> bytes:
+def _blinded_sign(msg: bytes, identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes, blinding_nonce: bytes) -> bytes: # type: ignore
try:
from cryptography.hazmat.primitives import serialization
except ImportError:
diff --git a/stem/descriptor/microdescriptor.py b/stem/descriptor/microdescriptor.py
index c2c104ff..7bd241e4 100644
--- a/stem/descriptor/microdescriptor.py
+++ b/stem/descriptor/microdescriptor.py
@@ -72,6 +72,7 @@ import stem.exit_policy
from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Type, Union
from stem.descriptor import (
+ ENTRY_TYPE,
Descriptor,
DigestHash,
DigestEncoding,
@@ -120,6 +121,9 @@ def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any
* **IOError** if the file can't be read
"""
+ if kwargs:
+ raise ValueError('BUG: keyword arguments unused by microdescriptors')
+
while True:
annotations = _read_until_keywords('onion-key', descriptor_file)
@@ -156,12 +160,12 @@ def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any
descriptor_text = bytes.join(b'', descriptor_lines)
- yield Microdescriptor(descriptor_text, validate, annotations, **kwargs)
+ yield Microdescriptor(descriptor_text, validate, annotations)
else:
break # done parsing descriptors
-def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
identities = {}
for entry in _values('id', entries):
@@ -246,12 +250,12 @@ class Microdescriptor(Descriptor):
}
@classmethod
- def content(cls: Type['stem.descriptor.microdescriptor.Microdescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.microdescriptor.Microdescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('onion-key', _random_crypto_blob('RSA PUBLIC KEY')),
))
- def __init__(self, raw_contents, validate = False, annotations = None):
+ def __init__(self, raw_contents: bytes, validate: bool = False, annotations: Optional[Sequence[bytes]] = None) -> None:
super(Microdescriptor, self).__init__(raw_contents, lazy_load = not validate)
self._annotation_lines = annotations if annotations else []
entries = _descriptor_components(raw_contents, validate)
@@ -262,7 +266,7 @@ class Microdescriptor(Descriptor):
else:
self._entries = entries
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA256, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.BASE64) -> Union[str, 'hashlib.HASH']:
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA256, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.BASE64) -> Union[str, 'hashlib._HASH']: # type: ignore
"""
Digest of this microdescriptor. These are referenced by...
@@ -287,7 +291,7 @@ class Microdescriptor(Descriptor):
raise NotImplementedError('Microdescriptor digests are only available in sha1 and sha256, not %s' % hash_type)
@functools.lru_cache()
- def get_annotations(self) -> Dict[str, str]:
+ def get_annotations(self) -> Dict[bytes, bytes]:
"""
Provides content that appeared prior to the descriptor. If this comes from
the cached-microdescs then this commonly contains content like...
@@ -310,7 +314,7 @@ class Microdescriptor(Descriptor):
return annotation_dict
- def get_annotation_lines(self) -> Sequence[str]:
+ def get_annotation_lines(self) -> Sequence[bytes]:
"""
Provides the lines of content that appeared prior to the descriptor. This
is the same as the
@@ -322,7 +326,7 @@ class Microdescriptor(Descriptor):
return self._annotation_lines
- def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None:
+ def _check_constraints(self, entries: ENTRY_TYPE) -> None:
"""
Does a basic check that the entries conform to this descriptor type's
constraints.
diff --git a/stem/descriptor/networkstatus.py b/stem/descriptor/networkstatus.py
index 48940987..6c0f5e8f 100644
--- a/stem/descriptor/networkstatus.py
+++ b/stem/descriptor/networkstatus.py
@@ -65,9 +65,10 @@ import stem.util.str_tools
import stem.util.tor_tools
import stem.version
-from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Type
+from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import (
+ ENTRY_TYPE,
PGP_BLOCK_END,
Descriptor,
DigestHash,
@@ -295,7 +296,7 @@ class DocumentDigest(collections.namedtuple('DocumentDigest', ['flavor', 'algori
"""
-def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.descriptor.networkstatus.NetworkStatusDocument']] = None, validate: bool = False, is_microdescriptor: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, **kwargs: Any) -> 'stem.descriptor.networkstatus.NetworkStatusDocument':
+def _parse_file(document_file: BinaryIO, document_type: Optional[Type] = None, validate: bool = False, is_microdescriptor: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, **kwargs: Any) -> Iterator[Union['stem.descriptor.networkstatus.NetworkStatusDocument', 'stem.descriptor.router_status_entry.RouterStatusEntry']]:
"""
Parses a network status and iterates over the RouterStatusEntry in it. The
document that these instances reference have an empty 'routers' attribute to
@@ -324,6 +325,8 @@ def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.desc
if document_type is None:
document_type = NetworkStatusDocumentV3
+ router_type = None # type: Optional[Type[stem.descriptor.router_status_entry.RouterStatusEntry]]
+
if document_type == NetworkStatusDocumentV2:
document_type, router_type = NetworkStatusDocumentV2, RouterStatusEntryV2
elif document_type == NetworkStatusDocumentV3:
@@ -334,10 +337,10 @@ def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.desc
yield document_type(document_file.read(), validate, **kwargs)
return
else:
- raise ValueError("Document type %i isn't recognized (only able to parse v2, v3, and bridge)" % document_type)
+ raise ValueError("Document type %s isn't recognized (only able to parse v2, v3, and bridge)" % document_type)
if document_handler == DocumentHandler.DOCUMENT:
- yield document_type(document_file.read(), validate, **kwargs)
+ yield document_type(document_file.read(), validate, **kwargs) # type: ignore
return
# getting the document without the routers section
@@ -355,7 +358,7 @@ def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.desc
document_content = bytes.join(b'', header + footer)
if document_handler == DocumentHandler.BARE_DOCUMENT:
- yield document_type(document_content, validate, **kwargs)
+ yield document_type(document_content, validate, **kwargs) # type: ignore
elif document_handler == DocumentHandler.ENTRIES:
desc_iterator = stem.descriptor.router_status_entry._parse_file(
document_file,
@@ -433,7 +436,7 @@ class NetworkStatusDocument(Descriptor):
Common parent for network status documents.
"""
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> None:
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore
"""
Digest of this descriptor's content. These are referenced by...
@@ -460,8 +463,8 @@ class NetworkStatusDocument(Descriptor):
raise NotImplementedError('Network status document digests are only available in sha1 and sha256, not %s' % hash_type)
-def _parse_version_line(keyword: str, attribute: str, expected_version: int) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_version_line(keyword: str, attribute: str, expected_version: int) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value(keyword, entries)
if not value.isdigit():
@@ -475,7 +478,7 @@ def _parse_version_line(keyword: str, attribute: str, expected_version: int) ->
return _parse
-def _parse_dir_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_dir_source_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value('dir-source', entries)
dir_source_comp = value.split()
@@ -495,7 +498,7 @@ def _parse_dir_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Di
descriptor.dir_port = None if dir_source_comp[2] == '0' else int(dir_source_comp[2])
-def _parse_additional_digests(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_additional_digests(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
digests = []
for val in _values('additional-digest', entries):
@@ -509,7 +512,7 @@ def _parse_additional_digests(descriptor: 'stem.descriptor.Descriptor', entries:
descriptor.additional_digests = digests
-def _parse_additional_signatures(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_additional_signatures(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
signatures = []
for val, block_type, block_contents in entries['additional-signature']:
@@ -584,7 +587,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument):
'signing_authority': (None, _parse_directory_signature_line),
'signatures': (None, _parse_directory_signature_line),
- }
+ } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = {
'network-status-version': _parse_network_status_version_line,
@@ -600,7 +603,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument):
}
@classmethod
- def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('network-status-version', '2'),
('dir-source', '%s %s 80' % (_random_ipv4_address(), _random_ipv4_address())),
@@ -648,7 +651,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument):
else:
self._entries = entries
- def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None:
+ def _check_constraints(self, entries: ENTRY_TYPE) -> None:
required_fields = [field for (field, is_mandatory) in NETWORK_STATUS_V2_FIELDS if is_mandatory]
for keyword in required_fields:
if keyword not in entries:
@@ -664,7 +667,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument):
raise ValueError("Network status document (v2) are expected to start with a 'network-status-version' line:\n%s" % str(self))
-def _parse_header_network_status_version_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_header_network_status_version_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "network-status-version" version
value = _value('network-status-version', entries)
@@ -685,7 +688,7 @@ def _parse_header_network_status_version_line(descriptor: 'stem.descriptor.Descr
raise ValueError("Expected a version 3 network status document, got version '%s' instead" % descriptor.version)
-def _parse_header_vote_status_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_header_vote_status_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "vote-status" type
#
# The consensus-method and consensus-methods fields are optional since
@@ -702,7 +705,7 @@ def _parse_header_vote_status_line(descriptor: 'stem.descriptor.Descriptor', ent
raise ValueError("A network status document's vote-status line can only be 'consensus' or 'vote', got '%s' instead" % value)
-def _parse_header_consensus_methods_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_header_consensus_methods_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "consensus-methods" IntegerList
if descriptor._lazy_loading and descriptor.is_vote:
@@ -719,7 +722,7 @@ def _parse_header_consensus_methods_line(descriptor: 'stem.descriptor.Descriptor
descriptor.consensus_methods = consensus_methods
-def _parse_header_consensus_method_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_header_consensus_method_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "consensus-method" Integer
if descriptor._lazy_loading and descriptor.is_consensus:
@@ -733,7 +736,7 @@ def _parse_header_consensus_method_line(descriptor: 'stem.descriptor.Descriptor'
descriptor.consensus_method = int(value)
-def _parse_header_voting_delay_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_header_voting_delay_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "voting-delay" VoteSeconds DistSeconds
value = _value('voting-delay', entries)
@@ -746,8 +749,8 @@ def _parse_header_voting_delay_line(descriptor: 'stem.descriptor.Descriptor', en
raise ValueError("A network status document's 'voting-delay' line must be a pair of integer values, but was '%s'" % value)
-def _parse_versions_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]:
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_versions_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]:
+ def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value, entries = _value(keyword, entries), []
for entry in value.split(','):
@@ -761,7 +764,7 @@ def _parse_versions_line(keyword: str, attribute: str) -> Callable[['stem.descri
return _parse
-def _parse_header_flag_thresholds_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_header_flag_thresholds_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "flag-thresholds" SP THRESHOLDS
value, thresholds = _value('flag-thresholds', entries).strip(), {}
@@ -784,7 +787,7 @@ def _parse_header_flag_thresholds_line(descriptor: 'stem.descriptor.Descriptor',
descriptor.flag_thresholds = thresholds
-def _parse_header_parameters_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_header_parameters_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "params" [Parameters]
# Parameter ::= Keyword '=' Int32
# Int32 ::= A decimal integer between -2147483648 and 2147483647.
@@ -800,7 +803,7 @@ def _parse_header_parameters_line(descriptor: 'stem.descriptor.Descriptor', entr
descriptor._check_params_constraints()
-def _parse_directory_footer_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_directory_footer_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# nothing to parse, simply checking that we don't have a value
value = _value('directory-footer', entries)
@@ -809,7 +812,7 @@ def _parse_directory_footer_line(descriptor: 'stem.descriptor.Descriptor', entri
raise ValueError("A network status document's 'directory-footer' line shouldn't have any content, got 'directory-footer %s'" % value)
-def _parse_footer_directory_signature_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_footer_directory_signature_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
signatures = []
for sig_value, block_type, block_contents in entries['directory-signature']:
@@ -830,7 +833,7 @@ def _parse_footer_directory_signature_line(descriptor: 'stem.descriptor.Descript
descriptor.signatures = signatures
-def _parse_package_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_package_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
package_versions = []
for value, _, _ in entries['package']:
@@ -851,7 +854,7 @@ def _parse_package_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[
descriptor.packages = package_versions
-def _parsed_shared_rand_commit(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parsed_shared_rand_commit(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "shared-rand-commit" Version AlgName Identity Commit [Reveal]
commitments = []
@@ -873,7 +876,7 @@ def _parsed_shared_rand_commit(descriptor: 'stem.descriptor.Descriptor', entries
descriptor.shared_randomness_commitments = commitments
-def _parse_shared_rand_previous_value(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_shared_rand_previous_value(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "shared-rand-previous-value" NumReveals Value
value = _value('shared-rand-previous-value', entries)
@@ -886,7 +889,7 @@ def _parse_shared_rand_previous_value(descriptor: 'stem.descriptor.Descriptor',
raise ValueError("A network status document's 'shared-rand-previous-value' line must be a pair of values, the first an integer but was '%s'" % value)
-def _parse_shared_rand_current_value(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_shared_rand_current_value(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "shared-rand-current-value" NumReveals Value
value = _value('shared-rand-current-value', entries)
@@ -899,7 +902,7 @@ def _parse_shared_rand_current_value(descriptor: 'stem.descriptor.Descriptor', e
raise ValueError("A network status document's 'shared-rand-current-value' line must be a pair of values, the first an integer but was '%s'" % value)
-def _parse_bandwidth_file_headers(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_bandwidth_file_headers(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "bandwidth-file-headers" KeyValues
# KeyValues ::= "" | KeyValue | KeyValues SP KeyValue
# KeyValue ::= Keyword '=' Value
@@ -914,7 +917,7 @@ def _parse_bandwidth_file_headers(descriptor: 'stem.descriptor.Descriptor', entr
descriptor.bandwidth_file_headers = results
-def _parse_bandwidth_file_digest(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_bandwidth_file_digest(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "bandwidth-file-digest" 1*(SP algorithm "=" digest)
value = _value('bandwidth-file-digest', entries)
@@ -1098,7 +1101,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
}
@classmethod
- def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> str:
+ def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> bytes:
attr = {} if attr is None else dict(attr)
is_vote = attr.get('vote-status') == 'vote'
@@ -1170,10 +1173,10 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return desc_content
@classmethod
- def create(cls: Type['stem.descriptor.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, authorities: Optional[Sequence['stem.directory.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> 'stem.descriptor.NetworkStatusDocumentV3':
+ def create(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> 'stem.descriptor.networkstatus.NetworkStatusDocumentV3':
return cls(cls.content(attr, exclude, authorities, routers), validate = validate)
- def __init__(self, raw_content: str, validate: bool = False, default_params: bool = True) -> None:
+ def __init__(self, raw_content: bytes, validate: bool = False, default_params: bool = True) -> None:
"""
Parse a v3 network status document.
@@ -1188,13 +1191,15 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
super(NetworkStatusDocumentV3, self).__init__(raw_content, lazy_load = not validate)
document_file = io.BytesIO(raw_content)
+ self._header_entries = None # type: Optional[ENTRY_TYPE]
+
self._default_params = default_params
self._header(document_file, validate)
self.directory_authorities = tuple(stem.descriptor.router_status_entry._parse_file(
document_file,
validate,
- entry_class = DirectoryAuthority,
+ entry_class = DirectoryAuthority, # type: ignore # TODO: move to another parse_file()
entry_keyword = AUTH_START,
section_end_keywords = (ROUTERS_START, FOOTER_START, V2_FOOTER_START),
extra_args = (self.is_vote,),
@@ -1255,13 +1260,13 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return self.valid_after < datetime.datetime.utcnow() < self.fresh_until
- def validate_signatures(self, key_certs: Sequence['stem.descriptor.networkstatus.KeyCertificates']) -> None:
+ def validate_signatures(self, key_certs: Sequence['stem.descriptor.networkstatus.KeyCertificate']) -> None:
"""
Validates we're properly signed by the signing certificates.
.. versionadded:: 1.6.0
- :param list key_certs: :class:`~stem.descriptor.networkstatus.KeyCertificates`
+ :param list key_certs: :class:`~stem.descriptor.networkstatus.KeyCertificate`
to validate the consensus against
:raises: **ValueError** if an insufficient number of valid signatures are present.
@@ -1289,7 +1294,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
if valid_digests < required_digests:
raise ValueError('Network Status Document has %i valid signatures out of %i total, needed %i' % (valid_digests, total_digests, required_digests))
- def get_unrecognized_lines(self) -> Sequence[str]:
+ def get_unrecognized_lines(self) -> List[str]:
if self._lazy_loading:
self._parse(self._header_entries, False, parser_for_line = self._HEADER_PARSER_FOR_LINE)
self._parse(self._footer_entries, False, parser_for_line = self._FOOTER_PARSER_FOR_LINE)
@@ -1308,10 +1313,10 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
:returns: **True** if we meet the given consensus-method, and **False** otherwise
"""
- if self.consensus_method is not None:
- return self.consensus_method >= method
- elif self.consensus_methods is not None:
- return bool([x for x in self.consensus_methods if x >= method])
+ if self.consensus_method is not None: # type: ignore
+ return self.consensus_method >= method # type: ignore
+ elif self.consensus_methods is not None: # type: ignore
+ return bool([x for x in self.consensus_methods if x >= method]) # type: ignore
else:
return False # malformed document
@@ -1341,9 +1346,9 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
# default consensus_method and consensus_methods based on if we're a consensus or vote
- if self.is_consensus and not self.consensus_method:
+ if self.is_consensus and not self.consensus_method: # type: ignore
self.consensus_method = 1
- elif self.is_vote and not self.consensus_methods:
+ elif self.is_vote and not self.consensus_methods: # type: ignore
self.consensus_methods = [1]
else:
self._header_entries = entries
@@ -1400,7 +1405,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
raise ValueError("'%s' value on the params line must be in the range of %i - %i, was %i" % (key, minimum, maximum, value))
-def _check_for_missing_and_disallowed_fields(document: 'stem.descriptor.networkstatus.NetworkStatusDocumentV3', entries: Mapping[str, str], fields: Sequence[str]) -> None:
+def _check_for_missing_and_disallowed_fields(document: 'stem.descriptor.networkstatus.NetworkStatusDocumentV3', entries: ENTRY_TYPE, fields: Sequence[Tuple[str, bool, bool, bool]]) -> None:
"""
Checks that we have mandatory fields for our type, and that we don't have
any fields exclusive to the other (ie, no vote-only fields appear in a
@@ -1438,7 +1443,8 @@ def _parse_int_mappings(keyword: str, value: str, validate: bool) -> Dict[str, i
# - values are integers
# - keys are sorted in lexical order
- results, seen_keys = {}, []
+ results = {} # type: Dict[str, int]
+ seen_keys = [] # type: List[str]
error_template = "Unable to parse network status document's '%s' line (%%s): %s'" % (keyword, value)
for key, val in _mappings_for(keyword, value):
@@ -1463,7 +1469,7 @@ def _parse_int_mappings(keyword: str, value: str, validate: bool) -> Dict[str, i
return results
-def _parse_dirauth_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_dirauth_source_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "dir-source" nickname identity address IP dirport orport
value = _value('dir-source', entries)
@@ -1582,7 +1588,7 @@ class DirectoryAuthority(Descriptor):
}
@classmethod
- def content(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), is_vote: bool = False) -> str:
+ def content(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), is_vote: bool = False) -> bytes:
attr = {} if attr is None else dict(attr)
# include mandatory 'vote-digest' if a consensus
@@ -1604,7 +1610,7 @@ class DirectoryAuthority(Descriptor):
def create(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, is_vote: bool = False) -> 'stem.descriptor.networkstatus.DirectoryAuthority':
return cls(cls.content(attr, exclude, is_vote), validate = validate, is_vote = is_vote)
- def __init__(self, raw_content: str, validate: bool = False, is_vote: bool = False) -> None:
+ def __init__(self, raw_content: bytes, validate: bool = False, is_vote: bool = False) -> None:
"""
Parse a directory authority entry in a v3 network status document.
@@ -1623,12 +1629,12 @@ class DirectoryAuthority(Descriptor):
key_div = content.find('\ndir-key-certificate-version')
if key_div != -1:
- self.key_certificate = KeyCertificate(content[key_div + 1:], validate)
+ self.key_certificate = KeyCertificate(content[key_div + 1:].encode('utf-8'), validate)
content = content[:key_div + 1]
else:
self.key_certificate = None
- entries = _descriptor_components(content, validate)
+ entries = _descriptor_components(content.encode('utf-8'), validate)
if validate and 'dir-source' != list(entries.keys())[0]:
raise ValueError("Authority entries are expected to start with a 'dir-source' line:\n%s" % (content))
@@ -1679,7 +1685,7 @@ class DirectoryAuthority(Descriptor):
self._entries = entries
-def _parse_dir_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_dir_address_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "dir-address" IPPort
value = _value('dir-address', entries)
@@ -1754,7 +1760,7 @@ class KeyCertificate(Descriptor):
}
@classmethod
- def content(cls: Type['stem.descriptor.networkstatus.KeyCertificate'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.networkstatus.KeyCertificate'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('dir-key-certificate-version', '3'),
('fingerprint', _random_fingerprint()),
@@ -1766,26 +1772,26 @@ class KeyCertificate(Descriptor):
('dir-key-certification', _random_crypto_blob('SIGNATURE')),
))
- def __init__(self, raw_content: str, validate: str = False) -> None:
+ def __init__(self, raw_content: bytes, validate: bool = False) -> None:
super(KeyCertificate, self).__init__(raw_content, lazy_load = not validate)
entries = _descriptor_components(raw_content, validate)
if validate:
if 'dir-key-certificate-version' != list(entries.keys())[0]:
- raise ValueError("Key certificates must start with a 'dir-key-certificate-version' line:\n%s" % (raw_content))
+ raise ValueError("Key certificates must start with a 'dir-key-certificate-version' line:\n%s" % stem.util.str_tools._to_unicode(raw_content))
elif 'dir-key-certification' != list(entries.keys())[-1]:
- raise ValueError("Key certificates must end with a 'dir-key-certification' line:\n%s" % (raw_content))
+ raise ValueError("Key certificates must end with a 'dir-key-certification' line:\n%s" % stem.util.str_tools._to_unicode(raw_content))
# check that we have mandatory fields and that our known fields only
# appear once
for keyword, is_mandatory in KEY_CERTIFICATE_PARAMS:
if is_mandatory and keyword not in entries:
- raise ValueError("Key certificates must have a '%s' line:\n%s" % (keyword, raw_content))
+ raise ValueError("Key certificates must have a '%s' line:\n%s" % (keyword, stem.util.str_tools._to_unicode(raw_content)))
entry_count = len(entries.get(keyword, []))
if entry_count > 1:
- raise ValueError("Key certificates can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, raw_content))
+ raise ValueError("Key certificates can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, stem.util.str_tools._to_unicode(raw_content)))
self._parse(entries, validate)
else:
@@ -1887,7 +1893,7 @@ class DetachedSignature(Descriptor):
'additional_digests': ([], _parse_additional_digests),
'additional_signatures': ([], _parse_additional_signatures),
'signatures': ([], _parse_footer_directory_signature_line),
- }
+ } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = {
'consensus-digest': _parse_consensus_digest_line,
@@ -1900,7 +1906,7 @@ class DetachedSignature(Descriptor):
}
@classmethod
- def content(cls: Type['stem.descriptor.networkstatus.DetachedSignature'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.networkstatus.DetachedSignature'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('consensus-digest', '6D3CC0EFA408F228410A4A8145E1B0BB0670E442'),
('valid-after', _random_date()),
@@ -1908,23 +1914,23 @@ class DetachedSignature(Descriptor):
('valid-until', _random_date()),
))
- def __init__(self, raw_content: str, validate: bool = False) -> None:
+ def __init__(self, raw_content: bytes, validate: bool = False) -> None:
super(DetachedSignature, self).__init__(raw_content, lazy_load = not validate)
entries = _descriptor_components(raw_content, validate)
if validate:
if 'consensus-digest' != list(entries.keys())[0]:
- raise ValueError("Detached signatures must start with a 'consensus-digest' line:\n%s" % (raw_content))
+ raise ValueError("Detached signatures must start with a 'consensus-digest' line:\n%s" % stem.util.str_tools._to_unicode(raw_content))
# check that we have mandatory fields and certain fields only appear once
for keyword, is_mandatory, is_multiple in DETACHED_SIGNATURE_PARAMS:
if is_mandatory and keyword not in entries:
- raise ValueError("Detached signatures must have a '%s' line:\n%s" % (keyword, raw_content))
+ raise ValueError("Detached signatures must have a '%s' line:\n%s" % (keyword, stem.util.str_tools._to_unicode(raw_content)))
entry_count = len(entries.get(keyword, []))
if not is_multiple and entry_count > 1:
- raise ValueError("Detached signatures can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, raw_content))
+ raise ValueError("Detached signatures can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, stem.util.str_tools._to_unicode(raw_content)))
self._parse(entries, validate)
else:
@@ -1943,7 +1949,7 @@ class BridgeNetworkStatusDocument(NetworkStatusDocument):
TYPE_ANNOTATION_NAME = 'bridge-network-status'
- def __init__(self, raw_content: str, validate: bool = False) -> None:
+ def __init__(self, raw_content: bytes, validate: bool = False) -> None:
super(BridgeNetworkStatusDocument, self).__init__(raw_content)
self.published = None
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index f3c6d6bd..2e2bb53b 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -101,7 +101,7 @@ import stem.util.tor_tools
from stem.descriptor import Compression
from stem.util import log, str_tools
-from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
# Tor has a limited number of descriptors we can fetch explicitly by their
# fingerprint or hashes due to a limit on the url length by squid proxies.
@@ -371,7 +371,7 @@ class Query(object):
the same as running **query.run(True)** (default is **False**)
"""
- def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence['stem.Endpoint']] = None, compression: Sequence['stem.descriptor.Compression'] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
+ def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
if not resource.startswith('/'):
raise ValueError("Resources should start with a '/': %s" % resource)
@@ -380,8 +380,10 @@ class Query(object):
resource = resource[:-2]
elif isinstance(compression, tuple):
compression = list(compression)
- elif not isinstance(compression, list):
+ elif isinstance(compression, stem.descriptor._Compression):
compression = [compression] # caller provided only a single option
+ else:
+ raise ValueError('Compression should be a list of stem.descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__))
if Compression.ZSTD in compression and not Compression.ZSTD.available:
compression.remove(Compression.ZSTD)
@@ -411,21 +413,21 @@ class Query(object):
self.retries = retries
self.fall_back_to_authority = fall_back_to_authority
- self.content = None
- self.error = None
+ self.content = None # type: Optional[bytes]
+ self.error = None # type: Optional[BaseException]
self.is_done = False
- self.download_url = None
+ self.download_url = None # type: Optional[str]
- self.start_time = None
+ self.start_time = None # type: Optional[float]
self.timeout = timeout
- self.runtime = None
+ self.runtime = None # type: Optional[float]
self.validate = validate
self.document_handler = document_handler
- self.reply_headers = None
+ self.reply_headers = None # type: Optional[Dict[str, str]]
self.kwargs = kwargs
- self._downloader_thread = None
+ self._downloader_thread = None # type: Optional[threading.Thread]
self._downloader_thread_lock = threading.RLock()
if start:
@@ -450,7 +452,7 @@ class Query(object):
self._downloader_thread.setDaemon(True)
self._downloader_thread.start()
- def run(self, suppress: bool = False) -> Sequence['stem.descriptor.Descriptor']:
+ def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
@@ -470,7 +472,7 @@ class Query(object):
return list(self._run(suppress))
- def _run(self, suppress: bool) -> Iterator['stem.descriptor.Descriptor']:
+ def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]:
with self._downloader_thread_lock:
self.start()
self._downloader_thread.join()
@@ -506,11 +508,11 @@ class Query(object):
raise self.error
- def __iter__(self) -> Iterator['stem.descriptor.Descriptor']:
+ def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
for desc in self._run(True):
yield desc
- def _pick_endpoint(self, use_authority: bool = False) -> 'stem.Endpoint':
+ def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint:
"""
Provides an endpoint to query. If we have multiple endpoints then one
is picked at random.
@@ -576,7 +578,7 @@ class DescriptorDownloader(object):
def __init__(self, use_mirrors: bool = False, **default_args: Any) -> None:
self._default_args = default_args
- self._endpoints = None
+ self._endpoints = None # type: Optional[List[stem.DirPort]]
if use_mirrors:
try:
@@ -586,7 +588,7 @@ class DescriptorDownloader(object):
except Exception as exc:
log.debug('Unable to retrieve directory mirrors: %s' % exc)
- def use_directory_mirrors(self) -> 'stem.descriptor.networkstatus.NetworkStatusDocumentV3':
+ def use_directory_mirrors(self) -> stem.descriptor.networkstatus.NetworkStatusDocumentV3:
"""
Downloads the present consensus and configures ourselves to use directory
mirrors, in addition to authorities.
@@ -610,7 +612,7 @@ class DescriptorDownloader(object):
self._endpoints = list(new_endpoints)
- return consensus
+ return consensus # type: ignore
def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
@@ -776,7 +778,7 @@ class DescriptorDownloader(object):
return consensus_query
- def get_vote(self, authority: 'stem.directory.Authority', **query_args: Any) -> 'stem.descriptor.remote.Query':
+ def get_vote(self, authority: stem.directory.Authority, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Provides the present vote for a given directory authority.
@@ -924,7 +926,7 @@ class DescriptorDownloader(object):
return Query(resource, **args)
-def _download_from_orport(endpoint: 'stem.ORPort', compression: Sequence['stem.Compression'], resource: str) -> Tuple[bytes, Dict[str, str]]:
+def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
"""
Downloads descriptors from the given orport. Payload is just like an http
response (headers and all)...
@@ -974,7 +976,7 @@ def _download_from_orport(endpoint: 'stem.ORPort', compression: Sequence['stem.C
for line in str_tools._to_unicode(header_data).splitlines():
if ': ' not in line:
- raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % line)
+ raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8')))
key, value = line.split(': ', 1)
headers[key] = value
@@ -982,7 +984,7 @@ def _download_from_orport(endpoint: 'stem.ORPort', compression: Sequence['stem.C
return _decompress(body_data, headers.get('Content-Encoding')), headers
-def _download_from_dirport(url: str, compression: Sequence['stem.descriptor.Compression'], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
+def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
"""
Downloads descriptors from the given url.
@@ -1011,8 +1013,8 @@ def _download_from_dirport(url: str, compression: Sequence['stem.descriptor.Comp
except socket.timeout as exc:
raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout)
except:
- exc, stacktrace = sys.exc_info()[1:3]
- raise stem.DownloadFailed(url, exc, stacktrace)
+ exception, stacktrace = sys.exc_info()[1:3]
+ raise stem.DownloadFailed(url, exception, stacktrace)
return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
diff --git a/stem/descriptor/router_status_entry.py b/stem/descriptor/router_status_entry.py
index 20822c82..2c4937f3 100644
--- a/stem/descriptor/router_status_entry.py
+++ b/stem/descriptor/router_status_entry.py
@@ -27,9 +27,10 @@ import io
import stem.exit_policy
import stem.util.str_tools
-from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type
+from typing import Any, BinaryIO, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import (
+ ENTRY_TYPE,
KEYWORD_LINE,
Descriptor,
_descriptor_content,
@@ -37,7 +38,7 @@ from stem.descriptor import (
_values,
_descriptor_components,
_parse_protocol_line,
- _read_until_keywords,
+ _read_until_keywords_with_ending_keyword,
_random_nickname,
_random_ipv4_address,
_random_date,
@@ -46,7 +47,7 @@ from stem.descriptor import (
_parse_pr_line = _parse_protocol_line('pr', 'protocols')
-def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], entry_keyword: str = 'r', start_position: int = None, end_position: int = None, section_end_keywords: Sequence[str] = (), extra_args: Sequence[str] = ()) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']:
+def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], entry_keyword: str = 'r', start_position: Optional[int] = None, end_position: Optional[int] = None, section_end_keywords: Tuple[str, ...] = (), extra_args: Sequence[Any] = ()) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']:
"""
Reads a range of the document_file containing some number of entry_class
instances. We deliminate the entry_class entries by the keyword on their
@@ -93,7 +94,7 @@ def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem
return
while end_position is None or document_file.tell() < end_position:
- desc_lines, ending_keyword = _read_until_keywords(
+ desc_lines, ending_keyword = _read_until_keywords_with_ending_keyword(
(entry_keyword,) + section_end_keywords,
document_file,
ignore_first = True,
@@ -113,7 +114,7 @@ def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem
break
-def _parse_r_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_r_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# Parses a RouterStatusEntry's 'r' line. They're very nearly identical for
# all current entry types (v2, v3, and microdescriptor v3) with one little
# wrinkle: only the microdescriptor flavor excludes a 'digest' field.
@@ -165,7 +166,7 @@ def _parse_r_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S
raise ValueError("Publication time time wasn't parsable: r %s" % value)
-def _parse_a_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_a_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "a" SP address ":" portlist
# example: a [2001:888:2133:0:82:94:251:204]:9001
@@ -188,7 +189,7 @@ def _parse_a_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S
descriptor.or_addresses = or_addresses
-def _parse_s_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_s_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "s" Flags
# example: s Named Running Stable Valid
@@ -203,7 +204,7 @@ def _parse_s_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S
raise ValueError("%s had extra whitespace on its 's' line: s %s" % (descriptor._name(), value))
-def _parse_v_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_v_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "v" version
# example: v Tor 0.2.2.35
#
@@ -221,7 +222,7 @@ def _parse_v_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S
raise ValueError('%s has a malformed tor version (%s): v %s' % (descriptor._name(), exc, value))
-def _parse_w_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_w_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "w" "Bandwidth=" INT ["Measured=" INT] ["Unmeasured=1"]
# example: w Bandwidth=7980
@@ -268,7 +269,7 @@ def _parse_w_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S
descriptor.unrecognized_bandwidth_entries = unrecognized_bandwidth_entries
-def _parse_p_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_p_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "p" ("accept" / "reject") PortList
#
# examples:
@@ -284,7 +285,7 @@ def _parse_p_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S
raise ValueError('%s exit policy is malformed (%s): p %s' % (descriptor._name(), exc, value))
-def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "id" "ed25519" ed25519-identity
#
# examples:
@@ -307,7 +308,7 @@ def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str,
raise ValueError("'id' lines should contain both the key type and digest: id %s" % value)
-def _parse_m_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_m_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "m" methods 1*(algorithm "=" digest)
# example: m 8,9,10,11,12 sha256=g1vx9si329muxV3tquWIXXySNOIwRGMeAESKs/v4DWs
@@ -341,7 +342,7 @@ def _parse_m_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S
descriptor.microdescriptor_hashes = all_hashes
-def _parse_microdescriptor_m_line(descriptor: 'stem.descriptor.Descriptor', entries):
+def _parse_microdescriptor_m_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "m" digest
# example: m aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70
@@ -422,7 +423,7 @@ class RouterStatusEntry(Descriptor):
}
@classmethod
- def from_str(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], content: str, **kwargs: Any) -> 'stem.descriptor.router_status_entry.RouterStatusEntry':
+ def from_str(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], content: str, **kwargs: Any) -> Union['stem.descriptor.router_status_entry.RouterStatusEntry', List['stem.descriptor.router_status_entry.RouterStatusEntry']]: # type: ignore
# Router status entries don't have their own @type annotation, so to make
# our subclass from_str() work we need to do the type inferencing ourself.
@@ -442,7 +443,7 @@ class RouterStatusEntry(Descriptor):
else:
raise ValueError("Descriptor.from_str() expected a single descriptor, but had %i instead. Please include 'multiple = True' if you want a list of results instead." % len(results))
- def __init__(self, content: str, validate: bool = False, document: Optional['stem.descriptor.NetworkStatusDocument'] = None) -> None:
+ def __init__(self, content: bytes, validate: bool = False, document: Optional['stem.descriptor.networkstatus.NetworkStatusDocument'] = None) -> None:
"""
Parse a router descriptor in a network status document.
@@ -481,14 +482,14 @@ class RouterStatusEntry(Descriptor):
return 'Router status entries' if is_plural else 'Router status entry'
- def _required_fields(self) -> Tuple[str]:
+ def _required_fields(self) -> Tuple[str, ...]:
"""
Provides lines that must appear in the descriptor.
"""
return ()
- def _single_fields(self) -> Tuple[str]:
+ def _single_fields(self) -> Tuple[str, ...]:
"""
Provides lines that can only appear in the descriptor once.
"""
@@ -514,7 +515,7 @@ class RouterStatusEntryV2(RouterStatusEntry):
})
@classmethod
- def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())),
))
@@ -522,10 +523,10 @@ class RouterStatusEntryV2(RouterStatusEntry):
def _name(self, is_plural: bool = False) -> str:
return 'Router status entries (v2)' if is_plural else 'Router status entry (v2)'
- def _required_fields(self) -> Tuple[str]:
+ def _required_fields(self) -> Tuple[str, ...]:
return ('r',)
- def _single_fields(self) -> Tuple[str]:
+ def _single_fields(self) -> Tuple[str, ...]:
return ('r', 's', 'v')
@@ -579,7 +580,7 @@ class RouterStatusEntryV3(RouterStatusEntry):
TYPE_ANNOTATION_NAME = 'network-status-consensus-3'
- ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{
+ ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ # type: ignore
'digest': (None, _parse_r_line),
'or_addresses': ([], _parse_a_line),
'identifier_type': (None, _parse_id_line),
@@ -595,7 +596,7 @@ class RouterStatusEntryV3(RouterStatusEntry):
'microdescriptor_hashes': ([], _parse_m_line),
})
- PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{
+ PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ # type: ignore
'a': _parse_a_line,
'w': _parse_w_line,
'p': _parse_p_line,
@@ -605,7 +606,7 @@ class RouterStatusEntryV3(RouterStatusEntry):
})
@classmethod
- def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())),
('s', 'Fast Named Running Stable Valid'),
@@ -614,10 +615,10 @@ class RouterStatusEntryV3(RouterStatusEntry):
def _name(self, is_plural: bool = False) -> str:
return 'Router status entries (v3)' if is_plural else 'Router status entry (v3)'
- def _required_fields(self) -> Tuple[str]:
+ def _required_fields(self) -> Tuple[str, ...]:
return ('r', 's')
- def _single_fields(self) -> Tuple[str]:
+ def _single_fields(self) -> Tuple[str, ...]:
return ('r', 's', 'v', 'w', 'p', 'pr')
@@ -652,7 +653,7 @@ class RouterStatusEntryMicroV3(RouterStatusEntry):
TYPE_ANNOTATION_NAME = 'network-status-microdesc-consensus-3'
- ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{
+ ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ # type: ignore
'or_addresses': ([], _parse_a_line),
'bandwidth': (None, _parse_w_line),
'measured': (None, _parse_w_line),
@@ -662,7 +663,7 @@ class RouterStatusEntryMicroV3(RouterStatusEntry):
'microdescriptor_digest': (None, _parse_microdescriptor_m_line),
})
- PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{
+ PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ # type: ignore
'a': _parse_a_line,
'w': _parse_w_line,
'm': _parse_microdescriptor_m_line,
@@ -670,7 +671,7 @@ class RouterStatusEntryMicroV3(RouterStatusEntry):
})
@classmethod
- def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryMicroV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryMicroV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('r', '%s ARIJF2zbqirB9IwsW0mQznccWww %s %s 9001 9030' % (_random_nickname(), _random_date(), _random_ipv4_address())),
('m', 'aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70'),
@@ -680,8 +681,8 @@ class RouterStatusEntryMicroV3(RouterStatusEntry):
def _name(self, is_plural: bool = False) -> str:
return 'Router status entries (micro v3)' if is_plural else 'Router status entry (micro v3)'
- def _required_fields(self) -> Tuple[str]:
+ def _required_fields(self) -> Tuple[str, ...]:
return ('r', 's', 'm')
- def _single_fields(self) -> Tuple[str]:
+ def _single_fields(self) -> Tuple[str, ...]:
return ('r', 's', 'v', 'w', 'm', 'pr')
diff --git a/stem/descriptor/server_descriptor.py b/stem/descriptor/server_descriptor.py
index 11b44972..fbb5c633 100644
--- a/stem/descriptor/server_descriptor.py
+++ b/stem/descriptor/server_descriptor.py
@@ -61,16 +61,17 @@ import stem.version
from stem.descriptor.certificate import Ed25519Certificate
from stem.descriptor.router_status_entry import RouterStatusEntryV3
-from typing import Any, BinaryIO, Dict, Iterator, Optional, Mapping, Sequence, Tuple, Type, Union
+from typing import Any, BinaryIO, Iterator, Optional, Mapping, Sequence, Tuple, Type, Union
from stem.descriptor import (
+ ENTRY_TYPE,
PGP_BLOCK_END,
Descriptor,
DigestHash,
DigestEncoding,
create_signing_key,
_descriptor_content,
- _descriptor_components,
+ _descriptor_components_with_extra,
_read_until_keywords,
_bytes_for_block,
_value,
@@ -214,14 +215,17 @@ def _parse_file(descriptor_file: BinaryIO, is_bridge: bool = False, validate: bo
descriptor_text = bytes.join(b'', descriptor_content)
if is_bridge:
- yield BridgeDescriptor(descriptor_text, validate, **kwargs)
+ if kwargs:
+ raise ValueError('BUG: keyword arguments unused by bridge descriptors')
+
+ yield BridgeDescriptor(descriptor_text, validate)
else:
yield RelayDescriptor(descriptor_text, validate, **kwargs)
else:
break # done parsing descriptors
-def _parse_router_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_router_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "router" nickname address ORPort SocksPort DirPort
value = _value('router', entries)
@@ -247,7 +251,7 @@ def _parse_router_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[s
descriptor.dir_port = None if router_comp[4] == '0' else int(router_comp[4])
-def _parse_bandwidth_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_bandwidth_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "bandwidth" bandwidth-avg bandwidth-burst bandwidth-observed
value = _value('bandwidth', entries)
@@ -267,7 +271,7 @@ def _parse_bandwidth_line(descriptor: 'stem.descriptor.Descriptor', entries: Dic
descriptor.observed_bandwidth = int(bandwidth_comp[2])
-def _parse_platform_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_platform_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "platform" string
_parse_bytes_line('platform', 'platform')(descriptor, entries)
@@ -293,7 +297,7 @@ def _parse_platform_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict
pass
-def _parse_fingerprint_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_fingerprint_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# This is forty hex digits split into space separated groups of four.
# Checking that we match this pattern.
@@ -310,7 +314,7 @@ def _parse_fingerprint_line(descriptor: 'stem.descriptor.Descriptor', entries: D
descriptor.fingerprint = fingerprint
-def _parse_extrainfo_digest_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_extrainfo_digest_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value('extra-info-digest', entries)
digest_comp = value.split(' ')
@@ -321,7 +325,7 @@ def _parse_extrainfo_digest_line(descriptor: 'stem.descriptor.Descriptor', entri
descriptor.extra_info_sha256_digest = digest_comp[1] if len(digest_comp) >= 2 else None
-def _parse_hibernating_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_hibernating_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
# "hibernating" 0|1 (in practice only set if one)
value = _value('hibernating', entries)
@@ -332,7 +336,7 @@ def _parse_hibernating_line(descriptor: 'stem.descriptor.Descriptor', entries: D
descriptor.hibernating = value == '1'
-def _parse_protocols_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_protocols_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value('protocols', entries)
protocols_match = re.match('^Link (.*) Circuit (.*)$', value)
@@ -344,7 +348,7 @@ def _parse_protocols_line(descriptor: 'stem.descriptor.Descriptor', entries: Dic
descriptor.circuit_protocols = circuit_versions.split(' ')
-def _parse_or_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_or_address_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
all_values = _values('or-address', entries)
or_addresses = []
@@ -367,7 +371,7 @@ def _parse_or_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Di
descriptor.or_addresses = or_addresses
-def _parse_history_line(keyword: str, history_end_attribute: str, history_interval_attribute: str, history_values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_history_line(keyword: str, history_end_attribute: str, history_interval_attribute: str, history_values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
value = _value(keyword, entries)
timestamp, interval, remainder = stem.descriptor.extrainfo_descriptor._parse_timestamp_and_interval(keyword, value)
@@ -384,7 +388,7 @@ def _parse_history_line(keyword: str, history_end_attribute: str, history_interv
setattr(descriptor, history_values_attribute, history_values)
-def _parse_exit_policy(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None:
+def _parse_exit_policy(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None:
if hasattr(descriptor, '_unparsed_exit_policy'):
if descriptor._unparsed_exit_policy and stem.util.str_tools._to_unicode(descriptor._unparsed_exit_policy[0]) == 'reject *:*':
descriptor.exit_policy = REJECT_ALL_POLICY
@@ -577,7 +581,7 @@ class ServerDescriptor(Descriptor):
'eventdns': _parse_eventdns_line,
}
- def __init__(self, raw_contents: str, validate: bool = False) -> None:
+ def __init__(self, raw_contents: bytes, validate: bool = False) -> None:
"""
Server descriptor constructor, created from an individual relay's
descriptor content (as provided by 'GETINFO desc/*', cached descriptors,
@@ -604,7 +608,7 @@ class ServerDescriptor(Descriptor):
# influences the resulting exit policy, but for everything else the order
# does not matter so breaking it into key / value pairs.
- entries, self._unparsed_exit_policy = _descriptor_components(stem.util.str_tools._to_unicode(raw_contents), validate, extra_keywords = ('accept', 'reject'), non_ascii_fields = ('contact', 'platform'))
+ entries, self._unparsed_exit_policy = _descriptor_components_with_extra(raw_contents, validate, extra_keywords = ('accept', 'reject'), non_ascii_fields = ('contact', 'platform'))
if validate:
self._parse(entries, validate)
@@ -622,7 +626,7 @@ class ServerDescriptor(Descriptor):
else:
self._entries = entries
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore
"""
Digest of this descriptor's content. These are referenced by...
@@ -642,7 +646,7 @@ class ServerDescriptor(Descriptor):
raise NotImplementedError('Unsupported Operation: this should be implemented by the ServerDescriptor subclass')
- def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None:
+ def _check_constraints(self, entries: ENTRY_TYPE) -> None:
"""
Does a basic check that the entries conform to this descriptor type's
constraints.
@@ -680,16 +684,16 @@ class ServerDescriptor(Descriptor):
# Constraints that the descriptor must meet to be valid. These can be None if
# not applicable.
- def _required_fields(self) -> Tuple[str]:
+ def _required_fields(self) -> Tuple[str, ...]:
return REQUIRED_FIELDS
- def _single_fields(self) -> Tuple[str]:
+ def _single_fields(self) -> Tuple[str, ...]:
return REQUIRED_FIELDS + SINGLE_FIELDS
def _first_keyword(self) -> str:
return 'router'
- def _last_keyword(self) -> str:
+ def _last_keyword(self) -> Optional[str]:
return 'router-signature'
@@ -754,7 +758,7 @@ class RelayDescriptor(ServerDescriptor):
'router-signature': _parse_router_signature_line,
})
- def __init__(self, raw_contents: str, validate: bool = False, skip_crypto_validation: bool = False) -> None:
+ def __init__(self, raw_contents: bytes, validate: bool = False, skip_crypto_validation: bool = False) -> None:
super(RelayDescriptor, self).__init__(raw_contents, validate)
if validate:
@@ -786,9 +790,8 @@ class RelayDescriptor(ServerDescriptor):
pass # cryptography module unavailable
@classmethod
- def content(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> str:
- if attr is None:
- attr = {}
+ def content(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> bytes:
+ attr = dict(attr) if attr else {}
if exit_policy is None:
exit_policy = REJECT_ALL_POLICY
@@ -798,7 +801,7 @@ class RelayDescriptor(ServerDescriptor):
('published', _random_date()),
('bandwidth', '153600 256000 104590'),
] + [
- tuple(line.split(' ', 1)) for line in str(exit_policy).splitlines()
+ tuple(line.split(' ', 1)) for line in str(exit_policy).splitlines() # type: ignore
] + [
('onion-key', _random_crypto_blob('RSA PUBLIC KEY')),
('signing-key', _random_crypto_blob('RSA PUBLIC KEY')),
@@ -832,7 +835,7 @@ class RelayDescriptor(ServerDescriptor):
return cls(cls.content(attr, exclude, sign, signing_key, exit_policy), validate = validate, skip_crypto_validation = not sign)
@functools.lru_cache()
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore
"""
Provides the digest of our descriptor's content.
@@ -889,7 +892,7 @@ class RelayDescriptor(ServerDescriptor):
if self.certificate:
attr['id'] = 'ed25519 %s' % _truncated_b64encode(self.certificate.key)
- return RouterStatusEntryV3.create(attr)
+ return RouterStatusEntryV3.create(attr) # type: ignore
@functools.lru_cache()
def _onion_key_crosscert_digest(self) -> str:
@@ -906,7 +909,7 @@ class RelayDescriptor(ServerDescriptor):
data = signing_key_digest + base64.b64decode(stem.util.str_tools._to_bytes(self.ed25519_master_key) + b'=')
return stem.util.str_tools._to_unicode(binascii.hexlify(data).upper())
- def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None:
+ def _check_constraints(self, entries: ENTRY_TYPE) -> None:
super(RelayDescriptor, self)._check_constraints(entries)
if self.certificate:
@@ -945,7 +948,7 @@ class BridgeDescriptor(ServerDescriptor):
})
@classmethod
- def content(cls: Type['stem.descriptor.server_descriptor.BridgeDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str:
+ def content(cls: Type['stem.descriptor.server_descriptor.BridgeDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes:
return _descriptor_content(attr, exclude, (
('router', '%s %s 9001 0 0' % (_random_nickname(), _random_ipv4_address())),
('router-digest', '006FD96BA35E7785A6A3B8B75FE2E2435A13BDB4'),
@@ -954,7 +957,7 @@ class BridgeDescriptor(ServerDescriptor):
('reject', '*:*'),
))
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']:
+ def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore
if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX:
return self._digest
else:
@@ -1007,7 +1010,7 @@ class BridgeDescriptor(ServerDescriptor):
return issues
- def _required_fields(self) -> Tuple[str]:
+ def _required_fields(self) -> Tuple[str, ...]:
# bridge required fields are the same as a relay descriptor, minus items
# excluded according to the format page
@@ -1023,8 +1026,8 @@ class BridgeDescriptor(ServerDescriptor):
return tuple(included_fields + [f for f in REQUIRED_FIELDS if f not in excluded_fields])
- def _single_fields(self) -> str:
+ def _single_fields(self) -> Tuple[str, ...]:
return self._required_fields() + SINGLE_FIELDS
- def _last_keyword(self) -> str:
+ def _last_keyword(self) -> Optional[str]:
return None
diff --git a/stem/descriptor/tordnsel.py b/stem/descriptor/tordnsel.py
index c36e343d..6b9d4ceb 100644
--- a/stem/descriptor/tordnsel.py
+++ b/stem/descriptor/tordnsel.py
@@ -10,13 +10,16 @@ exit list files.
TorDNSEL - Exit list provided by TorDNSEL
"""
+import datetime
+
import stem.util.connection
import stem.util.str_tools
import stem.util.tor_tools
-from typing import Any, BinaryIO, Dict, Iterator, Sequence
+from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Optional, Tuple
from stem.descriptor import (
+ ENTRY_TYPE,
Descriptor,
_read_until_keywords,
_descriptor_components,
@@ -35,6 +38,9 @@ def _parse_file(tordnsel_file: BinaryIO, validate: bool = False, **kwargs: Any)
* **IOError** if the file can't be read
"""
+ if kwargs:
+ raise ValueError("TorDNSEL doesn't support additional arguments: %s" % kwargs)
+
# skip content prior to the first ExitNode
_read_until_keywords('ExitNode', tordnsel_file, skip = True)
@@ -43,7 +49,7 @@ def _parse_file(tordnsel_file: BinaryIO, validate: bool = False, **kwargs: Any)
contents += _read_until_keywords('ExitNode', tordnsel_file)
if contents:
- yield TorDNSEL(bytes.join(b'', contents), validate, **kwargs)
+ yield TorDNSEL(bytes.join(b'', contents), validate)
else:
break # done parsing file
@@ -64,19 +70,21 @@ class TorDNSEL(Descriptor):
TYPE_ANNOTATION_NAME = 'tordnsel'
- def __init__(self, raw_contents: str, validate: bool) -> None:
+ def __init__(self, raw_contents: bytes, validate: bool) -> None:
super(TorDNSEL, self).__init__(raw_contents)
- raw_contents = stem.util.str_tools._to_unicode(raw_contents)
entries = _descriptor_components(raw_contents, validate)
- self.fingerprint = None
- self.published = None
- self.last_status = None
- self.exit_addresses = []
+ self.fingerprint = None # type: Optional[str]
+ self.published = None # type: Optional[datetime.datetime]
+ self.last_status = None # type: Optional[datetime.datetime]
+ self.exit_addresses = [] # type: List[Tuple[str, datetime.datetime]]
self._parse(entries, validate)
- def _parse(self, entries: Dict[str, Sequence[str]], validate: bool) -> None:
+ def _parse(self, entries: ENTRY_TYPE, validate: bool, parser_for_line: Optional[Dict[str, Callable]] = None) -> None:
+ if parser_for_line:
+ raise ValueError('parser_for_line is unused by TorDNSEL')
+
for keyword, values in list(entries.items()):
value, block_type, block_content = values[0]
@@ -102,7 +110,7 @@ class TorDNSEL(Descriptor):
raise ValueError("LastStatus time wasn't parsable: %s" % value)
elif keyword == 'ExitAddress':
for value, block_type, block_content in values:
- address, date = value.split(' ', 1)
+ address, date_str = value.split(' ', 1)
if validate:
if not stem.util.connection.is_valid_ipv4_address(address):
@@ -111,7 +119,7 @@ class TorDNSEL(Descriptor):
raise ValueError('Unexpected block content: %s' % block_content)
try:
- date = stem.util.str_tools._parse_timestamp(date)
+ date = stem.util.str_tools._parse_timestamp(date_str)
self.exit_addresses.append((address, date))
except ValueError:
if validate:
diff --git a/stem/directory.py b/stem/directory.py
index f96adfbb..3ecb0b71 100644
--- a/stem/directory.py
+++ b/stem/directory.py
@@ -49,7 +49,7 @@ import stem.util
import stem.util.conf
from stem.util import connection, str_tools, tor_tools
-from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Pattern, Sequence, Tuple
+from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Pattern, Sequence, Tuple, Union
GITWEB_AUTHORITY_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/auth_dirs.inc'
GITWEB_FALLBACK_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/fallback_dirs.inc'
@@ -69,7 +69,7 @@ FALLBACK_EXTRAINFO = re.compile('/\\* extrainfo=([0-1]) \\*/')
FALLBACK_IPV6 = re.compile('" ipv6=\\[([\\da-f:]+)\\]:(\\d+)"')
-def _match_with(lines: Sequence[str], regexes: Sequence[Pattern], required: Optional[bool] = None) -> Dict[Pattern, Tuple[str]]:
+def _match_with(lines: Sequence[str], regexes: Sequence[Pattern], required: Optional[Sequence[Pattern]] = None) -> Dict[Pattern, Union[str, List[str]]]:
"""
Scans the given content against a series of regex matchers, providing back a
mapping of regexes to their capture groups. This maping is with the value if
@@ -102,7 +102,7 @@ def _match_with(lines: Sequence[str], regexes: Sequence[Pattern], required: Opti
return matches
-def _directory_entries(lines: Sequence[str], pop_section_func: Callable[[Sequence[str]], Sequence[str]], regexes: Sequence[Pattern], required: bool = None) -> Iterator[Dict[Pattern, Tuple[str]]]:
+def _directory_entries(lines: List[str], pop_section_func: Callable[[List[str]], List[str]], regexes: Sequence[Pattern], required: Optional[Sequence[Pattern]] = None) -> Iterator[Dict[Pattern, Union[str, List[str]]]]:
next_section = pop_section_func(lines)
while next_section:
@@ -130,11 +130,11 @@ class Directory(object):
:var int dir_port: port on which directory information is available
:var str fingerprint: relay fingerprint
:var str nickname: relay nickname
- :var str orport_v6: **(address, port)** tuple for the directory's IPv6
+ :var tuple orport_v6: **(address, port)** tuple for the directory's IPv6
ORPort, or **None** if it doesn't have one
"""
- def __init__(self, address: str, or_port: int, dir_port: int, fingerprint: str, nickname: str, orport_v6: str) -> None:
+ def __init__(self, address: str, or_port: Union[int, str], dir_port: Union[int, str], fingerprint: str, nickname: str, orport_v6: Tuple[str, int]) -> None:
identifier = '%s (%s)' % (fingerprint, nickname) if nickname else fingerprint
if not connection.is_valid_ipv4_address(address):
@@ -164,7 +164,7 @@ class Directory(object):
self.orport_v6 = (orport_v6[0], int(orport_v6[1])) if orport_v6 else None
@staticmethod
- def from_cache() -> Dict[str, 'stem.directory.Directory']:
+ def from_cache() -> Dict[str, Any]:
"""
Provides cached Tor directory information. This information is hardcoded
into Tor and occasionally changes, so the information provided by this
@@ -182,7 +182,7 @@ class Directory(object):
raise NotImplementedError('Unsupported Operation: this should be implemented by the Directory subclass')
@staticmethod
- def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Directory']:
+ def from_remote(timeout: int = 60) -> Dict[str, Any]:
"""
Reads and parses tor's directory data `from gitweb.torproject.org <https://gitweb.torproject.org/>`_.
Note that while convenient, this reliance on GitWeb means you should alway
@@ -232,7 +232,7 @@ class Authority(Directory):
:var str v3ident: identity key fingerprint used to sign votes and consensus
"""
- def __init__(self, address: Optional[str] = None, or_port: Optional[int] = None, dir_port: Optional[int] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, orport_v6: Optional[int] = None, v3ident: Optional[str] = None) -> None:
+ def __init__(self, address: Optional[str] = None, or_port: Optional[Union[int, str]] = None, dir_port: Optional[Union[int, str]] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, orport_v6: Optional[Tuple[str, int]] = None, v3ident: Optional[str] = None) -> None:
super(Authority, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6)
if v3ident and not tor_tools.is_valid_fingerprint(v3ident):
@@ -276,8 +276,8 @@ class Authority(Directory):
dir_port = dir_port,
fingerprint = fingerprint.replace(' ', ''),
nickname = nickname,
- orport_v6 = matches.get(AUTHORITY_IPV6),
- v3ident = matches.get(AUTHORITY_V3IDENT),
+ orport_v6 = matches.get(AUTHORITY_IPV6), # type: ignore
+ v3ident = matches.get(AUTHORITY_V3IDENT), # type: ignore
)
except ValueError as exc:
raise IOError(str(exc))
@@ -285,7 +285,7 @@ class Authority(Directory):
return results
@staticmethod
- def _pop_section(lines: Sequence[str]) -> Sequence[str]:
+ def _pop_section(lines: List[str]) -> List[str]:
"""
Provides the next authority entry.
"""
@@ -349,7 +349,7 @@ class Fallback(Directory):
:var collections.OrderedDict header: metadata about the fallback directory file this originated from
"""
- def __init__(self, address: Optional[str] = None, or_port: Optional[int] = None, dir_port: Optional[int] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, has_extrainfo: bool = False, orport_v6: Optional[int] = None, header: Optional[Mapping[str, str]] = None) -> None:
+ def __init__(self, address: Optional[str] = None, or_port: Optional[Union[int, str]] = None, dir_port: Optional[Union[int, str]] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, has_extrainfo: bool = False, orport_v6: Optional[Tuple[str, int]] = None, header: Optional[Mapping[str, str]] = None) -> None:
super(Fallback, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6)
self.has_extrainfo = has_extrainfo
self.header = collections.OrderedDict(header) if header else collections.OrderedDict()
@@ -440,9 +440,9 @@ class Fallback(Directory):
or_port = int(or_port),
dir_port = int(dir_port),
fingerprint = fingerprint,
- nickname = matches.get(FALLBACK_NICKNAME),
+ nickname = matches.get(FALLBACK_NICKNAME), # type: ignore
has_extrainfo = matches.get(FALLBACK_EXTRAINFO) == '1',
- orport_v6 = matches.get(FALLBACK_IPV6),
+ orport_v6 = matches.get(FALLBACK_IPV6), # type: ignore
header = header,
)
except ValueError as exc:
@@ -451,7 +451,7 @@ class Fallback(Directory):
return results
@staticmethod
- def _pop_section(lines: Sequence[str]) -> Sequence[str]:
+ def _pop_section(lines: List[str]) -> List[str]:
"""
Provides lines up through the next divider. This excludes lines with just a
comma since they're an artifact of these being C strings.
@@ -514,7 +514,7 @@ class Fallback(Directory):
return not self == other
-def _fallback_directory_differences(previous_directories: Sequence['stem.directory.Dirctory'], new_directories: Sequence['stem.directory.Directory']) -> str:
+def _fallback_directory_differences(previous_directories: Mapping[str, 'stem.directory.Fallback'], new_directories: Mapping[str, 'stem.directory.Fallback']) -> str:
"""
Provides a description of how fallback directories differ.
"""
diff --git a/stem/exit_policy.py b/stem/exit_policy.py
index 076611d2..19178c9a 100644
--- a/stem/exit_policy.py
+++ b/stem/exit_policy.py
@@ -71,7 +71,7 @@ import stem.util.connection
import stem.util.enum
import stem.util.str_tools
-from typing import Any, Iterator, Optional, Sequence, Union
+from typing import Any, Iterator, List, Optional, Sequence, Set, Union
AddressType = stem.util.enum.Enum(('WILDCARD', 'Wildcard'), ('IPv4', 'IPv4'), ('IPv6', 'IPv6'))
@@ -167,6 +167,8 @@ class ExitPolicy(object):
def __init__(self, *rules: Union[str, 'stem.exit_policy.ExitPolicyRule']) -> None:
# sanity check the types
+ self._input_rules = None # type: Optional[Union[bytes, Sequence[Union[str, bytes, stem.exit_policy.ExitPolicyRule]]]]
+
for rule in rules:
if not isinstance(rule, (bytes, str)) and not isinstance(rule, ExitPolicyRule):
raise TypeError('Exit policy rules can only contain strings or ExitPolicyRules, got a %s (%s)' % (type(rule), rules))
@@ -183,13 +185,14 @@ class ExitPolicy(object):
is_all_str = False
if rules and is_all_str:
- byte_rules = [stem.util.str_tools._to_bytes(r) for r in rules]
+ byte_rules = [stem.util.str_tools._to_bytes(r) for r in rules] # type: ignore
self._input_rules = zlib.compress(b','.join(byte_rules))
else:
self._input_rules = rules
- self._rules = None
- self._hash = None
+ self._policy_str = None # type: Optional[str]
+ self._rules = None # type: List[stem.exit_policy.ExitPolicyRule]
+ self._hash = None # type: Optional[int]
# Result when no rules apply. According to the spec policies default to 'is
# allowed', but our microdescriptor policy subclass might want to change
@@ -228,7 +231,7 @@ class ExitPolicy(object):
otherwise.
"""
- rejected_ports = set()
+ rejected_ports = set() # type: Set[int]
for rule in self._get_rules():
if rule.is_accept:
@@ -298,7 +301,8 @@ class ExitPolicy(object):
# convert port list to a list of ranges (ie, ['1-3'] rather than [1, 2, 3])
if display_ports:
- display_ranges, temp_range = [], []
+ display_ranges = []
+ temp_range = [] # type: List[int]
display_ports.sort()
display_ports.append(None) # ending item to include last range in loop
@@ -384,23 +388,28 @@ class ExitPolicy(object):
input_rules = self._input_rules
if self._rules is None and input_rules is not None:
- rules = []
+ rules = [] # type: List[stem.exit_policy.ExitPolicyRule]
is_all_accept, is_all_reject = True, True
+ decompressed_rules = None # type: Optional[Sequence[Union[str, bytes, stem.exit_policy.ExitPolicyRule]]]
if isinstance(input_rules, bytes):
decompressed_rules = zlib.decompress(input_rules).split(b',')
else:
decompressed_rules = input_rules
- for rule in decompressed_rules:
- if isinstance(rule, bytes):
- rule = stem.util.str_tools._to_unicode(rule)
+ for rule_val in decompressed_rules:
+ if isinstance(rule_val, bytes):
+ rule_val = stem.util.str_tools._to_unicode(rule_val)
- if isinstance(rule, (bytes, str)):
- if not rule.strip():
+ if isinstance(rule_val, str):
+ if not rule_val.strip():
continue
- rule = ExitPolicyRule(rule.strip())
+ rule = ExitPolicyRule(rule_val.strip())
+ elif isinstance(rule_val, stem.exit_policy.ExitPolicyRule):
+ rule = rule_val
+ else:
+ raise TypeError('BUG: unexpected type within decompressed policy: %s (%s)' % (stem.util.str_tools._to_unicode(rule_val), type(rule_val).__name__))
if rule.is_accept:
is_all_reject = False
@@ -446,9 +455,11 @@ class ExitPolicy(object):
for rule in self._get_rules():
yield rule
- @functools.lru_cache()
def __str__(self) -> str:
- return ', '.join([str(rule) for rule in self._get_rules()])
+ if self._policy_str is None:
+ self._policy_str = ', '.join([str(rule) for rule in self._get_rules()])
+
+ return self._policy_str
def __hash__(self) -> int:
if self._hash is None:
@@ -505,7 +516,7 @@ class MicroExitPolicy(ExitPolicy):
# PortList ::= PortList "," PortOrRange
# PortOrRange ::= INT "-" INT / INT
- self._policy = policy
+ policy_str = policy
if policy.startswith('accept'):
self.is_accept = True
@@ -517,7 +528,7 @@ class MicroExitPolicy(ExitPolicy):
policy = policy[6:]
if not policy.startswith(' '):
- raise ValueError('A microdescriptor exit policy should have a space separating accept/reject from its port list: %s' % self._policy)
+ raise ValueError('A microdescriptor exit policy should have a space separating accept/reject from its port list: %s' % policy_str)
policy = policy.lstrip()
@@ -538,9 +549,10 @@ class MicroExitPolicy(ExitPolicy):
super(MicroExitPolicy, self).__init__(*rules)
self._is_allowed_default = not self.is_accept
+ self._policy_str = policy_str
def __str__(self) -> str:
- return self._policy
+ return self._policy_str
def __hash__(self) -> int:
return hash(str(self))
@@ -606,17 +618,17 @@ class ExitPolicyRule(object):
if ':' not in exitpattern or ']' in exitpattern.rsplit(':', 1)[1]:
raise ValueError("An exitpattern must be of the form 'addrspec:portspec': %s" % rule)
- self.address = None
- self._address_type = None
- self._masked_bits = None
- self.min_port = self.max_port = None
- self._hash = None
+ self.address = None # type: Optional[str]
+ self._address_type = None # type: Optional[stem.exit_policy.AddressType]
+ self._masked_bits = None # type: Optional[int]
+ self.min_port = self.max_port = None # type: Optional[int]
+ self._hash = None # type: Optional[int]
# Our mask in ip notation (ex. '255.255.255.0'). This is only set if we
# either have a custom mask that can't be represented by a number of bits,
# or the user has called mask(), lazily loading this.
- self._mask = None
+ self._mask = None # type: Optional[str]
# Malformed exit policies are rejected, but there's an exception where it's
# just skipped: when an accept6/reject6 rule has an IPv4 address...
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index 2a3cff18..1d08abb6 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -54,13 +54,13 @@ def main() -> None:
import stem.interpreter.commands
try:
- args = stem.interpreter.arguments.parse(sys.argv[1:])
+ args = stem.interpreter.arguments.Arguments.parse(sys.argv[1:])
except ValueError as exc:
print(exc)
sys.exit(1)
if args.print_help:
- print(stem.interpreter.arguments.get_help())
+ print(stem.interpreter.arguments.Arguments.get_help())
sys.exit()
if args.disable_color or not sys.stdout.isatty():
@@ -82,13 +82,11 @@ def main() -> None:
if not args.run_cmd and not args.run_path:
print(format(msg('msg.starting_tor'), *HEADER_OUTPUT))
- control_port = '9051' if args.control_port == 'default' else str(args.control_port)
-
try:
stem.process.launch_tor_with_config(
config = {
'SocksPort': '0',
- 'ControlPort': control_port,
+ 'ControlPort': '9051' if args.control_port is None else str(args.control_port),
'CookieAuthentication': '1',
'ExitPolicy': 'reject *:*',
},
@@ -115,7 +113,7 @@ def main() -> None:
control_port = control_port,
control_socket = control_socket,
password_prompt = True,
- )
+ ) # type: stem.control.Controller
if controller is None:
sys.exit(1)
@@ -126,7 +124,7 @@ def main() -> None:
if args.run_cmd:
if args.run_cmd.upper().startswith('SETEVENTS '):
- controller._handle_event = lambda event_message: print(format(str(event_message), *STANDARD_OUTPUT))
+ controller._handle_event = lambda event_message: print(format(str(event_message), *STANDARD_OUTPUT)) # type: ignore
if sys.stdout.isatty():
events = args.run_cmd.upper().split(' ', 1)[1]
diff --git a/stem/interpreter/arguments.py b/stem/interpreter/arguments.py
index 8ac1c2c1..dd0b19bb 100644
--- a/stem/interpreter/arguments.py
+++ b/stem/interpreter/arguments.py
@@ -5,103 +5,102 @@
Commandline argument parsing for our interpreter prompt.
"""
-import collections
import getopt
import os
import stem.interpreter
import stem.util.connection
-from typing import NamedTuple, Sequence
-
-DEFAULT_ARGS = {
- 'control_address': '127.0.0.1',
- 'control_port': 'default',
- 'user_provided_port': False,
- 'control_socket': '/var/run/tor/control',
- 'user_provided_socket': False,
- 'tor_path': 'tor',
- 'run_cmd': None,
- 'run_path': None,
- 'disable_color': False,
- 'print_help': False,
-}
+from typing import Any, Dict, NamedTuple, Optional, Sequence
OPT = 'i:s:h'
OPT_EXPANDED = ['interface=', 'socket=', 'tor=', 'run=', 'no-color', 'help']
-def parse(argv: Sequence[str]) -> NamedTuple:
- """
- Parses our arguments, providing a named tuple with their values.
-
- :param list argv: input arguments to be parsed
-
- :returns: a **named tuple** with our parsed arguments
-
- :raises: **ValueError** if we got an invalid argument
- """
-
- args = dict(DEFAULT_ARGS)
-
- try:
- recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED)
-
- if unrecognized_args:
- error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument"
- raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg))
- except Exception as exc:
- raise ValueError('%s (for usage provide --help)' % exc)
-
- for opt, arg in recognized_args:
- if opt in ('-i', '--interface'):
- if ':' in arg:
- address, port = arg.rsplit(':', 1)
- else:
- address, port = None, arg
-
- if address is not None:
- if not stem.util.connection.is_valid_ipv4_address(address):
- raise ValueError("'%s' isn't a valid IPv4 address" % address)
-
- args['control_address'] = address
-
- if not stem.util.connection.is_valid_port(port):
- raise ValueError("'%s' isn't a valid port number" % port)
-
- args['control_port'] = int(port)
- args['user_provided_port'] = True
- elif opt in ('-s', '--socket'):
- args['control_socket'] = arg
- args['user_provided_socket'] = True
- elif opt in ('--tor'):
- args['tor_path'] = arg
- elif opt in ('--run'):
- if os.path.exists(arg):
- args['run_path'] = arg
- else:
- args['run_cmd'] = arg
- elif opt == '--no-color':
- args['disable_color'] = True
- elif opt in ('-h', '--help'):
- args['print_help'] = True
-
- # translates our args dict into a named tuple
-
- Args = collections.namedtuple('Args', args.keys())
- return Args(**args)
-
-
-def get_help() -> str:
- """
- Provides our --help usage information.
-
- :returns: **str** with our usage information
- """
-
- return stem.interpreter.msg(
- 'msg.help',
- address = DEFAULT_ARGS['control_address'],
- port = DEFAULT_ARGS['control_port'],
- socket = DEFAULT_ARGS['control_socket'],
- )
+class Arguments(NamedTuple):
+ control_address: str = '127.0.0.1'
+ control_port: Optional[int] = None
+ user_provided_port: bool = False
+ control_socket: str = '/var/run/tor/control'
+ user_provided_socket: bool = False
+ tor_path: str = 'tor'
+ run_cmd: Optional[str] = None
+ run_path: Optional[str] = None
+ disable_color: bool = False
+ print_help: bool = False
+
+ @staticmethod
+ def parse(argv: Sequence[str]) -> 'stem.interpreter.arguments.Arguments':
+ """
+ Parses our commandline arguments into this class.
+
+ :param list argv: input arguments to be parsed
+
+ :returns: :class:`stem.interpreter.arguments.Arguments` for this
+ commandline input
+
+ :raises: **ValueError** if we got an invalid argument
+ """
+
+ args = {} # type: Dict[str, Any]
+
+ try:
+ recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) # type: ignore
+
+ if unrecognized_args:
+ error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument"
+ raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg))
+ except Exception as exc:
+ raise ValueError('%s (for usage provide --help)' % exc)
+
+ for opt, arg in recognized_args:
+ if opt in ('-i', '--interface'):
+ if ':' in arg:
+ address, port = arg.rsplit(':', 1)
+ else:
+ address, port = None, arg
+
+ if address is not None:
+ if not stem.util.connection.is_valid_ipv4_address(address):
+ raise ValueError("'%s' isn't a valid IPv4 address" % address)
+
+ args['control_address'] = address
+
+ if not stem.util.connection.is_valid_port(port):
+ raise ValueError("'%s' isn't a valid port number" % port)
+
+ args['control_port'] = int(port)
+ args['user_provided_port'] = True
+ elif opt in ('-s', '--socket'):
+ args['control_socket'] = arg
+ args['user_provided_socket'] = True
+ elif opt in ('--tor'):
+ args['tor_path'] = arg
+ elif opt in ('--run'):
+ if os.path.exists(arg):
+ args['run_path'] = arg
+ else:
+ args['run_cmd'] = arg
+ elif opt == '--no-color':
+ args['disable_color'] = True
+ elif opt in ('-h', '--help'):
+ args['print_help'] = True
+
+ return Arguments(**args)
+
+ @staticmethod
+ def get_help() -> str:
+ """
+ Provides our --help usage information.
+
+ :returns: **str** with our usage information
+ """
+
+ defaults = Arguments()
+
+ return stem.interpreter.msg(
+ 'msg.help',
+ address = defaults.control_address,
+ port = defaults.control_port if defaults.control_port else 'default',
+ socket = defaults.control_socket,
+ )
diff --git a/stem/interpreter/autocomplete.py b/stem/interpreter/autocomplete.py
index 671085a7..e310ed28 100644
--- a/stem/interpreter/autocomplete.py
+++ b/stem/interpreter/autocomplete.py
@@ -7,12 +7,15 @@ Tab completion for our interpreter prompt.
import functools
+import stem.control
+import stem.util.conf
+
from stem.interpreter import uses_settings
-from typing import Optional, Sequence
+from typing import List, Optional
@uses_settings
-def _get_commands(controller: 'stem.control.Controller', config: 'stem.util.conf.Config') -> Sequence[str]:
+def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Config) -> List[str]:
"""
Provides commands recognized by tor.
"""
@@ -77,11 +80,11 @@ def _get_commands(controller: 'stem.control.Controller', config: 'stem.util.conf
class Autocompleter(object):
- def __init__(self, controller: 'stem.control.Controller') -> None:
+ def __init__(self, controller: stem.control.Controller) -> None:
self._commands = _get_commands(controller)
@functools.lru_cache()
- def matches(self, text: str) -> Sequence[str]:
+ def matches(self, text: str) -> List[str]:
"""
Provides autocompletion matches for the given text.
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index 1d610dac..254e46a1 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -21,12 +21,12 @@ import stem.util.tor_tools
from stem.interpreter import STANDARD_OUTPUT, BOLD_OUTPUT, ERROR_OUTPUT, uses_settings, msg
from stem.util.term import format
-from typing import BinaryIO, Iterator, Sequence, Tuple
+from typing import Iterator, List, TextIO
MAX_EVENTS = 100
-def _get_fingerprint(arg: str, controller: 'stem.control.Controller') -> str:
+def _get_fingerprint(arg: str, controller: stem.control.Controller) -> str:
"""
Resolves user input into a relay fingerprint. This accepts...
@@ -91,7 +91,7 @@ def _get_fingerprint(arg: str, controller: 'stem.control.Controller') -> str:
@contextlib.contextmanager
-def redirect(stdout: BinaryIO, stderr: BinaryIO) -> Iterator[None]:
+def redirect(stdout: TextIO, stderr: TextIO) -> Iterator[None]:
original = sys.stdout, sys.stderr
sys.stdout, sys.stderr = stdout, stderr
@@ -107,8 +107,8 @@ class ControlInterpreter(code.InteractiveConsole):
for special irc style subcommands.
"""
- def __init__(self, controller: 'stem.control.Controller') -> None:
- self._received_events = []
+ def __init__(self, controller: stem.control.Controller) -> None:
+ self._received_events = [] # type: List[stem.response.events.Event]
code.InteractiveConsole.__init__(self, {
'stem': stem,
@@ -130,18 +130,19 @@ class ControlInterpreter(code.InteractiveConsole):
handle_event_real = self._controller._handle_event
- def handle_event_wrapper(event_message: 'stem.response.events.Event') -> None:
+ def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
handle_event_real(event_message)
- self._received_events.insert(0, event_message)
+ self._received_events.insert(0, event_message) # type: ignore
if len(self._received_events) > MAX_EVENTS:
self._received_events.pop()
- self._controller._handle_event = handle_event_wrapper
+ # type check disabled due to https://github.com/python/mypy/issues/708
- def get_events(self, *event_types: 'stem.control.EventType') -> Sequence['stem.response.events.Event']:
+ self._controller._handle_event = handle_event_wrapper # type: ignore
+
+ def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]:
events = list(self._received_events)
- event_types = list(map(str.upper, event_types)) # make filtering case insensitive
if event_types:
events = [e for e in events if e.type in event_types]
@@ -296,7 +297,7 @@ class ControlInterpreter(code.InteractiveConsole):
return format(response, *STANDARD_OUTPUT)
@uses_settings
- def run_command(self, command: str, config: 'stem.util.conf.Config', print_response: bool = False) -> Sequence[Tuple[str, int]]:
+ def run_command(self, command: str, config: stem.util.conf.Config, print_response: bool = False) -> str:
"""
Runs the given command. Requests starting with a '/' are special commands
to the interpreter, and anything else is sent to the control port.
@@ -304,8 +305,7 @@ class ControlInterpreter(code.InteractiveConsole):
:param str command: command to be processed
:param bool print_response: prints the response to stdout if true
- :returns: **list** out output lines, each line being a list of
- (msg, format) tuples
+ :returns: **str** output of the command
:raises: **stem.SocketClosed** if the control connection has been severed
"""
@@ -363,7 +363,7 @@ class ControlInterpreter(code.InteractiveConsole):
output = console_output.getvalue().strip()
else:
try:
- output = format(self._controller.msg(command).raw_content().strip(), *STANDARD_OUTPUT)
+ output = format(str(self._controller.msg(command).raw_content()).strip(), *STANDARD_OUTPUT)
except stem.ControllerError as exc:
if isinstance(exc, stem.SocketClosed):
raise
diff --git a/stem/interpreter/help.py b/stem/interpreter/help.py
index 81c76d34..3a206c35 100644
--- a/stem/interpreter/help.py
+++ b/stem/interpreter/help.py
@@ -7,6 +7,11 @@ Provides our /help responses.
import functools
+import stem.control
+import stem.util.conf
+
+from stem.util.term import format
+
from stem.interpreter import (
STANDARD_OUTPUT,
BOLD_OUTPUT,
@@ -15,10 +20,8 @@ from stem.interpreter import (
uses_settings,
)
-from stem.util.term import format
-
-def response(controller: 'stem.control.Controller', arg: str) -> str:
+def response(controller: stem.control.Controller, arg: str) -> str:
"""
Provides our /help response.
@@ -33,7 +36,7 @@ def response(controller: 'stem.control.Controller', arg: str) -> str:
return _response(controller, _normalize(arg))
-def _normalize(arg) -> str:
+def _normalize(arg: str) -> str:
arg = arg.upper()
# If there's multiple arguments then just take the first. This is
@@ -52,7 +55,7 @@ def _normalize(arg) -> str:
@functools.lru_cache()
@uses_settings
-def _response(controller: 'stem.control.Controller', arg: str, config: 'stem.util.conf.Config') -> str:
+def _response(controller: stem.control.Controller, arg: str, config: stem.util.conf.Config) -> str:
if not arg:
return _general_help()
diff --git a/stem/manual.py b/stem/manual.py
index e28e0e6f..9bc10b85 100644
--- a/stem/manual.py
+++ b/stem/manual.py
@@ -61,9 +61,10 @@ import stem.util
import stem.util.conf
import stem.util.enum
import stem.util.log
+import stem.util.str_tools
import stem.util.system
-from typing import Any, Dict, Mapping, Optional, Sequence, TextIO, Tuple, Union
+from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union
Category = stem.util.enum.Enum('GENERAL', 'CLIENT', 'RELAY', 'DIRECTORY', 'AUTHORITY', 'HIDDEN_SERVICE', 'DENIAL_OF_SERVICE', 'TESTING', 'UNKNOWN')
GITWEB_MANUAL_URL = 'https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt'
@@ -111,7 +112,7 @@ class SchemaMismatch(IOError):
self.supported_schemas = supported_schemas
-def query(query: str, *param: str) -> 'sqlite3.Cursor':
+def query(query: str, *param: str) -> 'sqlite3.Cursor': # type: ignore
"""
Performs the given query on our sqlite manual cache. This database should
be treated as being read-only. File permissions generally enforce this, and
@@ -182,7 +183,7 @@ class ConfigOption(object):
@functools.lru_cache()
-def _config(lowercase: bool = True) -> Dict[str, Union[Sequence[str], str]]:
+def _config(lowercase: bool = True) -> Dict[str, Union[List[str], str]]:
"""
Provides a dictionary for our settings.cfg. This has a couple categories...
@@ -264,7 +265,7 @@ def is_important(option: str) -> bool:
return option.lower() in _config()['manual.important']
-def download_man_page(path: Optional[str] = None, file_handle: Optional[TextIO] = None, url: str = GITWEB_MANUAL_URL, timeout: int = 20) -> None:
+def download_man_page(path: Optional[str] = None, file_handle: Optional[IO[bytes]] = None, url: str = GITWEB_MANUAL_URL, timeout: int = 20) -> None:
"""
Downloads tor's latest man page from `gitweb.torproject.org
<https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt>`_. This method is
@@ -303,7 +304,7 @@ def download_man_page(path: Optional[str] = None, file_handle: Optional[TextIO]
if not os.path.exists(manual_path):
raise OSError('no man page was generated')
except stem.util.system.CallError as exc:
- raise IOError("Unable to run '%s': %s" % (exc.command, exc.stderr))
+ raise IOError("Unable to run '%s': %s" % (exc.command, stem.util.str_tools._to_unicode(exc.stderr)))
if path:
try:
@@ -349,7 +350,7 @@ class Manual(object):
:var str stem_commit: stem commit to cache this manual information
"""
- def __init__(self, name: str, synopsis: str, description: str, commandline_options: Mapping[str, str], signals: Mapping[str, str], files: Mapping[str, str], config_options: Mapping[str, str]) -> None:
+ def __init__(self, name: str, synopsis: str, description: str, commandline_options: Mapping[str, str], signals: Mapping[str, str], files: Mapping[str, str], config_options: Mapping[str, 'stem.manual.ConfigOption']) -> None:
self.name = name
self.synopsis = synopsis
self.description = description
@@ -449,7 +450,8 @@ class Manual(object):
except OSError as exc:
raise IOError("Unable to run '%s': %s" % (man_cmd, exc))
- categories, config_options = _get_categories(man_output), collections.OrderedDict()
+ categories = _get_categories(man_output)
+ config_options = collections.OrderedDict() # type: collections.OrderedDict[str, stem.manual.ConfigOption]
for category_header, category_enum in CATEGORY_SECTIONS.items():
_add_config_options(config_options, category_enum, categories.get(category_header, []))
@@ -561,7 +563,7 @@ class Manual(object):
return not self == other
-def _get_categories(content: str) -> Dict[str, str]:
+def _get_categories(content: Sequence[str]) -> Dict[str, List[str]]:
"""
The man page is headers followed by an indented section. First pass gets
the mapping of category titles to their lines.
@@ -576,7 +578,8 @@ def _get_categories(content: str) -> Dict[str, str]:
content = content[:-1]
categories = collections.OrderedDict()
- category, lines = None, []
+ category = None
+ lines = [] # type: List[str]
for line in content:
# replace non-ascii characters
@@ -607,7 +610,7 @@ def _get_categories(content: str) -> Dict[str, str]:
return categories
-def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, Sequence[str]]:
+def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, str]:
"""
Parses the commandline argument and signal sections. These are options
followed by an indented description. For example...
@@ -624,7 +627,8 @@ def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, Sequence[str]]
ignoring those.
"""
- options, last_arg = collections.OrderedDict(), None
+ options = collections.OrderedDict() # type: collections.OrderedDict[str, List[str]]
+ last_arg = None
for line in lines:
if line == ' Note':
@@ -637,7 +641,7 @@ def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, Sequence[str]]
return dict([(arg, ' '.join(desc_lines)) for arg, desc_lines in options.items() if desc_lines])
-def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption'], category: str, lines: Sequence[str]) -> None:
+def _add_config_options(config_options: Dict[str, 'stem.manual.ConfigOption'], category: str, lines: Sequence[str]) -> None:
"""
Parses a section of tor configuration options. These have usage information,
followed by an indented description. For instance...
@@ -655,7 +659,7 @@ def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption']
since that platform lacks getrlimit(). (Default: 1000)
"""
- def add_option(title: str, description: str) -> None:
+ def add_option(title: str, description: List[str]) -> None:
if 'PER INSTANCE OPTIONS' in title:
return # skip, unfortunately amid the options
@@ -669,7 +673,7 @@ def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption']
add_option(subtitle, description)
else:
name, usage = title.split(' ', 1) if ' ' in title else (title, '')
- summary = _config().get('manual.summary.%s' % name.lower(), '')
+ summary = str(_config().get('manual.summary.%s' % name.lower(), ''))
config_options[name] = ConfigOption(name, category, usage, summary, _join_lines(description).strip())
# Remove the section's description by finding the sentence the section
@@ -681,7 +685,8 @@ def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption']
lines = lines[max(end_indices):] # trim to the description paragrah
lines = lines[lines.index(''):] # drop the paragraph
- last_title, description = None, []
+ last_title = None
+ description = [] # type: List[str]
for line in lines:
if line and not line.startswith(' '):
@@ -704,7 +709,7 @@ def _join_lines(lines: Sequence[str]) -> str:
Simple join, except we want empty lines to still provide a newline.
"""
- result = []
+ result = [] # type: List[str]
for line in lines:
if not line:
diff --git a/stem/process.py b/stem/process.py
index bfab4967..3c7688a5 100644
--- a/stem/process.py
+++ b/stem/process.py
@@ -29,7 +29,7 @@ import stem.util.str_tools
import stem.util.system
import stem.version
-from typing import Any, Callable, Mapping, Optional, Sequence, Union
+from typing import Any, Callable, Dict, Optional, Sequence, Union
NO_TORRC = '<no torrc>'
DEFAULT_INIT_TIMEOUT = 90
@@ -199,7 +199,7 @@ def launch_tor(tor_cmd: str = 'tor', args: Optional[Sequence[str]] = None, torrc
pass
-def launch_tor_with_config(config: Mapping[str, Union[str, Sequence[str]]], tor_cmd: str = 'tor', completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True) -> subprocess.Popen:
+def launch_tor_with_config(config: Dict[str, Union[str, Sequence[str]]], tor_cmd: str = 'tor', completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True) -> subprocess.Popen:
"""
Initializes a tor process, like :func:`~stem.process.launch_tor`, but with a
customized configuration. This writes a temporary torrc to disk, launches
@@ -260,7 +260,7 @@ def launch_tor_with_config(config: Mapping[str, Union[str, Sequence[str]]], tor_
break
if not has_stdout:
- config['Log'].append('NOTICE stdout')
+ config['Log'] = list(config['Log']) + ['NOTICE stdout']
config_str = ''
diff --git a/stem/response/__init__.py b/stem/response/__init__.py
index 4b1f9533..2f851389 100644
--- a/stem/response/__init__.py
+++ b/stem/response/__init__.py
@@ -38,7 +38,7 @@ import stem.socket
import stem.util
import stem.util.str_tools
-from typing import Any, Iterator, Optional, Sequence, Tuple, Union
+from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
__all__ = [
'add_onion',
@@ -123,7 +123,40 @@ def convert(response_type: str, message: 'stem.response.ControlMessage', **kwarg
raise TypeError('Unsupported response type: %s' % response_type)
message.__class__ = response_class
- message._parse_message(**kwargs)
+ message._parse_message(**kwargs) # type: ignore
+
+
+# TODO: These aliases are for type hint compatability. We should refactor how
+# message conversion is performed to avoid this headache.
+
+def _convert_to_single_line(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.SingleLineResponse':
+ stem.response.convert('SINGLELINE', message)
+ return message # type: ignore
+
+
+def _convert_to_event(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.events.Event':
+ stem.response.convert('EVENT', message)
+ return message # type: ignore
+
+
+def _convert_to_getinfo(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.getinfo.GetInfoResponse':
+ stem.response.convert('GETINFO', message)
+ return message # type: ignore
+
+
+def _convert_to_getconf(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.getconf.GetConfResponse':
+ stem.response.convert('GETCONF', message)
+ return message # type: ignore
+
+
+def _convert_to_add_onion(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.add_onion.AddOnionResponse':
+ stem.response.convert('ADD_ONION', message)
+ return message # type: ignore
+
+
+def _convert_to_mapaddress(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.mapaddress.MapAddressResponse':
+ stem.response.convert('MAPADDRESS', message)
+ return message # type: ignore
class ControlMessage(object):
@@ -142,7 +175,7 @@ class ControlMessage(object):
"""
@staticmethod
- def from_str(content: str, msg_type: Optional[str] = None, normalize: bool = False, **kwargs: Any) -> 'stem.response.ControlMessage':
+ def from_str(content: Union[str, bytes], msg_type: Optional[str] = None, normalize: bool = False, **kwargs: Any) -> 'stem.response.ControlMessage':
"""
Provides a ControlMessage for the given content.
@@ -160,28 +193,35 @@ class ControlMessage(object):
:returns: stem.response.ControlMessage instance
"""
+ if isinstance(content, str):
+ content = stem.util.str_tools._to_bytes(content)
+
if normalize:
- if not content.endswith('\n'):
- content += '\n'
+ if not content.endswith(b'\n'):
+ content += b'\n'
- content = re.sub('([\r]?)\n', '\r\n', content)
+ content = re.sub(b'([\r]?)\n', b'\r\n', content)
- msg = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes(content)), arrived_at = kwargs.pop('arrived_at', None))
+ msg = stem.socket.recv_message(io.BytesIO(content), arrived_at = kwargs.pop('arrived_at', None))
if msg_type is not None:
convert(msg_type, msg, **kwargs)
return msg
- def __init__(self, parsed_content: Sequence[Tuple[str, str, bytes]], raw_content: bytes, arrived_at: Optional[int] = None) -> None:
+ def __init__(self, parsed_content: Sequence[Tuple[str, str, bytes]], raw_content: bytes, arrived_at: Optional[float] = None) -> None:
if not parsed_content:
raise ValueError("ControlMessages can't be empty")
- self.arrived_at = arrived_at if arrived_at else int(time.time())
+ # TODO: Change arrived_at to a float (can't yet because it causes Event
+ # equality checks to fail - events include arrived_at within their hash
+ # whereas ControlMessages don't).
+
+ self.arrived_at = int(arrived_at if arrived_at else time.time())
self._parsed_content = parsed_content
self._raw_content = raw_content
- self._str = None
+ self._str = None # type: Optional[str]
self._hash = stem.util._hash_attr(self, '_raw_content')
def is_ok(self) -> bool:
@@ -197,7 +237,12 @@ class ControlMessage(object):
return False
- def content(self, get_bytes: bool = False) -> Sequence[Tuple[str, str, bytes]]:
+ # TODO: drop this alias when we provide better type support
+
+ def _content_bytes(self) -> List[Tuple[str, str, bytes]]:
+ return self.content(get_bytes = True) # type: ignore
+
+ def content(self, get_bytes: bool = False) -> List[Tuple[str, str, str]]:
"""
Provides the parsed message content. These are entries of the form...
@@ -234,9 +279,9 @@ class ControlMessage(object):
if not get_bytes:
return [(code, div, stem.util.str_tools._to_unicode(content)) for (code, div, content) in self._parsed_content]
else:
- return list(self._parsed_content)
+ return list(self._parsed_content) # type: ignore
- def raw_content(self, get_bytes: bytes = False) -> Union[str, bytes]:
+ def raw_content(self, get_bytes: bool = False) -> Union[str, bytes]:
"""
Provides the unparsed content read from the control socket.
@@ -253,6 +298,9 @@ class ControlMessage(object):
else:
return self._raw_content
+ def _parse_message(self) -> None:
+ raise NotImplementedError('Implemented by subclasses')
+
def __str__(self) -> str:
"""
Content of the message, stripped of status code and divider protocol
@@ -288,9 +336,7 @@ class ControlMessage(object):
"""
for _, _, content in self._parsed_content:
- content = stem.util.str_tools._to_unicode(content)
-
- yield ControlLine(content)
+ yield ControlLine(stem.util.str_tools._to_unicode(content))
def __len__(self) -> int:
"""
@@ -330,7 +376,7 @@ class ControlLine(str):
"""
def __new__(self, value: str) -> 'stem.response.ControlLine':
- return str.__new__(self, value)
+ return str.__new__(self, value) # type: ignore
def __init__(self, value: str) -> None:
self._remainder = value
@@ -443,7 +489,12 @@ class ControlLine(str):
with self._remainder_lock:
next_entry, remainder = _parse_entry(self._remainder, quoted, escaped, False)
self._remainder = remainder
- return next_entry
+ return next_entry # type: ignore
+
+ # TODO: drop this alias when we provide better type support
+
+ def _pop_mapping_bytes(self, quoted: bool = False, escaped: bool = False) -> Tuple[str, bytes]:
+ return self.pop_mapping(quoted, escaped, get_bytes = True) # type: ignore
def pop_mapping(self, quoted: bool = False, escaped: bool = False, get_bytes: bool = False) -> Tuple[str, str]:
"""
@@ -479,7 +530,7 @@ class ControlLine(str):
next_entry, remainder = _parse_entry(remainder, quoted, escaped, get_bytes)
self._remainder = remainder
- return (key, next_entry)
+ return (key, next_entry) # type: ignore
def _parse_entry(line: str, quoted: bool, escaped: bool, get_bytes: bool) -> Tuple[Union[str, bytes], str]:
@@ -532,15 +583,15 @@ def _parse_entry(line: str, quoted: bool, escaped: bool, get_bytes: bool) -> Tup
#
# https://stackoverflow.com/questions/14820429/how-do-i-decodestring-escape-in-python3
- next_entry = codecs.escape_decode(next_entry)[0]
+ next_entry = codecs.escape_decode(next_entry)[0] # type: ignore
if not get_bytes:
next_entry = stem.util.str_tools._to_unicode(next_entry) # normalize back to str
if get_bytes:
- next_entry = stem.util.str_tools._to_bytes(next_entry)
-
- return (next_entry, remainder.lstrip())
+ return (stem.util.str_tools._to_bytes(next_entry), remainder.lstrip())
+ else:
+ return (next_entry, remainder.lstrip())
def _get_quote_indices(line: str, escaped: bool) -> Tuple[int, int]:
@@ -566,7 +617,7 @@ def _get_quote_indices(line: str, escaped: bool) -> Tuple[int, int]:
indices.append(quote_index)
- return tuple(indices)
+ return tuple(indices) # type: ignore
class SingleLineResponse(ControlMessage):
@@ -604,4 +655,7 @@ class SingleLineResponse(ControlMessage):
elif len(content) == 0:
raise stem.ProtocolError('Received empty response')
else:
- self.code, _, self.message = content[0]
+ code, _, msg = content[0]
+
+ self.code = stem.util.str_tools._to_unicode(code)
+ self.message = stem.util.str_tools._to_unicode(msg)
diff --git a/stem/response/events.py b/stem/response/events.py
index 0e112373..65419fe6 100644
--- a/stem/response/events.py
+++ b/stem/response/events.py
@@ -1,6 +1,9 @@
# Copyright 2012-2020, Damian Johnson and The Tor Project
# See LICENSE for licensing information
+#
+# mypy: ignore-errors
+import datetime
import io
import re
@@ -12,7 +15,7 @@ import stem.util
import stem.version
from stem.util import connection, log, str_tools, tor_tools
-from typing import Any, Dict, Sequence
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
# Matches keyword=value arguments. This can't be a simple "(.*)=(.*)" pattern
# because some positional arguments, like circuit paths, can have an equal
@@ -34,10 +37,13 @@ class Event(stem.response.ControlMessage):
:var dict keyword_args: key/value arguments of the event
"""
- _POSITIONAL_ARGS = () # attribute names for recognized positional arguments
- _KEYWORD_ARGS = {} # map of 'keyword => attribute' for recognized attributes
- _QUOTED = () # positional arguments that are quoted
- _OPTIONALLY_QUOTED = () # positional arguments that may or may not be quoted
+ # TODO: Replace metaprogramming with concrete implementations (to simplify type information)
+ # TODO: _QUOTED looks to be unused
+
+ _POSITIONAL_ARGS = () # type: Tuple[str, ...] # attribute names for recognized positional arguments
+ _KEYWORD_ARGS = {} # type: Dict[str, str] # map of 'keyword => attribute' for recognized attributes
+ _QUOTED = () # type: Tuple[str, ...] # positional arguments that are quoted
+ _OPTIONALLY_QUOTED = () # type: Tuple[str, ...] # positional arguments that may or may not be quoted
_SKIP_PARSING = False # skip parsing contents into our positional_args and keyword_args
_VERSION_ADDED = stem.version.Version('0.1.1.1-alpha') # minimum version with control-spec V1 event support
@@ -46,13 +52,14 @@ class Event(stem.response.ControlMessage):
raise stem.ProtocolError('Received a blank tor event. Events must at the very least have a type.')
self.type = str(self).split()[0]
- self.positional_args = []
- self.keyword_args = {}
+ self.positional_args = [] # type: List[str]
+ self.keyword_args = {} # type: Dict[str, str]
# if we're a recognized event type then translate ourselves into that subclass
if self.type in EVENT_TYPE_TO_CLASS:
self.__class__ = EVENT_TYPE_TO_CLASS[self.type]
+ self.__init__() # type: ignore
if not self._SKIP_PARSING:
self._parse_standard_attr()
@@ -123,7 +130,7 @@ class Event(stem.response.ControlMessage):
for controller_attr_name, attr_name in self._KEYWORD_ARGS.items():
setattr(self, attr_name, self.keyword_args.get(controller_attr_name))
- def _iso_timestamp(self, timestamp: str) -> 'datetime.datetime':
+ def _iso_timestamp(self, timestamp: str) -> datetime.datetime:
"""
Parses an iso timestamp (ISOTime2Frac in the control-spec).
@@ -146,7 +153,7 @@ class Event(stem.response.ControlMessage):
def _parse(self) -> None:
pass
- def _log_if_unrecognized(self, attr: str, attr_enum: 'stem.util.enum.Enum') -> None:
+ def _log_if_unrecognized(self, attr: str, attr_enum: Union[stem.util.enum.Enum, Sequence[stem.util.enum.Enum]]) -> None:
"""
Checks if an attribute exists in a given enumeration, logging a message if
it isn't. Attributes can either be for a string or collection of strings
@@ -195,7 +202,15 @@ class AddrMapEvent(Event):
'EXPIRES': 'utc_expiry',
'CACHED': 'cached',
}
- _OPTIONALLY_QUOTED = ('expiry')
+ _OPTIONALLY_QUOTED = ('expiry',)
+
+ def __init__(self):
+ self.hostname = None # type: Optional[str]
+ self.destination = None # type: Optional[str]
+ self.expiry = None # type: Optional[datetime.datetime]
+ self.error = None # type: Optional[str]
+ self.utc_expiry = None # type: Optional[datetime.datetime]
+ self.cached = None # type: Optional[bool]
def _parse(self) -> None:
if self.destination == '<error>':
@@ -235,6 +250,10 @@ class BandwidthEvent(Event):
_POSITIONAL_ARGS = ('read', 'written')
+ def __init__(self):
+ self.read = None # type: Optional[int]
+ self.written = None # type: Optional[int]
+
def _parse(self) -> None:
if not self.read:
raise stem.ProtocolError('BW event is missing its read value')
@@ -278,6 +297,17 @@ class BuildTimeoutSetEvent(Event):
}
_VERSION_ADDED = stem.version.Version('0.2.2.7-alpha')
+ def __init__(self):
+ self.set_type = None # type: Optional[stem.TimeoutSetType]
+ self.total_times = None # type: Optional[int]
+ self.timeout = None # type: Optional[int]
+ self.xm = None # type: Optional[int]
+ self.alpha = None # type: Optional[float]
+ self.quantile = None # type: Optional[float]
+ self.timeout_rate = None # type: Optional[float]
+ self.close_timeout = None # type: Optional[int]
+ self.close_rate = None # type: Optional[float]
+
def _parse(self) -> None:
# convert our integer and float parameters
@@ -347,6 +377,20 @@ class CircuitEvent(Event):
'SOCKS_PASSWORD': 'socks_password',
}
+ def __init__(self):
+ self.id = None # type: Optional[str]
+ self.status = None # type: Optional[stem.CircStatus]
+ self.path = None # type: Optional[Tuple[Tuple[str, str], ...]]
+ self.build_flags = None # type: Optional[Tuple[stem.CircBuildFlag, ...]]
+ self.purpose = None # type: Optional[stem.CircPurpose]
+ self.hs_state = None # type: Optional[stem.HiddenServiceState]
+ self.rend_query = None # type: Optional[str]
+ self.created = None # type: Optional[datetime.datetime]
+ self.reason = None # type: Optional[stem.CircClosureReason]
+ self.remote_reason = None # type: Optional[stem.CircClosureReason]
+ self.socks_username = None # type: Optional[str]
+ self.socks_password = None # type: Optional[str]
+
def _parse(self) -> None:
self.path = tuple(stem.control._parse_circ_path(self.path))
self.created = self._iso_timestamp(self.created)
@@ -415,6 +459,18 @@ class CircMinorEvent(Event):
}
_VERSION_ADDED = stem.version.Version('0.2.3.11-alpha')
+ def __init__(self):
+ self.id = None # type: Optional[str]
+ self.event = None # type: Optional[stem.CircEvent]
+ self.path = None # type: Optional[Tuple[Tuple[str, str], ...]]
+ self.build_flags = None # type: Optional[Tuple[stem.CircBuildFlag, ...]]
+ self.purpose = None # type: Optional[stem.CircPurpose]
+ self.hs_state = None # type: Optional[stem.HiddenServiceState]
+ self.rend_query = None # type: Optional[str]
+ self.created = None # type: Optional[datetime.datetime]
+ self.old_purpose = None # type: Optional[stem.CircPurpose]
+ self.old_hs_state = None # type: Optional[stem.HiddenServiceState]
+
def _parse(self) -> None:
self.path = tuple(stem.control._parse_circ_path(self.path))
self.created = self._iso_timestamp(self.created)
@@ -451,6 +507,11 @@ class ClientsSeenEvent(Event):
}
_VERSION_ADDED = stem.version.Version('0.2.1.10-alpha')
+ def __init__(self):
+ self.start_time = None # type: Optional[datetime.datetime]
+ self.locales = None # type: Optional[Dict[str, int]]
+ self.ip_versions = None # type: Optional[Dict[str, int]]
+
def _parse(self) -> None:
if self.start_time is not None:
self.start_time = stem.util.str_tools._parse_timestamp(self.start_time)
@@ -510,6 +571,10 @@ class ConfChangedEvent(Event):
_SKIP_PARSING = True
_VERSION_ADDED = stem.version.Version('0.2.3.3-alpha')
+ def __init__(self):
+ self.changed = {} # type: Dict[str, List[str]]
+ self.unset = [] # type: List[str]
+
def _parse(self) -> None:
self.changed = {}
self.unset = []
@@ -541,6 +606,9 @@ class DescChangedEvent(Event):
_VERSION_ADDED = stem.version.Version('0.1.2.2-alpha')
+ def __init__(self):
+ pass
+
class GuardEvent(Event):
"""
@@ -564,10 +632,14 @@ class GuardEvent(Event):
_VERSION_ADDED = stem.version.Version('0.1.2.5-alpha')
_POSITIONAL_ARGS = ('guard_type', 'endpoint', 'status')
- def _parse(self) -> None:
- self.endpoint_fingerprint = None
- self.endpoint_nickname = None
+ def __init__(self):
+ self.guard_type = None # type: Optional[stem.GuardType]
+ self.endpoint = None # type: Optional[str]
+ self.endpoint_fingerprint = None # type: Optional[str]
+ self.endpoint_nickname = None # type: Optional[str]
+ self.status = None # type: Optional[stem.GuardStatus]
+ def _parse(self) -> None:
try:
self.endpoint_fingerprint, self.endpoint_nickname = \
stem.control._parse_circ_entry(self.endpoint)
@@ -611,10 +683,19 @@ class HSDescEvent(Event):
_POSITIONAL_ARGS = ('action', 'address', 'authentication', 'directory', 'descriptor_id')
_KEYWORD_ARGS = {'REASON': 'reason', 'REPLICA': 'replica', 'HSDIR_INDEX': 'index'}
- def _parse(self) -> None:
- self.directory_fingerprint = None
- self.directory_nickname = None
+ def __init__(self):
+ self.action = None # type: Optional[stem.HSDescAction]
+ self.address = None # type: Optional[str]
+ self.authentication = None # type: Optional[stem.HSAuth]
+ self.directory = None # type: Optional[str]
+ self.directory_fingerprint = None # type: Optional[str]
+ self.directory_nickname = None # type: Optional[str]
+ self.descriptor_id = None # type: Optional[str]
+ self.reason = None # type: Optional[stem.HSDescReason]
+ self.replica = None # type: Optional[int]
+ self.index = None # type: Optional[str]
+ def _parse(self) -> None:
if self.directory != 'UNKNOWN':
try:
self.directory_fingerprint, self.directory_nickname = \
@@ -651,13 +732,18 @@ class HSDescContentEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.7.1-alpha')
_POSITIONAL_ARGS = ('address', 'descriptor_id', 'directory')
+ def __init__(self):
+ self.address = None # type: Optional[str]
+ self.descriptor_id = None # type: Optional[str]
+ self.directory = None # type: Optional[str]
+ self.directory_fingerprint = None # type: Optional[str]
+ self.directory_nickname = None # type: Optional[str]
+ self.descriptor = None # type: Optional[stem.descriptor.hidden_service.HiddenServiceDescriptorV2]
+
def _parse(self) -> None:
if self.address == 'UNKNOWN':
self.address = None
- self.directory_fingerprint = None
- self.directory_nickname = None
-
try:
self.directory_fingerprint, self.directory_nickname = \
stem.control._parse_circ_entry(self.directory)
@@ -687,6 +773,10 @@ class LogEvent(Event):
_SKIP_PARSING = True
+ def __init__(self):
+ self.runlevel = None # type: Optional[stem.Runlevel]
+ self.message = None # type: Optional[str]
+
def _parse(self) -> None:
self.runlevel = self.type
self._log_if_unrecognized('runlevel', stem.Runlevel)
@@ -710,6 +800,9 @@ class NetworkStatusEvent(Event):
_SKIP_PARSING = True
_VERSION_ADDED = stem.version.Version('0.1.2.3-alpha')
+ def __init__(self):
+ self.descriptors = None # type: Optional[List[stem.descriptor.router_status_entry.RouterStatusEntryV3]]
+
def _parse(self) -> None:
content = str(self).lstrip('NS\n').rstrip('\nOK')
@@ -735,6 +828,9 @@ class NetworkLivenessEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.7.2-alpha')
_POSITIONAL_ARGS = ('status',)
+ def __init__(self):
+ self.status = None # type: Optional[str]
+
class NewConsensusEvent(Event):
"""
@@ -754,11 +850,14 @@ class NewConsensusEvent(Event):
_SKIP_PARSING = True
_VERSION_ADDED = stem.version.Version('0.2.1.13-alpha')
+ def __init__(self):
+ self.consensus_content = None # type: Optional[str]
+ self._parsed = None # type: List[stem.descriptor.router_status_entry.RouterStatusEntryV3]
+
def _parse(self) -> None:
self.consensus_content = str(self).lstrip('NEWCONSENSUS\n').rstrip('\nOK')
- self._parsed = None
- def entries(self) -> Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']:
+ def entries(self) -> List[stem.descriptor.router_status_entry.RouterStatusEntryV3]:
"""
Relay router status entries residing within this consensus.
@@ -774,7 +873,7 @@ class NewConsensusEvent(Event):
entry_class = stem.descriptor.router_status_entry.RouterStatusEntryV3,
))
- return self._parsed
+ return list(self._parsed)
class NewDescEvent(Event):
@@ -792,6 +891,9 @@ class NewDescEvent(Event):
new descriptors
"""
+ def __init__(self):
+ self.relays = () # type: Tuple[Tuple[str, str], ...]
+
def _parse(self) -> None:
self.relays = tuple([stem.control._parse_circ_entry(entry) for entry in str(self).split()[1:]])
@@ -833,12 +935,18 @@ class ORConnEvent(Event):
'ID': 'id',
}
- def _parse(self) -> None:
- self.endpoint_fingerprint = None
- self.endpoint_nickname = None
- self.endpoint_address = None
- self.endpoint_port = None
+ def __init__(self):
+ self.id = None # type: Optional[str]
+ self.endpoint = None # type: Optional[str]
+ self.endpoint_fingerprint = None # type: Optional[str]
+ self.endpoint_nickname = None # type: Optional[str]
+ self.endpoint_address = None # type: Optional[str]
+ self.endpoint_port = None # type: Optional[int]
+ self.status = None # type: Optional[stem.ORStatus]
+ self.reason = None # type: Optional[stem.ORClosureReason]
+ self.circ_count = None # type: Optional[int]
+ def _parse(self) -> None:
try:
self.endpoint_fingerprint, self.endpoint_nickname = \
stem.control._parse_circ_entry(self.endpoint)
@@ -887,6 +995,9 @@ class SignalEvent(Event):
_POSITIONAL_ARGS = ('signal',)
_VERSION_ADDED = stem.version.Version('0.2.3.1-alpha')
+ def __init__(self):
+ self.signal = None # type: Optional[stem.Signal]
+
def _parse(self) -> None:
# log if we recieved an unrecognized signal
expected_signals = (
@@ -919,6 +1030,12 @@ class StatusEvent(Event):
_POSITIONAL_ARGS = ('runlevel', 'action')
_VERSION_ADDED = stem.version.Version('0.1.2.3-alpha')
+ def __init__(self):
+ self.status_type = None # type: Optional[stem.StatusType]
+ self.runlevel = None # type: Optional[stem.Runlevel]
+ self.action = None # type: Optional[str]
+ self.arguments = None # type: Optional[Dict[str, str]]
+
def _parse(self) -> None:
if self.type == 'STATUS_GENERAL':
self.status_type = stem.StatusType.GENERAL
@@ -971,6 +1088,21 @@ class StreamEvent(Event):
'PURPOSE': 'purpose',
}
+ def __init__(self):
+ self.id = None # type: Optional[str]
+ self.status = None # type: Optional[stem.StreamStatus]
+ self.circ_id = None # type: Optional[str]
+ self.target = None # type: Optional[str]
+ self.target_address = None # type: Optional[str]
+ self.target_port = None # type: Optional[int]
+ self.reason = None # type: Optional[stem.StreamClosureReason]
+ self.remote_reason = None # type: Optional[stem.StreamClosureReason]
+ self.source = None # type: Optional[stem.StreamSource]
+ self.source_addr = None # type: Optional[str]
+ self.source_address = None # type: Optional[str]
+ self.source_port = None # type: Optional[str]
+ self.purpose = None # type: Optional[stem.StreamPurpose]
+
def _parse(self) -> None:
if self.target is None:
raise stem.ProtocolError("STREAM event didn't have a target: %s" % self)
@@ -1030,6 +1162,12 @@ class StreamBwEvent(Event):
_POSITIONAL_ARGS = ('id', 'written', 'read', 'time')
_VERSION_ADDED = stem.version.Version('0.1.2.8-beta')
+ def __init__(self):
+ self.id = None # type: Optional[str]
+ self.written = None # type: Optional[int]
+ self.read = None # type: Optional[int]
+ self.time = None # type: Optional[datetime.datetime]
+
def _parse(self) -> None:
if not tor_tools.is_valid_stream_id(self.id):
raise stem.ProtocolError("Stream IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self))
@@ -1063,6 +1201,12 @@ class TransportLaunchedEvent(Event):
_POSITIONAL_ARGS = ('type', 'name', 'address', 'port')
_VERSION_ADDED = stem.version.Version('0.2.5.0-alpha')
+ def __init__(self):
+ self.type = None # type: Optional[str]
+ self.name = None # type: Optional[str]
+ self.address = None # type: Optional[str]
+ self.port = None # type: Optional[int]
+
def _parse(self) -> None:
if self.type not in ('server', 'client'):
raise stem.ProtocolError("Transport type should either be 'server' or 'client': %s" % self)
@@ -1105,6 +1249,12 @@ class ConnectionBandwidthEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
+ def __init__(self):
+ self.id = None # type: Optional[str]
+ self.conn_type = None # type: Optional[stem.ConnectionType]
+ self.read = None # type: Optional[int]
+ self.written = None # type: Optional[int]
+
def _parse(self) -> None:
if not self.id:
raise stem.ProtocolError('CONN_BW event is missing its id')
@@ -1164,6 +1314,16 @@ class CircuitBandwidthEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
+ def __init__(self):
+ self.id = None # type: Optional[str]
+ self.read = None # type: Optional[int]
+ self.written = None # type: Optional[int]
+ self.delivered_read = None # type: Optional[int]
+ self.delivered_written = None # type: Optional[int]
+ self.overhead_read = None # type: Optional[int]
+ self.overhead_written = None # type: Optional[int]
+ self.time = None # type: Optional[datetime.datetime]
+
def _parse(self) -> None:
if not self.id:
raise stem.ProtocolError('CIRC_BW event is missing its id')
@@ -1234,6 +1394,19 @@ class CellStatsEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
+ def __init__(self):
+ self.id = None # type: Optional[str]
+ self.inbound_queue = None # type: Optional[str]
+ self.inbound_connection = None # type: Optional[str]
+ self.inbound_added = None # type: Optional[Dict[str, int]]
+ self.inbound_removed = None # type: Optional[Dict[str, int]]
+ self.inbound_time = None # type: Optional[Dict[str, int]]
+ self.outbound_queue = None # type: Optional[str]
+ self.outbound_connection = None # type: Optional[str]
+ self.outbound_added = None # type: Optional[Dict[str, int]]
+ self.outbound_removed = None # type: Optional[Dict[str, int]]
+ self.outbound_time = None # type: Optional[Dict[str, int]]
+
def _parse(self) -> None:
if self.id and not tor_tools.is_valid_circuit_id(self.id):
raise stem.ProtocolError("Circuit IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self))
@@ -1280,6 +1453,13 @@ class TokenBucketEmptyEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
+ def __init__(self):
+ self.bucket = None # type: Optional[stem.TokenBucket]
+ self.id = None # type: Optional[str]
+ self.read = None # type: Optional[int]
+ self.written = None # type: Optional[int]
+ self.last_refill = None # type: Optional[int]
+
def _parse(self) -> None:
if self.id and not tor_tools.is_valid_connection_id(self.id):
raise stem.ProtocolError("Connection IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self))
diff --git a/stem/response/getconf.py b/stem/response/getconf.py
index 7ba972ae..6c65c4ec 100644
--- a/stem/response/getconf.py
+++ b/stem/response/getconf.py
@@ -4,6 +4,8 @@
import stem.response
import stem.socket
+from typing import Dict, List
+
class GetConfResponse(stem.response.ControlMessage):
"""
@@ -23,7 +25,7 @@ class GetConfResponse(stem.response.ControlMessage):
# 250-DataDirectory=/home/neena/.tor
# 250 DirPort
- self.entries = {}
+ self.entries = {} # type: Dict[str, List[str]]
remaining_lines = list(self)
if self.content() == [('250', ' ', 'OK')]:
diff --git a/stem/response/getinfo.py b/stem/response/getinfo.py
index 7aebd70a..9d9da21b 100644
--- a/stem/response/getinfo.py
+++ b/stem/response/getinfo.py
@@ -4,7 +4,7 @@
import stem.response
import stem.socket
-from typing import Sequence
+from typing import Dict, Set
class GetInfoResponse(stem.response.ControlMessage):
@@ -27,8 +27,8 @@ class GetInfoResponse(stem.response.ControlMessage):
# .
# 250 OK
- self.entries = {}
- remaining_lines = [content for (code, div, content) in self.content(get_bytes = True)]
+ self.entries = {} # type: Dict[str, bytes]
+ remaining_lines = [content for (code, div, content) in self._content_bytes()]
if not self.is_ok() or not remaining_lines.pop() == b'OK':
unrecognized_keywords = []
@@ -51,11 +51,11 @@ class GetInfoResponse(stem.response.ControlMessage):
while remaining_lines:
try:
- key, value = remaining_lines.pop(0).split(b'=', 1)
+ key_bytes, value = remaining_lines.pop(0).split(b'=', 1)
except ValueError:
raise stem.ProtocolError('GETINFO replies should only contain parameter=value mappings:\n%s' % self)
- key = stem.util.str_tools._to_unicode(key)
+ key = stem.util.str_tools._to_unicode(key_bytes)
# if the value is a multiline value then it *must* be of the form
# '<key>=\n<value>'
@@ -68,7 +68,7 @@ class GetInfoResponse(stem.response.ControlMessage):
self.entries[key] = value
- def _assert_matches(self, params: Sequence[str]) -> None:
+ def _assert_matches(self, params: Set[str]) -> None:
"""
Checks if we match a given set of parameters, and raise a ProtocolError if not.
diff --git a/stem/response/protocolinfo.py b/stem/response/protocolinfo.py
index 330b165e..c1387fab 100644
--- a/stem/response/protocolinfo.py
+++ b/stem/response/protocolinfo.py
@@ -9,6 +9,7 @@ import stem.version
import stem.util.str_tools
from stem.util import log
+from typing import Tuple
class ProtocolInfoResponse(stem.response.ControlMessage):
@@ -36,8 +37,8 @@ class ProtocolInfoResponse(stem.response.ControlMessage):
self.protocol_version = None
self.tor_version = None
- self.auth_methods = ()
- self.unknown_auth_methods = ()
+ self.auth_methods = () # type: Tuple[stem.connection.AuthMethod, ...]
+ self.unknown_auth_methods = () # type: Tuple[str, ...]
self.cookie_path = None
auth_methods, unknown_auth_methods = [], []
@@ -107,7 +108,7 @@ class ProtocolInfoResponse(stem.response.ControlMessage):
# parse optional COOKIEFILE mapping (quoted and can have escapes)
if line.is_next_mapping('COOKIEFILE', True, True):
- self.cookie_path = line.pop_mapping(True, True, get_bytes = True)[1].decode(sys.getfilesystemencoding())
+ self.cookie_path = line._pop_mapping_bytes(True, True)[1].decode(sys.getfilesystemencoding())
self.cookie_path = stem.util.str_tools._to_unicode(self.cookie_path) # normalize back to str
elif line_type == 'VERSION':
# Line format:
diff --git a/stem/socket.py b/stem/socket.py
index 179ae16e..b5da4b78 100644
--- a/stem/socket.py
+++ b/stem/socket.py
@@ -80,7 +80,7 @@ import stem.util.str_tools
from stem.util import log
from types import TracebackType
-from typing import BinaryIO, Callable, Optional, Type
+from typing import BinaryIO, Callable, List, Optional, Tuple, Type, Union, overload
MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]')
ERROR_MSG = 'Error while receiving a control message (%s): %s'
@@ -96,7 +96,8 @@ class BaseSocket(object):
"""
def __init__(self) -> None:
- self._socket, self._socket_file = None, None
+ self._socket = None # type: Optional[Union[socket.socket, ssl.SSLSocket]]
+ self._socket_file = None # type: Optional[BinaryIO]
self._is_alive = False
self._connection_time = 0.0 # time when we last connected or disconnected
@@ -218,7 +219,7 @@ class BaseSocket(object):
if is_change:
self._close()
- def _send(self, message: str, handler: Callable[[socket.socket, BinaryIO, str], None]) -> None:
+ def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None:
"""
Send message in a thread safe manner. Handler is expected to be of the form...
@@ -242,7 +243,15 @@ class BaseSocket(object):
raise
- def _recv(self, handler: Callable[[socket.socket, BinaryIO], None]) -> bytes:
+ @overload
+ def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes:
+ ...
+
+ @overload
+ def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage:
+ ...
+
+ def _recv(self, handler):
"""
Receives a message in a thread safe manner. Handler is expected to be of the form...
@@ -317,7 +326,7 @@ class BaseSocket(object):
pass
- def _make_socket(self) -> socket.socket:
+ def _make_socket(self) -> Union[socket.socket, ssl.SSLSocket]:
"""
Constructs and connects new socket. This is implemented by subclasses.
@@ -362,7 +371,7 @@ class RelaySocket(BaseSocket):
if connect:
self.connect()
- def send(self, message: str) -> None:
+ def send(self, message: Union[str, bytes]) -> None:
"""
Sends a message to the relay's ORPort.
@@ -389,26 +398,26 @@ class RelaySocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket closes before we receive a complete message
"""
- def wrapped_recv(s: socket.socket, sf: BinaryIO) -> bytes:
+ def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes:
if timeout is None:
- return s.recv()
+ return s.recv(1024)
else:
- s.setblocking(0)
+ s.setblocking(False)
s.settimeout(timeout)
try:
- return s.recv()
+ return s.recv(1024)
except (socket.timeout, ssl.SSLError, ssl.SSLWantReadError):
return None
finally:
- s.setblocking(1)
+ s.setblocking(True)
return self._recv(wrapped_recv)
def is_localhost(self) -> bool:
return self.address == '127.0.0.1'
- def _make_socket(self) -> socket.socket:
+ def _make_socket(self) -> ssl.SSLSocket:
try:
relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
relay_socket.connect((self.address, self.port))
@@ -430,7 +439,7 @@ class ControlSocket(BaseSocket):
def __init__(self) -> None:
super(ControlSocket, self).__init__()
- def send(self, message: str) -> None:
+ def send(self, message: Union[bytes, str]) -> None:
"""
Formats and sends a message to the control socket. For more information see
the :func:`~stem.socket.send_message` function.
@@ -536,7 +545,7 @@ class ControlSocketFile(ControlSocket):
raise stem.SocketError(exc)
-def send_message(control_file: BinaryIO, message: str, raw: bool = False) -> None:
+def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool = False) -> None:
"""
Sends a message to the control socket, adding the expected formatting for
single verses multi-line messages. Neither message type should contain an
@@ -568,6 +577,8 @@ def send_message(control_file: BinaryIO, message: str, raw: bool = False) -> Non
* :class:`stem.SocketClosed` if the socket is known to be shut down
"""
+ message = stem.util.str_tools._to_unicode(message)
+
if not raw:
message = send_formatting(message)
@@ -579,7 +590,7 @@ def send_message(control_file: BinaryIO, message: str, raw: bool = False) -> Non
log.trace('Sent to tor:%s%s' % (msg_div, log_message))
-def _write_to_socket(socket_file: BinaryIO, message: str) -> None:
+def _write_to_socket(socket_file: BinaryIO, message: Union[str, bytes]) -> None:
try:
socket_file.write(stem.util.str_tools._to_bytes(message))
socket_file.flush()
@@ -618,7 +629,9 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) ->
a complete message
"""
- parsed_content, raw_content, first_line = None, None, True
+ parsed_content = [] # type: List[Tuple[str, str, bytes]]
+ raw_content = bytearray()
+ first_line = True
while True:
try:
@@ -649,10 +662,10 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) ->
log.info(ERROR_MSG % ('SocketClosed', 'empty socket content'))
raise stem.SocketClosed('Received empty socket content.')
elif not MESSAGE_PREFIX.match(line):
- log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line)))
+ log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line.decode('utf-8'))))
raise stem.ProtocolError('Badly formatted reply line: beginning is malformed')
elif not line.endswith(b'\r\n'):
- log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line)))
+ log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line.decode('utf-8'))))
raise stem.ProtocolError('All lines should end with CRLF')
status_code, divider, content = line[:3], line[3:4], line[4:-2] # strip CRLF off content
@@ -691,11 +704,11 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) ->
line = control_file.readline()
raw_content += line
except socket.error as exc:
- log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content)))))
+ log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content).decode('utf-8')))))
raise stem.SocketClosed(exc)
if not line.endswith(b'\r\n'):
- log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content))))
+ log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content).decode('utf-8'))))
raise stem.ProtocolError('All lines should end with CRLF')
elif line == b'.\r\n':
break # data block termination
@@ -722,7 +735,7 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) ->
raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line)))
-def send_formatting(message: str) -> None:
+def send_formatting(message: str) -> str:
"""
Performs the formatting expected from sent control messages. For more
information see the :func:`~stem.socket.send_message` function.
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 050f6c91..498234cd 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -80,13 +80,15 @@ def datetime_to_unix(timestamp: 'datetime.datetime') -> float:
return (timestamp - datetime.datetime(1970, 1, 1)).total_seconds()
-def _pubkey_bytes(key: Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']) -> bytes:
+def _pubkey_bytes(key: Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']) -> bytes: # type: ignore
"""
Normalizes X25509 and ED25519 keys into their public key bytes.
"""
- if isinstance(key, (bytes, str)):
+ if isinstance(key, bytes):
return key
+ elif isinstance(key, str):
+ return key.encode('utf-8')
try:
from cryptography.hazmat.primitives import serialization
diff --git a/stem/util/conf.py b/stem/util/conf.py
index 37d1c5f4..1fd31fd0 100644
--- a/stem/util/conf.py
+++ b/stem/util/conf.py
@@ -162,8 +162,10 @@ import inspect
import os
import threading
+import stem.util.enum
+
from stem.util import log
-from typing import Any, Callable, Mapping, Optional, Sequence, Union
+from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union
CONFS = {} # mapping of identifier to singleton instances of configs
@@ -186,10 +188,10 @@ class _SyncListener(object):
if interceptor_value:
new_value = interceptor_value
- self.config_dict[key] = new_value
+ self.config_dict[key] = new_value # type: ignore
-def config_dict(handle: str, conf_mappings: Mapping[str, Any], handler: Optional[Callable[[str, Any], Any]] = None) -> Mapping[str, Any]:
+def config_dict(handle: str, conf_mappings: Dict[str, Any], handler: Optional[Callable[[str, Any], Any]] = None) -> Dict[str, Any]:
"""
Makes a dictionary that stays synchronized with a configuration.
@@ -308,7 +310,7 @@ def parse_enum(key: str, value: str, enumeration: 'stem.util.enum.Enum') -> Any:
return parse_enum_csv(key, value, enumeration, 1)[0]
-def parse_enum_csv(key: str, value: str, enumeration: 'stem.util.enum.Enum', count: Optional[Union[int, Sequence[int]]] = None) -> Sequence[Any]:
+def parse_enum_csv(key: str, value: str, enumeration: 'stem.util.enum.Enum', count: Optional[Union[int, Sequence[int]]] = None) -> List[Any]:
"""
Parses a given value as being a comma separated listing of enumeration keys,
returning the corresponding enumeration values. This is intended to be a
@@ -449,15 +451,15 @@ class Config(object):
"""
def __init__(self) -> None:
- self._path = None # location we last loaded from or saved to
- self._contents = collections.OrderedDict() # configuration key/value pairs
- self._listeners = [] # functors to be notified of config changes
+ self._path = None # type: Optional[str] # location we last loaded from or saved to
+ self._contents = collections.OrderedDict() # type: Dict[str, Any] # configuration key/value pairs
+ self._listeners = [] # type: List[Callable[['stem.util.conf.Config', str], Any]] # functors to be notified of config changes
# used for accessing _contents
self._contents_lock = threading.RLock()
# keys that have been requested (used to provide unused config contents)
- self._requested_keys = set()
+ self._requested_keys = set() # type: Set[str]
# flag to support lazy loading in uses_settings()
self._settings_loaded = False
@@ -577,7 +579,7 @@ class Config(object):
self._contents.clear()
self._requested_keys = set()
- def add_listener(self, listener: Callable[[str, Any], Any], backfill: bool = True) -> None:
+ def add_listener(self, listener: Callable[['stem.util.conf.Config', str], Any], backfill: bool = True) -> None:
"""
Registers the function to be notified of configuration updates. Listeners
are expected to be functors which accept (config, key).
@@ -600,7 +602,7 @@ class Config(object):
self._listeners = []
- def keys(self) -> Sequence[str]:
+ def keys(self) -> List[str]:
"""
Provides all keys in the currently loaded configuration.
@@ -609,7 +611,7 @@ class Config(object):
return list(self._contents.keys())
- def unused_keys(self) -> Sequence[str]:
+ def unused_keys(self) -> Set[str]:
"""
Provides the configuration keys that have never been provided to a caller
via :func:`~stem.util.conf.config_dict` or the
@@ -740,7 +742,7 @@ class Config(object):
return val
- def get_value(self, key: str, default: Optional[Any] = None, multiple: bool = False) -> Union[str, Sequence[str]]:
+ def get_value(self, key: str, default: Optional[Any] = None, multiple: bool = False) -> Union[str, List[str]]:
"""
This provides the current value associated with a given key.
diff --git a/stem/util/connection.py b/stem/util/connection.py
index 2f815a46..21745c43 100644
--- a/stem/util/connection.py
+++ b/stem/util/connection.py
@@ -65,7 +65,7 @@ import stem.util.proc
import stem.util.system
from stem.util import conf, enum, log, str_tools
-from typing import Optional, Sequence, Union
+from typing import List, Optional, Sequence, Tuple, Union
# Connection resolution is risky to log about since it's highly likely to
# contain sensitive information. That said, it's also difficult to get right in
@@ -158,15 +158,15 @@ class Connection(collections.namedtuple('Connection', ['local_address', 'local_p
"""
-def download(url: str, timeout: Optional[int] = None, retries: Optional[int] = None) -> bytes:
+def download(url: str, timeout: Optional[float] = None, retries: Optional[int] = None) -> bytes:
"""
Download from the given url.
.. versionadded:: 1.8.0
:param str url: uncompressed url to download from
- :param int timeout: timeout when connection becomes idle, no timeout applied
- if **None**
+ :param float timeout: timeout when connection becomes idle, no timeout
+ applied if **None**
:param int retires: maximum attempts to impose
:returns: **bytes** content of the given url
@@ -186,17 +186,17 @@ def download(url: str, timeout: Optional[int] = None, retries: Optional[int] = N
except socket.timeout as exc:
raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout)
except:
- exc, stacktrace = sys.exc_info()[1:3]
+ exception, stacktrace = sys.exc_info()[1:3]
if timeout is not None:
timeout -= time.time() - start_time
if retries > 0 and (timeout is None or timeout > 0):
- log.debug('Failed to download from %s (%i retries remaining): %s' % (url, retries, exc))
+ log.debug('Failed to download from %s (%i retries remaining): %s' % (url, retries, exception))
return download(url, timeout, retries - 1)
else:
- log.debug('Failed to download from %s: %s' % (url, exc))
- raise stem.DownloadFailed(url, exc, stacktrace)
+ log.debug('Failed to download from %s: %s' % (url, exception))
+ raise stem.DownloadFailed(url, exception, stacktrace)
def get_connections(resolver: Optional['stem.util.connection.Resolver'] = None, process_pid: Optional[int] = None, process_name: Optional[str] = None) -> Sequence['stem.util.connection.Connection']:
@@ -254,7 +254,7 @@ def get_connections(resolver: Optional['stem.util.connection.Resolver'] = None,
raise ValueError('Process pid was non-numeric: %s' % process_pid)
if process_pid is None:
- all_pids = stem.util.system.pid_by_name(process_name, True)
+ all_pids = stem.util.system.pid_by_name(process_name, True) # type: List[int] # type: ignore
if len(all_pids) == 0:
if resolver in (Resolver.NETSTAT_WINDOWS, Resolver.PROC, Resolver.BSD_PROCSTAT):
@@ -289,7 +289,7 @@ def get_connections(resolver: Optional['stem.util.connection.Resolver'] = None,
connections = []
resolver_regex = re.compile(resolver_regex_str)
- def _parse_address_str(addr_type: str, addr_str: str, line: str) -> str:
+ def _parse_address_str(addr_type: str, addr_str: str, line: str) -> Tuple[str, int]:
addr, port = addr_str.rsplit(':', 1)
if not is_valid_ipv4_address(addr) and not is_valid_ipv6_address(addr, allow_brackets = True):
@@ -524,8 +524,15 @@ def is_valid_port(entry: Union[str, int, Sequence[str], Sequence[int]], allow_ze
:returns: **True** if input is an integer and within the valid port range, **False** otherwise
"""
+ if isinstance(entry, (tuple, list)):
+ for port in entry:
+ if not is_valid_port(port, allow_zero):
+ return False
+
+ return True
+
try:
- value = int(entry)
+ value = int(entry) # type: ignore
if str(value) != str(entry):
return False # invalid leading char, e.g. space or zero
@@ -534,14 +541,7 @@ def is_valid_port(entry: Union[str, int, Sequence[str], Sequence[int]], allow_ze
else:
return value > 0 and value < 65536
except TypeError:
- if isinstance(entry, (tuple, list)):
- for port in entry:
- if not is_valid_port(port, allow_zero):
- return False
-
- return True
- else:
- return False
+ return False
except ValueError:
return False
@@ -621,6 +621,9 @@ def expand_ipv6_address(address: str) -> str:
:raises: **ValueError** if the address can't be expanded due to being malformed
"""
+ if isinstance(address, bytes):
+ address = str_tools._to_unicode(address)
+
if not is_valid_ipv6_address(address):
raise ValueError("'%s' isn't a valid IPv6 address" % address)
diff --git a/stem/util/enum.py b/stem/util/enum.py
index b70d29f4..719a4c06 100644
--- a/stem/util/enum.py
+++ b/stem/util/enum.py
@@ -40,10 +40,10 @@ constructed as simple type listings...
+- __iter__ - iterator over our enum keys
"""
-from typing import Iterator, Sequence
+from typing import Any, Iterator, List, Sequence, Tuple, Union
-def UppercaseEnum(*args: str) -> 'stem.util.enum.Enum':
+def UppercaseEnum(*args: str) -> 'Enum':
"""
Provides an :class:`~stem.util.enum.Enum` instance where the values are
identical to the keys. Since the keys are uppercase by convention this means
@@ -69,14 +69,15 @@ class Enum(object):
Basic enumeration.
"""
- def __init__(self, *args: str) -> None:
+ def __init__(self, *args: Union[str, Tuple[str, Any]]) -> None:
from stem.util.str_tools import _to_camel_case
# ordered listings of our keys and values
- keys, values = [], []
+ keys = [] # type: List[str]
+ values = [] # type: List[Any]
for entry in args:
- if isinstance(entry, (bytes, str)):
+ if isinstance(entry, str):
key, val = entry, _to_camel_case(entry)
elif isinstance(entry, tuple) and len(entry) == 2:
key, val = entry
@@ -99,11 +100,11 @@ class Enum(object):
return list(self._keys)
- def index_of(self, value: str) -> int:
+ def index_of(self, value: Any) -> int:
"""
Provides the index of the given value in the collection.
- :param str value: entry to be looked up
+ :param object value: entry to be looked up
:returns: **int** index of the given entry
@@ -112,11 +113,11 @@ class Enum(object):
return self._values.index(value)
- def next(self, value: str) -> str:
+ def next(self, value: Any) -> Any:
"""
Provides the next enumeration after the given value.
- :param str value: enumeration for which to get the next entry
+ :param object value: enumeration for which to get the next entry
:returns: enum value following the given entry
@@ -129,11 +130,11 @@ class Enum(object):
next_index = (self._values.index(value) + 1) % len(self._values)
return self._values[next_index]
- def previous(self, value: str) -> str:
+ def previous(self, value: Any) -> Any:
"""
Provides the previous enumeration before the given value.
- :param str value: enumeration for which to get the previous entry
+ :param object value: enumeration for which to get the previous entry
:returns: enum value proceeding the given entry
@@ -146,13 +147,13 @@ class Enum(object):
prev_index = (self._values.index(value) - 1) % len(self._values)
return self._values[prev_index]
- def __getitem__(self, item: str) -> str:
+ def __getitem__(self, item: str) -> Any:
"""
Provides the values for the given key.
- :param str item: key to be looked up
+ :param str item: key to look up
- :returns: **str** with the value for the given key
+ :returns: value for the given key
:raises: **ValueError** if the key doesn't exist
"""
@@ -163,7 +164,7 @@ class Enum(object):
keys = ', '.join(self.keys())
raise ValueError("'%s' isn't among our enumeration keys, which includes: %s" % (item, keys))
- def __iter__(self) -> Iterator[str]:
+ def __iter__(self) -> Iterator[Any]:
"""
Provides an ordered listing of the enums in this set.
"""
diff --git a/stem/util/log.py b/stem/util/log.py
index 940469a3..404249a7 100644
--- a/stem/util/log.py
+++ b/stem/util/log.py
@@ -172,7 +172,7 @@ def log(runlevel: 'stem.util.log.Runlevel', message: str) -> None:
LOGGER.log(LOG_VALUES[runlevel], message)
-def log_once(message_id: str, runlevel: 'stem.util.log.Runlevel', message: str) -> None:
+def log_once(message_id: str, runlevel: 'stem.util.log.Runlevel', message: str) -> bool:
"""
Logs a message at the given runlevel. If a message with this ID has already
been logged then this is a no-op.
@@ -189,6 +189,7 @@ def log_once(message_id: str, runlevel: 'stem.util.log.Runlevel', message: str)
else:
DEDUPLICATION_MESSAGE_IDS.add(message_id)
log(runlevel, message)
+ return True
# shorter aliases for logging at a runlevel
diff --git a/stem/util/proc.py b/stem/util/proc.py
index 10f2ae60..e180bb66 100644
--- a/stem/util/proc.py
+++ b/stem/util/proc.py
@@ -56,7 +56,7 @@ import stem.util.enum
import stem.util.str_tools
from stem.util import log
-from typing import Any, Mapping, Optional, Sequence, Set, Tuple, Type
+from typing import Any, Mapping, Optional, Sequence, Set, Tuple
try:
# unavailable on windows (#19823)
@@ -233,7 +233,7 @@ def memory_usage(pid: int) -> Tuple[int, int]:
raise exc
-def stats(pid: int, *stat_types: 'stem.util.proc.Stat') -> Sequence[Any]:
+def stats(pid: int, *stat_types: 'stem.util.proc.Stat') -> Sequence[str]:
"""
Provides process specific information. See the :data:`~stem.util.proc.Stat`
enum for valid options.
@@ -290,7 +290,7 @@ def stats(pid: int, *stat_types: 'stem.util.proc.Stat') -> Sequence[Any]:
results.append(str(float(stat_comp[14]) / CLOCK_TICKS))
elif stat_type == Stat.START_TIME:
if pid == 0:
- return system_start_time()
+ results.append(str(system_start_time()))
else:
# According to documentation, starttime is in field 21 and the unit is
# jiffies (clock ticks). We divide it for clock ticks, then add the
@@ -452,7 +452,7 @@ def _inodes_for_sockets(pid: int) -> Set[bytes]:
return inodes
-def _unpack_addr(addr: str) -> str:
+def _unpack_addr(addr: bytes) -> str:
"""
Translates an address entry in the /proc/net/* contents to a human readable
form (`reference <http://linuxdevcenter.com/pub/a/linux/2000/11/16/LinuxAdmin.html>`_,
@@ -554,7 +554,7 @@ def _get_lines(file_path: str, line_prefixes: Sequence[str], parameter: str) ->
raise
-def _log_runtime(parameter: str, proc_location: str, start_time: int) -> None:
+def _log_runtime(parameter: str, proc_location: str, start_time: float) -> None:
"""
Logs a message indicating a successful proc query.
@@ -567,7 +567,7 @@ def _log_runtime(parameter: str, proc_location: str, start_time: int) -> None:
log.debug('proc call (%s): %s (runtime: %0.4f)' % (parameter, proc_location, runtime))
-def _log_failure(parameter: str, exc: Type[Exception]) -> None:
+def _log_failure(parameter: str, exc: BaseException) -> None:
"""
Logs a message indicating that the proc query failed.
diff --git a/stem/util/str_tools.py b/stem/util/str_tools.py
index c606906a..a0bef734 100644
--- a/stem/util/str_tools.py
+++ b/stem/util/str_tools.py
@@ -26,7 +26,7 @@ import sys
import stem.util
import stem.util.enum
-from typing import Sequence, Tuple, Union
+from typing import List, Sequence, Tuple, Union, overload
# label conversion tuples of the form...
# (bits / bytes / seconds, short label, long label)
@@ -73,7 +73,7 @@ def _to_bytes(msg: Union[str, bytes]) -> bytes:
"""
if isinstance(msg, str):
- return codecs.latin_1_encode(msg, 'replace')[0]
+ return codecs.latin_1_encode(msg, 'replace')[0] # type: ignore
else:
return msg
@@ -95,7 +95,7 @@ def _to_unicode(msg: Union[str, bytes]) -> str:
return msg
-def _decode_b64(msg: Union[str, bytes]) -> str:
+def _decode_b64(msg: bytes) -> bytes:
"""
Base64 decode, without padding concerns.
"""
@@ -103,7 +103,7 @@ def _decode_b64(msg: Union[str, bytes]) -> str:
missing_padding = len(msg) % 4
padding_chr = b'=' if isinstance(msg, bytes) else '='
- return base64.b64decode(msg + padding_chr * missing_padding)
+ return base64.b64decode(msg + (padding_chr * missing_padding))
def _to_int(msg: Union[str, bytes]) -> int:
@@ -150,7 +150,17 @@ def _to_camel_case(label: str, divider: str = '_', joiner: str = ' ') -> str:
return joiner.join(words)
-def _split_by_length(msg: str, size: int) -> Sequence[str]:
+@overload
+def _split_by_length(msg: bytes, size: int) -> List[bytes]:
+ ...
+
+
+@overload
+def _split_by_length(msg: str, size: int) -> List[str]:
+ ...
+
+
+def _split_by_length(msg, size):
"""
Splits a string into a list of strings up to the given size.
@@ -174,7 +184,7 @@ def _split_by_length(msg: str, size: int) -> Sequence[str]:
Ending = stem.util.enum.Enum('ELLIPSE', 'HYPHEN')
-def crop(msg: str, size: int, min_word_length: int = 4, min_crop: int = 0, ending: 'stem.util.str_tools.Ending' = Ending.ELLIPSE, get_remainder: bool = False) -> str:
+def crop(msg: str, size: int, min_word_length: int = 4, min_crop: int = 0, ending: 'stem.util.str_tools.Ending' = Ending.ELLIPSE, get_remainder: bool = False) -> Union[str, Tuple[str, str]]:
"""
Shortens a string to a given length.
@@ -381,7 +391,7 @@ def time_labels(seconds: int, is_long: bool = False) -> Sequence[str]:
for count_per_unit, _, _ in TIME_UNITS:
if abs(seconds) >= count_per_unit:
time_labels.append(_get_label(TIME_UNITS, seconds, 0, is_long))
- seconds %= count_per_unit
+ seconds %= int(count_per_unit)
return time_labels
@@ -413,7 +423,7 @@ def short_time_label(seconds: int) -> str:
for amount, _, label in TIME_UNITS:
count = int(seconds / amount)
- seconds %= amount
+ seconds %= int(amount)
time_comp[label.strip()] = count
label = '%02i:%02i' % (time_comp['minute'], time_comp['second'])
@@ -471,7 +481,7 @@ def parse_short_time_label(label: str) -> int:
raise ValueError('Non-numeric value in time entry: %s' % label)
-def _parse_timestamp(entry: str) -> 'datetime.datetime':
+def _parse_timestamp(entry: str) -> datetime.datetime:
"""
Parses the date and time that in format like like...
@@ -535,7 +545,7 @@ def _parse_iso_timestamp(entry: str) -> 'datetime.datetime':
return timestamp + datetime.timedelta(microseconds = int(microseconds))
-def _get_label(units: Tuple[int, str, str], count: int, decimal: int, is_long: bool, round: bool = False) -> str:
+def _get_label(units: Sequence[Tuple[float, str, str]], count: int, decimal: int, is_long: bool, round: bool = False) -> str:
"""
Provides label corresponding to units of the highest significance in the
provided set. This rounds down (ie, integer truncation after visible units).
@@ -580,3 +590,5 @@ def _get_label(units: Tuple[int, str, str], count: int, decimal: int, is_long: b
return count_label + long_label + ('s' if is_plural else '')
else:
return count_label + short_label
+
+ raise ValueError('BUG: value should always be divisible by a unit (%s)' % str(units))
diff --git a/stem/util/system.py b/stem/util/system.py
index 8a61b2b9..a5147976 100644
--- a/stem/util/system.py
+++ b/stem/util/system.py
@@ -82,7 +82,7 @@ import stem.util.str_tools
from stem import UNDEFINED
from stem.util import log
-from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, TextIO, Union
+from typing import Any, BinaryIO, Callable, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union
State = stem.util.enum.UppercaseEnum(
'PENDING',
@@ -98,11 +98,11 @@ SIZE_RECURSES = {
dict: lambda d: itertools.chain.from_iterable(d.items()),
set: iter,
frozenset: iter,
-}
+} # type: Dict[Type, Callable]
# Mapping of commands to if they're available or not.
-CMD_AVAILABLE_CACHE = {}
+CMD_AVAILABLE_CACHE = {} # type: Dict[str, bool]
# An incomplete listing of commands provided by the shell. Expand this as
# needed. Some noteworthy things about shell commands...
@@ -186,11 +186,11 @@ class CallError(OSError):
:var str command: command that was ran
:var int exit_status: exit code of the process
:var float runtime: time the command took to run
- :var str stdout: stdout of the process
- :var str stderr: stderr of the process
+ :var bytes stdout: stdout of the process
+ :var bytes stderr: stderr of the process
"""
- def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: str, stderr: str) -> None:
+ def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: bytes, stderr: bytes) -> None:
self.msg = msg
self.command = command
self.exit_status = exit_status
@@ -211,7 +211,7 @@ class CallTimeoutError(CallError):
:var float timeout: time we waited
"""
- def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: str, stderr: str, timeout: float) -> None:
+ def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: bytes, stderr: bytes, timeout: float) -> None:
super(CallTimeoutError, self).__init__(msg, command, exit_status, runtime, stdout, stderr)
self.timeout = timeout
@@ -242,8 +242,8 @@ class DaemonTask(object):
self.result = None
self.error = None
- self._process = None
- self._pipe = None
+ self._process = None # type: Optional[multiprocessing.Process]
+ self._pipe = None # type: Optional[multiprocessing.connection.Connection]
if start:
self.run()
@@ -462,7 +462,7 @@ def is_running(command: Union[str, int, Sequence[str]]) -> bool:
return None
-def size_of(obj: Any, exclude: Optional[Sequence[int]] = None) -> int:
+def size_of(obj: Any, exclude: Optional[Collection[int]] = None) -> int:
"""
Provides the `approximate memory usage of an object
<https://code.activestate.com/recipes/577504/>`_. This can recurse tuples,
@@ -486,9 +486,9 @@ def size_of(obj: Any, exclude: Optional[Sequence[int]] = None) -> int:
if platform.python_implementation() == 'PyPy':
raise NotImplementedError('PyPy does not implement sys.getsizeof()')
- if exclude is None:
- exclude = set()
- elif id(obj) in exclude:
+ exclude = set(exclude) if exclude is not None else set()
+
+ if id(obj) in exclude:
return 0
try:
@@ -548,7 +548,7 @@ def name_by_pid(pid: int) -> Optional[str]:
return process_name
-def pid_by_name(process_name: str, multiple: bool = False) -> Union[int, Sequence[int]]:
+def pid_by_name(process_name: str, multiple: bool = False) -> Union[int, List[int]]:
"""
Attempts to determine the process id for a running process, using...
@@ -996,10 +996,8 @@ def user(pid: int) -> Optional[str]:
import pwd # only available on unix platforms
uid = stem.util.proc.uid(pid)
-
- if uid and uid.isdigit():
- return pwd.getpwuid(int(uid)).pw_name
- except:
+ return pwd.getpwuid(uid).pw_name
+ except ImportError:
pass
if is_available('ps'):
@@ -1042,7 +1040,7 @@ def start_time(pid: str) -> Optional[float]:
return None
-def tail(target: Union[str, TextIO], lines: Optional[int] = None) -> Iterator[str]:
+def tail(target: Union[str, BinaryIO], lines: Optional[int] = None) -> Iterator[str]:
"""
Provides lines of a file starting with the end. For instance,
'tail -n 50 /tmp/my_log' could be done with...
@@ -1061,8 +1059,8 @@ def tail(target: Union[str, TextIO], lines: Optional[int] = None) -> Iterator[st
if isinstance(target, str):
with open(target, 'rb') as target_file:
- for line in tail(target_file, lines):
- yield line
+ for tail_line in tail(target_file, lines):
+ yield tail_line
return
@@ -1299,7 +1297,7 @@ def call(command: Union[str, Sequence[str]], default: Any = UNDEFINED, ignore_ex
if timeout:
while process.poll() is None:
if time.time() - start_time > timeout:
- raise CallTimeoutError("Process didn't finish after %0.1f seconds" % timeout, ' '.join(command_list), None, timeout, '', '', timeout)
+ raise CallTimeoutError("Process didn't finish after %0.1f seconds" % timeout, ' '.join(command_list), None, timeout, b'', b'', timeout)
time.sleep(0.001)
@@ -1313,11 +1311,11 @@ def call(command: Union[str, Sequence[str]], default: Any = UNDEFINED, ignore_ex
trace_prefix = 'Received from system (%s)' % command
if stdout and stderr:
- log.trace(trace_prefix + ', stdout:\n%s\nstderr:\n%s' % (stdout, stderr))
+ log.trace(trace_prefix + ', stdout:\n%s\nstderr:\n%s' % (stdout.decode('utf-8'), stderr.decode('utf-8')))
elif stdout:
- log.trace(trace_prefix + ', stdout:\n%s' % stdout)
+ log.trace(trace_prefix + ', stdout:\n%s' % stdout.decode('utf-8'))
elif stderr:
- log.trace(trace_prefix + ', stderr:\n%s' % stderr)
+ log.trace(trace_prefix + ', stderr:\n%s' % stderr.decode('utf-8'))
exit_status = process.poll()
diff --git a/stem/util/term.py b/stem/util/term.py
index acc52cad..862767c4 100644
--- a/stem/util/term.py
+++ b/stem/util/term.py
@@ -72,14 +72,14 @@ CSI = '\x1B[%sm'
RESET = CSI % '0'
-def encoding(*attrs: Union['stem.util.terminal.Color', 'stem.util.terminal.BgColor', 'stem.util.terminal.Attr']) -> Optional[str]:
+def encoding(*attrs: Union['stem.util.term.Color', 'stem.util.term.BgColor', 'stem.util.term.Attr']) -> Optional[str]:
"""
Provides the ANSI escape sequence for these terminal color or attributes.
.. versionadded:: 1.5.0
- :param list attr: :data:`~stem.util.terminal.Color`,
- :data:`~stem.util.terminal.BgColor`, or :data:`~stem.util.terminal.Attr` to
+ :param list attr: :data:`~stem.util.term.Color`,
+ :data:`~stem.util.term.BgColor`, or :data:`~stem.util.term.Attr` to
provide an ecoding for
:returns: **str** of the ANSI escape sequence, **None** no attributes are
@@ -99,9 +99,11 @@ def encoding(*attrs: Union['stem.util.terminal.Color', 'stem.util.terminal.BgCol
if term_encodings:
return CSI % ';'.join(term_encodings)
+ else:
+ return None
-def format(msg: str, *attr: Union['stem.util.terminal.Color', 'stem.util.terminal.BgColor', 'stem.util.terminal.Attr']) -> str:
+def format(msg: str, *attr: Union['stem.util.term.Color', 'stem.util.term.BgColor', 'stem.util.term.Attr']) -> str:
"""
Simple terminal text formatting using `ANSI escape sequences
<https://en.wikipedia.org/wiki/ANSI_escape_code#CSI_codes>`_.
@@ -125,12 +127,12 @@ def format(msg: str, *attr: Union['stem.util.terminal.Color', 'stem.util.termina
"""
msg = stem.util.str_tools._to_unicode(msg)
+ attr = list(attr)
if DISABLE_COLOR_SUPPORT:
return msg
if Attr.LINES in attr:
- attr = list(attr)
attr.remove(Attr.LINES)
lines = [format(line, *attr) for line in msg.split('\n')]
return '\n'.join(lines)
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py
index 80de447e..34573450 100644
--- a/stem/util/test_tools.py
+++ b/stem/util/test_tools.py
@@ -44,7 +44,7 @@ import stem.util.conf
import stem.util.enum
import stem.util.system
-from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Type
+from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
CONFIG = stem.util.conf.config_dict('test', {
'pycodestyle.ignore': [],
@@ -53,8 +53,8 @@ CONFIG = stem.util.conf.config_dict('test', {
'exclude_paths': [],
})
-TEST_RUNTIMES = {}
-ASYNC_TESTS = {}
+TEST_RUNTIMES: Dict[str, float] = {}
+ASYNC_TESTS: Dict[str, 'stem.util.test_tools.AsyncTest'] = {}
AsyncStatus = stem.util.enum.UppercaseEnum('PENDING', 'RUNNING', 'FINISHED')
AsyncResult = collections.namedtuple('AsyncResult', 'type msg')
@@ -147,11 +147,11 @@ class AsyncTest(object):
self.method = lambda test: self.result(test) # method that can be mixed into TestCases
- self._process = None
- self._process_pipe = None
+ self._process = None # type: Optional[Union[threading.Thread, multiprocessing.Process]]
+ self._process_pipe = None # type: Optional[multiprocessing.connection.Connection]
self._process_lock = threading.RLock()
- self._result = None
+ self._result = None # type: Optional[stem.util.test_tools.AsyncResult]
self._status = AsyncStatus.PENDING
def run(self, *runner_args: Any, **kwargs: Any) -> None:
@@ -194,9 +194,9 @@ class AsyncTest(object):
self._process.start()
self._status = AsyncStatus.RUNNING
- def pid(self) -> int:
+ def pid(self) -> Optional[int]:
with self._process_lock:
- return self._process.pid if (self._process and not self._threaded) else None
+ return getattr(self._process, 'pid', None)
def join(self) -> None:
self.result(None)
@@ -238,9 +238,9 @@ class TimedTestRunner(unittest.TextTestRunner):
.. versionadded:: 1.6.0
"""
- def run(self, test: 'unittest.TestCase') -> None:
- for t in test._tests:
- original_type = type(t)
+ def run(self, test: Union[unittest.TestCase, unittest.TestSuite]) -> unittest.TestResult:
+ for t in getattr(test, '_tests', ()):
+ original_type = type(t) # type: Any
class _TestWrapper(original_type):
def run(self, result: Optional[Any] = None) -> Any:
@@ -273,7 +273,7 @@ class TimedTestRunner(unittest.TextTestRunner):
return super(TimedTestRunner, self).run(test)
-def test_runtimes() -> Mapping[str, float]:
+def test_runtimes() -> Dict[str, float]:
"""
Provides the runtimes of tests executed through TimedTestRunners.
@@ -286,7 +286,7 @@ def test_runtimes() -> Mapping[str, float]:
return dict(TEST_RUNTIMES)
-def clean_orphaned_pyc(paths: Sequence[str]) -> Sequence[str]:
+def clean_orphaned_pyc(paths: Sequence[str]) -> List[str]:
"""
Deletes any file with a \\*.pyc extention without a corresponding \\*.py. This
helps to address a common gotcha when deleting python files...
@@ -302,7 +302,7 @@ def clean_orphaned_pyc(paths: Sequence[str]) -> Sequence[str]:
:param list paths: paths to search for orphaned pyc files
- :returns: list of absolute paths that were deleted
+ :returns: **list** of absolute paths that were deleted
"""
orphaned_pyc = []
@@ -366,7 +366,7 @@ def is_mypy_available() -> bool:
return _module_exists('mypy.api')
-def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_exception_keyword: bool = False, prefer_single_quotes: bool = False) -> Mapping[str, 'stem.util.test_tools.Issue']:
+def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_exception_keyword: bool = False, prefer_single_quotes: bool = False) -> Dict[str, List['stem.util.test_tools.Issue']]:
"""
Checks for stylistic issues that are an issue according to the parts of PEP8
we conform to. You can suppress pycodestyle issues by making a 'test'
@@ -424,7 +424,7 @@ def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_e
:returns: dict of paths list of :class:`stem.util.test_tools.Issue` instances
"""
- issues = {}
+ issues = {} # type: Dict[str, List[stem.util.test_tools.Issue]]
ignore_rules = []
ignore_for_file = []
@@ -505,7 +505,7 @@ def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_e
return issues
-def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issue']:
+def pyflakes_issues(paths: Sequence[str]) -> Dict[str, List['stem.util.test_tools.Issue']]:
"""
Performs static checks via pyflakes. False positives can be ignored via
'pyflakes.ignore' entries in our 'test' config. For instance...
@@ -531,7 +531,7 @@ def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.
:returns: dict of paths list of :class:`stem.util.test_tools.Issue` instances
"""
- issues = {}
+ issues = {} # type: Dict[str, List[stem.util.test_tools.Issue]]
if is_pyflakes_available():
import pyflakes.api
@@ -539,19 +539,19 @@ def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.
class Reporter(pyflakes.reporter.Reporter):
def __init__(self) -> None:
- self._ignored_issues = {}
+ self._ignored_issues = {} # type: Dict[str, List[str]]
for line in CONFIG['pyflakes.ignore']:
path, issue = line.split('=>')
self._ignored_issues.setdefault(path.strip(), []).append(issue.strip())
- def unexpectedError(self, filename: str, msg: str) -> None:
+ def unexpectedError(self, filename: str, msg: 'pyflakes.messages.Message') -> None:
self._register_issue(filename, None, msg, None)
def syntaxError(self, filename: str, msg: str, lineno: int, offset: int, text: str) -> None:
self._register_issue(filename, lineno, msg, text)
- def flake(self, msg: str) -> None:
+ def flake(self, msg: 'pyflakes.messages.Message') -> None:
self._register_issue(msg.filename, msg.lineno, msg.message % msg.message_args, None)
def _register_issue(self, path: str, line_number: int, issue: str, line: str) -> None:
@@ -569,7 +569,7 @@ def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.
return issues
-def type_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issue']:
+def type_issues(args: Sequence[str]) -> Dict[str, List['stem.util.test_tools.Issue']]:
"""
Performs type checks via mypy. False positives can be ignored via
'mypy.ignore' entries in our 'test' config. For instance...
@@ -578,23 +578,25 @@ def type_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issu
mypy.ignore stem/util/system.py => Incompatible types in assignment*
- :param list paths: paths to search for problems
+ :param list args: mypy commmandline arguments
:returns: dict of paths list of :class:`stem.util.test_tools.Issue` instances
"""
- issues = {}
+ issues = {} # type: Dict[str, List[stem.util.test_tools.Issue]]
if is_mypy_available():
import mypy.api
- ignored_issues = {}
+ ignored_issues = {} # type: Dict[str, List[str]]
for line in CONFIG['mypy.ignore']:
path, issue = line.split('=>')
ignored_issues.setdefault(path.strip(), []).append(issue.strip())
- lines = mypy.api.run(paths)[0].splitlines() # mypy returns (report, errors, exit_status)
+ # mypy returns (report, errors, exit_status)
+
+ lines = mypy.api.run(args)[0].splitlines() # type: ignore
for line in lines:
# example:
@@ -606,13 +608,13 @@ def type_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issu
raise ValueError('Failed to parse mypy line: %s' % line)
path, line_number, _, issue = line.split(':', 3)
- issue = issue.strip()
- if line_number.isdigit():
- line_number = int(line_number)
- else:
+ if not line_number.isdigit():
raise ValueError('Malformed line number on: %s' % line)
+ issue = issue.strip()
+ line_number = int(line_number)
+
if _is_ignored(ignored_issues, path, issue):
continue
@@ -660,16 +662,21 @@ def _python_files(paths: Sequence[str]) -> Iterator[str]:
def _is_ignored(config: Mapping[str, Sequence[str]], path: str, issue: str) -> bool:
for ignored_path, ignored_issues in config.items():
- if path.endswith(ignored_path):
- if issue in ignored_issues:
- return True
-
- for prefix in [i[:1] for i in ignored_issues if i.endswith('*')]:
- if issue.startswith(prefix):
+ if ignored_path == '*' or path.endswith(ignored_path):
+ for ignored_issue in ignored_issues:
+ if issue == ignored_issue:
return True
- for suffix in [i[1:] for i in ignored_issues if i.startswith('*')]:
- if issue.endswith(suffix):
- return True
+ # TODO: try using glob module instead?
+
+ if ignored_issue.startswith('*') and ignored_issue.endswith('*'):
+ if ignored_issue[1:-1] in issue:
+ return True # substring match
+ elif ignored_issue.startswith('*'):
+ if issue.endswith(ignored_issue[1:]):
+ return True # prefix match
+ elif ignored_issue.endswith('*'):
+ if issue.startswith(ignored_issue[:-1]):
+ return True # suffix match
return False
diff --git a/stem/version.py b/stem/version.py
index 8ec35293..6ef7c890 100644
--- a/stem/version.py
+++ b/stem/version.py
@@ -72,9 +72,9 @@ def get_system_tor_version(tor_cmd: str = 'tor') -> 'stem.version.Version':
if 'No such file or directory' in str(exc):
if os.path.isabs(tor_cmd):
- exc = "Unable to check tor's version. '%s' doesn't exist." % tor_cmd
+ raise IOError("Unable to check tor's version. '%s' doesn't exist." % tor_cmd)
else:
- exc = "Unable to run '%s'. Maybe tor isn't in your PATH?" % version_cmd
+ raise IOError("Unable to run '%s'. Maybe tor isn't in your PATH?" % version_cmd)
raise IOError(exc)
@@ -132,13 +132,12 @@ class Version(object):
version_parts = VERSION_PATTERN.match(version_str)
if version_parts:
- major, minor, micro, patch, status, extra_str, _ = version_parts.groups()
+ major, minor, micro, patch_str, status, extra_str, _ = version_parts.groups()
# The patch and status matches are optional (may be None) and have an extra
# proceeding period or dash if they exist. Stripping those off.
- if patch:
- patch = int(patch[1:])
+ patch = int(patch_str[1:]) if patch_str else None
if status:
status = status[1:]
@@ -166,7 +165,7 @@ class Version(object):
return self.version_str
- def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> Callable[[Any, Any], bool]:
+ def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool:
"""
Compares version ordering according to the spec.
"""
diff --git a/test/arguments.py b/test/arguments.py
index d0f0dc3f..e06148c4 100644
--- a/test/arguments.py
+++ b/test/arguments.py
@@ -5,13 +5,14 @@
Commandline argument parsing for our test runner.
"""
-import collections
import getopt
import stem.util.conf
import stem.util.log
import test
+from typing import Any, Dict, List, NamedTuple, Optional, Sequence
+
LOG_TYPE_ERROR = """\
'%s' isn't a logging runlevel, use one of the following instead:
TRACE, DEBUG, INFO, NOTICE, WARN, ERROR
@@ -23,138 +24,136 @@ CONFIG = stem.util.conf.config_dict('test', {
'target.torrc': {},
})
-DEFAULT_ARGS = {
- 'run_unit': False,
- 'run_integ': False,
- 'specific_test': [],
- 'exclude_test': [],
- 'logging_runlevel': None,
- 'logging_path': None,
- 'tor_path': 'tor',
- 'run_targets': [test.Target.RUN_OPEN],
- 'attribute_targets': [],
- 'quiet': False,
- 'verbose': False,
- 'print_help': False,
-}
-
OPT = 'auit:l:qvh'
OPT_EXPANDED = ['all', 'unit', 'integ', 'targets=', 'test=', 'exclude-test=', 'log=', 'log-file=', 'tor=', 'quiet', 'verbose', 'help']
-def parse(argv):
- """
- Parses our arguments, providing a named tuple with their values.
+class Arguments(NamedTuple):
+ run_unit: bool = False
+ run_integ: bool = False
+ specific_test: List[str] = []
+ exclude_test: List[str] = []
+ logging_runlevel: Optional[str] = None
+ logging_path: Optional[str] = None
+ tor_path: str = 'tor'
+ run_targets: List['test.Target'] = [test.Target.RUN_OPEN]
+ attribute_targets: List['test.Target'] = []
+ quiet: bool = False
+ verbose: bool = False
+ print_help: bool = False
+
+ @staticmethod
+ def parse(argv: Sequence[str]) -> 'test.arguments.Arguments':
+ """
+ Parses our commandline arguments into this class.
+
+ :param list argv: input arguments to be parsed
+
+ :returns: :class:`test.arguments.Arguments` for this commandline input
+
+ :raises: **ValueError** if we got an invalid argument
+ """
- :param list argv: input arguments to be parsed
+ args = {} # type: Dict[str, Any]
- :returns: a **named tuple** with our parsed arguments
+ try:
+ recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) # type: ignore
- :raises: **ValueError** if we got an invalid argument
- """
+ if unrecognized_args:
+ error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument"
+ raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg))
+ except Exception as exc:
+ raise ValueError('%s (for usage provide --help)' % exc)
- args = dict(DEFAULT_ARGS)
-
- try:
- recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED)
-
- if unrecognized_args:
- error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument"
- raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg))
- except Exception as exc:
- raise ValueError('%s (for usage provide --help)' % exc)
-
- for opt, arg in recognized_args:
- if opt in ('-a', '--all'):
- args['run_unit'] = True
- args['run_integ'] = True
- elif opt in ('-u', '--unit'):
- args['run_unit'] = True
- elif opt in ('-i', '--integ'):
- args['run_integ'] = True
- elif opt in ('-t', '--targets'):
- run_targets, attribute_targets = [], []
-
- integ_targets = arg.split(',')
- all_run_targets = [t for t in test.Target if CONFIG['target.torrc'].get(t) is not None]
-
- # validates the targets and split them into run and attribute targets
-
- if not integ_targets:
- raise ValueError('No targets provided')
-
- for target in integ_targets:
- if target not in test.Target:
- raise ValueError('Invalid integration target: %s' % target)
- elif target in all_run_targets:
- run_targets.append(target)
- else:
- attribute_targets.append(target)
-
- # check if we were told to use all run targets
-
- if test.Target.RUN_ALL in attribute_targets:
- attribute_targets.remove(test.Target.RUN_ALL)
- run_targets = all_run_targets
-
- # if no RUN_* targets are provided then keep the default (otherwise we
- # won't have any tests to run)
-
- if run_targets:
- args['run_targets'] = run_targets
-
- args['attribute_targets'] = attribute_targets
- elif opt == '--test':
- args['specific_test'].append(crop_module_name(arg))
- elif opt == '--exclude-test':
- args['exclude_test'].append(crop_module_name(arg))
- elif opt in ('-l', '--log'):
- arg = arg.upper()
-
- if arg not in stem.util.log.LOG_VALUES:
- raise ValueError(LOG_TYPE_ERROR % arg)
-
- args['logging_runlevel'] = arg
- elif opt == '--log-file':
- args['logging_path'] = arg
- elif opt in ('--tor'):
- args['tor_path'] = arg
- elif opt in ('-q', '--quiet'):
- args['quiet'] = True
- elif opt in ('-v', '--verbose'):
- args['verbose'] = True
- elif opt in ('-h', '--help'):
- args['print_help'] = True
-
- # translates our args dict into a named tuple
-
- Args = collections.namedtuple('Args', args.keys())
- return Args(**args)
-
-
-def get_help():
- """
- Provides usage information, as provided by the '--help' argument. This
- includes a listing of the valid integration targets.
+ for opt, arg in recognized_args:
+ if opt in ('-a', '--all'):
+ args['run_unit'] = True
+ args['run_integ'] = True
+ elif opt in ('-u', '--unit'):
+ args['run_unit'] = True
+ elif opt in ('-i', '--integ'):
+ args['run_integ'] = True
+ elif opt in ('-t', '--targets'):
+ run_targets, attribute_targets = [], []
- :returns: **str** with our usage information
- """
+ integ_targets = arg.split(',')
+ all_run_targets = [t for t in test.Target if CONFIG['target.torrc'].get(t) is not None]
+
+ # validates the targets and split them into run and attribute targets
+
+ if not integ_targets:
+ raise ValueError('No targets provided')
+
+ for target in integ_targets:
+ if target not in test.Target:
+ raise ValueError('Invalid integration target: %s' % target)
+ elif target in all_run_targets:
+ run_targets.append(target)
+ else:
+ attribute_targets.append(target)
+
+ # check if we were told to use all run targets
+
+ if test.Target.RUN_ALL in attribute_targets:
+ attribute_targets.remove(test.Target.RUN_ALL)
+ run_targets = all_run_targets
+
+ # if no RUN_* targets are provided then keep the default (otherwise we
+ # won't have any tests to run)
+
+ if run_targets:
+ args['run_targets'] = run_targets
+
+ args['attribute_targets'] = attribute_targets
+ elif opt == '--test':
+ args['specific_test'].append(crop_module_name(arg))
+ elif opt == '--exclude-test':
+ args['exclude_test'].append(crop_module_name(arg))
+ elif opt in ('-l', '--log'):
+ arg = arg.upper()
+
+ if arg not in stem.util.log.LOG_VALUES:
+ raise ValueError(LOG_TYPE_ERROR % arg)
+
+ args['logging_runlevel'] = arg
+ elif opt == '--log-file':
+ args['logging_path'] = arg
+ elif opt in ('--tor'):
+ args['tor_path'] = arg
+ elif opt in ('-q', '--quiet'):
+ args['quiet'] = True
+ elif opt in ('-v', '--verbose'):
+ args['verbose'] = True
+ elif opt in ('-h', '--help'):
+ args['print_help'] = True
+
+ return Arguments(**args)
+
+ @staticmethod
+ def get_help() -> str:
+ """
+ Provides usage information, as provided by the '--help' argument. This
+ includes a listing of the valid integration targets.
+
+ :returns: **str** with our usage information
+ """
+
+ help_msg = CONFIG['msg.help']
- help_msg = CONFIG['msg.help']
+ # gets the longest target length so we can show the entries in columns
- # gets the longest target length so we can show the entries in columns
- target_name_length = max(map(len, test.Target))
- description_format = '\n %%-%is - %%s' % target_name_length
+ target_name_length = max(map(len, test.Target))
+ description_format = '\n %%-%is - %%s' % target_name_length
- for target in test.Target:
- help_msg += description_format % (target, CONFIG['target.description'].get(target, ''))
+ for target in test.Target:
+ help_msg += description_format % (target, CONFIG['target.description'].get(target, ''))
- help_msg += '\n'
+ help_msg += '\n'
- return help_msg
+ return help_msg
-def crop_module_name(name):
+def crop_module_name(name: str) -> str:
"""
Test modules have a 'test.unit.' or 'test.integ.' prefix which can
be omitted from our '--test' argument. Cropping this so we can do
diff --git a/test/mypy.ini b/test/mypy.ini
new file mode 100644
index 00000000..1c77449a
--- /dev/null
+++ b/test/mypy.ini
@@ -0,0 +1,6 @@
+[mypy]
+allow_redefinition = True
+ignore_missing_imports = True
+show_error_codes = True
+strict_optional = False
+warn_unused_ignores = True
diff --git a/test/settings.cfg b/test/settings.cfg
index 8c6423bb..51109f96 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -196,10 +196,12 @@ pyflakes.ignore stem/manual.py => undefined name 'sqlite3'
pyflakes.ignore stem/client/cell.py => undefined name 'cryptography'
pyflakes.ignore stem/client/cell.py => undefined name 'hashlib'
pyflakes.ignore stem/client/datatype.py => redefinition of unused 'pop' from *
+pyflakes.ignore stem/descriptor/__init__.py => undefined name 'cryptography'
pyflakes.ignore stem/descriptor/hidden_service.py => undefined name 'cryptography'
pyflakes.ignore stem/interpreter/autocomplete.py => undefined name 'stem'
pyflakes.ignore stem/interpreter/help.py => undefined name 'stem'
pyflakes.ignore stem/response/events.py => undefined name 'datetime'
+pyflakes.ignore stem/socket.py => redefinition of unused '_recv'*
pyflakes.ignore stem/util/conf.py => undefined name 'stem'
pyflakes.ignore stem/util/enum.py => undefined name 'stem'
pyflakes.ignore test/require.py => 'cryptography.utils.int_from_bytes' imported but unused
@@ -214,6 +216,23 @@ pyflakes.ignore test/unit/response/events.py => 'from stem import *' used; unabl
pyflakes.ignore test/unit/response/events.py => *may be undefined, or defined from star imports: stem
pyflakes.ignore test/integ/interpreter.py => 'readline' imported but unused
+# Our enum class confuses mypy. Ignore this until we can change to python 3.x's
+# new enum builtin.
+#
+# For example...
+#
+# See https://mypy.readthedocs.io/en/latest/common_issues.html#variables-vs-type-aliases
+# Variable "stem.control.EventType" is not valid as a type [valid-type]
+
+mypy.ignore * => "Enum" has no attribute *
+mypy.ignore * => "_IntegerEnum" has no attribute *
+mypy.ignore * => See https://mypy.readthedocs.io/en/latest/common_issues.html*
+mypy.ignore * => *is not valid as a type*
+
+# Metaprogramming prevents mypy from determining descriptor attributes.
+
+mypy.ignore * => "Descriptor" has no attribute "*
+
# Test modules we want to run. Modules are roughly ordered by the dependencies
# so the lowest level tests come first. This is because a problem in say,
# controller message parsing, will cause all higher level tests to fail too.
diff --git a/test/task.py b/test/task.py
index 2366564c..b2957e65 100644
--- a/test/task.py
+++ b/test/task.py
@@ -355,7 +355,7 @@ PYCODESTYLE_TASK = StaticCheckTask(
MYPY_TASK = StaticCheckTask(
'running mypy',
stem.util.test_tools.type_issues,
- args = ([os.path.join(test.STEM_BASE, 'stem')],),
+ args = (['--config-file', os.path.join(test.STEM_BASE, 'test', 'mypy.ini'), os.path.join(test.STEM_BASE, 'stem')],),
is_available = stem.util.test_tools.is_mypy_available(),
unavailable_msg = MYPY_UNAVAILABLE,
)
diff --git a/test/unit/client/address.py b/test/unit/client/address.py
index c4e51f4e..352a8d8c 100644
--- a/test/unit/client/address.py
+++ b/test/unit/client/address.py
@@ -50,7 +50,7 @@ class TestAddress(unittest.TestCase):
self.assertEqual(AddrType.UNKNOWN, addr.type)
self.assertEqual(12, addr.type_int)
self.assertEqual(None, addr.value)
- self.assertEqual('hello', addr.value_bin)
+ self.assertEqual(b'hello', addr.value_bin)
def test_packing(self):
test_data = {
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index 37e252b9..02ed2774 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -206,7 +206,7 @@ class TestControl(unittest.TestCase):
get_info_mock.side_effect = InvalidArguments
- get_conf_mock.side_effect = lambda param, **kwargs: {
+ get_conf_mock.side_effect = lambda param, *args, **kwargs: {
'ControlPort': '9050',
'ControlListenAddress': ['127.0.0.1'],
}[param]
@@ -217,7 +217,7 @@ class TestControl(unittest.TestCase):
# non-local addresss
- get_conf_mock.side_effect = lambda param, **kwargs: {
+ get_conf_mock.side_effect = lambda param, *args, **kwargs: {
'ControlPort': '9050',
'ControlListenAddress': ['27.4.4.1'],
}[param]
@@ -679,7 +679,7 @@ class TestControl(unittest.TestCase):
# check default if nothing was set
- get_conf_mock.side_effect = lambda param, **kwargs: {
+ get_conf_mock.side_effect = lambda param, *args, **kwargs: {
'BandwidthRate': '1073741824',
'BandwidthBurst': '1073741824',
'RelayBandwidthRate': '0',
diff --git a/test/unit/descriptor/bandwidth_file.py b/test/unit/descriptor/bandwidth_file.py
index 9bee5f95..5e56f9d2 100644
--- a/test/unit/descriptor/bandwidth_file.py
+++ b/test/unit/descriptor/bandwidth_file.py
@@ -7,6 +7,7 @@ import datetime
import unittest
import stem.descriptor
+import stem.util.str_tools
from unittest.mock import Mock, patch
@@ -334,5 +335,5 @@ class TestBandwidthFile(unittest.TestCase):
)
for value in test_values:
- expected_exc = "First line should be a unix timestamp, but was '%s'" % value
+ expected_exc = "First line should be a unix timestamp, but was '%s'" % stem.util.str_tools._to_unicode(value)
self.assertRaisesWith(ValueError, expected_exc, BandwidthFile.create, {'timestamp': value})
diff --git a/test/unit/interpreter/arguments.py b/test/unit/interpreter/arguments.py
index df81e7e3..d61de42d 100644
--- a/test/unit/interpreter/arguments.py
+++ b/test/unit/interpreter/arguments.py
@@ -1,39 +1,39 @@
import unittest
-from stem.interpreter.arguments import DEFAULT_ARGS, parse, get_help
+from stem.interpreter.arguments import Arguments
class TestArgumentParsing(unittest.TestCase):
def test_that_we_get_default_values(self):
- args = parse([])
+ args = Arguments.parse([])
- for attr in DEFAULT_ARGS:
- self.assertEqual(DEFAULT_ARGS[attr], getattr(args, attr))
+ for attr, value in Arguments._field_defaults.items():
+ self.assertEqual(value, getattr(args, attr))
def test_that_we_load_arguments(self):
- args = parse(['--interface', '10.0.0.25:80'])
+ args = Arguments.parse(['--interface', '10.0.0.25:80'])
self.assertEqual('10.0.0.25', args.control_address)
self.assertEqual(80, args.control_port)
- args = parse(['--interface', '80'])
- self.assertEqual(DEFAULT_ARGS['control_address'], args.control_address)
+ args = Arguments.parse(['--interface', '80'])
+ self.assertEqual('127.0.0.1', args.control_address)
self.assertEqual(80, args.control_port)
- args = parse(['--socket', '/tmp/my_socket'])
+ args = Arguments.parse(['--socket', '/tmp/my_socket'])
self.assertEqual('/tmp/my_socket', args.control_socket)
- args = parse(['--help'])
+ args = Arguments.parse(['--help'])
self.assertEqual(True, args.print_help)
def test_examples(self):
- args = parse(['-i', '1643'])
+ args = Arguments.parse(['-i', '1643'])
self.assertEqual(1643, args.control_port)
- args = parse(['-s', '~/.tor/socket'])
+ args = Arguments.parse(['-s', '~/.tor/socket'])
self.assertEqual('~/.tor/socket', args.control_socket)
def test_that_we_reject_unrecognized_arguments(self):
- self.assertRaises(ValueError, parse, ['--blarg', 'stuff'])
+ self.assertRaises(ValueError, Arguments.parse, ['--blarg', 'stuff'])
def test_that_we_reject_invalid_interfaces(self):
invalid_inputs = (
@@ -49,15 +49,15 @@ class TestArgumentParsing(unittest.TestCase):
)
for invalid_input in invalid_inputs:
- self.assertRaises(ValueError, parse, ['--interface', invalid_input])
+ self.assertRaises(ValueError, Arguments.parse, ['--interface', invalid_input])
def test_run_with_command(self):
- self.assertEqual('GETINFO version', parse(['--run', 'GETINFO version']).run_cmd)
+ self.assertEqual('GETINFO version', Arguments.parse(['--run', 'GETINFO version']).run_cmd)
def test_run_with_path(self):
- self.assertEqual(__file__, parse(['--run', __file__]).run_path)
+ self.assertEqual(__file__, Arguments.parse(['--run', __file__]).run_path)
def test_get_help(self):
- help_text = get_help()
+ help_text = Arguments.get_help()
self.assertTrue('Interactive interpreter for Tor.' in help_text)
self.assertTrue('change control interface from 127.0.0.1:default' in help_text)
diff --git a/test/unit/util/proc.py b/test/unit/util/proc.py
index 2316f669..39087cbe 100644
--- a/test/unit/util/proc.py
+++ b/test/unit/util/proc.py
@@ -147,18 +147,17 @@ class TestProc(unittest.TestCase):
# tests the case where pid = 0
- if 'start time' in args:
- response = 10
- else:
- response = ()
-
- for arg in args:
- if arg == 'command':
- response += ('sched',)
- elif arg == 'utime':
- response += ('0',)
- elif arg == 'stime':
- response += ('0',)
+ response = ()
+
+ for arg in args:
+ if arg == 'command':
+ response += ('sched',)
+ elif arg == 'utime':
+ response += ('0',)
+ elif arg == 'stime':
+ response += ('0',)
+ elif arg == 'start time':
+ response += ('10',)
get_line_mock.side_effect = lambda *params: {
('/proc/0/stat', '0', 'process %s' % ', '.join(args)): stat
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits