Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a chk package for checked types #2743

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions util/chk/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package chk_test

import (
"fmt"
"math"

"github.com/offchainlabs/nitro/util/chk"
)

func PublicWith(x uint64) (uint64, error) {
px, err := chk.NewPos64(x)
if err != nil {
return 0, err
}
return someOtherCalculationWith(px, someCalculationWith(px)).Val(), nil
}

func someCalculationWith(x chk.Pos64) chk.Pos64 {
// Other complicated logic here.
return safelyDoubleWith(x)
}

func someOtherCalculationWith(x, y chk.Pos64) chk.Pos64 {
// Other complicated logic here.
return safelyAddWith(x, y)
}

func safelyDoubleWith(x chk.Pos64) chk.Pos64 {
if x.Val() > math.MaxUint64/2 {
return chk.MustPos64(math.MaxUint64)
}
return chk.MustPos64(x.Val() / 2)
}

func safelyAddWith(x, y chk.Pos64) chk.Pos64 {
if x.Val() > math.MaxUint64-y.Val() {
return chk.MustPos64(math.MaxUint64)
}
return chk.MustPos64(x.Val() + y.Val())
}

func Example() {
r, _ := Public(10)
fmt.Println(r)
// Output: 30
}
61 changes: 61 additions & 0 deletions util/chk/example_without_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package chk_test

import (
"errors"
"fmt"
"math"
)

func Public(x uint64) (uint64, error) {
if x == 0 {
return 0, errors.New("x must be positive")
}
y, err := someCalculation(x)
if err != nil {
return 0, err
}
z, err := someOtherCalculation(x, y)
if err != nil {
return 0, err
}
return z, nil
}

func someCalculation(x uint64) (uint64, error) {
if x == 0 {
return 0, errors.New("x must be positive")
}
// Other complicated logic here.
return safelyDouble(x), nil
}

func someOtherCalculation(x, y uint64) (uint64, error) {
if x == 0 {
return 0, errors.New("x must be positive")
}
if y == 0 {
return 0, errors.New("y must be positive")
}
// Other complicated logic here.
return safelyAdd(x, y), nil
}

func safelyDouble(x uint64) uint64 {
if x > math.MaxUint64/2 {
return math.MaxUint64
}
return x * 2
}

func safelyAdd(x, y uint64) uint64 {
if x > math.MaxUint64-y {
return math.MaxUint64
}
return x + y
}

func Example_without() {
r, _ := Public(10)
fmt.Println(r)
// Output: 30
}
143 changes: 143 additions & 0 deletions util/chk/positive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Package chk supplies a set of checked types.
//
// These types can be used to avoid repeatedly checking the same checks
// on function and method arguments at multiple layers in your code's call
// stack.
//
// For exmpample, if you have a package which provides a public function which
// accepts a uint64, but inside that package you have other functions which all
// need to be able to be able to operate on strictly positive integers:
//
// Without the chk package you might write code like:
//
// func Public(x uint64) (uint64, error) {
// if x == 0 {
// return 0, errors.New("x must be positive")
// }
// y, err := someCalculation(x)
// if err != nil {
// return 0, err
// }
// z, err := someOtherCalculation(x, y)
// if err != nil {
// return 0, err
// }
// return z, nil
// }
//
// func someCalculation(x uint64) (uint64, error) {
// if x == 0 {
// return 0, errors.New("x must be positive")
// }
// // Other complicated logic here.
// return safelyDouble(x), nil
// }
//
// func someOtherCalculation(x, y uint64) (uint64, error) {
// if x == 0 {
// return 0, errors.New("x must be positive")
// }
// if y == 0 {
// return 0, errors.New("y must be positive")
// }
// // Other complicated logic here.
// return safelyAdd(x, y), nil
// }
//
// func safelyDouble(x uint64) uint64 {
// if x > math.MaxUint64/2 {
// return math.MaxUint64
// }
// return x * 2
// }
//
// func safelyAdd(x, y uint64) uint64 {
// if x > math.MaxUint64-y {
// return math.MaxUint64
// }
// return x + y
// }
//
// This sort of code is annoying to write and maintain, but it is necessary to
// enusure that a coding error in the future doesn't introduce some other caller
// of one of the internal functions which aren't guarded by a check for a
// positive value.
//
// With the chk package you can write code like this:
//
// func PublicWith(x uint64) (uint64, error) {
// px, err := chk.NewPos64(x)
// if err != nil {
// return 0, err
// }
// return someOtherCalculationWith(px, someCalculationWith(px)).Val(), nil
// }
//
// func someCalculationWith(x chk.Pos64) chk.Pos64 {
// // Other complicated logic here.
// return safelyDoubleWith(x)
// }
//
// func someOtherCalculationWith(x, y chk.Pos64) chk.Pos64 {
// // Other complicated logic here.
// return safelyAddWith(x, y)
// }
//
// func safelyDoubleWith(x chk.Pos64) chk.Pos64 {
// if x.Val() > math.MaxUint64/2 {
// return chk.MustPos64(math.MaxUint64)
// }
// return chk.MustPos64(x.Val() / 2)
// }
//
// func safelyAddWith(x, y chk.Pos64) chk.Pos64 {
// if x.Val() > math.MaxUint64-y.Val() {
// return chk.MustPos64(math.MaxUint64)
// }
// return chk.MustPos64(x.Val() + y.Val())
// }
//
// Of course, if you don't mind forcing clients of your package to depend on
// the chk package as well, you can just have your public funciton take a
// chk.Pos64 argument directly.
package chk

