[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[minion-cvs] Docs, tests, and a couple of bugfixes for ClientUtils



Update of /home/minion/cvsroot/src/minion/lib/mixminion
In directory moria.mit.edu:/tmp/cvs-serv1354/lib/mixminion

Modified Files:
	ClientUtils.py test.py 
Log Message:
Docs, tests, and a couple of bugfixes for ClientUtils

Index: ClientUtils.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/ClientUtils.py,v
retrieving revision 1.7
retrieving revision 1.8
diff -u -d -r1.7 -r1.8
--- ClientUtils.py	7 Nov 2003 10:43:18 -0000	1.7
+++ ClientUtils.py	8 Nov 2003 05:35:57 -0000	1.8
@@ -25,21 +25,50 @@
 
 #----------------------------------------------------------------------
 class BadPassword(MixError):
+    """Exception raised when we try to access a password-protected resource
+       and the user doesn't give the right password"""
     pass
 
 class PasswordManager:
-    # passwords: name -> string
+    """A PasswordManager keeps track of a set of named passwords, so that
+       a user never has to enter any password more than once.  This is an
+       abstract class."""
+    ## Fields
+    # passwords: map from password name to string value of the password.
     def __init__(self):
+        """Create a new PasswordManager"""
         self.passwords = {}
     def _getPassword(self, name, prompt):
+        """Abstract function; subclasses must override.
+
+           Use the prompt 'prompt' to ask the user for the password
+           'name'.  Return what the user enters.
+        """   
         raise NotImplemented()
     def _getNewPassword(self, name, prompt):
+        """Abstract function; subclasses must override.
+
+           Use the prompt 'prompt' to ask the user for a _new_
+           password 'name'.  Ususally, this will involve asking for
+           the password twice to confirm that the user hasn't mistyped.
+        """   
         raise NotImplemented()
     def setPassword(self, name, password):
+        """Change the internally cached value for the password named
+           'name' to 'password'."""
         self.passwords[name] = password
     def getPassword(self, name, prompt, confirmFn, maxTries=-1):
+        """Return the password named 'name', querying using the prompt
+           'prompt' if necessary.  Before returning a prospective
+           password, we call 'confirmFn' on it.  If confirmFn returns 1,
+           the password is correct.  If confirmFn returns 0, the password
+           is incorrect.  Queries the user at most 'maxTries' times before
+           giving up.  Raises BadPassword on failure.""" 
         if self.passwords.has_key(name):
-            return self.passwords[name]
+            pwd = self.passwords[name]
+            if confirmFn(pwd):
+                self.passwords[name] = pwd
+                return pwd
         for othername, pwd in self.passwords.items():
             if confirmFn(pwd):
                 self.passwords[name] = pwd
@@ -55,10 +84,13 @@
 
         raise BadPassword()
     def getNewPassword(self, name, prompt):
+        """Use 'prompt' to ask the user for a fresh password named 'name'."""
         self.passwords[name] = self._getNewPassword(name, prompt)
         return self.passwords[name]
 
 class CLIPasswordManager(PasswordManager):
+    """Impementation of PasswordManager that asks for passwords from the
+       command line."""
     def __init__(self):
         PasswordManager.__init__(self)
     def _getPassword(self, name, prompt):
@@ -68,7 +100,7 @@
 
 def getPassword_term(prompt):
     """Read a password from the console, then return it.  Use the string
-    'message' as a prompt."""
+       'message' as a prompt."""
     # getpass.getpass uses stdout by default .... but stdout may have
     # been redirected.  If stdout is not a terminal, write the message
     # to stderr instead.
@@ -104,16 +136,26 @@
         f.flush()
 
 #----------------------------------------------------------------------
+# Functions to save and load data do disk in password-encrypted files.
+#
+# The file format is:
+#     variable      [File-specific magic, used to make sure we have an
+#                    encrypted file.]
+#     8 bytes       [Random salt.]
+#     variable      [AES-CTR Encrypted data:
+#                             key is sha1(salt+password+salt),
+#                             value is data+sha1(data+salt+magic)]
+#
+# Note that this format does not conceal the length of the data.
 
 def readEncryptedFile(fname, password, magic):
-    """DOCDOC
-       return None on failure; raise  MixError on corrupt file.
+    """Read encrypted data from the file named 'fname', using the password
+       'password' and checking for the filetype 'magic'.  Returns the 
+       plaintext file contents on success.  If the 
+
+       If the file is corrupt or the password is wrong, raises BadPassword.
+       If the magic is incorrect, raises ValueError.
     """
-    #  variable         [File specific magic]       "KEYRING1"
-    #  8                [8 bytes of salt]
-    #  variable         ENCRYPTED DATA:KEY=sha1(salt+password+salt)
-    #                                  DATA=data+
-    #                                                   sha1(data+salt+magic)
     s = readFile(fname, 1)
     if not s.startswith(magic):
         raise ValueError("Invalid versioning on %s"%fname)
@@ -131,22 +173,40 @@
     return data
 
 def writeEncryptedFile(fname, password, magic, data):
+    """Write 'data' into an encrypted file named 'fname', replacing it
+       if necessary.  Encrypts the data with the password 'password',
+       and uses the filetype 'magic'."""
     salt = mixminion.Crypto.getCommonPRNG().getBytes(8)
     key = mixminion.Crypto.sha1(salt+password+salt)[:16]
-    hash = mixminion.Crypto.sha1("".join([data+salt+magic]))
+    hash = mixminion.Crypto.sha1("".join([data,salt,magic]))
     encrypted = mixminion.Crypto.ctr_crypt(data+hash, key)
     writeFile(fname, "".join([magic,salt,encrypted]), binary=1)
 
 def readEncryptedPickled(fname, password, magic):
+    """Read the pickled object stored in the encrypted file 'fname'. Arguments
+       are as for 'readEncryptedFile'."""
     return cPickle.loads(readEncryptedFile(fname, password, magic))
 
 def writeEncryptedPickled(fname, password, magic, obj):
+    """Write 'obj' into encrypted file 'fname'. Arguments are as for
+       'writeEncryptedFile'."""
     data = cPickle.dumps(obj, 1)
     writeEncryptedFile(fname, password, magic, data)
 
 class LazyEncryptedPickled:
+    """Wrapper for a file containing an encrypted pickled object, to
+       perform password querying and loading on demand."""
     def __init__(self, fname, pwdManager, pwdName, queryPrompt, newPrompt,
                  magic, initFn):
+        """Create a new LazyEncryptedPickled
+              fname -- The name of the file to hold the encrypted object.
+              pwdManager -- A PasswordManager instance.
+              pwdName, queryPrompt, newPrompt -- Arguments used when getting
+                  passwords from the PasswordManager.
+              magic -- The filetype to use for the encrypted file.
+              initFn -- A callable object that returns a fresh value for
+                  a newly created encrypted file.
+        """
         self.fname = fname
         self.pwdManager = pwdManager
         self.pwdName = pwdName
@@ -158,6 +218,10 @@
         self.password = None
         self.initFn = initFn
     def load(self, create=0,password=None):
+        """Try to load the encrypted file from disk.  If 'password' is
+           not provided, query it from the password manager.  If the file
+           does not exist, and 'create' is true, get a new password and
+           create the file."""
         if self.loaded:
             return 
         elif os.path.exists(self.fname):
@@ -185,8 +249,9 @@
             self.save()
         else:
             return
-
     def _loadWithPassword(self, password):
+        """Helper function: tries to load the file with a given password.
+           If Successful, return 1. Else return 0."""
         try:
             self.object = readEncryptedPickled(self.fname,password,self.magic)
             self.password = password
@@ -195,21 +260,28 @@
         except MixError:
             return 0
     def isLoaded(self):
+        """Return true iff this file has been successfully loaded."""
         return self.loaded
     def get(self):
+        """Returns the contents of this file. The file must first have
+           been loaded."""
         assert self.loaded
         return self.object
     def set(self, val):
+        """Set the contents of this file.  Does not save the file to
+           disk."""
         self.object = val
         self.loaded = 1
     def setPassword(self, pwd):
+        """Set the password on this file."""
         self.password = pwd
+        self.pwdManager.setPassword(self.pwdName, pwd)
     def save(self):
+        """Flush the current contens of this file to disk."""
         assert self.loaded and self.password is not None
         writeEncryptedPickled(self.fname, self.password, self.magic,
                               self.object)
         
-        
 # ----------------------------------------------------------------------
 
 class SURBLog(mixminion.Filestore.DBBase):
@@ -316,17 +388,24 @@
        tell us not to."""
     ## Fields:
     # dir -- a directory to store packets in.
-    # store -- an instance of ObjectStore.  The entries are of the
+    # store -- an instance of ObjectMetadataStore.  The objects are of the
     #    format:
     #           ("PACKET-0",
     #             a 32K string (the packet),
-    #             an instance of IPV4Info (the first hop),
+    #             an instance of IPV4Info or HostInfo (the first hop),
     #             the latest midnight preceding the time when this
     #                 packet was inserted into the queue
     #           )
-    # XXXX change this to be OO; add nicknames.
+    #    The metadata is of the format:
+    #           ("V0",
+    #             an instance of IPV4Info or HostInfo (the first hop),
+    #             the latest midnight preceding the time when this
+    #                 packet was inserted into the queue
+    #           )
+    #    [These formats are redundant so that 0.0.6 and 0.0.5 clients
+    #     stay backward compatible for now.]
+    #
     # XXXX006 write unit tests
-    # XXXX Switch to use metadata.
     def __init__(self, directory, prng=None):
         """Create a new ClientQueue object, storing packets in 'directory'
            and generating random filenames using 'prng'."""
@@ -334,7 +413,7 @@
         createPrivateDir(directory)
 
         # We used to name entries "pkt_X"; this has changed.
-        # XXXX006 remove this when it's no longer needed.
+        # XXXX007 remove this when it's no longer needed.
         for fn in os.listdir(directory):
             if fn.startswith("pkt_"):
                 handle = fn[4:]
@@ -347,19 +426,21 @@
 
         self.metadataLoaded = 0
 
-    def queuePacket(self, message, routing):
-        """Insert the 32K packet 'message' (to be delivered to 'routing')
+    def queuePacket(self, packet, routing, now=None):
+        """Insert the 32K packet 'packet' (to be delivered to 'routing')
            into the queue.  Return the handle of the newly inserted packet."""
+        if now is None:
+            now = time.time()
         mixminion.ClientMain.clientLock()
         try:
-            fmt = ("PACKET-0", message, routing, previousMidnight(time.time()))
-            meta = ("V0", routing, previousMidnight(time.time()))
+            fmt = ("PACKET-0", packet, routing, previousMidnight(now))
+            meta = ("V0", routing, previousMidnight(now))
             return self.store.queueObjectAndMetadata(fmt,meta)
         finally:
             mixminion.ClientMain.clientUnlock()
 
     def getHandles(self):
-        """Return a list of the handles of all messages currently in the
+        """Return a list of the handles of all packets currently in the
            queue."""
         mixminion.ClientMain.clientLock()
         try:
@@ -368,7 +449,7 @@
             mixminion.ClientMain.clientUnlock()
 
     def getRouting(self, handle):
-        """DOCDOC"""
+        """Return the routing information associated with the given handle."""
         self.loadMetadata()
         return self.store.getMetadata(handle)[1]
 
@@ -379,13 +460,13 @@
            CorruptedFile."""
         obj = self.store.getObject(handle)
         try:
-            magic, message, routing, when = obj
+            magic, packet, routing, when = obj
         except (ValueError, TypeError):
             magic = None
         if magic != "PACKET-0":
             LOG.error("Unrecognized packet format for %s",handle)
             return None
-        return message, routing, when
+        return packet, routing, when
 
     def packetExists(self, handle):
         """Return true iff the queue contains a packet with the handle
@@ -397,8 +478,9 @@
         self.store.removeMessage(handle)
 
     def inspectQueue(self, now=None):
-        """Print a message describing how many messages in the queue are headed
+        """Print a message describing how many packets in the queue are headed
            to which addresses."""
+        #XXXX006 refactor
         if now is None:
             now = time.time()
         handles = self.getHandles()
@@ -423,8 +505,7 @@
                 count, s.ip, s.port, days)
 
     def cleanQueue(self, maxAge=None, now=None):
