Merge pull request #3742 from sagikazarmark/fix-aud-claim-list
Accept list of strings in audience claim in token authmaster
						commit
						29b5e79f82
					
				|  | @ -180,7 +180,7 @@ func (issuer *TokenIssuer) CreateJWT(subject string, audience string, grantedAcc | ||||||
| 	claimSet := token.ClaimSet{ | 	claimSet := token.ClaimSet{ | ||||||
| 		Issuer:     issuer.Issuer, | 		Issuer:     issuer.Issuer, | ||||||
| 		Subject:    subject, | 		Subject:    subject, | ||||||
| 		Audience:   audience, | 		Audience:   []string{audience}, | ||||||
| 		Expiration: now.Add(exp).Unix(), | 		Expiration: now.Add(exp).Unix(), | ||||||
| 		NotBefore:  now.Unix(), | 		NotBefore:  now.Unix(), | ||||||
| 		IssuedAt:   now.Unix(), | 		IssuedAt:   now.Unix(), | ||||||
|  |  | ||||||
|  | @ -42,13 +42,13 @@ type ResourceActions struct { | ||||||
| // ClaimSet describes the main section of a JSON Web Token.
 | // ClaimSet describes the main section of a JSON Web Token.
 | ||||||
| type ClaimSet struct { | type ClaimSet struct { | ||||||
| 	// Public claims
 | 	// Public claims
 | ||||||
| 	Issuer     string `json:"iss"` | 	Issuer     string       `json:"iss"` | ||||||
| 	Subject    string `json:"sub"` | 	Subject    string       `json:"sub"` | ||||||
| 	Audience   string `json:"aud"` | 	Audience   AudienceList `json:"aud"` | ||||||
| 	Expiration int64  `json:"exp"` | 	Expiration int64        `json:"exp"` | ||||||
| 	NotBefore  int64  `json:"nbf"` | 	NotBefore  int64        `json:"nbf"` | ||||||
| 	IssuedAt   int64  `json:"iat"` | 	IssuedAt   int64        `json:"iat"` | ||||||
| 	JWTID      string `json:"jti"` | 	JWTID      string       `json:"jti"` | ||||||
| 
 | 
 | ||||||
| 	// Private claims
 | 	// Private claims
 | ||||||
| 	Access []*ResourceActions `json:"access"` | 	Access []*ResourceActions `json:"access"` | ||||||
|  | @ -143,8 +143,8 @@ func (t *Token) Verify(verifyOpts VerifyOptions) error { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Verify that the Audience claim is allowed.
 | 	// Verify that the Audience claim is allowed.
 | ||||||
| 	if !contains(verifyOpts.AcceptedAudiences, t.Claims.Audience) { | 	if !containsAny(verifyOpts.AcceptedAudiences, t.Claims.Audience) { | ||||||
| 		log.Infof("token intended for another audience: %q", t.Claims.Audience) | 		log.Infof("token intended for another audience: %v", t.Claims.Audience) | ||||||
| 		return ErrInvalidToken | 		return ErrInvalidToken | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -116,7 +116,7 @@ func makeTestToken(issuer, audience string, access []*ResourceActions, rootKey l | ||||||
| 	claimSet := &ClaimSet{ | 	claimSet := &ClaimSet{ | ||||||
| 		Issuer:     issuer, | 		Issuer:     issuer, | ||||||
| 		Subject:    "foo", | 		Subject:    "foo", | ||||||
| 		Audience:   audience, | 		Audience:   []string{audience}, | ||||||
| 		Expiration: exp.Unix(), | 		Expiration: exp.Unix(), | ||||||
| 		NotBefore:  now.Unix(), | 		NotBefore:  now.Unix(), | ||||||
| 		IssuedAt:   now.Unix(), | 		IssuedAt:   now.Unix(), | ||||||
|  |  | ||||||
|  | @ -0,0 +1,55 @@ | ||||||
|  | package token | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"reflect" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // AudienceList is a slice of strings that can be deserialized from either a single string value or a list of strings.
 | ||||||
|  | type AudienceList []string | ||||||
|  | 
 | ||||||
|  | func (s *AudienceList) UnmarshalJSON(data []byte) (err error) { | ||||||
|  | 	var value interface{} | ||||||
|  | 
 | ||||||
|  | 	if err = json.Unmarshal(data, &value); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch v := value.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		*s = []string{v} | ||||||
|  | 
 | ||||||
|  | 	case []string: | ||||||
|  | 		*s = v | ||||||
|  | 
 | ||||||
|  | 	case []interface{}: | ||||||
|  | 		var ss []string | ||||||
|  | 
 | ||||||
|  | 		for _, vv := range v { | ||||||
|  | 			vs, ok := vv.(string) | ||||||
|  | 			if !ok { | ||||||
|  | 				return &json.UnsupportedTypeError{ | ||||||
|  | 					Type: reflect.TypeOf(vv), | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			ss = append(ss, vs) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		*s = ss | ||||||
|  | 
 | ||||||
|  | 	case nil: | ||||||
|  | 		return nil | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		return &json.UnsupportedTypeError{ | ||||||
|  | 			Type: reflect.TypeOf(v), | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s AudienceList) MarshalJSON() (b []byte, err error) { | ||||||
|  | 	return json.Marshal([]string(s)) | ||||||
|  | } | ||||||
|  | @ -0,0 +1,85 @@ | ||||||
|  | package token | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestAudienceList_Unmarshal(t *testing.T) { | ||||||
|  | 	t.Run("OK", func(t *testing.T) { | ||||||
|  | 		testCases := []struct { | ||||||
|  | 			value    string | ||||||
|  | 			expected AudienceList | ||||||
|  | 		}{ | ||||||
|  | 			{ | ||||||
|  | 				value:    `"audience"`, | ||||||
|  | 				expected: AudienceList{"audience"}, | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				value:    `["audience1", "audience2"]`, | ||||||
|  | 				expected: AudienceList{"audience1", "audience2"}, | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				value:    `null`, | ||||||
|  | 				expected: nil, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for _, testCase := range testCases { | ||||||
|  | 			testCase := testCase | ||||||
|  | 
 | ||||||
|  | 			t.Run("", func(t *testing.T) { | ||||||
|  | 				var actual AudienceList | ||||||
|  | 
 | ||||||
|  | 				err := json.Unmarshal([]byte(testCase.value), &actual) | ||||||
|  | 				if err != nil { | ||||||
|  | 					t.Fatal(err) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				assertStringListEqual(t, testCase.expected, actual) | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Error", func(t *testing.T) { | ||||||
|  | 		var actual AudienceList | ||||||
|  | 
 | ||||||
|  | 		err := json.Unmarshal([]byte("1234"), &actual) | ||||||
|  | 		if err == nil { | ||||||
|  | 			t.Fatal("expected unmarshal to fail") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestAudienceList_Marshal(t *testing.T) { | ||||||
|  | 	value := AudienceList{"audience"} | ||||||
|  | 
 | ||||||
|  | 	expected := `["audience"]` | ||||||
|  | 
 | ||||||
|  | 	actual, err := json.Marshal(value) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if expected != string(actual) { | ||||||
|  | 		t.Errorf("expected marshaled list to be %v, got %v", expected, actual) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func assertStringListEqual(t *testing.T, expected []string, actual []string) { | ||||||
|  | 	t.Helper() | ||||||
|  | 
 | ||||||
|  | 	if len(expected) != len(actual) { | ||||||
|  | 		t.Errorf("length mismatch: expected %d long slice, got %d", len(expected), len(actual)) | ||||||
|  | 
 | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for i, v := range expected { | ||||||
|  | 		if v != actual[i] { | ||||||
|  | 			t.Errorf("expected %d. item to be %q, got %q", i, v, actual[i]) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -56,3 +56,14 @@ func contains(ss []string, q string) bool { | ||||||
| 
 | 
 | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // containsAny returns true if any of q is found in ss.
 | ||||||
|  | func containsAny(ss []string, q []string) bool { | ||||||
|  | 	for _, s := range ss { | ||||||
|  | 		if contains(q, s) { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue