Simplify configuration and transport
Repository creation now just takes in an http.RoundTripper. Authenticated requests or requests which require additional headers should use the NewTransport function along with a request modifier (such an an authentication handler). Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)master
							parent
							
								
									8b0ea19d39
								
							
						
					
					
						commit
						89c396e0f5
					
				| 
						 | 
				
			
			@ -124,13 +124,8 @@ func TestUploadReadFrom(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	repoConfig := &RepositoryConfig{}
 | 
			
		||||
	client, err := repoConfig.HTTPClient()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Error creating client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	layerUpload := &httpLayerUpload{
 | 
			
		||||
		client: client,
 | 
			
		||||
		client: &http.Client{},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Valid case
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,7 +20,7 @@ import (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
// NewRepository creates a new Repository for the given repository name and endpoint
 | 
			
		||||
func NewRepository(ctx context.Context, name, endpoint string, repoConfig *RepositoryConfig) (distribution.Repository, error) {
 | 
			
		||||
func NewRepository(ctx context.Context, name, endpoint string, transport http.RoundTripper) (distribution.Repository, error) {
 | 
			
		||||
	if err := v2.ValidateRespositoryName(name); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -30,9 +30,10 @@ func NewRepository(ctx context.Context, name, endpoint string, repoConfig *Repos
 | 
			
		|||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := repoConfig.HTTPClient()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	client := &http.Client{
 | 
			
		||||
		Transport: transport,
 | 
			
		||||
		Timeout:   1 * time.Minute,
 | 
			
		||||
		// TODO(dmcgowan): create cookie jar
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &repository{
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,7 +97,7 @@ func TestLayerFetch(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	r, err := NewRepository(context.Background(), "test.example.com/repo1", e, &RepositoryConfig{})
 | 
			
		||||
	r, err := NewRepository(context.Background(), "test.example.com/repo1", e, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -127,7 +127,7 @@ func TestLayerExists(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	r, err := NewRepository(context.Background(), "test.example.com/repo1", e, &RepositoryConfig{})
 | 
			
		||||
	r, err := NewRepository(context.Background(), "test.example.com/repo1", e, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -227,7 +227,7 @@ func TestLayerUploadChunked(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -334,7 +334,7 @@ func TestLayerUploadMonolithic(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -475,7 +475,7 @@ func TestManifestFetch(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -508,7 +508,7 @@ func TestManifestFetchByTag(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -553,7 +553,7 @@ func TestManifestDelete(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -591,7 +591,7 @@ func TestManifestPut(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -636,7 +636,7 @@ func TestManifestTags(t *testing.T) {
 | 
			
		|||
	e, c := testServer(m)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{})
 | 
			
		||||
	r, err := NewRepository(context.Background(), repo, e, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,12 +11,6 @@ import (
 | 
			
		|||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Authorizer is used to apply Authorization to an HTTP request
 | 
			
		||||
type Authorizer interface {
 | 
			
		||||
	// Authorizer updates an HTTP request with the needed authorization
 | 
			
		||||
	Authorize(req *http.Request) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AuthenticationHandler is an interface for authorizing a request from
 | 
			
		||||
// params from a "WWW-Authenicate" header for a single scheme.
 | 
			
		||||
type AuthenticationHandler interface {
 | 
			
		||||
| 
						 | 
				
			
			@ -31,54 +25,11 @@ type CredentialStore interface {
 | 
			
		|||
	Basic(*url.URL) (string, string)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RepositoryConfig holds the base configuration needed to communicate
 | 
			
		||||
// with a registry including a method of authorization and HTTP headers.
 | 
			
		||||
type RepositoryConfig struct {
 | 
			
		||||
	Header     http.Header
 | 
			
		||||
	AuthSource Authorizer
 | 
			
		||||
 | 
			
		||||
	BaseTransport http.RoundTripper
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HTTPClient returns a new HTTP client configured for this configuration
 | 
			
		||||
func (rc *RepositoryConfig) HTTPClient() (*http.Client, error) {
 | 
			
		||||
	transport := &Transport{
 | 
			
		||||
		ExtraHeader: rc.Header,
 | 
			
		||||
		AuthSource:  rc.AuthSource,
 | 
			
		||||
		Base:        rc.BaseTransport,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := &http.Client{
 | 
			
		||||
		Transport: transport,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return client, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewTokenAuthorizer returns an authorizer which is capable of getting a token
 | 
			
		||||
// from a token server. The expected authorization method will be discovered
 | 
			
		||||
// by the authorizer, getting the token server endpoint from the URL being
 | 
			
		||||
// requested. Basic authentication may either be done to the token source or
 | 
			
		||||
// directly with the requested endpoint depending on the endpoint's
 | 
			
		||||
// WWW-Authenticate header.
 | 
			
		||||
func NewTokenAuthorizer(creds CredentialStore, transport http.RoundTripper, header http.Header, scope TokenScope) Authorizer {
 | 
			
		||||
	return &tokenAuthorizer{
 | 
			
		||||
		header:     header,
 | 
			
		||||
		challenges: map[string]map[string]authorizationChallenge{},
 | 
			
		||||
		handlers: []AuthenticationHandler{
 | 
			
		||||
			NewTokenHandler(transport, creds, scope, header),
 | 
			
		||||
			NewBasicHandler(creds),
 | 
			
		||||
		},
 | 
			
		||||
		transport: transport,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewAuthorizer creates an authorizer which can handle multiple authentication
 | 
			
		||||
// schemes. The handlers are tried in order, the higher priority authentication
 | 
			
		||||
// methods should be first.
 | 
			
		||||
func NewAuthorizer(transport http.RoundTripper, header http.Header, handlers ...AuthenticationHandler) Authorizer {
 | 
			
		||||
func NewAuthorizer(transport http.RoundTripper, handlers ...AuthenticationHandler) RequestModifier {
 | 
			
		||||
	return &tokenAuthorizer{
 | 
			
		||||
		header:     header,
 | 
			
		||||
		challenges: map[string]map[string]authorizationChallenge{},
 | 
			
		||||
		handlers:   handlers,
 | 
			
		||||
		transport:  transport,
 | 
			
		||||
| 
						 | 
				
			
			@ -86,7 +37,6 @@ func NewAuthorizer(transport http.RoundTripper, header http.Header, handlers ...
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
type tokenAuthorizer struct {
 | 
			
		||||
	header     http.Header
 | 
			
		||||
	challenges map[string]map[string]authorizationChallenge
 | 
			
		||||
	handlers   []AuthenticationHandler
 | 
			
		||||
	transport  http.RoundTripper
 | 
			
		||||
| 
						 | 
				
			
			@ -99,10 +49,7 @@ func (ta *tokenAuthorizer) ping(endpoint string) (map[string]authorizationChalle
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	client := &http.Client{
 | 
			
		||||
		Transport: &Transport{
 | 
			
		||||
			ExtraHeader: ta.header,
 | 
			
		||||
			Base:        ta.transport,
 | 
			
		||||
		},
 | 
			
		||||
		Transport: ta.transport,
 | 
			
		||||
		// Ping should fail fast
 | 
			
		||||
		Timeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -140,7 +87,7 @@ HeaderLoop:
 | 
			
		|||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ta *tokenAuthorizer) Authorize(req *http.Request) error {
 | 
			
		||||
func (ta *tokenAuthorizer) ModifyRequest(req *http.Request) error {
 | 
			
		||||
	v2Root := strings.Index(req.URL.Path, "/v2/")
 | 
			
		||||
	if v2Root == -1 {
 | 
			
		||||
		return nil
 | 
			
		||||
| 
						 | 
				
			
			@ -195,54 +142,52 @@ type TokenScope struct {
 | 
			
		|||
	Actions  []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewTokenHandler creates a new AuthenicationHandler which supports
 | 
			
		||||
// fetching tokens from a remote token server.
 | 
			
		||||
func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope TokenScope, header http.Header) AuthenticationHandler {
 | 
			
		||||
	return &tokenHandler{
 | 
			
		||||
		header: header,
 | 
			
		||||
		creds:  creds,
 | 
			
		||||
		scope:  scope,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts TokenScope) String() string {
 | 
			
		||||
	return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ","))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *tokenHandler) client() *http.Client {
 | 
			
		||||
	return &http.Client{
 | 
			
		||||
		Transport: &Transport{
 | 
			
		||||
			ExtraHeader: ts.header,
 | 
			
		||||
			Base:        ts.transport,
 | 
			
		||||
		},
 | 
			
		||||
// NewTokenHandler creates a new AuthenicationHandler which supports
 | 
			
		||||
// fetching tokens from a remote token server.
 | 
			
		||||
func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope TokenScope) AuthenticationHandler {
 | 
			
		||||
	return &tokenHandler{
 | 
			
		||||
		transport: transport,
 | 
			
		||||
		creds:     creds,
 | 
			
		||||
		scope:     scope,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *tokenHandler) Scheme() string {
 | 
			
		||||
func (th *tokenHandler) client() *http.Client {
 | 
			
		||||
	return &http.Client{
 | 
			
		||||
		Transport: th.transport,
 | 
			
		||||
		Timeout:   15 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (th *tokenHandler) Scheme() string {
 | 
			
		||||
	return "bearer"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error {
 | 
			
		||||
	if err := ts.refreshToken(params); err != nil {
 | 
			
		||||
func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error {
 | 
			
		||||
	if err := th.refreshToken(params); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.tokenCache))
 | 
			
		||||
	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.tokenCache))
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *tokenHandler) refreshToken(params map[string]string) error {
 | 
			
		||||
	ts.tokenLock.Lock()
 | 
			
		||||
	defer ts.tokenLock.Unlock()
 | 
			
		||||
func (th *tokenHandler) refreshToken(params map[string]string) error {
 | 
			
		||||
	th.tokenLock.Lock()
 | 
			
		||||
	defer th.tokenLock.Unlock()
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	if now.After(ts.tokenExpiration) {
 | 
			
		||||
		token, err := ts.fetchToken(params)
 | 
			
		||||
	if now.After(th.tokenExpiration) {
 | 
			
		||||
		token, err := th.fetchToken(params)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		ts.tokenCache = token
 | 
			
		||||
		ts.tokenExpiration = now.Add(time.Minute)
 | 
			
		||||
		th.tokenCache = token
 | 
			
		||||
		th.tokenExpiration = now.Add(time.Minute)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
| 
						 | 
				
			
			@ -252,7 +197,7 @@ type tokenResponse struct {
 | 
			
		|||
	Token string `json:"token"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err error) {
 | 
			
		||||
func (th *tokenHandler) fetchToken(params map[string]string) (token string, err error) {
 | 
			
		||||
	//log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username)
 | 
			
		||||
	realm, ok := params["realm"]
 | 
			
		||||
	if !ok {
 | 
			
		||||
| 
						 | 
				
			
			@ -273,7 +218,7 @@ func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err
 | 
			
		|||
 | 
			
		||||
	reqParams := req.URL.Query()
 | 
			
		||||
	service := params["service"]
 | 
			
		||||
	scope := ts.scope.String()
 | 
			
		||||
	scope := th.scope.String()
 | 
			
		||||
 | 
			
		||||
	if service != "" {
 | 
			
		||||
		reqParams.Add("service", service)
 | 
			
		||||
| 
						 | 
				
			
			@ -283,8 +228,8 @@ func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err
 | 
			
		|||
		reqParams.Add("scope", scopeField)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ts.creds != nil {
 | 
			
		||||
		username, password := ts.creds.Basic(realmURL)
 | 
			
		||||
	if th.creds != nil {
 | 
			
		||||
		username, password := th.creds.Basic(realmURL)
 | 
			
		||||
		if username != "" && password != "" {
 | 
			
		||||
			reqParams.Add("account", username)
 | 
			
		||||
			req.SetBasicAuth(username, password)
 | 
			
		||||
| 
						 | 
				
			
			@ -293,7 +238,7 @@ func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err
 | 
			
		|||
 | 
			
		||||
	req.URL.RawQuery = reqParams.Encode()
 | 
			
		||||
 | 
			
		||||
	resp, err := ts.client().Do(req)
 | 
			
		||||
	resp, err := th.client().Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -116,14 +116,8 @@ func TestEndpointAuthorizeToken(t *testing.T) {
 | 
			
		|||
	e, c := testServerWithAuth(m, authenicate, validCheck)
 | 
			
		||||
	defer c()
 | 
			
		||||
 | 
			
		||||
	repo1Config := &RepositoryConfig{
 | 
			
		||||
		AuthSource: NewTokenAuthorizer(nil, nil, nil, tokenScope1),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := repo1Config.HTTPClient()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Error creating http client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	transport1 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, nil, tokenScope1)))
 | 
			
		||||
	client := &http.Client{Transport: transport1}
 | 
			
		||||
 | 
			
		||||
	req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
| 
						 | 
				
			
			@ -141,13 +135,8 @@ func TestEndpointAuthorizeToken(t *testing.T) {
 | 
			
		|||
	e2, c2 := testServerWithAuth(m, authenicate, badCheck)
 | 
			
		||||
	defer c2()
 | 
			
		||||
 | 
			
		||||
	repo2Config := &RepositoryConfig{
 | 
			
		||||
		AuthSource: NewTokenAuthorizer(nil, nil, nil, tokenScope2),
 | 
			
		||||
	}
 | 
			
		||||
	client2, err := repo2Config.HTTPClient()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Error creating http client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	transport2 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, nil, tokenScope2)))
 | 
			
		||||
	client2 := &http.Client{Transport: transport2}
 | 
			
		||||
 | 
			
		||||
	req, _ = http.NewRequest("GET", e2+"/v2/hello", nil)
 | 
			
		||||
	resp, err = client2.Do(req)
 | 
			
		||||
| 
						 | 
				
			
			@ -220,14 +209,9 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) {
 | 
			
		|||
		username: username,
 | 
			
		||||
		password: password,
 | 
			
		||||
	}
 | 
			
		||||
	repoConfig := &RepositoryConfig{
 | 
			
		||||
		AuthSource: NewTokenAuthorizer(creds, nil, nil, tokenScope),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := repoConfig.HTTPClient()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Error creating http client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	transport1 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, creds, tokenScope), NewBasicHandler(creds)))
 | 
			
		||||
	client := &http.Client{Transport: transport1}
 | 
			
		||||
 | 
			
		||||
	req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
| 
						 | 
				
			
			@ -265,14 +249,9 @@ func TestEndpointAuthorizeBasic(t *testing.T) {
 | 
			
		|||
		username: username,
 | 
			
		||||
		password: password,
 | 
			
		||||
	}
 | 
			
		||||
	repoConfig := &RepositoryConfig{
 | 
			
		||||
		AuthSource: NewTokenAuthorizer(creds, nil, nil, TokenScope{}),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := repoConfig.HTTPClient()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Error creating http client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	transport1 := NewTransport(nil, NewAuthorizer(nil, NewBasicHandler(creds)))
 | 
			
		||||
	client := &http.Client{Transport: transport1}
 | 
			
		||||
 | 
			
		||||
	req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,14 +6,36 @@ import (
 | 
			
		|||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Transport is an http.RoundTripper that makes registry HTTP requests,
 | 
			
		||||
// wrapping a base RoundTripper and adding an Authorization header
 | 
			
		||||
// from an Auth source
 | 
			
		||||
type Transport struct {
 | 
			
		||||
	AuthSource  Authorizer
 | 
			
		||||
	ExtraHeader http.Header
 | 
			
		||||
type RequestModifier interface {
 | 
			
		||||
	ModifyRequest(*http.Request) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	Base http.RoundTripper
 | 
			
		||||
type headerModifier http.Header
 | 
			
		||||
 | 
			
		||||
func NewHeaderRequestModifier(header http.Header) RequestModifier {
 | 
			
		||||
	return headerModifier(header)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h headerModifier) ModifyRequest(req *http.Request) error {
 | 
			
		||||
	for k, s := range http.Header(h) {
 | 
			
		||||
		req.Header[k] = append(req.Header[k], s...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper {
 | 
			
		||||
	return &transport{
 | 
			
		||||
		Modifiers: modifiers,
 | 
			
		||||
		Base:      base,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// transport is an http.RoundTripper that makes HTTP requests after
 | 
			
		||||
// copying and modifying the request
 | 
			
		||||
type transport struct {
 | 
			
		||||
	Modifiers []RequestModifier
 | 
			
		||||
	Base      http.RoundTripper
 | 
			
		||||
 | 
			
		||||
	mu     sync.Mutex                      // guards modReq
 | 
			
		||||
	modReq map[*http.Request]*http.Request // original -> modified
 | 
			
		||||
| 
						 | 
				
			
			@ -22,13 +44,14 @@ type Transport struct {
 | 
			
		|||
// RoundTrip authorizes and authenticates the request with an
 | 
			
		||||
// access token. If no token exists or token is expired,
 | 
			
		||||
// tries to refresh/fetch a new token.
 | 
			
		||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 | 
			
		||||
	req2 := t.cloneRequest(req)
 | 
			
		||||
	if t.AuthSource != nil {
 | 
			
		||||
		if err := t.AuthSource.Authorize(req2); err != nil {
 | 
			
		||||
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
 | 
			
		||||
	req2 := cloneRequest(req)
 | 
			
		||||
	for _, modifier := range t.Modifiers {
 | 
			
		||||
		if err := modifier.ModifyRequest(req2); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.setModReq(req, req2)
 | 
			
		||||
	res, err := t.base().RoundTrip(req2)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -43,7 +66,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// CancelRequest cancels an in-flight request by closing its connection.
 | 
			
		||||
func (t *Transport) CancelRequest(req *http.Request) {
 | 
			
		||||
func (t *transport) CancelRequest(req *http.Request) {
 | 
			
		||||
	type canceler interface {
 | 
			
		||||
		CancelRequest(*http.Request)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -56,14 +79,14 @@ func (t *Transport) CancelRequest(req *http.Request) {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Transport) base() http.RoundTripper {
 | 
			
		||||
func (t *transport) base() http.RoundTripper {
 | 
			
		||||
	if t.Base != nil {
 | 
			
		||||
		return t.Base
 | 
			
		||||
	}
 | 
			
		||||
	return http.DefaultTransport
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Transport) setModReq(orig, mod *http.Request) {
 | 
			
		||||
func (t *transport) setModReq(orig, mod *http.Request) {
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	defer t.mu.Unlock()
 | 
			
		||||
	if t.modReq == nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -78,7 +101,7 @@ func (t *Transport) setModReq(orig, mod *http.Request) {
 | 
			
		|||
 | 
			
		||||
// cloneRequest returns a clone of the provided *http.Request.
 | 
			
		||||
// The clone is a shallow copy of the struct and its Header map.
 | 
			
		||||
func (t *Transport) cloneRequest(r *http.Request) *http.Request {
 | 
			
		||||
func cloneRequest(r *http.Request) *http.Request {
 | 
			
		||||
	// shallow copy of the struct
 | 
			
		||||
	r2 := new(http.Request)
 | 
			
		||||
	*r2 = *r
 | 
			
		||||
| 
						 | 
				
			
			@ -87,9 +110,7 @@ func (t *Transport) cloneRequest(r *http.Request) *http.Request {
 | 
			
		|||
	for k, s := range r.Header {
 | 
			
		||||
		r2.Header[k] = append([]string(nil), s...)
 | 
			
		||||
	}
 | 
			
		||||
	for k, s := range t.ExtraHeader {
 | 
			
		||||
		r2.Header[k] = append(r2.Header[k], s...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return r2
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue