[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [stem/master] Make Relay thread safe
commit 331127483838c416b279d30cf041deb678984ab2
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date: Tue Feb 6 09:53:08 2018 -0800
Make Relay thread safe
---
stem/client/__init__.py | 91 +++++++++++++++++++++++++++----------------------
1 file changed, 51 insertions(+), 40 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 8d34b626..5a9d09e6 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -16,11 +16,18 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
|
|- is_alive - reports if our connection is open or closed
|- connection_time - time when we last connected or disconnected
- +- close - shuts down our connection
+ |- close - shuts down our connection
+ |
+ +- create_circuit - establishes a new circuit
+
+ Circuit - Circuit we've established through a relay.
+ |- send - sends a message through this circuit
+ +- close - closes this circuit
"""
import copy
import hashlib
+import threading
import stem
import stem.client.cell
@@ -47,6 +54,7 @@ class Relay(object):
def __init__(self, orport, link_protocol):
self.link_protocol = link_protocol
self._orport = orport
+ self._orport_lock = threading.RLock()
self._circuits = {}
@staticmethod
@@ -138,40 +146,42 @@ class Relay(object):
:func:`~stem.socket.BaseSocket.close` method.
"""
- return self._orport.close()
+ with self._orport_lock:
+ return self._orport.close()
def create_circuit(self):
"""
Establishes a new circuit.
"""
- # Find an unused circuit id. Since we're initiating the circuit we pick any
- # value from a range that's determined by our link protocol.
+ with self._orport_lock:
+ # Find an unused circuit id. Since we're initiating the circuit we pick any
+ # value from a range that's determined by our link protocol.
- circ_id = 0x80000000 if self.link_protocol > 3 else 0x01
+ circ_id = 0x80000000 if self.link_protocol > 3 else 0x01
- while circ_id in self._circuits:
- circ_id += 1
+ while circ_id in self._circuits:
+ circ_id += 1
- create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
- self._orport.send(create_fast_cell.pack(self.link_protocol))
+ create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
+ self._orport.send(create_fast_cell.pack(self.link_protocol))
- response = stem.client.cell.Cell.unpack(self._orport.recv(), self.link_protocol)
- created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response)
+ response = stem.client.cell.Cell.unpack(self._orport.recv(), self.link_protocol)
+ created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response)
- if not created_fast_cells:
- raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
+ if not created_fast_cells:
+ raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
- created_fast_cell = created_fast_cells[0]
- kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
+ created_fast_cell = created_fast_cells[0]
+ kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
- if created_fast_cell.derivative_key != kdf.key_hash:
- raise ValueError('Remote failed to prove that it knows our shared key')
+ if created_fast_cell.derivative_key != kdf.key_hash:
+ raise ValueError('Remote failed to prove that it knows our shared key')
- circ = Circuit(self, circ_id, kdf)
- self._circuits[circ.id] = circ
+ circ = Circuit(self, circ_id, kdf)
+ self._circuits[circ.id] = circ
- return circ
+ return circ
def __enter__(self):
return self
@@ -219,30 +229,31 @@ class Circuit(object):
"""
# TODO: move RelayCommand to this base module?
- # TODO: add lock
- orig_digest = self.forward_digest.copy()
- orig_key = copy.copy(self.forward_key)
+ with self.relay._orport_lock:
+ orig_digest = self.forward_digest.copy()
+ orig_key = copy.copy(self.forward_key)
- try:
- cell = stem.client.cell.RelayCell(self.id, command, data, 0, stream_id)
- payload_without_digest = cell.pack(self.relay.link_protocol)[3:]
- self.forward_digest.update(payload_without_digest)
+ try:
+ cell = stem.client.cell.RelayCell(self.id, command, data, 0, stream_id)
+ payload_without_digest = cell.pack(self.relay.link_protocol)[3:]
+ self.forward_digest.update(payload_without_digest)
- cell = stem.client.cell.RelayCell(self.id, command, data, self.forward_digest, stream_id)
- header, payload = split(cell.pack(self.relay.link_protocol), 3)
- encrypted_payload = header + self.forward_key.update(payload)
+ cell = stem.client.cell.RelayCell(self.id, command, data, self.forward_digest, stream_id)
+ header, payload = split(cell.pack(self.relay.link_protocol), 3)
+ encrypted_payload = header + self.forward_key.update(payload)
- self.relay._orport.send(encrypted_payload)
- reply = next(stem.client.cell.Cell.unpack(self.relay._orport.recv(), self.relay.link_protocol))
+ self.relay._orport.send(encrypted_payload)
+ reply = next(stem.client.cell.Cell.unpack(self.relay._orport.recv(), self.relay.link_protocol))
- decrypted = self.backward_key.update(reply.pack(3)[3:])
- return stem.client.cell.RelayCell._unpack(decrypted, self.id, 3)
- except:
- self.forward_digest = orig_digest
- self.forward_key = orig_key
- raise
+ decrypted = self.backward_key.update(reply.pack(3)[3:])
+ return stem.client.cell.RelayCell._unpack(decrypted, self.id, 3)
+ except:
+ self.forward_digest = orig_digest
+ self.forward_key = orig_key
+ raise
def close(self):
- self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
- del self.relay._circuits[self.id]
+ with self.relay._orport_lock:
+ self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
+ del self.relay._circuits[self.id]
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits