[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[tor-commits] [stem/master] Make `Controller.get_hidden_service_descriptor` asynchronous



commit a729b61c61f4c2e136f13a5a921bebfddf1cd4ed
Author: Illia Volochii <illia.volochii@xxxxxxxxx>
Date:   Fri Apr 17 23:07:47 2020 +0300

    Make `Controller.get_hidden_service_descriptor` asynchronous
---
 stem/control.py | 34 ++++++++++++++++++++--------------
 1 file changed, 20 insertions(+), 14 deletions(-)

diff --git a/stem/control.py b/stem/control.py
index 84efbf81..fd38871a 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -2007,7 +2007,7 @@ class Controller(BaseController):
       yield desc  # type: ignore
 
   @with_default()
-  def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2:
+  async def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2:
     """
     get_hidden_service_descriptor(address, default = UNDEFINED, servers = None, await_result = True)
 
@@ -2050,23 +2050,25 @@ class Controller(BaseController):
     if not stem.util.tor_tools.is_valid_hidden_service_address(address):
       raise ValueError("'%s.onion' isn't a valid hidden service address" % address)
 
-    hs_desc_queue = queue.Queue()  # type: queue.Queue[stem.response.events.Event]
+    hs_desc_queue = asyncio.Queue()  # type: asyncio.Queue[stem.response.events.Event]
     hs_desc_listener = None
 
-    hs_desc_content_queue = queue.Queue()  # type: queue.Queue[stem.response.events.Event]
+    hs_desc_content_queue = asyncio.Queue()  # type: asyncio.Queue[stem.response.events.Event]
     hs_desc_content_listener = None
 
     start_time = time.time()
 
     if await_result:
-      def hs_desc_listener(event: stem.response.events.Event) -> None:
-        hs_desc_queue.put(event)
+      async def hs_desc_listener(event: stem.response.events.Event) -> None:
+        await hs_desc_queue.put(event)
 
-      def hs_desc_content_listener(event: stem.response.events.Event) -> None:
-        hs_desc_content_queue.put(event)
+      async def hs_desc_content_listener(event: stem.response.events.Event) -> None:
+        await hs_desc_content_queue.put(event)
 
-      self.add_event_listener(hs_desc_listener, EventType.HS_DESC)
-      self.add_event_listener(hs_desc_content_listener, EventType.HS_DESC_CONTENT)
+      await asyncio.gather(
+        self.add_event_listener(hs_desc_listener, EventType.HS_DESC),
+        self.add_event_listener(hs_desc_content_listener, EventType.HS_DESC_CONTENT),
+      )
 
     try:
       request = 'HSFETCH %s' % address
@@ -2074,7 +2076,7 @@ class Controller(BaseController):
       if servers:
         request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
 
-      response = stem.response._convert_to_single_line(self.msg(request))
+      response = stem.response._convert_to_single_line(await self.msg(request))
 
       if not response.is_ok():
         raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code)
@@ -2083,7 +2085,7 @@ class Controller(BaseController):
         return None  # not waiting, so nothing to provide back
       else:
         while True:
-          event = _get_with_timeout(hs_desc_content_queue, timeout, start_time)
+          event = await _get_with_timeout(hs_desc_content_queue, timeout)
 
           if event.address == address:
             if event.descriptor:
@@ -2092,7 +2094,7 @@ class Controller(BaseController):
               # no descriptor, looking through HS_DESC to figure out why
 
               while True:
-                event = _get_with_timeout(hs_desc_queue, timeout, start_time)
+                event = await _get_with_timeout(hs_desc_queue, timeout)
 
                 if event.address == address and event.action == stem.HSDescAction.FAILED:
                   if event.reason == stem.HSDescReason.NOT_FOUND:
@@ -2100,11 +2102,15 @@ class Controller(BaseController):
                   else:
                     raise stem.DescriptorUnavailable('Unable to retrieve the descriptor for %s.onion (retrieved from %s): %s' % (address, event.directory_fingerprint, event.reason))
     finally:
+      awaitable_removals = []
+
       if hs_desc_listener:
-        self.remove_event_listener(hs_desc_listener)
+        awaitable_removals.append(self.remove_event_listener(hs_desc_listener))
 
       if hs_desc_content_listener:
-        self.remove_event_listener(hs_desc_content_listener)
+        awaitable_removals.append(self.remove_event_listener(hs_desc_content_listener))
+
+      await asyncio.gather(*awaitable_removals)
 
   async def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]:
     """



_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits