[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[tor-commits] [pluggable-transports/snowflake] 06/07: Use a sync.Pool to reuse packet buffers in QueuePacketConn.



This is an automated email from the git hooks/post-receive script.

meskio pushed a commit to branch main
in repository pluggable-transports/snowflake.

commit c097d5f3bc9e95403006527b90207dfb11ce6438
Author: David Fifield <david@xxxxxxxxxxxxxxx>
AuthorDate: Tue Apr 4 18:45:26 2023 -0600

    Use a sync.Pool to reuse packet buffers in QueuePacketConn.
    
    This is meant to reduce overall allocations. See past discussion at
    https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40260#note_2885524 ff.
---
 common/turbotunnel/queuepacketconn.go      | 47 +++++++++++++++----
 common/turbotunnel/queuepacketconn_test.go | 72 ++++++++++++++++++++++++++++--
 server/lib/http.go                         |  5 ++-
 server/lib/snowflake.go                    |  6 ++-
 4 files changed, 116 insertions(+), 14 deletions(-)

diff --git a/common/turbotunnel/queuepacketconn.go b/common/turbotunnel/queuepacketconn.go
index 5cdb559..6fcc3bf 100644
--- a/common/turbotunnel/queuepacketconn.go
+++ b/common/turbotunnel/queuepacketconn.go
@@ -27,23 +27,29 @@ type QueuePacketConn struct {
 	recvQueue chan taggedPacket
 	closeOnce sync.Once
 	closed    chan struct{}
+	mtu       int
+	// Pool of reusable mtu-sized buffers.
+	bufPool sync.Pool
 	// What error to return when the QueuePacketConn is closed.
 	err atomic.Value
 }
 
 // NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients
-// for at least a duration of timeout.
-func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn {
+// for at least a duration of timeout. The maximum packet size is mtu.
+func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration, mtu int) *QueuePacketConn {
 	return &QueuePacketConn{
 		clients:   NewClientMap(timeout),
 		localAddr: localAddr,
 		recvQueue: make(chan taggedPacket, queueSize),
 		closed:    make(chan struct{}),
+		mtu:       mtu,
+		bufPool:   sync.Pool{New: func() interface{} { return make([]byte, mtu) }},
 	}
 }
 
 // QueueIncoming queues an incoming packet and its source address, to be
-// returned in a future call to ReadFrom.
+// returned in a future call to ReadFrom. If p is longer than the MTU, only its
+// first MTU bytes will be used.
 func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
 	select {
 	case <-c.closed:
@@ -52,12 +58,18 @@ func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
 	default:
 	}
 	// Copy the slice so that the caller may reuse it.
-	buf := make([]byte, len(p))
+	buf := c.bufPool.Get().([]byte)
+	if len(p) < cap(buf) {
+		buf = buf[:len(p)]
+	} else {
+		buf = buf[:cap(buf)]
+	}
 	copy(buf, p)
 	select {
 	case c.recvQueue <- taggedPacket{buf, addr}:
 	default:
 		// Drop the incoming packet if the receive queue is full.
+		c.Restore(buf)
 	}
 }
 
@@ -68,6 +80,16 @@ func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
 	return c.clients.SendQueue(addr)
 }
 
+// Restore adds a slice to the internal pool of packet buffers. Typically you
+// will call this with a slice from the OutgoingQueue channel once you are done
+// using it. (It is not an error to fail to do so, it will just result in more
+// allocations.)
+func (c *QueuePacketConn) Restore(p []byte) {
+	if cap(p) >= c.mtu {
+		c.bufPool.Put(p)
+	}
+}
+
 // ReadFrom returns a packet and address previously stored by QueueIncoming.
 func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
 	select {
@@ -79,12 +101,15 @@ func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
 	case <-c.closed:
 		return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
 	case packet := <-c.recvQueue:
-		return copy(p, packet.P), packet.Addr, nil
+		n := copy(p, packet.P)
+		c.Restore(packet.P)
+		return n, packet.Addr, nil
 	}
 }
 
 // WriteTo queues an outgoing packet for the given address. The queue can later
-// be retrieved using the OutgoingQueue method.
+// be retrieved using the OutgoingQueue method. If p is longer than the MTU,
+// only its first MTU bytes will be used.
 func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
 	select {
 	case <-c.closed:
@@ -92,14 +117,20 @@ func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
 	default:
 	}
 	// Copy the slice so that the caller may reuse it.
-	buf := make([]byte, len(p))
+	buf := c.bufPool.Get().([]byte)
+	if len(p) < cap(buf) {
+		buf = buf[:len(p)]
+	} else {
+		buf = buf[:cap(buf)]
+	}
 	copy(buf, p)
 	select {
 	case c.clients.SendQueue(addr) <- buf:
 		return len(buf), nil
 	default:
 		// Drop the outgoing packet if the send queue is full.
-		return len(buf), nil
+		c.Restore(buf)
+		return len(p), nil
 	}
 }
 
diff --git a/common/turbotunnel/queuepacketconn_test.go b/common/turbotunnel/queuepacketconn_test.go
index 37f46bc..b9f62c9 100644
--- a/common/turbotunnel/queuepacketconn_test.go
+++ b/common/turbotunnel/queuepacketconn_test.go
@@ -23,7 +23,7 @@ func (i intAddr) String() string  { return fmt.Sprintf("%d", i) }
 
 // Run with -benchmem to see memory allocations.
 func BenchmarkQueueIncoming(b *testing.B) {
-	conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour)
+	conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
 	defer conn.Close()
 
 	b.ResetTimer()
