package cauth import ( "context" "crypto/rand" "encoding/base64" "fmt" "log" "net/http" "time" "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" ) type Handlers struct { oauth2Config *oauth2.Config session SessionStorer verifier *oidc.IDTokenVerifier } func NewHandler(oauth2Config *oauth2.Config, session SessionStorer, verifier *oidc.IDTokenVerifier) (*Handlers, error) { return &Handlers{ oauth2Config: oauth2Config, session: session, verifier: verifier, }, nil } type UserClaims struct { Email string `json:"email"` Verified bool `json:"email_verified"` Name string `json:"given_name"` Username string `json:"cognito:username"` Picture string `json:"picture"` Sub string `json:"sub"` Groups []string `json:"cognito:groups"` } func generateState() (string, error) { b := make([]byte, 16) _, err := rand.Read(b) if err != nil { return "", err } return base64.URLEncoding.EncodeToString(b), nil } func (h *Handlers) SignIn(w http.ResponseWriter, r *http.Request) { state, err := generateState() if err != nil { log.Println("Failed to generate state") http.Error(w, "Something went wrong", http.StatusInternalServerError) return } session, err := h.session.Get(r) if err != nil { http.Error(w, "Failed to get session", http.StatusInternalServerError) return } session.Values["state"] = state err = session.Save(r, w) if err != nil { http.Error(w, "Failed to save session", http.StatusInternalServerError) return } http.Redirect(w, r, h.oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusFound) } // CallbackHandler handles the OAuth2 callback from Cognito func (h *Handlers) CallbackHandler(w http.ResponseWriter, r *http.Request) { ctx := context.Background() session, err := h.session.Get(r) if err != nil { http.Error(w, "Failed to get session", http.StatusInternalServerError) return } state, ok := session.Values["state"].(string) if !ok || state != r.URL.Query().Get("state") { http.Error(w, "Invalid state parameter", http.StatusBadRequest) return } code := r.URL.Query().Get("code") if code == "" { http.Error(w, "Code not found", http.StatusBadRequest) return } oauth2Token, err := h.oauth2Config.Exchange(ctx, code) if err != nil { log.Printf("Failed to exchange token: %v", err) http.Error(w, "Failed to exchange token", http.StatusInternalServerError) return } rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { http.Error(w, "No id_token field in oauth2 token", http.StatusInternalServerError) return } idToken, err := h.verifier.Verify(ctx, rawIDToken) if err != nil { log.Printf("Failed to verify ID Token: %v", err) http.Error(w, "Failed to verify ID Token", http.StatusInternalServerError) return } var claims UserClaims if err := idToken.Claims(&claims); err != nil { log.Printf("Failed to parse claims: %v", err) http.Error(w, "Failed to parse claims", http.StatusInternalServerError) return } session.Values["access_token"] = oauth2Token.AccessToken session.Values["user_info"] = claims fmt.Println(claims) session.Options.MaxAge = int(oauth2Token.Expiry.Sub(time.Now()).Seconds()) err = session.Save(r, w) if err != nil { log.Printf("Failed to save session: %v", err) http.Error(w, "Failed to save session", http.StatusInternalServerError) return } http.Redirect(w, r, "/", http.StatusFound) } // LogoutHandler clears the session and logs the user out func (h *Handlers) LogoutHandler(w http.ResponseWriter, r *http.Request) { session, err := h.session.Get(r) if err != nil { http.Error(w, "Failed to get session", http.StatusInternalServerError) return } session.Options.MaxAge = -1 err = session.Save(r, w) if err != nil { log.Printf("Failed to clear session: %v", err) http.Error(w, "Failed to clear session", http.StatusInternalServerError) return } http.Redirect(w, r, "/", http.StatusSeeOther) }