[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[minion-cvs] Halftested] Add support to MMTP for connection padding,...



Update of /home/minion/cvsroot/src/minion/lib/mixminion/server
In directory moria.mit.edu:/tmp/cvs-serv9624/lib/mixminion/server

Modified Files:
	MMTPServer.py 
Log Message:
[Halftested] Add support to MMTP for connection padding, key renegotiation,
protocol negotiation.  Bump protocol version to 0.2, since older servers don't
receive padding correctly.


Index: MMTPServer.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/server/MMTPServer.py,v
retrieving revision 1.16
retrieving revision 1.17
diff -u -d -r1.16 -r1.17
--- MMTPServer.py	10 Jan 2003 19:29:29 -0000	1.16
+++ MMTPServer.py	12 Jan 2003 04:27:19 -0000	1.17
@@ -30,7 +30,7 @@
 
 import mixminion._minionlib as _ml
 from mixminion.Common import MixError, MixFatalError, LOG, stringContains
-from mixminion.Crypto import sha1
+from mixminion.Crypto import sha1, getCommonPRNG
 from mixminion.Packet import MESSAGE_LEN, DIGEST_LEN
 
 __all__ = [ 'AsyncServer', 'ListenConnection', 'MMTPServerConnection',
@@ -242,6 +242,7 @@
     #    __terminator: None, or a string which will terminate the current read.
     #    __outbuf: None, or the remainder of the string we're currently
     #           writing.
+    # DOCDOC __servermode
     def __init__(self, sock, tls, serverMode, address=None):
         """Create a new SimpleTLSConnection.
 
@@ -253,6 +254,7 @@
         self.__con = tls
         self.fd = self.__con.fileno()
         self.lastActivity = time.time()
+        self.__serverMode = serverMode
 
         if serverMode:
             self.__state = self.__acceptFn
@@ -304,6 +306,10 @@
         self.__state = self.__writeFn
         self.__server.registerWriter(self)
 
+    def get_num_renegotiations(self):
+        "DOCDOC"
+        return self.__con.get_num_renegotiations()
+
     def __acceptFn(self):
         """Hook to implement server-side handshake."""
         self.__con.accept() #may throw want*
@@ -389,6 +395,19 @@
         if len(out) == 0:
             self.finished()
 
+    def __handshakeFn(self):
+        "DOCDOC"
+        assert not self.__serverMode #DOCDOC
+        self.__con.do_handshake() #may throw want*
+        self.__server.unregister(self)
+        self.finished()
+
+    def startRenegotiate(self):
+        "DOCDOC"
+        self.__con.renegotiate() # Succeeds immediately.
+        self.__state = self.__handshakeFn
+        self.__server.registerBoth(self) #????
+
     def tryTimeout(self, cutoff):
         if self.lastActivity <= cutoff:
             warn("Socket %s to %s timed out", self.fd, self.address)
@@ -475,7 +494,7 @@
 # Implementation for MMTP.
 
 # The protocol string to send.
-PROTOCOL_STRING      = "MMTP 0.1\r\n"
+PROTOCOL_STRING      = "MMTP 0.1,0.2\r\n"
 # The protocol specification to expect.
 PROTOCOL_RE          = re.compile("MMTP ([^\s\r\n]+)\r\n")
 # Control line for sending a message.
@@ -496,6 +515,9 @@
     # messageConsumer: a function to call with all received messages.
     # finished: callback when we're done with a read or write; see
     #     SimpleTLSConnection.
+    # DOCDOC protocol
+    # DOCDOC nrenegotiations: renegotiations when last message received.
+    PROTOCOL_VERSIONS = [ '0.2', '0.1' ]
     def __init__(self, sock, tls, consumer):
         """Create an MMTP connection to receive messages sent along a given
            socket.  When valid packets are received, pass them to the
@@ -504,6 +526,7 @@
                                      "%s:%s"%sock.getpeername())
         self.messageConsumer = consumer
         self.finished = self.__setupFinished
+        self.protocol = None
 
     def __setupFinished(self):
         """Called once we're done accepting.  Begins reading the protocol
@@ -511,6 +534,7 @@
         """
         self.finished = self.__receivedProtocol
         self.expectRead(1024, '\n')
+        self.nRenegotiations = self.get_num_renegotiations()
 
     def __receivedProtocol(self):
         """Called once we're done reading the protocol string.  Either
@@ -521,19 +545,24 @@
         m = PROTOCOL_RE.match(inp)
 
         if not m:
-            warn("Bad protocol list.  Closing connection to %s", self.address)
-            self.shutdown(err=1)
-            return
-        protocols = m.group(1).split(",")
-        if "0.1" not in protocols:
-            warn("Unsupported protocol list.  Closing connection to %s",
+            warn("Bad protocol list: %r.  Closing connection to %s", inp,
                  self.address)
             self.shutdown(err=1)
             return
-        else:
-            trace("protocol ok (fd %s)", self.fd)
-            self.finished = self.__sentProtocol
-            self.beginWrite(PROTOCOL_STRING)
+        protocols = m.group(1).split(",")
+        for p in self.PROTOCOL_VERSIONS:
+            if p in protocols:
+                trace("Using protocol %s with %s (fd %s)",
+                      p, self.address, self.fd)
+                self.protocol = p
+                self.finished = self.__sentProtocol
+                self.beginWrite("MMTP %s\r\n"% p)
+                return
+            
+        warn("Unsupported protocol list.  Closing connection to %s",
+             self.address)
+        self.shutdown(err=1)
+        return
 
     def __sentProtocol(self):
         """Called once we're done sending our protocol response.  Begins
@@ -553,9 +582,11 @@
         if data.startswith(JUNK_CONTROL):
             expectedDigest = sha1(msg+"JUNK")
             replyDigest = sha1(msg+"RECEIVED JUNK")
+            isJunk = 1
         elif data.startswith(SEND_CONTROL):
             expectedDigest = sha1(msg+"SEND")
             replyDigest = sha1(msg+"RECEIVED")
+            isJunk = 0
         else:
             warn("Unrecognized command from %s.  Closing connection.",
                  self.address)
@@ -567,11 +598,13 @@
             self.shutdown(err=1)
             return
         else:
+            self.nRenegotiations = self.get_num_renegotiations()
             debug("%s packet received from %s; Checksum valid.",
                   data[:4], self.address)
             self.finished = self.__sentAck
             self.beginWrite(RECEIVED_CONTROL+replyDigest)
-            self.messageConsumer(msg)
+            if not isJunk:
+                self.messageConsumer(msg)
 
     def __sentAck(self):
         """Called once we're done sending an ACK.  Begins reading a new
@@ -590,6 +623,8 @@
     ## Fields:
     # ip, port, keyID, messageList, handleList, sendCallback, failCallback:
     #   As described in the docstring for __init__ below.
+    # DOCDOC protocol
+    PROTOCOL_VERSIONS = [ '0.1', '0.2' ]
     def __init__(self, context, ip, port, keyID, messageList, handleList,
                  sentCallback=None, failCallback=None):
         """Create a connection to send messages to an MMTP server.
@@ -599,6 +634,7 @@
            port -- The port to connect to.
            keyID -- None, or the expected SHA1 hash of the server's public key
            messageList -- a list of message payloads to send
+               DOCDOC, or 'JUNK', or 'RENEGOTIATE'.
            handleList -- a list of objects corresponding to the messages in
               messageList.  Used for callback.
            sentCallback -- None, or a function of (msg, handle) to be called
@@ -606,6 +642,12 @@
            failCallback -- None, or a function of (msg, handle, retriable)
               to be called when messages can't be sent."""
 
+        # Generate junk before connecting to avoid timing attacks
+        self.junk = [] #XXXX doc this field.
+        for m in messageList:
+            if m == 'JUNK':
+                self.junk.append(getCommonPRNG().getBytes(MESSAGE_LEN))
+
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setblocking(0)
         self.keyID = keyID
@@ -626,6 +668,7 @@
         self.finished = self.__setupFinished
         self.sentCallback = sentCallback
         self.failCallback = failCallback
+        self.protocol = None
 
         debug("Opening client connection (fd %s)", self.fd)
 
@@ -644,7 +687,7 @@
             else:
                 debug("KeyID from %s is valid", self.address)
 
-        self.beginWrite(PROTOCOL_STRING)
+        self.beginWrite("MMTP %s\r\n"%(",".join(self.PROTOCOL_VERSIONS)))
         self.finished = self.__sentProtocol
 
     def __sentProtocol(self):
@@ -659,14 +702,19 @@
            sending a packet, or exits if we're done sending.
         """
         inp = self.getInput()
-        if inp != PROTOCOL_STRING:
-            warn("Invalid protocol.  Closing connection to %s", self.address)
-            # This isn't retriable; we don't talk to servers we don't
-            # understand.
-            self.shutdown(err=1,retriable=0)
-            return
+        
+        for p in self.PROTOCOL_VERSIONS:
+            if inp == 'MMTP %s\r\n'%p:
+                trace("Speaking MMTP version %s with %s", p, self.address)
+                self.protocol = inp
+                self.beginNextMessage()
+                return
 
-        self.beginNextMessage()
+        warn("Invalid protocol.  Closing connection to %s", self.address)
+        # This isn't retriable; we don't talk to servers we don't
+        # understand.
+        self.shutdown(err=1,retriable=0)
+        return
 
     def beginNextMessage(self):
         """Start writing a message to the connection."""
@@ -674,10 +722,28 @@
             self.shutdown(0)
             return
         msg = self.messageList[0]
-        self.expectedDigest = sha1(msg+"RECEIVED")
-        msg = SEND_CONTROL+msg+sha1(msg+"SEND")
-        assert len(msg) == SEND_RECORD_LEN
+        if msg == 'RENEGOTIATE':
+            del self.messageList[0]
+            self.finished = self.beginNextMessage
+            self.startRenegotiate()
+            return
+        elif msg == 'JUNK':
+            del self.messageList[0]
+            msg = self.junk[0]
+            del self.junk[0]
+            if self.protocol == '0.1':
+                debug("Won't send junk to a 0.1 server.")
+                self.beginNextMessage()
+                return
+            self.expectedDigest = sha1(msg+"RECEIVED JUNK")
+            msg = JUNK_CONTROL+msg+sha1(msg+"JUNK")
+            self.isJunk = 1 #DOCDOC
+        else:
+            self.expectedDigest = sha1(msg+"RECEIVED")
+            msg = SEND_CONTROL+msg+sha1(msg+"SEND")
+            self.isJunk = 0 #DOCDOC
 
+        assert len(msg) == SEND_RECORD_LEN
         self.beginWrite(msg)
         self.finished = self.__sentMessage
 
@@ -708,12 +774,13 @@
            return
 
        debug("Received valid ACK for message from %s", self.address)
-       justSent = self.messageList[0]
-       justSentHandle = self.handleList[0]
-       del self.messageList[0]
-       del self.handleList[0]
-       if self.sentCallback is not None:
-           self.sentCallback(justSent, justSentHandle)
+       if not self.isJunk:
+           justSent = self.messageList[0]
+           justSentHandle = self.handleList[0]
+           del self.messageList[0]
+           del self.handleList[0]
+           if self.sentCallback is not None:
+               self.sentCallback(justSent, justSentHandle)
 
        self.beginNextMessage()
 
@@ -723,7 +790,7 @@
             for msg, handle in zip(self.messageList, self.handleList):
                 self.failCallback(msg,handle,retriable)
 
-LISTEN_BACKLOG = 10 # ???? Is something else more reasonable?
+LISTEN_BACKLOG = 128
 class MMTPAsyncServer(AsyncServer):
     """A helper class to invoke AsyncServer, MMTPServerConnection, and
        MMTPClientConnection, with a function to add new connections, and