255 lines
6.1 KiB
Go
255 lines
6.1 KiB
Go
package redis
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/dgrijalva/jwt-go.git"
|
|
"github.com/google/uuid"
|
|
"src.dualinventive.com/go/authentication-service/internal/domain"
|
|
"src.dualinventive.com/go/authentication-service/internal/storage"
|
|
"src.dualinventive.com/go/lib/dilog"
|
|
"src.dualinventive.com/go/lib/kv"
|
|
)
|
|
|
|
const tokenExpirationDuration = time.Hour * 24
|
|
|
|
//TokenRepository contains functions to manage tokens
|
|
type TokenRepository struct {
|
|
kvstore kv.KV
|
|
privateKey *rsa.PrivateKey
|
|
publicKey *rsa.PublicKey
|
|
logger dilog.Logger
|
|
}
|
|
|
|
//NewTokenRepository returns a new instance of a redis (or ram) backed repository
|
|
func NewTokenRepository(
|
|
host, port, privateKeyFile, publicKeyFile string,
|
|
logger dilog.Logger) (storage.TokenRepository, error) {
|
|
repo := &TokenRepository{}
|
|
repo.logger = logger
|
|
|
|
privateKey, err := configurePrivateKey(privateKeyFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
publicKey, err := configurePublicKey(publicKeyFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
repo.privateKey = privateKey
|
|
repo.publicKey = publicKey
|
|
|
|
kvstore, err := kv.New(kv.TypeRedis, fmt.Sprintf("%s:%s", host, port))
|
|
if err != nil {
|
|
return nil, NewErr(err)
|
|
}
|
|
|
|
repo.kvstore = kvstore
|
|
return repo, nil
|
|
}
|
|
|
|
//CreateToken creates and stores token
|
|
func (tr *TokenRepository) CreateToken(
|
|
userName string,
|
|
companyCode string,
|
|
userAgent string,
|
|
rights ...string) (*domain.Token, error) {
|
|
secret, err := tr.generateSecret(userName, userAgent, uuid.New().String(), rights)
|
|
if err != nil {
|
|
return nil, NewErr(err)
|
|
}
|
|
|
|
err = tr.kvstore.Set("token:"+secret, userName+"::"+companyCode)
|
|
if err != nil {
|
|
return nil, NewErr(err)
|
|
}
|
|
return &domain.Token{Secret: secret}, nil
|
|
}
|
|
|
|
//GetTokens returns tokens, if exists and is not expired, for given user name
|
|
func (tr *TokenRepository) GetTokens(userName string, companyCode string) ([]*domain.Token, error) {
|
|
var keys []*domain.Token
|
|
iter, err := tr.kvstore.Keys("token:*")
|
|
if err != nil {
|
|
return nil, NewErr(err)
|
|
}
|
|
for _, key := range iter {
|
|
val, err := tr.kvstore.Get(key)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if val == userName+"::"+companyCode {
|
|
keys = tr.processKey(key, keys)
|
|
}
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
//GetUserNameByToken returns a username that is bound to the given token.
|
|
func (tr *TokenRepository) GetUserNameByToken(token *domain.Token) (string, string, error) {
|
|
value, err := tr.kvstore.Get("token:" + token.Secret)
|
|
if err != nil {
|
|
return "", "", NewErr(err)
|
|
}
|
|
|
|
split := strings.Split(value, "::")
|
|
return split[0], split[1], nil
|
|
}
|
|
|
|
//IsValid checks if token is stored and not expired
|
|
func (tr *TokenRepository) IsValid(token *domain.Token, rights ...string) (valid bool, authorized bool, err error) {
|
|
exists, err := tr.kvstore.Exists("token:" + token.Secret)
|
|
if err != nil {
|
|
return false, false, NewErr(err)
|
|
}
|
|
if !exists {
|
|
return false, false, storage.ErrTokenNotFound
|
|
}
|
|
valid, authorized = tr.validateToken(token, rights)
|
|
if !valid {
|
|
return false, false, storage.ErrTokenInvalid
|
|
}
|
|
return valid, authorized, nil
|
|
}
|
|
|
|
//DeleteToken removes token if exists
|
|
func (tr *TokenRepository) DeleteToken(token *domain.Token) error {
|
|
ok, err := tr.kvstore.Del("token:" + token.Secret)
|
|
if err != nil {
|
|
// TODO same shit as above
|
|
return NewErr(err)
|
|
}
|
|
if !ok {
|
|
return storage.ErrTokenNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
//GetTokenPayload returns encoded token data.
|
|
func (tr *TokenRepository) GetTokenPayload(token *domain.Token) (opaqueID string, userAgent string, err error) {
|
|
var claims Claims
|
|
_, err = jwt.ParseWithClaims(token.Secret, &claims, func(token *jwt.Token) (interface{}, error) {
|
|
return tr.publicKey, nil
|
|
})
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
return claims.OpaqueID, claims.UserAgent, nil
|
|
}
|
|
|
|
func configurePrivateKey(filepath string) (*rsa.PrivateKey, error) {
|
|
//jwt takes care of key parsing
|
|
f, err := ioutil.ReadFile(filepath) //nolint:gosec
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return jwt.ParseRSAPrivateKeyFromPEM(f)
|
|
}
|
|
|
|
func configurePublicKey(filepath string) (*rsa.PublicKey, error) {
|
|
//jwt takes care of key parsing
|
|
f, err := ioutil.ReadFile(filepath) //nolint:gosec
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return jwt.ParseRSAPublicKeyFromPEM(f)
|
|
}
|
|
|
|
func (tr *TokenRepository) generateSecret(
|
|
username string,
|
|
userAgent string,
|
|
opaqueID string,
|
|
rights []string) (string, error) {
|
|
expirationTime := time.Now().Add(tokenExpirationDuration)
|
|
|
|
claims := Claims{
|
|
User: username,
|
|
UserAgent: userAgent,
|
|
OpaqueID: opaqueID,
|
|
Rights: rights,
|
|
StandardClaims: jwt.StandardClaims{
|
|
ExpiresAt: expirationTime.Unix(),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
tokenString, err := token.SignedString(tr.privateKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return tokenString, nil
|
|
}
|
|
|
|
func (tr *TokenRepository) processKey(key string, keys []*domain.Token) []*domain.Token {
|
|
secret := string([]rune(key)[6:])
|
|
token := &domain.Token{Secret: secret}
|
|
valid, _ := tr.validateToken(token, []string{})
|
|
if !valid {
|
|
_, err := tr.kvstore.Del(key)
|
|
if err != nil {
|
|
tr.logger.Warning("floating key detected %s\n", key)
|
|
}
|
|
} else {
|
|
keys = append(keys, token)
|
|
}
|
|
return keys
|
|
}
|
|
|
|
func (tr *TokenRepository) validateToken(token *domain.Token, rights []string) (valid bool, authorized bool) {
|
|
var claims Claims
|
|
tkn, err := jwt.ParseWithClaims(token.Secret, &claims, func(token *jwt.Token) (interface{}, error) {
|
|
return tr.publicKey, nil
|
|
})
|
|
if err != nil {
|
|
return false, false
|
|
}
|
|
if time.Now().Unix() > claims.ExpiresAt {
|
|
return false, false
|
|
}
|
|
|
|
return tkn.Valid, claims.Rights.containsAll(rights...)
|
|
|
|
}
|
|
|
|
//Claims contains token data which will be encrypted and form the JWT token including headers and signing.
|
|
type Claims struct {
|
|
jwt.StandardClaims
|
|
User string `json:"user"`
|
|
UserAgent string `json:"userAgent"`
|
|
OpaqueID string `json:"opaqueId"`
|
|
Rights Rights `json:"rights"`
|
|
}
|
|
|
|
//Rights is a list of right-codes
|
|
type Rights []string
|
|
|
|
func (c Rights) containsAll(rights ...string) bool {
|
|
for _, r := range rights {
|
|
if !c.contains(r) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (c Rights) contains(right string) bool {
|
|
for _, r := range c {
|
|
if r == right {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|