208 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			208 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
package context
 | 
						|
 | 
						|
import (
 | 
						|
	"net/http"
 | 
						|
	"reflect"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"golang.org/x/net/context"
 | 
						|
)
 | 
						|
 | 
						|
func TestWithRequest(t *testing.T) {
 | 
						|
	var req http.Request
 | 
						|
 | 
						|
	start := time.Now()
 | 
						|
	req.Method = "GET"
 | 
						|
	req.Host = "example.com"
 | 
						|
	req.RequestURI = "/test-test"
 | 
						|
	req.Header = make(http.Header)
 | 
						|
	req.Header.Set("Referer", "foo.com/referer")
 | 
						|
	req.Header.Set("User-Agent", "test/0.1")
 | 
						|
 | 
						|
	ctx := WithRequest(context.Background(), &req)
 | 
						|
	for _, testcase := range []struct {
 | 
						|
		key      string
 | 
						|
		expected interface{}
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			key:      "http.request",
 | 
						|
			expected: &req,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key: "http.request.id",
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key:      "http.request.method",
 | 
						|
			expected: req.Method,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key:      "http.request.host",
 | 
						|
			expected: req.Host,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key:      "http.request.uri",
 | 
						|
			expected: req.RequestURI,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key:      "http.request.referer",
 | 
						|
			expected: req.Referer(),
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key:      "http.request.useragent",
 | 
						|
			expected: req.UserAgent(),
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key:      "http.request.remoteaddr",
 | 
						|
			expected: req.RemoteAddr,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key: "http.request.startedat",
 | 
						|
		},
 | 
						|
	} {
 | 
						|
		v := ctx.Value(testcase.key)
 | 
						|
 | 
						|
		if v == nil {
 | 
						|
			t.Fatalf("value not found for %q", testcase.key)
 | 
						|
		}
 | 
						|
 | 
						|
		if testcase.expected != nil && v != testcase.expected {
 | 
						|
			t.Fatalf("%s: %v != %v", testcase.key, v, testcase.expected)
 | 
						|
		}
 | 
						|
 | 
						|
		// Key specific checks!
 | 
						|
		switch testcase.key {
 | 
						|
		case "http.request.id":
 | 
						|
			if _, ok := v.(string); !ok {
 | 
						|
				t.Fatalf("request id not a string: %v", v)
 | 
						|
			}
 | 
						|
		case "http.request.startedat":
 | 
						|
			vt, ok := v.(time.Time)
 | 
						|
			if !ok {
 | 
						|
				t.Fatalf("value not a time: %v", v)
 | 
						|
			}
 | 
						|
 | 
						|
			now := time.Now()
 | 
						|
			if vt.After(now) {
 | 
						|
				t.Fatalf("time generated too late: %v > %v", vt, now)
 | 
						|
			}
 | 
						|
 | 
						|
			if vt.Before(start) {
 | 
						|
				t.Fatalf("time generated too early: %v < %v", vt, start)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type testResponseWriter struct {
 | 
						|
	flushed bool
 | 
						|
	status  int
 | 
						|
	written int64
 | 
						|
	header  http.Header
 | 
						|
}
 | 
						|
 | 
						|
func (trw *testResponseWriter) Header() http.Header {
 | 
						|
	if trw.header == nil {
 | 
						|
		trw.header = make(http.Header)
 | 
						|
	}
 | 
						|
 | 
						|
	return trw.header
 | 
						|
}
 | 
						|
 | 
						|
func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
 | 
						|
	if trw.status == 0 {
 | 
						|
		trw.status = http.StatusOK
 | 
						|
	}
 | 
						|
 | 
						|
	n = len(p)
 | 
						|
	trw.written += int64(n)
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func (trw *testResponseWriter) WriteHeader(status int) {
 | 
						|
	trw.status = status
 | 
						|
}
 | 
						|
 | 
						|
func (trw *testResponseWriter) Flush() {
 | 
						|
	trw.flushed = true
 | 
						|
}
 | 
						|
 | 
						|
func TestWithResponseWriter(t *testing.T) {
 | 
						|
	trw := testResponseWriter{}
 | 
						|
	ctx, rw := WithResponseWriter(context.Background(), &trw)
 | 
						|
 | 
						|
	if ctx.Value("http.response") != &trw {
 | 
						|
		t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), &trw)
 | 
						|
	}
 | 
						|
 | 
						|
	if n, err := rw.Write(make([]byte, 1024)); err != nil {
 | 
						|
		t.Fatalf("unexpected error writing: %v", err)
 | 
						|
	} else if n != 1024 {
 | 
						|
		t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
 | 
						|
	}
 | 
						|
 | 
						|
	if ctx.Value("http.response.status") != http.StatusOK {
 | 
						|
		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
 | 
						|
	}
 | 
						|
 | 
						|
	if ctx.Value("http.response.written") != int64(1024) {
 | 
						|
		t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
 | 
						|
	}
 | 
						|
 | 
						|
	// Make sure flush propagates
 | 
						|
	rw.(http.Flusher).Flush()
 | 
						|
 | 
						|
	if !trw.flushed {
 | 
						|
		t.Fatalf("response writer not flushed")
 | 
						|
	}
 | 
						|
 | 
						|
	// Write another status and make sure context is correct. This normally
 | 
						|
	// wouldn't work except for in this contrived testcase.
 | 
						|
	rw.WriteHeader(http.StatusBadRequest)
 | 
						|
 | 
						|
	if ctx.Value("http.response.status") != http.StatusBadRequest {
 | 
						|
		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestWithVars(t *testing.T) {
 | 
						|
	var req http.Request
 | 
						|
	vars := map[string]string{
 | 
						|
		"foo": "asdf",
 | 
						|
		"bar": "qwer",
 | 
						|
	}
 | 
						|
 | 
						|
	getVarsFromRequest = func(r *http.Request) map[string]string {
 | 
						|
		if r != &req {
 | 
						|
			t.Fatalf("unexpected request: %v != %v", r, req)
 | 
						|
		}
 | 
						|
 | 
						|
		return vars
 | 
						|
	}
 | 
						|
 | 
						|
	ctx := WithVars(context.Background(), &req)
 | 
						|
	for _, testcase := range []struct {
 | 
						|
		key      string
 | 
						|
		expected interface{}
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			key:      "vars",
 | 
						|
			expected: vars,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key:      "vars.foo",
 | 
						|
			expected: "asdf",
 | 
						|
		},
 | 
						|
		{
 | 
						|
			key:      "vars.bar",
 | 
						|
			expected: "qwer",
 | 
						|
		},
 | 
						|
	} {
 | 
						|
		v := ctx.Value(testcase.key)
 | 
						|
 | 
						|
		if !reflect.DeepEqual(v, testcase.expected) {
 | 
						|
			t.Fatalf("%q: %v != %v", testcase.key, v, testcase.expected)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 |