Skip to content

Commit

Permalink
Merge commit from fork
Browse files Browse the repository at this point in the history
* Plumb allow/deny CIDRs to HTTP client

* Make the DNS cache aware of the ControlContext function

---------

Co-authored-by: Till Faelligen <[email protected]>
  • Loading branch information
turt2live and S7evinK authored Jan 16, 2025
1 parent bf86bc9 commit c4f1e01
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 19 deletions.
106 changes: 91 additions & 15 deletions fclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

"github.com/matrix-org/gomatrix"
Expand All @@ -54,13 +55,15 @@ type UserInfo struct {
}

type clientOptions struct {
transport http.RoundTripper
dnsCache *DNSCache
timeout time.Duration
skipVerify bool
keepAlives bool
wellKnownSRV bool
userAgent string
transport http.RoundTripper
dnsCache *DNSCache
timeout time.Duration
skipVerify bool
keepAlives bool
wellKnownSRV bool
userAgent string
allowNetworks []string
denyNetworks []string
}

// ClientOption are supplied to NewClient or NewFederationClient.
Expand All @@ -82,6 +85,8 @@ func NewClient(options ...ClientOption) *Client {
clientOpts.dnsCache,
clientOpts.keepAlives,
clientOpts.wellKnownSRV,
clientOpts.allowNetworks,
clientOpts.denyNetworks,
)
}
client := &Client{
Expand Down Expand Up @@ -152,6 +157,15 @@ func WithUserAgent(userAgent string) ClientOption {
}
}

// WithAllowDenyNetworks sets the allowed and denied networks for the http client. By default,
// all networks are allowed. The deny list is checked before the allow list.
func WithAllowDenyNetworks(allowCIDRs []string, denyCIDRs []string) ClientOption {
return func(options *clientOptions) {
options.allowNetworks = allowCIDRs
options.denyNetworks = denyCIDRs
}
}

const destinationTripperLifetime = time.Minute * 5 // how long to keep an entry
const destinationTripperReapInterval = time.Minute // how often to check for dead entries

Expand All @@ -165,15 +179,17 @@ type destinationTripper struct {
dnsCache *DNSCache
keepAlives bool
wellKnownSRV bool
dialer *net.Dialer
}

func newDestinationTripper(skipVerify bool, dnsCache *DNSCache, keepAlives, wellKnownSRV bool) *destinationTripper {
func newDestinationTripper(skipVerify bool, dnsCache *DNSCache, keepAlives, wellKnownSRV bool, allowCIDRs []string, denyCIDRs []string) *destinationTripper {
tripper := &destinationTripper{
transports: make(map[string]*destinationTripperTransport),
skipVerify: skipVerify,
dnsCache: dnsCache,
keepAlives: keepAlives,
wellKnownSRV: wellKnownSRV,
dialer: newDestinationTripperDialer(allowCIDRs, denyCIDRs),
}
time.AfterFunc(destinationTripperReapInterval, tripper.reaper)
return tripper
Expand All @@ -195,11 +211,71 @@ func (f *destinationTripper) reaper() {
time.AfterFunc(destinationTripperReapInterval, f.reaper)
}

// destinationTripperDialer enforces dial timeouts on the federation requests. If
// newDestinationTripperDialer creates a dialer which enforces dial timeouts on the federation requests. If
// the TCP connection doesn't complete within 5 seconds, it's probably just not
// going to.
var destinationTripperDialer = &net.Dialer{
Timeout: time.Second * 5,
// The dialer can also be limited to CIDR ranges, if allow or deny networks is non-empty.
func newDestinationTripperDialer(allowNetworks []string, denyNetworks []string) *net.Dialer {
if len(allowNetworks) == 0 && len(denyNetworks) == 0 {
return &net.Dialer{
Timeout: time.Second * 5,
}
}

return &net.Dialer{
Timeout: time.Second * 5,
ControlContext: allowDenyNetworksControl(allowNetworks, denyNetworks),
}
}

// allowDenyNetworksControl is used to allow/deny access to certain networks
func allowDenyNetworksControl(allowNetworks, denyNetworks []string) func(_ context.Context, network string, address string, conn syscall.RawConn) error {
return func(_ context.Context, network string, address string, conn syscall.RawConn) error {
if network != "tcp4" && network != "tcp6" {
return fmt.Errorf("%s is not a safe network type", network)
}

host, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("%s is not a valid host/port pair: %s", address, err)
}

ipaddress := net.ParseIP(host)
if ipaddress == nil {
return fmt.Errorf("%s is not a valid IP address", host)
}

if !isAllowed(ipaddress, allowNetworks, denyNetworks) {
return fmt.Errorf("%s is denied", address)
}

return nil // allow connection
}
}

func isAllowed(ip net.IP, allowCIDRs []string, denyCIDRs []string) bool {
if inRange(ip, denyCIDRs) {
return false
}
if inRange(ip, allowCIDRs) {
return true
}
return false // "should never happen"
}

func inRange(ip net.IP, CIDRs []string) bool {
for i := 0; i < len(CIDRs); i++ {
cidr := CIDRs[i]
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return false
}
if network.Contains(ip) {
return true
}
}

return false
}

type destinationTripperTransport struct {
Expand All @@ -213,7 +289,7 @@ type destinationTripperTransport struct {
// We need to use one transport per TLS server name (instead of giving our round
// tripper a single transport) because there is no way to specify the TLS
// ServerName on a per-connection basis.
func (f *destinationTripper) getTransport(tlsServerName string) http.RoundTripper {
func (f *destinationTripper) getTransport(tlsServerName string, dialer *net.Dialer) http.RoundTripper {
f.transportsMutex.Lock()
defer f.transportsMutex.Unlock()

Expand All @@ -230,8 +306,8 @@ func (f *destinationTripper) getTransport(tlsServerName string) http.RoundTrippe
InsecureSkipVerify: f.skipVerify,
ClientSessionCache: tls.NewLRUClientSessionCache(0), // 0 = use default
},
Dial: destinationTripperDialer.Dial, // nolint: staticcheck
DialContext: destinationTripperDialer.DialContext,
Dial: dialer.Dial, // nolint: staticcheck
DialContext: dialer.DialContext,
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true, // if we can multiplex requests over HTTP/2, we should
},
Expand Down Expand Up @@ -296,7 +372,7 @@ retryResolution:
u := makeHTTPSURL(r.URL, result.Destination)
r.URL = &u
r.Host = string(result.Host)
resp, err = f.getTransport(result.TLSServerName).RoundTrip(r)
resp, err = f.getTransport(result.TLSServerName, f.dialer).RoundTrip(r)
if err == nil {
return resp, nil
}
Expand Down
9 changes: 6 additions & 3 deletions fclient/dnscache.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ type DNSCache struct {
size int
duration time.Duration
entries map[string]*dnsCacheEntry
dialer net.Dialer
}

func NewDNSCache(size int, duration time.Duration) *DNSCache {
func NewDNSCache(size int, duration time.Duration, allowNetworks, denyNetworks []string) *DNSCache {
return &DNSCache{
resolver: net.DefaultResolver,
size: size,
duration: duration,
entries: make(map[string]*dnsCacheEntry),
dialer: net.Dialer{
ControlContext: allowDenyNetworksControl(allowNetworks, denyNetworks),
},
}
}

Expand Down Expand Up @@ -100,7 +104,6 @@ func (c *DNSCache) DialContext(ctx context.Context, network, address string) (ne
// retried set to true. This stops us from recursing more than
// once.
retried := false
dialer := net.Dialer{}

retryLookup:
// Consult the cache for the hostname. This will cause the OS to
Expand All @@ -113,7 +116,7 @@ retryLookup:
// Try each address in the cached entry. If we successfully connect
// to one of those addresses then return the conn and stop there.
for _, addr := range entry.addrs {
conn, err := dialer.DialContext(ctx, "tcp", addr.String()+":"+port)
conn, err := c.dialer.DialContext(ctx, "tcp", addr.String()+":"+port)
if err != nil {
continue
}
Expand Down
2 changes: 1 addition & 1 deletion fclient/dnscache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (r *dummyNetResolver) LookupIPAddr(_ context.Context, hostname string) ([]n
}

func mustCreateCache(size int, lifetime time.Duration) *DNSCache {
cache := NewDNSCache(size, lifetime)
cache := NewDNSCache(size, lifetime, []string{}, []string{})
cache.resolver = &dummyNetResolver{}
return cache
}
Expand Down

0 comments on commit c4f1e01

Please sign in to comment.