[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[tor-commits] [flashproxy/master] Add flashproxy.util.addr_family function.



commit 7787056bc7055bfe8efc5df0af20ca9e13a1ffaa
Author: David Fifield <david@xxxxxxxxxxxxxxx>
Date:   Sat Feb 1 17:03:33 2014 -0800

    Add flashproxy.util.addr_family function.
---
 flashproxy/test/test_util.py |   13 ++++++++++++-
 flashproxy/util.py           |    9 +++++++--
 2 files changed, 19 insertions(+), 3 deletions(-)

diff --git a/flashproxy/test/test_util.py b/flashproxy/test/test_util.py
index 935dd1f..e095f38 100644
--- a/flashproxy/test/test_util.py
+++ b/flashproxy/test/test_util.py
@@ -1,8 +1,9 @@
 #!/usr/bin/env python
 
+import socket
 import unittest
 
-from flashproxy.util import parse_addr_spec, canonical_ip
+from flashproxy.util import parse_addr_spec, canonical_ip, addr_family
 
 class ParseAddrSpecTest(unittest.TestCase):
     def test_ipv4(self):
@@ -39,5 +40,15 @@ class ParseAddrSpecTest(unittest.TestCase):
         """Test that canonical_ip does not do DNS resolution by default."""
         self.assertRaises(ValueError, canonical_ip, *parse_addr_spec("example.com:80"))
 
+class AddrFamilyTest(unittest.TestCase):
+    def test_ipv4(self):
+        self.assertEqual(addr_family("1.2.3.4"), socket.AF_INET)
+
+    def test_ipv6(self):
+        self.assertEqual(addr_family("1:2::3:4"), socket.AF_INET6)
+
+    def test_name(self):
+        self.assertRaises(socket.gaierror, addr_family, "localhost")
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/flashproxy/util.py b/flashproxy/util.py
index a53bdad..63cdef5 100644
--- a/flashproxy/util.py
+++ b/flashproxy/util.py
@@ -95,6 +95,12 @@ def canonical_ip(host, port, af=0):
     except that the host param must already be an IP address."""
     return resolve_to_ip(host, port, af, gai_flags=socket.AI_NUMERICHOST)
 
+def addr_family(ip):
+    """Return the address family of an IP address. Raises socket.gaierror if ip
+    is not a numeric IP."""
+    addrs = socket.getaddrinfo(ip, 0, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_NUMERICHOST)
+    return addrs[0][0]
+
 def format_addr(addr):
     host, port = addr
     host_str = u""
@@ -102,8 +108,7 @@ def format_addr(addr):
     if host is not None:
         # Numeric IPv6 address?
         try:
-            addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_NUMERICHOST)
-            af = addrs[0][0]
+            af = addr_family(host)
         except socket.gaierror, e:
             af = 0
         if af == socket.AF_INET6:



_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits