[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]
[tor-commits] [flashproxy/master] Add tests for WebSocket request handling.
commit 0c1afe22b38596675f9aa06d75fbb754ea915638
Author: David Fifield <david@xxxxxxxxxxxxxxx>
Date: Fri Sep 7 05:04:53 2012 -0700
Add tests for WebSocket request handling.
---
flashproxy-client-test | 169 ++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 169 insertions(+), 0 deletions(-)
diff --git a/flashproxy-client-test b/flashproxy-client-test
index 527093c..084a272 100755
--- a/flashproxy-client-test
+++ b/flashproxy-client-test
@@ -1,11 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import base64
+import cStringIO
+import httplib
import socket
import subprocess
import sys
import unittest
+try:
+ from hashlib import sha1
+except ImportError:
+ # Python 2.4 uses this name.
+ from sha import sha as sha1
+
# Special tricks to load a module whose filename contains a dash and doesn't end
# in ".py".
import imp
@@ -13,6 +22,7 @@ dont_write_bytecode = sys.dont_write_bytecode
sys.dont_write_bytecode = True
flashproxy = imp.load_source("flashproxy", "flashproxy-client")
parse_socks_request = flashproxy.parse_socks_request
+handle_websocket_request = flashproxy.handle_websocket_request
WebSocketDecoder = flashproxy.WebSocketDecoder
WebSocketEncoder = flashproxy.WebSocketEncoder
sys.dont_write_bytecode = dont_write_bytecode
@@ -40,6 +50,165 @@ class TestSocks(unittest.TestCase):
def test_parse_socks_request_hostname(self):
dest, port = parse_socks_request("\x04\x01\x99\x99\x00\x00\x00\x01userid\x00abc\x00")
+class DummySocket(object):
+ def __init__(self, read_fd, write_fd):
+ self.read_fd = read_fd
+ self.write_fd = write_fd
+ self.readp = 0
+
+ def read(self, *args, **kwargs):
+ self.read_fd.seek(self.readp, 0)
+ data = self.read_fd.read(*args, **kwargs)
+ self.readp = self.read_fd.tell()
+ return data
+
+ def readline(self, *args, **kwargs):
+ self.read_fd.seek(self.readp, 0)
+ data = self.read_fd.readline(*args, **kwargs)
+ self.readp = self.read_fd.tell()
+ return data
+
+ def recv(self, size, *args, **kwargs):
+ return self.read(size)
+
+ def write(self, data):
+ self.write_fd.seek(0, 2)
+ self.write_fd.write(data)
+
+ def send(self, data, *args, **kwargs):
+ return self.write(data)
+
+ def sendall(self, data, *args, **kwargs):
+ return self.write(data)
+
+ def makefile(self, *args, **kwargs):
+ return self
+
+def dummy_socketpair():
+ f1 = cStringIO.StringIO()
+ f2 = cStringIO.StringIO()
+ return (DummySocket(f1, f2), DummySocket(f2, f1))
+
+class HTTPRequest(object):
+ def __init__(self):
+ self.method = "GET"
+ self.path = "/"
+ self.headers = {}
+
+def transact_http(req):
+ l, r = dummy_socketpair()
+ r.send("%s %s HTTP/1.0\r\n" % (req.method, req.path))
+ for k, v in req.headers.items():
+ r.send("%s: %s\r\n" % (k, v))
+ r.send("\r\n")
+ protocols = handle_websocket_request(l)
+
+ resp = httplib.HTTPResponse(r)
+ resp.begin()
+ return resp, protocols
+
+class TestHandleWebSocketRequest(unittest.TestCase):
+ DEFAULT_KEY = "0123456789ABCDEF"
+ DEFAULT_KEY_BASE64 = base64.b64encode(DEFAULT_KEY)
+ MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
+
+ @staticmethod
+ def default_req():
+ req = HTTPRequest()
+ req.method = "GET"
+ req.path = "/"
+ req.headers["Upgrade"] = "websocket"
+ req.headers["Connection"] = "Upgrade"
+ req.headers["Sec-WebSocket-Key"] = TestHandleWebSocketRequest.DEFAULT_KEY_BASE64
+ req.headers["Sec-WebSocket-Version"] = "13"
+
+ return req
+
+ def assert_ok(self, req):
+ resp, protocols = transact_http(req)
+ self.assertEqual(resp.status, 101)
+ self.assertEqual(resp.getheader("Upgrade").lower(), "websocket")
+ self.assertEqual(resp.getheader("Connection").lower(), "upgrade")
+ self.assertEqual(resp.getheader("Sec-WebSocket-Accept"), base64.b64encode(sha1(self.DEFAULT_KEY_BASE64 + self.MAGIC_GUID).digest()))
+ self.assertEqual(protocols, [])
+
+ def assert_not_ok(self, req):
+ resp, protocols = transact_http(req)
+ self.assertEqual(resp.status // 100, 4)
+ self.assertEqual(protocols, None)
+
+ def test_default(self):
+ req = self.default_req()
+ self.assert_ok(req)
+
+ def test_missing_upgrade(self):
+ req = self.default_req()
+ del req.headers["Upgrade"]
+ self.assert_not_ok(req)
+
+ def test_missing_connection(self):
+ req = self.default_req()
+ del req.headers["Connection"]
+ self.assert_not_ok(req)
+
+ def test_case_insensitivity(self):
+ """Test that the values of the Upgrade and Connection headers are
+ case-insensitive."""
+ req = self.default_req()
+ req.headers["Upgrade"] = req.headers["Upgrade"].lower()
+ self.assert_ok(req)
+ req.headers["Upgrade"] = req.headers["Upgrade"].upper()
+ self.assert_ok(req)
+ req.headers["Connection"] = req.headers["Connection"].lower()
+ self.assert_ok(req)
+ req.headers["Connection"] = req.headers["Connection"].upper()
+ self.assert_ok(req)
+
+ def test_bogus_key(self):
+ req = self.default_req()
+ req.headers["Sec-WebSocket-Key"] = base64.b64encode(self.DEFAULT_KEY[:-1])
+ self.assert_not_ok(req)
+
+ req.headers["Sec-WebSocket-Key"] = "///"
+ self.assert_not_ok(req)
+
+ def test_versions(self):
+ req = self.default_req()
+ req.headers["Sec-WebSocket-Version"] = "13"
+ self.assert_ok(req)
+ req.headers["Sec-WebSocket-Version"] = "8"
+ self.assert_ok(req)
+
+ req.headers["Sec-WebSocket-Version"] = "7"
+ self.assert_not_ok(req)
+ req.headers["Sec-WebSocket-Version"] = "9"
+ self.assert_not_ok(req)
+
+ del req.headers["Sec-WebSocket-Version"]
+ self.assert_not_ok(req)
+
+ def test_protocols(self):
+ req = self.default_req()
+ req.headers["Sec-WebSocket-Protocol"] = "base64"
+ resp, protocols = transact_http(req)
+ self.assertEqual(resp.status, 101)
+ self.assertEqual(protocols, ["base64"])
+ self.assertEqual(resp.getheader("Sec-WebSocket-Protocol"), "base64")
+
+ req = self.default_req()
+ req.headers["Sec-WebSocket-Protocol"] = "cat"
+ resp, protocols = transact_http(req)
+ self.assertEqual(resp.status, 101)
+ self.assertEqual(protocols, ["cat"])
+ self.assertEqual(resp.getheader("Sec-WebSocket-Protocol"), None)
+
+ req = self.default_req()
+ req.headers["Sec-WebSocket-Protocol"] = "cat, base64"
+ resp, protocols = transact_http(req)
+ self.assertEqual(resp.status, 101)
+ self.assertEqual(protocols, ["cat", "base64"])
+ self.assertEqual(resp.getheader("Sec-WebSocket-Protocol"), "base64")
+
def read_frames(dec):
frames = []
while True:
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits