You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
495 lines
14 KiB
495 lines
14 KiB
package jwt |
|
|
|
import ( |
|
"context" |
|
"errors" |
|
"fmt" |
|
"math/rand" |
|
"net/http" |
|
"reflect" |
|
"strconv" |
|
"sync" |
|
"testing" |
|
"time" |
|
|
|
"github.com/golang-jwt/jwt/v4" |
|
|
|
"github.com/go-kratos/kratos/v2/middleware" |
|
"github.com/go-kratos/kratos/v2/transport" |
|
) |
|
|
|
type headerCarrier http.Header |
|
|
|
func (hc headerCarrier) Get(key string) string { return http.Header(hc).Get(key) } |
|
|
|
func (hc headerCarrier) Set(key string, value string) { http.Header(hc).Set(key, value) } |
|
|
|
// Keys lists the keys stored in this carrier. |
|
func (hc headerCarrier) Keys() []string { |
|
keys := make([]string, 0, len(hc)) |
|
for k := range http.Header(hc) { |
|
keys = append(keys, k) |
|
} |
|
return keys |
|
} |
|
|
|
func newTokenHeader(headerKey string, token string) *headerCarrier { |
|
header := &headerCarrier{} |
|
header.Set(headerKey, token) |
|
return header |
|
} |
|
|
|
type Transport struct { |
|
kind transport.Kind |
|
endpoint string |
|
operation string |
|
reqHeader transport.Header |
|
} |
|
|
|
func (tr *Transport) Kind() transport.Kind { |
|
return tr.kind |
|
} |
|
|
|
func (tr *Transport) Endpoint() string { |
|
return tr.endpoint |
|
} |
|
|
|
func (tr *Transport) Operation() string { |
|
return tr.operation |
|
} |
|
|
|
func (tr *Transport) RequestHeader() transport.Header { |
|
return tr.reqHeader |
|
} |
|
|
|
func (tr *Transport) ReplyHeader() transport.Header { |
|
return nil |
|
} |
|
|
|
type CustomerClaims struct { |
|
Name string `json:"name"` |
|
jwt.RegisteredClaims |
|
} |
|
|
|
func TestJWTServerParse(t *testing.T) { |
|
var ( |
|
errConcurrentWrite = errors.New("concurrent write claims") |
|
errParseClaims = errors.New("bad result, token claims is not CustomerClaims") |
|
) |
|
|
|
testKey := "testKey" |
|
tests := []struct { |
|
name string |
|
token func() string |
|
claims func() jwt.Claims |
|
exceptErr error |
|
key string |
|
goroutineNum int |
|
}{ |
|
{ |
|
name: "normal", |
|
token: func() string { |
|
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{}).SignedString([]byte(testKey)) |
|
if err != nil { |
|
panic(err) |
|
} |
|
return fmt.Sprintf(bearerFormat, token) |
|
}, |
|
claims: func() jwt.Claims { |
|
return &CustomerClaims{} |
|
}, |
|
exceptErr: nil, |
|
key: testKey, |
|
goroutineNum: 1, |
|
}, |
|
{ |
|
name: "concurrent request", |
|
token: func() string { |
|
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{ |
|
Name: strconv.Itoa(rand.Int()), |
|
}).SignedString([]byte(testKey)) |
|
if err != nil { |
|
panic(err) |
|
} |
|
return fmt.Sprintf(bearerFormat, token) |
|
}, |
|
claims: func() jwt.Claims { |
|
return &CustomerClaims{} |
|
}, |
|
exceptErr: nil, |
|
key: testKey, |
|
goroutineNum: 10000, |
|
}, |
|
} |
|
|
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
testToken, _ := FromContext(ctx) |
|
var name string |
|
if customerClaims, ok := testToken.(*CustomerClaims); ok { |
|
name = customerClaims.Name |
|
} else { |
|
return nil, errParseClaims |
|
} |
|
|
|
// mock biz |
|
time.Sleep(100 * time.Millisecond) |
|
|
|
if customerClaims, ok := testToken.(*CustomerClaims); ok { |
|
if name != customerClaims.Name { |
|
return nil, errConcurrentWrite |
|
} |
|
} else { |
|
return nil, errParseClaims |
|
} |
|
return "reply", nil |
|
} |
|
|
|
for _, test := range tests { |
|
t.Run(test.name, func(t *testing.T) { |
|
server := Server( |
|
func(token *jwt.Token) (interface{}, error) { return []byte(testKey), nil }, |
|
WithClaims(test.claims), |
|
)(next) |
|
wg := sync.WaitGroup{} |
|
for i := 0; i < test.goroutineNum; i++ { |
|
wg.Add(1) |
|
go func() { |
|
defer wg.Done() |
|
ctx := transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, test.token())}) |
|
_, err2 := server(ctx, test.name) |
|
if !errors.Is(test.exceptErr, err2) { |
|
t.Errorf("except error %v, but got %v", test.exceptErr, err2) |
|
} |
|
}() |
|
} |
|
wg.Wait() |
|
}) |
|
} |
|
} |
|
|
|
func TestServer(t *testing.T) { |
|
testKey := "testKey" |
|
mapClaims := jwt.MapClaims{} |
|
mapClaims["name"] = "xiaoli" |
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
|
token, err := claims.SignedString([]byte(testKey)) |
|
if err != nil { |
|
panic(err) |
|
} |
|
token = fmt.Sprintf(bearerFormat, token) |
|
tests := []struct { |
|
name string |
|
ctx context.Context |
|
signingMethod jwt.SigningMethod |
|
exceptErr error |
|
key string |
|
}{ |
|
{ |
|
name: "normal", |
|
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
|
signingMethod: jwt.SigningMethodHS256, |
|
exceptErr: nil, |
|
key: testKey, |
|
}, |
|
{ |
|
name: "miss token", |
|
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: headerCarrier{}}), |
|
signingMethod: jwt.SigningMethodHS256, |
|
exceptErr: ErrMissingJwtToken, |
|
key: testKey, |
|
}, |
|
{ |
|
name: "token invalid", |
|
ctx: transport.NewServerContext(context.Background(), &Transport{ |
|
reqHeader: newTokenHeader(authorizationKey, fmt.Sprintf(bearerFormat, "12313123")), |
|
}), |
|
signingMethod: jwt.SigningMethodHS256, |
|
exceptErr: ErrTokenInvalid, |
|
key: testKey, |
|
}, |
|
{ |
|
name: "method invalid", |
|
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
|
signingMethod: jwt.SigningMethodES384, |
|
exceptErr: ErrUnSupportSigningMethod, |
|
key: testKey, |
|
}, |
|
{ |
|
name: "miss signing method", |
|
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
|
signingMethod: nil, |
|
exceptErr: nil, |
|
key: testKey, |
|
}, |
|
{ |
|
name: "miss signing method", |
|
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
|
signingMethod: nil, |
|
exceptErr: nil, |
|
key: testKey, |
|
}, |
|
} |
|
|
|
for _, test := range tests { |
|
t.Run(test.name, func(t *testing.T) { |
|
var testToken jwt.Claims |
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
t.Log(req) |
|
testToken, _ = FromContext(ctx) |
|
return "reply", nil |
|
} |
|
var server middleware.Handler |
|
if test.signingMethod != nil { |
|
server = Server(func(token *jwt.Token) (interface{}, error) { |
|
return []byte(test.key), nil |
|
}, WithSigningMethod(test.signingMethod))(next) |
|
} else { |
|
server = Server(func(token *jwt.Token) (interface{}, error) { |
|
return []byte(test.key), nil |
|
})(next) |
|
} |
|
_, err2 := server(test.ctx, test.name) |
|
if !errors.Is(test.exceptErr, err2) { |
|
t.Errorf("except error %v, but got %v", test.exceptErr, err2) |
|
} |
|
if test.exceptErr == nil { |
|
if testToken == nil { |
|
t.Errorf("except testToken not nil, but got nil") |
|
} |
|
_, ok := testToken.(jwt.MapClaims) |
|
if !ok { |
|
t.Errorf("except testToken is jwt.MapClaims, but got %T", testToken) |
|
} |
|
} |
|
}) |
|
} |
|
} |
|
|
|
func TestClient(t *testing.T) { |
|
testKey := "testKey" |
|
|
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{}) |
|
token, err := claims.SignedString([]byte(testKey)) |
|
if err != nil { |
|
panic(err) |
|
} |
|
tProvider := func(*jwt.Token) (interface{}, error) { |
|
return []byte(testKey), nil |
|
} |
|
tests := []struct { |
|
name string |
|
expectError error |
|
tokenProvider jwt.Keyfunc |
|
}{ |
|
{ |
|
name: "normal", |
|
expectError: nil, |
|
tokenProvider: tProvider, |
|
}, |
|
{ |
|
name: "miss token provider", |
|
expectError: ErrNeedTokenProvider, |
|
tokenProvider: nil, |
|
}, |
|
} |
|
for _, test := range tests { |
|
t.Run(test.name, func(t *testing.T) { |
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
return "reply", nil |
|
} |
|
handler := Client(test.tokenProvider)(next) |
|
header := &headerCarrier{} |
|
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
|
if !errors.Is(test.expectError, err2) { |
|
t.Errorf("except error %v, but got %v", test.expectError, err2) |
|
} |
|
if err2 == nil { |
|
if !reflect.DeepEqual(header.Get(authorizationKey), fmt.Sprintf(bearerFormat, token)) { |
|
t.Errorf("except header %s, but got %s", fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) |
|
} |
|
} |
|
}) |
|
} |
|
} |
|
|
|
func TestTokenExpire(t *testing.T) { |
|
testKey := "testKey" |
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ |
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Millisecond)), |
|
}) |
|
token, err := claims.SignedString([]byte(testKey)) |
|
if err != nil { |
|
panic(err) |
|
} |
|
token = fmt.Sprintf(bearerFormat, token) |
|
time.Sleep(time.Second) |
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
t.Log(req) |
|
return "reply", nil |
|
} |
|
ctx := transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}) |
|
server := Server(func(token *jwt.Token) (interface{}, error) { |
|
return []byte(testKey), nil |
|
}, WithSigningMethod(jwt.SigningMethodHS256))(next) |
|
_, err2 := server(ctx, "test expire token") |
|
if !errors.Is(ErrTokenExpired, err2) { |
|
t.Errorf("except error %v, but got %v", ErrTokenExpired, err2) |
|
} |
|
} |
|
|
|
func TestMissingKeyFunc(t *testing.T) { |
|
testKey := "testKey" |
|
mapClaims := jwt.MapClaims{} |
|
mapClaims["name"] = "xiaoli" |
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
|
token, err := claims.SignedString([]byte(testKey)) |
|
if err != nil { |
|
panic(err) |
|
} |
|
token = fmt.Sprintf(bearerFormat, token) |
|
test := struct { |
|
name string |
|
ctx context.Context |
|
signingMethod jwt.SigningMethod |
|
exceptErr error |
|
key string |
|
}{ |
|
name: "miss key", |
|
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
|
signingMethod: jwt.SigningMethodHS256, |
|
exceptErr: ErrMissingKeyFunc, |
|
key: "", |
|
} |
|
|
|
var testToken jwt.Claims |
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
t.Log(req) |
|
testToken, _ = FromContext(ctx) |
|
return "reply", nil |
|
} |
|
server := Server(nil)(next) |
|
_, err2 := server(test.ctx, test.name) |
|
if !errors.Is(test.exceptErr, err2) { |
|
t.Errorf("except error %v, but got %v", test.exceptErr, err2) |
|
} |
|
if test.exceptErr == nil { |
|
if testToken == nil { |
|
t.Errorf("except testToken not nil, but got nil") |
|
} |
|
} |
|
} |
|
|
|
func TestClientWithClaims(t *testing.T) { |
|
testKey := "testKey" |
|
mapClaims := jwt.MapClaims{} |
|
mapClaims["name"] = "xiaoli" |
|
mapClaimsFunc := func() jwt.Claims { return mapClaims } |
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
|
token, err := claims.SignedString([]byte(testKey)) |
|
if err != nil { |
|
panic(err) |
|
} |
|
tProvider := func(*jwt.Token) (interface{}, error) { |
|
return []byte(testKey), nil |
|
} |
|
test := struct { |
|
name string |
|
expectError error |
|
tokenProvider jwt.Keyfunc |
|
}{ |
|
name: "normal", |
|
expectError: nil, |
|
tokenProvider: tProvider, |
|
} |
|
|
|
t.Run(test.name, func(t *testing.T) { |
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
return "reply", nil |
|
} |
|
handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next) |
|
header := &headerCarrier{} |
|
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
|
if !errors.Is(test.expectError, err2) { |
|
t.Errorf("except error %v, but got %v", test.expectError, err2) |
|
} |
|
if err2 == nil { |
|
if !reflect.DeepEqual(header.Get(authorizationKey), fmt.Sprintf(bearerFormat, token)) { |
|
t.Errorf("except header %s, but got %s", fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) |
|
} |
|
} |
|
}) |
|
} |
|
|
|
func TestClientWithHeader(t *testing.T) { |
|
testKey := "testKey" |
|
mapClaims := jwt.MapClaims{} |
|
mapClaims["name"] = "xiaoli" |
|
mapClaimsFunc := func() jwt.Claims { return mapClaims } |
|
tokenHeader := map[string]interface{}{ |
|
"test": "test", |
|
} |
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
|
for k, v := range tokenHeader { |
|
claims.Header[k] = v |
|
} |
|
token, err := claims.SignedString([]byte(testKey)) |
|
if err != nil { |
|
panic(err) |
|
} |
|
tProvider := func(*jwt.Token) (interface{}, error) { |
|
return []byte(testKey), nil |
|
} |
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
return "reply", nil |
|
} |
|
handler := Client(tProvider, WithClaims(mapClaimsFunc), WithTokenHeader(tokenHeader))(next) |
|
header := &headerCarrier{} |
|
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
|
if err2 != nil { |
|
t.Errorf("except error nil, but got %v", err2) |
|
} |
|
if !reflect.DeepEqual(header.Get(authorizationKey), fmt.Sprintf(bearerFormat, token)) { |
|
t.Errorf("except header %s, but got %s", fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) |
|
} |
|
} |
|
|
|
func TestClientMissKey(t *testing.T) { |
|
testKey := "testKey" |
|
mapClaims := jwt.MapClaims{} |
|
mapClaims["name"] = "xiaoli" |
|
mapClaimsFunc := func() jwt.Claims { return mapClaims } |
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
|
token, err := claims.SignedString([]byte(testKey)) |
|
if err != nil { |
|
panic(err) |
|
} |
|
tProvider := func(*jwt.Token) (interface{}, error) { |
|
return nil, errors.New("some error") |
|
} |
|
test := struct { |
|
name string |
|
expectError error |
|
tokenProvider jwt.Keyfunc |
|
}{ |
|
name: "normal", |
|
expectError: ErrGetKey, |
|
tokenProvider: tProvider, |
|
} |
|
|
|
t.Run(test.name, func(t *testing.T) { |
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
return "reply", nil |
|
} |
|
handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next) |
|
header := &headerCarrier{} |
|
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
|
if !errors.Is(test.expectError, err2) { |
|
t.Errorf("except error %v, but got %v", test.expectError, err2) |
|
} |
|
if err2 == nil { |
|
if !reflect.DeepEqual(header.Get(authorizationKey), fmt.Sprintf(bearerFormat, token)) { |
|
t.Errorf("except header %s, but got %s", fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) |
|
} |
|
} |
|
}) |
|
}
|
|
|