139 lines
3.7 KiB
Go
139 lines
3.7 KiB
Go
package cognito_auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/service/cognitoidentityprovider"
|
|
"github.com/aws/aws-sdk-go-v2/service/cognitoidentityprovider/types"
|
|
)
|
|
|
|
type Client struct {
|
|
CognitoClient *cognitoidentityprovider.Client
|
|
poolId string
|
|
clientId string
|
|
}
|
|
|
|
func NewAuthClient(poolID string, clientId string) (AuthClient, error) {
|
|
sdkConfig, err := config.LoadDefaultConfig(context.TODO())
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return nil, err
|
|
}
|
|
cognitoClient := cognitoidentityprovider.NewFromConfig(sdkConfig)
|
|
|
|
return &Client{CognitoClient: cognitoClient, poolId: poolID, clientId: clientId}, nil
|
|
}
|
|
|
|
type AuthClient interface {
|
|
SignUp(req SignUpRequest) error
|
|
ConfirmSignUp(req ConfirmSignUpRequest) error
|
|
SignIn(ctx context.Context, username, password string) (*AuthenticationResult, error)
|
|
}
|
|
|
|
type SignUpRequest struct {
|
|
Username string `json:"username" binding:"required"`
|
|
Password string `json:"password" binding:"required"`
|
|
Email string `json:"email" binding:"required,email"`
|
|
Context context.Context
|
|
}
|
|
|
|
func (c *Client) SignUp(req SignUpRequest) error {
|
|
secretHash := getSecretHash(req.Username, *&c.clientId)
|
|
params := &cognitoidentityprovider.SignUpInput{
|
|
ClientId: aws.String(c.clientId),
|
|
Username: aws.String(req.Username),
|
|
Password: aws.String(req.Password),
|
|
UserAttributes: []types.AttributeType{
|
|
{Name: aws.String("email"), Value: aws.String(req.Email)},
|
|
},
|
|
SecretHash: &secretHash,
|
|
}
|
|
|
|
_, err := c.CognitoClient.SignUp(context.Background(), params)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
type ConfirmSignUpRequest struct {
|
|
Context context.Context
|
|
Username string
|
|
ConfirmationCode string
|
|
}
|
|
|
|
func (c *Client) ConfirmSignUp(req ConfirmSignUpRequest) error {
|
|
input := &cognitoidentityprovider.ConfirmSignUpInput{
|
|
ClientId: aws.String(c.clientId),
|
|
Username: aws.String(req.Username),
|
|
ConfirmationCode: aws.String(req.ConfirmationCode),
|
|
}
|
|
|
|
_, err := c.CognitoClient.ConfirmSignUp(req.Context, input)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type SignInRequest struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
// SignIn authenticates a user and returns tokens
|
|
func (c *Client) SignIn(ctx context.Context, username, password string) (*AuthenticationResult, error) {
|
|
authParams := map[string]string{
|
|
"USERNAME": username,
|
|
"PASSWORD": password,
|
|
}
|
|
|
|
input := &cognitoidentityprovider.InitiateAuthInput{
|
|
AuthFlow: types.AuthFlowTypeUserPasswordAuth,
|
|
ClientId: aws.String(c.clientId),
|
|
AuthParameters: authParams,
|
|
}
|
|
|
|
output, err := c.CognitoClient.InitiateAuth(ctx, input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if output.AuthenticationResult == nil {
|
|
return nil, errors.New("authentication result is nil")
|
|
}
|
|
|
|
return &AuthenticationResult{
|
|
IDToken: aws.ToString(output.AuthenticationResult.IdToken),
|
|
AccessToken: aws.ToString(output.AuthenticationResult.AccessToken),
|
|
RefreshToken: aws.ToString(output.AuthenticationResult.RefreshToken),
|
|
}, nil
|
|
}
|
|
|
|
// AuthenticationResult holds the tokens returned after successful authentication
|
|
type AuthenticationResult struct {
|
|
IDToken string
|
|
AccessToken string
|
|
RefreshToken string
|
|
}
|
|
|
|
func getSecretHash(username string, clientID string) string {
|
|
secret := os.Getenv("COGNITO_SECRET")
|
|
if secret == "" {
|
|
secret = "1bb1r4fegke1hcn6rjo8d38io5np0qcce7juhjb8hu4kvu6qfr3s"
|
|
}
|
|
mac := hmac.New(sha256.New, []byte(secret))
|
|
mac.Write([]byte(fmt.Sprintf("%s%s", username, clientID)))
|
|
|
|
return base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
|
}
|