[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Compression support for ORPort descriptor downloads
commit 5a875ed329ddb119c2b6787e051571fb7323f622
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date: Sat Apr 28 16:47:42 2018 -0700
Compression support for ORPort descriptor downloads
Compression works the same regardless of of if we download from an ORPort or
DirPort. Also including a couple python3 compatibility fixes for circuit
construction.
---
stem/client/__init__.py | 4 +-
stem/descriptor/remote.py | 88 +++++++++++++++++++++++++++---------------
test/unit/descriptor/remote.py | 10 ++---
3 files changed, 62 insertions(+), 40 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index aa4aa274..6e25f748 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -168,7 +168,7 @@ class Relay(object):
if not created_fast_cells:
raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
- created_fast_cell = created_fast_cells[0]
+ created_fast_cell = list(created_fast_cells)[0]
kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
if created_fast_cell.derivative_key != kdf.key_hash:
@@ -211,7 +211,7 @@ class Circuit(object):
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
- ctr = modes.CTR(ZERO * (algorithms.AES.block_size / 8))
+ ctr = modes.CTR(ZERO * (algorithms.AES.block_size // 8))
self.relay = relay
self.id = circ_id
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index e745ce20..d6832eaa 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -254,7 +254,7 @@ def get_consensus(authority_v3ident = None, microdescriptor = False, **query_arg
return get_instance().get_consensus(authority_v3ident, microdescriptor, **query_args)
-def _download_from_orport(endpoint, resource):
+def _download_from_orport(endpoint, compression, resource):
"""
Downloads descriptors from the given orport. Payload is just like an http
response (headers and all)...
@@ -273,6 +273,7 @@ def _download_from_orport(endpoint, resource):
... rest of the descriptor content...
:param stem.ORPort endpoint: endpoint to download from
+ :param list compression: compression methods for the request
:param str resource: descriptor resource to download
:returns: two value tuple of the form (data, reply_headers)
@@ -286,26 +287,30 @@ def _download_from_orport(endpoint, resource):
with stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
with relay.create_circuit() as circ:
+ request = '\r\n'.join((
+ 'GET %s HTTP/1.0' % resource,
+ 'Accept-Encoding: %s' % ', '.join(compression),
+ 'User-Agent: Stem/%s' % stem.__version__,
+ )) + '\r\n\r\n'
+
circ.send('RELAY_BEGIN_DIR', stream_id = 1)
- lines = b''.join([cell.data for cell in circ.send('RELAY_DATA', 'GET %s HTTP/1.0\r\n\r\n' % resource, stream_id = 1)]).splitlines()
- first_line = lines.pop(0)
+ response = b''.join([cell.data for cell in circ.send('RELAY_DATA', request, stream_id = 1)])
+ first_line, data = response.split(b'\r\n', 1)
+ header_data, data = data.split(b'\r\n\r\n', 1)
- if first_line != 'HTTP/1.0 200 OK':
+ if first_line != b'HTTP/1.0 200 OK':
raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % first_line)
headers = {}
- next_line = lines.pop(0)
- while next_line:
- if ': ' not in next_line:
- raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % next_line)
+ for line in str_tools._to_unicode(header_data).splitlines():
+ if ': ' not in line:
+ raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % line)
- key, value = next_line.split(': ', 1)
+ key, value = line.split(': ', 1)
headers[key] = value
- next_line = lines.pop(0)
-
- return '\n'.join(lines), headers
+ return _decompress(data, headers.get('Content-Encoding')), headers
def _download_from_dirport(url, compression, timeout):
@@ -334,29 +339,49 @@ def _download_from_dirport(url, compression, timeout):
timeout = timeout,
)
- data = response.read()
- encoding = response.headers.get('Content-Encoding')
+ return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
- # Tor doesn't include compression headers. As such when using gzip we
- # need to include '32' for automatic header detection...
- #
- # https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompressing-incorrect-header-check/22310760#22310760
- #
- # ... and with zstd we need to use the streaming API.
- if encoding in (Compression.GZIP, 'deflate'):
- data = zlib.decompress(data, zlib.MAX_WBITS | 32)
- elif encoding == Compression.ZSTD and ZSTD_SUPPORTED:
+def _decompress(data, encoding):
+ """
+ Decompresses descriptor data.
+
+ Tor doesn't include compression headers. As such when using gzip we
+ need to include '32' for automatic header detection...
+
+ https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompressing-incorrect-header-check/22310760#22310760
+
+ ... and with zstd we need to use the streaming API.
+
+ :param bytes data: data we received
+ :param str encoding: 'Content-Encoding' header of the response
+
+ :raises:
+ * **ValueError** if encoding is unrecognized
+ * **ImportError** if missing the decompression module
+ """
+
+ if encoding == Compression.PLAINTEXT:
+ return data.strip()
+ elif encoding in (Compression.GZIP, 'deflate'):
+ return zlib.decompress(data, zlib.MAX_WBITS | 32).strip()
+ elif encoding == Compression.ZSTD:
+ if not ZSTD_SUPPORTED:
+ raise ImportError('Decompressing zstd data requires https://pypi.python.org/pypi/zstandard')
+
output_buffer = io.BytesIO()
with zstd.ZstdDecompressor().write_to(output_buffer) as decompressor:
decompressor.write(data)
- data = output_buffer.getvalue()
- elif encoding == Compression.LZMA and LZMA_SUPPORTED:
- data = lzma.decompress(data)
+ return output_buffer.getvalue().strip()
+ elif encoding == Compression.LZMA:
+ if not LZMA_SUPPORTED:
+ raise ImportError('Decompressing lzma data requires https://docs.python.org/3/library/lzma.html')
- return data.strip(), response.headers
+ return lzma.decompress(data).strip()
+ else:
+ raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
def _guess_descriptor_type(resource):
@@ -476,6 +501,9 @@ class Query(object):
:var list endpoints: :class:`~stem.DirPort` or :class:`~stem.ORPort` of the
authority or mirror we're querying, this uses authorities if undefined
+ :var list compression: list of :data:`stem.descriptor.remote.Compression`
+ we're willing to accept, when none are mutually supported downloads fall
+ back to Compression.PLAINTEXT
:var int retries: number of times to attempt the request if downloading it
fails
:var bool fall_back_to_authority: when retrying request issues the last
@@ -500,11 +528,7 @@ class Query(object):
Following are only applicable when downloading from a
:class:`~stem.DirPort`...
- :var list compression: list of :data:`stem.descriptor.remote.Compression`
- we're willing to accept, when none are mutually supported downloads fall
- back to Compression.PLAINTEXT
:var float timeout: duration before we'll time out our request
-
:var str download_url: last url used to download the descriptor, this is
unset until we've actually made a download attempt
@@ -683,7 +707,7 @@ class Query(object):
endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
if isinstance(endpoint, stem.ORPort):
- self.content, self.reply_headers = _download_from_orport(endpoint, self.resource)
+ self.content, self.reply_headers = _download_from_orport(endpoint, self.compression, self.resource)
elif isinstance(endpoint, stem.DirPort):
self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
self.content, self.reply_headers = _download_from_dirport(self.download_url, self.compression, timeout)
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 040d4c6c..753681e0 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -12,6 +12,7 @@ import unittest
import stem.descriptor.remote
import stem.prereq
import stem.util.conf
+import stem.util.str_tools
from stem.descriptor.remote import Compression
from test.unit.descriptor import read_resource
@@ -126,10 +127,9 @@ HEADER = '\r\n'.join([
'Content-Encoding: %s',
])
-ORPORT_DESCRIPTOR = 'HTTP/1.0 200 OK\n' + HEADER + '\n\n' + TEST_DESCRIPTOR
-
-def _orport_mock(data):
+def _orport_mock(data, encoding = 'identity'):
+ data = b'HTTP/1.0 200 OK\r\n' + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + data
cells = []
for hunk in [data[i:i + 50] for i in range(0, len(data), 50)]:
@@ -167,7 +167,7 @@ class TestDescriptorDownloader(unittest.TestCase):
# prevent our mocks from impacting other tests
stem.descriptor.remote.SINGLETON_DOWNLOADER = None
- @patch('stem.client.Relay.connect', _orport_mock(ORPORT_DESCRIPTOR))
+ @patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR))
def test_using_orport(self):
"""
Download a descriptor through the ORPort.
@@ -175,7 +175,6 @@ class TestDescriptorDownloader(unittest.TestCase):
reply = stem.descriptor.remote.their_server_descriptor(
endpoints = [stem.ORPort('12.34.56.78', 1100)],
- fall_back_to_authority = False,
validate = True,
)
@@ -191,7 +190,6 @@ class TestDescriptorDownloader(unittest.TestCase):
reply = stem.descriptor.remote.their_server_descriptor(
endpoints = [stem.DirPort('12.34.56.78', 1100)],
- fall_back_to_authority = False,
validate = True,
)
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits