Merge pull request #3742 from sagikazarmark/fix-aud-claim-list
Accept list of strings in audience claim in token authmaster
						commit
						29b5e79f82
					
				|  | @ -180,7 +180,7 @@ func (issuer *TokenIssuer) CreateJWT(subject string, audience string, grantedAcc | |||
| 	claimSet := token.ClaimSet{ | ||||
| 		Issuer:     issuer.Issuer, | ||||
| 		Subject:    subject, | ||||
| 		Audience:   audience, | ||||
| 		Audience:   []string{audience}, | ||||
| 		Expiration: now.Add(exp).Unix(), | ||||
| 		NotBefore:  now.Unix(), | ||||
| 		IssuedAt:   now.Unix(), | ||||
|  |  | |||
|  | @ -42,13 +42,13 @@ type ResourceActions struct { | |||
| // ClaimSet describes the main section of a JSON Web Token.
 | ||||
| type ClaimSet struct { | ||||
| 	// Public claims
 | ||||
| 	Issuer     string `json:"iss"` | ||||
| 	Subject    string `json:"sub"` | ||||
| 	Audience   string `json:"aud"` | ||||
| 	Expiration int64  `json:"exp"` | ||||
| 	NotBefore  int64  `json:"nbf"` | ||||
| 	IssuedAt   int64  `json:"iat"` | ||||
| 	JWTID      string `json:"jti"` | ||||
| 	Issuer     string       `json:"iss"` | ||||
| 	Subject    string       `json:"sub"` | ||||
| 	Audience   AudienceList `json:"aud"` | ||||
| 	Expiration int64        `json:"exp"` | ||||
| 	NotBefore  int64        `json:"nbf"` | ||||
| 	IssuedAt   int64        `json:"iat"` | ||||
| 	JWTID      string       `json:"jti"` | ||||
| 
 | ||||
| 	// Private claims
 | ||||
| 	Access []*ResourceActions `json:"access"` | ||||
|  | @ -143,8 +143,8 @@ func (t *Token) Verify(verifyOpts VerifyOptions) error { | |||
| 	} | ||||
| 
 | ||||
| 	// Verify that the Audience claim is allowed.
 | ||||
| 	if !contains(verifyOpts.AcceptedAudiences, t.Claims.Audience) { | ||||
| 		log.Infof("token intended for another audience: %q", t.Claims.Audience) | ||||
| 	if !containsAny(verifyOpts.AcceptedAudiences, t.Claims.Audience) { | ||||
| 		log.Infof("token intended for another audience: %v", t.Claims.Audience) | ||||
| 		return ErrInvalidToken | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -116,7 +116,7 @@ func makeTestToken(issuer, audience string, access []*ResourceActions, rootKey l | |||
| 	claimSet := &ClaimSet{ | ||||
| 		Issuer:     issuer, | ||||
| 		Subject:    "foo", | ||||
| 		Audience:   audience, | ||||
| 		Audience:   []string{audience}, | ||||
| 		Expiration: exp.Unix(), | ||||
| 		NotBefore:  now.Unix(), | ||||
| 		IssuedAt:   now.Unix(), | ||||
|  |  | |||
|  | @ -0,0 +1,55 @@ | |||
| package token | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"reflect" | ||||
| ) | ||||
| 
 | ||||
| // AudienceList is a slice of strings that can be deserialized from either a single string value or a list of strings.
 | ||||
| type AudienceList []string | ||||
| 
 | ||||
| func (s *AudienceList) UnmarshalJSON(data []byte) (err error) { | ||||
| 	var value interface{} | ||||
| 
 | ||||
| 	if err = json.Unmarshal(data, &value); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	switch v := value.(type) { | ||||
| 	case string: | ||||
| 		*s = []string{v} | ||||
| 
 | ||||
| 	case []string: | ||||
| 		*s = v | ||||
| 
 | ||||
| 	case []interface{}: | ||||
| 		var ss []string | ||||
| 
 | ||||
| 		for _, vv := range v { | ||||
| 			vs, ok := vv.(string) | ||||
| 			if !ok { | ||||
| 				return &json.UnsupportedTypeError{ | ||||
| 					Type: reflect.TypeOf(vv), | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			ss = append(ss, vs) | ||||
| 		} | ||||
| 
 | ||||
| 		*s = ss | ||||
| 
 | ||||
| 	case nil: | ||||
| 		return nil | ||||
| 
 | ||||
| 	default: | ||||
| 		return &json.UnsupportedTypeError{ | ||||
| 			Type: reflect.TypeOf(v), | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s AudienceList) MarshalJSON() (b []byte, err error) { | ||||
| 	return json.Marshal([]string(s)) | ||||
| } | ||||
|  | @ -0,0 +1,85 @@ | |||
| package token | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestAudienceList_Unmarshal(t *testing.T) { | ||||
| 	t.Run("OK", func(t *testing.T) { | ||||
| 		testCases := []struct { | ||||
| 			value    string | ||||
| 			expected AudienceList | ||||
| 		}{ | ||||
| 			{ | ||||
| 				value:    `"audience"`, | ||||
| 				expected: AudienceList{"audience"}, | ||||
| 			}, | ||||
| 			{ | ||||
| 				value:    `["audience1", "audience2"]`, | ||||
| 				expected: AudienceList{"audience1", "audience2"}, | ||||
| 			}, | ||||
| 			{ | ||||
| 				value:    `null`, | ||||
| 				expected: nil, | ||||
| 			}, | ||||
| 		} | ||||
| 
 | ||||
| 		for _, testCase := range testCases { | ||||
| 			testCase := testCase | ||||
| 
 | ||||
| 			t.Run("", func(t *testing.T) { | ||||
| 				var actual AudienceList | ||||
| 
 | ||||
| 				err := json.Unmarshal([]byte(testCase.value), &actual) | ||||
| 				if err != nil { | ||||
| 					t.Fatal(err) | ||||
| 				} | ||||
| 
 | ||||
| 				assertStringListEqual(t, testCase.expected, actual) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| 
 | ||||
| 	t.Run("Error", func(t *testing.T) { | ||||
| 		var actual AudienceList | ||||
| 
 | ||||
| 		err := json.Unmarshal([]byte("1234"), &actual) | ||||
| 		if err == nil { | ||||
| 			t.Fatal("expected unmarshal to fail") | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func TestAudienceList_Marshal(t *testing.T) { | ||||
| 	value := AudienceList{"audience"} | ||||
| 
 | ||||
| 	expected := `["audience"]` | ||||
| 
 | ||||
| 	actual, err := json.Marshal(value) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	if expected != string(actual) { | ||||
| 		t.Errorf("expected marshaled list to be %v, got %v", expected, actual) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func assertStringListEqual(t *testing.T, expected []string, actual []string) { | ||||
| 	t.Helper() | ||||
| 
 | ||||
| 	if len(expected) != len(actual) { | ||||
| 		t.Errorf("length mismatch: expected %d long slice, got %d", len(expected), len(actual)) | ||||
| 
 | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	for i, v := range expected { | ||||
| 		if v != actual[i] { | ||||
| 			t.Errorf("expected %d. item to be %q, got %q", i, v, actual[i]) | ||||
| 		} | ||||
| 
 | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | @ -56,3 +56,14 @@ func contains(ss []string, q string) bool { | |||
| 
 | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // containsAny returns true if any of q is found in ss.
 | ||||
| func containsAny(ss []string, q []string) bool { | ||||
| 	for _, s := range ss { | ||||
| 		if contains(q, s) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue