diff --git a/pkg/ovn_ic_controller/ovn_ic_controller.go b/pkg/ovn_ic_controller/ovn_ic_controller.go index 9a86e65957b..3757862ef77 100644 --- a/pkg/ovn_ic_controller/ovn_ic_controller.go +++ b/pkg/ovn_ic_controller/ovn_ic_controller.go @@ -356,10 +356,10 @@ func (c *Controller) acquireLrpAddress(ts string) (string, error) { var ips []string v4Cidr, v6Cidr := util.SplitStringIP(cidr) if v4Cidr != "" { - ips = append(ips, util.GenerateRandomV4IP(v4Cidr)) + ips = append(ips, util.GenerateRandomIP(v4Cidr)) } if v6Cidr != "" { - ips = append(ips, util.GenerateRandomV6IP(v6Cidr)) + ips = append(ips, util.GenerateRandomIP(v6Cidr)) } random = strings.Join(ips, ",") // find a free address diff --git a/pkg/ovs/ovn-nb-logical_switch_port_test.go b/pkg/ovs/ovn-nb-logical_switch_port_test.go index 3fff4abaa38..9016bbdb677 100644 --- a/pkg/ovs/ovn-nb-logical_switch_port_test.go +++ b/pkg/ovs/ovn-nb-logical_switch_port_test.go @@ -885,7 +885,7 @@ func (suite *OvnClientTestSuite) testEnablePortLayer2forward() { lspName := "test-enable-port-l2-lsp" ns := "test-enable-port-l2-ns" pod := "test-enable-port-l2-pod" - ip := util.GenerateRandomV4IP("192.168.1.0/24") + ip := util.GenerateRandomIP("192.168.1.0/24") mac := util.GenerateMac() err := ovnClient.CreateBareLogicalSwitch(lsName) diff --git a/pkg/util/net.go b/pkg/util/net.go index 802efe79b94..bf59009a5be 100644 --- a/pkg/util/net.go +++ b/pkg/util/net.go @@ -78,12 +78,13 @@ func SubnetNumber(subnet string) string { func SubnetBroadcast(subnet string) string { _, cidr, _ := net.ParseCIDR(subnet) - maskLength, length := cidr.Mask.Size() - if maskLength+1 == length { + ones, bits := cidr.Mask.Size() + if ones+1 == bits { return "" } ipInt := IP2BigInt(cidr.IP.String()) - size := big.NewInt(0).Lsh(big.NewInt(1), uint(length-maskLength)) + zeros := uint(bits - ones) // #nosec G115 + size := big.NewInt(0).Lsh(big.NewInt(1), zeros) size = big.NewInt(0).Sub(size, big.NewInt(1)) return BigInt2Ip(ipInt.Add(ipInt, size)) } @@ -108,22 +109,17 @@ func LastIP(subnet string) (string, error) { if err != nil { return "", fmt.Errorf("%s is not a valid cidr", subnet) } - var length int - proto := CheckProtocol(subnet) - if proto == kubeovnv1.ProtocolIPv4 { - length = 32 - } else { - length = 128 - } - maskLength, _ := cidr.Mask.Size() + ipInt := IP2BigInt(cidr.IP.String()) - size := getCIDRSize(length, maskLength) + size := getCIDRSize(cidr) return BigInt2Ip(ipInt.Add(ipInt, size)), nil } -func getCIDRSize(length, maskLength int) *big.Int { - size := big.NewInt(0).Lsh(big.NewInt(1), uint(length-maskLength)) - if maskLength+1 == length { +func getCIDRSize(cidr *net.IPNet) *big.Int { + ones, bits := cidr.Mask.Size() + zeros := uint(bits - ones) // #nosec G115 + size := big.NewInt(0).Lsh(big.NewInt(1), zeros) + if ones+1 == bits { return big.NewInt(0).Sub(size, big.NewInt(1)) } return big.NewInt(0).Sub(size, big.NewInt(2)) @@ -212,31 +208,21 @@ func AddressCount(network *net.IPNet) float64 { return math.Pow(2, float64(bits-prefixLen)) - 2 } -func GenerateRandomV4IP(cidr string) string { - return genRandomIP(cidr, false) -} - -func GenerateRandomV6IP(cidr string) string { - return genRandomIP(cidr, true) -} - -func genRandomIP(cidr string, isIPv6 bool) string { - if len(strings.Split(cidr, "/")) != 2 { +func GenerateRandomIP(cidr string) string { + ip, network, err := net.ParseCIDR(cidr) + if err != nil { + klog.Errorf("failed to parse cidr %q: %v", cidr, err) return "" } - ip := strings.Split(cidr, "/")[0] - netMask, _ := strconv.Atoi(strings.Split(cidr, "/")[1]) - hostBits := 32 - netMask - if isIPv6 { - hostBits = 128 - netMask - } - add, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), uint(hostBits)-1)) + ones, bits := network.Mask.Size() + zeros := uint(bits - ones) // #nosec G115 + add, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), zeros-1)) if err != nil { - klog.Errorf("failed to generate random big int with bits %d: %v", hostBits, err) + klog.Errorf("failed to generate random big int with bits %d: %v", zeros, err) return "" } - t := big.NewInt(0).Add(IP2BigInt(ip), add) - return fmt.Sprintf("%s/%d", BigInt2Ip(t), netMask) + t := big.NewInt(0).Add(IP2BigInt(ip.String()), add) + return fmt.Sprintf("%s/%d", BigInt2Ip(t), ones) } func IPToString(ip string) string { diff --git a/pkg/util/net_test.go b/pkg/util/net_test.go index bbc9fd55f16..185cceb1a0c 100644 --- a/pkg/util/net_test.go +++ b/pkg/util/net_test.go @@ -475,14 +475,14 @@ func TestGenerateRandomV4IP(t *testing.T) { t.Run(c.name, func(t *testing.T) { _, IPNets, err := net.ParseCIDR(c.cidr) if err != nil { - ans := GenerateRandomV4IP(c.cidr) + ans := GenerateRandomIP(c.cidr) if c.want != ans { t.Errorf("%v expected %v, but %v got", c.cidr, c.want, ans) } } else { - ans := GenerateRandomV4IP(c.cidr) - if IPNets.Contains(net.ParseIP(GenerateRandomV4IP(c.cidr))) { + ans := GenerateRandomIP(c.cidr) + if IPNets.Contains(net.ParseIP(GenerateRandomIP(c.cidr))) { t.Errorf("%v expected %v, but %v got", c.cidr, c.want, ans) } @@ -514,7 +514,7 @@ func TestGenerateRandomV6IP(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ip := GenerateRandomV6IP(tt.cidr) + ip := GenerateRandomIP(tt.cidr) if tt.wantErr { if ip != "" { t.Errorf("GenerateRandomV6IP(%s) = %s; want empty string", tt.cidr, ip)