[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[minion-cvs] Many documentation/code cleanups, as suggested by Roger.



Update of /home/minion/cvsroot/src/minion/lib/mixminion
In directory moria.seul.org:/tmp/cvs-serv25507/lib/mixminion

Modified Files:
	BuildMessage.py Common.py HashLog.py MMTPClient.py 
	MMTPServer.py Packet.py PacketHandler.py Queue.py __init__.py 
	test.py 
Log Message:
Many documentation/code cleanups, as suggested by Roger.

Also...

Common.py:
	- Recover from missing /usr/bin/shred.
	- Call waitpid properly

MMTPServer.py:
	- Handle interrupted select.
	- Call setsockopt correctly. (socket.SOL_SOCKET != 0, no matter what
          the example code I was reading might have said.)
	- Simplify maxBytes argument out of expectRead method.

Packet.py:
	- Be a little stricter about reply block length.

PacketHandler.py:
	- A list of private keys requires a list of hash logs.

Queue.py:
	- Avoid having multiple instances of shred running at once; they
	  seem to step on one another's toes.
	- Add more bits to a handle.

__init__.py:
	- Make __init__.py act like a regular __init__ file.


test.py:
	- Be a bit more careful about shredding files and closing sockets.

_minionlib.h:
	- Refactor individual METHOD macros into a common declaration.

crypt.c:
	- Replace an impossible error with an assert
	
main.c:
	- More comments

tls.c:
	- Better description of SSL_ERROR_SYSCALL




Index: BuildMessage.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/BuildMessage.py,v
retrieving revision 1.8
retrieving revision 1.9
diff -u -d -r1.8 -r1.9
--- BuildMessage.py	27 Jun 2002 23:32:24 -0000	1.8
+++ BuildMessage.py	1 Jul 2002 18:03:05 -0000	1.9
@@ -32,8 +32,8 @@
                          path1=path1, path2=replyBlock)
 
 def buildReplyBlock(path, exitType, exitInfo, expiryTime=0, secretPRNG=None):
-    """Return a newly-constructed reply block and a list of secrets used
-       to make it.
+    """Return a 2-tuple containing (1) a newly-constructed reply block and (2)
+       a list of secrets used to make it.
        
               path: A list of ServerInfo
               exitType: Routing type to use for the final node
@@ -235,7 +235,8 @@
         # Create a subheader object for this node, but don't fill in the
         # digest until we've constructed the rest of the header.
         subhead = Subheader(MAJOR_NO, MINOR_NO,
-                            secrets[i], " "*20,
+                            secrets[i],
+                            None, #placeholder for as-yet-uncalculated digest
                             rt, ri)
 
         extHeaders = "".join(subhead.getExtraBlocks())

Index: Common.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/Common.py,v
retrieving revision 1.5
retrieving revision 1.6
diff -u -d -r1.5 -r1.6
--- Common.py	25 Jun 2002 11:41:08 -0000	1.5
+++ Common.py	1 Jul 2002 18:03:05 -0000	1.6
@@ -55,19 +55,31 @@
 # FFFF This needs to be made portable.
 _SHRED_CMD = "/usr/bin/shred"
 
+if not os.path.exists(_SHRED_CMD):
+    # XXXX use real logging
+    log("Warning: %s not found. Files will not be securely deleted." %
+        _SHRED_CMD)
+    _SHRED_CMD = None
+
 def secureDelete(fnames, blocking=0):
     """Given a list of filenames, removes the contents of all of those
        files, from the disk, 'securely'.  If blocking=1, does not
        return until the remove is complete.  If blocking=0, returns
        immediately, and returns the PID of the process removing the
        files.  (Returns None if this process unlinked the files
-       itself) XXXX Clarify this.
+       itself.) XXXX Clarify this.
 
        XXXX Securely deleting files only does so much good.  Metadata on
        XXXX the file system, such as atime and dtime, can still be used
        XXXX to reconstruct information about message timings.  To be
        XXXX really safe, we should use a loopback device and shred _that_
        XXXX from time to time.
+
+       XXXX Currently, we use shred from GNU fileutils.  Shred's 'unlink'
+       XXXX operation has the regrettable property that two shred commands
+       XXXX running in the same directory can sometimes get into a race.
+       XXXX The source to shred.c seems to imply that this is harmless, but
+       XXXX let's try to avoid that, to be on the safe side. 
     """
     if isinstance(fnames, StringType):
         fnames = [fnames]
@@ -116,11 +128,10 @@
     while 1:
         try:
             # FFFF This won't work on Windows.  What to do?
-            pid, status = os.waitpid(-1, 0)
-        except:
+            pid, status = os.waitpid(0, 0)
+        except OSError, e:
             break
 
-
 def _sigChldHandler(signal_num, _):
     '''(Signal handler for SIGCHLD)'''
     # Because of the peculiarities of Python's signal handling logic, I
@@ -130,7 +141,7 @@
     while 1:
         try:
             # This waitpid call won't work on Windows.  What to do?
-            pid, status = os.waitpid(-1, os.WNOHANG)
+            pid, status = os.waitpid(0, os.WNOHANG)
             if pid == 0:
                 break
         except OSError:

Index: HashLog.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/HashLog.py,v
retrieving revision 1.6
retrieving revision 1.7
diff -u -d -r1.6 -r1.7
--- HashLog.py	27 Jun 2002 23:32:24 -0000	1.6
+++ HashLog.py	1 Jul 2002 18:03:05 -0000	1.7
@@ -19,9 +19,14 @@
        Each HashLog corresponds to a single public key (whose hash is the
        log's keyid).  A HashLog must persist for as long as the key does.
 
-       It is not necessary to sync the HashLog to the disk every time a new
-       message is seen; rather, the HashLog must be synced before any messages
-       are sent to the network.
+       It is not necessary to sync the HashLog to the disk every time
+       a new message is seen; instead, we must only ensure that every
+       _retransmitted_ message is first inserted into the hashlog and
+       synced.  (One way to implement this is to process messages from
+       'state A' into 'state B', marking them in the hashlog as we go,
+       and syncing the hashlog before any message is sent from 'B' to
+       the network.  On a restart, we reinsert all messages waiting in 'B'
+       into the log.)
 
        HashLogs are implemented using Python's anydbm interface.  This defaults
        to using Berkeley DB, GDBM, or --if you have none of these-- a flat

Index: MMTPClient.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/MMTPClient.py,v
retrieving revision 1.2
retrieving revision 1.3
diff -u -d -r1.2 -r1.3
--- MMTPClient.py	25 Jun 2002 11:41:08 -0000	1.2
+++ MMTPClient.py	1 Jul 2002 18:03:05 -0000	1.3
@@ -6,13 +6,17 @@
    side of the Mixminion Transfer protocol.  You can use this client to 
    upload messages to any conforming Mixminion server.
 
-   XXXX (We don't want to use this module for tranferring packets
-   XXXX between servers; once we have async IO working in MMTPServer, we'll
-   XXXX use that.)
+   (We don't use this module for tranferring packets between servers;
+   in fact, MMTPServer makes it redundant.  We only keep this module
+   around [A] so that clients have an easy (blocking) interface to
+   introduce messages into the system, and [B] so that we've got an
+   easy-to-verify reference implementation of the protocol.)
 
    XXXX We don't yet check for the correct keyid.
 
-   XXXX: As yet unsupported are: Session resumption and key renegotiation."""
+   XXXX: As yet unsupported are: Session resumption and key renegotiation.
+
+   XXXX: Also unsupported: timeouts."""
 
 import socket
 import mixminion._minionlib as _ml
@@ -44,7 +48,7 @@
         
         ####
         # Protocol negotiation
-
+        # For now, we only support 1.0
         self.tls.write("PROTOCOL 1.0\n")
         inp = self.tls.read(len("PROTOCOL 1.0\n"))
         if inp != "PROTOCOL 1.0\n":

Index: MMTPServer.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/MMTPServer.py,v
retrieving revision 1.3
retrieving revision 1.4
diff -u -d -r1.3 -r1.4
--- MMTPServer.py	27 Jun 2002 23:32:24 -0000	1.3
+++ MMTPServer.py	1 Jul 2002 18:03:05 -0000	1.4
@@ -10,8 +10,20 @@
    If you just want to send messages into the system, use MMTPClient.
 
    XXXX As yet unsupported are: Session resumption, key renegotiation,
-   XXXX checking KeyID."""
+   XXXX checking KeyID.
+
+   XXXX: Also unsupported: timeouts."""
+
+# NOTE FOR THE CURIOUS: The 'asyncore' module in the standard library
+#    is another general select/poll wrapper... so why are we using our
+#    own?  Basically, because asyncore has IMO a couple of misdesigns,
+#    the largest of which is that it has the 'server loop' periodically
+#    query the connections for their status, whereas we have the
+#    connections inform the server of their status whenever they
+#    change.  This latter approach turns out to be far easier to use
+#    with TLS.
 
+import errno
 import socket
 import select
 import re
@@ -35,7 +47,7 @@
        Connection objects that are waiting for reads and writes
        (respectively), and waits for their underlying sockets to be
        available for the desired operations.
-       """       
+       """
     def __init__(self):
         """Create a new AsyncServer with no readers or writers."""
         self.writers = {}
@@ -45,14 +57,24 @@
         """If any relevant file descriptors become available within
            'timeout' seconds, call the appropriate methods on their
            connections and return immediately after. Otherwise, wait
-           'timeout' seconds and return."""
+           'timeout' seconds and return.
+
+           If we receive an unblocked signal, return immediately.
+           """
 
         debug("%s readers, %s writers" % (len(self.readers),
                                           len(self.writers)))
         
         readfds = self.readers.keys()
         writefds = self.writers.keys()
-        readfds, writefds, exfds = select.select(readfds, writefds,[], timeout)
+        try:
+            readfds,writefds,exfds = select.select(readfds,writefds,[],timeout)
+        except select.error, e:
+            if e[0] == errno.EINTR:
+                return
+            else:
+                raise e
+        
         for fd in readfds:
             debug("Got a read on "+str(fd))
             self.readers[fd].handleRead()
@@ -126,7 +148,7 @@
         self.port = port
         self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         self.sock.setblocking(0)
-        self.sock.setsockopt(0, socket.SO_REUSEADDR, 1)
+        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         self.sock.bind((self.ip, self.port))
         self.sock.listen(backlog)
         # FFFF LOG
@@ -175,8 +197,6 @@
     #    __inbuflen: The total length of all the strings in __inbuf
     #    __expectReadLen: None, or the number of bytes to read before
     #           the current read succeeds.
-    #    __maxReadLen: None, or a number of bytes above which the current
-    #           read must fail.
     #    __terminator: None, or a string which will terminate the current read.
     #    __outbuf: None, or the remainder of the string we're currently
     #           writing.
@@ -208,24 +228,20 @@
             assert self.__state == self.__connectFn
             server.registerWriter(self)
         
-    def expectRead(self, bytes=None, bytesMax=None, terminator=None):
+    def expectRead(self, bytes=None, terminator=None):
         """Begin reading from the underlying TLS connection.
 
            After the read is finished, this object's finished method
            is invoked.  A call to 'getInput' will retrieve the contents
            of the input buffer since the last call to 'expectRead'.
 
-           bytes -- If provided, a number of bytes to read before
-                    exiting the read state.
-           bytesMax -- If provided, a maximal number of bytes to read
-                    before exiting the read state.
-           terminator -- If provided, a character sequence to read
-                    before exiting the read state.
+           If 'terminator' is not provided, we try to read exactly
+           'bytes' bytes.  If terminator is provided, we read until we
+           encounter the terminator, but give up after 'bytes' bytes.
         """
         self.__inbuf = []
         self.__inbuflen = 0
         self.__expectReadLen = bytes
-        self.__maxReadLen = bytesMax
         self.__terminator = terminator
 
         self.__state = self.__readFn
@@ -283,7 +299,7 @@
         if self.__terminator and len(self.__inbuf) > 1:
             self.__inbuf = ["".join(self.__inbuf)]
 
-        if self.__maxReadLen and self.__inbuflen > self.__maxReadLen:
+        if self.__expectReadLen and self.__inbuflen > self.__expectReadLen:
             debug("Read got too much.")
             self.shutdown(err=1)
             return
@@ -293,7 +309,7 @@
             self.__server.unregister(self)
             self.finished()
 
-        if self.__expectReadLen and (self.__inbuflen >= self.__expectReadLen):
+        if self.__expectReadLen and (self.__inbuflen == self.__expectReadLen):
             debug("read got enough.")
             self.__server.unregister(self)
             self.finished()
@@ -395,7 +411,7 @@
            string.
         """
         self.finished = self.__receivedProtocol
-        self.expectRead(None, 1024, '\n')
+        self.expectRead(1024, '\n')
 
     def __receivedProtocol(self):
         """Called once we're done reading the protocol string.  Either
@@ -419,7 +435,7 @@
         """
         debug("done w/ server sendproto")
         self.finished = self.__receivedMessage
-        self.expectRead(SEND_RECORD_LEN, SEND_RECORD_LEN)
+        self.expectRead(SEND_RECORD_LEN)
 
     def __receivedMessage(self):
         """Called once we've read a message from the line.  Checks the
@@ -444,7 +460,7 @@
         debug("done w/ send ack")
         #XXXX Rehandshake
         self.finished = self.__receivedMessage
-        self.expectRead(SEND_RECORD_LEN, SEND_RECORD_LEN)
+        self.expectRead(SEND_RECORD_LEN)
 
 #----------------------------------------------------------------------
         
@@ -481,7 +497,7 @@
         """Called when we're done sending the protocol string.  Begins
            reading the server's response.
         """
-        self.expectRead(len(PROTOCOL_STRING), len(PROTOCOL_STRING))
+        self.expectRead(1024, '\n')
         self.finished = self.__receivedProtocol
 
     def __receivedProtocol(self):
@@ -517,7 +533,8 @@
     def __receivedAck(self):
        """Called when we're done reading the ACK.  If the ACK is bad,
           closes the connection.  If the ACK is correct, removes the
-          just-sent message from the queue, and calls sentCallback.
+          just-sent message from the connection's internal queue, and
+          calls sentCallback.
 
           If there are more messages to send, begins sending the next.
           Otherwise, begins shutting down.

Index: Packet.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/Packet.py,v
retrieving revision 1.4
retrieving revision 1.5
diff -u -d -r1.4 -r1.5
--- Packet.py	27 Jun 2002 23:32:24 -0000	1.4
+++ Packet.py	1 Jul 2002 18:03:05 -0000	1.5
@@ -259,6 +259,7 @@
        on the path, and the RoutingType and RoutingInfo for the server."""
     def __init__(self, header, useBy, rt, ri):
         """Construct a new Reply Block."""
+        assert len(header) == HEADER_LEN
         self.header = header
         self.timestamp = useBy
         self.routingType = rt

Index: PacketHandler.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/PacketHandler.py,v
retrieving revision 1.4
retrieving revision 1.5
diff -u -d -r1.4 -r1.5
--- PacketHandler.py	27 Jun 2002 23:32:24 -0000	1.4
+++ PacketHandler.py	1 Jul 2002 18:03:05 -0000	1.5
@@ -27,17 +27,20 @@
 
            A sequence of private keys may be provided, if you'd like the
            server to accept messages encrypted with any of them.  Beware,
-           though: this slows down the packet handler a lot.
+           though: PK decryption is expensive.  Also, a hashlog must be
+           provided for each private key.
         """
         # ???? Any way to support multiple keys in protocol?
         try:
             # Check whether we have a key or a tuple of keys.
             _ = privatekey[0]
+            assert len(hashlog) == len(privatekey)
+            
             self.privatekey = privatekey
+            self.hashlog = hashlog
         except:
             self.privatekey = (privatekey, )
-
-        self.hashlog = hashlog
+            self.hashlog = (hashlog, )
 
     def processMessage(self, msg):
         """Given a 32K mixminion message, processes it completely.
@@ -67,20 +70,24 @@
         # order.  Only fail if all private keys fail.
         subh = None
         e = None
-        for pk in self.privatekey:
+        for pk, hashlog in zip(self.privatekey, self.hashlog):
             try:
                 subh = Crypto.pk_decrypt(header1[0], pk)
+                break
             except Crypto.CryptoError, err:
                 e = err
         if not subh:
+            # Nobody managed to get us the first subheader.  Raise the
+            # most-recently-received error.
             raise e
-        subh = Packet.parseSubheader(subh)
+
+        subh = Packet.parseSubheader(subh) #may raise ParseError
 
         # Check the version: can we read it?
         if subh.major != Packet.MAJOR_NO or subh.minor != Packet.MINOR_NO:
             raise ContentError("Invalid protocol version")
 
-        # Check the digest: is it correct?
+        # Check the digest of all of header1 but the first subheader.
         if subh.digest != Crypto.sha1(header1[1:]):
             raise ContentError("Invalid digest")
 
@@ -89,10 +96,10 @@
 
         # Replay prevention
         replayhash = keys.get(Crypto.REPLAY_PREVENTION_MODE, 20)
-        if self.hashlog.seenHash(replayhash):
+        if hashlog.seenHash(replayhash):
             raise ContentError("Duplicate message detected.")
         else:
-            self.hashlog.logHash(replayhash)
+            hashlog.logHash(replayhash)
 
         # If we're meant to drop, drop now.
         rt = subh.routingtype
@@ -112,7 +119,7 @@
                 # size can be longer than the number of bytes in the rest
                 # of the header.
                 raise ContentError("Impossibly long routing info length")
-                
+
             extra = Crypto.ctr_crypt(header1[1:1+nExtra], header_sec_key)
             subh.appendExtraBlocks(extra)
             remainingHeader = header1[1+nExtra:]

Index: Queue.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/Queue.py,v
retrieving revision 1.2
retrieving revision 1.3
diff -u -d -r1.2 -r1.3
--- Queue.py	25 Jun 2002 11:41:08 -0000	1.2
+++ Queue.py	1 Jul 2002 18:03:05 -0000	1.3
@@ -26,6 +26,10 @@
 # trash.
 INPUT_TIMEOUT = 600
 
+# If we've been cleaning for more than CLEAN_TIMEOUT seconds, assume the 
+# old clean is dead.
+CLEAN_TIMEOUT = 60
+
 class Queue:
     """A Queue is an unordered collection of files with secure remove and
        move operations.
@@ -37,9 +41,18 @@
              rmv_HANDLE   (A message waiting to be deleted)
              msg_HANDLE  (A message waiting in the queue.
              inp_HANDLE  (An incomplete message being created.)
-       (Where HANDLE is a randomly chosen 8-character selection from the
+       (Where HANDLE is a randomly chosen 12-character selection from the
        characters 'A-Za-z0-9+-'.  [Collision probability is negligable.])
        """
+       # How negligible?  A back-of-the-envelope approximation: The chance
+       # of a collision reaches .1% when you have 3e9 messages in a single
+       # queue.  If Alice somehow manages to accumulate a 96 gigabyte
+       # backlog, we'll have bigger problems than name collision... such
+       # as the fact that most Unices behave badly when confronted with
+       # 3 billion files in the same directory... or the fact that,
+       # at today's processor speeds, it will take Alice 3 or 4
+       # CPU-years to clear her backlog. 
+
     # Fields:   rng--a random number generator for creating new messages
     #                and getting a random slice of the queue.
     #           dir--the location of the queue.
@@ -63,7 +76,7 @@
                 os.mkdir(location, 0700)
             else:
                 raise MixFatalError("No directory for queue %s" % location)
-
+ 
         # Check permissions
         mode = os.stat(location)[stat.ST_MODE]
         if mode & 0077:
@@ -118,7 +131,6 @@
     def removeMessage(self, handle):
         """Given a handle, removes the corresponding message from the queue."""
         self.__changeState(handle, "msg", "rmv")
-        secureDelete(os.path.join(self.dir, "rmv_"+handle))
 
     def removeAll(self):
         """Removes all messages from this queue."""
@@ -126,8 +138,10 @@
         for m in os.listdir(self.dir):
             if m[:4] in ('inp_', 'msg_'):
                 self.__changeState(m[4:], m[:3], "rmv")
-                removed.append(os.path.join(self.dir, "rmv_"+m[4:]))
-        secureDelete(removed)
+                #removed.append(os.path.join(self.dir, "rmv_"+m[4:]))
+        #    elif m[:4] == 'rmv_':
+        #        removed.append(self.dir)
+        self.cleanQueue()
 
     def moveMessage(self, handle, queue):
         """Given a handle and a queue, moves the corresponding message from
@@ -172,23 +186,42 @@
            rejects the corresponding message."""
         f.close()
         self.__changeState(handle, "inp", "rmv")
-        secureDelete(os.path.join(self.dir, "rmv_"+handle))
 
-    def cleanQueue(self, initial=0):
-        """Removes all timed-out or trash messages from the queue.  If
-           'initial', assumes we're starting up and nobody's already removing
-           messages.  Else, assumes halfway-removed messages are garbage."""
+    def cleanQueue(self):
+        """Removes all timed-out or trash messages from the queue.
+
+           Returns 1 if a clean is already in progress; otherwise
+           returns 0.
+        """
+        now = time.time() 
+        cleanFile = os.path.join(self.dir,".cleaning")
+        try:
+            s = os.stat(cleanFile)
+            if now - s[stat.ST_MTIME] > CLEAN_TIMEOUT:
+                cleaning = 0
+            cleaning = 1    
+        except OSError:
+            cleaning = 0
+
+        if cleaning:
+            return 1
+
+        f = open(cleanFile, 'w')
+        f.write(str(now))
+        f.close()
+        
         rmv = []
         allowedTime = int(time.time()) - INPUT_TIMEOUT
         for m in os.listdir(self.dir):
-            if initial and m.startswith("rmv_"):
+            if m.startswith("rmv_"):
                 rmv.append(os.path.join(self.dir, m))
             elif m.startswith("inp_"):
                 s = os.stat(m)
                 if s[stat.ST_MTIME] < allowedTime:
                     self.__changeState(m[4:], "inp", "rmv")
                     rmv.append(os.path.join(self.dir, m))
-        secureDelete(rmv)
+        _secureDelete_bg(rmv, cleanFile)
+        return 0
 
     def __changeState(self, handle, s1, s2):
         """Helper method: changes the state of message 'handle' from 's1'
@@ -198,5 +231,16 @@
 
     def __newHandle(self):
         """Helper method: creates a new random handle."""
-        junk = self.rng.getBytes(6)
+        junk = self.rng.getBytes(9)
         return base64.encodestring(junk).strip().replace("/","-")
+
+def _secureDelete_bg(files, cleanFile):
+    if os.fork() != 0:
+        return
+    # Now we're in the child process.
+    secureDelete(files, blocking=1)
+    try:
+        os.unlink(cleanFile)
+    except OSError:
+        pass
+    os._exit(0)

Index: __init__.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/__init__.py,v
retrieving revision 1.2
retrieving revision 1.3
diff -u -d -r1.2 -r1.3
--- __init__.py	2 Jun 2002 06:11:16 -0000	1.2
+++ __init__.py	1 Jul 2002 18:03:05 -0000	1.3
@@ -8,5 +8,19 @@
    XXXX write more on principal interfaces"""
 
 __version__ = "0.1"
+__all__ = [ "BuildMessage", "MMTPClient" ]
+
+import BuildMessage
+import Crypto
+import MMTPServer
+import PacketHandler
+import Common
+import HashLog
+import Modules
+import Queue
+import Config
+import MMTPClient
+import Packet
+import ServerInfo
+
 
-# XXXX __all__

Index: test.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/test.py,v
retrieving revision 1.8
retrieving revision 1.9
diff -u -d -r1.8 -r1.9
--- test.py	25 Jun 2002 11:41:08 -0000	1.8
+++ test.py	1 Jul 2002 18:03:05 -0000	1.9
@@ -17,6 +17,7 @@
 import sys
 import threading
 import time
+import atexit
 
 from mixminion.Common import MixError, MixFatalError
 
@@ -936,7 +937,7 @@
         self.sp1 = PacketHandler(self.pk1, h)
         self.sp2 = PacketHandler(self.pk2, h)
         self.sp3 = PacketHandler(self.pk3, h)
-        self.sp2_3 = PacketHandler((self.pk2,self.pk3), h)
+        self.sp2_3 = PacketHandler((self.pk2,self.pk3), (h,h))
 
     def tearDown(self):
         self.hlog.close()
@@ -1140,7 +1141,16 @@
 from mixminion.Common import waitForChildren
 from mixminion.Queue import Queue
 
-already = 0
+def removeTempDirs(*dirs):
+    print "Removing temporary dirs"
+    waitForChildren()
+    for d in dirs:
+        if os.path.isdir(d):
+            for fn in os.listdir(d):
+                os.unlink(os.path.join(d,fn))
+            os.rmdir(d)
+        elif os.path.exists(d):
+            os.unlink(d)
 
 class QueueTests(unittest.TestCase):
     def setUp(self):
@@ -1148,19 +1158,9 @@
         mixminion.Common.installSignalHandlers(child=1,hup=0,term=0)
         self.d1 = tempfile.mktemp("q1")
         self.d2 = tempfile.mktemp("q2")
+        self.d3 = tempfile.mktemp("q3")
+        atexit.register(removeTempDirs, self.d1, self.d2, self.d3)
         
-    def tearDown(self):
-        # First, wait until all the removes have finished.
-        waitForChildren()
-        
-        for d in (self.d1, self.d2):
-            if os.path.isdir(d):
-                for fn in os.listdir(d):
-                    os.unlink(os.path.join(d,fn))
-                os.rmdir(d)
-            elif os.path.exists(d):
-                os.unlink(d)
-
     def testCreateQueue(self):
         # Nonexistant dir.
         self.failUnlessRaises(MixFatalError, Queue, self.d1)
@@ -1196,8 +1196,8 @@
 
     def testQueueOps(self):
         #XXXX COMMENT ME
-        queue1 = Queue(self.d1, create=1)
-        queue2 = Queue(self.d2, create=1)
+        queue1 = Queue(self.d2, create=1)
+        queue2 = Queue(self.d3, create=1)
 
         handles = [queue1.queueMessage("Sample message %s" % i) 
                    for i in range(100)]
@@ -1214,11 +1214,12 @@
             self.assertEquals("Sample message %s" %i,
                               queue1.messageContents(h))
 
-        assert len(hdict) == len(handles) == 100     
+        assert len(hdict) == len(handles) == 100
 
         q2h = []
         for h in handles[:30]:
-            q2h.append(queue1.moveMessage(h, queue2))
+            nh = queue1.moveMessage(h, queue2)
+            q2h.append(nh)
 
         from string import atoi
         seen = {}
@@ -1228,6 +1229,7 @@
             i = atoi(c[15:])
             self.failIf(seen.has_key(i))
             seen[i]=1
+
         for i in range(30):
             self.failUnless(seen.has_key(i))
 
@@ -1259,11 +1261,14 @@
         queue1.abortMessage(f,h)
         self.failUnlessRaises(IOError, queue1.messageContents, h)
         self.assertEquals(queue1.count(), 41)
-        self.assert_(not os.path.exists(os.path.join(self.d1, "msg_"+h)))
+        self.assert_(not os.path.exists(os.path.join(self.d2, "msg_"+h)))
 
         queue1.removeAll()
         queue2.removeAll()
 
+        queue1.cleanQueue()    
+        queue2.cleanQueue()
+
 #----------------------------------------------------------------------
 # SIGHANDLERS
 # XXXX
@@ -1303,8 +1308,32 @@
         return server, listener, messagesIn
 
 class MMTPTests(unittest.TestCase):
+
+    def doTest(self, fn):
+        self.listener = self.server = None
+        try:
+            fn()
+        finally:
+            if self.listener is not None:
+                self.listener.shutdown()
+            if self.server is not None:
+                count = 0
+                while count < 100 and (self.server.readers or
+                                       self.server.writers):
+                    self.server.process(0.1)
+                    count = count + 1
+
     def testBlockingTransmission(self):
-        server, listener, messagesIn = _getMMTPServer() 
+        self.doTest(self._testBlockingTransmission)
+
+    def testNonblockingTransmission(self):
+        self.doTest(self._testNonblockingTransmission)
+    
+    def _testBlockingTransmission(self):
+        server, listener, messagesIn = _getMMTPServer()
+        self.listener = listener
+        self.server = server
+        
         messages = ["helloxxx"*4096, "helloyyy"*4096]
 
         server.process(0.1)
@@ -1315,13 +1344,16 @@
         while len(messagesIn) < 2:
             server.process(0.1)
         t.join()
+
+        for i in xrange(10):
+            server.process(0.1)
+
         self.failUnless(messagesIn == messages)
-        # Shutdown properly on failure. XXXX
-        listener.shutdown()
-        server.process(0.1)
-            
-    def testNonblockingTransmission(self):
-        server, listener, messagesIn = _getMMTPServer() 
+
+    def _testNonblockingTransmission(self):
+        server, listener, messagesIn = _getMMTPServer()
+        self.listener = listener
+        self.server = server
 
         messages = ["helloxxx"*4096, "helloyyy"*4096]
         async = mixminion.MMTPServer.AsyncServer()
@@ -1344,9 +1376,6 @@
 
         self.assertEquals(len(messagesIn), len(messages))
         self.failUnless(messagesIn == messages)
-        # Shutdown properly on failure. XXXX
-        listener.shutdown()
-        server.process(0.1)
         
 #----------------------------------------------------------------------