jwt.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package jws
  2. import (
  3. "errors"
  4. "github.com/dgrijalva/jwt-go"
  5. "github.com/gin-gonic/gin"
  6. "github.com/spf13/viper"
  7. "strings"
  8. "time"
  9. )
  10. // JWT 签名结构
  11. type JWT struct {
  12. SigningKey []byte
  13. }
  14. // 一些常量
  15. var (
  16. TokenExpired = errors.New("授权已过期")
  17. TokenNotValidYet = errors.New("授权未生效")
  18. TokenMalformed = errors.New("无权限访问")
  19. TokenInvalid = errors.New("无权限访问")
  20. TokenHeaderName = "Authorization"
  21. Claims = "claims"
  22. DefaultSignKey = "defaultSignKey"
  23. TokenTimeout = 2592000 // 60 * 60 * 24 * 30
  24. )
  25. // 载荷,可以加一些自己需要的信息
  26. type TokenClaims struct {
  27. ID int64 `json:"id"`
  28. Token string `json:"-"`
  29. jwt.StandardClaims
  30. }
  31. // 新建一个jwt实例
  32. func NewJWT(args ...string) *JWT {
  33. var signKey string
  34. if key := viper.GetString("jws.sign_key"); len(strings.TrimSpace(key)) > 0 {
  35. signKey = strings.TrimSpace(key)
  36. } else if len(args) > 0 {
  37. signKey = strings.TrimSpace(args[0])
  38. } else {
  39. signKey = DefaultSignKey
  40. }
  41. return &JWT{SigningKey: []byte(signKey)}
  42. }
  43. func GetToken(ctxt *gin.Context) string {
  44. if nil == ctxt || nil == ctxt.Request ||
  45. nil == ctxt.Request.Header || len(ctxt.Request.Header) == 0 {
  46. return ""
  47. }
  48. token := ctxt.Request.Header.Get(TokenHeaderName)
  49. if token == "" {
  50. token = ctxt.Request.Header.Get(strings.ToLower(TokenHeaderName))
  51. }
  52. return strings.TrimSpace(token)
  53. }
  54. // CreateToken 生成一个token
  55. func (j *JWT) CreateToken(claims TokenClaims) (string, error) {
  56. jwt.TimeFunc = time.Now
  57. claims.StandardClaims.ExpiresAt = time.Now().Add(time.Duration(TokenTimeout) * time.Second).Unix()
  58. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  59. token.Header["exp"] = claims.StandardClaims.ExpiresAt
  60. return token.SignedString(j.SigningKey)
  61. }
  62. // 解析Tokne
  63. func (j *JWT) ParseToken(tokenString string) (*TokenClaims, error) {
  64. token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
  65. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  66. return nil, TokenInvalid
  67. }
  68. return j.SigningKey, nil
  69. })
  70. if err != nil {
  71. if ve, ok := err.(*jwt.ValidationError); ok {
  72. if ve.Errors&jwt.ValidationErrorMalformed != 0 {
  73. return nil, TokenMalformed
  74. } else if ve.Errors&jwt.ValidationErrorExpired != 0 {
  75. // Token is expired
  76. return nil, TokenExpired
  77. } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
  78. return nil, TokenNotValidYet
  79. } else {
  80. return nil, TokenInvalid
  81. }
  82. }
  83. }
  84. if claims, ok := token.Claims.(*TokenClaims); ok && token.Valid {
  85. return claims, nil
  86. }
  87. return nil, TokenInvalid
  88. }
  89. // 更新token
  90. func (j *JWT) RefreshToken(tokenString string) (string, error) {
  91. jwt.TimeFunc = func() time.Time {
  92. return time.Unix(0, 0)
  93. }
  94. token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
  95. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  96. return nil, TokenInvalid
  97. }
  98. return j.SigningKey, nil
  99. })
  100. if err != nil {
  101. return "", err
  102. }
  103. if claims, ok := token.Claims.(*TokenClaims); ok && token.Valid {
  104. return j.CreateToken(*claims)
  105. }
  106. return "", TokenInvalid
  107. }