[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]
[minion-cvs] Halftested] Add support to MMTP for connection padding,...
Update of /home/minion/cvsroot/src/minion/lib/mixminion/server
In directory moria.mit.edu:/tmp/cvs-serv9624/lib/mixminion/server
Modified Files:
MMTPServer.py
Log Message:
[Halftested] Add support to MMTP for connection padding, key renegotiation,
protocol negotiation. Bump protocol version to 0.2, since older servers don't
receive padding correctly.
Index: MMTPServer.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/server/MMTPServer.py,v
retrieving revision 1.16
retrieving revision 1.17
diff -u -d -r1.16 -r1.17
--- MMTPServer.py 10 Jan 2003 19:29:29 -0000 1.16
+++ MMTPServer.py 12 Jan 2003 04:27:19 -0000 1.17
@@ -30,7 +30,7 @@
import mixminion._minionlib as _ml
from mixminion.Common import MixError, MixFatalError, LOG, stringContains
-from mixminion.Crypto import sha1
+from mixminion.Crypto import sha1, getCommonPRNG
from mixminion.Packet import MESSAGE_LEN, DIGEST_LEN
__all__ = [ 'AsyncServer', 'ListenConnection', 'MMTPServerConnection',
@@ -242,6 +242,7 @@
# __terminator: None, or a string which will terminate the current read.
# __outbuf: None, or the remainder of the string we're currently
# writing.
+ # DOCDOC __servermode
def __init__(self, sock, tls, serverMode, address=None):
"""Create a new SimpleTLSConnection.
@@ -253,6 +254,7 @@
self.__con = tls
self.fd = self.__con.fileno()
self.lastActivity = time.time()
+ self.__serverMode = serverMode
if serverMode:
self.__state = self.__acceptFn
@@ -304,6 +306,10 @@
self.__state = self.__writeFn
self.__server.registerWriter(self)
+ def get_num_renegotiations(self):
+ "DOCDOC"
+ return self.__con.get_num_renegotiations()
+
def __acceptFn(self):
"""Hook to implement server-side handshake."""
self.__con.accept() #may throw want*
@@ -389,6 +395,19 @@
if len(out) == 0:
self.finished()
+ def __handshakeFn(self):
+ "DOCDOC"
+ assert not self.__serverMode #DOCDOC
+ self.__con.do_handshake() #may throw want*
+ self.__server.unregister(self)
+ self.finished()
+
+ def startRenegotiate(self):
+ "DOCDOC"
+ self.__con.renegotiate() # Succeeds immediately.
+ self.__state = self.__handshakeFn
+ self.__server.registerBoth(self) #????
+
def tryTimeout(self, cutoff):
if self.lastActivity <= cutoff:
warn("Socket %s to %s timed out", self.fd, self.address)
@@ -475,7 +494,7 @@
# Implementation for MMTP.
# The protocol string to send.
-PROTOCOL_STRING = "MMTP 0.1\r\n"
+PROTOCOL_STRING = "MMTP 0.1,0.2\r\n"
# The protocol specification to expect.
PROTOCOL_RE = re.compile("MMTP ([^\s\r\n]+)\r\n")
# Control line for sending a message.
@@ -496,6 +515,9 @@
# messageConsumer: a function to call with all received messages.
# finished: callback when we're done with a read or write; see
# SimpleTLSConnection.
+ # DOCDOC protocol
+ # DOCDOC nrenegotiations: renegotiations when last message received.
+ PROTOCOL_VERSIONS = [ '0.2', '0.1' ]
def __init__(self, sock, tls, consumer):
"""Create an MMTP connection to receive messages sent along a given
socket. When valid packets are received, pass them to the
@@ -504,6 +526,7 @@
"%s:%s"%sock.getpeername())
self.messageConsumer = consumer
self.finished = self.__setupFinished
+ self.protocol = None
def __setupFinished(self):
"""Called once we're done accepting. Begins reading the protocol
@@ -511,6 +534,7 @@
"""
self.finished = self.__receivedProtocol
self.expectRead(1024, '\n')
+ self.nRenegotiations = self.get_num_renegotiations()
def __receivedProtocol(self):
"""Called once we're done reading the protocol string. Either
@@ -521,19 +545,24 @@
m = PROTOCOL_RE.match(inp)
if not m:
- warn("Bad protocol list. Closing connection to %s", self.address)
- self.shutdown(err=1)
- return
- protocols = m.group(1).split(",")
- if "0.1" not in protocols:
- warn("Unsupported protocol list. Closing connection to %s",
+ warn("Bad protocol list: %r. Closing connection to %s", inp,
self.address)
self.shutdown(err=1)
return
- else:
- trace("protocol ok (fd %s)", self.fd)
- self.finished = self.__sentProtocol
- self.beginWrite(PROTOCOL_STRING)
+ protocols = m.group(1).split(",")
+ for p in self.PROTOCOL_VERSIONS:
+ if p in protocols:
+ trace("Using protocol %s with %s (fd %s)",
+ p, self.address, self.fd)
+ self.protocol = p
+ self.finished = self.__sentProtocol
+ self.beginWrite("MMTP %s\r\n"% p)
+ return
+
+ warn("Unsupported protocol list. Closing connection to %s",
+ self.address)
+ self.shutdown(err=1)
+ return
def __sentProtocol(self):
"""Called once we're done sending our protocol response. Begins
@@ -553,9 +582,11 @@
if data.startswith(JUNK_CONTROL):
expectedDigest = sha1(msg+"JUNK")
replyDigest = sha1(msg+"RECEIVED JUNK")
+ isJunk = 1
elif data.startswith(SEND_CONTROL):
expectedDigest = sha1(msg+"SEND")
replyDigest = sha1(msg+"RECEIVED")
+ isJunk = 0
else:
warn("Unrecognized command from %s. Closing connection.",
self.address)
@@ -567,11 +598,13 @@
self.shutdown(err=1)
return
else:
+ self.nRenegotiations = self.get_num_renegotiations()
debug("%s packet received from %s; Checksum valid.",
data[:4], self.address)
self.finished = self.__sentAck
self.beginWrite(RECEIVED_CONTROL+replyDigest)
- self.messageConsumer(msg)
+ if not isJunk:
+ self.messageConsumer(msg)
def __sentAck(self):
"""Called once we're done sending an ACK. Begins reading a new
@@ -590,6 +623,8 @@
## Fields:
# ip, port, keyID, messageList, handleList, sendCallback, failCallback:
# As described in the docstring for __init__ below.
+ # DOCDOC protocol
+ PROTOCOL_VERSIONS = [ '0.1', '0.2' ]
def __init__(self, context, ip, port, keyID, messageList, handleList,
sentCallback=None, failCallback=None):
"""Create a connection to send messages to an MMTP server.
@@ -599,6 +634,7 @@
port -- The port to connect to.
keyID -- None, or the expected SHA1 hash of the server's public key
messageList -- a list of message payloads to send
+ DOCDOC, or 'JUNK', or 'RENEGOTIATE'.
handleList -- a list of objects corresponding to the messages in
messageList. Used for callback.
sentCallback -- None, or a function of (msg, handle) to be called
@@ -606,6 +642,12 @@
failCallback -- None, or a function of (msg, handle, retriable)
to be called when messages can't be sent."""
+ # Generate junk before connecting to avoid timing attacks
+ self.junk = [] #XXXX doc this field.
+ for m in messageList:
+ if m == 'JUNK':
+ self.junk.append(getCommonPRNG().getBytes(MESSAGE_LEN))
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(0)
self.keyID = keyID
@@ -626,6 +668,7 @@
self.finished = self.__setupFinished
self.sentCallback = sentCallback
self.failCallback = failCallback
+ self.protocol = None
debug("Opening client connection (fd %s)", self.fd)
@@ -644,7 +687,7 @@
else:
debug("KeyID from %s is valid", self.address)
- self.beginWrite(PROTOCOL_STRING)
+ self.beginWrite("MMTP %s\r\n"%(",".join(self.PROTOCOL_VERSIONS)))
self.finished = self.__sentProtocol
def __sentProtocol(self):
@@ -659,14 +702,19 @@
sending a packet, or exits if we're done sending.
"""
inp = self.getInput()
- if inp != PROTOCOL_STRING:
- warn("Invalid protocol. Closing connection to %s", self.address)
- # This isn't retriable; we don't talk to servers we don't
- # understand.
- self.shutdown(err=1,retriable=0)
- return
+
+ for p in self.PROTOCOL_VERSIONS:
+ if inp == 'MMTP %s\r\n'%p:
+ trace("Speaking MMTP version %s with %s", p, self.address)
+ self.protocol = inp
+ self.beginNextMessage()
+ return
- self.beginNextMessage()
+ warn("Invalid protocol. Closing connection to %s", self.address)
+ # This isn't retriable; we don't talk to servers we don't
+ # understand.
+ self.shutdown(err=1,retriable=0)
+ return
def beginNextMessage(self):
"""Start writing a message to the connection."""
@@ -674,10 +722,28 @@
self.shutdown(0)
return
msg = self.messageList[0]
- self.expectedDigest = sha1(msg+"RECEIVED")
- msg = SEND_CONTROL+msg+sha1(msg+"SEND")
- assert len(msg) == SEND_RECORD_LEN
+ if msg == 'RENEGOTIATE':
+ del self.messageList[0]
+ self.finished = self.beginNextMessage
+ self.startRenegotiate()
+ return
+ elif msg == 'JUNK':
+ del self.messageList[0]
+ msg = self.junk[0]
+ del self.junk[0]
+ if self.protocol == '0.1':
+ debug("Won't send junk to a 0.1 server.")
+ self.beginNextMessage()
+ return
+ self.expectedDigest = sha1(msg+"RECEIVED JUNK")
+ msg = JUNK_CONTROL+msg+sha1(msg+"JUNK")
+ self.isJunk = 1 #DOCDOC
+ else:
+ self.expectedDigest = sha1(msg+"RECEIVED")
+ msg = SEND_CONTROL+msg+sha1(msg+"SEND")
+ self.isJunk = 0 #DOCDOC
+ assert len(msg) == SEND_RECORD_LEN
self.beginWrite(msg)
self.finished = self.__sentMessage
@@ -708,12 +774,13 @@
return
debug("Received valid ACK for message from %s", self.address)
- justSent = self.messageList[0]
- justSentHandle = self.handleList[0]
- del self.messageList[0]
- del self.handleList[0]
- if self.sentCallback is not None:
- self.sentCallback(justSent, justSentHandle)
+ if not self.isJunk:
+ justSent = self.messageList[0]
+ justSentHandle = self.handleList[0]
+ del self.messageList[0]
+ del self.handleList[0]
+ if self.sentCallback is not None:
+ self.sentCallback(justSent, justSentHandle)
self.beginNextMessage()
@@ -723,7 +790,7 @@
for msg, handle in zip(self.messageList, self.handleList):
self.failCallback(msg,handle,retriable)
-LISTEN_BACKLOG = 10 # ???? Is something else more reasonable?
+LISTEN_BACKLOG = 128
class MMTPAsyncServer(AsyncServer):
"""A helper class to invoke AsyncServer, MMTPServerConnection, and
MMTPClientConnection, with a function to add new connections, and