Implementation of a basic authentication scheme using standard .htpasswd files
Signed-off-by: BadZen <dave.trombley@gmail.com> Signed-off-by: Dave Trombley <dave.trombley@gmail.com>master
							parent
							
								
									7363323321
								
							
						
					
					
						commit
						8a204f59e7
					
				|  | @ -0,0 +1,112 @@ | |||
| // Package basic provides a simple authentication scheme that checks for the
 | ||||
| // user credential hash in an htpasswd formatted file in a configuration-determined
 | ||||
| // location.
 | ||||
| //
 | ||||
| // The use of SHA hashes (htpasswd -s) is enforced since MD5 is insecure and simple
 | ||||
| // system crypt() may be as well.
 | ||||
| //
 | ||||
| // This authentication method MUST be used under TLS, as simple token-replay attack is possible.
 | ||||
| 
 | ||||
| package basic | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	ctxu "github.com/docker/distribution/context" | ||||
| 	"github.com/docker/distribution/registry/auth" | ||||
| 	"golang.org/x/net/context" | ||||
| ) | ||||
| 
 | ||||
| type accessController struct { | ||||
| 	realm    string | ||||
| 	htpasswd *HTPasswd | ||||
| } | ||||
| 
 | ||||
| type challenge struct { | ||||
| 	realm string | ||||
| 	err   error | ||||
| } | ||||
| 
 | ||||
| var _ auth.AccessController = &accessController{} | ||||
| var ( | ||||
| 	ErrPasswordRequired  = errors.New("authorization credential required") | ||||
| 	ErrInvalidCredential = errors.New("invalid authorization credential") | ||||
| ) | ||||
| 
 | ||||
| func newAccessController(options map[string]interface{}) (auth.AccessController, error) { | ||||
| 	realm, present := options["realm"] | ||||
| 	if _, ok := realm.(string); !present || !ok { | ||||
| 		return nil, fmt.Errorf(`"realm" must be set for basic access controller`) | ||||
| 	} | ||||
| 
 | ||||
| 	path, present := options["path"] | ||||
| 	if _, ok := path.(string); !present || !ok { | ||||
| 		return nil, fmt.Errorf(`"path" must be set for basic access controller`) | ||||
| 	} | ||||
| 
 | ||||
| 	return &accessController{realm: realm.(string), htpasswd: NewHTPasswd(path.(string))}, nil | ||||
| } | ||||
| 
 | ||||
| func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { | ||||
| 	req, err := ctxu.GetRequest(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	authHeader := req.Header.Get("Authorization") | ||||
| 
 | ||||
| 	if authHeader == "" { | ||||
| 		challenge := challenge{ | ||||
| 			realm: ac.realm, | ||||
| 		} | ||||
| 		return nil, &challenge | ||||
| 	} | ||||
| 
 | ||||
| 	parts := strings.Split(req.Header.Get("Authorization"), " ") | ||||
| 
 | ||||
| 	challenge := challenge{ | ||||
| 		realm: ac.realm, | ||||
| 	} | ||||
| 
 | ||||
| 	if len(parts) != 2 || strings.ToLower(parts[0]) != "basic" { | ||||
| 		challenge.err = ErrPasswordRequired | ||||
| 		return nil, &challenge | ||||
| 	} | ||||
| 
 | ||||
| 	text, err := base64.StdEncoding.DecodeString(parts[1]) | ||||
| 	if err != nil { | ||||
| 		challenge.err = ErrInvalidCredential | ||||
| 		return nil, &challenge | ||||
| 	} | ||||
| 
 | ||||
| 	credential := strings.Split(string(text), ":") | ||||
| 	if len(credential) != 2 { | ||||
| 		challenge.err = ErrInvalidCredential | ||||
| 		return nil, &challenge | ||||
| 	} | ||||
| 
 | ||||
| 	if res, _ := ac.htpasswd.AuthenticateUser(credential[0], credential[1]); !res { | ||||
| 		challenge.err = ErrInvalidCredential | ||||
| 		return nil, &challenge | ||||
| 	} | ||||
| 
 | ||||
| 	return auth.WithUser(ctx, auth.UserInfo{Name: credential[0]}), nil | ||||
| } | ||||
| 
 | ||||
| func (ch *challenge) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||
| 	header := fmt.Sprintf("Realm realm=%q", ch.realm) | ||||
| 	w.Header().Set("WWW-Authenticate", header) | ||||
| 	w.WriteHeader(http.StatusUnauthorized) | ||||
| } | ||||
| 
 | ||||
| func (ch *challenge) Error() string { | ||||
| 	return fmt.Sprintf("basic authentication challenge: %#v", ch) | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	auth.Register("basic", auth.InitFunc(newAccessController)) | ||||
| } | ||||
|  | @ -0,0 +1,100 @@ | |||
| package basic | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/docker/distribution/registry/auth" | ||||
| 	"golang.org/x/net/context" | ||||
| ) | ||||
| 
 | ||||
| func TestBasicAccessController(t *testing.T) { | ||||
| 
 | ||||
| 	testRealm := "The-Shire" | ||||
| 	testUser := "bilbo" | ||||
| 	testHtpasswdContent := "bilbo:{SHA}5siv5c0SHx681xU6GiSx9ZQryqs=" | ||||
| 
 | ||||
| 	tempFile, err := ioutil.TempFile("", "htpasswd-test") | ||||
| 	if err != nil { | ||||
| 		t.Fatal("could not create temporary htpasswd file") | ||||
| 	} | ||||
| 	if _, err = tempFile.WriteString(testHtpasswdContent); err != nil { | ||||
| 		t.Fatal("could not write temporary htpasswd file") | ||||
| 	} | ||||
| 
 | ||||
| 	options := map[string]interface{}{ | ||||
| 		"realm": testRealm, | ||||
| 		"path":  tempFile.Name(), | ||||
| 	} | ||||
| 
 | ||||
| 	accessController, err := newAccessController(options) | ||||
| 	if err != nil { | ||||
| 		t.Fatal("error creating access controller") | ||||
| 	} | ||||
| 
 | ||||
| 	tempFile.Close() | ||||
| 
 | ||||
| 	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		ctx := context.WithValue(nil, "http.request", r) | ||||
| 		authCtx, err := accessController.Authorized(ctx) | ||||
| 		if err != nil { | ||||
| 			switch err := err.(type) { | ||||
| 			case auth.Challenge: | ||||
| 				err.ServeHTTP(w, r) | ||||
| 				return | ||||
| 			default: | ||||
| 				t.Fatalf("unexpected error authorizing request: %v", err) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		userInfo, ok := authCtx.Value("auth.user").(auth.UserInfo) | ||||
| 		if !ok { | ||||
| 			t.Fatal("basic accessController did not set auth.user context") | ||||
| 		} | ||||
| 
 | ||||
| 		if userInfo.Name != testUser { | ||||
| 			t.Fatalf("expected user name %q, got %q", testUser, userInfo.Name) | ||||
| 		} | ||||
| 
 | ||||
| 		w.WriteHeader(http.StatusNoContent) | ||||
| 	})) | ||||
| 
 | ||||
| 	client := &http.Client{ | ||||
| 		CheckRedirect: nil, | ||||
| 	} | ||||
| 
 | ||||
| 	req, _ := http.NewRequest("GET", server.URL, nil) | ||||
| 	resp, err := client.Do(req) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("unexpected error during GET: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 
 | ||||
| 	// Request should not be authorized
 | ||||
| 	if resp.StatusCode != http.StatusUnauthorized { | ||||
| 		t.Fatalf("unexpected non-fail response status: %v != %v", resp.StatusCode, http.StatusUnauthorized) | ||||
| 	} | ||||
| 
 | ||||
| 	req, _ = http.NewRequest("GET", server.URL, nil) | ||||
| 
 | ||||
| 	sekrit := "bilbo:baggins" | ||||
| 	credential := "Basic " + base64.StdEncoding.EncodeToString([]byte(sekrit)) | ||||
| 
 | ||||
| 	req.Header.Set("Authorization", credential) | ||||
| 	resp, err = client.Do(req) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("unexpected error during GET: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 
 | ||||
| 	// Request should be authorized
 | ||||
| 	if resp.StatusCode != http.StatusNoContent { | ||||
| 		t.Fatalf("unexpected non-success response status: %v != %v", resp.StatusCode, http.StatusNoContent) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
|  | @ -0,0 +1,49 @@ | |||
| package basic | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/sha1" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/csv" | ||||
| 	"errors" | ||||
| 	"os" | ||||
| ) | ||||
| 
 | ||||
| var ErrSHARequired = errors.New("htpasswd file must use SHA (htpasswd -s)") | ||||
| 
 | ||||
| type HTPasswd struct { | ||||
| 	path   string | ||||
| 	reader *csv.Reader | ||||
| } | ||||
| 
 | ||||
| func NewHTPasswd(htpath string) *HTPasswd { | ||||
| 	return &HTPasswd{path: htpath} | ||||
| } | ||||
| 
 | ||||
| func (htpasswd *HTPasswd) AuthenticateUser(user string, pwd string) (bool, error) { | ||||
| 
 | ||||
| 	// Hash the credential.
 | ||||
| 	sha := sha1.New() | ||||
| 	sha.Write([]byte(pwd)) | ||||
| 	hash := base64.StdEncoding.EncodeToString(sha.Sum(nil)) | ||||
| 
 | ||||
| 	// Open the file.
 | ||||
| 	in, err := os.Open(htpasswd.path) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Parse the contents of the standard .htpasswd until we hit the end or find a match.
 | ||||
| 	reader := csv.NewReader(in) | ||||
| 	reader.Comma = ':' | ||||
| 	reader.Comment = '#' | ||||
| 	reader.TrimLeadingSpace = true | ||||
| 	for entry, readerr := reader.Read(); entry != nil || readerr != nil; entry, readerr = reader.Read() { | ||||
| 		if entry[0] == user { | ||||
| 			if len(entry[1]) < 6 || entry[1][0:5] != "{SHA}" { | ||||
| 				return false, ErrSHARequired | ||||
| 			} | ||||
| 			return entry[1][5:] == hash, nil | ||||
| 		} | ||||
| 	} | ||||
| 	return false, nil | ||||
| } | ||||
		Loading…
	
		Reference in New Issue