[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[tor-commits] [pluggable-transports/snowflake] 02/03: Parse ClientPollRequest version in DecodeClientPollRequest



This is an automated email from the git hooks/post-receive script.

arlo pushed a commit to branch main
in repository pluggable-transports/snowflake.

commit 829cacac5f7ecb2cc701a24061679814fc1841bc
Author: Arlo Breault <arlolra@xxxxxxxxx>
AuthorDate: Wed Mar 9 19:48:16 2022 -0500

    Parse ClientPollRequest version in DecodeClientPollRequest
    
    Instead of IPC.ClientOffers.  This makes things consistent with
    EncodeClientPollRequest which adds the version while serializing.
---
 broker/http.go                   |  5 +++--
 broker/ipc.go                    | 45 +++++++++-------------------------------
 client/lib/rendezvous.go         |  5 +++--
 client/lib/rendezvous_test.go    |  5 +++--
 common/messages/client.go        | 28 ++++++++++++++++++++-----
 common/messages/messages_test.go | 21 +++++++++----------
 6 files changed, 52 insertions(+), 57 deletions(-)

diff --git a/broker/http.go b/broker/http.go
index 3b0ba1f..7acc465 100644
--- a/broker/http.go
+++ b/broker/http.go
@@ -146,8 +146,9 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
 	if len(body) > 0 && body[0] == '{' {
 		isLegacy = true
 		req := messages.ClientPollRequest{
-			Offer: string(body),
-			NAT:   r.Header.Get("Snowflake-NAT-Type"),
+			Offer:   string(body),
+			NAT:     r.Header.Get("Snowflake-NAT-Type"),
+			Version: messages.ClientVersion1_0,
 		}
 		body, err = req.EncodeClientPollRequest()
 		if err != nil {
diff --git a/broker/ipc.go b/broker/ipc.go
index b8359f6..768c0b7 100644
--- a/broker/ipc.go
+++ b/broker/ipc.go
@@ -1,7 +1,6 @@
 package main
 
 import (
-	"bytes"
 	"container/heap"
 	"fmt"
 	"log"
@@ -21,12 +20,6 @@ const (
 	NATUnrestricted = "unrestricted"
 )
 
-type clientVersion int
-
-const (
-	v1 clientVersion = iota
-)
-
 type IPC struct {
 	ctx *BrokerContext
 }
@@ -132,32 +125,16 @@ func sendClientResponse(resp *messages.ClientPollResponse, response *[]byte) err
 }
 
 func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
-	var version clientVersion
-
 	startTime := time.Now()
-	body := arg.Body
 
-	parts := bytes.SplitN(body, []byte("\n"), 2)
-	if len(parts) < 2 {
-		// no version number found
-		err := fmt.Errorf("unsupported message version")
-		return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
-	}
-	body = parts[1]
-	if string(parts[0]) == "1.0" {
-		version = v1
-	} else {
-		err := fmt.Errorf("unsupported message version")
+	req, err := messages.DecodeClientPollRequest(arg.Body)
+	if err != nil {
 		return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
 	}
 
 	var offer *ClientOffer
-	switch version {
-	case v1:
-		req, err := messages.DecodeClientPollRequest(body)
-		if err != nil {
-			return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
-		}
+	switch req.Version {
+	case messages.ClientVersion1_0:
 		offer = &ClientOffer{
 			natType: req.NAT,
 			sdp:     []byte(req.Offer),
@@ -188,8 +165,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
 			i.ctx.metrics.clientRestrictedDeniedCount++
 		}
 		i.ctx.metrics.lock.Unlock()
-		switch version {
-		case v1:
+		switch req.Version {
+		case messages.ClientVersion1_0:
 			resp := &messages.ClientPollResponse{Error: messages.StrNoProxies}
 			return sendClientResponse(resp, response)
 		default:
@@ -204,8 +181,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
 	i.ctx.snowflakeLock.Unlock()
 	snowflake.offerChannel <- offer
 
-	var err error
-
 	// Wait for the answer to be returned on the channel or timeout.
 	select {
 	case answer := <-snowflake.answerChannel:
@@ -213,8 +188,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
 		i.ctx.metrics.clientProxyMatchCount++
 		i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc()
 		i.ctx.metrics.lock.Unlock()
-		switch version {
-		case v1:
+		switch req.Version {
+		case messages.ClientVersion1_0:
 			resp := &messages.ClientPollResponse{Answer: answer}
 			err = sendClientResponse(resp, response)
 		default:
@@ -224,8 +199,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
 		i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
 	case <-time.After(time.Second * ClientTimeout):
 		log.Println("Client: Timed out.")
-		switch version {
-		case v1:
+		switch req.Version {
+		case messages.ClientVersion1_0:
 			resp := &messages.ClientPollResponse{Error: messages.StrTimedOut}
 			err = sendClientResponse(resp, response)
 		default:
diff --git a/client/lib/rendezvous.go b/client/lib/rendezvous.go
index 0ce2744..e7543ad 100644
--- a/client/lib/rendezvous.go
+++ b/client/lib/rendezvous.go
@@ -122,8 +122,9 @@ func (bc *BrokerChannel) Negotiate(offer *webrtc.SessionDescription) (
 	// Encode the client poll request.
 	bc.lock.Lock()
 	req := &messages.ClientPollRequest{
-		Offer: offerSDP,
-		NAT:   bc.natType,
+		Offer:   offerSDP,
+		NAT:     bc.natType,
+		Version: messages.ClientVersion1_0,
 	}
 	encReq, err := req.EncodeClientPollRequest()
 	bc.lock.Unlock()
diff --git a/client/lib/rendezvous_test.go b/client/lib/rendezvous_test.go
index 21b9f57..a233e7d 100644
--- a/client/lib/rendezvous_test.go
+++ b/client/lib/rendezvous_test.go
@@ -43,8 +43,9 @@ func (t errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 // offer.
 func makeEncPollReq(offer string) []byte {
 	encPollReq, err := (&messages.ClientPollRequest{
-		Offer: offer,
-		NAT:   nat.NATUnknown,
+		Offer:   offer,
+		NAT:     nat.NATUnknown,
+		Version: messages.ClientVersion1_0,
 	}).EncodeClientPollRequest()
 	if err != nil {
 		panic(err)
diff --git a/common/messages/client.go b/common/messages/client.go
index 5a7d73b..2a35594 100644
--- a/common/messages/client.go
+++ b/common/messages/client.go
@@ -4,13 +4,14 @@
 package messages
 
 import (
+	"bytes"
 	"encoding/json"
 	"fmt"
 
 	"git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat"
 )
 
-const ClientVersion = "1.0"
+const ClientVersion1_0 = "1.0"
 
 /* Client--Broker protocol v1.x specification:
 
@@ -49,24 +50,41 @@ for the error.
 */
 
 type ClientPollRequest struct {
-	Offer string `json:"offer"`
-	NAT   string `json:"nat"`
+	Offer   string `json:"offer"`
+	NAT     string `json:"nat"`
+	Version string `json:"-"`
 }
 
 // Encodes a poll message from a snowflake client
 func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) {
+	if req.Version != ClientVersion1_0 {
+		return nil, fmt.Errorf("unsupported message version")
+	}
 	body, err := json.Marshal(req)
 	if err != nil {
 		return nil, err
 	}
-	return append([]byte(ClientVersion+"\n"), body...), nil
+	return append([]byte(req.Version+"\n"), body...), nil
 }
 
 // Decodes a poll message from a snowflake client
 func DecodeClientPollRequest(data []byte) (*ClientPollRequest, error) {
+	parts := bytes.SplitN(data, []byte("\n"), 2)
+
+	if len(parts) < 2 {
+		// no version number found
+		return nil, fmt.Errorf("unsupported message version")
+	}
+
 	var message ClientPollRequest
 
-	err := json.Unmarshal(data, &message)
+	if string(parts[0]) == ClientVersion1_0 {
+		message.Version = ClientVersion1_0
+	} else {
+		return nil, fmt.Errorf("unsupported message version")
+	}
+
+	err := json.Unmarshal(parts[1], &message)
 	if err != nil {
 		return nil, err
 	}
diff --git a/common/messages/messages_test.go b/common/messages/messages_test.go
index 0d8b450..e0aa2a8 100644
--- a/common/messages/messages_test.go
+++ b/common/messages/messages_test.go
@@ -1,7 +1,6 @@
 package messages
 
 import (
-	"bytes"
 	"encoding/json"
 	"fmt"
 	"testing"
@@ -286,14 +285,16 @@ func TestDecodeClientPollRequest(t *testing.T) {
 				//version 1.0 client message
 				"unknown",
 				"fake",
-				`{"nat":"unknown","offer":"fake"}`,
+				`1.0
+{"nat":"unknown","offer":"fake"}`,
 				nil,
 			},
 			{
 				//version 1.0 client message
 				"unknown",
 				"fake",
-				`{"offer":"fake"}`,
+				`1.0
+{"offer":"fake"}`,
 				nil,
 			},
 			{
@@ -307,16 +308,17 @@ func TestDecodeClientPollRequest(t *testing.T) {
 				//no offer
 				"",
 				"",
-				`{"nat":"unknown"}`,
+				`1.0
+{"nat":"unknown"}`,
 				fmt.Errorf(""),
 			},
 		} {
 			req, err := DecodeClientPollRequest([]byte(test.data))
+			So(err, ShouldHaveSameTypeAs, test.err)
 			if test.err == nil {
 				So(req.NAT, ShouldResemble, test.natType)
 				So(req.Offer, ShouldResemble, test.offer)
 			}
-			So(err, ShouldHaveSameTypeAs, test.err)
 		}
 
 	})
@@ -325,15 +327,12 @@ func TestDecodeClientPollRequest(t *testing.T) {
 func TestEncodeClientPollRequests(t *testing.T) {
 	Convey("Context", t, func() {
 		req1 := &ClientPollRequest{
-			NAT:   "unknown",
-			Offer: "fake",
+			NAT:     "unknown",
+			Offer:   "fake",
+			Version: ClientVersion1_0,
 		}
 		b, err := req1.EncodeClientPollRequest()
 		So(err, ShouldEqual, nil)
-		fmt.Println(string(b))
-		parts := bytes.SplitN(b, []byte("\n"), 2)
-		So(string(parts[0]), ShouldEqual, "1.0")
-		b = parts[1]
 		req2, err := DecodeClientPollRequest(b)
 		So(err, ShouldEqual, nil)
 		So(req2, ShouldResemble, req1)

-- 
To stop receiving notification emails like this one, please contact
the administrator of this repository.
_______________________________________________
tor-commits mailing list
tor-commits@xxxxxxxxxxxxxxxxxxxx
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits