Skip to content

Commit

Permalink
change AllowedIP to use postgres types rather than in code
Browse files Browse the repository at this point in the history
  • Loading branch information
lindgrenj6 committed Feb 7, 2024
1 parent 99856eb commit 43159c6
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 23 deletions.
2 changes: 1 addition & 1 deletion internal/handlers/allowlist_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func AllowlistCreateHandler(w http.ResponseWriter, r *http.Request) {
}

if !strings.Contains(createReq.IPBlock, "/") {
createReq.IPBlock = createReq.IPBlock + "/32"
createReq.IPBlock += "/32"
}

_, _, err = net.ParseCIDR(createReq.IPBlock)
Expand Down
2 changes: 1 addition & 1 deletion internal/store/in_memory_store_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (m *inMemoryStore) Delete(orgID string, uid string) error {
func (m *inMemoryStore) AllowedAddresses(_ string) ([]AllowlistBlock, error) {
return m.allowedAddresses, nil
}
func (m *inMemoryStore) AllowedIP(ip, orgID string) (bool, error) {
func (m *inMemoryStore) AllowedIP(ip, _ string) (bool, error) {
for _, addr := range m.allowedAddresses {
_, ipnet, _ := net.ParseCIDR(addr.IPBlock)
if ipnet.Contains(net.ParseIP(ip)) {
Expand Down
34 changes: 14 additions & 20 deletions internal/store/postgres_store_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package store
import (
"database/sql"
"encoding/json"
"net"
"time"

// the pgx driver for the database
Expand Down Expand Up @@ -174,32 +173,27 @@ func scanRegistration(row scanner) (*Registration, error) {
}

func (p *postgresStore) AllowedIP(ip string, orgID string) (bool, error) {
rows, err := p.db.Query(`select ip_block from allowlist where org_id = $1 or org_id = 'gateway'`, orgID)
// turns out postgres can do this on the backend! see old code that accomplishes the same thing at commit dca8f2c

// basically its doing a subquery selecting all blocks from the org or
// gateway and then shoving them into an array and checking if the ip exists
// in those blocks.
stmt, err := p.db.Prepare(`select $1::inet << any(array(select ip_block from allowlist where org_id = $2 or org_id = 'gateway')::inet[])`)
if err != nil {
return false, err
}

var blocks []string
for rows.Next() {
var block string
err := rows.Scan(&block)
if err != nil {
return false, nil
}

blocks = append(blocks, block)
row := stmt.QueryRow(ip, orgID)
if err != nil {
return false, err
}

// Loop over blocks and see if they contain the IP
for _, block := range blocks {
// ignoring the error because we sanitize them on insert
_, ipnet, _ := net.ParseCIDR(block)
if ipnet.Contains(net.ParseIP(ip)) {
return true, nil
}
var valid bool
err = row.Scan(&valid)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, err
}

return false, nil
return valid, nil
}

func (p *postgresStore) AllowAddress(ip *AllowlistBlock) error {
Expand Down
2 changes: 1 addition & 1 deletion internal/store/postgres_store_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func (suite *TestSuite) TestIPAllowedHappyMultiple() {
OrgID: "gateway",
}))

for _, ip := range []string{"10.0.0.100", "192.168.1.100"} {
for _, ip := range []string{"10.0.0.100", "192.168.1.100", "10.0.0.20/32", "192.168.1.20/32"} {
allowed, err := suite.store.AllowedIP(ip, "1234")
suite.True(allowed)
suite.Nil(err)
Expand Down

0 comments on commit 43159c6

Please sign in to comment.