[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Rewrite descriptor downloading
commit 448060eabed41b3bad22cc5b0a5b5494f2793816
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date: Mon Jun 15 16:23:40 2020 -0700
Rewrite descriptor downloading
Using run_in_executor() here has a couple issues...
1. Executor threads aren't cleaned up. Running our tests with the '--all'
argument concludes with...
Threads lingering after test run:
<_MainThread(MainThread, started 140249831520000)>
<Thread(ThreadPoolExecutor-0_0, started daemon 140249689769728)>
<Thread(ThreadPoolExecutor-0_1, started daemon 140249606911744)>
<Thread(ThreadPoolExecutor-0_2, started daemon 140249586980608)>
<Thread(ThreadPoolExecutor-0_3, started daemon 140249578587904)>
<Thread(ThreadPoolExecutor-0_4, started daemon 140249570195200)>
...
2. Asyncio has its own IO. Wrapping urllib within an executor is easy,
but loses asyncio benefits such as imposing timeouts through
asyncio.wait_for().
Urllib marshals and parses HTTP headers, but we already do that
for ORPort requests, so using a raw asyncio connection actually
lets us deduplicate some code.
Deduplication greatly simplifies testing in that we can mock _download_from()
rather than the raw connection. However, I couldn't adapt our timeout test.
Asyncio's wait_for() works in practice, but no dice when mocked.
---
stem/descriptor/remote.py | 229 ++++++++++++++++-------------------------
test/unit/descriptor/remote.py | 183 ++++++++------------------------
2 files changed, 133 insertions(+), 279 deletions(-)
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index c23ab7a9..f1ce79db 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -84,14 +84,11 @@ content. For example...
"""
import asyncio
-import functools
import io
import random
-import socket
import sys
import threading
import time
-import urllib.request
import stem
import stem.client
@@ -313,7 +310,7 @@ class AsyncQuery(object):
:var bool is_done: flag that indicates if our request has finished
:var float start_time: unix timestamp when we first started running
- :var http.client.HTTPMessage reply_headers: headers provided in the response,
+ :var dict reply_headers: headers provided in the response,
**None** if we haven't yet made our request
:var float runtime: time our query took, this is **None** if it's not yet
finished
@@ -330,13 +327,9 @@ class AsyncQuery(object):
:var float timeout: duration before we'll time out our request
:var str download_url: last url used to download the descriptor, this is
unset until we've actually made a download attempt
-
- :param start: start making the request when constructed (default is **True**)
- :param block: only return after the request has been completed, this is
- the same as running **query.run(True)** (default is **False**)
"""
- def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
+ def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
if not resource.startswith('/'):
raise ValueError("Resources should start with a '/': %s" % resource)
@@ -395,22 +388,15 @@ class AsyncQuery(object):
self._downloader_task = None # type: Optional[asyncio.Task]
self._downloader_lock = threading.RLock()
- self._asyncio_loop = asyncio.get_event_loop()
-
- if start:
- self.start()
-
- if block:
- self.run(True)
-
- def start(self) -> None:
+ async def start(self) -> None:
"""
Starts downloading the scriptors if we haven't started already.
"""
with self._downloader_lock:
if self._downloader_task is None:
- self._downloader_task = self._asyncio_loop.create_task(self._download_descriptors(self.retries, self.timeout))
+ loop = asyncio.get_running_loop()
+ self._downloader_task = loop.create_task(self._download_descriptors(self.retries, self.timeout))
async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
"""
@@ -434,7 +420,7 @@ class AsyncQuery(object):
async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
with self._downloader_lock:
- self.start()
+ await self.start()
await self._downloader_task
if self.error:
@@ -491,36 +477,71 @@ class AsyncQuery(object):
return random.choice(self.endpoints)
async def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
- try:
- self.start_time = time.time()
+ self.start_time = time.time()
+
+ retries = self.retries
+ time_remaining = self.timeout
+
+ while True:
endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
if isinstance(endpoint, stem.ORPort):
downloaded_from = 'ORPort %s:%s (resource %s)' % (endpoint.address, endpoint.port, self.resource)
- self.content, self.reply_headers = await _download_from_orport(endpoint, self.compression, self.resource)
elif isinstance(endpoint, stem.DirPort):
- self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
- downloaded_from = self.download_url
- self.content, self.reply_headers = await _download_from_dirport(self.download_url, self.compression, timeout)
+ downloaded_from = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
+ self.download_url = downloaded_from
else:
raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
- self.runtime = time.time() - self.start_time
- log.trace('Descriptors retrieved from %s in %0.2fs' % (downloaded_from, self.runtime))
- except:
- exc = sys.exc_info()[1]
+ try:
+ response = await asyncio.wait_for(self._download_from(endpoint), time_remaining)
+ self.content, self.reply_headers = _http_body_and_headers(response)
+
+ self.is_done = True
+ self.runtime = time.time() - self.start_time
+
+ log.trace('Descriptors retrieved from %s in %0.2fs' % (downloaded_from, self.runtime))
+ return
+ except asyncio.TimeoutError as exc:
+ self.is_done = True
+ self.error = stem.DownloadTimeout(downloaded_from, exc, sys.exc_info()[2], self.timeout)
+ return
+ except:
+ exception = sys.exc_info()[1]
+ retries -= 1
+
+ if time_remaining is not None:
+ time_remaining -= time.time() - self.start_time
+
+ if retries > 0:
+ log.debug("Failed to download descriptors from '%s' (%i retries remaining): %s" % (downloaded_from, retries, exception))
+ else:
+ log.debug("Failed to download descriptors from '%s': %s" % (self.download_url, exception))
+
+ self.is_done = True
+ self.error = exception
+ return
- if timeout is not None:
- timeout -= time.time() - self.start_time
+ async def _download_from(self, endpoint: stem.Endpoint) -> bytes:
+ http_request = '\r\n'.join((
+ 'GET %s HTTP/1.0' % self.resource,
+ 'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, self.compression)),
+ 'User-Agent: %s' % stem.USER_AGENT,
+ )) + '\r\n\r\n'
- if retries > 0 and (timeout is None or timeout > 0):
- log.debug("Unable to download descriptors from '%s' (%i retries remaining): %s" % (self.download_url, retries, exc))
- return await self._download_descriptors(retries - 1, timeout)
- else:
- log.debug("Unable to download descriptors from '%s': %s" % (self.download_url, exc))
- self.error = exc
- finally:
- self.is_done = True
+ if isinstance(endpoint, stem.ORPort):
+ link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
+
+ async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
+ async with await relay.create_circuit() as circ:
+ return await circ.directory(http_request, stream_id = 1)
+ elif isinstance(endpoint, stem.DirPort):
+ reader, writer = await asyncio.open_connection(endpoint.address, endpoint.port)
+ writer.write(str_tools._to_bytes(http_request))
+
+ return await reader.read()
+ else:
+ raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
class Query(stem.util.AsyncClassWrapper):
@@ -663,8 +684,8 @@ class Query(stem.util.AsyncClassWrapper):
"""
def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
- self._loop = asyncio.get_event_loop()
- self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
+ self._loop = asyncio.new_event_loop()
+ self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'query asyncio')
self._loop_thread.setDaemon(True)
self._loop_thread.start()
@@ -677,19 +698,23 @@ class Query(stem.util.AsyncClassWrapper):
retries,
fall_back_to_authority,
timeout,
- start,
- block,
validate,
document_handler,
**kwargs,
)
+ if start:
+ self.start()
+
+ if block:
+ self.run(True)
+
def start(self) -> None:
"""
Starts downloading the scriptors if we haven't started already.
"""
- self._call_async_method_soon('start')
+ self._execute_async_method('start')
def run(self, suppress = False) -> List['stem.descriptor.Descriptor']:
"""
@@ -1146,10 +1171,9 @@ class DescriptorDownloader(object):
return Query(resource, **args)
-async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
+def _http_body_and_headers(data: bytes) -> Tuple[bytes, Dict[str, str]]:
"""
- Downloads descriptors from the given orport. Payload is just like an http
- response (headers and all)...
+ Parse the headers and decompressed body from a HTTP response, such as...
::
@@ -1164,112 +1188,41 @@ async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[ste
identity-ed25519
... rest of the descriptor content...
- :param endpoint: endpoint to download from
- :param compression: compression methods for the request
- :param resource: descriptor resource to download
+ :param data: HTTP response
- :returns: two value tuple of the form (data, reply_headers)
+ :returns: **tuple** with the decompressed data and headers
:raises:
- * :class:`stem.ProtocolError` if not a valid descriptor response
- * :class:`stem.SocketError` if unable to establish a connection
- """
-
- link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
-
- async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
- async with await relay.create_circuit() as circ:
- request = '\r\n'.join((
- 'GET %s HTTP/1.0' % resource,
- 'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, compression)),
- 'User-Agent: %s' % stem.USER_AGENT,
- )) + '\r\n\r\n'
-
- response = await circ.directory(request, stream_id = 1)
- first_line, data = response.split(b'\r\n', 1)
- header_data, body_data = data.split(b'\r\n\r\n', 1)
-
- if not first_line.startswith(b'HTTP/1.0 2'):
- raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % str_tools._to_unicode(first_line))
-
- headers = {}
-
- for line in str_tools._to_unicode(header_data).splitlines():
- if ': ' not in line:
- raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8')))
-
- key, value = line.split(': ', 1)
- headers[key] = value
-
- return _decompress(body_data, headers.get('Content-Encoding')), headers
-
-
-async def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
- """
- Downloads descriptors from the given url.
-
- :param url: dirport url from which to download from
- :param compression: compression methods for the request
- :param timeout: duration before we'll time out our request
-
- :returns: two value tuple of the form (data, reply_headers)
-
- :raises:
- * :class:`~stem.DownloadTimeout` if our request timed out
- * :class:`~stem.DownloadFailed` if our request fails
- """
-
- # TODO: use an asyncronous solution for the HTTP request.
- request = urllib.request.Request(
- url,
- headers = {
- 'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
- 'User-Agent': stem.USER_AGENT,
- }
- )
- get_response = functools.partial(urllib.request.urlopen, request, timeout = timeout)
-
- loop = asyncio.get_event_loop()
- try:
- response = await loop.run_in_executor(None, get_response)
- except socket.timeout as exc:
- raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout)
- except:
- exception, stacktrace = sys.exc_info()[1:3]
- raise stem.DownloadFailed(url, exception, stacktrace)
-
- return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
-
-
-def _decompress(data: bytes, encoding: str) -> bytes:
+ * **stem.ProtocolError** if response was unsuccessful or malformed
+ * **ValueError** if encoding is unrecognized
+ * **ImportError** if missing the decompression module
"""
- Decompresses descriptor data.
- Tor doesn't include compression headers. As such when using gzip we
- need to include '32' for automatic header detection...
+ first_line, data = data.split(b'\r\n', 1)
+ header_data, body_data = data.split(b'\r\n\r\n', 1)
- https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompressing-incorrect-header-check/22310760#22310760
+ if not first_line.startswith(b'HTTP/1.0 2'):
+ raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % str_tools._to_unicode(first_line))
- ... and with zstd we need to use the streaming API.
+ headers = {}
- :param data: data we received
- :param encoding: 'Content-Encoding' header of the response
+ for line in str_tools._to_unicode(header_data).splitlines():
+ if ': ' not in line:
+ raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8')))
- :returns: **bytes** with the decompressed data
+ key, value = line.split(': ', 1)
+ headers[key] = value
- :raises:
- * **ValueError** if encoding is unrecognized
- * **ImportError** if missing the decompression module
- """
+ encoding = headers.get('Content-Encoding')
if encoding == 'deflate':
- return stem.descriptor.Compression.GZIP.decompress(data)
+ return stem.descriptor.Compression.GZIP.decompress(body_data), headers
for compression in stem.descriptor.Compression:
if encoding == compression.encoding:
- return compression.decompress(data)
+ return compression.decompress(body_data), headers
- raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
+ raise ValueError("'%s' is an unrecognized encoding" % encoding)
def _guess_descriptor_type(resource: str) -> str:
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 33ee57fb..797bc8a3 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -2,9 +2,6 @@
Unit tests for stem.descriptor.remote.
"""
-import http.client
-import socket
-import time
import unittest
import stem
@@ -67,47 +64,13 @@ HEADER = '\r\n'.join([
])
-def _orport_mock(data, encoding = 'identity', response_code_header = None):
+def mock_download(descriptor, encoding = 'identity', response_code_header = None):
if response_code_header is None:
response_code_header = b'HTTP/1.0 200 OK\r\n'
- data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + data
- cells = []
+ data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + descriptor
- for hunk in [data[i:i + 50] for i in range(0, len(data), 50)]:
- cell = Mock()
- cell.data = hunk
- cells.append(cell)
-
- class AsyncMock(Mock):
- async def __aenter__(self):
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- return
-
- circ_mock = AsyncMock()
- circ_mock.directory.side_effect = coro_func_returning_value(data)
-
- relay_mock = AsyncMock()
- relay_mock.create_circuit.side_effect = coro_func_returning_value(circ_mock)
-
- return coro_func_returning_value(relay_mock)
-
-
-def _dirport_mock(data, encoding = 'identity'):
- dirport_mock = Mock()
- dirport_mock().read.return_value = data
-
- headers = http.client.HTTPMessage()
-
- for line in HEADER.splitlines():
- key, value = line.split(': ', 1)
- headers.add_header(key, encoding if key == 'Content-Encoding' else value)
-
- dirport_mock().headers = headers
-
- return dirport_mock
+ return patch('stem.descriptor.remote.AsyncQuery._download_from', Mock(side_effect = coro_func_returning_value(data)))
class TestDescriptorDownloader(unittest.TestCase):
@@ -115,10 +78,10 @@ class TestDescriptorDownloader(unittest.TestCase):
# prevent our mocks from impacting other tests
stem.descriptor.remote.SINGLETON_DOWNLOADER = None
- @patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR))
- def test_using_orport(self):
+ @mock_download(TEST_DESCRIPTOR)
+ def test_download(self):
"""
- Download a descriptor through the ORPort.
+ Simply download and parse a descriptor.
"""
reply = stem.descriptor.remote.their_server_descriptor(
@@ -128,10 +91,16 @@ class TestDescriptorDownloader(unittest.TestCase):
)
self.assertEqual(1, len(list(reply)))
- self.assertEqual('moria1', list(reply)[0].nickname)
self.assertEqual(5, len(reply.reply_headers))
- def test_orport_response_code_headers(self):
+ desc = list(reply)[0]
+
+ self.assertEqual('moria1', desc.nickname)
+ self.assertEqual('128.31.0.34', desc.address)
+ self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
+ self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
+
+ def test_response_header_code(self):
"""
When successful Tor provides a '200 OK' status, but we should accept other 2xx
response codes, reason text, and recognize HTTP errors.
@@ -144,14 +113,14 @@ class TestDescriptorDownloader(unittest.TestCase):
)
for header in response_code_headers:
- with patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR, response_code_header = header)):
+ with mock_download(TEST_DESCRIPTOR, response_code_header = header):
stem.descriptor.remote.their_server_descriptor(
endpoints = [stem.ORPort('12.34.56.78', 1100)],
validate = True,
skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
).run()
- with patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR, response_code_header = b'HTTP/1.0 500 Kaboom\r\n')):
+ with mock_download(TEST_DESCRIPTOR, response_code_header = b'HTTP/1.0 500 Kaboom\r\n'):
request = stem.descriptor.remote.their_server_descriptor(
endpoints = [stem.ORPort('12.34.56.78', 1100)],
validate = True,
@@ -160,28 +129,32 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaisesRegexp(stem.ProtocolError, "^Response should begin with HTTP success, but was 'HTTP/1.0 500 Kaboom'", request.run)
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
- def test_using_dirport(self):
- """
- Download a descriptor through the DirPort.
- """
+ @mock_download(TEST_DESCRIPTOR)
+ def test_reply_header_data(self):
+ query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
+ self.assertEqual(None, query.reply_headers) # initially we don't have a reply
+ query.run()
- reply = stem.descriptor.remote.their_server_descriptor(
- endpoints = [stem.DirPort('12.34.56.78', 1100)],
- validate = True,
- skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
- )
+ self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('Date'))
+ self.assertEqual('application/octet-stream', query.reply_headers.get('Content-Type'))
+ self.assertEqual('97.103.17.56', query.reply_headers.get('X-Your-Address-Is'))
+ self.assertEqual('no-cache', query.reply_headers.get('Pragma'))
+ self.assertEqual('identity', query.reply_headers.get('Content-Encoding'))
- self.assertEqual(1, len(list(reply)))
- self.assertEqual('moria1', list(reply)[0].nickname)
- self.assertEqual(5, len(reply.reply_headers))
+ # request a header that isn't present
+ self.assertEqual(None, query.reply_headers.get('no-such-header'))
+ self.assertEqual('default', query.reply_headers.get('no-such-header', 'default'))
+
+ descriptors = list(query)
+ self.assertEqual(1, len(descriptors))
+ self.assertEqual('moria1', descriptors[0].nickname)
def test_gzip_url_override(self):
query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
self.assertEqual(TEST_RESOURCE, query.resource)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_identity'), encoding = 'identity'))
+ @mock_download(read_resource('compressed_identity'), encoding = 'identity')
def test_compression_plaintext(self):
"""
Download a plaintext descriptor.
@@ -197,7 +170,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_gzip'), encoding = 'gzip'))
+ @mock_download(read_resource('compressed_gzip'), encoding = 'gzip')
def test_compression_gzip(self):
"""
Download a gip compressed descriptor.
@@ -213,7 +186,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_zstd'), encoding = 'x-zstd'))
+ @mock_download(read_resource('compressed_zstd'), encoding = 'x-zstd')
def test_compression_zstd(self):
"""
Download a zstd compressed descriptor.
@@ -231,7 +204,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_lzma'), encoding = 'x-tor-lzma'))
+ @mock_download(read_resource('compressed_lzma'), encoding = 'x-tor-lzma')
def test_compression_lzma(self):
"""
Download a lzma compressed descriptor.
@@ -249,8 +222,8 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen')
- def test_each_getter(self, dirport_mock):
+ @mock_download(TEST_DESCRIPTOR)
+ def test_each_getter(self):
"""
Surface level exercising of each getter method for downloading descriptors.
"""
@@ -266,57 +239,8 @@ class TestDescriptorDownloader(unittest.TestCase):
downloader.get_bandwidth_file()
downloader.get_detached_signatures()
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
- def test_reply_headers(self):
- query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
- self.assertEqual(None, query.reply_headers) # initially we don't have a reply
- query.run()
-
- self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('date'))
- self.assertEqual('application/octet-stream', query.reply_headers.get('content-type'))
- self.assertEqual('97.103.17.56', query.reply_headers.get('x-your-address-is'))
- self.assertEqual('no-cache', query.reply_headers.get('pragma'))
- self.assertEqual('identity', query.reply_headers.get('content-encoding'))
-
- # getting headers should be case insensitive
- self.assertEqual('identity', query.reply_headers.get('CoNtEnT-ENCODING'))
-
- # request a header that isn't present
- self.assertEqual(None, query.reply_headers.get('no-such-header'))
- self.assertEqual('default', query.reply_headers.get('no-such-header', 'default'))
-
- descriptors = list(query)
- self.assertEqual(1, len(descriptors))
- self.assertEqual('moria1', descriptors[0].nickname)
-
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
- def test_query_download(self):
- """
- Check Query functionality when we successfully download a descriptor.
- """
-
- query = stem.descriptor.remote.Query(
- TEST_RESOURCE,
- 'server-descriptor 1.0',
- endpoints = [stem.DirPort('128.31.0.39', 9131)],
- compression = Compression.PLAINTEXT,
- validate = True,
- skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
- )
-
- self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._wrapped_instance._pick_endpoint())
-
- descriptors = list(query)
- self.assertEqual(1, len(descriptors))
- desc = descriptors[0]
-
- self.assertEqual('moria1', desc.nickname)
- self.assertEqual('128.31.0.34', desc.address)
- self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
- self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
-
- @patch('urllib.request.urlopen', _dirport_mock(b'some malformed stuff'))
- def test_query_with_malformed_content(self):
+ @mock_download(b'some malformed stuff')
+ def test_malformed_content(self):
"""
Query with malformed descriptor content.
"""
@@ -340,29 +264,6 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaises(ValueError, query.run)
- @patch('urllib.request.urlopen')
- def test_query_with_timeout(self, dirport_mock):
- def urlopen_call(*args, **kwargs):
- time.sleep(0.06)
- raise socket.timeout('connection timed out')
-
- dirport_mock.side_effect = urlopen_call
-
- query = stem.descriptor.remote.Query(
- TEST_RESOURCE,
- 'server-descriptor 1.0',
- endpoints = [stem.DirPort('128.31.0.39', 9131)],
- fall_back_to_authority = False,
- timeout = 0.1,
- validate = True,
- )
-
- # After two requests we'll have reached our total permissable timeout.
- # It would be nice to check that we don't make a third, but this
- # assertion has proved unreliable so only checking for the exception.
-
- self.assertRaises(stem.DownloadTimeout, query.run)
-
def test_query_with_invalid_endpoints(self):
invalid_endpoints = {
'hello': "'h' is a str.",
@@ -375,7 +276,7 @@ class TestDescriptorDownloader(unittest.TestCase):
expected_error = 'Endpoints must be an stem.ORPort or stem.DirPort. ' + error_suffix
self.assertRaisesWith(ValueError, expected_error, stem.descriptor.remote.Query, TEST_RESOURCE, 'server-descriptor 1.0', endpoints = endpoints)
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
+ @mock_download(TEST_DESCRIPTOR)
def test_can_iterate_multiple_times(self):
query = stem.descriptor.remote.Query(
TEST_RESOURCE,
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits