Skip to content

Commit

Permalink
refactor: abstract new/nil-pointer creation into pseudo.Constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
ARR4N committed Aug 26, 2024
1 parent 03c1687 commit 5781184
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 50 deletions.
24 changes: 24 additions & 0 deletions libevm/pseudo/constructor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package pseudo

// A Constructor returns newly constructed [Type] instances for a pre-registered
// concrete type.
type Constructor interface {
Zero() *Type
NewPointer() *Type
NilPointer() *Type
}

// NewConstructor returns a [Constructor] that builds `T` [Type] instances.
func NewConstructor[T any]() Constructor {
return ctor[T]{}
}

type ctor[T any] struct{}

func (ctor[T]) Zero() *Type { return Zero[T]().Type }
func (ctor[T]) NilPointer() *Type { return Zero[*T]().Type }

func (ctor[T]) NewPointer() *Type {
var x T
return From(&x).Type
}
45 changes: 45 additions & 0 deletions libevm/pseudo/constructor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package pseudo

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestConstructor(t *testing.T) {
testConstructor[uint](t)
testConstructor[string](t)
testConstructor[struct{ x string }](t)
}

func testConstructor[T any](t *testing.T) {
var zero T
t.Run(fmt.Sprintf("%T", zero), func(t *testing.T) {
ctor := NewConstructor[T]()

t.Run("NilPointer()", func(t *testing.T) {
got := get[*T](t, ctor.NilPointer())
assert.Nil(t, got)
})

t.Run("NewPointer()", func(t *testing.T) {
got := get[*T](t, ctor.NewPointer())
require.NotNil(t, got)
assert.Equal(t, zero, *got)
})

t.Run("Zero()", func(t *testing.T) {
got := get[T](t, ctor.Zero())
assert.Equal(t, zero, got)
})
})
}

func get[T any](t *testing.T, typ *Type) (x T) {
t.Helper()
val, err := NewValue[T](typ)
require.NoError(t, err, "NewValue[%T]()", x)
return val.Get()
}
91 changes: 41 additions & 50 deletions params/config.libevm.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,45 @@ func RegisterExtras[C any, R any](e Extras[C, R]) ExtraPayloadGetter[C, R] {
}
mustBeStruct[C]()
mustBeStruct[R]()
registeredExtras = &e
return ExtraPayloadGetter[C, R]{}
registeredExtras = &extraConstructors{
chainConfig: pseudo.NewConstructor[C](),
rules: pseudo.NewConstructor[R](),
newForRules: e.newForRules,
}
return e.getter()
}

// registeredExtras holds non-generic constructors for the [Extras] types
// registered via [RegisterExtras].
var registeredExtras *extraConstructors

type extraConstructors struct {
chainConfig, rules pseudo.Constructor
newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type
}

func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type {
if e.NewRules == nil {
return registeredExtras.rules.NilPointer()
}
rExtra := e.NewRules(c, r, e.getter().FromChainConfig(c), blockNum, isMerge, timestamp)
return pseudo.From(rExtra).Type
}

func (*Extras[C, R]) getter() (g ExtraPayloadGetter[C, R]) { return }

// mustBeStruct panics if `T` isn't a struct.
func mustBeStruct[T any]() {
if k := reflect.TypeFor[T]().Kind(); k != reflect.Struct {
panic(notStructMessage[T]())
}
}

// registeredExtras holds the [Extras] registered via [RegisterExtras]. As we
// don't know `C` and `R` at compile time, it must be an interface.
var registeredExtras interface {
nilForChainConfig() *pseudo.Type
nilForRules() *pseudo.Type
newForChainConfig() *pseudo.Type
newForRules(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type
// notStructMessage returns the message with which [mustBeStruct] might panic.
// It exists to avoid change-detector tests should the message contents change.
func notStructMessage[T any]() string {
var x T
return fmt.Sprintf("%T is not a struct", x)
}

// An ExtraPayloadGettter provides strongly typed access to the extra payloads
Expand All @@ -74,29 +102,15 @@ func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) *R {
return pseudo.MustNewValue[*R](r.extraPayload()).Get()
}

func mustBeStruct[T any]() {
var x T
if k := reflect.TypeOf(x).Kind(); k != reflect.Struct {
panic(notStructMessage[T]())
}
}

// notStructMessage returns the message with which [mustBeStruct] might panic.
// It exists to avoid change-detector tests should the message contents change.
func notStructMessage[T any]() string {
var x T
return fmt.Sprintf("%T is not a struct", x)
}

// UnmarshalJSON implements the [json.Unmarshaler] interface.
func (c *ChainConfig) UnmarshalJSON(data []byte) error {
type raw ChainConfig // doesn't inherit methods so avoids recursing back here (infinitely)
cc := &struct {
*raw
Extra *pseudo.Type `json:"extra"`
}{
raw: (*raw)(c), // embedded to achieve regular JSON unmarshalling
Extra: registeredExtras.nilForChainConfig(), // `c.extra` is otherwise unexported
raw: (*raw)(c), // embedded to achieve regular JSON unmarshalling
Extra: registeredExtras.chainConfig.NilPointer(), // `c.extra` is otherwise unexported
}

if err := json.Unmarshal(data, cc); err != nil {
Expand Down Expand Up @@ -143,7 +157,7 @@ func (c *ChainConfig) extraPayload() *pseudo.Type {
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c))
}
if c.extra == nil {
c.extra = registeredExtras.nilForChainConfig()
c.extra = registeredExtras.chainConfig.NilPointer()
}
return c.extra
}
Expand All @@ -155,30 +169,7 @@ func (r *Rules) extraPayload() *pseudo.Type {
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r))
}
if r.extra == nil {
r.extra = registeredExtras.nilForRules()
r.extra = registeredExtras.rules.NilPointer()
}
return r.extra
}

/**
* Start of Extras implementing the registeredExtras interface.
*/

func (Extras[C, R]) nilForChainConfig() *pseudo.Type { return pseudo.Zero[*C]().Type }
func (Extras[C, R]) nilForRules() *pseudo.Type { return pseudo.Zero[*R]().Type }

func (*Extras[C, R]) newForChainConfig() *pseudo.Type {
var x C
return pseudo.From(&x).Type
}

func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type {
if e.NewRules == nil {
return e.nilForRules()
}
return pseudo.From(e.NewRules(c, r, c.extra.Interface().(*C), blockNum, isMerge, timestamp)).Type
}

/**
* End of Extras implementing the registeredExtras interface.
*/

0 comments on commit 5781184

Please sign in to comment.