Skip to content

Commit

Permalink
Add Parent(), Ancestors(), and VisitFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
Daisuke Maki committed Oct 11, 2024
1 parent 55775ec commit a68646f
Showing 1 changed file with 53 additions and 8 deletions.
61 changes: 53 additions & 8 deletions trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type Node[K cmp.Ordered, V any] interface {
Value() V
Children() iter.Seq[Node[K, V]]
AddChild(Node[K, V])
Parent() Node[K, V]
Ancestors() iter.Seq[Node[K, V]]
}

// New creates a new Trie object.
Expand All @@ -73,27 +75,45 @@ func (t *Trie[L, K, V]) Get(key L) (V, bool) {
for x := range iter {
tokens = append(tokens, x)
}
return get(t.root, tokens)
node, ok := getNode(t.root, tokens)
if !ok {
return zero, false
}
return node.Value(), true
}

func (t *Trie[L, K, V]) GetNode(key L) (Node[K, V], bool) {
iter, err := t.tokenizer.Tokenize(key)
if err != nil {
return nil, false
}

t.mu.RLock()
defer t.mu.RUnlock()
var tokens []K
for x := range iter {
tokens = append(tokens, x)
}
return getNode(t.root, tokens)
}

func get[K cmp.Ordered, V any](root Node[K, V], tokens []K) (V, bool) {
func getNode[K cmp.Ordered, V any](root Node[K, V], tokens []K) (Node[K, V], bool) {
if len(tokens) > 0 {
for child := range root.Children() {
if child.Key() == tokens[0] {
// found the current token in the children.
if len(tokens) == 1 {
// this is the node we're looking for
return child.Value(), true
return child, true
}
// we need to traverse down the trie
return get[K, V](child, tokens[1:])
return getNode[K, V](child, tokens[1:])
}
}
}

// if we got here, that means we couldn't find a common ancestor
var zero V
return zero, false
return nil, false
}

// Delete removes data associated with `key`. It returns true if the value
Expand Down Expand Up @@ -201,6 +221,7 @@ type node[K cmp.Ordered, V any] struct {
key K
value V
children []*node[K, V]
parent *node[K, V]
}

func newNode[K cmp.Ordered, V any]() *node[K, V] {
Expand All @@ -215,6 +236,22 @@ func (n *node[K, V]) Value() V {
return n.value
}

func (n *node[K, V]) Parent() Node[K, V] {
return n.parent
}

func (n *node[K, V]) Ancestors() iter.Seq[Node[K, V]] {
return func(yield func(Node[K, V]) bool) {
cur := n.parent
for cur != nil {
if !yield(cur) {
break
}
cur = cur.parent
}
}
}

func (n *node[K, V]) Children() iter.Seq[Node[K, V]] {
n.mu.RLock()
children := make([]*node[K, V], len(n.children))
Expand All @@ -235,7 +272,9 @@ func (n *node[K, V]) AddChild(child Node[K, V]) {
// Node[T] interface because we don't want the users to instantiate
// their own nodes... so this type conversion is safe.
//nolint:forcetypeassert
n.children = append(n.children, child.(*node[K, V]))
raw := child.(*node[K, V])
raw.parent = n
n.children = append(n.children, raw)
sort.Slice(n.children, func(i, j int) bool {
return n.children[i].Key() < n.children[j].Key()
})
Expand All @@ -250,6 +289,12 @@ type Visitor[K cmp.Ordered, V any] interface {
Visit(Node[K, V], VisitMetadata) bool
}

type VisitFunc[K cmp.Ordered, V any] func(Node[K, V], VisitMetadata) bool

func (f VisitFunc[K, V]) Visit(n Node[K, V], m VisitMetadata) bool {
return f(n, m)
}

func Walk[L any, K cmp.Ordered, V any](trie *Trie[L, K, V], v Visitor[K, V]) {
var meta VisitMetadata
meta.Depth = 1
Expand All @@ -273,7 +318,7 @@ func (dumper[K, V]) Visit(n Node[K, V], meta VisitMetadata) bool {
sb.WriteString(" ")
}

fmt.Fprintf(&sb, "%v: %v", n.Key(), n.Value())
fmt.Fprintf(&sb, "%q: %v", fmt.Sprintf("%v", n.Key()), n.Value())
fmt.Println(sb.String())
return true
}
Expand Down

0 comments on commit a68646f

Please sign in to comment.