From 8946717bcce93be90a23a4f982ede121a1f3f875 Mon Sep 17 00:00:00 2001 From: Jacob Lindgren Date: Wed, 7 Feb 2024 10:45:09 -0600 Subject: [PATCH] change AllowedIP to use postgres types rather than in code --- internal/handlers/allowlist_handler.go | 2 +- internal/store/in_memory_store_impl.go | 2 +- internal/store/postgres_store_impl.go | 34 +++++++++------------- internal/store/postgres_store_impl_test.go | 2 +- 4 files changed, 17 insertions(+), 23 deletions(-) diff --git a/internal/handlers/allowlist_handler.go b/internal/handlers/allowlist_handler.go index 1631a65..8088c9f 100644 --- a/internal/handlers/allowlist_handler.go +++ b/internal/handlers/allowlist_handler.go @@ -38,7 +38,7 @@ func AllowlistCreateHandler(w http.ResponseWriter, r *http.Request) { } if !strings.Contains(createReq.IPBlock, "/") { - createReq.IPBlock = createReq.IPBlock + "/32" + createReq.IPBlock += "/32" } _, _, err = net.ParseCIDR(createReq.IPBlock) diff --git a/internal/store/in_memory_store_impl.go b/internal/store/in_memory_store_impl.go index 591e9ee..cc6066c 100644 --- a/internal/store/in_memory_store_impl.go +++ b/internal/store/in_memory_store_impl.go @@ -80,7 +80,7 @@ func (m *inMemoryStore) Delete(orgID string, uid string) error { func (m *inMemoryStore) AllowedAddresses(_ string) ([]AllowlistBlock, error) { return m.allowedAddresses, nil } -func (m *inMemoryStore) AllowedIP(ip, orgID string) (bool, error) { +func (m *inMemoryStore) AllowedIP(ip, _ string) (bool, error) { for _, addr := range m.allowedAddresses { _, ipnet, _ := net.ParseCIDR(addr.IPBlock) if ipnet.Contains(net.ParseIP(ip)) { diff --git a/internal/store/postgres_store_impl.go b/internal/store/postgres_store_impl.go index ebde1a9..af7939d 100644 --- a/internal/store/postgres_store_impl.go +++ b/internal/store/postgres_store_impl.go @@ -3,7 +3,6 @@ package store import ( "database/sql" "encoding/json" - "net" "time" // the pgx driver for the database @@ -174,32 +173,27 @@ func scanRegistration(row scanner) (*Registration, error) { } func (p *postgresStore) AllowedIP(ip string, orgID string) (bool, error) { - rows, err := p.db.Query(`select ip_block from allowlist where org_id = $1 or org_id = 'gateway'`, orgID) + // turns out postgres can do this on the backend! see old code that accomplishes the same thing at commit dca8f2c + + // basically its doing a subquery selecting all blocks from the org or + // gateway and then shoving them into an array and checking if the ip exists + // in those blocks. + stmt, err := p.db.Prepare(`select $1::inet << any(array(select ip_block from allowlist where org_id = $2 or org_id = 'gateway')::inet[])`) if err != nil { return false, err } - - var blocks []string - for rows.Next() { - var block string - err := rows.Scan(&block) - if err != nil { - return false, nil - } - - blocks = append(blocks, block) + row := stmt.QueryRow(ip, orgID) + if err != nil { + return false, err } - // Loop over blocks and see if they contain the IP - for _, block := range blocks { - // ignoring the error because we sanitize them on insert - _, ipnet, _ := net.ParseCIDR(block) - if ipnet.Contains(net.ParseIP(ip)) { - return true, nil - } + var valid bool + err = row.Scan(&valid) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return false, err } - return false, nil + return valid, nil } func (p *postgresStore) AllowAddress(ip *AllowlistBlock) error { diff --git a/internal/store/postgres_store_impl_test.go b/internal/store/postgres_store_impl_test.go index c6585c7..535e3df 100644 --- a/internal/store/postgres_store_impl_test.go +++ b/internal/store/postgres_store_impl_test.go @@ -288,7 +288,7 @@ func (suite *TestSuite) TestIPAllowedHappyMultiple() { OrgID: "gateway", })) - for _, ip := range []string{"10.0.0.100", "192.168.1.100"} { + for _, ip := range []string{"10.0.0.100", "192.168.1.100", "10.0.0.20/32", "192.168.1.20/32"} { allowed, err := suite.store.AllowedIP(ip, "1234") suite.True(allowed) suite.Nil(err)