-        """Remove all messages older than maxAge seconds from this
-           queue."""
+        """Remove all packets older than maxAge seconds from this queue."""
         if now is None:
             now = time.time()
         if maxAge is not None:
@@ -438,16 +519,17 @@
                     continue
                 if when < cutoff:
                     remove.append(h)
-            LOG.info("Removing %s old messages from queue", len(remove))
+            LOG.info("Removing %s old packets from queue", len(remove))
             for h in remove:
                 self.store.removeMessage(h)
         self.store.cleanQueue()
 
     def loadMetadata(self):
-        """DOCDOC"""
+        """Ensure that we've loaded metadata for this queue from disk."""
         if self.metadataLoaded:
             return
 
+        # Helper function: create metadata from a file without it.
         def fixupHandle(h,self=self):
             packet, routing, when = self.getPacket(h)
             return "V0", routing, when

Index: test.py
===================================================================
RCS file: /home/minion/cvsroot/src/minion/lib/mixminion/test.py,v
retrieving revision 1.161
retrieving revision 1.162
diff -u -d -r1.161 -r1.162
--- test.py	7 Nov 2003 09:05:00 -0000	1.161
+++ test.py	8 Nov 2003 05:35:57 -0000	1.162
@@ -5957,6 +5957,7 @@
         d = mix_mktemp()
         createPrivateDir(d)
         f1 = os.path.join(d, "foo")
+        # Test reading and writing.
         CU.writeEncryptedFile(f1, password="x", magic="ABC", data="xyzzyxyzzy")
         contents = readFile(f1)
         self.assertEquals(contents[:3], "ABC")
@@ -5969,8 +5970,129 @@
         self.assertEquals("xyzzyxyzzy",
               CU.readEncryptedFile(f1, "x", "ABC"))
 
-        #XXXX006 finish testing corner cases and pickles.
+        # Try reading with wrong password.
+        self.assertRaises(CU.BadPassword, CU.readEncryptedFile,
+                          f1, "nobodaddy", "ABC")
+
+        # Try reading with wrong magic.
+        self.assertRaises(ValueError, CU.readEncryptedFile,
+                          f1, "x", "ABX")
+
+        # Try empty data.
+        CU.writeEncryptedFile(f1, password="x", magic="ABC", data="")
+        self.assertEquals("", CU.readEncryptedFile(f1, "x", "ABC"))
+        
+        # Test pickles.
+        f2 = os.path.join(d, "bar")
+        CU.writeEncryptedPickled(f2, "pswd", "ZZZ", [1,2,3])
+        self.assertEquals([1,2,3],CU.readEncryptedPickled(f2,"pswd","ZZZ"))
+        CU.writeEncryptedPickled(f2, "pswd", "ZZZ", {9:10,11:12})
+        self.assertEquals({9:10,11:12},
+                          CU.readEncryptedPickled(f2,"pswd","ZZZ"))
+        
+        # Test LazyEncryptedPickle
+        class DummyPasswordManager(CU.PasswordManager):
+            def __init__(self,d):
+                mixminion.ClientUtils.PasswordManager.__init__(self)
+                self.d = d
+            def _getPassword(self,name,prompt):
+                return self.d.get(name)
+            def _getNewPassword(self,name,prompt):
+                return self.d.get(name)
+
+        f3 = os.path.join(d, "Baz")
+        dpm = DummyPasswordManager({"Password1" : "p1"})
+        lep = CU.LazyEncryptedPickled(f3, dpm, "Password1", "Q:", "N:",
+                                     "magic0", lambda: "x"*3)
+        # Don't create.
+        self.assert_(not lep.isLoaded())
+        lep.load(create=0)
+        self.assert_(not lep.isLoaded())
+        lep.load(create=1)
+        self.assert_(lep.isLoaded())
+        self.assertEquals("x"*3, lep.get())
+        self.assertEquals("x"*3, CU.readEncryptedPickled(f3,"p1","magic0"))
+
+        lep = CU.LazyEncryptedPickled(f3, dpm, "Password1", "Q:", "N:",
+                                     "magic0", lambda: "x"*3)
+        lep.load()
+        self.assertEquals("x"*3, lep.get())
+        dpm.d = {}
+        self.assertEquals("x"*3, lep.get())
+
+    def testSURBLog(self):
+        brb = BuildMessage.buildReplyBlock
+        SURBLog = mixminion.ClientUtils.SURBLog
+        ServerInfo = mixminion.ServerInfo.ServerInfo
+        dirname = mix_mktemp()
+        fname = os.path.join(dirname, "surblog")
+
+        # generate 3 SURBs.
+        examples = getExampleServerDescriptors()
+        alice = ServerInfo(string=examples["Alice"][0])
+        lola = ServerInfo(string=examples["Lola"][0])
+        joe = ServerInfo(string=examples["Joe"][0])
+        surbs = [brb([alice,lola,joe], SMTP_TYPE, "bjork@iceland", "x",
+                     time.time()+24*60*60)
+                 for _ in range(3)]
+
+        #FFFF check for skipping expired and shortlived SURBs.
+        
+        s = SURBLog(fname)
+        try:
+            self.assert_(not s.isSURBUsed(surbs[0]))
+            self.assert_(not s.isSURBUsed(surbs[1]))
+            s.markSURBUsed(surbs[0])
+            self.assert_(s.isSURBUsed(surbs[0]))
+            s.close()
+            s = SURBLog(fname)
+            self.assert_(s.isSURBUsed(surbs[0]))
+            self.assert_(not s.isSURBUsed(surbs[1]))
+            self.assert_(s.findUnusedSURBs(surbs)[0] is surbs[1])
+            one = s.findUnusedSURBs(surbs,1)
+            self.assertEquals(len(one),1)
+            two = s.findUnusedSURBs(surbs,2)
+            self.assert_(two[0] is surbs[1])
+            self.assert_(two[1] is surbs[2])
+            s.markSURBUsed(surbs[1])
+            self.assert_(s.findUnusedSURBs(surbs)[0] is surbs[2])
+            s.markSURBUsed(surbs[2])
+            self.assert_(s.findUnusedSURBs(surbs) == [])
+        finally:
+            s.close()
+
+    def testClientQueue(self):
+        CQ = mixminion.ClientUtils.ClientQueue
+        d = mix_mktemp()
+        now = time.time()
+        cq = CQ(d)
+        p1 = "Z"*(32*1024)
+        p2 = mixminion.Crypto.getCommonPRNG().getBytes(32*1024)
+        p3 = p2[:1024]*32
+        ipv4 = mixminion.Packet.IPV4Info("10.20.30.40",48099,"KZ"*10)
+        host = mixminion.Packet.MMTPHostInfo("bliznerty.potrzebie",48099,
+                                             "KZ"*10)
+        self.assertEquals(cq.getHandles(), [])
+        self.assert_(not cq.packetExists("Z"))
+        h1 = cq.queuePacket(p1, ipv4, now)
+        h2 = cq.queuePacket(p2, host, now-24*60*60*10)
+        self.assertEquals(ipv4, cq.getRouting(h1))
+        self.assert_(cq.packetExists(h1))
+        self.assertEquals(host, cq.getRouting(h2))
 
