From fd434712fbeccddf282b111f8dbb44a34e14a540 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 27 Jan 2025 18:30:29 +0100 Subject: [PATCH] stricter hostname validation and replace Signed-off-by: Kristoffer Dalby --- hscontrol/auth.go | 4 +-- hscontrol/db/node.go | 8 +++++- hscontrol/db/node_test.go | 51 +++++++++++++++++++++++++++++++++++--- hscontrol/db/users.go | 4 +-- hscontrol/types/node.go | 12 ++++++--- hscontrol/util/dns.go | 25 ++++++++++++------- hscontrol/util/dns_test.go | 2 +- hscontrol/util/string.go | 5 ++++ 8 files changed, 89 insertions(+), 22 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 9e22660d46..c8e7daa673 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -152,7 +152,7 @@ func (h *Headscale) handleRegister( newNode := types.RegisterNode{ Node: types.Node{ MachineKey: machineKey, - Hostname: regReq.Hostinfo.Hostname, + Hostname: strings.ToLower(regReq.Hostinfo.Hostname), NodeKey: regReq.NodeKey, LastSeen: &now, Expiry: &time.Time{}, @@ -386,7 +386,7 @@ func (h *Headscale) handleAuthKey( now := time.Now().UTC() nodeToRegister := types.Node{ - Hostname: registerRequest.Hostinfo.Hostname, + Hostname: strings.ToLower(registerRequest.Hostinfo.Hostname), UserID: pak.User.ID, User: pak.User, MachineKey: machineKey, diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index f722d9ab16..bf6510c5e4 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -262,7 +262,7 @@ func SetTags( func RenameNode(tx *gorm.DB, nodeID types.NodeID, newName string, ) error { - err := util.CheckForFQDNRules( + err := util.ValidateHostname( newName, ) if err != nil { @@ -459,6 +459,12 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad node.IPv4 = ipv4 node.IPv6 = ipv6 + if err := util.ValidateHostname(node.Hostname); err != nil { + newHostname := util.InvalidString() + log.Info().Err(err).Str("invalid-hostname", node.Hostname).Str("new-hostname", newHostname).Msgf("Invalid hostname, replacing") + node.Hostname = newHostname + } + if node.GivenName == "" { givenName, err := ensureUniqueGivenName(tx, node.Hostname) if err != nil { diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 270fd91b2e..7f9659872c 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -756,7 +756,7 @@ func TestListEphemeralNodes(t *testing.T) { assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname) } -func TestRenameNode(t *testing.T) { +func TestNodeNaming(t *testing.T) { db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating db: %s", err) @@ -786,6 +786,26 @@ func TestRenameNode(t *testing.T) { RegisterMethod: util.RegisterMethodAuthKey, } + // Using non-ASCII characters in the hostname can + // break your network, so they should be replaced when registering + // a node. + // https://github.com/juanfont/headscale/issues/2343 + nodeInvalidHostname := types.Node{ + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "我的电脑", + UserID: user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + } + + nodeShortHostname := types.Node{ + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "a", + UserID: user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + } + err = db.DB.Save(&node).Error require.NoError(t, err) @@ -798,6 +818,11 @@ func TestRenameNode(t *testing.T) { return err } _, err = RegisterNode(tx, node2, nil, nil) + if err != nil { + return err + } + _, err = RegisterNode(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil) + _, err = RegisterNode(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil) return err }) require.NoError(t, err) @@ -805,10 +830,12 @@ func TestRenameNode(t *testing.T) { nodes, err := db.ListNodes() require.NoError(t, err) - assert.Len(t, nodes, 2) + assert.Len(t, nodes, 4) t.Logf("node1 %s %s", nodes[0].Hostname, nodes[0].GivenName) t.Logf("node2 %s %s", nodes[1].Hostname, nodes[1].GivenName) + t.Logf("node3 %s %s", nodes[2].Hostname, nodes[2].GivenName) + t.Logf("node4 %s %s", nodes[3].Hostname, nodes[3].GivenName) assert.Equal(t, nodes[0].Hostname, nodes[0].GivenName) assert.NotEqual(t, nodes[1].Hostname, nodes[1].GivenName) @@ -820,6 +847,10 @@ func TestRenameNode(t *testing.T) { assert.Len(t, nodes[1].Hostname, 4) assert.Len(t, nodes[0].GivenName, 4) assert.Len(t, nodes[1].GivenName, 13) + assert.Contains(t, nodes[2].Hostname, "invalid-") // invalid chars + assert.Contains(t, nodes[2].GivenName, "invalid-") + assert.Contains(t, nodes[3].Hostname, "invalid-") // too short + assert.Contains(t, nodes[3].GivenName, "invalid-") // Nodes can be renamed to a unique name err = db.Write(func(tx *gorm.DB) error { @@ -829,7 +860,7 @@ func TestRenameNode(t *testing.T) { nodes, err = db.ListNodes() require.NoError(t, err) - assert.Len(t, nodes, 2) + assert.Len(t, nodes, 4) assert.Equal(t, "test", nodes[0].Hostname) assert.Equal(t, "newname", nodes[0].GivenName) @@ -841,7 +872,7 @@ func TestRenameNode(t *testing.T) { nodes, err = db.ListNodes() require.NoError(t, err) - assert.Len(t, nodes, 2) + assert.Len(t, nodes, 4) assert.Equal(t, "test", nodes[0].Hostname) assert.Equal(t, "newname", nodes[0].GivenName) assert.Equal(t, "test", nodes[1].GivenName) @@ -851,4 +882,16 @@ func TestRenameNode(t *testing.T) { return RenameNode(tx, nodes[0].ID, "test") }) assert.ErrorContains(t, err, "name is not unique") + + // Rename invalid chars + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[2].ID, "我的电脑") + }) + assert.ErrorContains(t, err, "lowercase ASCII letters numbers") + + // Rename too short + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[3].ID, "a") + }) + assert.ErrorContains(t, err, "longer than 2 or more characters") } diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 3fdc14a0bb..1b10be40b1 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -24,7 +24,7 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { // CreateUser creates a new User. Returns error if could not be created // or another user already exists. func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) { - err := util.CheckForFQDNRules(user.Name) + err := util.ValidateHostname(user.Name) if err != nil { return nil, err } @@ -89,7 +89,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { if err != nil { return err } - err = util.CheckForFQDNRules(newName) + err = util.ValidateHostname(newName) if err != nil { return err } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 36a6506231..c897d09435 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -368,12 +368,18 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { return } - if node.Hostname != hostInfo.Hostname { + newHostname := strings.ToLower(hostInfo.Hostname) + + if err := util.ValidateHostname(newHostname); err != nil { + newHostname = util.InvalidString() + } + + if node.Hostname != newHostname { if node.GivenNameHasBeenChanged() { - node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname) + node.GivenName = util.ConvertWithFQDNRules(newHostname) } - node.Hostname = hostInfo.Hostname + node.Hostname = newHostname } } diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index c87714d095..da028bab13 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -24,7 +24,7 @@ const ( var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") -var ErrInvalidUserName = errors.New("invalid user name") +var ErrInvalidHostName = errors.New("invalid hostname") // ValidateUsername checks if a username is valid. // It must be at least 2 characters long, start with a letter, and contain @@ -64,26 +64,33 @@ func ValidateUsername(username string) error { return nil } -func CheckForFQDNRules(name string) error { +func ValidateHostname(name string) error { + if len(name) < 2 { + return fmt.Errorf( + "hostname must be longer than 2 or more characters. %q doesn't comply with this rule: %w", + name, + ErrInvalidHostName, + ) + } if len(name) > LabelHostnameLength { return fmt.Errorf( - "DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", + "DNS segment must not be over 63 chars. %q doesn't comply with this rule: %w", name, - ErrInvalidUserName, + ErrInvalidHostName, ) } if strings.ToLower(name) != name { return fmt.Errorf( - "DNS segment should be lowercase. %v doesn't comply with this rule: %w", + "DNS segment should be lowercase. %q doesn't comply with this rule: %w", name, - ErrInvalidUserName, + ErrInvalidHostName, ) } if invalidDNSRegex.MatchString(name) { return fmt.Errorf( - "DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", + "DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %q doesn't comply with theses rules: %w", name, - ErrInvalidUserName, + ErrInvalidHostName, ) } @@ -244,7 +251,7 @@ func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { return "", fmt.Errorf( "label %v is more than 63 chars: %w", elt, - ErrInvalidUserName, + ErrInvalidHostName, ) } } diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go index 30652e4b98..dabcf07fd4 100644 --- a/hscontrol/util/dns_test.go +++ b/hscontrol/util/dns_test.go @@ -46,7 +46,7 @@ func TestCheckForFQDNRules(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { + if err := ValidateHostname(tt.args.name); (err != nil) != tt.wantErr { t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go index 08769060bc..61e8c0e02f 100644 --- a/hscontrol/util/string.go +++ b/hscontrol/util/string.go @@ -57,6 +57,11 @@ func GenerateRandomStringDNSSafe(size int) (string, error) { return str[:size], nil } +func InvalidString() string { + hash, _ := GenerateRandomStringDNSSafe(8) + return "invalid-" + hash +} + func TailNodesToString(nodes []*tailcfg.Node) string { temp := make([]string, len(nodes))