Skip to content

Commit

Permalink
add ALLOWLIST_ENABLED configuration setting
Browse files Browse the repository at this point in the history
  • Loading branch information
lindgrenj6 committed Feb 5, 2024
1 parent c3b7958 commit ec206ba
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 28 deletions.
5 changes: 5 additions & 0 deletions deployments/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ objects:
value: ${TOKEN_TTL_DURATION}
- name: STORE_BACKEND
value: ${STORE_BACKEND}
- name: ALLOWLIST_ENABLED
value: ${ALLOWLIST_ENABLED}
- name: DISABLE_CATCHALL
value: ${DISABLE_CATCHALL}
- name: IS_INTERNAL_LABEL
Expand Down Expand Up @@ -325,3 +327,6 @@ parameters:
- name: CERT_DIR
description: the base directory where ssl certs are stored
value: "/certs"
- name: ALLOWLIST_ENABLED
description: whether to check registrations against the internal allowlist
value: "false"
3 changes: 3 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type MbopConfig struct {
KeyCloakTokenGrantType string
KeyCloakTokenClientID string

AllowlistEnabled bool
StoreBackend string
DatabaseHost string
DatabasePort string
Expand All @@ -61,6 +62,7 @@ func Get() *MbopConfig {
}

disableCatchAll, _ := strconv.ParseBool(fetchWithDefault("DISABLE_CATCHALL", "false"))
allowlistEnabled, _ := strconv.ParseBool(fetchWithDefault("ALLOWLIST_ENABLED", "false"))
debug, _ := strconv.ParseBool(fetchWithDefault("DEBUG", "false"))
certDir := fetchWithDefault("CERT_DIR", "/certs")
keyCloakTimeout, _ := strconv.ParseInt(fetchWithDefault("KEYCLOAK_TIMEOUT", "60"), 0, 64)
Expand Down Expand Up @@ -90,6 +92,7 @@ func Get() *MbopConfig {
DatabasePassword: fetchWithDefault("DATABASE_PASSWORD", ""),
DatabaseName: fetchWithDefault("DATABASE_NAME", "mbop"),
StoreBackend: fetchWithDefault("STORE_BACKEND", "memory"),
AllowlistEnabled: allowlistEnabled,

CognitoAppClientID: fetchWithDefault("COGNITO_APP_CLIENT_ID", ""),
CognitoAppClientSecret: fetchWithDefault("COGNITO_APP_CLIENT_SECRET", ""),
Expand Down
8 changes: 5 additions & 3 deletions internal/handlers/allowlist_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"encoding/json"
"errors"
"net/http"
"runtime"
"time"

"github.com/go-chi/chi/v5"
l "github.com/redhatinsights/mbop/internal/logger"
"github.com/redhatinsights/mbop/internal/store"
"github.com/redhatinsights/platform-go-middlewares/identity"
)
Expand Down Expand Up @@ -85,7 +85,6 @@ func AllowlistListHandler(w http.ResponseWriter, r *http.Request) {

db := store.GetStore()

runtime.Breakpoint()
addrs, err := db.AllowedAddresses(id.Identity.OrgID)
if err != nil {
do500(w, "error listing addresses: %w"+err.Error())
Expand All @@ -101,5 +100,8 @@ func AllowlistListHandler(w http.ResponseWriter, r *http.Request) {
}
}

json.NewEncoder(w).Encode(out)
err = json.NewEncoder(w).Encode(out)
if err != nil {
l.Log.Info("failed to encode response", "error", err)
}
}
42 changes: 23 additions & 19 deletions internal/handlers/registration_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/go-chi/chi/v5"
"github.com/redhatinsights/mbop/internal/config"
"github.com/redhatinsights/mbop/internal/store"
"github.com/redhatinsights/platform-go-middlewares/identity"
)
Expand Down Expand Up @@ -78,27 +79,21 @@ func RegistrationListHandler(w http.ResponseWriter, r *http.Request) {

func RegistrationCreateHandler(w http.ResponseWriter, r *http.Request) {
id := identity.Get(r.Context())
if !id.Identity.User.OrgAdmin {
doError(w, "user must be org admin to register satellite", 403)
return
}
if id.Identity.User.Username == "" {
do400(w, "[username] not present in identity header")
return
}

db := store.GetStore()

allowed, err := db.AllowedIP(&store.Address{
IP: r.Header.Get("x-forwarded-for"),
OrgID: id.Identity.OrgID,
})
if err != nil {
do500(w, "error listing ip addresses: "+err.Error())
return
}
if !allowed {
doError(w, "address is not allowlisted", 403)
if config.Get().AllowlistEnabled {
allowed, err := db.AllowedIP(&store.Address{
IP: r.Header.Get("x-forwarded-for"),
OrgID: id.Identity.OrgID,
})
if err != nil {
do500(w, "error listing ip addresses: "+err.Error())
return
}
if !allowed {
doError(w, "address is not allowlisted", 403)
return
}
}

b, err := io.ReadAll(r.Body)
Expand All @@ -124,6 +119,15 @@ func RegistrationCreateHandler(w http.ResponseWriter, r *http.Request) {
return
}

if !id.Identity.User.OrgAdmin {
doError(w, "user must be org admin to register satellite", 403)
return
}
if id.Identity.User.Username == "" {
do400(w, "[username] not present in identity header")
return
}

gatewayCN, err := getCertCN(r.Header.Get(CertHeader))
if err != nil {
do400(w, err.Error())
Expand Down
2 changes: 1 addition & 1 deletion internal/store/in_memory_store_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (m *inMemoryStore) Delete(orgID string, uid string) error {
return ErrRegistrationNotFound
}

func (m *inMemoryStore) AllowedAddresses(orgID string) ([]Address, error) {
func (m *inMemoryStore) AllowedAddresses(_ string) ([]Address, error) {
return m.allowedAddresses, nil
}
func (m *inMemoryStore) AllowedIP(ip *Address) (bool, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/store/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type RegistrationStore interface {
}

type AllowlistStore interface {
AllowedAddresses(orgId string) ([]Address, error)
AllowedAddresses(orgID string) ([]Address, error)
AllowedIP(ip *Address) (bool, error)
AllowAddress(ip *Address) error
DenyAddress(ip *Address) error
Expand Down
6 changes: 2 additions & 4 deletions internal/store/postgres_store_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package store
import (
"database/sql"
"encoding/json"
"runtime"
"time"

// the pgx driver for the database
Expand Down Expand Up @@ -210,16 +209,15 @@ func (p *postgresStore) DenyAddress(ip *Address) error {
return nil
}

func (p *postgresStore) AllowedAddresses(orgId string) ([]Address, error) {
func (p *postgresStore) AllowedAddresses(orgID string) ([]Address, error) {
rows, err := p.db.Query(`select
org_id, ip, created_at
from allowlist
where org_id = $1`, orgId)
where org_id = $1`, orgID)
if err != nil {
return nil, err
}

runtime.Breakpoint()
addresses := make([]Address, 0)
for rows.Next() {
var (
Expand Down

0 comments on commit ec206ba

Please sign in to comment.