286 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			286 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Go
		
	
	
| package context
 | |
| 
 | |
| import (
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"net/http/httputil"
 | |
| 	"net/url"
 | |
| 	"reflect"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| func TestWithRequest(t *testing.T) {
 | |
| 	var req http.Request
 | |
| 
 | |
| 	start := time.Now()
 | |
| 	req.Method = http.MethodGet
 | |
| 	req.Host = "example.com"
 | |
| 	req.RequestURI = "/test-test"
 | |
| 	req.Header = make(http.Header)
 | |
| 	req.Header.Set("Referer", "foo.com/referer")
 | |
| 	req.Header.Set("User-Agent", "test/0.1")
 | |
| 
 | |
| 	ctx := WithRequest(Background(), &req)
 | |
| 	for _, testcase := range []struct {
 | |
| 		key      string
 | |
| 		expected interface{}
 | |
| 	}{
 | |
| 		{
 | |
| 			key:      "http.request",
 | |
| 			expected: &req,
 | |
| 		},
 | |
| 		{
 | |
| 			key: "http.request.id",
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.method",
 | |
| 			expected: req.Method,
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.host",
 | |
| 			expected: req.Host,
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.uri",
 | |
| 			expected: req.RequestURI,
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.referer",
 | |
| 			expected: req.Referer(),
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.useragent",
 | |
| 			expected: req.UserAgent(),
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.remoteaddr",
 | |
| 			expected: req.RemoteAddr,
 | |
| 		},
 | |
| 		{
 | |
| 			key: "http.request.startedat",
 | |
| 		},
 | |
| 	} {
 | |
| 		v := ctx.Value(testcase.key)
 | |
| 
 | |
| 		if v == nil {
 | |
| 			t.Fatalf("value not found for %q", testcase.key)
 | |
| 		}
 | |
| 
 | |
| 		if testcase.expected != nil && v != testcase.expected {
 | |
| 			t.Fatalf("%s: %v != %v", testcase.key, v, testcase.expected)
 | |
| 		}
 | |
| 
 | |
| 		// Key specific checks!
 | |
| 		switch testcase.key {
 | |
| 		case "http.request.id":
 | |
| 			if _, ok := v.(string); !ok {
 | |
| 				t.Fatalf("request id not a string: %v", v)
 | |
| 			}
 | |
| 		case "http.request.startedat":
 | |
| 			vt, ok := v.(time.Time)
 | |
| 			if !ok {
 | |
| 				t.Fatalf("value not a time: %v", v)
 | |
| 			}
 | |
| 
 | |
| 			now := time.Now()
 | |
| 			if vt.After(now) {
 | |
| 				t.Fatalf("time generated too late: %v > %v", vt, now)
 | |
| 			}
 | |
| 
 | |
| 			if vt.Before(start) {
 | |
| 				t.Fatalf("time generated too early: %v < %v", vt, start)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type testResponseWriter struct {
 | |
| 	flushed bool
 | |
| 	status  int
 | |
| 	written int64
 | |
| 	header  http.Header
 | |
| }
 | |
| 
 | |
| func (trw *testResponseWriter) Header() http.Header {
 | |
| 	if trw.header == nil {
 | |
| 		trw.header = make(http.Header)
 | |
| 	}
 | |
| 
 | |
| 	return trw.header
 | |
| }
 | |
| 
 | |
| func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
 | |
| 	if trw.status == 0 {
 | |
| 		trw.status = http.StatusOK
 | |
| 	}
 | |
| 
 | |
| 	n = len(p)
 | |
| 	trw.written += int64(n)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (trw *testResponseWriter) WriteHeader(status int) {
 | |
| 	trw.status = status
 | |
| }
 | |
| 
 | |
| func (trw *testResponseWriter) Flush() {
 | |
| 	trw.flushed = true
 | |
| }
 | |
| 
 | |
| func TestWithResponseWriter(t *testing.T) {
 | |
| 	trw := testResponseWriter{}
 | |
| 	ctx, rw := WithResponseWriter(Background(), &trw)
 | |
| 
 | |
| 	if ctx.Value("http.response") != rw {
 | |
| 		t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), rw)
 | |
| 	}
 | |
| 
 | |
| 	grw, err := GetResponseWriter(ctx)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("error getting response writer: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	if grw != rw {
 | |
| 		t.Fatalf("unexpected response writer returned: %#v != %#v", grw, rw)
 | |
| 	}
 | |
| 
 | |
| 	if ctx.Value("http.response.status") != 0 {
 | |
| 		t.Fatalf("response status should always be a number and should be zero here: %v != 0", ctx.Value("http.response.status"))
 | |
| 	}
 | |
| 
 | |
| 	if n, err := rw.Write(make([]byte, 1024)); err != nil {
 | |
| 		t.Fatalf("unexpected error writing: %v", err)
 | |
| 	} else if n != 1024 {
 | |
| 		t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
 | |
| 	}
 | |
| 
 | |
| 	if ctx.Value("http.response.status") != http.StatusOK {
 | |
| 		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
 | |
| 	}
 | |
| 
 | |
| 	if ctx.Value("http.response.written") != int64(1024) {
 | |
| 		t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
 | |
| 	}
 | |
| 
 | |
| 	// Make sure flush propagates
 | |
| 	rw.(http.Flusher).Flush()
 | |
| 
 | |
| 	if !trw.flushed {
 | |
| 		t.Fatalf("response writer not flushed")
 | |
| 	}
 | |
| 
 | |
| 	// Write another status and make sure context is correct. This normally
 | |
| 	// wouldn't work except for in this contrived testcase.
 | |
| 	rw.WriteHeader(http.StatusBadRequest)
 | |
| 
 | |
| 	if ctx.Value("http.response.status") != http.StatusBadRequest {
 | |
| 		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestWithVars(t *testing.T) {
 | |
| 	var req http.Request
 | |
| 	vars := map[string]string{
 | |
| 		"foo": "asdf",
 | |
| 		"bar": "qwer",
 | |
| 	}
 | |
| 
 | |
| 	getVarsFromRequest = func(r *http.Request) map[string]string {
 | |
| 		if r != &req {
 | |
| 			t.Fatalf("unexpected request: %v != %v", r, req)
 | |
| 		}
 | |
| 
 | |
| 		return vars
 | |
| 	}
 | |
| 
 | |
| 	ctx := WithVars(Background(), &req)
 | |
| 	for _, testcase := range []struct {
 | |
| 		key      string
 | |
| 		expected interface{}
 | |
| 	}{
 | |
| 		{
 | |
| 			key:      "vars",
 | |
| 			expected: vars,
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "vars.foo",
 | |
| 			expected: "asdf",
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "vars.bar",
 | |
| 			expected: "qwer",
 | |
| 		},
 | |
| 	} {
 | |
| 		v := ctx.Value(testcase.key)
 | |
| 
 | |
| 		if !reflect.DeepEqual(v, testcase.expected) {
 | |
| 			t.Fatalf("%q: %v != %v", testcase.key, v, testcase.expected)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
 | |
| // RemoteAddr().  A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
 | |
| // at the transport layer to 127.0.0.1:<port> .  However, as the X-Forwarded-For header
 | |
| // just contains the IP address, it is different enough for testing.
 | |
| func TestRemoteAddr(t *testing.T) {
 | |
| 	var expectedRemote string
 | |
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		defer r.Body.Close()
 | |
| 
 | |
| 		if r.RemoteAddr == expectedRemote {
 | |
| 			t.Errorf("Unexpected matching remote addresses")
 | |
| 		}
 | |
| 
 | |
| 		actualRemote := RemoteAddr(r)
 | |
| 		if expectedRemote != actualRemote {
 | |
| 			t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
 | |
| 		}
 | |
| 
 | |
| 		w.WriteHeader(200)
 | |
| 	}))
 | |
| 
 | |
| 	defer backend.Close()
 | |
| 	backendURL, err := url.Parse(backend.URL)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	proxy := httputil.NewSingleHostReverseProxy(backendURL)
 | |
| 	frontend := httptest.NewServer(proxy)
 | |
| 	defer frontend.Close()
 | |
| 
 | |
| 	// X-Forwarded-For set by proxy
 | |
| 	expectedRemote = "127.0.0.1"
 | |
| 	proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	_, err = http.DefaultClient.Do(proxyReq)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// RemoteAddr in X-Real-Ip
 | |
| 	getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	expectedRemote = "1.2.3.4"
 | |
| 	getReq.Header["X-Real-ip"] = []string{expectedRemote}
 | |
| 	_, err = http.DefaultClient.Do(getReq)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// Valid X-Real-Ip and invalid X-Forwarded-For
 | |
| 	getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
 | |
| 	_, err = http.DefaultClient.Do(getReq)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 |