diff --git a/internal/handlers/allowlist_handler.go b/internal/handlers/allowlist_handler.go index 1631a65..8088c9f 100644 --- a/internal/handlers/allowlist_handler.go +++ b/internal/handlers/allowlist_handler.go @@ -38,7 +38,7 @@ func AllowlistCreateHandler(w http.ResponseWriter, r *http.Request) { } if !strings.Contains(createReq.IPBlock, "/") { - createReq.IPBlock = createReq.IPBlock + "/32" + createReq.IPBlock += "/32" } _, _, err = net.ParseCIDR(createReq.IPBlock) diff --git a/internal/store/in_memory_store_impl.go b/internal/store/in_memory_store_impl.go index 591e9ee..cc6066c 100644 --- a/internal/store/in_memory_store_impl.go +++ b/internal/store/in_memory_store_impl.go @@ -80,7 +80,7 @@ func (m *inMemoryStore) Delete(orgID string, uid string) error { func (m *inMemoryStore) AllowedAddresses(_ string) ([]AllowlistBlock, error) { return m.allowedAddresses, nil } -func (m *inMemoryStore) AllowedIP(ip, orgID string) (bool, error) { +func (m *inMemoryStore) AllowedIP(ip, _ string) (bool, error) { for _, addr := range m.allowedAddresses { _, ipnet, _ := net.ParseCIDR(addr.IPBlock) if ipnet.Contains(net.ParseIP(ip)) { diff --git a/internal/store/migrations/5_add_allowlist.up.sql b/internal/store/migrations/5_add_allowlist.up.sql index e52f7b8..a17f746 100644 --- a/internal/store/migrations/5_add_allowlist.up.sql +++ b/internal/store/migrations/5_add_allowlist.up.sql @@ -5,5 +5,5 @@ create table if not exists public.allowlist( ); alter table if exists public.allowlist - add constraint allowlist_unique_cidr - primary key (ip_block); + add constraint allowlist_unique_cidr_per_org + primary key (ip_block, org_id); diff --git a/internal/store/postgres_store_impl.go b/internal/store/postgres_store_impl.go index ebde1a9..3f75c99 100644 --- a/internal/store/postgres_store_impl.go +++ b/internal/store/postgres_store_impl.go @@ -3,7 +3,6 @@ package store import ( "database/sql" "encoding/json" - "net" "time" // the pgx driver for the database @@ -174,32 +173,20 @@ func scanRegistration(row scanner) (*Registration, error) { } func (p *postgresStore) AllowedIP(ip string, orgID string) (bool, error) { - rows, err := p.db.Query(`select ip_block from allowlist where org_id = $1 or org_id = 'gateway'`, orgID) - if err != nil { - return false, err - } - - var blocks []string - for rows.Next() { - var block string - err := rows.Scan(&block) - if err != nil { - return false, nil - } + // turns out postgres can do this on the backend! see old code that accomplishes the same thing at commit dca8f2c - blocks = append(blocks, block) - } + // basically its doing a subquery selecting all blocks from the org or + // gateway and then shoving them into an array and checking if the ip exists + // in those blocks. + row := p.db.QueryRow(`select $1::inet << any(array(select ip_block from allowlist where org_id = $2)::inet[])`, ip, orgID) - // Loop over blocks and see if they contain the IP - for _, block := range blocks { - // ignoring the error because we sanitize them on insert - _, ipnet, _ := net.ParseCIDR(block) - if ipnet.Contains(net.ParseIP(ip)) { - return true, nil - } + var valid bool + err := row.Scan(&valid) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return false, err } - return false, nil + return valid, nil } func (p *postgresStore) AllowAddress(ip *AllowlistBlock) error { diff --git a/internal/store/postgres_store_impl_test.go b/internal/store/postgres_store_impl_test.go index c6585c7..63b5435 100644 --- a/internal/store/postgres_store_impl_test.go +++ b/internal/store/postgres_store_impl_test.go @@ -285,10 +285,10 @@ func (suite *TestSuite) TestIPAllowedHappyMultiple() { })) suite.Nil(suite.store.AllowAddress(&AllowlistBlock{ IPBlock: "192.168.1.1/24", - OrgID: "gateway", + OrgID: "1234", })) - for _, ip := range []string{"10.0.0.100", "192.168.1.100"} { + for _, ip := range []string{"10.0.0.100", "192.168.1.100", "10.0.0.20/32", "192.168.1.20/32"} { allowed, err := suite.store.AllowedIP(ip, "1234") suite.True(allowed) suite.Nil(err)