Add unit tests for auth challenge and endpoint
Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)master
							parent
							
								
									174a732c94
								
							
						
					
					
						commit
						b1ba2183ee
					
				|  | @ -127,7 +127,7 @@ func expectTokenOrQuoted(s string) (value string, rest string) { | |||
| 			p := make([]byte, len(s)-1) | ||||
| 			j := copy(p, s[:i]) | ||||
| 			escape := true | ||||
| 			for i = i + i; i < len(s); i++ { | ||||
| 			for i = i + 1; i < len(s); i++ { | ||||
| 				b := s[i] | ||||
| 				switch { | ||||
| 				case escape: | ||||
|  |  | |||
|  | @ -0,0 +1,37 @@ | |||
| package client | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestAuthChallengeParse(t *testing.T) { | ||||
| 	header := http.Header{} | ||||
| 	header.Add("WWW-Authenticate", `Bearer realm="https://auth.example.com/token",service="registry.example.com",other=fun,slashed="he\"\l\lo"`) | ||||
| 
 | ||||
| 	challenges := parseAuthHeader(header) | ||||
| 	if len(challenges) != 1 { | ||||
| 		t.Fatalf("Unexpected number of auth challenges: %d, expected 1", len(challenges)) | ||||
| 	} | ||||
| 
 | ||||
| 	if expected := "bearer"; challenges[0].Scheme != expected { | ||||
| 		t.Fatalf("Unexpected scheme: %s, expected: %s", challenges[0].Scheme, expected) | ||||
| 	} | ||||
| 
 | ||||
| 	if expected := "https://auth.example.com/token"; challenges[0].Parameters["realm"] != expected { | ||||
| 		t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["realm"], expected) | ||||
| 	} | ||||
| 
 | ||||
| 	if expected := "registry.example.com"; challenges[0].Parameters["service"] != expected { | ||||
| 		t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["service"], expected) | ||||
| 	} | ||||
| 
 | ||||
| 	if expected := "fun"; challenges[0].Parameters["other"] != expected { | ||||
| 		t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["other"], expected) | ||||
| 	} | ||||
| 
 | ||||
| 	if expected := "he\"llo"; challenges[0].Parameters["slashed"] != expected { | ||||
| 		t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["slashed"], expected) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
|  | @ -117,6 +117,8 @@ func (e *RepositoryEndpoint) URLBuilder() (*v2.URLBuilder, error) { | |||
| 
 | ||||
| // HTTPClient returns a new HTTP client configured for this endpoint
 | ||||
| func (e *RepositoryEndpoint) HTTPClient(name string) (*http.Client, error) { | ||||
| 	// TODO(dmcgowan): create http.Transport
 | ||||
| 
 | ||||
| 	transport := &repositoryTransport{ | ||||
| 		Header: e.Header, | ||||
| 	} | ||||
|  |  | |||
|  | @ -0,0 +1,259 @@ | |||
| package client | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/docker/distribution/testutil" | ||||
| ) | ||||
| 
 | ||||
| type testAuthenticationWrapper struct { | ||||
| 	headers   http.Header | ||||
| 	authCheck func(string) bool | ||||
| 	next      http.Handler | ||||
| } | ||||
| 
 | ||||
| func (w *testAuthenticationWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { | ||||
| 	auth := r.Header.Get("Authorization") | ||||
| 	if auth == "" || !w.authCheck(auth) { | ||||
| 		h := rw.Header() | ||||
| 		for k, values := range w.headers { | ||||
| 			h[k] = values | ||||
| 		} | ||||
| 		rw.WriteHeader(http.StatusUnauthorized) | ||||
| 		return | ||||
| 	} | ||||
| 	w.next.ServeHTTP(rw, r) | ||||
| } | ||||
| 
 | ||||
| func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, authCheck func(string) bool) (*RepositoryEndpoint, func()) { | ||||
| 	h := testutil.NewHandler(rrm) | ||||
| 	wrapper := &testAuthenticationWrapper{ | ||||
| 
 | ||||
| 		headers: http.Header(map[string][]string{ | ||||
| 			"Docker-Distribution-API-Version": {"registry/2.0"}, | ||||
| 			"WWW-Authenticate":                {authenticate}, | ||||
| 		}), | ||||
| 		authCheck: authCheck, | ||||
| 		next:      h, | ||||
| 	} | ||||
| 
 | ||||
| 	s := httptest.NewServer(wrapper) | ||||
| 	e := RepositoryEndpoint{Endpoint: s.URL, Mirror: false} | ||||
| 	return &e, s.Close | ||||
| } | ||||
| 
 | ||||
| type testCredentialStore struct { | ||||
| 	username string | ||||
| 	password string | ||||
| } | ||||
| 
 | ||||
| func (tcs *testCredentialStore) Basic(*url.URL) (string, string) { | ||||
| 	return tcs.username, tcs.password | ||||
| } | ||||
| 
 | ||||
| func TestEndpointAuthorizeToken(t *testing.T) { | ||||
| 	service := "localhost.localdomain" | ||||
| 	repo1 := "some/registry" | ||||
| 	repo2 := "other/registry" | ||||
| 	scope1 := fmt.Sprintf("repository:%s:pull,push", repo1) | ||||
| 	scope2 := fmt.Sprintf("repository:%s:pull,push", repo2) | ||||
| 
 | ||||
| 	tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ | ||||
| 		{ | ||||
| 			Request: testutil.Request{ | ||||
| 				Method: "GET", | ||||
| 				Route:  fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope1), service), | ||||
| 			}, | ||||
| 			Response: testutil.Response{ | ||||
| 				StatusCode: http.StatusOK, | ||||
| 				Body:       []byte(`{"token":"statictoken"}`), | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Request: testutil.Request{ | ||||
| 				Method: "GET", | ||||
| 				Route:  fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope2), service), | ||||
| 			}, | ||||
| 			Response: testutil.Response{ | ||||
| 				StatusCode: http.StatusOK, | ||||
| 				Body:       []byte(`{"token":"badtoken"}`), | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	te, tc := testServer(tokenMap) | ||||
| 	defer tc() | ||||
| 
 | ||||
| 	m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ | ||||
| 		{ | ||||
| 			Request: testutil.Request{ | ||||
| 				Method: "GET", | ||||
| 				Route:  "/hello", | ||||
| 			}, | ||||
| 			Response: testutil.Response{ | ||||
| 				StatusCode: http.StatusAccepted, | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 
 | ||||
| 	authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service) | ||||
| 	validCheck := func(a string) bool { | ||||
| 		return a == "Bearer statictoken" | ||||
| 	} | ||||
| 	e, c := testServerWithAuth(m, authenicate, validCheck) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	client, err := e.HTTPClient(repo1) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error creating http client: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil) | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error sending get request: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != http.StatusAccepted { | ||||
| 		t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) | ||||
| 	} | ||||
| 
 | ||||
| 	badCheck := func(a string) bool { | ||||
| 		return a == "Bearer statictoken" | ||||
| 	} | ||||
| 	e2, c2 := testServerWithAuth(m, authenicate, badCheck) | ||||
| 	defer c2() | ||||
| 
 | ||||
| 	client2, err := e2.HTTPClient(repo2) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error creating http client: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	req, _ = http.NewRequest("GET", e.Endpoint+"/hello", nil) | ||||
| 	resp, err = client2.Do(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error sending get request: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != http.StatusUnauthorized { | ||||
| 		t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func basicAuth(username, password string) string { | ||||
| 	auth := username + ":" + password | ||||
| 	return base64.StdEncoding.EncodeToString([]byte(auth)) | ||||
| } | ||||
| 
 | ||||
| func TestEndpointAuthorizeTokenBasic(t *testing.T) { | ||||
| 	service := "localhost.localdomain" | ||||
| 	repo := "some/fun/registry" | ||||
| 	scope := fmt.Sprintf("repository:%s:pull,push", repo) | ||||
| 	username := "tokenuser" | ||||
| 	password := "superSecretPa$$word" | ||||
| 
 | ||||
| 	tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ | ||||
| 		{ | ||||
| 			Request: testutil.Request{ | ||||
| 				Method: "GET", | ||||
| 				Route:  fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service), | ||||
| 			}, | ||||
| 			Response: testutil.Response{ | ||||
| 				StatusCode: http.StatusOK, | ||||
| 				Body:       []byte(`{"token":"statictoken"}`), | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 
 | ||||
| 	authenicate1 := fmt.Sprintf("Basic realm=localhost") | ||||
| 	basicCheck := func(a string) bool { | ||||
| 		return a == fmt.Sprintf("Basic %s", basicAuth(username, password)) | ||||
| 	} | ||||
| 	te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck) | ||||
| 	defer tc() | ||||
| 
 | ||||
| 	m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ | ||||
| 		{ | ||||
| 			Request: testutil.Request{ | ||||
| 				Method: "GET", | ||||
| 				Route:  "/hello", | ||||
| 			}, | ||||
| 			Response: testutil.Response{ | ||||
| 				StatusCode: http.StatusAccepted, | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 
 | ||||
| 	authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service) | ||||
| 	bearerCheck := func(a string) bool { | ||||
| 		return a == "Bearer statictoken" | ||||
| 	} | ||||
| 	e, c := testServerWithAuth(m, authenicate2, bearerCheck) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	e.Credentials = &testCredentialStore{ | ||||
| 		username: username, | ||||
| 		password: password, | ||||
| 	} | ||||
| 
 | ||||
| 	client, err := e.HTTPClient(repo) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error creating http client: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil) | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error sending get request: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != http.StatusAccepted { | ||||
| 		t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestEndpointAuthorizeBasic(t *testing.T) { | ||||
| 	m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ | ||||
| 		{ | ||||
| 			Request: testutil.Request{ | ||||
| 				Method: "GET", | ||||
| 				Route:  "/hello", | ||||
| 			}, | ||||
| 			Response: testutil.Response{ | ||||
| 				StatusCode: http.StatusAccepted, | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 
 | ||||
| 	username := "user1" | ||||
| 	password := "funSecretPa$$word" | ||||
| 	authenicate := fmt.Sprintf("Basic realm=localhost") | ||||
| 	validCheck := func(a string) bool { | ||||
| 		return a == fmt.Sprintf("Basic %s", basicAuth(username, password)) | ||||
| 	} | ||||
| 	e, c := testServerWithAuth(m, authenicate, validCheck) | ||||
| 	defer c() | ||||
| 	e.Credentials = &testCredentialStore{ | ||||
| 		username: username, | ||||
| 		password: password, | ||||
| 	} | ||||
| 
 | ||||
| 	client, err := e.HTTPClient("test/repo/basic") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error creating http client: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil) | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error sending get request: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != http.StatusAccepted { | ||||
| 		t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) | ||||
| 	} | ||||
| } | ||||
|  | @ -25,8 +25,8 @@ import ( | |||
| 	"golang.org/x/net/context" | ||||
| ) | ||||
| 
 | ||||
| // NewRepositoryClient creates a new Repository for the given repository name and endpoint
 | ||||
| func NewRepositoryClient(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) { | ||||
| // NewRepository creates a new Repository for the given repository name and endpoint
 | ||||
| func NewRepository(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) { | ||||
| 	if err := v2.ValidateRespositoryName(name); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  |  | |||
|  | @ -97,7 +97,7 @@ func TestLayerFetch(t *testing.T) { | |||
| 	e, c := testServer(m) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e) | ||||
| 	r, err := NewRepository(context.Background(), "test.example.com/repo1", e) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | @ -127,7 +127,7 @@ func TestLayerExists(t *testing.T) { | |||
| 	e, c := testServer(m) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e) | ||||
| 	r, err := NewRepository(context.Background(), "test.example.com/repo1", e) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | @ -227,7 +227,7 @@ func TestLayerUploadChunked(t *testing.T) { | |||
| 	e, c := testServer(m) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	r, err := NewRepositoryClient(context.Background(), repo, e) | ||||
| 	r, err := NewRepository(context.Background(), repo, e) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | @ -334,7 +334,7 @@ func TestLayerUploadMonolithic(t *testing.T) { | |||
| 	e, c := testServer(m) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	r, err := NewRepositoryClient(context.Background(), repo, e) | ||||
| 	r, err := NewRepository(context.Background(), repo, e) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | @ -475,7 +475,7 @@ func TestManifestFetch(t *testing.T) { | |||
| 	e, c := testServer(m) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	r, err := NewRepositoryClient(context.Background(), repo, e) | ||||
| 	r, err := NewRepository(context.Background(), repo, e) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | @ -508,7 +508,7 @@ func TestManifestFetchByTag(t *testing.T) { | |||
| 	e, c := testServer(m) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	r, err := NewRepositoryClient(context.Background(), repo, e) | ||||
| 	r, err := NewRepository(context.Background(), repo, e) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | @ -553,7 +553,7 @@ func TestManifestDelete(t *testing.T) { | |||
| 	e, c := testServer(m) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	r, err := NewRepositoryClient(context.Background(), repo, e) | ||||
| 	r, err := NewRepository(context.Background(), repo, e) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | @ -591,7 +591,7 @@ func TestManifestPut(t *testing.T) { | |||
| 	e, c := testServer(m) | ||||
| 	defer c() | ||||
| 
 | ||||
| 	r, err := NewRepositoryClient(context.Background(), repo, e) | ||||
| 	r, err := NewRepository(context.Background(), repo, e) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
|  | @ -40,16 +41,18 @@ type Request struct { | |||
| func (r Request) String() string { | ||||
| 	queryString := "" | ||||
| 	if len(r.QueryParams) > 0 { | ||||
| 		queryString = "?" | ||||
| 		keys := make([]string, 0, len(r.QueryParams)) | ||||
| 		queryParts := make([]string, 0, len(r.QueryParams)) | ||||
| 		for k := range r.QueryParams { | ||||
| 			keys = append(keys, k) | ||||
| 		} | ||||
| 		sort.Strings(keys) | ||||
| 		for _, k := range keys { | ||||
| 			queryString += strings.Join(r.QueryParams[k], "&") + "&" | ||||
| 			for _, val := range r.QueryParams[k] { | ||||
| 				queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, url.QueryEscape(val))) | ||||
| 			} | ||||
| 		} | ||||
| 		queryString = queryString[:len(queryString)-1] | ||||
| 		queryString = "?" + strings.Join(queryParts, "&") | ||||
| 	} | ||||
| 	return fmt.Sprintf("%s %s%s\n%s", r.Method, r.Route, queryString, r.Body) | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue