commit
						b1b100cf01
					
				|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"sync" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Challenge carries information from a WWW-Authenticate response header.
 | // Challenge carries information from a WWW-Authenticate response header.
 | ||||||
|  | @ -43,19 +44,26 @@ type ChallengeManager interface { | ||||||
| // perform requests on the endpoints or cache the responses
 | // perform requests on the endpoints or cache the responses
 | ||||||
| // to a backend.
 | // to a backend.
 | ||||||
| func NewSimpleChallengeManager() ChallengeManager { | func NewSimpleChallengeManager() ChallengeManager { | ||||||
| 	return simpleChallengeManager{} | 	return &simpleChallengeManager{ | ||||||
|  | 		Challanges: make(map[string][]Challenge), | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type simpleChallengeManager map[string][]Challenge | type simpleChallengeManager struct { | ||||||
|  | 	sync.RWMutex | ||||||
|  | 	Challanges map[string][]Challenge | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| func (m simpleChallengeManager) GetChallenges(endpoint url.URL) ([]Challenge, error) { | func (m *simpleChallengeManager) GetChallenges(endpoint url.URL) ([]Challenge, error) { | ||||||
| 	endpoint.Host = strings.ToLower(endpoint.Host) | 	endpoint.Host = strings.ToLower(endpoint.Host) | ||||||
| 
 | 
 | ||||||
| 	challenges := m[endpoint.String()] | 	m.RLock() | ||||||
|  | 	defer m.RUnlock() | ||||||
|  | 	challenges := m.Challanges[endpoint.String()] | ||||||
| 	return challenges, nil | 	return challenges, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m simpleChallengeManager) AddResponse(resp *http.Response) error { | func (m *simpleChallengeManager) AddResponse(resp *http.Response) error { | ||||||
| 	challenges := ResponseChallenges(resp) | 	challenges := ResponseChallenges(resp) | ||||||
| 	if resp.Request == nil { | 	if resp.Request == nil { | ||||||
| 		return fmt.Errorf("missing request reference") | 		return fmt.Errorf("missing request reference") | ||||||
|  | @ -65,7 +73,9 @@ func (m simpleChallengeManager) AddResponse(resp *http.Response) error { | ||||||
| 		Host:   strings.ToLower(resp.Request.URL.Host), | 		Host:   strings.ToLower(resp.Request.URL.Host), | ||||||
| 		Scheme: resp.Request.URL.Scheme, | 		Scheme: resp.Request.URL.Scheme, | ||||||
| 	} | 	} | ||||||
| 	m[urlCopy.String()] = challenges | 	m.Lock() | ||||||
|  | 	defer m.Unlock() | ||||||
|  | 	m.Challanges[urlCopy.String()] = challenges | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -43,6 +44,7 @@ func TestAuthChallengeParse(t *testing.T) { | ||||||
| func TestAuthChallengeNormalization(t *testing.T) { | func TestAuthChallengeNormalization(t *testing.T) { | ||||||
| 	testAuthChallengeNormalization(t, "reg.EXAMPLE.com") | 	testAuthChallengeNormalization(t, "reg.EXAMPLE.com") | ||||||
| 	testAuthChallengeNormalization(t, "bɿɒʜɔiɿ-ɿɘƚƨim-ƚol-ɒ-ƨʞnɒʜƚ.com") | 	testAuthChallengeNormalization(t, "bɿɒʜɔiɿ-ɿɘƚƨim-ƚol-ɒ-ƨʞnɒʜƚ.com") | ||||||
|  | 	testAuthChallengeConcurrent(t, "reg.EXAMPLE.com") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func testAuthChallengeNormalization(t *testing.T, host string) { | func testAuthChallengeNormalization(t *testing.T, host string) { | ||||||
|  | @ -79,3 +81,45 @@ func testAuthChallengeNormalization(t *testing.T, host string) { | ||||||
| 		t.Fatal("Expected challenge for lower-cased-host URL") | 		t.Fatal("Expected challenge for lower-cased-host URL") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func testAuthChallengeConcurrent(t *testing.T, host string) { | ||||||
|  | 
 | ||||||
|  | 	scm := NewSimpleChallengeManager() | ||||||
|  | 
 | ||||||
|  | 	url, err := url.Parse(fmt.Sprintf("http://%s/v2/", host)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	resp := &http.Response{ | ||||||
|  | 		Request: &http.Request{ | ||||||
|  | 			URL: url, | ||||||
|  | 		}, | ||||||
|  | 		Header:     make(http.Header), | ||||||
|  | 		StatusCode: http.StatusUnauthorized, | ||||||
|  | 	} | ||||||
|  | 	resp.Header.Add("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"https://%s/token\",service=\"registry.example.com\"", host)) | ||||||
|  | 	var s sync.WaitGroup | ||||||
|  | 	s.Add(2) | ||||||
|  | 	go func() { | ||||||
|  | 		defer s.Done() | ||||||
|  | 		for i := 0; i < 200; i++ { | ||||||
|  | 			err = scm.AddResponse(resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Error(err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 	go func() { | ||||||
|  | 		defer s.Done() | ||||||
|  | 		lowered := *url | ||||||
|  | 		lowered.Host = strings.ToLower(lowered.Host) | ||||||
|  | 		for k := 0; k < 200; k++ { | ||||||
|  | 			_, err := scm.GetChallenges(lowered) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Error(err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 	s.Wait() | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue