You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
180 lines
5.3 KiB
180 lines
5.3 KiB
package jwt |
|
|
|
import ( |
|
"context" |
|
"fmt" |
|
"strings" |
|
|
|
"github.com/golang-jwt/jwt/v4" |
|
|
|
"github.com/go-kratos/kratos/v2/errors" |
|
"github.com/go-kratos/kratos/v2/middleware" |
|
"github.com/go-kratos/kratos/v2/transport" |
|
) |
|
|
|
type authKey struct{} |
|
|
|
const ( |
|
|
|
// bearerWord the bearer key word for authorization |
|
bearerWord string = "Bearer" |
|
|
|
// bearerFormat authorization token format |
|
bearerFormat string = "Bearer %s" |
|
|
|
// authorizationKey holds the key used to store the JWT Token in the request tokenHeader. |
|
authorizationKey string = "Authorization" |
|
|
|
// reason holds the error reason. |
|
reason string = "UNAUTHORIZED" |
|
) |
|
|
|
var ( |
|
ErrMissingJwtToken = errors.Unauthorized(reason, "JWT token is missing") |
|
ErrMissingKeyFunc = errors.Unauthorized(reason, "keyFunc is missing") |
|
ErrTokenInvalid = errors.Unauthorized(reason, "Token is invalid") |
|
ErrTokenExpired = errors.Unauthorized(reason, "JWT token has expired") |
|
ErrTokenParseFail = errors.Unauthorized(reason, "Fail to parse JWT token ") |
|
ErrUnSupportSigningMethod = errors.Unauthorized(reason, "Wrong signing method") |
|
ErrWrongContext = errors.Unauthorized(reason, "Wrong context for middleware") |
|
ErrNeedTokenProvider = errors.Unauthorized(reason, "Token provider is missing") |
|
ErrSignToken = errors.Unauthorized(reason, "Can not sign token.Is the key correct?") |
|
ErrGetKey = errors.Unauthorized(reason, "Can not get key while signing token") |
|
) |
|
|
|
// Option is jwt option. |
|
type Option func(*options) |
|
|
|
// Parser is a jwt parser |
|
type options struct { |
|
signingMethod jwt.SigningMethod |
|
claims func() jwt.Claims |
|
tokenHeader map[string]interface{} |
|
} |
|
|
|
// WithSigningMethod with signing method option. |
|
func WithSigningMethod(method jwt.SigningMethod) Option { |
|
return func(o *options) { |
|
o.signingMethod = method |
|
} |
|
} |
|
|
|
// WithClaims with customer claim |
|
// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems |
|
// If you use it in Client, f only needs to return a single object to provide performance |
|
func WithClaims(f func() jwt.Claims) Option { |
|
return func(o *options) { |
|
o.claims = f |
|
} |
|
} |
|
|
|
// WithTokenHeader withe customer tokenHeader for client side |
|
func WithTokenHeader(header map[string]interface{}) Option { |
|
return func(o *options) { |
|
o.tokenHeader = header |
|
} |
|
} |
|
|
|
// Server is a server auth middleware. Check the token and extract the info from token. |
|
func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware { |
|
o := &options{ |
|
signingMethod: jwt.SigningMethodHS256, |
|
} |
|
for _, opt := range opts { |
|
opt(o) |
|
} |
|
return func(handler middleware.Handler) middleware.Handler { |
|
return func(ctx context.Context, req interface{}) (interface{}, error) { |
|
if header, ok := transport.FromServerContext(ctx); ok { |
|
if keyFunc == nil { |
|
return nil, ErrMissingKeyFunc |
|
} |
|
auths := strings.SplitN(header.RequestHeader().Get(authorizationKey), " ", 2) |
|
if len(auths) != 2 || !strings.EqualFold(auths[0], bearerWord) { |
|
return nil, ErrMissingJwtToken |
|
} |
|
jwtToken := auths[1] |
|
var ( |
|
tokenInfo *jwt.Token |
|
err error |
|
) |
|
if o.claims != nil { |
|
tokenInfo, err = jwt.ParseWithClaims(jwtToken, o.claims(), keyFunc) |
|
} else { |
|
tokenInfo, err = jwt.Parse(jwtToken, keyFunc) |
|
} |
|
if err != nil { |
|
ve, ok := err.(*jwt.ValidationError) |
|
if !ok { |
|
return nil, errors.Unauthorized(reason, err.Error()) |
|
} |
|
if ve.Errors&jwt.ValidationErrorMalformed != 0 { |
|
return nil, ErrTokenInvalid |
|
} |
|
if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { |
|
return nil, ErrTokenExpired |
|
} |
|
return nil, ErrTokenParseFail |
|
} |
|
if !tokenInfo.Valid { |
|
return nil, ErrTokenInvalid |
|
} |
|
if tokenInfo.Method != o.signingMethod { |
|
return nil, ErrUnSupportSigningMethod |
|
} |
|
ctx = NewContext(ctx, tokenInfo.Claims) |
|
return handler(ctx, req) |
|
} |
|
return nil, ErrWrongContext |
|
} |
|
} |
|
} |
|
|
|
// Client is a client jwt middleware. |
|
func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware { |
|
claims := jwt.RegisteredClaims{} |
|
o := &options{ |
|
signingMethod: jwt.SigningMethodHS256, |
|
claims: func() jwt.Claims { return claims }, |
|
} |
|
for _, opt := range opts { |
|
opt(o) |
|
} |
|
return func(handler middleware.Handler) middleware.Handler { |
|
return func(ctx context.Context, req interface{}) (interface{}, error) { |
|
if keyProvider == nil { |
|
return nil, ErrNeedTokenProvider |
|
} |
|
token := jwt.NewWithClaims(o.signingMethod, o.claims()) |
|
if o.tokenHeader != nil { |
|
for k, v := range o.tokenHeader { |
|
token.Header[k] = v |
|
} |
|
} |
|
key, err := keyProvider(token) |
|
if err != nil { |
|
return nil, ErrGetKey |
|
} |
|
tokenStr, err := token.SignedString(key) |
|
if err != nil { |
|
return nil, ErrSignToken |
|
} |
|
if clientContext, ok := transport.FromClientContext(ctx); ok { |
|
clientContext.RequestHeader().Set(authorizationKey, fmt.Sprintf(bearerFormat, tokenStr)) |
|
return handler(ctx, req) |
|
} |
|
return nil, ErrWrongContext |
|
} |
|
} |
|
} |
|
|
|
// NewContext put auth info into context |
|
func NewContext(ctx context.Context, info jwt.Claims) context.Context { |
|
return context.WithValue(ctx, authKey{}, info) |
|
} |
|
|
|
// FromContext extract auth info from context |
|
func FromContext(ctx context.Context) (token jwt.Claims, ok bool) { |
|
token, ok = ctx.Value(authKey{}).(jwt.Claims) |
|
return |
|
}
|
|
|