@@ -36,7 +36,7 @@ func BenchmarkQueueIncoming(b *testing.B) {
 
 // BenchmarkWriteTo benchmarks the QueuePacketConn.WriteTo function.
 func BenchmarkWriteTo(b *testing.B) {
-	conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour)
+	conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
 	defer conn.Close()
 
 	b.ResetTimer()
@@ -47,6 +47,72 @@ func BenchmarkWriteTo(b *testing.B) {
 	b.StopTimer()
 }
 
+// TestQueueIncomingOversize tests that QueueIncoming truncates packets that are
+// larger than the MTU.
+func TestQueueIncomingOversize(t *testing.T) {
+	const payload = "abcdefghijklmnopqrstuvwxyz"
+	conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1)
+	defer conn.Close()
+	conn.QueueIncoming([]byte(payload), emptyAddr{})
+	var p [500]byte
+	n, _, err := conn.ReadFrom(p[:])
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(p[:n], []byte(payload[:len(payload)-1])) {
+		t.Fatalf("payload was %+q, expected %+q", p[:n], payload[:len(payload)-1])
+	}
+}
+
+// TestWriteToOversize tests that WriteTo truncates packets that are larger than
+// the MTU.
+func TestWriteToOversize(t *testing.T) {
+	const payload = "abcdefghijklmnopqrstuvwxyz"
+	conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1)
+	defer conn.Close()
+	conn.WriteTo([]byte(payload), emptyAddr{})
+	p := <-conn.OutgoingQueue(emptyAddr{})
+	if !bytes.Equal(p, []byte(payload[:len(payload)-1])) {
+		t.Fatalf("payload was %+q, expected %+q", p, payload[:len(payload)-1])
+	}
+}
+
+// TestRestoreMTU tests that Restore ignores any inputs that are not at least
+// MTU-sized.
+func TestRestoreMTU(t *testing.T) {
+	const mtu = 500
+	const payload = "hello"
+	conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu)
+	defer conn.Close()
+	conn.Restore(make([]byte, mtu-1))
+	// This WriteTo may use the short slice we just gave to Restore.
+	conn.WriteTo([]byte(payload), emptyAddr{})
+	// Read the queued slice and ensure its capacity is at least the MTU.
+	p := <-conn.OutgoingQueue(emptyAddr{})
+	if cap(p) != mtu {
+		t.Fatalf("cap was %v, expected %v", cap(p), mtu)
+	}
+	// Check the payload while we're at it.
+	if !bytes.Equal(p, []byte(payload)) {
+		t.Fatalf("payload was %+q, expected %+q", p, payload)
+	}
+}
+
+// TestRestoreCap tests that Restore can use slices whose cap is at least the
+// MTU, even if the len is shorter.
+func TestRestoreCap(t *testing.T) {
+	const mtu = 500
+	const payload = "hello"
+	conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu)
+	defer conn.Close()
+	conn.Restore(make([]byte, 0, mtu))
+	conn.WriteTo([]byte(payload), emptyAddr{})
+	p := <-conn.OutgoingQueue(emptyAddr{})
+	if !bytes.Equal(p, []byte(payload)) {
+		t.Fatalf("payload was %+q, expected %+q", p, payload)
+	}
+}
+
 // DiscardPacketConn is a net.PacketConn whose ReadFrom method block forever and
 // whose WriteTo method discards whatever it is called with.
 type DiscardPacketConn struct{}
@@ -122,7 +188,7 @@ func TestQueuePacketConnWriteToKCP(t *testing.T) {
 		}
 	}()
 
-	pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour)
+	pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
 	defer pconn.Close()
 	addr1 := intAddr(1)
 	outgoing := pconn.OutgoingQueue(addr1)
diff --git a/server/lib/http.go b/server/lib/http.go
index 3a01884..8c0343f 100644
--- a/server/lib/http.go
+++ b/server/lib/http.go
@@ -69,10 +69,10 @@ type httpHandler struct {
 
 // newHTTPHandler creates a new http.Handler that exchanges encapsulated packets
 // over incoming WebSocket connections.
-func newHTTPHandler(localAddr net.Addr, numInstances int) *httpHandler {
+func newHTTPHandler(localAddr net.Addr, numInstances int, mtu int) *httpHandler {
 	pconns := make([]*turbotunnel.QueuePacketConn, 0, numInstances)
 	for i := 0; i < numInstances; i++ {
-		pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout))
+		pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout, mtu))
 	}
 
 	clientIDLookupKey := make([]byte, 16)
@@ -200,6 +200,7 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error
 					return
 				}
 				_, err := encapsulation.WriteData(bw, p)
+				pconn.Restore(p)
 				if err == nil {
 					err = bw.Flush()
 				}
diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go
index c4d3fbc..3c3c440 100644
--- a/server/lib/snowflake.go
+++ b/server/lib/snowflake.go
@@ -79,7 +79,11 @@ func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListen
 		ln:     make([]*kcp.Listener, 0, numKCPInstances),
 	}
 
-	handler := newHTTPHandler(addr, numKCPInstances)
+	// kcp-go doesn't provide an accessor for the current MTU setting (and
+	// anyway we could not create a kcp.Listener without creating a
+	// net.PacketConn for it first), so assume the default kcp.IKCP_MTU_DEF
+	// (1400 bytes) and don't increase it elsewhere.
+	handler := newHTTPHandler(addr, numKCPInstances, kcp.IKCP_MTU_DEF)
 	server := &http.Server{
 		Addr:        addr.String(),
 		Handler:     handler,

-- 
To stop receiving notification emails like this one, please contact
the administrator of this repository.
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits