99 lines
2.2 KiB
Go
99 lines
2.2 KiB
Go
package cauth
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/lestrrat-go/jwx/jwk"
|
|
"log"
|
|
"net/http"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const userContextKey = contextKey("user")
|
|
|
|
type Middleware struct {
|
|
s SessionStorer
|
|
ck jwk.Set
|
|
}
|
|
|
|
func NewMiddleware(s SessionStorer, cognitoUrl string) *Middleware {
|
|
cognitoKeySet, err := jwk.Fetch(context.Background(), cognitoUrl)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
return &Middleware{s, cognitoKeySet}
|
|
}
|
|
|
|
func (m *Middleware) AddUserInfo(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
session, err := m.s.Get(r)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
token := session.Values["access_token"]
|
|
|
|
if token == "" || token == nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
userInfo := session.Values["user_info"]
|
|
fmt.Println(userInfo)
|
|
|
|
ctx := context.WithValue(r.Context(), userContextKey, userInfo)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// ProtectedRoute Checks if session and token are present, if not return a 401 response
|
|
func (m *Middleware) ProtectedRoute(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
session, err := m.s.Get(r)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
token := session.Values["access_token"]
|
|
|
|
if token == "" || token == nil {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// ProtectedRouteWithRedirect Checks if session and token are present, if not return a redirect to /login
|
|
func (m *Middleware) ProtectedRouteWithRedirect(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
session, err := m.s.Get(r)
|
|
if err != nil {
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
token := session.Values["access_token"]
|
|
|
|
if token == "" || token == nil {
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func GetUserFromContext(r *http.Request) *UserClaims {
|
|
userOptional := r.Context().Value(userContextKey)
|
|
if userOptional != nil {
|
|
user := userOptional.(UserClaims)
|
|
return &user
|
|
}
|
|
|
|
return &UserClaims{}
|
|
}
|