387 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			387 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Go
		
	
	
| package token
 | |
| 
 | |
| import (
 | |
| 	"crypto"
 | |
| 	"crypto/rand"
 | |
| 	"crypto/x509"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"encoding/pem"
 | |
| 	"fmt"
 | |
| 	"io/ioutil"
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"strings"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/docker/distribution/auth"
 | |
| 	"github.com/docker/libtrust"
 | |
| 	"golang.org/x/net/context"
 | |
| )
 | |
| 
 | |
| func makeRootKeys(numKeys int) ([]libtrust.PrivateKey, error) {
 | |
| 	keys := make([]libtrust.PrivateKey, 0, numKeys)
 | |
| 
 | |
| 	for i := 0; i < numKeys; i++ {
 | |
| 		key, err := libtrust.GenerateECP256PrivateKey()
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		keys = append(keys, key)
 | |
| 	}
 | |
| 
 | |
| 	return keys, nil
 | |
| }
 | |
| 
 | |
| func makeSigningKeyWithChain(rootKey libtrust.PrivateKey, depth int) (libtrust.PrivateKey, error) {
 | |
| 	if depth == 0 {
 | |
| 		// Don't need to build a chain.
 | |
| 		return rootKey, nil
 | |
| 	}
 | |
| 
 | |
| 	var (
 | |
| 		x5c       = make([]string, depth)
 | |
| 		parentKey = rootKey
 | |
| 		key       libtrust.PrivateKey
 | |
| 		cert      *x509.Certificate
 | |
| 		err       error
 | |
| 	)
 | |
| 
 | |
| 	for depth > 0 {
 | |
| 		if key, err = libtrust.GenerateECP256PrivateKey(); err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		if cert, err = libtrust.GenerateCACert(parentKey, key); err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		depth--
 | |
| 		x5c[depth] = base64.StdEncoding.EncodeToString(cert.Raw)
 | |
| 		parentKey = key
 | |
| 	}
 | |
| 
 | |
| 	key.AddExtendedField("x5c", x5c)
 | |
| 
 | |
| 	return key, nil
 | |
| }
 | |
| 
 | |
| func makeRootCerts(rootKeys []libtrust.PrivateKey) ([]*x509.Certificate, error) {
 | |
| 	certs := make([]*x509.Certificate, 0, len(rootKeys))
 | |
| 
 | |
| 	for _, key := range rootKeys {
 | |
| 		cert, err := libtrust.GenerateCACert(key, key)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		certs = append(certs, cert)
 | |
| 	}
 | |
| 
 | |
| 	return certs, nil
 | |
| }
 | |
| 
 | |
| func makeTrustedKeyMap(rootKeys []libtrust.PrivateKey) map[string]libtrust.PublicKey {
 | |
| 	trustedKeys := make(map[string]libtrust.PublicKey, len(rootKeys))
 | |
| 
 | |
| 	for _, key := range rootKeys {
 | |
| 		trustedKeys[key.KeyID()] = key.PublicKey()
 | |
| 	}
 | |
| 
 | |
| 	return trustedKeys
 | |
| }
 | |
| 
 | |
| func makeTestToken(issuer, audience string, access []*ResourceActions, rootKey libtrust.PrivateKey, depth int) (*Token, error) {
 | |
| 	signingKey, err := makeSigningKeyWithChain(rootKey, depth)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("unable to amke signing key with chain: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	rawJWK, err := signingKey.PublicKey().MarshalJSON()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("unable to marshal signing key to JSON: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	joseHeader := &Header{
 | |
| 		Type:       "JWT",
 | |
| 		SigningAlg: "ES256",
 | |
| 		RawJWK:     json.RawMessage(rawJWK),
 | |
| 	}
 | |
| 
 | |
| 	now := time.Now()
 | |
| 
 | |
| 	randomBytes := make([]byte, 15)
 | |
| 	if _, err = rand.Read(randomBytes); err != nil {
 | |
| 		return nil, fmt.Errorf("unable to read random bytes for jwt id: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	claimSet := &ClaimSet{
 | |
| 		Issuer:     issuer,
 | |
| 		Subject:    "foo",
 | |
| 		Audience:   audience,
 | |
| 		Expiration: now.Add(5 * time.Minute).Unix(),
 | |
| 		NotBefore:  now.Unix(),
 | |
| 		IssuedAt:   now.Unix(),
 | |
| 		JWTID:      base64.URLEncoding.EncodeToString(randomBytes),
 | |
| 		Access:     access,
 | |
| 	}
 | |
| 
 | |
| 	var joseHeaderBytes, claimSetBytes []byte
 | |
| 
 | |
| 	if joseHeaderBytes, err = json.Marshal(joseHeader); err != nil {
 | |
| 		return nil, fmt.Errorf("unable to marshal jose header: %s", err)
 | |
| 	}
 | |
| 	if claimSetBytes, err = json.Marshal(claimSet); err != nil {
 | |
| 		return nil, fmt.Errorf("unable to marshal claim set: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	encodedJoseHeader := joseBase64UrlEncode(joseHeaderBytes)
 | |
| 	encodedClaimSet := joseBase64UrlEncode(claimSetBytes)
 | |
| 	encodingToSign := fmt.Sprintf("%s.%s", encodedJoseHeader, encodedClaimSet)
 | |
| 
 | |
| 	var signatureBytes []byte
 | |
| 	if signatureBytes, _, err = signingKey.Sign(strings.NewReader(encodingToSign), crypto.SHA256); err != nil {
 | |
| 		return nil, fmt.Errorf("unable to sign jwt payload: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	signature := joseBase64UrlEncode(signatureBytes)
 | |
| 	tokenString := fmt.Sprintf("%s.%s", encodingToSign, signature)
 | |
| 
 | |
| 	return NewToken(tokenString)
 | |
| }
 | |
| 
 | |
| // This test makes 4 tokens with a varying number of intermediate
 | |
| // certificates ranging from no intermediate chain to a length of 3
 | |
| // intermediates.
 | |
| func TestTokenVerify(t *testing.T) {
 | |
| 	var (
 | |
| 		numTokens = 4
 | |
| 		issuer    = "test-issuer"
 | |
| 		audience  = "test-audience"
 | |
| 		access    = []*ResourceActions{
 | |
| 			{
 | |
| 				Type:    "repository",
 | |
| 				Name:    "foo/bar",
 | |
| 				Actions: []string{"pull", "push"},
 | |
| 			},
 | |
| 		}
 | |
| 	)
 | |
| 
 | |
| 	rootKeys, err := makeRootKeys(numTokens)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	rootCerts, err := makeRootCerts(rootKeys)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	rootPool := x509.NewCertPool()
 | |
| 	for _, rootCert := range rootCerts {
 | |
| 		rootPool.AddCert(rootCert)
 | |
| 	}
 | |
| 
 | |
| 	trustedKeys := makeTrustedKeyMap(rootKeys)
 | |
| 
 | |
| 	tokens := make([]*Token, 0, numTokens)
 | |
| 
 | |
| 	for i := 0; i < numTokens; i++ {
 | |
| 		token, err := makeTestToken(issuer, audience, access, rootKeys[i], i)
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		tokens = append(tokens, token)
 | |
| 	}
 | |
| 
 | |
| 	verifyOps := VerifyOptions{
 | |
| 		TrustedIssuers:    []string{issuer},
 | |
| 		AcceptedAudiences: []string{audience},
 | |
| 		Roots:             rootPool,
 | |
| 		TrustedKeys:       trustedKeys,
 | |
| 	}
 | |
| 
 | |
| 	for _, token := range tokens {
 | |
| 		if err := token.Verify(verifyOps); err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func writeTempRootCerts(rootKeys []libtrust.PrivateKey) (filename string, err error) {
 | |
| 	rootCerts, err := makeRootCerts(rootKeys)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	tempFile, err := ioutil.TempFile("", "rootCertBundle")
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	defer tempFile.Close()
 | |
| 
 | |
| 	for _, cert := range rootCerts {
 | |
| 		if err = pem.Encode(tempFile, &pem.Block{
 | |
| 			Type:  "CERTIFICATE",
 | |
| 			Bytes: cert.Raw,
 | |
| 		}); err != nil {
 | |
| 			os.Remove(tempFile.Name())
 | |
| 			return "", err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return tempFile.Name(), nil
 | |
| }
 | |
| 
 | |
| // TestAccessController tests complete integration of the token auth package.
 | |
| // It starts by mocking the options for a token auth accessController which
 | |
| // it creates. It then tries a few mock requests:
 | |
| // 		- don't supply a token; should error with challenge
 | |
| //		- supply an invalid token; should error with challenge
 | |
| // 		- supply a token with insufficient access; should error with challenge
 | |
| //		- supply a valid token; should not error
 | |
| func TestAccessController(t *testing.T) {
 | |
| 	// Make 2 keys; only the first is to be a trusted root key.
 | |
| 	rootKeys, err := makeRootKeys(2)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	rootCertBundleFilename, err := writeTempRootCerts(rootKeys[:1])
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	defer os.Remove(rootCertBundleFilename)
 | |
| 
 | |
| 	realm := "https://auth.example.com/token/"
 | |
| 	issuer := "test-issuer.example.com"
 | |
| 	service := "test-service.example.com"
 | |
| 
 | |
| 	options := map[string]interface{}{
 | |
| 		"realm":          realm,
 | |
| 		"issuer":         issuer,
 | |
| 		"service":        service,
 | |
| 		"rootCertBundle": rootCertBundleFilename,
 | |
| 	}
 | |
| 
 | |
| 	accessController, err := newAccessController(options)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// 1. Make a mock http.Request with no token.
 | |
| 	req, err := http.NewRequest("GET", "http://example.com/foo", nil)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	testAccess := auth.Access{
 | |
| 		Resource: auth.Resource{
 | |
| 			Type: "foo",
 | |
| 			Name: "bar",
 | |
| 		},
 | |
| 		Action: "baz",
 | |
| 	}
 | |
| 
 | |
| 	ctx := context.WithValue(nil, "http.request", req)
 | |
| 	authCtx, err := accessController.Authorized(ctx, testAccess)
 | |
| 	challenge, ok := err.(auth.Challenge)
 | |
| 	if !ok {
 | |
| 		t.Fatal("accessController did not return a challenge")
 | |
| 	}
 | |
| 
 | |
| 	if challenge.Error() != ErrTokenRequired.Error() {
 | |
| 		t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
 | |
| 	}
 | |
| 
 | |
| 	if authCtx != nil {
 | |
| 		t.Fatalf("expected nil auth context but got %s", authCtx)
 | |
| 	}
 | |
| 
 | |
| 	// 2. Supply an invalid token.
 | |
| 	token, err := makeTestToken(
 | |
| 		issuer, service,
 | |
| 		[]*ResourceActions{{
 | |
| 			Type:    testAccess.Type,
 | |
| 			Name:    testAccess.Name,
 | |
| 			Actions: []string{testAccess.Action},
 | |
| 		}},
 | |
| 		rootKeys[1], 1, // Everything is valid except the key which signed it.
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
 | |
| 
 | |
| 	authCtx, err = accessController.Authorized(ctx, testAccess)
 | |
| 	challenge, ok = err.(auth.Challenge)
 | |
| 	if !ok {
 | |
| 		t.Fatal("accessController did not return a challenge")
 | |
| 	}
 | |
| 
 | |
| 	if challenge.Error() != ErrInvalidToken.Error() {
 | |
| 		t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
 | |
| 	}
 | |
| 
 | |
| 	if authCtx != nil {
 | |
| 		t.Fatalf("expected nil auth context but got %s", authCtx)
 | |
| 	}
 | |
| 
 | |
| 	// 3. Supply a token with insufficient access.
 | |
| 	token, err = makeTestToken(
 | |
| 		issuer, service,
 | |
| 		[]*ResourceActions{}, // No access specified.
 | |
| 		rootKeys[0], 1,
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
 | |
| 
 | |
| 	authCtx, err = accessController.Authorized(ctx, testAccess)
 | |
| 	challenge, ok = err.(auth.Challenge)
 | |
| 	if !ok {
 | |
| 		t.Fatal("accessController did not return a challenge")
 | |
| 	}
 | |
| 
 | |
| 	if challenge.Error() != ErrInsufficientScope.Error() {
 | |
| 		t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrInsufficientScope)
 | |
| 	}
 | |
| 
 | |
| 	if authCtx != nil {
 | |
| 		t.Fatalf("expected nil auth context but got %s", authCtx)
 | |
| 	}
 | |
| 
 | |
| 	// 4. Supply the token we need, or deserve, or whatever.
 | |
| 	token, err = makeTestToken(
 | |
| 		issuer, service,
 | |
| 		[]*ResourceActions{{
 | |
| 			Type:    testAccess.Type,
 | |
| 			Name:    testAccess.Name,
 | |
| 			Actions: []string{testAccess.Action},
 | |
| 		}},
 | |
| 		rootKeys[0], 1,
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
 | |
| 
 | |
| 	authCtx, err = accessController.Authorized(ctx, testAccess)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("accessController returned unexpected error: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	userInfo, ok := authCtx.Value("auth.user").(auth.UserInfo)
 | |
| 	if !ok {
 | |
| 		t.Fatal("token accessController did not set auth.user context")
 | |
| 	}
 | |
| 
 | |
| 	if userInfo.Name != "foo" {
 | |
| 		t.Fatalf("expected user name %q, got %q", "foo", userInfo.Name)
 | |
| 	}
 | |
| }
 |