diff --git a/internal/handlers/allowlist_handler.go b/internal/handlers/allowlist_handler.go index 2b6dc8c..947dd68 100644 --- a/internal/handlers/allowlist_handler.go +++ b/internal/handlers/allowlist_handler.go @@ -3,8 +3,6 @@ package handlers import ( "encoding/json" "errors" - "fmt" - "io" "net/http" "runtime" "time" @@ -34,8 +32,7 @@ func AllowlistCreateHandler(w http.ResponseWriter, r *http.Request) { var createReq allowlistCreateRequest err := json.NewDecoder(r.Body).Decode(&createReq) if err != nil { - w.WriteHeader(400) - io.WriteString(w, "invalid json in body - only expected key is [ip]") + do400(w, "invalid json in body - only expected key is [ip]") return } @@ -43,9 +40,7 @@ func AllowlistCreateHandler(w http.ResponseWriter, r *http.Request) { err = db.AllowAddress(&store.Address{IP: createReq.IP, OrgID: id.Identity.OrgID}) if err != nil { - w.WriteHeader(500) - err = fmt.Errorf("error storing address: %w", err) - io.WriteString(w, err.Error()) + do500(w, "error storing address: "+err.Error()) return } @@ -61,8 +56,7 @@ func AllowlistDeleteHandler(w http.ResponseWriter, r *http.Request) { ip := chi.URLParam(r, "address") if ip == "" { - w.WriteHeader(400) - io.WriteString(w, "need address in path in the form `/v1/allowlist/{address}") + do400(w, "need address in path in the form `/v1/allowlist/{address}") return } @@ -71,14 +65,11 @@ func AllowlistDeleteHandler(w http.ResponseWriter, r *http.Request) { err := db.DenyAddress(&store.Address{IP: ip}) if err != nil { if errors.Is(err, store.ErrAddressNotAllowListed) { - w.WriteHeader(404) - io.WriteString(w, "ip not allowlisted") + doError(w, "ip not allowlisted", 404) return } - w.WriteHeader(500) - err = fmt.Errorf("error deleting addressaddress: %w", err) - io.WriteString(w, err.Error()) + do500(w, "error deleting addressaddress: %w"+err.Error()) return } @@ -97,9 +88,7 @@ func AllowlistListHandler(w http.ResponseWriter, r *http.Request) { runtime.Breakpoint() addrs, err := db.AllowedAddresses(id.Identity.OrgID) if err != nil { - w.WriteHeader(500) - err = fmt.Errorf("error listing addresses: %w", err) - io.WriteString(w, err.Error()) + do500(w, "error listing addresses: %w"+err.Error()) return } diff --git a/internal/handlers/registration_handler.go b/internal/handlers/registration_handler.go index ed86664..75b6639 100644 --- a/internal/handlers/registration_handler.go +++ b/internal/handlers/registration_handler.go @@ -77,6 +77,30 @@ func RegistrationListHandler(w http.ResponseWriter, r *http.Request) { } func RegistrationCreateHandler(w http.ResponseWriter, r *http.Request) { + id := identity.Get(r.Context()) + if !id.Identity.User.OrgAdmin { + doError(w, "user must be org admin to register satellite", 403) + return + } + if id.Identity.User.Username == "" { + do400(w, "[username] not present in identity header") + return + } + + db := store.GetStore() + + allowed, err := db.AllowedIP(&store.Address{ + IP: r.Header.Get("x-forwarded-for"), + OrgID: id.Identity.OrgID, + }) + if err != nil { + do500(w, "error listing ip addresses: "+err.Error()) + return + } + if !allowed { + doError(w, "address is not allowlisted", 403) + } + b, err := io.ReadAll(r.Body) if err != nil { do500(w, "failed to read body bytes: "+err.Error()) @@ -100,16 +124,6 @@ func RegistrationCreateHandler(w http.ResponseWriter, r *http.Request) { return } - id := identity.Get(r.Context()) - if !id.Identity.User.OrgAdmin { - doError(w, "user must be org admin to register satellite", 403) - return - } - if id.Identity.User.Username == "" { - do400(w, "[username] not present in identity header") - return - } - gatewayCN, err := getCertCN(r.Header.Get(CertHeader)) if err != nil { do400(w, err.Error()) @@ -121,7 +135,6 @@ func RegistrationCreateHandler(w http.ResponseWriter, r *http.Request) { return } - db := store.GetStore() _, err = db.Create(&store.Registration{ OrgID: id.Identity.OrgID, Username: id.Identity.User.Username,