Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary comparable type constraints from data structures #244

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions internal/treeenc/treeenc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package treeenc

import (
"encoding"
"encoding/json"
"reflect"
"strconv"
)

// KeyMarshaler is a helper type for marshaling keys of a tree.
// When marshaling a tree, we need first to convert tree key/value
// pairs to a standard Go map, and then marshal the map. However,
// Go maps can only have keys of comparable types, and so we wrap
// the key in a KeyMarshaler and implement the encoding.TextMarshaler
// interface to make it comparable and marshalable.
//
// The map should be declared as map[*KeyMarshaler[K]]V.
type KeyMarshaler[K any] struct {
Key K
}

var _ encoding.TextMarshaler = &KeyMarshaler[string]{}

func (m *KeyMarshaler[T]) MarshalText() ([]byte, error) {
kv := reflect.ValueOf(m.Key)
if tm, ok := kv.Interface().(encoding.TextMarshaler); ok {
if kv.Kind() == reflect.Pointer && kv.IsNil() {
return nil, nil
}
return tm.MarshalText()
}

var text string
switch kv.Kind() {
case reflect.String:
text = kv.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
text = strconv.FormatInt(kv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
text = strconv.FormatUint(kv.Uint(), 10)
default:
return nil, &json.UnsupportedTypeError{Type: kv.Type()}
}
return []byte(text), nil
}

// KeyUnmarshaler is a helper type for unmarshaling keys of a tree.
// When unmarshaling a tree, we first unmarshal the JSON into a Go
// map, and then convert the map to tree key/value pairs. Similar to
// KeyMarshaler, we wrap the key in a KeyUnmarshaler to make it
// unmarshalable by implementing the encoding.TextUnmarshaler interface.
//
// The map should be declared as map[KeyUnmarshaler[K]]V.
type KeyUnmarshaler[K any] struct {
Key *K
}

var _ encoding.TextUnmarshaler = &KeyUnmarshaler[string]{}

func (m *KeyUnmarshaler[K]) UnmarshalText(text []byte) error {
var key K
m.Key = &key

kv := reflect.ValueOf(m.Key)
if tu, ok := kv.Interface().(encoding.TextUnmarshaler); ok {
if kv.Kind() == reflect.Ptr && kv.IsNil() {
return nil
}
return tu.UnmarshalText(text)
}

var err error
kv = kv.Elem()
switch kv.Kind() {
case reflect.String:
kv.SetString(string(text))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var i int64
i, err = strconv.ParseInt(string(text), 10, 64)
kv.SetInt(i)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
var u uint64
u, err = strconv.ParseUint(string(text), 10, 64)
kv.SetUint(u)
default:
err = &json.UnsupportedTypeError{Type: kv.Type()}
}
return err
}
140 changes: 140 additions & 0 deletions internal/treeenc/treeenc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package treeenc

import (
"encoding/json"
"fmt"
"math"
"strings"
"testing"
)

type customType struct {
value string
}

func (c customType) MarshalText() ([]byte, error) {
return []byte(fmt.Sprintf("customType(%s)", c.value)), nil
}

func (c *customType) UnmarshalText(text []byte) error {
value := strings.TrimPrefix(string(text), "customType(")
value = strings.TrimSuffix(value, ")")
c.value = value
return nil
}

func TestKeyMarshaler(t *testing.T) {
strMap := make(map[*KeyMarshaler[string]]string)
strMap[&KeyMarshaler[string]{Key: "key"}] = "value"
data, err := json.Marshal(strMap)
if err != nil {
t.Fatalf("Got error: %v", err)
}
expected := "{\"key\":\"value\"}"
if string(data) != expected {
t.Fatalf("Expected %q, got %q", expected, string(data))
}

intMap := make(map[*KeyMarshaler[int]]int)
intMap[&KeyMarshaler[int]{Key: math.MinInt}] = math.MinInt
data, err = json.Marshal(&intMap)
if err != nil {
t.Fatalf("Got error: %v", err)
}
expected = "{\"-9223372036854775808\":-9223372036854775808}"
if string(data) != expected {
t.Fatalf("Expected %q, got %q", expected, string(data))
}

uintMap := make(map[*KeyMarshaler[uint]]uint)
uintMap[&KeyMarshaler[uint]{Key: math.MaxUint}] = math.MaxUint
data, err = json.Marshal(&uintMap)
if err != nil {
t.Fatalf("Got error: %v", err)
}
expected = "{\"18446744073709551615\":18446744073709551615}"
if string(data) != expected {
t.Fatalf("Expected %q, got %q", expected, string(data))
}

customMap := make(map[*KeyMarshaler[customType]]string)
customMap[&KeyMarshaler[customType]{Key: customType{value: "key"}}] = "value"
data, err = json.Marshal(&customMap)
if err != nil {
t.Fatalf("Got error: %v", err)
}
expected = "{\"customType(key)\":\"value\"}"
if string(data) != expected {
t.Fatalf("Expected %q, got %q", expected, string(data))
}
}

func TestKeyUnmarshaler(t *testing.T) {
strMap := make(map[KeyUnmarshaler[string]]string)
data := []byte("{\"key\":\"value\"}")
if err := json.Unmarshal(data, &strMap); err != nil {
t.Fatalf("Got error: %v", err)
}
if len(strMap) != 1 {
t.Fatalf("Expected 1 key, got %d", len(strMap))
}
for k, v := range strMap {
if *k.Key != "key" {
t.Fatalf("Expected key %q, got %q", "key", *k.Key)
}
if v != "value" {
t.Fatalf("Expected value %q, got %q", "value", v)
}
}

intMap := make(map[KeyUnmarshaler[int]]int)
data = []byte("{\"-9223372036854775808\":-9223372036854775808}")
if err := json.Unmarshal(data, &intMap); err != nil {
t.Fatalf("Got error: %v", err)
}
if len(intMap) != 1 {
t.Fatalf("Expected 1 key, got %d", len(intMap))
}
for k, v := range intMap {
if *k.Key != math.MinInt {
t.Fatalf("Expected key %d, got %d", math.MinInt, *k.Key)
}
if v != math.MinInt {
t.Fatalf("Expected value %d, got %d", math.MinInt, v)
}
}

uintMap := make(map[KeyUnmarshaler[uint]]uint)
data = []byte("{\"18446744073709551615\":18446744073709551615}")
if err := json.Unmarshal(data, &uintMap); err != nil {
t.Fatalf("Got error: %v", err)
}
if len(uintMap) != 1 {
t.Fatalf("Expected 1 key, got %d", len(uintMap))
}
for k, v := range uintMap {
if *k.Key != math.MaxUint {
t.Fatalf("Expected key %d, got %d", uint(math.MaxUint), *k.Key)
}
if v != math.MaxUint {
t.Fatalf("Expected value %d, got %d", uint(math.MaxUint), v)
}
}

customMap := make(map[KeyUnmarshaler[customType]]string)
data = []byte("{\"customType(key)\":\"value\"}")
if err := json.Unmarshal(data, &customMap); err != nil {
t.Fatalf("Got error: %v", err)
}
if len(customMap) != 1 {
t.Fatalf("Expected 1 key, got %d", len(customMap))
}
for k, v := range customMap {
if (*k.Key).value != "key" {
t.Fatalf("Expected key %q, got %q", "key", (*k.Key).value)
}
if v != "value" {
t.Fatalf("Expected value %q, got %q", "value", v)
}
}
}
18 changes: 13 additions & 5 deletions lists/arraylist/arraylist.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ import (
var _ lists.List[int] = (*List[int])(nil)

// List holds the elements in a slice
type List[T comparable] struct {
type List[T any] struct {
elements []T
size int
equal func(a, b T) bool
}

const (
Expand All @@ -34,7 +35,14 @@ const (

// New instantiates a new list and adds the passed values, if any, to the list
func New[T comparable](values ...T) *List[T] {
list := &List[T]{}
equal := func(a, b T) bool { return a == b }
return NewWith(equal, values...)
}

// NewWith instantiates a new list with the custom equal
// function and adds the passed values, if any, to the list.
func NewWith[T any](equal func(a, b T) bool, values ...T) *List[T] {
list := &List[T]{equal: equal}
if len(values) > 0 {
list.Add(values...)
}
Expand Down Expand Up @@ -85,7 +93,7 @@ func (list *List[T]) Contains(values ...T) bool {
for _, searchValue := range values {
found := false
for index := 0; index < list.size; index++ {
if list.elements[index] == searchValue {
if list.equal(list.elements[index], searchValue) {
found = true
break
}
Expand All @@ -99,7 +107,7 @@ func (list *List[T]) Contains(values ...T) bool {

// Values returns all elements in the list.
func (list *List[T]) Values() []T {
newElements := make([]T, list.size, list.size)
newElements := make([]T, list.size)
copy(newElements, list.elements[:list.size])
return newElements
}
Expand All @@ -110,7 +118,7 @@ func (list *List[T]) IndexOf(value T) int {
return -1
}
for index, element := range list.elements {
if element == value {
if list.equal(element, value) {
return index
}
}
Expand Down
2 changes: 1 addition & 1 deletion lists/arraylist/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import "github.com/emirpasic/gods/v2/containers"
var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil)

// Iterator holding the iterator's state
type Iterator[T comparable] struct {
type Iterator[T any] struct {
list *List[T]
index int
}
Expand Down
18 changes: 13 additions & 5 deletions lists/doublylinkedlist/doublylinkedlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,29 @@ import (
var _ lists.List[any] = (*List[any])(nil)

// List holds the elements, where each element points to the next and previous element
type List[T comparable] struct {
type List[T any] struct {
first *element[T]
last *element[T]
size int
equal func(a, b T) bool
}

type element[T comparable] struct {
type element[T any] struct {
value T
prev *element[T]
next *element[T]
}

// New instantiates a new list and adds the passed values, if any, to the list
func New[T comparable](values ...T) *List[T] {
list := &List[T]{}
equal := func(a, b T) bool { return a == b }
return NewWith(equal, values...)
}

// NewWith instantiates a new list with the custom equal
// function and adds the passed values, if any, to the list.
func NewWith[T any](equal func(a, b T) bool, values ...T) *List[T] {
list := &List[T]{equal: equal}
if len(values) > 0 {
list.Add(values...)
}
Expand Down Expand Up @@ -158,7 +166,7 @@ func (list *List[T]) Contains(values ...T) bool {
for _, value := range values {
found := false
for element := list.first; element != nil; element = element.next {
if element.value == value {
if list.equal(element.value, value) {
found = true
break
}
Expand All @@ -185,7 +193,7 @@ func (list *List[T]) IndexOf(value T) int {
return -1
}
for index, element := range list.Values() {
if element == value {
if list.equal(element, value) {
return index
}
}
Expand Down
2 changes: 1 addition & 1 deletion lists/doublylinkedlist/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import "github.com/emirpasic/gods/v2/containers"
var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil)

// Iterator holding the iterator's state
type Iterator[T comparable] struct {
type Iterator[T any] struct {
list *List[T]
index int
element *element[T]
Expand Down
2 changes: 1 addition & 1 deletion lists/lists.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

// List interface that all lists implement
type List[T comparable] interface {
type List[T any] interface {
Get(index int) (T, bool)
Remove(index int)
Add(values ...T)
Expand Down
2 changes: 1 addition & 1 deletion lists/singlylinkedlist/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import "github.com/emirpasic/gods/v2/containers"
var _ containers.IteratorWithIndex[int] = (*Iterator[int])(nil)

// Iterator holding the iterator's state
type Iterator[T comparable] struct {
type Iterator[T any] struct {
list *List[T]
index int
element *element[T]
Expand Down
Loading