package jws

import (
	"errors"
	"github.com/dgrijalva/jwt-go"
	"github.com/gin-gonic/gin"
	"github.com/spf13/viper"
	"strings"
	"time"
)

// JWT 签名结构
type JWT struct {
	SigningKey []byte
}

// 一些常量
var (
	TokenExpired     = errors.New("授权已过期")
	TokenNotValidYet = errors.New("授权未生效")
	TokenMalformed   = errors.New("无权限访问")
	TokenInvalid     = errors.New("无权限访问")
	TokenHeaderName  = "Authorization"
	Claims           = "claims"
	DefaultSignKey   = "defaultSignKey"
	TokenTimeout     = 2592000 // 60 * 60 * 24 * 30
)

// 载荷,可以加一些自己需要的信息
type TokenClaims struct {
	ID    int64  `json:"id"`
	Token string `json:"-"`
	jwt.StandardClaims
}

// 新建一个jwt实例
func NewJWT(args ...string) *JWT {
	var signKey string
	if key := viper.GetString("jws.sign_key"); len(strings.TrimSpace(key)) > 0 {
		signKey = strings.TrimSpace(key)
	} else if len(args) > 0 {
		signKey = strings.TrimSpace(args[0])
	} else {
		signKey = DefaultSignKey
	}
	return &JWT{SigningKey: []byte(signKey)}
}

func GetToken(ctxt *gin.Context) string {
	if nil == ctxt || nil == ctxt.Request ||
		nil == ctxt.Request.Header || len(ctxt.Request.Header) == 0 {
		return ""
	}
	token := ctxt.Request.Header.Get(TokenHeaderName)
	if token == "" {
		token = ctxt.Request.Header.Get(strings.ToLower(TokenHeaderName))
	}
	return strings.TrimSpace(token)
}

// CreateToken 生成一个token
func (j *JWT) CreateToken(claims TokenClaims) (string, error) {
	jwt.TimeFunc = time.Now
	claims.StandardClaims.ExpiresAt = time.Now().Add(time.Duration(TokenTimeout) * time.Second).Unix()
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	token.Header["exp"] = claims.StandardClaims.ExpiresAt
	return token.SignedString(j.SigningKey)
}

// 解析Tokne
func (j *JWT) ParseToken(tokenString string) (*TokenClaims, error) {
	token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			return nil, TokenInvalid
		}

		return j.SigningKey, nil
	})
	if err != nil {
		if ve, ok := err.(*jwt.ValidationError); ok {
			if ve.Errors&jwt.ValidationErrorMalformed != 0 {
				return nil, TokenMalformed
			} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
				// Token is expired
				return nil, TokenExpired
			} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
				return nil, TokenNotValidYet
			} else {
				return nil, TokenInvalid
			}
		}
	}
	if claims, ok := token.Claims.(*TokenClaims); ok && token.Valid {
		return claims, nil
	}
	return nil, TokenInvalid
}

// 更新token
func (j *JWT) RefreshToken(tokenString string) (string, error) {
	jwt.TimeFunc = func() time.Time {
		return time.Unix(0, 0)
	}
	token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			return nil, TokenInvalid
		}

		return j.SigningKey, nil
	})
	if err != nil {
		return "", err
	}
	if claims, ok := token.Claims.(*TokenClaims); ok && token.Valid {
		return j.CreateToken(*claims)
	}
	return "", TokenInvalid
}