Skip to content

Commit

Permalink
change AllowedIP to use postgres types rather than in code
Browse files Browse the repository at this point in the history
  • Loading branch information
lindgrenj6 committed Feb 8, 2024
1 parent 99856eb commit fb3388a
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 29 deletions.
2 changes: 1 addition & 1 deletion internal/handlers/allowlist_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func AllowlistCreateHandler(w http.ResponseWriter, r *http.Request) {
}

if !strings.Contains(createReq.IPBlock, "/") {
createReq.IPBlock = createReq.IPBlock + "/32"
createReq.IPBlock += "/32"
}

_, _, err = net.ParseCIDR(createReq.IPBlock)
Expand Down
2 changes: 1 addition & 1 deletion internal/store/in_memory_store_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (m *inMemoryStore) Delete(orgID string, uid string) error {
func (m *inMemoryStore) AllowedAddresses(_ string) ([]AllowlistBlock, error) {
return m.allowedAddresses, nil
}
func (m *inMemoryStore) AllowedIP(ip, orgID string) (bool, error) {
func (m *inMemoryStore) AllowedIP(ip, _ string) (bool, error) {
for _, addr := range m.allowedAddresses {
_, ipnet, _ := net.ParseCIDR(addr.IPBlock)
if ipnet.Contains(net.ParseIP(ip)) {
Expand Down
4 changes: 2 additions & 2 deletions internal/store/migrations/5_add_allowlist.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ create table if not exists public.allowlist(
);

alter table if exists public.allowlist
add constraint allowlist_unique_cidr
primary key (ip_block);
add constraint allowlist_unique_cidr_per_org
primary key (ip_block, org_id);
33 changes: 10 additions & 23 deletions internal/store/postgres_store_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package store
import (
"database/sql"
"encoding/json"
"net"
"time"

// the pgx driver for the database
Expand Down Expand Up @@ -174,32 +173,20 @@ func scanRegistration(row scanner) (*Registration, error) {
}

func (p *postgresStore) AllowedIP(ip string, orgID string) (bool, error) {
rows, err := p.db.Query(`select ip_block from allowlist where org_id = $1 or org_id = 'gateway'`, orgID)
if err != nil {
return false, err
}

var blocks []string
for rows.Next() {
var block string
err := rows.Scan(&block)
if err != nil {
return false, nil
}
// turns out postgres can do this on the backend! see old code that accomplishes the same thing at commit dca8f2c

blocks = append(blocks, block)
}
// basically its doing a subquery selecting all blocks from the org or
// gateway and then shoving them into an array and checking if the ip exists
// in those blocks.
row := p.db.QueryRow(`select $1::inet << any(array(select ip_block from allowlist where org_id = $2)::inet[])`, ip, orgID)

// Loop over blocks and see if they contain the IP
for _, block := range blocks {
// ignoring the error because we sanitize them on insert
_, ipnet, _ := net.ParseCIDR(block)
if ipnet.Contains(net.ParseIP(ip)) {
return true, nil
}
var valid bool
err := row.Scan(&valid)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, err
}

return false, nil
return valid, nil
}

func (p *postgresStore) AllowAddress(ip *AllowlistBlock) error {
Expand Down
4 changes: 2 additions & 2 deletions internal/store/postgres_store_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ func (suite *TestSuite) TestIPAllowedHappyMultiple() {
}))
suite.Nil(suite.store.AllowAddress(&AllowlistBlock{
IPBlock: "192.168.1.1/24",
OrgID: "gateway",
OrgID: "1234",
}))

for _, ip := range []string{"10.0.0.100", "192.168.1.100"} {
for _, ip := range []string{"10.0.0.100", "192.168.1.100", "10.0.0.20/32", "192.168.1.20/32"} {
allowed, err := suite.store.AllowedIP(ip, "1234")
suite.True(allowed)
suite.Nil(err)
Expand Down

0 comments on commit fb3388a

Please sign in to comment.