diff --git a/ptr/ptr.go b/ptr/ptr.go index 318a814..c4070fa 100644 --- a/ptr/ptr.go +++ b/ptr/ptr.go @@ -1,15 +1,20 @@ package ptr -import "github.com/go-board/std/cmp" +import ( + "github.com/go-board/std/cmp" +) func zero[T any]() (v T) { return } // Ref return reference of value func Ref[T any](t T) *T { return &t } -// Default return default value of type -func Default[T any]() *T { - return Ref(zero[T]()) +// RefOrNil return reference of value if it not the zero value, else return nil +func RefOrNil[T comparable](t T) *T { + if t == zero[T]() { + return nil + } + return &t } // ValueOr return value of pointer if not nil, else return default value. @@ -28,8 +33,12 @@ func ValueOrZero[T any](v *T) T { // Compare compares two pointer. If both non-nil, compare underlying data, // if both nil, return 0, non-nil pointer is always greater than nil pointer. func Compare[T cmp.Ordered](l, r *T) int { + return CompareBy(l, r, cmp.Compare[T]) +} + +func CompareBy[T any](l, r *T, cmp func(T, T) int) int { if l != nil && r != nil { - return cmp.Compare(*l, *r) + return cmp(*l, *r) } if l == nil && r == nil { return 0 @@ -43,8 +52,12 @@ func Compare[T cmp.Ordered](l, r *T) int { // Equal test whether two pointer are equal. If both non-nil, test underlying data, // if both nil, return true, else return false func Equal[T comparable](l, r *T) bool { + return EqualBy(l, r, cmp.Equal[T]) +} + +func EqualBy[T any](l, r *T, eq func(T, T) bool) bool { if l != nil && r != nil { - return *l == *r + return eq(*l, *r) } else if l == nil && r == nil { return true }