feat: add WeakStringList type to support lists in aud claim
Signed-off-by: Mark Sagi-Kazar <mark.sagikazar@gmail.com>master
							parent
							
								
									78b9c98c5c
								
							
						
					
					
						commit
						97fa1183bf
					
				|  | @ -0,0 +1,55 @@ | ||||||
|  | package token | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"reflect" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // WeakStringList is a slice of strings that can be deserialized from either a single string value or a list of strings.
 | ||||||
|  | type WeakStringList []string | ||||||
|  | 
 | ||||||
|  | func (s *WeakStringList) UnmarshalJSON(data []byte) (err error) { | ||||||
|  | 	var value interface{} | ||||||
|  | 
 | ||||||
|  | 	if err = json.Unmarshal(data, &value); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch v := value.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		*s = []string{v} | ||||||
|  | 
 | ||||||
|  | 	case []string: | ||||||
|  | 		*s = v | ||||||
|  | 
 | ||||||
|  | 	case []interface{}: | ||||||
|  | 		var ss []string | ||||||
|  | 
 | ||||||
|  | 		for _, vv := range v { | ||||||
|  | 			vs, ok := vv.(string) | ||||||
|  | 			if !ok { | ||||||
|  | 				return &json.UnsupportedTypeError{ | ||||||
|  | 					Type: reflect.TypeOf(vv), | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			ss = append(ss, vs) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		*s = ss | ||||||
|  | 
 | ||||||
|  | 	case nil: | ||||||
|  | 		return nil | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		return &json.UnsupportedTypeError{ | ||||||
|  | 			Type: reflect.TypeOf(v), | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s WeakStringList) MarshalJSON() (b []byte, err error) { | ||||||
|  | 	return json.Marshal([]string(s)) | ||||||
|  | } | ||||||
|  | @ -0,0 +1,85 @@ | ||||||
|  | package token | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestWeakStringList_Unmarshal(t *testing.T) { | ||||||
|  | 	t.Run("OK", func(t *testing.T) { | ||||||
|  | 		testCases := []struct { | ||||||
|  | 			value    string | ||||||
|  | 			expected WeakStringList | ||||||
|  | 		}{ | ||||||
|  | 			{ | ||||||
|  | 				value:    `"audience"`, | ||||||
|  | 				expected: WeakStringList{"audience"}, | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				value:    `["audience1", "audience2"]`, | ||||||
|  | 				expected: WeakStringList{"audience1", "audience2"}, | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				value:    `null`, | ||||||
|  | 				expected: nil, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for _, testCase := range testCases { | ||||||
|  | 			testCase := testCase | ||||||
|  | 
 | ||||||
|  | 			t.Run("", func(t *testing.T) { | ||||||
|  | 				var actual WeakStringList | ||||||
|  | 
 | ||||||
|  | 				err := json.Unmarshal([]byte(testCase.value), &actual) | ||||||
|  | 				if err != nil { | ||||||
|  | 					t.Fatal(err) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				assertStringListEqual(t, testCase.expected, actual) | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Error", func(t *testing.T) { | ||||||
|  | 		var actual WeakStringList | ||||||
|  | 
 | ||||||
|  | 		err := json.Unmarshal([]byte("1234"), &actual) | ||||||
|  | 		if err == nil { | ||||||
|  | 			t.Fatal("expected unmarshal to fail") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestWeakStringList_Marshal(t *testing.T) { | ||||||
|  | 	value := WeakStringList{"audience"} | ||||||
|  | 
 | ||||||
|  | 	expected := `["audience"]` | ||||||
|  | 
 | ||||||
|  | 	actual, err := json.Marshal(value) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if expected != string(actual) { | ||||||
|  | 		t.Errorf("expected marshaled list to be %v, got %v", expected, actual) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func assertStringListEqual(t *testing.T, expected []string, actual []string) { | ||||||
|  | 	t.Helper() | ||||||
|  | 
 | ||||||
|  | 	if len(expected) != len(actual) { | ||||||
|  | 		t.Errorf("length mismatch: expected %d long slice, got %d", len(expected), len(actual)) | ||||||
|  | 
 | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for i, v := range expected { | ||||||
|  | 		if v != actual[i] { | ||||||
|  | 			t.Errorf("expected %d. item to be %q, got %q", i, v, actual[i]) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue