[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[tor-commits] [stem/master] Using mock for controller unit tests



commit 8303c8d99ac6728367dabec3f6b2c9e43bb68c69
Author: Damian Johnson <atagar@xxxxxxxxxxxxxx>
Date:   Tue Jun 11 08:43:05 2013 -0700

    Using mock for controller unit tests
    
    Another heavy mock user. Nice improvements, I'm really liking this library!
---
 test/unit/control/controller.py |  217 ++++++++++++++++++++-------------------
 1 file changed, 113 insertions(+), 104 deletions(-)

diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index a0deaf1..ce25d30 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -13,10 +13,11 @@ import stem.socket
 import stem.util.system
 import stem.version
 
+from mock import Mock, patch
+
 from stem import InvalidArguments, InvalidRequest, ProtocolError, UnsatisfiableRequest
 from stem.control import _parse_circ_path, Controller, EventType
 from stem.exit_policy import ExitPolicy
-from stem.socket import ControlSocket
 from test import mocking
 
 
@@ -24,14 +25,11 @@ class TestControl(unittest.TestCase):
   def setUp(self):
     socket = stem.socket.ControlSocket()
 
-    mocking.mock_method(Controller, "add_event_listener", mocking.no_op())
-    self.controller = Controller(socket)
-    mocking.revert_mocking()
-
-  def tearDown(self):
-    mocking.revert_mocking()
+    with patch('stem.control.Controller.add_event_listener', Mock()):
+      self.controller = Controller(socket)
 
-  def test_get_version(self):
+  @patch('stem.control.Controller.get_info')
+  def test_get_version(self, get_info_mock):
     """
     Exercises the get_version() method.
     """
@@ -40,7 +38,7 @@ class TestControl(unittest.TestCase):
       # Use one version for first check.
       version_2_1 = "0.2.1.32"
       version_2_1_object = stem.version.Version(version_2_1)
-      mocking.mock_method(Controller, "get_info", mocking.return_value(version_2_1))
+      get_info_mock.return_value = version_2_1
 
       # Return a version with a cold cache.
       self.assertEqual(version_2_1_object, self.controller.get_version())
@@ -48,7 +46,7 @@ class TestControl(unittest.TestCase):
       # Use a different version for second check.
       version_2_2 = "0.2.2.39"
       version_2_2_object = stem.version.Version(version_2_2)
-      mocking.mock_method(Controller, "get_info", mocking.return_value(version_2_2))
+      get_info_mock.return_value = version_2_2
 
       # Return a version with a hot cache, so it will be the old version.
       self.assertEqual(version_2_1_object, self.controller.get_version())
@@ -59,7 +57,7 @@ class TestControl(unittest.TestCase):
       self.assertEqual(version_2_2_object, self.controller.get_version())
 
       # Raise an exception in the get_info() call.
-      mocking.mock_method(Controller, "get_info", mocking.raise_exception(InvalidArguments))
+      get_info_mock.side_effect = InvalidArguments
 
       # Get a default value when the call fails.
       self.assertEqual(
@@ -72,26 +70,29 @@ class TestControl(unittest.TestCase):
 
       # Give a bad version.  The stem.version.Version ValueError should bubble up.
       version_A_42 = "0.A.42.spam"
-      mocking.mock_method(Controller, "get_info", mocking.return_value(version_A_42))
+      get_info_mock.return_value = version_A_42
+      get_info_mock.side_effect = None
       self.assertRaises(ValueError, self.controller.get_version)
     finally:
       # Turn caching back on before we leave.
       self.controller._is_caching_enabled = True
 
-  def test_get_exit_policy(self):
+  @patch('stem.control.Controller.get_info')
+  @patch('stem.control.Controller.get_conf')
+  def test_get_exit_policy(self, get_conf_mock, get_info_mock):
     """
     Exercises the get_exit_policy() method.
     """
 
-    mocking.mock_method(Controller, "get_conf", mocking.return_for_args({
-      ("ExitPolicyRejectPrivate",): "1",
-      ("ExitPolicy", "multiple=True"): ["accept *:80,   accept *:443", "accept 43.5.5.5,reject *:22"]
-    }, is_method = True))
+    get_conf_mock.side_effect = lambda param, **kwargs: {
+      "ExitPolicyRejectPrivate": "1",
+      "ExitPolicy": ["accept *:80,   accept *:443", "accept 43.5.5.5,reject *:22"],
+    }[param]
 
-    mocking.mock_method(Controller, "get_info", mocking.return_for_args({
-      ("address", None): "123.45.67.89",
-      ("exit-policy/default",): "reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*"
-    }, is_method = True))
+    get_info_mock.side_effect = lambda param, default = None: {
+      "address": "123.45.67.89",
+      "exit-policy/default": "reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*",
+    }[param]
 
     expected = ExitPolicy(
       'reject 0.0.0.0/8:*',  # private entries
@@ -120,7 +121,9 @@ class TestControl(unittest.TestCase):
 
     self.assertEqual(expected, self.controller.get_exit_policy())
 
-  def test_get_socks_listeners_old(self):
+  @patch('stem.control.Controller.get_info')
+  @patch('stem.control.Controller.get_conf')
+  def test_get_socks_listeners_old(self, get_conf_mock, get_info_mock):
     """
     Exercises the get_socks_listeners() method as though talking to an old tor
     instance.
@@ -129,68 +132,75 @@ class TestControl(unittest.TestCase):
     # An old tor raises stem.InvalidArguments for get_info about socks, but
     # get_socks_listeners should work anyway.
 
-    mocking.mock_method(Controller, "get_info", mocking.raise_exception(InvalidArguments))
+    get_info_mock.side_effect = InvalidArguments
+
+    get_conf_mock.side_effect = lambda param, **kwargs: {
+      "SocksPort": "9050",
+      "SocksListenAddress": ["127.0.0.1"],
+    }[param]
 
-    mocking.mock_method(Controller, "get_conf", mocking.return_for_args({
-      ("SocksPort",): "9050",
-      ("SocksListenAddress", "multiple=True"): ["127.0.0.1"]
-    }, is_method = True))
     self.assertEqual([('127.0.0.1', 9050)], self.controller.get_socks_listeners())
 
     # Again, an old tor, but SocksListenAddress overrides the port number.
 
-    mocking.mock_method(Controller, "get_conf", mocking.return_for_args({
-      ("SocksPort",): "9050",
-      ("SocksListenAddress", "multiple=True"): ["127.0.0.1:1112"]
-    }, is_method = True))
+    get_conf_mock.side_effect = lambda param, **kwargs: {
+      "SocksPort": "9050",
+      "SocksListenAddress": ["127.0.0.1:1112"],
+    }[param]
+
     self.assertEqual([('127.0.0.1', 1112)], self.controller.get_socks_listeners())
 
     # Again, an old tor, but multiple listeners
 
-    mocking.mock_method(Controller, "get_conf", mocking.return_for_args({
-      ("SocksPort",): "9050",
-      ("SocksListenAddress", "multiple=True"): ["127.0.0.1:1112", "127.0.0.1:1114"]
-    }, is_method = True))
+    get_conf_mock.side_effect = lambda param, **kwargs: {
+      "SocksPort": "9050",
+      "SocksListenAddress": ["127.0.0.1:1112", "127.0.0.1:1114"],
+    }[param]
+
     self.assertEqual([('127.0.0.1', 1112), ('127.0.0.1', 1114)], self.controller.get_socks_listeners())
 
     # Again, an old tor, but no SOCKS listeners
 
-    mocking.mock_method(Controller, "get_conf", mocking.return_for_args({
-      ("SocksPort",): "0",
-      ("SocksListenAddress", "multiple=True"): []
-    }, is_method = True))
+    get_conf_mock.side_effect = lambda param, **kwargs: {
+      "SocksPort": "0",
+      "SocksListenAddress": [],
+    }[param]
+
     self.assertEqual([], self.controller.get_socks_listeners())
 
     # Where tor provides invalid ports or addresses
 
-    mocking.mock_method(Controller, "get_conf", mocking.return_for_args({
-      ("SocksPort",): "blarg",
-      ("SocksListenAddress", "multiple=True"): ["127.0.0.1"]
-    }, is_method = True))
+    get_conf_mock.side_effect = lambda param, **kwargs: {
+      "SocksPort": "blarg",
+      "SocksListenAddress": ["127.0.0.1"],
+    }[param]
+
     self.assertRaises(stem.ProtocolError, self.controller.get_socks_listeners)
 
-    mocking.mock_method(Controller, "get_conf", mocking.return_for_args({
-      ("SocksPort",): "0",
-      ("SocksListenAddress", "multiple=True"): ["127.0.0.1:abc"]
-    }, is_method = True))
+    get_conf_mock.side_effect = lambda param, **kwargs: {
+      "SocksPort": "0",
+      "SocksListenAddress": ["127.0.0.1:abc"],
+    }[param]
+
     self.assertRaises(stem.ProtocolError, self.controller.get_socks_listeners)
 
-    mocking.mock_method(Controller, "get_conf", mocking.return_for_args({
-      ("SocksPort",): "40",
-      ("SocksListenAddress", "multiple=True"): ["500.0.0.1"]
-    }, is_method = True))
+    get_conf_mock.side_effect = lambda param, **kwargs: {
+      "SocksPort": "40",
+      "SocksListenAddress": ["500.0.0.1"],
+    }[param]
+
     self.assertRaises(stem.ProtocolError, self.controller.get_socks_listeners)
 
-  def test_get_socks_listeners_new(self):
+  @patch('stem.control.Controller.get_info')
+  def test_get_socks_listeners_new(self, get_info_mock):
     """
     Exercises the get_socks_listeners() method as if talking to a newer tor
     instance.
     """
 
     # multiple SOCKS listeners
-    mocking.mock_method(Controller, "get_info", mocking.return_value(
-      '"127.0.0.1:1112" "127.0.0.1:1114"'
-    ))
+
+    get_info_mock.return_value = '"127.0.0.1:1112" "127.0.0.1:1114"'
 
     self.assertEqual(
       [('127.0.0.1', 1112), ('127.0.0.1', 1114)],
@@ -198,7 +208,8 @@ class TestControl(unittest.TestCase):
     )
 
     # no SOCKS listeners
-    mocking.mock_method(Controller, "get_info", mocking.return_value(""))
+
+    get_info_mock.return_value = ''
     self.assertEqual([], self.controller.get_socks_listeners())
 
     # check where GETINFO provides malformed content
@@ -211,19 +222,18 @@ class TestControl(unittest.TestCase):
     )
 
     for response in invalid_responses:
-      mocking.mock_method(Controller, "get_info", mocking.return_value(response))
+      get_info_mock.return_value = response
       self.assertRaises(stem.ProtocolError, self.controller.get_socks_listeners)
 
-  def test_get_protocolinfo(self):
+  @patch('stem.connection.get_protocolinfo')
+  def test_get_protocolinfo(self, get_protocolinfo_mock):
     """
     Exercises the get_protocolinfo() method.
     """
 
     # use the handy mocked protocolinfo response
 
-    mocking.mock(stem.connection.get_protocolinfo, mocking.return_value(
-      mocking.get_protocolinfo_response()
-    ))
+    get_protocolinfo_mock.return_value = mocking.get_protocolinfo_response()
 
     # compare the str representation of these object, because the class
     # does not have, nor need, a direct comparison operator
@@ -235,7 +245,7 @@ class TestControl(unittest.TestCase):
 
     # raise an exception in the stem.connection.get_protocolinfo() call
 
-    mocking.mock(stem.connection.get_protocolinfo, mocking.raise_exception(ProtocolError))
+    get_protocolinfo_mock.side_effect = ProtocolError
 
     # get a default value when the call fails
 
@@ -248,54 +258,53 @@ class TestControl(unittest.TestCase):
 
     self.assertRaises(ProtocolError, self.controller.get_protocolinfo)
 
+  @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
   def test_get_user_remote(self):
     """
     Exercise the get_user() method for a non-local socket.
     """
 
-    mocking.mock_method(ControlSocket, "is_localhost", mocking.return_false())
-
     self.assertRaises(ValueError, self.controller.get_user)
     self.assertEqual(123, self.controller.get_user(123))
 
+  @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
+  @patch('stem.control.Controller.get_info', Mock(return_value = 'atagar'))
   def test_get_user_by_getinfo(self):
     """
     Exercise the get_user() resolution via its getinfo option.
     """
 
-    mocking.mock_method(ControlSocket, "is_localhost", mocking.return_true())
-    mocking.mock_method(Controller, "get_info", mocking.return_value('atagar'))
     self.assertEqual('atagar', self.controller.get_user())
 
+  @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
+  @patch('stem.util.system.get_pid_by_name', Mock(return_value = 432))
+  @patch('stem.util.system.get_user', Mock(return_value = 'atagar'))
   def test_get_user_by_system(self):
     """
     Exercise the get_user() resolution via the system module.
     """
 
-    mocking.mock_method(ControlSocket, "is_localhost", mocking.return_true())
-    mocking.mock(stem.util.system.get_pid_by_name, mocking.return_value(432))
-    mocking.mock(stem.util.system.get_user, mocking.return_value('atagar'))
     self.assertEqual('atagar', self.controller.get_user())
 
+  @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
   def test_get_pid_remote(self):
     """
     Exercise the get_pid() method for a non-local socket.
     """
 
-    mocking.mock_method(ControlSocket, "is_localhost", mocking.return_false())
-
     self.assertRaises(ValueError, self.controller.get_pid)
     self.assertEqual(123, self.controller.get_pid(123))
 
+  @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
+  @patch('stem.control.Controller.get_info', Mock(return_value = '321'))
   def test_get_pid_by_getinfo(self):
     """
     Exercise the get_pid() resolution via its getinfo option.
     """
 
-    mocking.mock_method(ControlSocket, "is_localhost", mocking.return_true())
-    mocking.mock_method(Controller, "get_info", mocking.return_value('321'))
     self.assertEqual(321, self.controller.get_pid())
 
+  @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
   def test_get_pid_by_pid_file(self):
     """
     Exercise the get_pid() resolution via a PidFile.
@@ -304,29 +313,28 @@ class TestControl(unittest.TestCase):
     # It's a little inappropriate for us to be using tempfile in unit tests,
     # but this is more reliable than trying to mock open().
 
-    mocking.mock_method(ControlSocket, "is_localhost", mocking.return_true())
-
     pid_file_path = tempfile.mkstemp()[1]
 
     try:
       with open(pid_file_path, 'w') as pid_file:
         pid_file.write('432')
 
-      mocking.mock_method(Controller, "get_conf", mocking.return_value(pid_file_path))
-      self.assertEqual(432, self.controller.get_pid())
+      with patch('stem.control.Controller.get_conf', Mock(return_value = pid_file_path)):
+        self.assertEqual(432, self.controller.get_pid())
     finally:
       os.remove(pid_file_path)
 
+  @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
+  @patch('stem.util.system.get_pid_by_name', Mock(return_value = 432))
   def test_get_pid_by_name(self):
     """
     Exercise the get_pid() resolution via the process name.
     """
 
-    mocking.mock_method(ControlSocket, "is_localhost", mocking.return_true())
-    mocking.mock(stem.util.system.get_pid_by_name, mocking.return_value(432))
     self.assertEqual(432, self.controller.get_pid())
 
-  def test_get_network_status(self):
+  @patch('stem.control.Controller.get_info')
+  def test_get_network_status(self, get_info_mock):
     """
     Exercises the get_network_status() method.
     """
@@ -340,7 +348,7 @@ class TestControl(unittest.TestCase):
 
     # always return the same router status entry
 
-    mocking.mock_method(Controller, "get_info", mocking.return_value(desc))
+    get_info_mock.return_value = desc
 
     # pretend to get the router status entry with its name
 
@@ -358,7 +366,7 @@ class TestControl(unittest.TestCase):
 
     # raise an exception in the get_info() call
 
-    mocking.mock_method(Controller, "get_info", mocking.raise_exception(InvalidArguments))
+    get_info_mock.side_effect = InvalidArguments
 
     # get a default value when the call fails
 
@@ -371,25 +379,29 @@ class TestControl(unittest.TestCase):
 
     self.assertRaises(InvalidArguments, self.controller.get_network_status, nickname)
 
-  def test_event_listening(self):
+  @patch('stem.control.Controller.is_authenticated', Mock(return_value = True))
+  @patch('stem.control.Controller._attach_listeners', Mock(return_value = ([], [])))
+  @patch('stem.control.Controller.get_version')
+  def test_event_listening(self, get_version_mock):
     """
     Exercises the add_event_listener and remove_event_listener methods.
     """
 
     # set up for failure to create any events
-    mocking.mock_method(Controller, "is_authenticated", mocking.return_true())
-    mocking.mock_method(Controller, "_attach_listeners", mocking.return_value(([], [])))
-    mocking.mock_method(Controller, "get_version", mocking.return_value(stem.version.Version('0.1.0.14')))
-    self.assertRaises(InvalidRequest, self.controller.add_event_listener, mocking.no_op(), EventType.BW)
+
+    get_version_mock.return_value = stem.version.Version('0.1.0.14')
+    self.assertRaises(InvalidRequest, self.controller.add_event_listener, Mock(), EventType.BW)
 
     # set up to only fail newer events
-    mocking.mock_method(Controller, "get_version", mocking.return_value(stem.version.Version('0.2.0.35')))
+
+    get_version_mock.return_value = stem.version.Version('0.2.0.35')
 
     # EventType.BW is one of the earliest events
-    self.controller.add_event_listener(mocking.no_op(), EventType.BW)
+
+    self.controller.add_event_listener(Mock(), EventType.BW)
 
     # EventType.SIGNAL was added in tor version 0.2.3.1-alpha
-    self.assertRaises(InvalidRequest, self.controller.add_event_listener, mocking.no_op(), EventType.SIGNAL)
+    self.assertRaises(InvalidRequest, self.controller.add_event_listener, Mock(), EventType.SIGNAL)
 
   def test_get_streams(self):
     """
@@ -403,20 +415,17 @@ class TestControl(unittest.TestCase):
       ("3", "SUCCEEDED", "4", "10.10.10.1:80")
     )
 
-    responses = ["%s\r\n" % " ".join(entry) for entry in valid_streams]
-
-    mocking.mock_method(Controller, "get_info", mocking.return_value(
-      "".join(responses)
-    ))
+    response = "".join(["%s\r\n" % " ".join(entry) for entry in valid_streams])
 
-    streams = self.controller.get_streams()
-    self.assertEqual(len(valid_streams), len(streams))
+    with patch('stem.control.Controller.get_info', Mock(return_value = response)):
+      streams = self.controller.get_streams()
+      self.assertEqual(len(valid_streams), len(streams))
 
-    for index, stream in enumerate(streams):
-      self.assertEqual(valid_streams[index][0], stream.id)
-      self.assertEqual(valid_streams[index][1], stream.status)
-      self.assertEqual(valid_streams[index][2], stream.circ_id)
-      self.assertEqual(valid_streams[index][3], stream.target)
+      for index, stream in enumerate(streams):
+        self.assertEqual(valid_streams[index][0], stream.id)
+        self.assertEqual(valid_streams[index][1], stream.status)
+        self.assertEqual(valid_streams[index][2], stream.circ_id)
+        self.assertEqual(valid_streams[index][3], stream.target)
 
   def test_attach_stream(self):
     """
@@ -427,9 +436,9 @@ class TestControl(unittest.TestCase):
     # instance, it's already open).
 
     response = stem.response.ControlMessage.from_str("555 Connection is not managed by controller.\r\n")
-    mocking.mock_method(Controller, "msg", mocking.return_value(response))
 
-    self.assertRaises(UnsatisfiableRequest, self.controller.attach_stream, 'stream_id', 'circ_id')
+    with patch('stem.control.Controller.msg', Mock(return_value = response)):
+      self.assertRaises(UnsatisfiableRequest, self.controller.attach_stream, 'stream_id', 'circ_id')
 
   def test_parse_circ_path(self):
     """



_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits