Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions admin/server/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,8 @@ func (a *Authenticator) authStart(w http.ResponseWriter, r *http.Request, signup
}

// If this is part of the custom domain login flow, save that info in the cookie since we need that info when handling the auth callback.
customDomainFlow := false
if b, err := strconv.ParseBool(r.URL.Query().Get("custom_domain_flow")); err == nil && b {
sess.Values[cookieFieldCustomDomainFlow] = b
customDomainFlow = b
}

// Save cookie
Expand All @@ -191,6 +189,12 @@ func (a *Authenticator) authStart(w http.ResponseWriter, r *http.Request, signup
return
}

err := a.validateRedirectURL(r.Context(), redirect)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

// Redirect to <canonical-domain>/auth/login (custom domain flow)
host := originalHost(r)
if a.admin.URLs.IsCustomDomain(host) {
Expand All @@ -200,12 +204,6 @@ func (a *Authenticator) authStart(w http.ResponseWriter, r *http.Request, signup
return
}

err := a.validateRedirectURL(r.Context(), redirect, customDomainFlow)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

// Redirect to auth provider (canonical domain flow)
redirectURL := a.oauth2.AuthCodeURL(state)
if signup {
Expand Down Expand Up @@ -596,7 +594,7 @@ func (a *Authenticator) authLogoutProvider(w http.ResponseWriter, r *http.Reques
// Validate and set custom redirect destination in cookie for when the logout flow is over (if any)
redirect := r.URL.Query().Get("redirect")
if redirect != "" {
err := a.validateRedirectURL(r.Context(), redirect, true)
err := a.validateRedirectURL(r.Context(), redirect)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
Expand Down Expand Up @@ -728,13 +726,10 @@ func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
}
}

func (a *Authenticator) validateRedirectURL(ctx context.Context, redirect string, allowCustomDomains bool) error {
func (a *Authenticator) validateRedirectURL(ctx context.Context, redirect string) error {
if a.admin.URLs.IsSafeRedirectURL(redirect) {
return nil
}
if !allowCustomDomains {
return fmt.Errorf("redirect to %q is not allowed", redirect)
}

parsed, err := url.Parse(redirect)
if err != nil {
Expand Down
Loading