import (
"errors"
)

// Pos64 is a type which represents a positive uint64.
//
// The "zero" value of Pos64 is 1.
type Pos64 struct {
uint64
}

// NewPos64 returns a new Pos64 with the given value.
//
// errors if v is 0.
func NewPos64(v uint64) (Pos64, error) {
if v == 0 {
return Pos64{}, errors.New("v must be positive. got: 0")
}
return Pos64{v}, nil
}

// MustPos64 returns a new Pos64 with the given value.
//
// panics if v is 0.
func MustPos64(v uint64) Pos64 {
if v == 0 {
panic("v must be positive. got: 0")
}
return Pos64{v}
}

// Val returns the value of the Pos64.
func (p Pos64) Val() uint64 {
// The zero value of Pos64 is 1.
if p.uint64 == 0 {
return 1
}
return p.uint64
}
97 changes: 97 additions & 0 deletions util/chk/positive_external_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package chk_test

import (
"math"
"testing"

"github.com/offchainlabs/nitro/util/chk"
)

func TestNewPos64(t *testing.T) {
v, err := chk.NewPos64(1)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if v.Val() != 1 {
t.Errorf("v.Val() want 1, got %d", v)
}
}

func TestMustPos64(t *testing.T) {
v := chk.MustPos64(1)
if v.Val() != 1 {
t.Errorf("v.Val() want 1, got %d", v)
}
}

func TestNewPos64_error(t *testing.T) {
_, err := chk.NewPos64(0)
if err == nil {
t.Error("Expected an error, got nil")
}
if err.Error() != "v must be positive. got: 0" {
t.Errorf("Expected error message 'value must be positive', got '%s'", err.Error())
}
}

func TestMustPos64_panic(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("Expected a panic, got nil")
}
}()
chk.MustPos64(0)
}

func BenchmarkAdding(b *testing.B) {
x := chk.MustPos64(1)
y := chk.MustPos64(2)
for i := 0; i < b.N; i++ {
_ = x.Val() + y.Val()
}
}

func BenchmarkAddingUint64(b *testing.B) {
x := uint64(1)
y := uint64(2)
for i := 0; i < b.N; i++ {
_ = x + y
}
}

// Test zero value.
func TestZeroValue(t *testing.T) {
var p chk.Pos64
if p.Val() != 1 {
t.Errorf("want 1, got %d", p.Val())
}
}

// Test MaxUint64 value.
func TestMaxUint64(t *testing.T) {
p := chk.MustPos64(math.MaxUint64)
if p.Val() != math.MaxUint64 {
t.Errorf("want math.MaxUint64, got %d", p.Val())
}
}

// Cations are always positive.
func handleCation(c chk.Pos64) uint64 {
return c.Val()
}

func TestPassingToFunction(t *testing.T) {
want := uint64(1)
got := handleCation(chk.MustPos64(1))
if got != want {
t.Errorf("want %d, got %d", want, got)
}
}

// Uncomment to see that these lines don't compile.
// func doesNotCompile() {
// _ = chk.Pos64{100}
// _ = chk.MustPos64(50).value
// handleCation(0)
// handleCation(uint64(0))
// }
18 changes: 18 additions & 0 deletions util/chk/positive_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package chk

import (
"testing"
"unsafe"
)

func TestSize(t *testing.T) {
// This test is here to ensure that the size of the Pos64 struct is 8 bytes.
u := uint64(24601)
p := Pos64{}

want := unsafe.Sizeof(u)
got := unsafe.Sizeof(p)
if got != want {
t.Errorf("Size of Pos64 want %d, got %d", want, got)
}
}
Loading