[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[or-cvs] r9619: Added NodeRestriction and PathRestriction interfaces and imp (torflow/trunk)



Author: mikeperry
Date: 2007-02-22 05:57:42 -0500 (Thu, 22 Feb 2007)
New Revision: 9619

Added:
   torflow/trunk/unit.py
Modified:
   torflow/trunk/TorCtl.py
   torflow/trunk/metatroller.py
   torflow/trunk/soat.pl
Log:
Added NodeRestriction and PathRestriction interfaces and implementations.
Architecture is really flexible, expressive, and modularized. Tested
integration only briefly, but did do some unit tests and pychecker.

Will be reorganizing TorCtl into its own module directory after this to allow
this new path support stuff to live in its own py file, so expect yet more
compatibility breakage before this stabilizes.



Modified: torflow/trunk/TorCtl.py
===================================================================
--- torflow/trunk/TorCtl.py	2007-02-22 08:21:17 UTC (rev 9618)
+++ torflow/trunk/TorCtl.py	2007-02-22 10:57:42 UTC (rev 9619)
@@ -7,6 +7,8 @@
 TorCtl -- Library to control Tor processes.
 """
 
+# XXX: Docstring all exported classes/interfaces. Also need __all__
+
 import os
 import re
 import struct
@@ -19,6 +21,7 @@
 import binascii
 import types
 import time
+import copy
 from TorUtil import *
 
 # Types of "EVENT" message.
@@ -51,6 +54,10 @@
     "Raised when Tor controller returns an error"
     pass
 
+class NodeError(TorCtlError):
+    "Raise when we have no nodes satisfying restrictions"
+    pass
+
 class NetworkStatus:
     "Filled in during NS events"
     def __init__(self, nickname, idhash, orhash, updated, ip, orport, dirport, flags):
@@ -133,20 +140,140 @@
         self.event_name = event_name
         self.event_string = event_string
 
-class NodeSelector:
-    "Interface for node selection policies"
-    def __init__(self, target_ip, target_port):
-        self.to_ip = target_ip
-        self.to_port = target_port
+class PathRestriction:
+    "Interface for path restriction policies"
+    def r_is_ok(self, path, r): return True    
+    def entry_is_ok(self, path, r): return self.r_is_ok(path, r)
+    def middle_is_ok(self, path, r): return self.r_is_ok(path, r)
+    def exit_is_ok(self, path, r): return self.r_is_ok(path, r)
 
+class PathRestrictionList:
+    def __init__(self, restrictions):
+        self.restrictions = restrictions
+    
+    def entry_is_ok(self, path, r):
+        for rs in self.restrictions:
+            if not rs.entry_is_ok(path, r):
+                return False
+        return True
+
+    def middle_is_ok(self, path, r):
+        for rs in self.restrictions:
+            if not rs.middle_is_ok(path, r):
+                return False
+        return True
+
+    def exit_is_ok(self, path, r):
+        for rs in self.restrictions:
+            if not rs.exit_is_ok(path, r):
+                return False
+        return True
+
+    def add_restriction(self, rstr):
+        self.restrictions.append(rstr)
+
+    def del_restriction(self, RestrictionClass):
+        # im_class actually returns current base class, not
+        # implementing class. We abuse this fact here. 
+        # XXX: Is this a standard, or a bug?
+        self.restrictions = filter(
+                lambda r: r.r_is_ok.im_class != RestrictionClass,
+                    self.restrictions)
+
+class NodeRestriction:
+    "Interface for node restriction policies"
+    def r_is_ok(self, r): return True    
+    def reset(self, router_list): pass
+
+class NodeRestrictionList:
+    def __init__(self, restrictions, sorted_r):
+        self.restrictions = restrictions
+        self.update_routers(sorted_r)
+
+    def __check_r(self, r):
+        for rst in self.restrictions:
+            if not rst.r_is_ok(r): return False
+        self.restricted_bw += r.bw
+        return True
+
+    def update_routers(self, sorted_r):
+        self._sorted_r = sorted_r
+        self.restricted_bw = 0
+        for rs in self.restrictions: rs.reset(sorted_r)
+        self.restricted_r = filter(self.__check_r, self._sorted_r)
+
+    def add_restriction(self, restr):
+        self.restrictions.append(restr)
+        for r in self.restricted_r:
+            if not restr.r_is_ok(r):
+                self.restricted_r.remove(r)
+                self.restricted_bw -= r.bw
+    
+    # XXX: This does not collapse And/Or restrictions.. That is non-trivial
+    # in teh general case
+    def del_restriction(self, RestrictionClass):
+        self.restrictions = filter(
+                lambda r: r.r_is_ok.im_class != RestrictionClass,
+                    self.restrictions)
+        self.update_routers(self._sorted_r)
+
+
+class NodeGenerator:
+    "Interface for node generation"
+    def __init__(self, restriction_list):
+        self.restriction_list = restriction_list
+        self.rewind()
+
+    def rewind(self):
+        # TODO: Hrmm... Is there any way to handle termination other 
+        # than to make a list of routers that we pop from? Random generators 
+        # will not terminate if no node matches the selector without this..
+        # Not so much an issue now, but in a few years, the Tor network
+        # will be large enough that having all these list copies will
+        # be obscene... Possible candidate for a python list comprehension
+        self.routers = copy.copy(self.restriction_list.restricted_r)
+        self.bw = self.restriction_list.restricted_bw
+
+    def mark_chosen(self, r):
+        self.routers.remove(r)
+        self.bw -= r.bw
+
+    def all_chosen(self):
+        if not self.routers and self.bw or not self.bw and self.routers:
+            plog("WARN", str(len(self.routers))+" routers left but bw="
+                 +str(self.bw))
+        return not self.routers
+
+    def next_r(self): raise NotImplemented()
+
+class PathSelector:
+    "Implementation of path selection policies"
+    def __init__(self, entry_gen, mid_gen, exit_gen, path_restrict):
+        self.entry_gen = entry_gen
+        self.mid_gen = mid_gen
+        self.exit_gen = exit_gen
+        self.path_restrict = path_restrict
+
     def entry_chooser(self, path):
-        raise NotImplemented()
-
+        self.entry_gen.rewind()
+        for r in self.entry_gen.next_r():
+            if self.path_restrict.entry_is_ok(path, r):
+                return r
+        raise NodeError();
+        
     def middle_chooser(self, path):
-        raise NotImplemented()
+        self.mid_gen.rewind()
+        for r in self.mid_gen.next_r():
+            if self.path_restrict.middle_is_ok(path, r):
+                return r
+        raise NodeError();
 
     def exit_chooser(self, path):
-        raise NotImplemented()
+        self.exit_gen.rewind()
+        for r in self.exit_gen.next_r():
+            if self.path_restrict.exit_is_ok(path, r):
+                return r
+        raise NodeError();
 
 class ExitPolicyLine:
     def __init__(self, match, ip_mask, port_low, port_high):
@@ -181,19 +308,35 @@
                 return self.match
         return -1
 
-# XXX: Parse out version and OS
+class RouterVersion:
+    def __init__(self, version):
+        v = re.search("^(\d+).(\d+).(\d+).(\d+)", version).groups()
+        self.version = int(v[0])*0x1000000 + int(v[1])*0x10000 + int(v[2])*0x100 + int(v[3])
+        self.ver_string = version
+
+    def __lt__(self, other): return self.version < other.version
+    def __gt__(self, other): return self.version > other.version
+    def __ge__(self, other): return self.version >= other.version
+    def __le__(self, other): return self.version <= other.version
+    def __eq__(self, other): return self.version == other.version
+    def __ne__(self, other): return self.version != other.version
+    def __str__(self): return self.ver_string
+
 class Router:
-    def __init__(self, idhex, name, bw, exitpolicy, down, guard, valid,
-                 badexit, fast):
+    def __init__(self, idhex, name, bw, down, exitpolicy, flags, ip, version, os):
         self.idhex = idhex
         self.name = name
         self.bw = bw
         self.exitpolicy = exitpolicy
-        self.guard = guard
+        self.guard = "Guard" in flags
+        self.badexit = "BadExit" in flags
+        self.valid = "Valid" in flags
+        self.fast = "Fast" in flags
+        self.flags = flags
         self.down = down
-        self.badexit = badexit
-        self.valid = valid
-        self.fast = fast
+        self.ip = struct.unpack(">I", socket.inet_aton(ip))[0]
+        self.version = RouterVersion(version)
+        self.os = os
 
     def will_exit_to(self, ip, port):
         for line in self.exitpolicy:
@@ -201,7 +344,10 @@
             if ret != -1:
                 return ret
         plog("NOTICE", "No matching exit line for "+self.name)
-        return 0
+        return False
+    
+    def __eq__(self, other): return self.idhex == other.idhex
+    def __ne__(self, other): return self.idhex != other.idhex
 
 class Circuit:
     def __init__(self):
@@ -209,7 +355,10 @@
         self.created_at = 0 # time
         self.path = [] # routers
         self.exit = 0
+    
+    def id_path(self): return map(lambda r: r.idhex, self.path)
 
+
 class Connection:
     """A Connection represents a connection to the Tor process."""
     def __init__(self, sock):
@@ -423,7 +572,7 @@
             self._debugFile.write(">>> %s" % amsg)
         self._s.write(msg)
 
-    def _sendAndRecv(self, msg="", expectedTypes=("250", "251")):
+    def sendAndRecv(self, msg="", expectedTypes=("250", "251")):
         """Helper: Send a command 'msg' to Tor, and wait for a command
            in response.  If the response type is in expectedTypes,
            return a list of (tp,body,extra) tuples.  If it is an
@@ -448,7 +597,7 @@
            method before Tor can start.
         """
         hexstr = binascii.b2a_hex(secret)
-        self._sendAndRecv("AUTHENTICATE %s\r\n"%hexstr)
+        self.sendAndRecv("AUTHENTICATE %s\r\n"%hexstr)
 
     def get_option(self, name):
         """Get the value of the configuration option named 'name'.  To
@@ -458,7 +607,7 @@
         """
         if not isinstance(name, str):
             name = " ".join(name)
-        lines = self._sendAndRecv("GETCONF %s\r\n" % name)
+        lines = self.sendAndRecv("GETCONF %s\r\n" % name)
 
         r = []
         for _,line,_ in lines:
@@ -482,7 +631,7 @@
         if not kvlist:
             return
         msg = " ".join(["%s=%s"%(k,quote(v)) for k,v in kvlist])
-        self._sendAndRecv("SETCONF %s\r\n"%msg)
+        self.sendAndRecv("SETCONF %s\r\n"%msg)
 
     def reset_options(self, keylist):
         """Reset the options listed in 'keylist' to their default values.
@@ -491,18 +640,18 @@
            previous versions wanted you to set configuration keys to "".
            That no longer works.
         """
-        self._sendAndRecv("RESETCONF %s\r\n"%(" ".join(keylist)))
+        self.sendAndRecv("RESETCONF %s\r\n"%(" ".join(keylist)))
 
     def get_network_status(self, who="all"):
         """Get the entire network status list"""
-        return parse_ns_body(self._sendAndRecv("GETINFO ns/"+who+"\r\n")[0][2])
+        return parse_ns_body(self.sendAndRecv("GETINFO ns/"+who+"\r\n")[0][2])
 
     def get_router(self, ns):
         """Fill in a Router class corresponding to a given NS class"""
-        desc = self._sendAndRecv("GETINFO desc/id/" + ns.idhex + "\r\n")[0][2].split("\n")
+        desc = self.sendAndRecv("GETINFO desc/id/" + ns.idhex + "\r\n")[0][2].split("\n")
         line = desc.pop(0)
-        m = re.search(r"^router\s+(\S+)\s+", line)
-        router = m.group(1)
+        m = re.search(r"^router\s+(\S+)\s+(\S+)", line)
+        router,ip = m.groups()
         exitpolicy = []
         dead = not ("Running" in ns.flags)
         bw_observed = 0
@@ -510,22 +659,26 @@
             plog("NOTICE", "Got different names " + ns.nickname + " vs " +
                          router + " for " + ns.idhex)
         for line in desc:
+            pl = re.search(r"^platform Tor (\S+) on (\S+)", line)
             ac = re.search(r"^accept (\S+):([^-]+)(?:-(\d+))?", line)
             rj = re.search(r"^reject (\S+):([^-]+)(?:-(\d+))?", line)
             bw = re.search(r"^bandwidth \d+ \d+ (\d+)", line)
             if re.search(r"^opt hibernating 1", line):
                 dead = 1 # XXX: Technically this may be stale..
+                if ("Running" in ns.flags):
+                    plog("NOTICE", "Hibernating router is running..")
             if ac:
-                exitpolicy.append(ExitPolicyLine(1, *ac.groups()))
+                exitpolicy.append(ExitPolicyLine(True, *ac.groups()))
             elif rj:
-                exitpolicy.append(ExitPolicyLine(0, *rj.groups()))
+                exitpolicy.append(ExitPolicyLine(False, *rj.groups()))
             elif bw:
                 bw_observed = int(bw.group(1))
+            elif pl:
+                version, os = pl.groups()
         if not bw_observed and not dead and ("Valid" in ns.flags):
             plog("NOTICE", "No bandwidth for live router " + ns.nickname)
-        return Router(ns.idhex, ns.nickname, bw_observed, exitpolicy, dead,
-                ("Guard" in ns.flags), ("Valid" in ns.flags),
-                ("BadExit" in ns.flags), ("Fast" in ns.flags))
+        return Router(ns.idhex, ns.nickname, bw_observed, dead, exitpolicy,
+                ns.flags, ip, version, os)
 
     def get_info(self, name):
         """Return the value of the internal information field named 'name'.
@@ -534,7 +687,7 @@
         """
         if not isinstance(name, str):
             name = " ".join(name)
-        lines = self._sendAndRecv("GETINFO %s\r\n"%name)
+        lines = self.sendAndRecv("GETINFO %s\r\n"%name)
         d = {}
         for _,msg,more in lines:
             if msg == "OK":
@@ -556,14 +709,14 @@
         """
         if extended:
             plog ("DEBUG", "SETEVENTS EXTENDED %s\r\n" % " ".join(events))
-            self._sendAndRecv("SETEVENTS EXTENDED %s\r\n" % " ".join(events))
+            self.sendAndRecv("SETEVENTS EXTENDED %s\r\n" % " ".join(events))
         else:
-            self._sendAndRecv("SETEVENTS %s\r\n" % " ".join(events))
+            self.sendAndRecv("SETEVENTS %s\r\n" % " ".join(events))
 
     def save_conf(self):
         """Flush all configuration changes to disk.
         """
-        self._sendAndRecv("SAVECONF\r\n")
+        self.sendAndRecv("SAVECONF\r\n")
 
     def send_signal(self, sig):
         """Send the signal 'sig' to the Tor process; The allowed values for
@@ -574,13 +727,13 @@
                 0x0A : "USR1",
                 0x0C : "USR2",
                 0x0F : "TERM" }.get(sig,sig)
-        self._sendAndRecv("SIGNAL %s\r\n"%sig)
+        self.sendAndRecv("SIGNAL %s\r\n"%sig)
 
     def map_address(self, kvList):
         if not kvList:
             return
         m = " ".join([ "%s=%s" for k,v in kvList])
-        lines = self._sendAndRecv("MAPADDRESS %s\r\n"%m)
+        lines = self.sendAndRecv("MAPADDRESS %s\r\n"%m)
         r = []
         for _,line,_ in lines:
             try:
@@ -597,7 +750,7 @@
         if circid is None:
             circid = "0"
         plog("DEBUG", "Extending circuit")
-        lines = self._sendAndRecv("EXTENDCIRCUIT %d %s\r\n"
+        lines = self.sendAndRecv("EXTENDCIRCUIT %d %s\r\n"
                                   %(circid, ",".join(hops)))
         tp,msg,_ = lines[0]
         m = re.match(r'EXTENDED (\S*)', msg)
@@ -606,46 +759,46 @@
         plog("DEBUG", "Circuit extended")
         return int(m.group(1))
 
-    def build_circuit(self, pathlen, nodesel):
+    def build_circuit(self, pathlen, path_sel):
         circ = Circuit()
         if pathlen == 1:
-            circ.exit = nodesel.exit_chooser(circ.path)
-            circ.path = [circ.exit.idhex]
-            circ.cid = self.extend_circuit(0, circ.path)
+            circ.exit = path_sel.exit_chooser(circ.path)
+            circ.path = [circ.exit]
+            circ.cid = self.extend_circuit(0, circ.id_path())
         else:
-            circ.path.append(nodesel.entry_chooser(circ.path).idhex)
+            circ.path.append(path_sel.entry_chooser(circ.path))
             for i in xrange(1, pathlen-1):
-                circ.path.append(nodesel.middle_chooser(circ.path).idhex)
-            circ.exit = nodesel.exit_chooser(circ.path)
-            circ.path.append(circ.exit.idhex)
-            circ.cid = self.extend_circuit(0, circ.path)
+                circ.path.append(path_sel.middle_chooser(circ.path))
+            circ.exit = path_sel.exit_chooser(circ.path)
+            circ.path.append(circ.exit)
+            circ.cid = self.extend_circuit(0, circ.id_path())
         circ.created_at = datetime.datetime.now()
         return circ
 
     def redirect_stream(self, streamid, newaddr, newport=""):
         """DOCDOC"""
         if newport:
-            self._sendAndRecv("REDIRECTSTREAM %d %s %s\r\n"%(streamid, newaddr, newport))
+            self.sendAndRecv("REDIRECTSTREAM %d %s %s\r\n"%(streamid, newaddr, newport))
         else:
-            self._sendAndRecv("REDIRECTSTREAM %d %s\r\n"%(streamid, newaddr))
+            self.sendAndRecv("REDIRECTSTREAM %d %s\r\n"%(streamid, newaddr))
 
     def attach_stream(self, streamid, circid):
         """DOCDOC"""
         plog("DEBUG", "Attaching stream: "+str(streamid)+" to "+str(circid))
-        self._sendAndRecv("ATTACHSTREAM %d %d\r\n"%(streamid, circid))
+        self.sendAndRecv("ATTACHSTREAM %d %d\r\n"%(streamid, circid))
 
     def close_stream(self, streamid, reason=0, flags=()):
         """DOCDOC"""
-        self._sendAndRecv("CLOSESTREAM %d %s %s\r\n"
+        self.sendAndRecv("CLOSESTREAM %d %s %s\r\n"
                           %(streamid, reason, "".join(flags)))
 
     def close_circuit(self, circid, reason=0, flags=()):
         """DOCDOC"""
-        self._sendAndRecv("CLOSECIRCUIT %d %s %s\r\n"
+        self.sendAndRecv("CLOSECIRCUIT %d %s %s\r\n"
                           %(circid, reason, "".join(flags)))
 
     def post_descriptor(self, desc):
-        self._sendAndRecv("+POSTDESCRIPTOR\r\n%s"%escape_dots(desc))
+        self.sendAndRecv("+POSTDESCRIPTOR\r\n%s"%escape_dots(desc))
 
 def parse_ns_body(data):
     "Parse the body of an NS event or command."
@@ -665,24 +818,25 @@
     def __init__(self):
         """Create a new EventHandler."""
         self._map1 = {
-            "CIRC" : self.circ_status,
-            "STREAM" : self.stream_status,
-            "ORCONN" : self.or_conn_status,
-            "BW" : self.bandwidth,
-            "DEBUG" : self.msg,
-            "INFO" : self.msg,
-            "NOTICE" : self.msg,
-            "WARN" : self.msg,
-            "ERR" : self.msg,
-            "NEWDESC" : self.new_desc,
-            "ADDRMAP" : self.address_mapped,
-            "NS" : self.ns
+            "CIRC" : self.circ_status_event,
+            "STREAM" : self.stream_status_event,
+            "ORCONN" : self.or_conn_status_event,
+            "BW" : self.bandwidth_event,
+            "DEBUG" : self.msg_event,
+            "INFO" : self.msg_event,
+            "NOTICE" : self.msg_event,
+            "WARN" : self.msg_event,
+            "ERR" : self.msg_event,
+            "NEWDESC" : self.new_desc_event,
+            "ADDRMAP" : self.address_mapped_event,
+            "NS" : self.ns_event
             }
 
     def handle1(self, lines):
         """Dispatcher: called from Connection when an event is received."""
         for code, msg, data in lines:
             event = self.decode1(msg, data)
+            self.heartbeat_event()
             self._map1.get(event.event_name, self.unknown_event)(event)
 
     def decode1(self, body, data):
@@ -767,13 +921,19 @@
 
         return event
 
+    def heartbeat_event(self):
+        """Called every time any event is recieved. Convenience function
+           for any cleanup you may need to do.
+        """
+        pass
+
     def unknown_event(self, event):
         """Called when we get an event type we don't recognize.  This
            is almost alwyas an error.
         """
         raise NotImplemented()
 
-    def circ_status(self, event):
+    def circ_status_event(self, event):
         """Called when a circuit status changes if listening to CIRCSTATUS
            events.  'status' is a member of CIRC_STATUS; circID is a numeric
            circuit ID, and 'path' is the circuit's path so far as a list of
@@ -781,41 +941,41 @@
         """
         raise NotImplemented()
 
-    def stream_status(self, event):
+    def stream_status_event(self, event):
         """Called when a stream status changes if listening to STREAMSTATUS
            events.  'status' is a member of STREAM_STATUS; streamID is a
            numeric stream ID, and 'target' is the destination of the stream.
         """
         raise NotImplemented()
 
-    def or_conn_status(self, event):
+    def or_conn_status_event(self, event):
         """Called when an OR connection's status changes if listening to
            ORCONNSTATUS events. 'status' is a member of OR_CONN_STATUS; target
            is the OR in question.
         """
         raise NotImplemented()
 
-    def bandwidth(self, event):
+    def bandwidth_event(self, event):
         """Called once a second if listening to BANDWIDTH events.  'read' is
            the number of bytes read; 'written' is the number of bytes written.
         """
         raise NotImplemented()
 
-    def new_desc(self, event):
+    def new_desc_event(self, event):
         """Called when Tor learns a new server descriptor if listenting to
            NEWDESC events.
         """
         raise NotImplemented()
 
-    def msg(self, event):
+    def msg_event(self, event):
         """Called when a log message of a given severity arrives if listening
            to INFO_MSG, NOTICE_MSG, WARN_MSG, or ERR_MSG events."""
         raise NotImplemented()
 
-    def ns(self, event):
+    def ns_event(self, event):
         raise NotImplemented()
 
-    def address_mapped(self, event):
+    def address_mapped_event(self, event):
         """Called when Tor adds a mapping for an address if listening
            to ADDRESSMAPPED events.
         """
@@ -824,7 +984,7 @@
 
 class DebugEventHandler(EventHandler):
     """Trivial debug event handler: reassembles all parsed events to stdout."""
-    def circ_status(self, circ_event): # CircuitEvent()
+    def circ_status_event(self, circ_event): # CircuitEvent()
         output = [circ_event.event_name, str(circ_event.circ_id),
                   circ_event.status]
         if circ_event.path:
@@ -835,7 +995,7 @@
             output.append("REMOTE_REASON=" + circ_event.remote_reason)
         print " ".join(output)
 
-    def stream_status(self, strm_event):
+    def stream_status_event(self, strm_event):
         output = [strm_event.event_name, str(strm_event.strm_id),
                   strm_event.status, str(strm_event.circ_id),
                   strm_event.target_host, str(strm_event.target_port)]
@@ -845,16 +1005,16 @@
             output.append("REMOTE_REASON=" + strm_event.remote_reason)
         print " ".join(output)
 
-    def ns(self, ns_event):
+    def ns_event(self, ns_event):
         for ns in ns_event.nslist:
             print " ".join((ns_event.event_name, ns.nickname, ns.idhash,
               ns.updated.isoformat(), ns.ip, str(ns.orport),
               str(ns.dirport), " ".join(ns.flags)))
 
-    def new_desc(self, newdesc_event):
+    def new_desc_event(self, newdesc_event):
         print " ".join((newdesc_event.event_name, " ".join(newdesc_event.idlist)))
    
-    def or_conn_status(self, orconn_event):
+    def or_conn_status_event(self, orconn_event):
         if orconn_event.age: age = "AGE="+str(orconn_event.age)
         else: age = ""
         if orconn_event.read_bytes: read = "READ="+str(orconn_event.read_bytes)
@@ -868,10 +1028,10 @@
         print " ".join((orconn_event.event_name, orconn_event.endpoint,
                         orconn_event.status, age, read, wrote, reason, ncircs))
 
-    def msg(self, log_event):
+    def msg_event(self, log_event):
         print log_event.event_name+" "+log_event.msg
     
-    def bandwidth(self, bw_event):
+    def bandwidth_event(self, bw_event):
         print bw_event.event_name+" "+str(bw_event.read)+" "+str(bw_event.written)
 
 def parseHostAndPort(h):

Modified: torflow/trunk/metatroller.py
===================================================================
--- torflow/trunk/metatroller.py	2007-02-22 08:21:17 UTC (rev 9618)
+++ torflow/trunk/metatroller.py	2007-02-22 10:57:42 UTC (rev 9619)
@@ -14,16 +14,14 @@
 import random
 import datetime
 import threading
+import struct
 from TorUtil import *
 
 routers = {} # indexed by idhex
 name_to_key = {}
 key_to_name = {}
 
-total_r_bw = 0
 sorted_r = []
-sorted_g = []
-total_g_bw = 0
 
 circuits = {} # map from ID # to circuit object
 streams = {} # map from stream id to circuit
@@ -33,7 +31,7 @@
 # TODO: Move these to config file
 # TODO: Option to ignore guard flag
 control_host = "127.0.0.1"
-control_port = 9051
+control_port = 9061
 meta_host = "127.0.0.1"
 meta_port = 9052
 max_detach = 3
@@ -85,132 +83,195 @@
         self.host = host
         self.port = port
 
-# TODO: Obviously we need other node selector implementations
-#  - BwWeightedSelector
-#  - Restrictors (puts self.r_is_ok() into list):
-#    - Subnet16
-#    - AvoidWastingExits
-#    - VersionRange (Less than, greater than, in-range, not-equal)
-#    - OSSelector (ex Yes: Linux, *BSD; No: Windows, Solaris)
-#    - OceanPhobicRestrictor (avoids Pacific Ocean or two atlantic crossings)
-#      or ContinentRestrictor (avoids doing more than N continent crossings)
-#      - Mathematical/empirical study of predecessor expectation
-#        - If middle node is on the same continent as exit, exit learns nothing
-#        - else, exit has a bias on the continent of origin of user
-#          - Language and browser accept string determine this anyway
-#    - ExitCountry
-#    - AllCountry
+# TODO: We still need more path support implementations
+#  - BwWeightedGenerator
+#  - NodeRestrictions:
+#    - Uptime
+#    - GeoIP
+#      - NodeCountry
+#  - PathRestrictions
+#    - Family
+#    - GeoIP:
+#      - OceanPhobicRestrictor (avoids Pacific Ocean or two atlantic crossings)
+#        or ContinentRestrictor (avoids doing more than N continent crossings)
+#        - Mathematical/empirical study of predecessor expectation
+#          - If middle node on the same continent as exit, exit learns nothing
+#          - else, exit has a bias on the continent of origin of user
+#            - Language and browser accept string determine this anyway
 
-class UniformSelector(TorCtl.NodeSelector):
-    "Uniform node selection"
+class PercentileRestriction(TorCtl.NodeRestriction):
+    """If used, this restriction MUST be FIRST in the RestrictionList."""
+    def __init__(self, pct_skip, pct_fast, r_list):
+        self.pct_skip = pct_skip
+        self.pct_fast = pct_fast
+        self.sorted_r = r_list
+        self.position = 0
 
-    next_exit_by_port = {} # class member (aka C++ 'static')
-
-    def __init__(self, host, port):
-        if not port:
-            plog("DEBUG", "Using resolve: "+host+":"+str(resolve_port))
-            port = resolve_port
-        TorCtl.NodeSelector.__init__(self, host, port)
-        self.pct_fast = percent_fast
-        self.pct_skip = percent_skip
-        self.min_bw = min_bw
-        self.order_exits = order_exits
-        self.all_exits = use_all_exits
+    def reset(self, r_list):
+        self.sorted_r = r_list
+        self.position = 0
         
     def r_is_ok(self, r):
-        if r.bw < self.min_bw or not r.valid or not r.fast:
-            return False
-        else:
-            return True
+        ret = True
+        if self.position == len(self.sorted_r):
+            self.position = 0
+            plog("WARN", "Resetting PctFastRestriction")
+        if self.position != self.sorted_r.index(r): # XXX expensive?
+            plog("WARN", "Router"+r.name+" at mismatched index: "
+                         +self.position+" vs "+self.sorted_r.index(r))
+        
+        if self.position < len(self.sorted_r)*self.pct_skip/100:
+            ret = False
+        elif self.position > len(self.sorted_r)*self.pct_fast/100:
+            ret = False
+        
+        self.position += 1
+        return ret
+        
+class OSRestriction(TorCtl.NodeRestriction):
+    def __init__(self, ok, bad=[]):
+        self.ok = ok
+        self.bad = bad
 
-    def pick_r(self, r_list):
-        idx = random.randint(len(r_list)*self.pct_skip/100,
-                             len(r_list)*self.pct_fast/100)
-        return r_list[idx]
+    def r_is_ok(self, r):
+        for y in self.ok:
+            if re.search(y, r.os):
+                return True
+        for b in self.bad:
+            if re.search(b, r.os):
+                return False
+        if self.ok: return False
+        if self.bad: return True
 
-    def entry_chooser(self, path):
-        r = self.pick_r(sorted_g)
-        while not self.r_is_ok(r) or r.idhex in path:
-            r = self.pick_r(sorted_g)
-        return r
+class ConserveExitsRestriction(TorCtl.NodeRestriction):
+    def r_is_ok(self, r): return not "Exit" in r.flags
 
-    def middle_chooser(self, path):
-        r = self.pick_r(sorted_r)
-        while not self.r_is_ok(r) or r.idhex in path:
-            r = self.pick_r(sorted_r)
-        return r
+class FlagsRestriction(TorCtl.NodeRestriction):
+    def __init__(self, mandatory, forbidden=[]):
+        self.mandatory = mandatory
+        self.forbidden = forbidden
 
-    def exit_chooser(self, path):
-        if self.order_exits:
-            if self.to_port not in self.next_exit_by_port or self.next_exit_by_port[self.to_port] >= len(sorted_r):
-                self.next_exit_by_port[self.to_port] = 0
-                
-            r = sorted_r[self.next_exit_by_port[self.to_port]]
-            self.next_exit_by_port[self.to_port] += 1
-            while not r.will_exit_to(self.to_ip, self.to_port):
-                r = sorted_r[self.next_exit_by_port[self.to_port]]
-                self.next_exit_by_port[self.to_port] += 1
-                if self.next_exit_by_port[self.to_port] >= len(sorted_r):
-                    self.next_exit_by_port[self.to_port] = 0
-            return r
+    def r_is_ok(self, router):
+        for m in self.mandatory:
+            if not m in router.flags: return False
+        for f in self.forbidden:
+            if f in router.flags: return False
+        return True
+        
 
-        # FIXME: This should apply to ORDEREXITS (for speedracer?)
-        if self.all_exits:
-            minbw = self.min_bw
-            pct_fast = self.pct_fast
-            pct_skip = self.pct_skip
-            self.min_bw = self.pct_skip = 0
-            self.pct_fast = 100
+class MinBWRestriction(TorCtl.NodeRestriction):
+    def __init__(self, minbw):
+        self.min_bw = minbw
+
+    def r_is_ok(self, router): return router.bw >= self.min_bw
      
-        allowed = []
-        for r in sorted_r:
-            if self.r_is_ok(r) and not r.badexit and r.will_exit_to(self.to_ip, self.to_port):
-                allowed.append(r)
-        r = self.pick_r(allowed)
-        while r.idhex in path:
-            r = self.pick_r(allowed)
+class VersionIncludeRestriction(TorCtl.NodeRestriction):
+    def __init__(self, eq):
+        self.eq = map(TorCtl.RouterVersion, eq)
+    
+    def r_is_ok(self, router):
+        for e in self.eq:
+            if e == router.version:
+                return True
+        return False
 
-        if self.all_exits:
-            self.min_bw = minbw
-            self.pct_fast = pct_fast
-            self.pct_skip = pct_skip
- 
-        return r
 
- 
-def read_routers(c, nslist):
-    bad_key = 0
-    for ns in nslist:
-        try:
-            key_to_name[ns.idhex] = ns.nickname
-            name_to_key[ns.nickname] = ns.idhex
-            r = MetaRouter(c.get_router(ns))
-            if ns.idhex in routers:
-                if routers[ns.idhex].name != r.name:
-                    plog("NOTICE", "Router "+r.idhex+" changed names from "
-                         +routers[ns.idhex].name+" to "+r.name)
-                sorted_r.remove(routers[ns.idhex])
-            routers[ns.idhex] = r
-            sorted_r.append(r)
-        except TorCtl.ErrorReply:
-            bad_key += 1
-            if "Running" in ns.flags:
-                plog("NOTICE", "Running router "+ns.nickname+"="
-                     +ns.idhex+" has no descriptor")
-            pass
-        except:
-            traceback.print_exception(*sys.exc_info())
-            continue
-    sorted_r.sort(lambda x, y: cmp(y.bw, x.bw))
+class VersionExcludeRestriction(TorCtl.NodeRestriction):
+    def __init__(self, exclude):
+        self.exclude = map(TorCtl.RouterVersion, exclude)
+    
+    def r_is_ok(self, router):
+        for e in self.exclude:
+            if e == router.version:
+                return False
+        return True
 
-    global total_r_bw, total_g_bw # lame....
-    for r in sorted_r:
-        if not r.down:
-            total_r_bw += r.bw
-            if r.guard and r.valid:
-                total_g_bw += r.bw
-                sorted_g.append(r)
+class VersionRangeRestriction(TorCtl.NodeRestriction):
+    def __init__(self, gr_eq, less_eq=None):
+        self.gr_eq = TorCtl.RouterVersion(gr_eq)
+        if less_eq: self.less_eq = TorCtl.RouterVersion(less_eq)
+        else: self.less_eq = None
+    
 
+    def r_is_ok(self, router):
+        return (not self.gr_eq or router.version >= self.gr_eq) and \
+                (not self.less_eq or router.version <= self.less_eq)
+
+class ExitPolicyRestriction(TorCtl.NodeRestriction):
+    def __init__(self, to_ip, to_port):
+        self.to_ip = to_ip
+        self.to_port = to_port
+
+    def r_is_ok(self, r):
+        return r.will_exit_to(self.to_ip, self.to_port)
+
+class AndRestriction(TorCtl.NodeRestriction):
+    def __init__(self, a, b):
+        self.a = a
+        self.b = b
+
+    def r_is_ok(self, r): return self.a.r_is_ok(r) and self.b.r_is_ok(r)
+
+class OrRestriction(TorCtl.NodeRestriction):
+    def __init__(self, a, b):
+        self.a = a
+        self.b = b
+
+    def r_is_ok(self, r): return self.a.r_is_ok(r) or self.b.r_is_ok(r)
+
+class NotRestriction(TorCtl.NodeRestriction):
+    def __init__(self, a):
+        self.a = a
+
+    def r_is_ok(self, r): return not self.a.r_is_ok(r)
+
+class Subnet16Restriction(TorCtl.PathRestriction):
+    def r_is_ok(self, path, router):
+        mask16 = struct.unpack(">I", socket.inet_aton("255.255.0.0"))[0]
+        ip16 = router.ip & mask16
+        for r in path:
+            if ip16 == (r.ip & mask16):
+                return False
+        return True
+
+class UniqueRestriction(TorCtl.PathRestriction):
+    def r_is_ok(self, path, r): return not r in path
+
+class UniformGenerator(TorCtl.NodeGenerator):
+    def next_r(self):
+        while not self.all_chosen():
+            r = random.choice(self.routers)
+            self.mark_chosen(r)
+            yield r
+
+class OrderedExitGenerator(TorCtl.NodeGenerator):
+    next_exit_by_port = {} # class member (aka C++ 'static')
+    def __init__(self, restriction_list, to_port):
+        self.to_port = to_port
+        TorCtl.NodeGenerator.__init__(self, restriction_list)
+
+    def rewind(self):
+        TorCtl.NodeGenerator.rewind(self)
+        if self.to_port not in self.next_exit_by_port or not self.next_exit_by_port[self.to_port]:
+            self.next_exit_by_port[self.to_port] = 0
+            self.last_idx = len(self.routers)
+        else:
+            self.last_idx = self.next_exit_by_port[self.to_port]
+   
+    # Just in case: 
+    def mark_chosen(self, r): raise NotImplemented()
+    def all_chosen(self): raise NotImplemented()
+
+    def next_r(self):
+        while True: # A do..while would be real nice here..
+            if self.next_exit_by_port[self.to_port] >= len(sorted_r):
+                self.next_exit_by_port[self.to_port] = 0
+            r = self.routers[self.next_exit_by_port[self.to_port]]
+            self.next_exit_by_port[self.to_port] += 1
+            yield r
+            if self.last_idx == self.next_exit_by_port[self.to_port]:
+                break
+        
 # TODO: Make passive mode so people can get aggregate node reliability 
 # stats for normal usage without us attaching streams
 
@@ -219,7 +280,54 @@
     def __init__(self, c):
         TorCtl.EventHandler.__init__(self)
         self.c = c
+        nslist = c.get_network_status()
+        self.read_routers(nslist)
+        plog("INFO", "Read "+str(len(sorted_r))+"/"+str(len(nslist))+" routers")
+        self.path_rstr = TorCtl.PathRestrictionList(
+                 [Subnet16Restriction(), UniqueRestriction()])
+        self.entry_rstr = TorCtl.NodeRestrictionList(
+            [PercentileRestriction(percent_skip, percent_fast, sorted_r),
+             ConserveExitsRestriction(),
+             FlagsRestriction(["Guard", "Valid", "Running"], [])], sorted_r)
+        self.mid_rstr = TorCtl.NodeRestrictionList(
+            [PercentileRestriction(percent_skip, percent_fast, sorted_r),
+             ConserveExitsRestriction(),
+             FlagsRestriction(["Valid", "Running"], [])], sorted_r)
+        self.exit_rstr = TorCtl.NodeRestrictionList(
+            [PercentileRestriction(percent_skip, percent_fast, sorted_r),
+             FlagsRestriction(["Valid", "Running", "Exit"], ["BadExit"])],
+             sorted_r)
+        self.path_selector = TorCtl.PathSelector(
+             UniformGenerator(self.entry_rstr),
+             UniformGenerator(self.mid_rstr),
+             OrderedExitGenerator(self.exit_rstr, 80), self.path_rstr)
 
+    def read_routers(self, nslist):
+        bad_key = 0
+        for ns in nslist:
+            try:
+                key_to_name[ns.idhex] = ns.nickname
+                name_to_key[ns.nickname] = ns.idhex
+                r = MetaRouter(self.c.get_router(ns))
+                if ns.idhex in routers:
+                    if routers[ns.idhex].name != r.name:
+                        plog("NOTICE", "Router "+r.idhex+" changed names from "
+                             +routers[ns.idhex].name+" to "+r.name)
+                    sorted_r.remove(routers[ns.idhex])
+                routers[ns.idhex] = r
+                sorted_r.append(r)
+            except TorCtl.ErrorReply:
+                bad_key += 1
+                if "Running" in ns.flags:
+                    plog("NOTICE", "Running router "+ns.nickname+"="
+                         +ns.idhex+" has no descriptor")
+                pass
+            except:
+                traceback.print_exception(*sys.exc_info())
+                continue
+    
+        sorted_r.sort(lambda x, y: cmp(y.bw, x.bw))
+
     def attach_stream_any(self, stream, badcircs):
         # Newnym, and warn if not built plus pending
         unattached_streams = [stream]
@@ -253,9 +361,12 @@
         else:
             circ = None
             while circ == None:
+                self.exit_rstr.del_restriction(ExitPolicyRestriction)
+                self.exit_rstr.add_restriction(
+                     ExitPolicyRestriction(stream.host, stream.port))
                 try:
                     circ = MetaCircuit(self.c.build_circuit(pathlen,
-                                    UniformSelector(stream.host, stream.port)))
+                                    self.path_selector))
                 except TorCtl.ErrorReply, e:
                     # FIXME: How come some routers are non-existant? Shouldn't
                     # we have gotten an NS event to notify us they
@@ -270,7 +381,11 @@
         global last_exit # Last attempted exit
         last_exit = circ.exit
 
-    def circ_status(self, c):
+    def heartbeat_event(self):
+        # XXX: Config updates to selectors
+        pass
+
+    def circ_status_event(self, c):
         output = [c.event_name, str(c.circ_id), c.status]
         if c.path: output.append(",".join(c.path))
         if c.reason: output.append("REASON=" + c.reason)
@@ -292,7 +407,7 @@
                 self.c.attach_stream(stream.sid, c.circ_id)
                 circuits[c.circ_id].used_cnt += 1
 
-    def stream_status(self, s):
+    def stream_status_event(self, s):
         output = [s.event_name, str(s.strm_id), s.status, str(s.circ_id),
                   s.target_host, str(s.target_port)]
         if s.reason: output.append("REASON=" + s.reason)
@@ -365,16 +480,23 @@
                 streams[s.strm_id].host = s.target_host
                 streams[s.strm_id].port = s.target_port
 
-    def ns(self, n):
-        read_routers(self.c, n.nslist)
+
+    def ns_event(self, n):
+        self.read_routers(n.nslist)
         plog("DEBUG", "Read " + str(len(n.nslist))+" NS => " 
              + str(len(sorted_r)) + " routers")
+        self.entry_rstr.update_routers(sorted_r)
+        self.mid_rstr.update_routers(sorted_r)
+        self.exit_rstr.update_routers(sorted_r)
     
-    def new_desc(self, d):
+    def new_desc_event(self, d):
         for i in d.idlist: # Is this too slow?
-            read_routers(self.c, self.c.get_network_status("id/"+i))
+            self.read_routers(self.c.get_network_status("id/"+i))
         plog("DEBUG", "Read " + str(len(d.idlist))+" Desc => " 
              + str(len(sorted_r)) + " routers")
+        self.entry_rstr.update_routers(sorted_r)
+        self.mid_rstr.update_routers(sorted_r)
+        self.exit_rstr.update_routers(sorted_r)
         
 
 def commandloop(s):
@@ -480,10 +602,6 @@
 
 def listenloop(c):
     """Loop that handles metatroller commands"""
-    nslist = c.get_network_status()
-    read_routers(c, nslist)
-    c.set_option("__LeaveStreamsUnattached", "1")
-    plog("INFO", "Read "+str(len(sorted_r))+"/"+str(len(nslist))+" routers")
     srv = ListenSocket(meta_host, meta_port)
     atexit.register(cleanup, *(c, srv))
     while 1:
@@ -493,19 +611,22 @@
         thr.run()
     srv.close()
 
-def main(argv):
+def startup():
     s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     s.connect((control_host,control_port))
     c = TorCtl.get_connection(s)
     c.debug(file("control.log", "w"))
+    c.authenticate()
     c.set_event_handler(SnakeHandler(c))
-    c.launch_thread()
-    c.authenticate()
     c.set_events([TorCtl.EVENT_TYPE.STREAM,
                   TorCtl.EVENT_TYPE.NS,
                   TorCtl.EVENT_TYPE.CIRC,
                   TorCtl.EVENT_TYPE.NEWDESC], True)
-    listenloop(c)
+    c.set_option("__LeaveStreamsUnattached", "1")
+    return c
 
+def main(argv):
+    listenloop(startup())
+
 if __name__ == '__main__':
     main(sys.argv)

Modified: torflow/trunk/soat.pl
===================================================================
--- torflow/trunk/soat.pl	2007-02-22 08:21:17 UTC (rev 9618)
+++ torflow/trunk/soat.pl	2007-02-22 10:57:42 UTC (rev 9619)
@@ -17,7 +17,7 @@
 
 #Privoxy is a bad idea since it rewrites shit that will mess with our 
 #baseline md5s of html
-my $SOCKS_PROXY = "127.0.0.1:9050";
+my $SOCKS_PROXY = "127.0.0.1:9060";
 
 my @TO_SCAN = ("ssl");
 my $ALLOW_NEW_SSL_IPS = 1;

Added: torflow/trunk/unit.py
===================================================================
--- torflow/trunk/unit.py	2007-02-22 08:21:17 UTC (rev 9618)
+++ torflow/trunk/unit.py	2007-02-22 10:57:42 UTC (rev 9619)
@@ -0,0 +1,59 @@
+#!/usr/bin/python
+# Metatroller and TorCtl Unit Tests
+
+"""
+Unit tests
+"""
+
+import metatroller
+import copy
+import TorCtl
+c = metatroller.startup()
+
+print "Done!"
+
+# TODO: Tests:
+#  - Test each NodeRestriction and print in/out lines for it
+#  - Test NodeGenerator and reapply NodeRestrictions
+#  - Same for PathSelector and PathRestrictions
+#    - Also Reapply each restriction by hand to path. Verify returns true
+
+def do_unit(rst, r_list):
+    print "\n"
+    print rst.r_is_ok.im_class
+    for r in r_list:
+        print r.name+" "+r.os+" "+str(r.version)+"="+str(rst.r_is_ok(r))
+
+# Need copy for threadsafeness (XXX: hopefully it is atomic)
+sorted_r = copy.copy(metatroller.sorted_r)
+pct_rst = metatroller.PercentileRestriction(10, 20, sorted_r)
+oss_rst = metatroller.OSRestriction([r"[lL]inux", r"BSD", "Darwin"], [])
+prop_rst = metatroller.OSRestriction([], ["Windows", "Solaris"])
+
+#do_unit(metatroller.VersionRangeRestriction("0.1.2.0"), sorted_r)
+#do_unit(metatroller.VersionRangeRestriction("0.1.2.0", "0.1.2.5"), sorted_r)
+#do_unit(metatroller.VersionIncludeRestriction(["0.1.1.26-alpha"]), sorted_r)
+#do_unit(metatroller.VersionExcludeRestriction(["0.1.1.26"]), sorted_r)
+
+#do_unit(metatroller.ConserveExitsRestriction(), sorted_r)
+
+#do_unit(metatroller.FlagsRestriction([], ["Valid"]), sorted_r)
+
+# TODO: Cross check ns exit flag with this list
+#do_unit(metatroller.ExitPolicyRestriction("255.255.255.255", 25), sorted_r)
+
+#do_unit(pct_rst, sorted_r)
+#do_unit(oss_rst, sorted_r)
+#do_unit(alpha_rst, sorted_r)
+    
+rl =  [metatroller.ExitPolicyRestriction("255.255.255.255", 80), metatroller.OrRestriction(metatroller.ExitPolicyRestriction("255.255.255.255", 443), metatroller.ExitPolicyRestriction("255.255.255.255", 6667)), metatroller.FlagsRestriction([], ["BadExit"])]
+
+exit_rstr = TorCtl.NodeRestrictionList(rl, sorted_r)
+
+ug = metatroller.UniformGenerator(exit_rstr)
+
+for r in ug.next_r():
+    print "Checking: " + r.name
+    for rs in rl:
+        if not rs.r_is_ok(r):
+            raise FuxxorateException()