Add simple implementation of token server
Token server implementation currently functional with existing docker 1.9.x release and latest distribution release. Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)master
							parent
							
								
									caa2001e1f
								
							
						
					
					
						commit
						eaa9da0be3
					
				|  | @ -0,0 +1,202 @@ | ||||||
|  | package main | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"flag" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"github.com/Sirupsen/logrus" | ||||||
|  | 	"github.com/docker/distribution/context" | ||||||
|  | 	"github.com/docker/distribution/registry/api/errcode" | ||||||
|  | 	"github.com/docker/distribution/registry/auth" | ||||||
|  | 	_ "github.com/docker/distribution/registry/auth/htpasswd" | ||||||
|  | 	"github.com/docker/libtrust" | ||||||
|  | 	"github.com/gorilla/mux" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func main() { | ||||||
|  | 	var ( | ||||||
|  | 		issuer = &TokenIssuer{} | ||||||
|  | 		pkFile string | ||||||
|  | 		addr   string | ||||||
|  | 		debug  bool | ||||||
|  | 		err    error | ||||||
|  | 
 | ||||||
|  | 		passwdFile string | ||||||
|  | 		realm      string | ||||||
|  | 
 | ||||||
|  | 		cert    string | ||||||
|  | 		certKey string | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	flag.StringVar(&issuer.Issuer, "issuer", "distribution-token-server", "Issuer string for token") | ||||||
|  | 	flag.StringVar(&pkFile, "key", "", "Private key file") | ||||||
|  | 	flag.StringVar(&addr, "addr", "localhost:8080", "Address to listen on") | ||||||
|  | 	flag.BoolVar(&debug, "debug", false, "Debug mode") | ||||||
|  | 
 | ||||||
|  | 	flag.StringVar(&passwdFile, "passwd", ".htpasswd", "Passwd file") | ||||||
|  | 	flag.StringVar(&realm, "realm", "", "Authentication realm") | ||||||
|  | 
 | ||||||
|  | 	flag.StringVar(&cert, "tlscert", "", "Certificate file for TLS") | ||||||
|  | 	flag.StringVar(&certKey, "tlskey", "", "Certificate key for TLS") | ||||||
|  | 
 | ||||||
|  | 	flag.Parse() | ||||||
|  | 
 | ||||||
|  | 	if debug { | ||||||
|  | 		logrus.SetLevel(logrus.DebugLevel) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if pkFile == "" { | ||||||
|  | 		issuer.SigningKey, err = libtrust.GenerateECP256PrivateKey() | ||||||
|  | 		if err != nil { | ||||||
|  | 			logrus.Fatalf("Error generating private key: %v", err) | ||||||
|  | 		} | ||||||
|  | 		logrus.Debugf("Using newly generated key with id %s", issuer.SigningKey.KeyID()) | ||||||
|  | 	} else { | ||||||
|  | 		issuer.SigningKey, err = libtrust.LoadKeyFile(pkFile) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logrus.Fatalf("Error loading key file %s: %v", pkFile, err) | ||||||
|  | 		} | ||||||
|  | 		logrus.Debugf("Loaded private key with id %s", issuer.SigningKey.KeyID()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if realm == "" { | ||||||
|  | 		logrus.Fatalf("Must provide realm") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ac, err := auth.GetAccessController("htpasswd", map[string]interface{}{ | ||||||
|  | 		"realm": realm, | ||||||
|  | 		"path":  passwdFile, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logrus.Fatalf("Error initializing access controller: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	ts := &tokenServer{ | ||||||
|  | 		issuer:           issuer, | ||||||
|  | 		accessController: ac, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	router := mux.NewRouter() | ||||||
|  | 	router.Path("/token/").Methods("GET").Handler(handlerWithContext(ctx, ts.getToken)) | ||||||
|  | 
 | ||||||
|  | 	if cert == "" { | ||||||
|  | 		err = http.ListenAndServe(addr, router) | ||||||
|  | 	} else if certKey == "" { | ||||||
|  | 		logrus.Fatalf("Must provide certficate and key") | ||||||
|  | 	} else { | ||||||
|  | 		err = http.ListenAndServeTLS(addr, cert, certKey, router) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		logrus.Infof("Error serving: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // handlerWithContext wraps the given context-aware handler by setting up the
 | ||||||
|  | // request context from a base context.
 | ||||||
|  | func handlerWithContext(ctx context.Context, handler func(context.Context, http.ResponseWriter, *http.Request)) http.Handler { | ||||||
|  | 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		ctx := context.WithRequest(ctx, r) | ||||||
|  | 		logger := context.GetRequestLogger(ctx) | ||||||
|  | 		ctx = context.WithLogger(ctx, logger) | ||||||
|  | 
 | ||||||
|  | 		handler(ctx, w, r) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func handleError(ctx context.Context, err error, w http.ResponseWriter) { | ||||||
|  | 	ctx, w = context.WithResponseWriter(ctx, w) | ||||||
|  | 
 | ||||||
|  | 	if serveErr := errcode.ServeJSON(w, err); serveErr != nil { | ||||||
|  | 		context.GetResponseLogger(ctx).Errorf("error sending error response: %v", serveErr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	context.GetResponseLogger(ctx).Info("application error") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type tokenServer struct { | ||||||
|  | 	issuer           *TokenIssuer | ||||||
|  | 	accessController auth.AccessController | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getToken handles authenticating the request and authorizing access to the
 | ||||||
|  | // requested scopes.
 | ||||||
|  | func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	context.GetLogger(ctx).Info("getToken") | ||||||
|  | 
 | ||||||
|  | 	params := r.URL.Query() | ||||||
|  | 	service := params.Get("service") | ||||||
|  | 	scopeSpecifiers := params["scope"] | ||||||
|  | 
 | ||||||
|  | 	requestedAccessList := ResolveScopeSpecifiers(scopeSpecifiers) | ||||||
|  | 
 | ||||||
|  | 	authorizedCtx, err := ts.accessController.Authorized(ctx, requestedAccessList...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		challenge, ok := err.(auth.Challenge) | ||||||
|  | 		if !ok { | ||||||
|  | 			handleError(ctx, err, w) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Get response context.
 | ||||||
|  | 		ctx, w = context.WithResponseWriter(ctx, w) | ||||||
|  | 
 | ||||||
|  | 		challenge.SetHeaders(w) | ||||||
|  | 		handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail(challenge.Error()), w) | ||||||
|  | 
 | ||||||
|  | 		context.GetResponseLogger(ctx).Info("authentication challenged") | ||||||
|  | 
 | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	ctx = authorizedCtx | ||||||
|  | 
 | ||||||
|  | 	// TODO(dmcgowan): handle case where this could panic?
 | ||||||
|  | 	username := ctx.Value("auth.user.name").(string) | ||||||
|  | 
 | ||||||
|  | 	ctx = context.WithValue(ctx, "acctSubject", username) | ||||||
|  | 	ctx = context.WithLogger(ctx, context.GetLogger(ctx, "acctSubject")) | ||||||
|  | 
 | ||||||
|  | 	context.GetLogger(ctx).Info("authenticated client") | ||||||
|  | 
 | ||||||
|  | 	ctx = context.WithValue(ctx, "requestedAccess", requestedAccessList) | ||||||
|  | 	ctx = context.WithLogger(ctx, context.GetLogger(ctx, "requestedAccess")) | ||||||
|  | 
 | ||||||
|  | 	scopePrefix := username + "/" | ||||||
|  | 	grantedAccessList := make([]auth.Access, 0, len(requestedAccessList)) | ||||||
|  | 	for _, access := range requestedAccessList { | ||||||
|  | 		if access.Type != "repository" { | ||||||
|  | 			context.GetLogger(ctx).Debugf("Skipping unsupported resource type: %s", access.Type) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if !strings.HasPrefix(access.Name, scopePrefix) { | ||||||
|  | 			context.GetLogger(ctx).Debugf("Resource scope not allowed: %s", access.Name) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		grantedAccessList = append(grantedAccessList, access) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ctx = context.WithValue(ctx, "grantedAccess", grantedAccessList) | ||||||
|  | 	ctx = context.WithLogger(ctx, context.GetLogger(ctx, "grantedAccess")) | ||||||
|  | 
 | ||||||
|  | 	token, err := ts.issuer.CreateJWT(username, service, grantedAccessList) | ||||||
|  | 	if err != nil { | ||||||
|  | 		handleError(ctx, err, w) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	context.GetLogger(ctx).Info("authorized client") | ||||||
|  | 
 | ||||||
|  | 	// Get response context.
 | ||||||
|  | 	ctx, w = context.WithResponseWriter(ctx, w) | ||||||
|  | 
 | ||||||
|  | 	w.Header().Set("Content-Type", "application/json") | ||||||
|  | 	json.NewEncoder(w).Encode(map[string]string{"token": token}) | ||||||
|  | 
 | ||||||
|  | 	context.GetResponseLogger(ctx).Info("getToken complete") | ||||||
|  | } | ||||||
|  | @ -0,0 +1,168 @@ | ||||||
|  | package main | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/docker/distribution/registry/auth" | ||||||
|  | 	"github.com/docker/distribution/registry/auth/token" | ||||||
|  | 	"github.com/docker/libtrust" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // ResolveScopeSpecifiers converts a list of scope specifiers from a token
 | ||||||
|  | // request's `scope` query parameters into a list of standard access objects.
 | ||||||
|  | func ResolveScopeSpecifiers(scopeSpecs []string) []auth.Access { | ||||||
|  | 	requestedAccessSet := make(map[auth.Access]struct{}, 2*len(scopeSpecs)) | ||||||
|  | 
 | ||||||
|  | 	for _, scopeSpecifier := range scopeSpecs { | ||||||
|  | 		// There should be 3 parts, separated by a `:` character.
 | ||||||
|  | 		parts := strings.SplitN(scopeSpecifier, ":", 3) | ||||||
|  | 
 | ||||||
|  | 		if len(parts) != 3 { | ||||||
|  | 			// Ignore malformed scope specifiers.
 | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		resourceType, resourceName, actions := parts[0], parts[1], parts[2] | ||||||
|  | 
 | ||||||
|  | 		// Actions should be a comma-separated list of actions.
 | ||||||
|  | 		for _, action := range strings.Split(actions, ",") { | ||||||
|  | 			requestedAccess := auth.Access{ | ||||||
|  | 				Resource: auth.Resource{ | ||||||
|  | 					Type: resourceType, | ||||||
|  | 					Name: resourceName, | ||||||
|  | 				}, | ||||||
|  | 				Action: action, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Add this access to the requested access set.
 | ||||||
|  | 			requestedAccessSet[requestedAccess] = struct{}{} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	requestedAccessList := make([]auth.Access, 0, len(requestedAccessSet)) | ||||||
|  | 	for requestedAccess := range requestedAccessSet { | ||||||
|  | 		requestedAccessList = append(requestedAccessList, requestedAccess) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return requestedAccessList | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TokenIssuer represents an issuer capable of generating JWT tokens
 | ||||||
|  | type TokenIssuer struct { | ||||||
|  | 	Issuer     string | ||||||
|  | 	SigningKey libtrust.PrivateKey | ||||||
|  | 	Expiration time.Duration | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CreateJWT creates and signs a JSON Web Token for the given account and
 | ||||||
|  | // audience with the granted access.
 | ||||||
|  | func (issuer *TokenIssuer) CreateJWT(subject string, audience string, grantedAccessList []auth.Access) (string, error) { | ||||||
|  | 	// Make a set of access entries to put in the token's claimset.
 | ||||||
|  | 	resourceActionSets := make(map[auth.Resource]map[string]struct{}, len(grantedAccessList)) | ||||||
|  | 	for _, access := range grantedAccessList { | ||||||
|  | 		actionSet, exists := resourceActionSets[access.Resource] | ||||||
|  | 		if !exists { | ||||||
|  | 			actionSet = map[string]struct{}{} | ||||||
|  | 			resourceActionSets[access.Resource] = actionSet | ||||||
|  | 		} | ||||||
|  | 		actionSet[access.Action] = struct{}{} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	accessEntries := make([]token.ResourceActions, 0, len(resourceActionSets)) | ||||||
|  | 	for resource, actionSet := range resourceActionSets { | ||||||
|  | 		actions := make([]string, 0, len(actionSet)) | ||||||
|  | 		for action := range actionSet { | ||||||
|  | 			actions = append(actions, action) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		accessEntries = append(accessEntries, token.ResourceActions{ | ||||||
|  | 			Type:    resource.Type, | ||||||
|  | 			Name:    resource.Name, | ||||||
|  | 			Actions: actions, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	randomBytes := make([]byte, 15) | ||||||
|  | 	_, err := io.ReadFull(rand.Reader, randomBytes) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	randomID := base64.URLEncoding.EncodeToString(randomBytes) | ||||||
|  | 
 | ||||||
|  | 	now := time.Now() | ||||||
|  | 
 | ||||||
|  | 	signingHash := crypto.SHA256 | ||||||
|  | 	var alg string | ||||||
|  | 	switch issuer.SigningKey.KeyType() { | ||||||
|  | 	case "RSA": | ||||||
|  | 		alg = "RS256" | ||||||
|  | 	case "EC": | ||||||
|  | 		alg = "ES256" | ||||||
|  | 	default: | ||||||
|  | 		panic(fmt.Errorf("unsupported signing key type %q", issuer.SigningKey.KeyType())) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	joseHeader := map[string]interface{}{ | ||||||
|  | 		"typ": "JWT", | ||||||
|  | 		"alg": alg, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if x5c := issuer.SigningKey.GetExtendedField("x5c"); x5c != nil { | ||||||
|  | 		joseHeader["x5c"] = x5c | ||||||
|  | 	} else { | ||||||
|  | 		joseHeader["jwk"] = issuer.SigningKey.PublicKey() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	exp := issuer.Expiration | ||||||
|  | 	if exp == 0 { | ||||||
|  | 		exp = 5 * time.Minute | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	claimSet := map[string]interface{}{ | ||||||
|  | 		"iss": issuer.Issuer, | ||||||
|  | 		"sub": subject, | ||||||
|  | 		"aud": audience, | ||||||
|  | 		"exp": now.Add(exp).Unix(), | ||||||
|  | 		"nbf": now.Unix(), | ||||||
|  | 		"iat": now.Unix(), | ||||||
|  | 		"jti": randomID, | ||||||
|  | 
 | ||||||
|  | 		"access": accessEntries, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var ( | ||||||
|  | 		joseHeaderBytes []byte | ||||||
|  | 		claimSetBytes   []byte | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if joseHeaderBytes, err = json.Marshal(joseHeader); err != nil { | ||||||
|  | 		return "", fmt.Errorf("unable to encode jose header: %s", err) | ||||||
|  | 	} | ||||||
|  | 	if claimSetBytes, err = json.Marshal(claimSet); err != nil { | ||||||
|  | 		return "", fmt.Errorf("unable to encode claim set: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	encodedJoseHeader := joseBase64Encode(joseHeaderBytes) | ||||||
|  | 	encodedClaimSet := joseBase64Encode(claimSetBytes) | ||||||
|  | 	encodingToSign := fmt.Sprintf("%s.%s", encodedJoseHeader, encodedClaimSet) | ||||||
|  | 
 | ||||||
|  | 	var signatureBytes []byte | ||||||
|  | 	if signatureBytes, _, err = issuer.SigningKey.Sign(strings.NewReader(encodingToSign), signingHash); err != nil { | ||||||
|  | 		return "", fmt.Errorf("unable to sign jwt payload: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	signature := joseBase64Encode(signatureBytes) | ||||||
|  | 
 | ||||||
|  | 	return fmt.Sprintf("%s.%s", encodingToSign, signature), nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func joseBase64Encode(data []byte) string { | ||||||
|  | 	return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=") | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue