Skip to content

Commit 616cde2

Browse files
authored
Merge pull request #140 from oschwald/greg/fix-networks-within
Return first IP in network with NetworksWithin
2 parents c2fcd44 + b2df6c3 commit 616cde2

File tree

2 files changed

+79
-17
lines changed

2 files changed

+79
-17
lines changed

traverse.go

+7
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) *
9595
}
9696

9797
pointer, bit := r.traverseTree(ip, 0, uint(prefixLength))
98+
99+
// We could skip this when bit >= prefixLength if we assume that the network
100+
// passed in is in canonical form. However, given that this may not be the
101+
// case, it is safest to always take the mask. If this is hot code at some
102+
// point, we could eliminate the allocation of the net.IPMask by zeroing
103+
// out the bits in ip directly.
104+
ip = ip.Mask(net.CIDRMask(bit, len(ip)*8))
98105
networks.nodes = []netNode{
99106
{
100107
ip: ip,

traverse_test.go

+72-17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package maxminddb
33
import (
44
"fmt"
55
"net"
6+
"strconv"
7+
"strings"
68
"testing"
79

810
"github.com/stretchr/testify/assert"
@@ -71,20 +73,50 @@ var tests = []networkTest{
7173
},
7274
},
7375
{
76+
// This is intentionally in non-canonical form to test
77+
// that we handle it correctly.
7478
Network: "1.1.1.1/30",
7579
Database: "ipv4",
7680
Expected: []string{
7781
"1.1.1.1/32",
7882
"1.1.1.2/31",
7983
},
8084
},
85+
{
86+
Network: "1.1.1.2/31",
87+
Database: "ipv4",
88+
Expected: []string{
89+
"1.1.1.2/31",
90+
},
91+
},
8192
{
8293
Network: "1.1.1.1/32",
8394
Database: "ipv4",
8495
Expected: []string{
8596
"1.1.1.1/32",
8697
},
8798
},
99+
{
100+
Network: "1.1.1.2/32",
101+
Database: "ipv4",
102+
Expected: []string{
103+
"1.1.1.2/31",
104+
},
105+
},
106+
{
107+
Network: "1.1.1.3/32",
108+
Database: "ipv4",
109+
Expected: []string{
110+
"1.1.1.2/31",
111+
},
112+
},
113+
{
114+
Network: "1.1.1.19/32",
115+
Database: "ipv4",
116+
Expected: []string{
117+
"1.1.1.16/28",
118+
},
119+
},
88120
{
89121
Network: "255.255.255.0/24",
90122
Database: "ipv4",
@@ -234,28 +266,51 @@ var tests = []networkTest{
234266
func TestNetworksWithin(t *testing.T) {
235267
for _, v := range tests {
236268
for _, recordSize := range []uint{24, 28, 32} {
237-
fileName := testFile(fmt.Sprintf("MaxMind-DB-test-%s-%d.mmdb", v.Database, recordSize))
238-
reader, err := Open(fileName)
239-
require.NoError(t, err, "unexpected error while opening database: %v", err)
269+
name := fmt.Sprintf(
270+
"%s-%d: %s, options: %v",
271+
v.Database,
272+
recordSize,
273+
v.Network,
274+
len(v.Options) != 0,
275+
)
276+
t.Run(name, func(t *testing.T) {
277+
fileName := testFile(fmt.Sprintf("MaxMind-DB-test-%s-%d.mmdb", v.Database, recordSize))
278+
reader, err := Open(fileName)
279+
require.NoError(t, err, "unexpected error while opening database: %v", err)
240280

241-
_, network, err := net.ParseCIDR(v.Network)
242-
require.NoError(t, err)
243-
n := reader.NetworksWithin(network, v.Options...)
244-
var innerIPs []string
281+
// We are purposely not using net.ParseCIDR so that we can pass in
282+
// values that aren't in canonical form.
283+
parts := strings.Split(v.Network, "/")
284+
ip := net.ParseIP(parts[0])
285+
if v := ip.To4(); v != nil {
286+
ip = v
287+
}
288+
prefixLength, err := strconv.Atoi(parts[1])
289+
require.NoError(t, err)
290+
mask := net.CIDRMask(prefixLength, len(ip)*8)
291+
network := &net.IPNet{
292+
IP: ip,
293+
Mask: mask,
294+
}
245295

246-
for n.Next() {
247-
record := struct {
248-
IP string `maxminddb:"ip"`
249-
}{}
250-
network, err := n.Network(&record)
251296
require.NoError(t, err)
252-
innerIPs = append(innerIPs, network.String())
253-
}
297+
n := reader.NetworksWithin(network, v.Options...)
298+
var innerIPs []string
254299

255-
assert.Equal(t, v.Expected, innerIPs)
256-
require.NoError(t, n.Err())
300+
for n.Next() {
301+
record := struct {
302+
IP string `maxminddb:"ip"`
303+
}{}
304+
network, err := n.Network(&record)
305+
require.NoError(t, err)
306+
innerIPs = append(innerIPs, network.String())
307+
}
257308

258-
require.NoError(t, reader.Close())
309+
assert.Equal(t, v.Expected, innerIPs)
310+
require.NoError(t, n.Err())
311+
312+
require.NoError(t, reader.Close())
313+
})
259314
}
260315
}
261316
}

0 commit comments

Comments
 (0)