155 lines
3.9 KiB
Go
155 lines
3.9 KiB
Go
package cauth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type Handlers struct {
|
|
oauth2Config *oauth2.Config
|
|
session SessionStorer
|
|
verifier *oidc.IDTokenVerifier
|
|
}
|
|
|
|
func NewHandler(oauth2Config *oauth2.Config, session SessionStorer, verifier *oidc.IDTokenVerifier) (*Handlers, error) {
|
|
return &Handlers{
|
|
oauth2Config: oauth2Config,
|
|
session: session,
|
|
verifier: verifier,
|
|
}, nil
|
|
}
|
|
|
|
type UserClaims struct {
|
|
Email string `json:"email"`
|
|
Verified bool `json:"email_verified"`
|
|
Name string `json:"given_name"`
|
|
Username string `json:"cognito:username"`
|
|
Picture string `json:"picture"`
|
|
Sub string `json:"sub"`
|
|
Groups []string `json:"cognito:groups"`
|
|
}
|
|
|
|
func generateState() (string, error) {
|
|
b := make([]byte, 16)
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
func (h *Handlers) SignIn(w http.ResponseWriter, r *http.Request) {
|
|
state, err := generateState()
|
|
if err != nil {
|
|
log.Println("Failed to generate state")
|
|
http.Error(w, "Something went wrong", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
session, err := h.session.Get(r)
|
|
if err != nil {
|
|
http.Error(w, "Failed to get session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
session.Values["state"] = state
|
|
err = session.Save(r, w)
|
|
if err != nil {
|
|
http.Error(w, "Failed to save session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, h.oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusFound)
|
|
}
|
|
|
|
// CallbackHandler handles the OAuth2 callback from Cognito
|
|
func (h *Handlers) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
|
ctx := context.Background()
|
|
|
|
session, err := h.session.Get(r)
|
|
if err != nil {
|
|
http.Error(w, "Failed to get session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
state, ok := session.Values["state"].(string)
|
|
if !ok || state != r.URL.Query().Get("state") {
|
|
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
http.Error(w, "Code not found", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
|
|
if err != nil {
|
|
log.Printf("Failed to exchange token: %v", err)
|
|
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
|
if !ok {
|
|
http.Error(w, "No id_token field in oauth2 token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
idToken, err := h.verifier.Verify(ctx, rawIDToken)
|
|
if err != nil {
|
|
log.Printf("Failed to verify ID Token: %v", err)
|
|
http.Error(w, "Failed to verify ID Token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var claims UserClaims
|
|
|
|
if err := idToken.Claims(&claims); err != nil {
|
|
log.Printf("Failed to parse claims: %v", err)
|
|
http.Error(w, "Failed to parse claims", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
session.Values["access_token"] = oauth2Token.AccessToken
|
|
session.Values["user_info"] = claims
|
|
|
|
fmt.Println(claims)
|
|
session.Options.MaxAge = int(oauth2Token.Expiry.Sub(time.Now()).Seconds())
|
|
err = session.Save(r, w)
|
|
if err != nil {
|
|
log.Printf("Failed to save session: %v", err)
|
|
http.Error(w, "Failed to save session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|
|
|
|
// LogoutHandler clears the session and logs the user out
|
|
func (h *Handlers) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
|
session, err := h.session.Get(r)
|
|
if err != nil {
|
|
http.Error(w, "Failed to get session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
session.Options.MaxAge = -1
|
|
err = session.Save(r, w)
|
|
if err != nil {
|
|
log.Printf("Failed to clear session: %v", err)
|
|
http.Error(w, "Failed to clear session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
}
|