diff --git a/services/httpoverrpc/server/server.go b/services/httpoverrpc/server/server.go index 7233ddd5..87d864b8 100644 --- a/services/httpoverrpc/server/server.go +++ b/services/httpoverrpc/server/server.go @@ -22,9 +22,6 @@ import ( "context" "crypto/tls" "fmt" - "io" - "net/http" - "github.com/Snowflake-Labs/sansshell/services" pb "github.com/Snowflake-Labs/sansshell/services/httpoverrpc" sansshellserver "github.com/Snowflake-Labs/sansshell/services/sansshell/server" @@ -32,6 +29,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "io" + "net/http" + "strings" ) // Metrics @@ -62,9 +62,16 @@ func (s *server) Host(ctx context.Context, req *pb.HostHTTPRequest) (*pb.HTTPRep } // Set a default user agent that can be overridden in the request. httpReq.Header["User-Agent"] = []string{"sansshell/" + sansshellserver.Version} + for _, header := range req.Request.Headers { + if strings.ToLower(header.Key) == "host" { + // override the host with value from header + httpReq.Host = header.Values[0] + continue + } httpReq.Header[header.Key] = header.Values } + client := &http.Client{} if req.Tlsconfig != nil { client.Transport = &http.Transport{ diff --git a/services/httpoverrpc/server/server_test.go b/services/httpoverrpc/server/server_test.go index 66447bb2..9a252861 100644 --- a/services/httpoverrpc/server/server_test.go +++ b/services/httpoverrpc/server/server_test.go @@ -64,168 +64,226 @@ func TestMain(m *testing.M) { } func TestServer(t *testing.T) { - var err error - ctx := context.Background() - conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) - testutil.FatalOnErr("Failed to dial bufnet", err, t) - t.Cleanup(func() { conn.Close() }) - - client := httpoverrpc.NewHTTPOverRPCClient(conn) - - // Set up web server - m := http.NewServeMux() - m.HandleFunc("/", func(httpResp http.ResponseWriter, httpReq *http.Request) { - _, _ = httpResp.Write([]byte("hello world")) - }) - l, err := net.Listen("tcp4", "localhost:0") - if err != nil { - t.Fatal(err) - } - go func() { _ = http.Serve(l, m) }() - - _, p, err := net.SplitHostPort(l.Addr().String()) - if err != nil { - t.Fatal(err) - } - httpPort, err := strconv.Atoi(p) - if err != nil { - t.Fatal(err) - } - - got, err := client.Host(ctx, &httpoverrpc.HostHTTPRequest{ - Request: &httpoverrpc.HTTPRequest{ - Method: "GET", - RequestUri: "/", - }, - Port: int32(httpPort), - Protocol: "http", - }) - if err != nil { - t.Fatal(err) - } + t.Run("it should send request and get expected requests", func(t *testing.T) { + var err error + ctx := context.Background() + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + testutil.FatalOnErr("Failed to dial bufnet", err, t) + t.Cleanup(func() { conn.Close() }) - sort.Slice(got.Headers, func(i, j int) bool { - return got.Headers[i].Key < got.Headers[j].Key - }) - for _, h := range got.Headers { - if h.Key == "Date" { - // Clear out the date header because it varies based on time. - h.Values = nil - } - } - - want := &httpoverrpc.HTTPReply{ - StatusCode: 200, - Headers: []*httpoverrpc.Header{ - {Key: "Content-Length", Values: []string{"11"}}, - {Key: "Content-Type", Values: []string{"text/plain; charset=utf-8"}}, - {Key: "Date"}, - }, - Body: []byte("hello world"), - } - if !cmp.Equal(got, want, protocmp.Transform()) { - t.Fatalf("want %v, got %v", want, got) - } - - // test https post request and expect json response - type Data struct { - InstanceID int `json:"instanceId"` - IPAddress string `json:"ipAddress"` - } - - type Response struct { - Data Data `json:"data"` - Code *string `json:"code"` - Message *string `json:"message"` - Success bool `json:"success"` - } - m = http.NewServeMux() - m.HandleFunc("/register", func(httpResp http.ResponseWriter, httpReq *http.Request) { - if httpReq.Method == http.MethodPost { - httpResp.Header().Set("Content-Type", "application/json") - response := Response{ - Data: Data{ - InstanceID: 11, - IPAddress: "127.0.0.1", - }, - Code: nil, - Message: nil, - Success: true, + client := httpoverrpc.NewHTTPOverRPCClient(conn) + + // Set up web server + m := http.NewServeMux() + m.HandleFunc("/", func(httpResp http.ResponseWriter, httpReq *http.Request) { + _, _ = httpResp.Write([]byte("hello world")) + }) + l, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + go func() { _ = http.Serve(l, m) }() + + _, p, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + httpPort, err := strconv.Atoi(p) + if err != nil { + t.Fatal(err) + } + + got, err := client.Host(ctx, &httpoverrpc.HostHTTPRequest{ + Request: &httpoverrpc.HTTPRequest{ + Method: "GET", + RequestUri: "/", + }, + Port: int32(httpPort), + Protocol: "http", + }) + if err != nil { + t.Fatal(err) + } + + sort.Slice(got.Headers, func(i, j int) bool { + return got.Headers[i].Key < got.Headers[j].Key + }) + for _, h := range got.Headers { + if h.Key == "Date" { + // Clear out the date header because it varies based on time. + h.Values = nil } - err = json.NewEncoder(httpResp).Encode(response) - testutil.FatalOnErr("Failed to ", err, t) - } else { - http.Error(httpResp, "Invalid request method", http.StatusMethodNotAllowed) } - }) - server := httptest.NewTLSServer(m) - l = server.Listener - - httpClient := server.Client() - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - httpClient.Transport = tr - - _, p, err = net.SplitHostPort(l.Addr().String()) - if err != nil { - t.Fatal(err) - } - httpPort, err = strconv.Atoi(p) - if err != nil { - t.Fatal(err) - } - - got, err = client.Host(ctx, &httpoverrpc.HostHTTPRequest{ - Request: &httpoverrpc.HTTPRequest{ - Method: "POST", - RequestUri: "/register", - }, - Port: int32(httpPort), - Protocol: "https", - Hostname: "localhost", - Tlsconfig: &httpoverrpc.TLSConfig{ - InsecureSkipVerify: true, - }, - }) - if err != nil { - t.Fatal(err) - } + want := &httpoverrpc.HTTPReply{ + StatusCode: 200, + Headers: []*httpoverrpc.Header{ + {Key: "Content-Length", Values: []string{"11"}}, + {Key: "Content-Type", Values: []string{"text/plain; charset=utf-8"}}, + {Key: "Date"}, + }, + Body: []byte("hello world"), + } + if !cmp.Equal(got, want, protocmp.Transform()) { + t.Fatalf("want %v, got %v", want, got) + } + + // test https post request and expect json response + type Data struct { + InstanceID int `json:"instanceId"` + IPAddress string `json:"ipAddress"` + } + + type Response struct { + Data Data `json:"data"` + Code *string `json:"code"` + Message *string `json:"message"` + Success bool `json:"success"` + } + m = http.NewServeMux() + m.HandleFunc("/register", func(httpResp http.ResponseWriter, httpReq *http.Request) { + if httpReq.Method == http.MethodPost { + httpResp.Header().Set("Content-Type", "application/json") + response := Response{ + Data: Data{ + InstanceID: 11, + IPAddress: "127.0.0.1", + }, + Code: nil, + Message: nil, + Success: true, + } + err = json.NewEncoder(httpResp).Encode(response) + testutil.FatalOnErr("Failed to ", err, t) + } else { + http.Error(httpResp, "Invalid request method", http.StatusMethodNotAllowed) + } + }) + + server := httptest.NewTLSServer(m) + l = server.Listener + + httpClient := server.Client() + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + httpClient.Transport = tr - sort.Slice(got.Headers, func(i, j int) bool { - return got.Headers[i].Key < got.Headers[j].Key + _, p, err = net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + httpPort, err = strconv.Atoi(p) + if err != nil { + t.Fatal(err) + } + + got, err = client.Host(ctx, &httpoverrpc.HostHTTPRequest{ + Request: &httpoverrpc.HTTPRequest{ + Method: "POST", + RequestUri: "/register", + }, + Port: int32(httpPort), + Protocol: "https", + Hostname: "localhost", + Tlsconfig: &httpoverrpc.TLSConfig{ + InsecureSkipVerify: true, + }, + }) + if err != nil { + t.Fatal(err) + } + + sort.Slice(got.Headers, func(i, j int) bool { + return got.Headers[i].Key < got.Headers[j].Key + }) + for _, h := range got.Headers { + if h.Key == "Date" { + // Clear out the date header because it varies based on time. + h.Values = nil + } + } + wantBody := `{"data":{"instanceId":11,"ipAddress":"127.0.0.1"},"code":null,"message":null,"success":true}` + "\n" + contentLengthStr := strconv.Itoa(len(wantBody)) + want = &httpoverrpc.HTTPReply{ + StatusCode: 200, + Headers: []*httpoverrpc.Header{ + {Key: "Content-Length", Values: []string{contentLengthStr}}, + {Key: "Content-Type", Values: []string{"application/json"}}, + {Key: "Date"}, + }, + Body: []byte(wantBody), + } + if !cmp.Equal(got, want, protocmp.Transform()) { + t.Fatalf("want %v, got %v", want, got) + } + + // without insecureSkipVerify, should get an error + got, err = client.Host(ctx, &httpoverrpc.HostHTTPRequest{ + Request: &httpoverrpc.HTTPRequest{ + Method: "POST", + RequestUri: "/register", + }, + Port: int32(httpPort), + Protocol: "https", + Hostname: "localhost", + }) + assert.Error(t, err) }) - for _, h := range got.Headers { - if h.Key == "Date" { - // Clear out the date header because it varies based on time. - h.Values = nil - } - } - wantBody := `{"data":{"instanceId":11,"ipAddress":"127.0.0.1"},"code":null,"message":null,"success":true}` + "\n" - contentLengthStr := strconv.Itoa(len(wantBody)) - want = &httpoverrpc.HTTPReply{ - StatusCode: 200, - Headers: []*httpoverrpc.Header{ - {Key: "Content-Length", Values: []string{contentLengthStr}}, - {Key: "Content-Type", Values: []string{"application/json"}}, - {Key: "Date"}, - }, - Body: []byte(wantBody), - } - if !cmp.Equal(got, want, protocmp.Transform()) { - t.Fatalf("want %v, got %v", want, got) - } - - // without insecureSkipVerify, should get an error - got, err = client.Host(ctx, &httpoverrpc.HostHTTPRequest{ - Request: &httpoverrpc.HTTPRequest{ - Method: "POST", - RequestUri: "/register", - }, - Port: int32(httpPort), - Protocol: "https", - Hostname: "localhost", + + t.Run("It should send send provided host header to the server", func(t *testing.T) { + // ARRANGE + var err error + ctx := context.Background() + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + testutil.FatalOnErr("Failed to dial bufnet", err, t) + t.Cleanup(func() { conn.Close() }) + + client := httpoverrpc.NewHTTPOverRPCClient(conn) + + // Set up web server + m := http.NewServeMux() + getHostHeaderURI := "/get-host-header" + m.HandleFunc(getHostHeaderURI, func(httpResp http.ResponseWriter, httpReq *http.Request) { + // reply always with provided host + _, _ = httpResp.Write([]byte(httpReq.Host)) + }) + l, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + go func() { _ = http.Serve(l, m) }() + + _, p, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + httpPort, err := strconv.Atoi(p) + if err != nil { + t.Fatal(err) + } + + customHostHeader := "example.com" + + // ACT + resp, err := client.Host(ctx, &httpoverrpc.HostHTTPRequest{ + Request: &httpoverrpc.HTTPRequest{ + Method: "GET", + RequestUri: getHostHeaderURI, + Headers: []*httpoverrpc.Header{ + {Key: "Host", Values: []string{customHostHeader}}, + }, + }, + Port: int32(httpPort), + Protocol: "http", + }) + if err != nil { + t.Fatal(err) + } + + // ASSERT + if string(resp.Body) != customHostHeader { + t.Fatalf("Expected response body to be %q, got %q", customHostHeader, resp.Body) + } }) - assert.Error(t, err) }