+        cq = CQ(d)
+        self.assertUnorderedEq(cq.getHandles(),[h1,h2])
+        self.assertEquals(host, cq.getRouting(h2))
+        v = cq.getPacket(h2)
+        self.assertEquals((host,previousMidnight(now-24*60*60*10)), v[1:])
+        self.assertLongStringEq(v[0], p2)
+        cq.cleanQueue(maxAge=24*60*60,now=now)
+        self.assertEquals([h1], cq.getHandles())
+        v = cq.getPacket(h1)
+        self.assertEquals((ipv4,previousMidnight(now)), v[1:])
+        self.assertLongStringEq(v[0], p1)
+        cq.removePacket(h1)
+        
 class ClientDirectoryTests(TestCase):
     def testClientDirectory(self):
         """Check out ClientMain's directory implementation"""
@@ -6476,46 +6598,6 @@
         parseFails("0x9999") # No data
         parseFails("0xFEEEF:zymurgy") # Hex literal out of range
 
-    def testSURBLog(self): #XXXX move this.
-        brb = BuildMessage.buildReplyBlock
-        SURBLog = mixminion.ClientUtils.SURBLog
-        ServerInfo = mixminion.ServerInfo.ServerInfo
-        dirname = mix_mktemp()
-        fname = os.path.join(dirname, "surblog")
-
-        # generate 3 SURBs.
-        examples = getExampleServerDescriptors()
-        alice = ServerInfo(string=examples["Alice"][0])
-        lola = ServerInfo(string=examples["Lola"][0])
-        joe = ServerInfo(string=examples["Joe"][0])
-        surbs = [brb([alice,lola,joe], SMTP_TYPE, "bjork@iceland", "x",
-                     time.time()+24*60*60)
-                 for _ in range(3)]
-
-        #FFFF check for skipping expired and shortlived SURBs.
-        
-        s = SURBLog(fname)
-        try:
-            self.assert_(not s.isSURBUsed(surbs[0]))
-            self.assert_(not s.isSURBUsed(surbs[1]))
-            s.markSURBUsed(surbs[0])
-            self.assert_(s.isSURBUsed(surbs[0]))
-            s.close()
-            s = SURBLog(fname)
-            self.assert_(s.isSURBUsed(surbs[0]))
-            self.assert_(not s.isSURBUsed(surbs[1]))
-            self.assert_(s.findUnusedSURBs(surbs)[0] is surbs[1])
-            one = s.findUnusedSURBs(surbs,1)
-            self.assertEquals(len(one),1)
-            two = s.findUnusedSURBs(surbs,2)
-            self.assert_(two[0] is surbs[1])
-            self.assert_(two[1] is surbs[2])
-            s.markSURBUsed(surbs[1])
-            self.assert_(s.findUnusedSURBs(surbs)[0] is surbs[2])
-            s.markSURBUsed(surbs[2])
-            self.assert_(s.findUnusedSURBs(surbs) == [])
-        finally:
-            s.close()
 
     def testClientKeyring(self):
         keydir = mix_mktemp()
@@ -6904,7 +6986,7 @@
     tc = loader.loadTestsFromTestCase
 
     if 0:
-        suite.addTest(tc(PacketHandlerTests))
+        suite.addTest(tc(ClientUtilTests))
         return suite
     testClasses = [MiscTests,
                    MinionlibCryptoTests,