[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[tor-commits] [ooni-probe/master] Implement fix for #493



commit b19cf6800c43c936a25ae452f85967b5be4c12d6
Author: Arturo Filastò <arturo@xxxxxxxxxxx>
Date:   Tue May 3 19:18:37 2016 +0200

    Implement fix for #493
    
    * Use a class factory to generate NetTestCase subclasses with injected
      localOptions.
    * Major reworking of how localOptions are handled.
    * Fixes to the NetTestCase object lifecycle to resolve issues with concurrent
      tests running.
---
 ooni/deck.py                |  73 ++++----
 ooni/director.py            |  12 +-
 ooni/errors.py              |   2 +
 ooni/nettest.py             | 394 ++++++++++++++++++++++++--------------------
 ooni/oonicli.py             |   5 +-
 ooni/reporter.py            |   3 +-
 ooni/tests/test_deck.py     |  56 ++++++-
 ooni/tests/test_director.py |  43 +++++
 ooni/tests/test_nettest.py  |  43 ++---
 ooni/tests/test_oonicli.py  |   7 +-
 10 files changed, 370 insertions(+), 268 deletions(-)

diff --git a/ooni/deck.py b/ooni/deck.py
index 4049103..4b6d894 100644
--- a/ooni/deck.py
+++ b/ooni/deck.py
@@ -136,6 +136,7 @@ class Deck(InputFile):
                 log.msg("Skipping...")
                 continue
             net_test_loader = NetTestLoader(test['options']['subargs'],
+                                            annotations=test['options'].get('annotations', {}),
                                             test_file=nettest_path)
             if test['options']['collector']:
                 net_test_loader.collector = test['options']['collector']
@@ -143,23 +144,13 @@ class Deck(InputFile):
 
     def insert(self, net_test_loader):
         """ Add a NetTestLoader to this test deck """
-
-        def has_test_helper(missing_option):
-            for rth in net_test_loader.requiredTestHelpers:
-                if missing_option == rth['option']:
-                    return True
-            return False
-
         try:
             net_test_loader.checkOptions()
             if net_test_loader.requiresTor:
                 self.requiresTor = True
-        except e.MissingRequiredOption as missing_options:
+        except e.MissingTestHelper:
             if not self.bouncer:
                 raise
-            for missing_option in missing_options.message:
-                if not has_test_helper(missing_option):
-                    raise
             self.requiresTor = True
 
         if net_test_loader.collector and net_test_loader.collector.startswith('https://'):
@@ -192,26 +183,25 @@ class Deck(InputFile):
         requires_collector = False
         for net_test_loader in self.netTestLoaders:
             nettest = {
-                'name': net_test_loader.testDetails['test_name'],
-                'version': net_test_loader.testDetails['test_version'],
+                'name': net_test_loader.testName,
+                'version': net_test_loader.testVersion,
                 'test-helpers': [],
                 'input-hashes': [x['hash'] for x in net_test_loader.inputFiles]
             }
             if not net_test_loader.collector and not self.no_collector:
                 requires_collector = True
 
-            for th in net_test_loader.requiredTestHelpers:
-                # {'name':'', 'option':'', 'test_class':''}
-                if th['test_class'].localOptions[th['option']]:
-                    continue
-                nettest['test-helpers'].append(th['name'])
+            if len(net_test_loader.missingTestHelpers) > 0:
                 requires_test_helpers = True
+                nettest['test-helpers'] += map(lambda x: x[1],
+                                               net_test_loader.missingTestHelpers)
 
             required_nettests.append(nettest)
 
         if not requires_test_helpers and not requires_collector:
             defer.returnValue(None)
 
+        log.debug("Looking up {}".format(required_nettests))
         response = yield self.oonibclient.lookupTestCollector(required_nettests)
         provided_net_tests = response['net-tests']
 
@@ -227,17 +217,18 @@ class Deck(InputFile):
                 return net_test['collector'], net_test['test-helpers']
 
         for net_test_loader in self.netTestLoaders:
-            log.msg("Setting collector and test helpers for %s" % net_test_loader.testDetails['test_name'])
+            log.msg("Setting collector and test helpers for %s" %
+                    net_test_loader.testName)
 
             collector, test_helpers = \
-                find_collector_and_test_helpers(net_test_loader.testDetails['test_name'],
-                                                net_test_loader.testDetails['test_version'],
+                find_collector_and_test_helpers(net_test_loader.testName,
+                                                net_test_loader.testVersion,
                                                 net_test_loader.inputFiles)
 
-            for th in net_test_loader.requiredTestHelpers:
-                if not th['test_class'].localOptions[th['option']]:
-                    th['test_class'].localOptions[th['option']] = test_helpers[th['name']].encode('utf-8')
-                net_test_loader.testHelpers[th['option']] = th['test_class'].localOptions[th['option']]
+            for option, name in net_test_loader.missingTestHelpers:
+                test_helper_address = test_helpers[name].encode('utf-8')
+                net_test_loader.localOptions[option] = test_helper_address
+                net_test_loader.testHelpers[option] = test_helper_address
 
             if not net_test_loader.collector:
                 net_test_loader.collector = collector.encode('utf-8')
@@ -252,11 +243,8 @@ class Deck(InputFile):
             if not net_test_loader.collector and not self.no_collector:
                 requires_collector.append(net_test_loader)
 
-            for th in net_test_loader.requiredTestHelpers:
-                # {'name':'', 'option':'', 'test_class':''}
-                if th['test_class'].localOptions[th['option']]:
-                    continue
-                required_test_helpers.append(th['name'])
+            required_test_helpers += map(lambda x: x[1],
+                                           net_test_loader.missingTestHelpers)
 
         if not required_test_helpers and not requires_collector:
             defer.returnValue(None)
@@ -265,33 +253,34 @@ class Deck(InputFile):
 
         for net_test_loader in self.netTestLoaders:
             log.msg("Setting collector and test helpers for %s" %
-                    net_test_loader.testDetails['test_name'])
+                    net_test_loader.testName)
 
             # Only set the collector if the no collector has been specified
             # from the command line or via the test deck.
-            if not net_test_loader.requiredTestHelpers and \
+            if len(net_test_loader.missingTestHelpers) == 0 and \
                             net_test_loader in requires_collector:
                 log.msg("Using the default collector: %s" %
                         response['default']['collector'])
                 net_test_loader.collector = response['default']['collector'].encode('utf-8')
                 continue
 
-            for th in net_test_loader.requiredTestHelpers:
-                # Only set helpers which are not already specified
-                if th['name'] not in required_test_helpers:
-                    continue
-                test_helper = response[th['name']]
-                log.msg("Using this helper: %s" % test_helper)
-                th['test_class'].localOptions[th['option']] = test_helper['address'].encode('utf-8')
+            for option, name in net_test_loader.missingTestHelpers:
+                test_helper_address = response[name]['address'].encode('utf-8')
+                test_helper_collector = \
+                    response[name]['collector'].encode('utf-8')
+
+                log.msg("Using this helper: %s" % test_helper_address)
+                net_test_loader.localOptions[option] = test_helper_address
+                net_test_loader.testHelpers[option] = test_helper_address
                 if net_test_loader in requires_collector:
-                    net_test_loader.collector = test_helper['collector'].encode('utf-8')
+                    net_test_loader.collector = test_helper_collector
 
     @defer.inlineCallbacks
     def fetchAndVerifyNetTestInput(self, net_test_loader):
         """ fetch and verify a single NetTest's inputs """
         log.debug("Fetching and verifying inputs")
         for i in net_test_loader.inputFiles:
-            if 'url' in i:
+            if i['url']:
                 log.debug("Downloading %s" % i['url'])
                 self.oonibclient.address = i['address']
 
@@ -305,4 +294,4 @@ class Deck(InputFile):
                 except AssertionError:
                     raise e.UnableToLoadDeckInput
 
-                i['test_class'].localOptions[i['key']] = input_file.cached_file
+                i['test_options'][i['key']] = input_file.cached_file
diff --git a/ooni/director.py b/ooni/director.py
index a60b91e..02619c2 100644
--- a/ooni/director.py
+++ b/ooni/director.py
@@ -241,21 +241,19 @@ class Director(object):
             net_test_loader:
                 an instance of :class:ooni.nettest.NetTestLoader
         """
-        # Here we set the test details again since the geoip lookups may
-        # not have already been done and probe_asn and probe_ip
-        # are not set.
-        net_test_loader.setTestDetails()
+        test_details = net_test_loader.getTestDetails()
+        test_cases = net_test_loader.getTestCases()
 
         if self.allTestsDone.called:
             self.allTestsDone = defer.Deferred()
 
         if config.privacy.includepcap:
-            self.startSniffing(net_test_loader.testDetails)
-        report = Report(net_test_loader.testDetails, report_filename,
+            self.startSniffing(test_details)
+        report = Report(test_details, report_filename,
                         self.reportEntryManager, collector_address,
                         no_yamloo)
 
-        net_test = NetTest(net_test_loader, report)
+        net_test = NetTest(test_cases, test_details, report)
         net_test.director = self
 
         yield net_test.report.open()
diff --git a/ooni/errors.py b/ooni/errors.py
index f98a09f..0412b50 100644
--- a/ooni/errors.py
+++ b/ooni/errors.py
@@ -195,6 +195,8 @@ class MissingRequiredOption(Exception):
     def __str__(self):
         return ','.join(self.message)
 
+class MissingTestHelper(MissingRequiredOption):
+    pass
 
 class OONIUsageError(usage.UsageError):
     def __init__(self, net_test_loader):
diff --git a/ooni/nettest.py b/ooni/nettest.py
index 9645400..074e8c2 100644
--- a/ooni/nettest.py
+++ b/ooni/nettest.py
@@ -8,6 +8,7 @@ from twisted.internet import defer
 from twisted.trial.runner import filenameToModule
 from twisted.python import usage, reflect
 
+from ooni import __version__ as ooniprobe_version
 from ooni import otime
 from ooni.tasks import Measurement
 from ooni.utils import log, sanitize_options, randomStr
@@ -128,140 +129,198 @@ def getNetTestInformation(net_test_file):
     return information
 
 
+def usageOptionsFactory(test_name, test_version):
+
+    class UsageOptions(usage.Options):
+        optParameters = []
+        optFlags = []
+
+        synopsis = "{} {} [options]".format(
+            os.path.basename(sys.argv[0]),
+            test_name
+        )
+
+        def opt_version(self):
+            """
+            Display the net_test version and exit.
+            """
+            print "{} version: {}".format(test_name, test_version)
+            sys.exit(0)
+
+    return UsageOptions
+
+def netTestCaseFactory(test_class, local_options):
+    class NetTestCaseWithLocalOptions(test_class):
+        localOptions = local_options
+    return NetTestCaseWithLocalOptions
+
+ONION_INPUT_REGEXP = re.compile("(httpo://[a-z0-9]{16}\.onion)/input/(["
+                                "a-z0-9]{64})$")
+
 class NetTestLoader(object):
     method_prefix = 'test'
     collector = None
     yamloo = True
-    requiresTor = False
-    reportID = None
 
-    def __init__(self, options, test_file=None, test_string=None):
-        self.onionInputRegex = re.compile(
-            "(httpo://[a-z0-9]{16}\.onion)/input/([a-z0-9]{64})$")
+    def __init__(self, options, test_file=None, test_string=None,
+                 annotations={}):
         self.options = options
-        self.testCases = []
-        self.annotations = {}
+        self.annotations = annotations
+
+        self.requiresTor = False
+
+        self.testName = ""
+        self.testVersion = ""
+        self.reportId = None
+
+        self.testHelpers = {}
+        self.missingTestHelpers = []
+        self.usageOptions = None
+        self.inputFiles = []
+
+        self._testCases = []
+        self.localOptions = None
 
         if test_file:
             self.loadNetTestFile(test_file)
         elif test_string:
             self.loadNetTestString(test_string)
 
-    @property
-    def requiredTestHelpers(self):
-        required_test_helpers = []
-        if not self.testCases:
-            return required_test_helpers
-
-        for test_class, test_methods in self.testCases:
-            for option, name in test_class.requiredTestHelpers.items():
-                required_test_helpers.append({
-                    'name': name,
-                    'option': option,
-                    'test_class': test_class
-                })
-        return required_test_helpers
-
-    @property
-    def inputFiles(self):
-        input_files = []
-        if not self.testCases:
-            return input_files
-
-        for test_class, test_methods in self.testCases:
-            if test_class.inputFile:
-                key = test_class.inputFile[0]
-                filename = test_class.localOptions[key]
-                if not filename:
-                    continue
-                input_file = {
-                    'key': key,
-                    'test_class': test_class
-                }
-                m = self.onionInputRegex.match(filename)
-                if m:
-                    input_file['url'] = filename
-                    input_file['address'] = m.group(1)
-                    input_file['hash'] = m.group(2)
-                else:
-                    input_file['filename'] = filename
-                    try:
-                        with open(filename) as f:
-                            h = sha256()
-                            for l in f:
-                                h.update(l)
-                    except:
-                        raise e.InvalidInputFile(filename)
-                    input_file['hash'] = h.hexdigest()
-                input_files.append(input_file)
-
-        return input_files
-
-    def setTestDetails(self):
-        from ooni import __version__ as software_version
-
-        input_file_hashes = []
-        for input_file in self.inputFiles:
-            input_file_hashes.append(input_file['hash'])
-
-        options = sanitize_options(self.options)
-        self.testDetails = {
-            'test_start_time': otime.timestampNowLongUTC(),
+    def getTestDetails(self):
+        return {
             'probe_asn': config.probe_ip.geodata['asn'],
             'probe_cc': config.probe_ip.geodata['countrycode'],
             'probe_ip': config.probe_ip.geodata['ip'],
             'probe_city': config.probe_ip.geodata['city'],
+            'software_name': 'ooniprobe',
+            'software_version': ooniprobe_version,
+            'options': sanitize_options(self.options),
+            'annotations': self.annotations,
+            'data_format_version': '0.2.0',
             'test_name': self.testName,
             'test_version': self.testVersion,
-            'software_name': 'ooniprobe',
-            'software_version': software_version,
-            'options': options,
-            'input_hashes': input_file_hashes,
-            'report_id': self.reportID,
             'test_helpers': self.testHelpers,
-            'annotations': self.annotations,
-            'data_format_version': '0.2.0'
+            'test_start_time': otime.timestampNowLongUTC(),
+            'input_hashes': [input_file['hash']
+                             for input_file in self.inputFiles],
+            'report_id': self.reportId
         }
 
-    def _parseNetTestOptions(self, klass):
+    def getTestCases(self):
         """
-        Helper method to assemble the options into a single UsageOptions object
+        Specialises the test_classes to include the local options.
+        :return:
         """
-        usage_options = klass.usageOptions
+        test_cases = []
+        for test_class, test_method in self._testCases:
+            test_cases.append((netTestCaseFactory(test_class,
+                                                  self.localOptions),
+                               test_method))
+        return test_cases
+
+    def _accumulateInputFiles(self, test_class):
+        if not test_class.inputFile:
+            return
 
-        if not hasattr(usage_options, 'optParameters'):
-            usage_options.optParameters = []
+        key = test_class.inputFile[0]
+        filename = self.localOptions[key]
+        if not filename:
+            return
+
+        input_file = {
+            'key': key,
+            'test_options': self.localOptions,
+            'hash': None,
+
+            'url': None,
+            'address': None,
+
+            'filename': None
+        }
+        m = ONION_INPUT_REGEXP.match(filename)
+        if m:
+            input_file['url'] = filename
+            input_file['address'] = m.group(1)
+            input_file['hash'] = m.group(2)
         else:
-            for parameter in usage_options.optParameters:
+            input_file['filename'] = filename
+            try:
+                with open(filename) as f:
+                    h = sha256()
+                    for l in f:
+                        h.update(l)
+            except:
+                raise e.InvalidInputFile(filename)
+            input_file['hash'] = h.hexdigest()
+        self.inputFiles.append(input_file)
+
+    def _accumulateTestOptions(self, test_class):
+        """
+        Accumulate the optParameters and optFlags for the NetTestCase class
+        into the usageOptions of the NetTestLoader.
+        """
+        if getattr(test_class.usageOptions, 'optParameters', None):
+            for parameter in test_class.usageOptions.optParameters:
+                # XXX should look into if this is still necessary, seems like
+                # something left over from a bug in some nettest.
+                # In theory optParameters should always have a length of 4.
                 if len(parameter) == 5:
                     parameter.pop()
+                self.usageOptions.optParameters.append(parameter)
 
-        if klass.inputFile:
-            usage_options.optParameters.append(klass.inputFile)
+        if getattr(test_class, 'inputFile', None):
+            self.usageOptions.optParameters.append(test_class.inputFile)
 
-        if klass.baseParameters:
-            for parameter in klass.baseParameters:
-                usage_options.optParameters.append(parameter)
+        if getattr(test_class, 'baseParameters', None):
+            for parameter in test_class.baseParameters:
+                self.usageOptions.optParameters.append(parameter)
 
-        if klass.baseFlags:
-            if not hasattr(usage_options, 'optFlags'):
-                usage_options.optFlags = []
-            for flag in klass.baseFlags:
-                usage_options.optFlags.append(flag)
+        if getattr(test_class, 'baseFlags', None):
+            for flag in test_class.baseFlags:
+                self.usageOptions.optFlags.append(flag)
 
-        return usage_options
-
-    @property
-    def usageOptions(self):
-        usage_options = None
-        for test_class, test_method in self.testCases:
-            if not usage_options:
-                usage_options = self._parseNetTestOptions(test_class)
-            else:
-                if usage_options != test_class.usageOptions:
-                    raise e.IncoherentOptions(usage_options.__name__,
-                                              test_class.usageOptions.__name__)
-        return usage_options
+    def parseLocalOptions(self):
+        """
+        Parses the localOptions for the NetTestLoader.
+        """
+        self.localOptions = self.usageOptions()
+        try:
+            self.localOptions.parseOptions(self.options)
+        except usage.UsageError:
+            tb = sys.exc_info()[2]
+            raise e.OONIUsageError(self), None, tb
+
+    def _checkTestClassOptions(self, test_class):
+        if test_class.requiresRoot and not hasRawSocketPermission():
+            raise e.InsufficientPrivileges
+        if test_class.requiresTor:
+            self.requiresTor = True
+        self._checkRequiredOptions(test_class)
+        self._setTestHelpers(test_class)
+        test_instance = netTestCaseFactory(test_class, self.localOptions)()
+        test_instance.requirements()
+
+    def _setTestHelpers(self, test_class):
+        for option, name in test_class.requiredTestHelpers.items():
+            if self.localOptions.get(option, None):
+                self.testHelpers[option] = self.localOptions[option]
+
+    def _checkRequiredOptions(self, test_class):
+        missing_options = []
+        for required_option in test_class.requiredOptions:
+            log.debug("Checking if %s is present" % required_option)
+            if required_option not in self.localOptions or \
+                    self.localOptions[required_option] is None:
+                missing_options.append(required_option)
+        missing_test_helpers = [opt in test_class.requiredTestHelpers.keys()
+                                for opt in missing_options]
+        if len(missing_test_helpers) and all(missing_test_helpers):
+            self.missingTestHelpers = map(lambda x:
+                                            (x, test_class.requiredTestHelpers[x]),
+                                          missing_options)
+            raise e.MissingTestHelper(missing_options, test_class)
+        elif missing_options:
+            raise e.MissingRequiredOption(missing_options, test_class)
 
     def loadNetTestString(self, net_test_string):
         """
@@ -280,12 +339,12 @@ class NetTestLoader(object):
         test_cases = []
         exec net_test_file_object.read() in ns
         for item in ns.itervalues():
-            test_cases.extend(self._get_test_methods(item))
+            test_cases.extend(self._getTestMethods(item))
 
         if not test_cases:
             raise e.NoTestCasesFound
 
-        self.setupTestCases(test_cases)
+        self._setupTestCases(test_cases)
 
     def loadNetTestFile(self, net_test_file):
         """
@@ -294,27 +353,26 @@ class NetTestLoader(object):
         test_cases = []
         module = filenameToModule(net_test_file)
         for __, item in getmembers(module):
-            test_cases.extend(self._get_test_methods(item))
+            test_cases.extend(self._getTestMethods(item))
 
         if not test_cases:
             raise e.NoTestCasesFound
 
-        self.setupTestCases(test_cases)
+        self._setupTestCases(test_cases)
 
-    def setupTestCases(self, test_cases):
+    def _setupTestCases(self, test_cases):
         """
         Creates all the necessary test_cases (a list of tuples containing the
         NetTestCase (test_class, test_method))
 
         example:
-            [(test_classA, test_method1),
-            (test_classA, test_method2),
-            (test_classA, test_method3),
-            (test_classA, test_method4),
-            (test_classA, test_method5),
-
-            (test_classB, test_method1),
-            (test_classB, test_method2)]
+            [(test_classA, [test_method1,
+                            test_method2,
+                            test_method3,
+                            test_method4,
+                            test_method5]),
+            (test_classB, [test_method1,
+                           test_method2])]
 
         Note: the inputs must be valid for test_classA and test_classB.
 
@@ -323,46 +381,37 @@ class NetTestLoader(object):
             generate the test_cases.
         """
         test_class, _ = test_cases[0]
-        self.testVersion = test_class.version
         self.testName = test_class_name_to_name(test_class.name)
-        self.testCases = test_cases
-        self.testClasses = set([])
-        self.testHelpers = {}
+        self.testVersion = test_class.version
+        self._testCases = test_cases
 
-        if config.reports.unique_id is True and not self.reportID:
-            self.reportID = randomStr(64)
+        self.usageOptions = usageOptionsFactory(self.testName,
+                                                self.testVersion)
 
-        for test_class, test_method in self.testCases:
-            self.testClasses.add(test_class)
+        if config.reports.unique_id is True:
+            self.reportId = randomStr(64)
+
+        for test_class, test_methods in self._testCases:
+            self._accumulateTestOptions(test_class)
 
     def checkOptions(self):
-        """
-        Call processTest and processOptions methods of each NetTestCase
-        """
-        for klass in self.testClasses:
-            options = self.usageOptions()
+        self.parseLocalOptions()
+        test_options_exc = None
+        usage_options = self._testCases[0][0].usageOptions
+        for test_class, test_methods in self._testCases:
             try:
-                options.parseOptions(self.options)
-            except usage.UsageError:
-                tb = sys.exc_info()[2]
-                raise e.OONIUsageError(self), None, tb
-
-            if options:
-                klass.localOptions = options
-            # XXX this class all needs to be refactored and this is kind of a
-            # hack.
-            self.setTestDetails()
-
-            test_instance = klass()
-            if test_instance.requiresRoot and not hasRawSocketPermission():
-                raise e.InsufficientPrivileges
-            if test_instance.requiresTor:
-                self.requiresTor = True
-            test_instance.requirements()
-            test_instance._checkRequiredOptions()
-            test_instance._checkValidOptions()
-
-    def _get_test_methods(self, item):
+                self._accumulateInputFiles(test_class)
+                self._checkTestClassOptions(test_class)
+                if usage_options != test_class.usageOptions:
+                    raise e.IncoherentOptions(usage_options.__name__,
+                                              test_class.usageOptions.__name__)
+            except Exception as exc:
+                test_options_exc = exc
+
+        if test_options_exc is not None:
+            raise test_options_exc
+
+    def _getTestMethods(self, item):
         """
         Look for test_ methods in subclasses of NetTestCase
         """
@@ -432,7 +481,7 @@ class NetTestState(object):
 class NetTest(object):
     director = None
 
-    def __init__(self, net_test_loader, report):
+    def __init__(self, test_cases, test_details, report):
         """
         net_test_loader:
              an instance of :class:ooni.nettest.NetTestLoader containing
@@ -442,9 +491,11 @@ class NetTest(object):
             an instance of :class:ooni.reporter.Reporter
         """
         self.report = report
-        self.testCases = net_test_loader.testCases
-        self.testClasses = net_test_loader.testClasses
-        self.testDetails = net_test_loader.testDetails
+
+        self.testDetails = test_details
+        self.testCases = test_cases
+
+        self.testInstances = []
 
         self.summary = {}
 
@@ -459,11 +510,18 @@ class NetTest(object):
     def __str__(self):
         return ' '.join(tc.name for tc, _ in self.testCases)
 
+    def uniqueClasses(self):
+        classes = []
+        for test_class, test_method in self.testCases:
+            if test_class not in classes:
+                classes.append(test_class)
+        return classes
+
     def doneNetTest(self, result):
         if self.summary:
             print "Summary for %s" % self.testDetails['test_name']
             print "------------" + "-"*len(self.testDetails['test_name'])
-            for test_class in self.testClasses:
+            for test_class in self.uniqueClasses():
                 test_instance = test_class()
                 test_instance.displaySummary(self.summary)
         if self.testDetails["report_id"]:
@@ -510,12 +568,10 @@ class NetTest(object):
 
     @defer.inlineCallbacks
     def initializeInputProcessor(self):
-        for test_class, _ in self.testCases:
+        for test_class, test_method in self.testCases:
             test_class.inputs = yield defer.maybeDeferred(
                 test_class().getInputProcessor
             )
-            if not test_class.inputs:
-                test_class.inputs = [None]
 
     def generateMeasurements(self):
         """
@@ -531,7 +587,7 @@ class NetTest(object):
                 test_instance._setUp()
                 test_instance.summary = self.summary
                 for method in test_methods:
-                    log.debug("Running %s %s" % (test_class, method))
+                    log.debug("Running %s %s" % (test_instance, method))
                     measurement = self.makeMeasurement(
                         test_instance,
                         method,
@@ -631,8 +687,6 @@ class NetTestCase(object):
     inputFile = None
     inputFilename = None
 
-    report = {}
-
     usageOptions = usage.Options
 
     optParameters = None
@@ -766,23 +820,7 @@ class NetTestCase(object):
         if self.inputs:
             return self.inputs
 
-        return None
-
-    def _checkValidOptions(self):
-        for option in self.localOptions:
-            if option not in self.usageOptions():
-                if not self.inputFile or option not in self.inputFile:
-                    raise e.InvalidOption
-
-    def _checkRequiredOptions(self):
-        missing_options = []
-        for required_option in self.requiredOptions:
-            log.debug("Checking if %s is present" % required_option)
-            if required_option not in self.localOptions or \
-                    self.localOptions[required_option] is None:
-                missing_options.append(required_option)
-        if missing_options:
-            raise e.MissingRequiredOption(missing_options, self)
+        return [None]
 
     def __repr__(self):
         return "<%s inputs=%s>" % (self.__class__, self.inputs)
diff --git a/ooni/oonicli.py b/ooni/oonicli.py
index 9327afc..e2d78e1 100644
--- a/ooni/oonicli.py
+++ b/ooni/oonicli.py
@@ -261,7 +261,8 @@ def createDeck(global_options, url=None):
             if any(global_options['subargs']):
                 args = global_options['subargs'] + args
             net_test_loader = NetTestLoader(args,
-                                            test_file=test_file)
+                                            test_file=test_file,
+                                            annotations=global_options['annotations'])
             if global_options['collector']:
                 net_test_loader.collector = global_options['collector']
             deck.insert(net_test_loader)
@@ -328,8 +329,6 @@ def runTestWithDirector(director, global_options, url=None, start_tor=True):
                 collector_address = setupCollector(global_options,
                                                    net_test_loader.collector)
 
-            net_test_loader.annotations = global_options['annotations']
-
             yield director.startNetTest(net_test_loader,
                                         global_options['reportfile'],
                                         collector_address,
diff --git a/ooni/reporter.py b/ooni/reporter.py
index 2ee8468..63d6f15 100644
--- a/ooni/reporter.py
+++ b/ooni/reporter.py
@@ -626,10 +626,9 @@ class Report(object):
                                                    self.collector_address)
 
         def created(report_id):
-            self.reportID = report_id
-            self.test_details['report_id'] = report_id
             if not self.oonib_reporter:
                 return
+            self.test_details['report_id'] = report_id
             return self.report_log.created(self.report_filename,
                                            self.collector_address,
                                            report_id)
diff --git a/ooni/tests/test_deck.py b/ooni/tests/test_deck.py
index d82c0eb..7b423af 100644
--- a/ooni/tests/test_deck.py
+++ b/ooni/tests/test_deck.py
@@ -44,6 +44,32 @@ class BaseTestCase(unittest.TestCase):
             test_file: manipulation/http_invalid_request_line
             testdeck: null
 """
+        self.dummy_deck_content_with_many_tests = """- options:
+            collector: null
+            help: 0
+            logfile: null
+            no-default-reporter: 0
+            parallelism: null
+            pcapfile: null
+            reportfile: null
+            resume: 0
+            subargs: [-b, "1.1.1.1"]
+            test_file: manipulation/http_invalid_request_line
+            testdeck: null
+- options:
+            collector: null
+            help: 0
+            logfile: null
+            no-default-reporter: 0
+            parallelism: null
+            pcapfile: null
+            reportfile: null
+            resume: 0
+            subargs: [-b, "2.2.2.2"]
+            test_file: manipulation/http_invalid_request_line
+            testdeck: null
+"""
+
 
 
 class TestInputFile(BaseTestCase):
@@ -127,10 +153,32 @@ class TestDeck(BaseTestCase):
         deck.bouncer = "httpo://foo.onion"
         deck.oonibclient = MockOONIBClient()
         deck.loadDeck(self.deck_file)
+
+        self.assertEqual(len(deck.netTestLoaders[0].missingTestHelpers), 1)
+
         yield deck.lookupTestHelpers()
 
-        assert deck.netTestLoaders[0].collector == 'httpo://thirteenchars1234.onion'
+        self.assertEqual(deck.netTestLoaders[0].collector,
+                         'httpo://thirteenchars1234.onion')
+
+        self.assertEqual(deck.netTestLoaders[0].localOptions['backend'],
+                         '127.0.0.1')
+
+
+    def test_deck_with_many_tests(self):
+        os.remove(self.deck_file)
+        deck_hash = sha256(self.dummy_deck_content_with_many_tests).hexdigest()
+        self.deck_file = os.path.join(self.cwd, deck_hash)
+        with open(self.deck_file, 'w+') as f:
+            f.write(self.dummy_deck_content_with_many_tests)
+        deck = Deck(decks_directory=".")
+        deck.loadDeck(self.deck_file)
 
-        required_test_helpers = deck.netTestLoaders[0].requiredTestHelpers
-        assert len(required_test_helpers) == 1
-        assert required_test_helpers[0]['test_class'].localOptions['backend'] == '127.0.0.1'
+        self.assertEqual(
+            deck.netTestLoaders[0].localOptions['backend'],
+            '1.1.1.1'
+        )
+        self.assertEqual(
+            deck.netTestLoaders[1].localOptions['backend'],
+            '2.2.2.2'
+        )
diff --git a/ooni/tests/test_director.py b/ooni/tests/test_director.py
index 61d504c..5875ccb 100644
--- a/ooni/tests/test_director.py
+++ b/ooni/tests/test_director.py
@@ -2,6 +2,7 @@ from mock import patch, MagicMock
 
 from ooni.settings import config
 from ooni.director import Director
+from ooni.nettest import NetTestLoader
 from ooni.tests.bases import ConfigTestCase
 
 from twisted.internet import defer
@@ -9,6 +10,31 @@ from twisted.trial import unittest
 
 from txtorcon import TorControlProtocol
 
+test_failing_twice = """
+from twisted.internet import defer, reactor
+from ooni.nettest import NetTestCase
+
+class TestFailingTwice(NetTestCase):
+    inputs = ["spam-{}".format(idx) for idx in range(50)]
+
+    def setUp(self):
+        self.summary[self.input] = self.summary.get(self.input, 0)
+
+    def test_a(self):
+        run_count = self.summary[self.input]
+        delay = float(self.input.split("-")[1])/1000
+        d = defer.Deferred()
+        def callback():
+            self.summary[self.input] += 1
+            if run_count < 3:
+                d.errback(Exception("Failing"))
+            else:
+                d.callback(self.summary[self.input])
+
+        reactor.callLater(delay, callback)
+        return d
+"""
+
 proto = MagicMock()
 proto.tor_protocol = TorControlProtocol()
 
@@ -52,6 +78,23 @@ class TestDirector(ConfigTestCase):
 
         return director_start_tor()
 
+    def test_run_test_fails_twice(self):
+        finished = defer.Deferred()
+
+        def net_test_done(net_test):
+            summary_items = net_test.summary.items()
+            self.assertEqual(len(summary_items), 50)
+            for input_name, run_count in summary_items:
+                self.assertEqual(run_count, 3)
+            finished.callback(None)
+
+        net_test_loader = NetTestLoader(('spam','ham'))
+        net_test_loader.loadNetTestString(test_failing_twice)
+        director = Director()
+        director.netTestDone = net_test_done
+        director.startNetTest(net_test_loader, None, no_yamloo=True)
+        return finished
+
 
 class TestStartSniffing(unittest.TestCase):
     def setUp(self):
diff --git a/ooni/tests/test_nettest.py b/ooni/tests/test_nettest.py
index e93111b..1108f11 100644
--- a/ooni/tests/test_nettest.py
+++ b/ooni/tests/test_nettest.py
@@ -186,13 +186,6 @@ class TestNetTest(unittest.TestCase):
                 uniq_test_methods.add(test_method)
         self.assertEqual(set(['test_a', 'test_b']), uniq_test_methods)
 
-    def verifyClasses(self, test_cases, control_classes):
-        actual_classes = set()
-        for test_class, test_methods in test_cases:
-            actual_classes.add(test_class.__name__)
-
-        self.assertEqual(actual_classes, control_classes)
-
     def test_load_net_test_from_file(self):
         """
         Given a file verify that the net test cases are properly
@@ -206,7 +199,7 @@ class TestNetTest(unittest.TestCase):
         ntl = NetTestLoader(dummyArgs)
         ntl.loadNetTestFile(net_test_file)
 
-        self.verifyMethods(ntl.testCases)
+        self.verifyMethods(ntl.getTestCases())
         os.unlink(net_test_file)
 
     def test_load_net_test_from_str(self):
@@ -217,24 +210,21 @@ class TestNetTest(unittest.TestCase):
         ntl = NetTestLoader(dummyArgs)
         ntl.loadNetTestString(net_test_string)
 
-        self.verifyMethods(ntl.testCases)
+        self.verifyMethods(ntl.getTestCases())
 
     def test_load_net_test_multiple(self):
         ntl = NetTestLoader(dummyArgs)
         ntl.loadNetTestString(double_net_test_string)
-
-        self.verifyMethods(ntl.testCases)
-        self.verifyClasses(ntl.testCases, set(('DummyTestCaseA', 'DummyTestCaseB')))
-
+        test_cases = ntl.getTestCases()
+        self.verifyMethods(test_cases)
         ntl.checkOptions()
 
     def test_load_net_test_multiple_different_options(self):
         ntl = NetTestLoader(dummyArgs)
         ntl.loadNetTestString(double_different_options_net_test_string)
 
-        self.verifyMethods(ntl.testCases)
-        self.verifyClasses(ntl.testCases, set(('DummyTestCaseA', 'DummyTestCaseB')))
-
+        test_cases = ntl.getTestCases()
+        self.verifyMethods(test_cases)
         self.assertRaises(IncoherentOptions, ntl.checkOptions)
 
     def test_load_with_option(self):
@@ -242,7 +232,7 @@ class TestNetTest(unittest.TestCase):
         ntl.loadNetTestString(net_test_string)
 
         self.assertIsInstance(ntl, NetTestLoader)
-        for test_klass, test_meth in ntl.testCases:
+        for test_klass, test_meth in ntl.getTestCases():
             for option in dummyOptions.keys():
                 self.assertIn(option, test_klass.usageOptions())
 
@@ -266,34 +256,29 @@ class TestNetTest(unittest.TestCase):
     def test_net_test_inputs(self):
         ntl = NetTestLoader(dummyArgsWithFile)
         ntl.loadNetTestString(net_test_string_with_file)
-
         ntl.checkOptions()
-        nt = NetTest(ntl, None)
+        nt = NetTest(ntl.getTestCases(), ntl.getTestDetails(), None)
         nt.initializeInputProcessor()
 
         # XXX: if you use the same test_class twice you will have consumed all
         # of its inputs!
         tested = set([])
-        for test_class, test_method in ntl.testCases:
-            if test_class not in tested:
-                tested.update([test_class])
-                self.assertEqual(len(list(test_class.inputs)), 10)
+        for test_instance, test_method, inputs in nt.testInstances:
+            self.assertEqual(len(list(inputs)), 10)
 
     def test_setup_local_options_in_test_cases(self):
         ntl = NetTestLoader(dummyArgs)
         ntl.loadNetTestString(net_test_string)
 
         ntl.checkOptions()
-
-        for test_class, test_method in ntl.testCases:
-            self.assertEqual(test_class.localOptions, dummyOptions)
+        self.assertEqual(dict(ntl.localOptions), dummyOptions)
 
     def test_generate_measurements_size(self):
         ntl = NetTestLoader(dummyArgsWithFile)
         ntl.loadNetTestString(net_test_string_with_file)
-
         ntl.checkOptions()
-        net_test = NetTest(ntl, None)
+
+        net_test = NetTest(ntl.getTestCases(), ntl.getTestDetails(), None)
 
         net_test.initializeInputProcessor()
         measurements = list(net_test.generateMeasurements())
@@ -321,7 +306,7 @@ class TestNetTest(unittest.TestCase):
         ntl = NetTestLoader(dummyArgs)
         ntl.loadNetTestString(net_test_root_required)
 
-        for test_class, method in ntl.testCases:
+        for test_class, methods in ntl.getTestCases():
             self.assertTrue(test_class.requiresRoot)
 
 
diff --git a/ooni/tests/test_oonicli.py b/ooni/tests/test_oonicli.py
index f5c8545..16e3f77 100644
--- a/ooni/tests/test_oonicli.py
+++ b/ooni/tests/test_oonicli.py
@@ -62,9 +62,7 @@ class TestRunDirector(ConfigTestCase):
         super(TestRunDirector, self).setUp()
         if not is_internet_connected():
             self.skipTest("You must be connected to the internet to run this test")
-        elif not hasRawSocketPermission():
-            self.skipTest("You must run this test as root or have the capabilities "
-            "cap_net_admin,cap_net_raw+eip")
+
         config.tor.socks_port = 9050
         config.tor.control_port = None
         self.filenames = ['example-input.txt']
@@ -165,6 +163,9 @@ class TestRunDirector(ConfigTestCase):
 
     @defer.inlineCallbacks
     def test_sniffing_activated(self):
+        if not hasRawSocketPermission():
+            self.skipTest("You must run this test as root or have the "
+                          "capabilities cap_net_admin,cap_net_raw+eip")
         self.skipTest("Not properly set packet capture?")
         filename = os.path.abspath('test_report.pcap')
         self.filenames.append(filename)



_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits