[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[minion-cvs] Committing first version of mixminion code to CVS repos...



Update of /home/minion/cvsroot/src/minion/lib/mixminion
In directory moria.seul.org:/tmp/cvs-serv22386/lib/mixminion

Added Files:
	BuildMessage.py Common.py Config.py Crypto.py Formats.py 
	HashLog.py Modules.py ServerInfo.py ServerProcess.py 
	__init__.py benchmark.py test.py 
Log Message:
Committing first version of mixminion code to CVS repository.  Right
now, the code to generate and handle messages is in place, but we're
still missing all the networking logic... along with everything else
mentioned in TODO.

See HACKING if you want to make it go.

See TODO it you want to make it better.

Please don't circulate this beyond the list until we have more of it
working.

I'm going to have to work on my day job for a while, so I may be slow
writing more of this.  Nevertheless, I hope to get the
buildMessage/serverProcess logic well tested and debugged by this
weekend.


--- NEW FILE: BuildMessage.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: BuildMessage.py,v 1.1 2002/05/29 03:52:13 nickm Exp $

from mixminion.Formats import *
import mixminion.Crypto as Crypto
import mixminion.Modules as Modules

__all__ = [ 'buildForwardMessage', 'buildReplyBlock', 'buildReplyMessage',
            'buildStatelessReplyBlock' ]

def buildForwardMessage(payload, exitType, exitInfo, path1, path2):
    return _buildMessage(payload, exitType, exitInfo, path1, path2)

def buildReplyMessage(payload, exitType, exitInfo, path1, replyBlock):
    return _buildMessage(payload, exitType, exitInfo, path1, reply=replyBlock)

# Bad interface: this shouldn't return a tuple. 
def buildReplyBlock(path, exitType, exitInfo, prng):
    secrets = [ prng.getBytes(SECRET_LEN) for node in path ]
    headers = _buildHeaders(path, secrets, exitType, exitInfo)
    return (headers, path[0]), secrets

# Bad interface: userkey should only be None if we trust the final node
# a lot.
def buildStatelessReplyBlock(path, prng, user, userKey=None, email=0):
    if email:
        assert userKey
    seed = Crypto.trng(16)
    if userKey:
        tag = Crypto.ctr_crypt(seed,userKey)
    else:
        tag = seed
    if emal:
        exitType = Modules.SMTP_TYPE
        exitInfo = SMTPInfo(user, "RTRN"+key).pack()
    else:
        exitType = Modules.LOCAL_TYPE
        exitInfo = LocalInfo(user, "RTRN"+key).pack()

    prng = Crypto.AESCounterPRNG(seed)
    return buildReplyBlock(path, exitType, exitInto, prng)

#----------------------------------------------------------------------
def _buildMessage(payload, exitType, exitInfo,
                  path1, path2=None, reply=None, prng=None, paranoia=0):
    assert path2 or reply
    if prng == None:
        prng = Crypto.AESCounterPRNG()

    # ???? Payload padding/sizing must be handled in spec.
    if len(payload) < PAYLOAD_LEN:
        payload += prng.getBytes(PAYLOAD_LEN-len(payload))

    if paranoia:
        secrets1 = [ Crypto.trng(SECRET_LEN) for node in path1 ]
        if path2: secrets2 = [ Crypto.trng(SECRET_LEN) for node in path2 ]
    else:
        secrets1 = [ prng.getBytes(SECRET_LEN) for node in path1 ]
        if path2: secrets2 = [ prng.getBytes(SECRET_LEN) for node in path2 ]

    if path2:
        node = path2[0]
    else:
        node = reply[1]
    info = IPV4Info(node.getIP(), node.getPort(), node.getKeyID())
    headers1 = _buildHeaders(path1, secrets1, Modules.SWAP_FWD_TYPE, info,prng)
    if path2:
        headers2 = _buildHeaders(path2, secrets2, exitType, exitInfo, prng)
    else:
        headers2 = reply[0]
    return _constructMessage(secrets1, secrets2, headers1, headers2, payload)


def _buildHeaders(path, secrets, exitType, exitInfo, prng):
    hops = len(path)

    #Calculate all routing info
    routing = []
    for i in range(hops-1):
        nextNode = path[i+1]
        info = IPV4Info(nextNode.getIP(), nextNode.getPort(),
                        nextNode.getKeyID())
        routing.append( (Modules.FWD_TYPE, info.pack() ) )
    
    routing.append( (exitType, exitInfo) )                   
    
    # size[i] is size, in blocks, of headers for i.
    size = [ getTotalBlocksForRoutingInfo(info) for t, info in routing ]

    totalSize = len(path)+size[-1]-1 

    # Calculate masks, junk.
    masks = []
    junk = [ "" ]
    headersecrets = []
    for secret, size in zip(secrets, size):
        ks = Crypto.Keyset(secrets)
        hs = ks.get(Crypto.HEADER_SECRET_MODE)
        nextMask = Crypto.prng(hs, HEADER_LEN)
        nextJunk = junk[-1] + Crypto.prng(ks.get(Crypto.RANDOM_JUNK_MODE),size)
        nextJunk = Crypto.strxor(nextJunk, nextMask[HEADER_LEN-len(nextJunk):])
        junk.append(nextJunk)
        masks.append(nextMask)
        headersecrets.append(hs)
        
    del junk[0]
    
    header = prng.getBytes(HEADER_LEN - totalSize*128)
    
    for i in range(hops-1, -1, -1):
        jnk = junk[i]
        rest = Crypto.strxor(header, masks[i])
        digest = Crypto.sha1(rest + junk[i])
        pubkey = Crypto.pk_from_modulus(nodes[i].getModulus())
        rt, ri = routing[i]
        subhead = Subheader(MAJOR_NO, MINOR_NO,
                            secrets[i], digest[i],
                            rt, ri).pack()
        esh = Crypto.pk_encrypt(pubkey, subhead)
        header = subhead + rest

    return header

 
# For a reply, secrets2==None
def _constructMessage(secrets1, secrets2, header1, header2, payload):
    assert len(payload) == PAYLOAD_LEN
    assert len(header1) == len(header2) == HEADER_LEN
    
    if secrets2:
        secrets2.reverse()
        for secret in secrets2:
            key = Crypto.Keyset(secret).getLionessKeys(PAYLOAD_ENCRYPT_MODE)
            payload = Crypto.lioness_encrypt(key, payload)

    key = Crypto.get_lioness_keys_from_payload(payload)
    header2 = Crypto.lionesss_encrypt(key, header2)

    secrets1.reverse()
    for secret in secrets1:
        ks = Crypto.Keyset(secret)
        hkey = ks.getLionessKeys(HEADER_ENCRYPT_MODE)
        pkey = ks.getLionessKeys(PAYLOAD_ENCRYPT_MODE)
        header2 = Crypto.lioness_encrypt(hkey, header2)
        payload = Crypto.lioness_encrypt(pkey, payload)

    return Message(header1, header2, payload).pack()

--- NEW FILE: Common.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: Common.py,v 1.1 2002/05/29 03:52:13 nickm Exp $

class MixError(Exception):
    pass




--- NEW FILE: Config.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: Config.py,v 1.1 2002/05/29 03:52:13 nickm Exp $

--- NEW FILE: Crypto.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: Crypto.py,v 1.1 2002/05/29 03:52:13 nickm Exp $
"""mixminion.Crypto

   This package contains XXXX"""

import sys
import mixminion._minionlib as _ml

__all__ = [ 'init_crypto', 'sha1',  'ctr_crypt', 'prng',
            'lioness_encrypt', 'lioness_decrypt', 'trng', 'pk_encrypt',
            'pk_decrypt', 'pk_generate', 'openssl_seed',
            'pk_get_modulus', 'pk_from_modulus',
            'pk_encode_private_key', 'pk_decode_private_key',
            'Keyset', 'AESCounterPRNG',
            'HEADER_SECRET_MODE', 'PRNG_MODE', 'HEADER_ENCRYPT_MODE',
            'PAYLOAD_ENCRYPT_MODE', 'HIDE_HEADER_MODE' ]

AES_KEY_LEN = 128/8
DIGEST_LEN = 160/8

def init_crypto():
    """init_crypto()

       Initialize the crypto subsystem."""
    try:
        # Try to read /dev/urandom.
        seed = trng(1)
    except:
        print "Couldn't initialize entropy source (/dev/urandom).  Bailing..."
        sys.exit(1)
    openssl_seed(40)

def sha1(s):
    """sha1(s) -> str

    Returns the SHA1 hash of a string"""
    return _ml.sha1(s)

def ctr_crypt(s, key, idx=0):
    """ctr_crypt(s, key, idx=0) -> str

       Given a string s and a 16-byte key key, computes the AES counter-mode
       encryption of s using k.  The counter begins at idx."""

    return _ml.aes_ctr128_crypt(key,s,idx)

def prng(key,count,idx=0):
    """Returns the bytestream 0x00000000...., encrypted in counter mode."""
    return _ml.aes_ctr128_crypt(key,"",idx,count)

def lioness_encrypt(s,key):
    """lioness_encrypt(s, (key1,key2,key3,key4)) -> str

    Given a 16-byte key2 and key4, and a 20-byte key1 and key3, encrypts
    s using the LIONESS super-pseudorandom permutation."""

    assert len(key) == 4
    key1,key2,key3,key4 = key
    assert len(key1)==len(key3)==20
    assert len(key2)==len(key4)==16
    assert len(s) > 20

    left = s[:20]
    right = s[20:]
    del s
    #XXXX This slice makes me nervous
    right = ctr_crypt(right, _ml.strxor(left,key1)[:16])
    left = _ml.strxor(left, _ml.sha1(right, key2))
    right = ctr_crypt(right, _ml.strxor(left, key3)[:16])
    left = _ml.strxor(left, _ml.sha1(right, key4))
    return left + right

def lioness_decrypt(s,key):
    """lioness_encrypt(s, (key1,key2,key3,key4)) -> str

    Given a 16-byte key2 and key4, and a 20-byte key1 and key3, decrypts
    s using the LIONESS super-pseudorandom permutation."""

    assert len(key) == 4
    key1,key2,key3,key4 = key
    assert len(key1)==len(key3)==20
    assert len(key2)==len(key4)==16
    assert len(s) > 20

    left = s[:20]
    right = s[20:]
    del s
    #XXXX This slice makes me nervous
    left = _ml.strxor(left, _ml.sha1(right, key4))
    right = ctr_crypt(right, _ml.strxor(left, key3)[:16])
    left = _ml.strxor(left, _ml.sha1(right, key2))
    right = ctr_crypt(right, _ml.strxor(left, key1)[:16])
    return left + right

def openssl_seed(count):
    """openssl_seed(count)

       Seeds the openssl rng with 'count' bytes of real entropy."""
    _ml.openssl_seed(trng(count))

def trng(count):
    """trng(count) -> str

    Returns (count) bytes of true random data from a true source of
    entropy (/dev/urandom)"""
    f = open('/dev/urandom')
    d = f.read(count)
    f.close()
    return d

OAEP_PARAMETER = "He who would make his own liberty secure, "+\
                 "must guard even his enemy from oppression."

def pk_encrypt(data,key):
    """pk_encrypt(data,key)->str

    Returns the RSA encryption of OAEP-padded data, using the public key in\n
    key"""
    bytes = _ml.rsa_get_modulus_bytes(key)
    data = _ml.add_oaep_padding(data,OAEP_PARAMETER,bytes)
    # public key encrypt
    return _ml.rsa_crypt(key, data, 1, 1)

def pk_decrypt(data,key):
    """pk_decrypt(data,key)->str

    Returns the unpadded RSA decryption of data, using the private key in\n
    key"""
    bytes = _ml.rsa_get_modulus_bytes(key)
    # private key decrypt
    data = _ml.rsa_crypt(key, data, 0, 0)
    return  _ml.check_oaep_padding(data,OAEP_PARAMETER,bytes)

def pk_generate(bits=1024,e=65535):
    """pk_generate(bits=1024, e=65535) -> rsa

       Generate a new RSA keypair with 'bits' bits and exponent 'e'.  It is
       safe to use the default value of 'e'."""
    return _ml.rsa_generate(bits,e)

def pk_get_modulus(key):
    """pk_get_modulus(rsa)->long

       Extracts the modulus of a public key."""
    return _ml.rsa_get_public_key(key)[0]

def pk_from_modulus(n, e=65535L):
    """pk_from_modulus(rsa,e=65535L)->rsa

       Given a modulus and exponent, creates an RSA public key."""
    return _ml.rsa_make_public_key(long(n),long(e))

def pk_encode_private_key(key):
    """pk_encode_private_key(rsa)->str

       Creates an ASN1 representation of a keypair for external storage."""
    return _ml.rsa_encode_key(key,0)

def pk_decode_private_key(s):
    """pk_encode_private_key(str)->rsa

       Reads an ASN1 representation of a keypair from external storage."""
    return _ml.rsa_decode_key(s,0)

#----------------------------------------------------------------------

HEADER_SECRET_MODE = "HEADER SECRET KEY"
PRNG_MODE = "RANDOM JUNK"
HEADER_ENCRYPT_MODE = "HEADER ENCRYPT"
PAYLOAD_ENCRYPT_MODE = "PAYLOAD ENCRYPT"
HIDE_HEADER_MODE = "HIDE HEADER"

class Keyset:
    """A Keyset represents a set of keys generated from a single master
       secret."""
    def __init__(self, master):
        """Keyset(master)

           Creates a new keyset from a given master secret."""
        self.master = master
    def get(self, mode, bytes=AES_KEY_LEN):
        """ks.get(mode, bytes=AES_KEY_LEN)

           Creates a new key from the master secret, using the first <bytes>
           bytes of SHA1(master||mode)."""
        assert 0<bytes<=DIGEST_LEN
        return sha1(self.master+mode)[:bytes]
    def getLionessKeys(self, mode):
        """ks.getLionessKeys(mode)

           Returns a set of 4 lioness keys, as described in the Mixminion
           specification."""
        return (self.get(mode+" (FIRST SUBKEY)", 20),
                self.get(mode+" (SECOND SUBKEY)", 16),
                self.get(mode+" (THIRD SUBKEY)", 20),
                self.get(mode+" (FOURTH SUBKEY)", 16))

def lioness_keys_from_payload(payload):
    # XXXX Temporary method till George and I agree on a key schedule.
    digest = sha1(payload)
    return Keyset(digest).getLionessKeys(HIDE_HEADER_MODE)

#---------------------------------------------------------------------

class AESCounterPRNG:
    _CHUNKSIZE = 16*1024
    _KEYSIZE = 16
    def __init__(self, seed=None):
        self.counter = 0
        self.bytes = ""
        if seed==None: seed=trng(AESCounterPRNG._KEYSIZE)
        self.key = seed

    def getBytes(self, n):
        if n > len(self.bytes):
            nMore = n+AESCounterPRNG._CHUNKSIZE-len(self.bytes)
            morebytes = prng(self.key,nMore,self.counter)
            self.counter+=nMore
            res = self.bytes+morebytes[:n-len(self.bytes)]
            self.bytes=morebytes[n-len(self.bytes):]
            return res
        else:
            res = self.bytes[:n]
            self.bytes=self.bytes[n:]
            return res

--- NEW FILE: Formats.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: Formats.py,v 1.1 2002/05/29 03:52:13 nickm Exp $
"""mixminion.Formats

   Functions, classes, and constants to parse and unparse Mixminion messages
   and related structures."""

__all__ = [ 'ParseError', 'Message', 'Header', 'Subheader',
            'parseMessage', 'parseHeader', 'parseSubheader',
            'getTotalBlocksForRoutingInfo',
            'IPV4Info', 'SMTPInfo',
            'parseIPV4Info', 'parseSMTPInfo',
            'ENC_SUBHEADER_LEN', 'HEADER_LEN',
            'PAYLOAD_LEN', 'MAJOR_NO', 'MINOR_NO',
            'SECRET_LEN']

import types, struct, unittest
import mixminion.Common

MAJOR_NO, MINOR_NO = 0,1

# Length of a Mixminion message
MESSAGE_LEN = 1 << 15
# Length of a header section
HEADER_LEN  = 128 * 16
# Length of a single payload
PAYLOAD_LEN = MESSAGE_LEN - HEADER_LEN*2

# Smallest possible size for a subheader
MIN_SUBHEADER_LEN = 42
# Most information we can fit into a subheader
MAX_SUBHEADER_LEN = 86
# Longest routing info that will fit in the main subheader
MAX_ROUTING_INFO_LEN = MAX_SUBHEADER_LEN - MIN_SUBHEADER_LEN

# Length of a subheader, once RSA-encoded.
ENC_SUBHEADER_LEN = 128

# Length of a digest
DIGEST_LEN = 20
# Length of a secret key
SECRET_LEN = 16

# Most info that fits in a single ERI block
ROUTING_INFO_PER_EXTENDED_SUBHEADER = ENC_SUBHEADER_LEN

class ParseError(mixminion.Common.MixError):
    """Thrown when a message or portion thereof is incorrectly formatted."""
    pass

def parseMessage(s):
    """parseMessage(s) -> Message

       Given a 32K string, returns a Message object that breaks it into
       two headers and a payload."""
    if len(s) != MESSAGE_LEN:
        raise ParseError("Bad message length")
        
    return Message(s[:HEADER_LEN],
                   s[HEADER_LEN:HEADER_LEN*2],
                   s[HEADER_LEN*2])

class Message:
    """Represents a complete Mixminion packet

       Fields: header1, header2, payload"""
    def __init__(self, header1, header2, payload):
        """Message(header1, header2, payload) -> msg

           Creates a new Message object from three strings."""
        self.header1 = header1
        self.header2 = header2
        self.payload = payload

    def pack(self):
        """Returns the 32K string value of this message."""
        return "".join([self.header1,self.header2,self.payload])

def parseHeader(s):
    """parseHeader(s) -> Header

       Converts a 2K string into a Header object"""
    if len(s) != HEADER_LEN:
        raise ParseError("Bad header length")

    return Header(s)

class Header:
    """Represents a 2K Mixminion header"""
    def __init__(self, contents):
        self.contents = contents

    def __getitem__(self, i):
        """header[i] -> str

           Returns the i'th encoded subheader of this header, for i in 0..15"""
        return self.contents[i*ENC_SUBHEADER_LEN:
                             (i+1)*ENC_SUBHEADER_LEN]

    def __getslice__(self, i, j):
        """header[i] -> str

           Returns a slice of the i-j'th subheaders of this header,
           for 0 <= i <= j <= 16"""
        return self.contents[i*ENC_SUBHEADER_LEN:
                             j*ENC_SUBHEADER_LEN]

SH_UNPACK_PATTERN = "!BB%ds%dsHH" % (SECRET_LEN, DIGEST_LEN)

def parseSubheader(s):
    """parseSubheader(s) -> Subheader

       Converts a decoded Mixminion subheader into a Subheader object"""
    if len(s) < MIN_SUBHEADER_LEN:
        raise ParseError("Header too short")

    major, minor, secret, digest, rlen, rt = \
           struct.unpack(SH_UNPACK_PATTERN, s[:MIN_SUBHEADER_LEN])
    ri = s[MIN_SUBHEADER_LEN:]
    if rlen < len(ri):
        ri = ri[:rlen]
    return Subheader(major,minor,secret,digest,rt,ri,rlen)

def getTotalBlocksForRoutingInfo(bytes):
    if bytes <= MAX_ROUTING_INFO_LEN:
        return 1
    else:
        extraBytes = bytes - MAX_ROUTING_INFO_LEN
        return 2 + (extraBytes // ROUTING_INFO_PER_EXTENDED_SUBHEADER)
    
class Subheader:
    """Represents a decoded Mixminion header

       Fields: major, minor, secret, digest, routinglen, routinginfo,
               routingtype."""
    def __init__(self, major, minor, secret, digest, routingtype,
                 routinginfo, routinglen=None):
        self.major = major
        self.minor = minor
        self.secret = secret
        self.digest = digest
        if routinglen == None:
            self.routinglen = len(routinginfo)
        else:
            self.routinglen = routinglen
        self.routingtype = routingtype
        self.routinginfo = routinginfo

    def __repr__(self):
        return ("Subheader(major=%(major)r, minor=%(minor)r, "+
                "secret=%(secret)r, digest=%(digest)r, "+
                "routingtype=%(routingtype)r, routinginfo=%(routinginfo)r, "+
                "routinglen=%(routinglen)r)")% self.__dict__
                
    def setRoutingInfo(self, info):
        """Changes the routinginfo, and the routinglength to correspond."""
        self.routinginfo = info
        self.routinglen = len(info)

    def isExtended(self):
        """Returns true iff the routinginfo is too long to fit in a single
           subheader."""
        return self.routinglen > MAX_ROUTING_INFO_LEN

    def getNExtraBlocks(self):
        """Returns the number of extra blocks that will be needed to fit
           the routinginfo."""
        return getTotalBlocksForRoutingInfo(self.routinglen)-1

    def appendExtraBlocks(self, data):
        """appendExtraBlocks(str)

           Given additional (decoded) blocks of routing info, adds them
           to the routinginfo of this object."""
        nBlocks = self.getNExtraBlocks()
        assert len(data) == nBlocks * ENC_SUBHEADER_LEN
        raw = [self.routinginfo]
        for i in range(nBlocks):
            block = data[i*ENC_SUBHEADER_LEN:(i+1)*ENC_SUBHEADER_LEN]
            raw.append(block)
        self.routinginfo = ("".join(raw))[:self.routinglen]
        
    def pack(self):
        """Returns the (unencrypted) string representation of this Subhead"""
        assert self.routinglen == len(self.routinginfo)
        assert len(self.digest) == DIGEST_LEN
        assert len(self.secret) == SECRET_LEN
        info = self.routinginfo[:MAX_ROUTING_INFO_LEN]

        return struct.pack(SH_UNPACK_PATTERN, 
                           self.major,self.minor,self.secret,self.digest,
                           self.routinglen, self.routingtype)+info
    
    def getExtraBlocks(self):
        """getExtraBlocks() -> [ str, ...]

           Returns a list of (unencrypted) blocks of extra routing info."""
        if not self.isExtended():
            return []
        else:
            info = self.routinginfo[MAX_ROUTING_INFO_LEN:]
            result = []
            for i in range(self.getNExtraBlocks()):
                content = info[i*ROUTING_INFO_PER_EXTENDED_SUBHEADER:
                               (i+1)*ROUTING_INFO_PER_EXTENDED_SUBHEADER]
                missing = ROUTING_INFO_PER_EXTENDED_SUBHEADER-len(content)
                if missing > 0:
                    content += '\000'*missing                
                result.append(content)
            return result

IPV4_PAT = "!H%ds" % DIGEST_LEN

def parseIPV4Info(s):
    """parseIP4VInfo(s) -> IPV4Info

       Converts routing info for an IPV4 address into an IPV4Info object."""
    if len(s) != 4+2+DIGEST_LEN:
        raise ParseError("IPV4 information with wrong length")
    ip, port, keyinfo = struct.unpack(IPV4_PAT, s)
    return IPV4Info(ip, port, keyinfo)

class IPV4Info:
    "XXXX"
    def __init__(self, ip, port, keyinfo):
        self.ip = ip
        self.port = port
        self.keyinfo = keyinfo

    def pack(self):
        assert len(keyinfo) == DIGEST_LEN
        return struct.pack(IPV4_PAT, self.ip, self.port, keyinfo)

def parseSMTPInfo(s):
    "XXXX"
    lst = s.split("\000",1)
    if len(lst) == 1:
        return SMTPInfo(s,None)
    else:
        return SMTPInfo(lst[0], lst[1])

class SMTPInfo:
    "XXXX"
    def __init__(self, email, tag):
        self.email = email
        self.tag = tag

    def pack(self):
        if self.tag != None:
            return self.email+"\000"+self.tag
        else:
            return self.email
        
def parseLocalInfo(s):
    "XXXX"
    nil = s.find('\000')
    user = s[:nil]
    tag = s[nil+1]
    return LocalInfo(user,tag)
    
class LocalInfo:
    "XXXX"
    def __init__(self, user, tag):
        self.user = user
        assert user.find('\000') == -1
        self.tag = tag

    def pack(self):
        return self.user+"\000"+self.tag

--- NEW FILE: HashLog.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: HashLog.py,v 1.1 2002/05/29 03:52:13 nickm Exp $

import anydbm

__all__ = [ 'HashLog' ]

class HashLog:
    """A HashLog is a file containing a list of message digests that we've
       already processed.

       Each HashLog corresponds to a single public key (whose hash is the
       log's keyid.  A HashLog must persist for as long as the key does.

       It is not necessary to sync the HashLog to the disk every time a new
       message is seen; rather, the HashLog must be synced before any messages
       are sent to the network.

       HashLogs are implemented using Python's anydbm interface.  This defaults
       to using Berkeley DB, GDBM, or --if you have none of these-- a flat
       text file.

       The base HashLog implementation assumes an 8-bit-clean database that
       maps strings to strings."""
    def __init__(self, filename, keyid):
        """HashLog(filename, keyid) -> hashlog

           Creates a new HashLog to store data in 'filename' for the key
           'keyid'."""
        self.log = anydbm.open(filename, 'c')
        try:
            if self.log["KEYID"] != keyid:
                print "Mismatch on keyid"
                #XXXX Need warning mechanism
        except KeyError:
            self.log["KEYID"] = keyid
            
    def seenHash(self, hash):
        """seenHash(hash) -> bool

           Returns true iff 'hash' has been logged before."""
        try:
            self.log[hash]
            return 1
        except KeyError:
            return 0

    def logHash(self, hash):
        """logHash(hash)

           Inserts 'hash' into the database."""
        self.log[hash] = "1"

    def sync(self):
        """sync()

           Flushes changes to this log to the filesystem."""
        self.log.sync()
        
    def close(self):
        """close()

           Closes this log."""
        self.log.close()

--- NEW FILE: Modules.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: Modules.py,v 1.1 2002/05/29 03:52:13 nickm Exp $

DROP_TYPE      = 0x0000
FWD_TYPE       = 0x0001
SWAP_FWD_TYPE  = 0x0002

MIN_EXIT_TYPE  = 0x0100
SMTP_TYPE      = 0x0100
LOCAL_TYPE     = 0x0101



--- NEW FILE: ServerInfo.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: ServerInfo.py,v 1.1 2002/05/29 03:52:13 nickm Exp $

#XXXX DOC

__all__ = [ 'ServerInfo' ]

#
# Stub class till we have the real thing
#
class ServerInfo:
    def __init__(self, addr, port, modulus, keyid):
        self.addr = addr
        self.port = port
        self.modulus = modulus
        self.keyid = keyid

    def getAddr(self): return self.addr
    def getPort(self): return self.port
    def getModulus(self): return self.modulus
    return getKeyID(self): return self.keyid
    

--- NEW FILE: ServerProcess.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: ServerProcess.py,v 1.1 2002/05/29 03:52:13 nickm Exp $

import mixminion.Crypto as Crypto
import mixminion.Formats as Formats
import mixminion.Modules as Modules
import mixminion.Common as Common

class ContentError(Common.MixError):
    pass

class ServerProcess:
    def __init__(self, privatekey, hashlog, exitHandler, forwardHandler):
        self.privatekey = privatekey
        self.hashlog = hashlog
        self.exitHandler = exitHandler
        self.forwardHandler = forwardHandler

    # Raises ParseError, ContentError.
    def processMessage(self, msg):
        r = self._processMessage(msg)
        if r != None:
            m, a = r
            apply(m, a)

    # Raises ParseError, ContentError, SSLError.
    #  Returns oneof (None), (method, argl)
    def _processMessage(self, msg):
        msg = Formats.parseMessage(msg)
        header1 = msg.header1
        subh = header1[0]
        subh = Crypto.pk_decrypt(subh, self.privatekey)
        subh = Formats.parseSubheader(subh)

        if subh.major != 3 or subh.minor != 0:
            raise ContentError("Invalid protocol version")

        digest = Crypto.sha1(header1[1:16])
        if digest != subh.digest:
            raise ContentError("Invalid digest")

        # XXXX Need to decrypt extra routing info.
        if subh.isExtended():
            nExtra = subh.getNExtraBlocks() 
            if nExtra > 15:
                raise ContentError("Impossibly long routing info length")
            extra = header1[1:1+nExtra]
            subh.appendExtraBlocks(extra)
            remainingHeader = header1[1+nExtra:]
        else:
            remainingHeader = header1[1:]

        keys = Crypto.Keyset(subh.master)

        if type == Modules.DROP_TYPE:
            return None

        payload = Crypto.sprp_decrypt(msg.payload,
                                      keys.get(Crypto.PAYLOAD_ENCRYPT_MODE))

        # XXXX This doesn't match what George said.
        if type > Modules.MIN_EXIT_TYPE:
            return (self.exitHandler.processMessage,
                    (subh.routingtype, subh.routinginfo,
                     keys.get(Crypto.APPLICATION_KEY_MODE),
                     payload))

        if type not in (SWAP_FWD_TYPE, FWD_TYPE):
            raise ContentError("Unrecognized mixminion type")

        remainingHeader = remainingHeader +\
                          Crypto.prng(keys.get(Crypto.PRNG_MODE),
                                     FORMATS.HEADER_LEN-len(remainingHeader))
        header1 = Crypto.ctr_crypt(remainingHeader,
                                   keys.get(Crypto.HEADER_SECRET_MODE))
        
        header2 = Crypto.sprp_decrypt(msg.header2,
                                      keys.get(Crypto.HEADER_ENCRYPT_MODE))

        if type == Modules.SWAP_FWD_TYPE:
            header2 = Crypto.sprp_decrypt(msg.header2,
                                          keys.get(Crypto.HIDE_HEADER_MODE))
            header1, header2 = header2, header1

        address = Formats.parseIPV4Info(subh.routinginfo)

        msg = Formats.Message(header1, header2, payload).pack()

        return (self.forwardHandler.queue, (address.ip,
                                            address.port,
                                            address.keyid,
                                            msg))

--- NEW FILE: __init__.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: __init__.py,v 1.1 2002/05/29 03:52:13 nickm Exp $

__version__ = "0.1"

--- NEW FILE: benchmark.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: benchmark.py,v 1.1 2002/05/29 03:52:13 nickm Exp $
from time import time

loop_overhead = {}
def timeit_(fn, iters, ov=1):
    """XXXX"""
    nones = [None]*iters
    overhead = [0, loop_overhead.get(iters, 0)][ov]
    t = time()
    for n in nones: fn()
    t2 = time()-t
    return (t2-overhead) / float(iters)

min_o = 1.0
max_o = 0.0
for iters in [10**n for n in range(2,7)]:
    overhead = timeit_((lambda:(lambda:None)()), iters)
    loop_overhead[iters] = overhead
    min_o = min(min_o, overhead/float(iters))
    max_o = max(max_o, overhead/float(iters))

def timestr(t):
    """XXXX"""
    if abs(t) >= 1.0:
        return "%.3f sec" % t
    elif abs(t) >= .001:
        return "%.3f msec" % (t*1000)
    elif abs(t) >= (.000001):
        return "%.3f usec" % (t*1000000)
    else:
        return "%f psec" % (t*1000000000L)

def timeit(fn,times):
    """XXXX"""
    return timestr(timeit_(fn,times))

def spacestr(n):
    if abs(n) < 1e4:
        return "%d bytes" %n
    elif abs(n) < 1e7:
        return "%d KB" % (n//1024)
    elif abs(n) < 1e10:
        return "%d MB" % (n//(1024*1024))
    else:
        return "%d GB" % (n//(1024*1024*1024))

#----------------------------------------------------------------------
import mixminion._minionlib as _ml
from Crypto import *
from Crypto import OAEP_PARAMETER

def cryptoTiming():
    loop_overhead = {}
    short = "Hello, Dali!"
    s1K = "8charstr"*128
    s2K = s1K*2
    s4K = s2K*2
    s8K = s4K*2
    s32K = s8K*4

    print "#==================== CRYPTO ======================="
    print "Timing overhead: %s...%s" % (timestr(min_o),timestr(max_o))

    print "SHA1 (short)", timeit((lambda : sha1(short)), 100000)
    print "SHA1 (8K)", timeit((lambda : sha1(s8K)), 10000)
    print "SHA1 (32K)", timeit((lambda : sha1(s32K)), 1000)

    shakey = "8charstr"*2
    print "Keyed SHA1 (short)",
    print timeit((lambda : _ml.sha1(short,shakey)), 100000)
    print "Keyed SHA1 (8K)", timeit((lambda : _ml.sha1(s8K, shakey)), 10000)
    print "Keyed SHA1 (32K)", timeit((lambda : _ml.sha1(s32K, shakey)), 1000)

    print "TRNG (20 byte)", timeit((lambda: trng(20)), 100)
    print "TRNG (128 byte)", timeit((lambda: trng(128)), 100)
    print "TRNG (1K)", timeit((lambda: trng(1024)), 100)

    print "xor (1K)", timeit((lambda: _ml.strxor(s1K,s1K)), 100000)
    print "xor (32K)", timeit((lambda: _ml.strxor(s32K,s32K)), 1000)

    key = "8charstr"*2
    print "aes (short)", timeit((lambda: ctr_crypt(short,key)), 100000)
    print "aes (1K)", timeit((lambda: ctr_crypt(s1K,key)), 10000)
    print "aes (32K)", timeit((lambda: ctr_crypt(s32K,key)), 100)

    print "prng (short)", timeit((lambda: prng(key,8)), 100000)
    print "prng (1K)", timeit((lambda: prng(key,1024)), 10000)
    print "prng (32)", timeit((lambda: prng(key,32768)), 100)

    lkey = Keyset("keymaterial foo bar baz").getLionessKeys("T")
    print "lioness E (1K)", timeit((lambda: lioness_encrypt(s1K, lkey)), 1000)
    print "lioness E (2K)", timeit((lambda: lioness_encrypt(s1K, lkey)), 1000)
    print "lioness E (4K)", timeit((lambda: lioness_encrypt(s4K, lkey)), 1000)
    print "lioness E (32K)", timeit((lambda: lioness_encrypt(s32K, lkey)), 100)
    print "lioness D (1K)", timeit((lambda: lioness_decrypt(s1K, lkey)), 1000)
    print "lioness D (2K)", timeit((lambda: lioness_decrypt(s1K, lkey)), 1000)
    print "lioness D (4K)", timeit((lambda: lioness_decrypt(s4K, lkey)), 1000)
    print "lioness D (32K)", timeit((lambda: lioness_decrypt(s32K, lkey)), 100)

    s70b = "10character"*7
    print "OAEP_add (70->128B)",
    print timeit((lambda: _ml.add_oaep_padding(s70b,OAEP_PARAMETER,128)),10000)
    r = _ml.add_oaep_padding(s70b, OAEP_PARAMETER,128)
    print "OAEP_check (128B->70B)",
    print timeit((lambda: _ml.check_oaep_padding(r,OAEP_PARAMETER,128)),10000)

    print "RSA generate (1024 bit)", timeit((lambda: pk_generate()),10)
    rsa = pk_generate()
    print "Pad+RSA public encrypt",
    print timeit((lambda: pk_encrypt(s70b, rsa)),1000)
    enc = pk_encrypt(s70b, rsa)
    print "Pad+RSA private decrypt", timeit((lambda: pk_decrypt(enc, rsa)),100)

    for (bits,it) in ((2048,10),(4096,3)):
        rsa2 = pk_generate(bits)
        enc = pk_encrypt(s70b, rsa2)
        print "Pad+RSA private decrypt (%d bit)"%bits,
        print timeit((lambda: pk_decrypt(enc, rsa2)),it)

#----------------------------------------------------------------------
def hashlogTiming():
    for load in (100, 1000, 10000, 100000):
        _hashlogTiming(load)

def _hashlogTiming(load):
    import tempfile, os
    from mixminion.Crypto import AESCounterPRNG
    from mixminion.HashLog import HashLog
    prng = AESCounterPRNG("a"*16)
    fname = tempfile.mktemp(".db")
    
    h = HashLog(fname, "A")
    hashes = [ prng.getBytes(20) for i in range(load) ]

    t = time()
    for hash in hashes:
        h.logHash(hash)
    t = time()-t
    print "Add entry (up to %s entries)" %load, timestr( t/float(load) )

    t = time()
    for hash in hashes[0:1000]:
        h.seenHash(hash)
    t = time()-t    
    print "Check entry [hit] (%s entries)" %load, timestr( t/1000.0 )

    hashes =[ prng.getBytes(20) for i in range(1000) ]
    t = time()
    for hash in hashes:
        h.seenHash(hash)
    t = time()-t   
    print "Check entry [miss] (%s entries)" %load, timestr( t/1000.0 )

    h.close()
    print "File size (%s entries)"%load, spacestr(os.stat(fname).st_size)
    os.unlink(fname)

#----------------------------------------------------------------------
def testLeaks1():
    print "Trying to leak (sha1,aes,xor,seed,oaep)"
    s20k="a"*20*1024
    key="a"*16
    while 1:
        if 1:
            _ml.sha1(s20k)
            _ml.sha1(s20k,s20k)
            _ml.aes_ctr128_crypt(key,s20k,0)
            _ml.aes_ctr128_crypt(key,s20k,2000)
            _ml.aes_ctr128_crypt(key,"",2000,20000)
            _ml.aes_ctr128_crypt(key,"",0,20000)
            _ml.aes_ctr128_crypt(key,s20k,0,2000)
            try:
                _ml.aes_ctr128_crypt("abc",s20k,0,2000)
            except:
                pass
            _ml.strxor(s20k,s20k)
            try:
                _ml.strxor(s20k,key)
            except:
                pass
            _ml.openssl_seed(s20k)
            r = _ml.add_oaep_padding("Hello",OAEP_PARAMETER,128)
            _ml.check_oaep_padding(r,OAEP_PARAMETER,128)
            try:
                _ml.check_oaep_padding("hello",OAEP_PARAMETER,128)
            except:
                pass
            try:
                _ml.add_oaep_padding(s20k,OAEP_PARAMETER,128)
            except:
                pass
            try:
                _ml.add_oaep_padding("a"*127,OAEP_PARAMETER,128)
            except:
                pass

def testLeaks2():
    print "Trying to leak (rsa)"

    s20 = "a"*20
    p = pk_generate(512)
    n,e = _ml.rsa_get_public_key(p)

    while 1:
        if 1:
            p = pk_generate(512)
            pk_decrypt(pk_encrypt(s20,p),p)
            for public in (0,1):
                x = _ml.rsa_encode_key(p,public)
                _ml.rsa_decode_key(x,public)
            _ml.rsa_get_public_key(p)
            _ml.rsa_make_public_key(n,e)

#----------------------------------------------------------------------

def timeAll():
    cryptoTiming()
    hashlogTiming()

if __name__ == '__main__':
    timeAll()
    #testLeaks1()
    #testLeaks2()

--- NEW FILE: test.py ---
# Copyright 2002 Nick Mathewson.  See LICENSE for licensing information.
# $Id: test.py,v 1.1 2002/05/29 03:52:13 nickm Exp $

import unittest

#----------------------------------------------------------------------
import mixminion._minionlib as _ml

class MinionlibCryptoTests(unittest.TestCase):
    def hexread(self,s):
        r = []
        hexvals = "0123456789ABCDEF"
        for i in range(len(s) // 2):
            v1 = hexvals.index(s[i*2])
            v2 = hexvals.index(s[i*2+1])
            c = (v1 << 4) + v2
            assert 0 <= c < 256
            r.append(chr(c))
        return "".join(r)

    def test_sha1(self):
        s1 = _ml.sha1

        self.assertEquals(s1("abc"),
               self.hexread("A9993E364706816ABA3E25717850C26C9CD0D89D"))

        s = s1("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq")
        self.assertEquals(s,
               self.hexread("84983E441C3BD26EBAAE4AA1F95129E5E54670F1"))

        self.assertEquals(s1("abc", "def"),
                          s1("defabcdef"))
        self.failUnlessRaises(TypeError, s1, 1)

    def test_xor(self):
        xor = _ml.strxor
        self.assertEquals(xor("abc", "\000\000\000"), "abc")
        self.assertEquals(xor("abc", "abc"), "\000\000\000")
        self.assertEquals(xor("\xEF\xF0\x12", "\x11\x22\x35"), '\xFE\xD2\x27')

        self.failUnlessRaises(TypeError, xor, "a", "bb")
        
    def test_aes(self):
        crypt = _ml.aes_ctr128_crypt

        # One of the test vectors from AES.
        key = "\x80" + "\x00" * 15
        expected = self.hexread("8EDD33D3C621E546455BD8BA1418BEC8")
        self.failUnless(crypt(key, key, 0) == expected)
        self.failUnless(crypt(key, key) == expected)
        self.failUnless(crypt(key, " "*100, 0)[1:] == crypt(key, " "*99, 1))
        self.failUnless(crypt(key,crypt(key, " "*100, 0),0) == " "*100)

        teststr = """I have seen the best ciphers of my generation
                     Destroyed by cryptanalysis, broken, insecure,
                     Implemented still in cryptographic libraries"""
        
        self.assertEquals(teststr,crypt("xyzz"*4,crypt("xyzz"*4,teststr)))

        # PRNG mode
        expected2 = self.hexread("0EDD33D3C621E546455BD8BA1418BEC8")
        self.assertEquals(expected2, crypt(key, "", 0, len(expected2)))
        self.assertEquals(expected2, crypt(key, "Z", 0, len(expected2)))
        self.assertEquals(expected2[5:], crypt(key, "", 5, len(expected2)-5))

        # Failing cases
        self.failUnlessRaises(TypeError, crypt, "a", teststr)
        self.failUnlessRaises(TypeError, crypt, "a"*17, teststr)

        self.assertEquals("", crypt(key,"",0,-1))

    def test_openssl_seed(self):
        _ml.openssl_seed("Hello")
        _ml.openssl_seed("")

    def test_oaep(self):
        x = _ml.add_oaep_padding("A", "B", 128)
        self.assertEquals("A",_ml.check_oaep_padding(x, "B", 128))
        
        _ml.add_oaep_padding("A"*86, "B",128)
        self.failUnlessRaises(TypeError,
                              _ml.add_oaep_padding,"A"*300, "B", 128)
        self.failUnlessRaises(_ml.SSLError,
                              _ml.add_oaep_padding,"A"*87, "B", 128)
        self.failUnlessRaises(_ml.SSLError,
                              _ml.check_oaep_padding,x[1:]+"Y","B",128)
        self.failUnlessRaises(_ml.SSLError,
                              _ml.check_oaep_padding,x[:-1]+"Y","B",128)

    def test_rsa(self):
        p = _ml.rsa_generate(1024, 65535)
        def sslerr(*args): self.failUnlessRaises(_ml.SSLError, *args)

        for pub1 in (0,1):
            for enc1 in (0,1):
                msg = "Now is the time for all anonymous parties"
                x = _ml.add_oaep_padding(msg, "B", 128)
                x2 = _ml.rsa_crypt(p, x, pub1, enc1);
                x3 = _ml.rsa_crypt(p, x2, [1,0][pub1], [1,0][enc1]);
                self.failUnless(x3 == x)
                x4 = _ml.check_oaep_padding(x3, "B", 128)
                self.failUnless(x4 == msg)

        # Too short
        self.failUnlessRaises(_ml.SSLError,_ml.rsa_crypt,p,"X",1,1)
        # Too long
        self.failUnlessRaises(_ml.SSLError,_ml.rsa_crypt,p,x+"XXX",1,1)

        padhello = _ml.add_oaep_padding("Hello", "B", 128)
        for public in (0,1):
            x = _ml.rsa_encode_key(p,public)
            p2 = _ml.rsa_decode_key(x,public)
            x3 = _ml.rsa_encode_key(p2,public)
            self.assertEquals(x,x3)
            self.assertEquals(_ml.rsa_crypt(p,padhello,public,1),
                              _ml.rsa_crypt(p2,padhello,public,1))

        n,e = _ml.rsa_get_public_key(p)
        p2 = _ml.rsa_make_public_key(n,e)
        self.assertEquals((n,e), _ml.rsa_get_public_key(p2))
        self.assertEquals(65535,e)
        self.assertEquals(_ml.rsa_encode_key(p,1), _ml.rsa_encode_key(p,1))
        
        # Try private-key ops with public key
        p3 = _ml.rsa_decode_key(_ml.rsa_encode_key(p,1),1)
        msg1 = _ml.rsa_crypt(p, padhello, 1,1)
        msg2 = _ml.rsa_crypt(p, padhello, 1,1)
        msg3 = _ml.rsa_crypt(p, padhello, 1,1)
        self.assertEquals(padhello, _ml.rsa_crypt(p,msg1,0,0))
        self.assertEquals(padhello, _ml.rsa_crypt(p,msg2,0,0))
        self.assertEquals(padhello, _ml.rsa_crypt(p,msg3,0,0))
        self.failUnlessRaises(TypeError, _ml.rsa_crypt, p2, msg1, 0, 0)
        self.failUnlessRaises(TypeError, _ml.rsa_crypt, p3, msg1, 0, 0)
        self.failUnlessRaises(TypeError, _ml.rsa_encode_key, p2, 0)
        self.failUnlessRaises(TypeError, _ml.rsa_encode_key, p3, 0)
#----------------------------------------------------------------------
import mixminion.Crypto
from mixminion.Crypto import *

class CryptoTests(unittest.TestCase):
    def test_initcrypto(self):
        init_crypto()

    def test_wrappers(self):
        self.assertEquals(_ml.sha1("xyzzy"), sha1("xyzzy"))
        k = "xyzy"*4
        self.assertEquals(_ml.aes_ctr128_crypt(k,"hello",0),
                          ctr_crypt("hello",k))
        self.assertEquals(_ml.aes_ctr128_crypt(k,"hello",99),
                          ctr_crypt("hello",k,99))
        self.assertEquals(_ml.aes_ctr128_crypt(k,"",0,99), prng(k,99))
        self.assertEquals(_ml.aes_ctr128_crypt(k,"",3,99), prng(k,99,3))
        self.assertEquals(prng(k,100,0),prng(k,50,0)+prng(k,50,50))

    def test_rsa(self):
        eq = self.assertEquals
        k512 = pk_generate(512)
        k1024 = pk_generate()

        eq(512/8, _ml.rsa_get_modulus_bytes(k512))
        eq(1024/8, _ml.rsa_get_modulus_bytes(k1024))

        self.failUnless((1L<<511) < pk_get_modulus(k512) < (1L<<513))
        self.failUnless((1L<<1023) < pk_get_modulus(k1024) < (1L<<1024))

        msg="Good hello"
        pub512 = pk_from_modulus(pk_get_modulus(k512))
        pub1024 = pk_from_modulus(pk_get_modulus(k1024))

        eq(msg, pk_decrypt(pk_encrypt(msg, k512),k512))
        eq(msg, pk_decrypt(pk_encrypt(msg, pub512),k512))
        eq(msg, pk_decrypt(pk_encrypt(msg, k1024),k1024))
        eq(msg, pk_decrypt(pk_encrypt(msg, pub1024),k1024))

        eq(msg, _ml.check_oaep_padding(
                    _ml.rsa_crypt(k512, pk_encrypt(msg,k512), 0, 0),
                    mixminion.Crypto.OAEP_PARAMETER, 64))

        encoded = pk_encode_private_key(k512)
        decoded = pk_decode_private_key(encoded)
        eq(msg, pk_decrypt(pk_encrypt(msg, pub512),decoded))
        
    def test_trng(self):
        self.assertNotEquals(trng(40), trng(40))

    def test_lioness(self):
        enc = lioness_encrypt
        dec = lioness_decrypt
        key = ("ABCDE"*4, "ABCD"*4, "VWXYZ"*4, "WXYZ"*4)
        plain = mixminion.Crypto.OAEP_PARAMETER*100
        self.assertNotEquals(plain, enc(plain,key))
        self.assertNotEquals(plain, dec(plain,key))
        self.assertEquals(len(plain), len(enc(plain,key)))
        self.assertEquals(len(plain), len(dec(plain,key)))
        self.assertEquals(plain, dec(enc(plain,key),key))
        self.assertEquals(plain, enc(dec(plain,key),key))
        #XXXX check for correct values

    def test_keyset(self):
        s = sha1
        k = Keyset("a")
        eq = self.assertEquals
        eq(s("aFoo")[:10], k.get("Foo",10))
        eq(s("aBar")[:16], k.get("Bar"))
        eq( (s("aBaz (FIRST SUBKEY)"), s("aBaz (SECOND SUBKEY)")[:16],
             s("aBaz (THIRD SUBKEY)"), s("aBaz (FOURTH SUBKEY)")[:16]),
            k.getLionessKeys("Baz"))

    def test_aesprng(self):
        key ="aaaa"*4
        PRNG = AESCounterPRNG(key)
        self.assert_(prng(key,100000) == (
                          PRNG.getBytes(5)+PRNG.getBytes(16*1024-5)+
                          PRNG.getBytes(50)+PRNG.getBytes(32*1024)+
                          PRNG.getBytes(9)+PRNG.getBytes(10)+
                          PRNG.getBytes(15)+PRNG.getBytes(16000)+
                          PRNG.getBytes(34764)))

#----------------------------------------------------------------------
import mixminion.Formats
from mixminion.Formats import *

class FormatTests(unittest.TestCase):
    def test_subheader(self):
        s = Subheader(3,0,"abcdeabcdeabcdef",
                      "ABCDEFGHIJABCDEFGHIJ",
                      1, "Hello")
        
        expected = "\003\000abcdeabcdeabcdef"+\
                   "ABCDEFGHIJABCDEFGHIJ\000\005\000\001Hello"
        self.assertEquals(s.pack(), expected)
        self.failUnless(not s.isExtended())
        self.assertEquals(s.getNExtraBlocks(), 0)
        self.assertEquals(s.getExtraBlocks(), [])

        s = parseSubheader(s.pack())
        self.assertEquals(s.major, 3)
        self.assertEquals(s.minor, 0)
        self.assertEquals(s.secret, "abcde"*3+"f")
        self.assertEquals(s.digest, "ABCDEFGHIJ"*2)
        self.assertEquals(s.routingtype, 1)
        self.assertEquals(s.routinglen, 5)
        self.assertEquals(s.routinginfo, "Hello")
        self.failUnless(not s.isExtended())
        self.assertEquals(s.pack(), expected)

        ts_eliot = "Who is the third who walks always beside you? / "+\
                   "When I count, there are only you and I together / "+\
                   "But when I look ahead up the white road / "+\
                   "There is always another one walking beside you"

        s = Subheader(3,9,"abcdeabcdeabcdef",
                      "ABCDEFGHIJABCDEFGHIJ",
                      62, ts_eliot, len(ts_eliot))

        self.assertEquals(len(ts_eliot), 186)

        expected = "\003\011abcdeabcdeabcdefABCDEFGHIJABCDEFGHIJ\000\272\000\076Who is the third who walks always beside you"
        self.assertEquals(len(expected), mixminion.Formats.MAX_SUBHEADER_LEN)
        self.assertEquals(s.pack(), expected)

        extra = s.getExtraBlocks()
        self.assertEquals(len(extra), 2)
        self.assertEquals(extra[0], "? / When I count, there are only you "+\
                          "and I together / But when I look ahead up the white "+\
                          "road / There is always another one walk")
        self.assertEquals(extra[1], "ing beside you"+(114*'\000'))

        s = parseSubheader(expected)
        self.assertEquals(s.major, 3)
        self.assertEquals(s.minor, 9)
        self.assertEquals(s.secret, "abcde"*3+"f")
        self.assertEquals(s.digest, "ABCDEFGHIJ"*2)
        self.assertEquals(s.routingtype, 62)
        self.assertEquals(s.routinglen, 186)
        self.failUnless(s.isExtended())
        self.assertEquals(s.getNExtraBlocks(), 2)

        s.appendExtraBlocks("".join(extra))
        self.assertEquals(s.routinginfo, ts_eliot)
        self.assertEquals(s.pack(), expected)
        self.assertEquals(s.getExtraBlocks(), extra)

        #XXXX Need failing tests, routinginfo tests.
        
#----------------------------------------------------------------------
from mixminion.HashLog import HashLog

class HashLogTests(unittest.TestCase):
    def test_hashlog(self):
        import tempfile, os
        fname = tempfile.mktemp(".db")
        try:
            self.hashlogTestImpl(fname)
        finally:
            try:
                os.unlink(fname)
            except:
                pass
        
    def hashlogTestImpl(self,fname):
        h = HashLog(fname, "Xyzzy")
        
        notseen = lambda hash: self.assert_(not h.seenHash(hash))
        seen = lambda hash: self.assert_(h.seenHash(hash))
        log = lambda hash: h.logHash(hash)
        
        notseen("a")
        notseen("a*20")
        notseen("\000"*10)
        notseen("\000")
        notseen("\277"*10)
        log("a")
        notseen("a*10")
        notseen("\000"*10)
        notseen("b")
        seen("a")

        log("b")
        seen("b")
        seen("a")
        
        log("\000")
        seen("\000")
        notseen("\000"*10)

        log("\000"*10)
        seen("\000"*10)

        log("\277"*20)
        seen("\277"*20)
        
        log("abcdef"*4)
        seen("abcdef"*4)
        
        h.close()
        h = HashLog(fname, "Xyzzy")
        seen("a")
        seen("b")
        seen("\277"*20)
        seen("abcdef"*4)
        seen("\000")
        seen("\000"*10)
        notseen(" ")
        notseen("\000"*5)

        notseen("ddddd")
        log("ddddd")
        seen("ddddd")
        
        h.close()
        h = HashLog(fname, "Xyzzy")
        seen("ddddd")

        h.close()

    def test_headers(self):
        pass #XXXX

    def test_message(self):
        pass #XXXX

    def test_ipv4info(self):
        pass #XXXX

    def test_smtpinfo(self):
        pass #XXXX

    def test_localinfo(self):
        pass #XXXX

#----------------------------------------------------------------------
import mixminion.ServerProcess
#----------------------------------------------------------------------
import mixminion.BuildMessage
#----------------------------------------------------------------------

def testSuite():
    suite = unittest.TestSuite()
    loader = unittest.TestLoader()
    tc = loader.loadTestsFromTestCase
    suite.addTest(tc(MinionlibCryptoTests))
    suite.addTest(tc(CryptoTests))
    suite.addTest(tc(FormatTests))
    suite.addTest(tc(HashLogTests))
    return suite

def testAll():
    unittest.TextTestRunner().run(testSuite())

if __name__ == '__main__':
    testAll()