cauth/middleware.go

99 lines
2.2 KiB
Go
Raw Permalink Normal View History

2024-10-26 10:27:18 +00:00
package cauth
import (
"context"
"fmt"
"github.com/lestrrat-go/jwx/jwk"
"log"
"net/http"
)
type contextKey string
const userContextKey = contextKey("user")
type Middleware struct {
s SessionStorer
ck jwk.Set
}
func NewMiddleware(s SessionStorer, cognitoUrl string) *Middleware {
cognitoKeySet, err := jwk.Fetch(context.Background(), cognitoUrl)
if err != nil {
log.Fatal(err)
}
return &Middleware{s, cognitoKeySet}
}
func (m *Middleware) AddUserInfo(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := m.s.Get(r)
if err != nil {
next.ServeHTTP(w, r)
return
}
token := session.Values["access_token"]
if token == "" || token == nil {
next.ServeHTTP(w, r)
return
}
userInfo := session.Values["user_info"]
fmt.Println(userInfo)
ctx := context.WithValue(r.Context(), userContextKey, userInfo)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// ProtectedRoute Checks if session and token are present, if not return a 401 response
2024-10-26 10:27:18 +00:00
func (m *Middleware) ProtectedRoute(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := m.s.Get(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
2024-10-26 10:27:18 +00:00
return
}
token := session.Values["access_token"]
if token == "" || token == nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
// ProtectedRouteWithRedirect Checks if session and token are present, if not return a redirect to /login
func (m *Middleware) ProtectedRouteWithRedirect(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := m.s.Get(r)
if err != nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
token := session.Values["access_token"]
if token == "" || token == nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
2024-10-26 10:27:18 +00:00
return
}
next.ServeHTTP(w, r)
})
}
func GetUserFromContext(r *http.Request) *UserClaims {
userOptional := r.Context().Value(userContextKey)
if userOptional != nil {
user := userOptional.(UserClaims)
return &user
}
return &UserClaims{}
}