Skip to content

Commit

Permalink
Add tighter-scoped function types in /client (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCutter authored Nov 7, 2024
1 parent 8367fdb commit fb28144
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 193 deletions.
103 changes: 49 additions & 54 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,29 @@ var (
hasher = rfc6962.DefaultHasher
)

// Fetcher is the signature of a function which can retrieve arbitrary files from
// a log's data storage, via whatever appropriate mechanism.
// The path parameter is relative to the root of the log storage.
// CheckpointFetcherFunc is the signature of a function which can retrieve the latest
// checkpoint from a log's data storage.
//
// Note that the implementation of this MUST return (either directly or wrapped)
// an os.ErrIsNotExit when the file referenced by path does not exist, e.g. a HTTP
// an os.ErrIsNotExist when the file referenced by path does not exist, e.g. a HTTP
// based implementation MUST return this error when it receives a 404 StatusCode.
type Fetcher func(ctx context.Context, path string) ([]byte, error)
type CheckpointFetcherFunc func(ctx context.Context) ([]byte, error)

// TileFetcherFunc is the signature of a function which can fetch the raw data
// for a given tile.
//
// Note that the implementation of this MUST return (either directly or wrapped)
// an os.ErrIsNotExist when the file referenced by path does not exist, e.g. a HTTP
// based implementation MUST return this error when it receives a 404 StatusCode.
type TileFetcherFunc func(ctx context.Context, level, index, logSize uint64) ([]byte, error)

// EntryBundleFetcherFunc is the signature of a function which can fetch the raw data
// for a given entry bundle.
//
// Note that the implementation of this MUST return (either directly or wrapped)
// an os.ErrIsNotExist when the file referenced by path does not exist, e.g. a HTTP
// based implementation MUST return this error when it receives a 404 StatusCode.
type EntryBundleFetcherFunc func(ctx context.Context, bundleIndex, logSize uint64) ([]byte, error)

// ConsensusCheckpointFunc is a function which returns the largest checkpoint known which is
// signed by logSigV and satisfies some consensus algorithm.
Expand All @@ -55,16 +70,16 @@ type Fetcher func(ctx context.Context, path string) ([]byte, error)
type ConsensusCheckpointFunc func(ctx context.Context, logSigV note.Verifier, origin string) (*log.Checkpoint, []byte, *note.Note, error)

// UnilateralConsensus blindly trusts the source log, returning the checkpoint it provided.
func UnilateralConsensus(f Fetcher) ConsensusCheckpointFunc {
func UnilateralConsensus(f CheckpointFetcherFunc) ConsensusCheckpointFunc {
return func(ctx context.Context, logSigV note.Verifier, origin string) (*log.Checkpoint, []byte, *note.Note, error) {
return FetchCheckpoint(ctx, f, logSigV, origin)
}
}

// FetchCheckpoint retrieves and opens a checkpoint from the log.
// Returns both the parsed structure and the raw serialised checkpoint.
func FetchCheckpoint(ctx context.Context, f Fetcher, v note.Verifier, origin string) (*log.Checkpoint, []byte, *note.Note, error) {
cpRaw, err := f(ctx, layout.CheckpointPath)
func FetchCheckpoint(ctx context.Context, f CheckpointFetcherFunc, v note.Verifier, origin string) (*log.Checkpoint, []byte, *note.Note, error) {
cpRaw, err := f(ctx)
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -86,19 +101,18 @@ type ProofBuilder struct {
// NewProofBuilder creates a new ProofBuilder object for a given tree size.
// The returned ProofBuilder can be re-used for proofs related to a given tree size, but
// it is not thread-safe and should not be accessed concurrently.
func NewProofBuilder(ctx context.Context, cp log.Checkpoint, f Fetcher) (*ProofBuilder, error) {
tf := newTileFetcher(f, cp.Size)
func NewProofBuilder(ctx context.Context, cp log.Checkpoint, f TileFetcherFunc) (*ProofBuilder, error) {
pb := &ProofBuilder{
cp: cp,
nodeCache: newNodeCache(tf, cp.Size),
nodeCache: newNodeCache(f, cp.Size),
}
// Can't re-create the root of a zero size checkpoint other than by convention,
// so return early here in that case.
if cp.Size == 0 {
return pb, nil
}

hashes, err := FetchRangeNodes(ctx, cp.Size, tf)
hashes, err := FetchRangeNodes(ctx, cp.Size, f)
if err != nil {
return nil, fmt.Errorf("failed to fetch range nodes: %w", err)
}
Expand Down Expand Up @@ -165,8 +179,8 @@ func (pb *ProofBuilder) fetchNodes(ctx context.Context, nodes proof.Nodes) ([][]

// FetchRangeNodes returns the set of nodes representing the compact range covering
// a log of size s.
func FetchRangeNodes(ctx context.Context, s uint64, gt GetTileFunc) ([][]byte, error) {
nc := newNodeCache(gt, s)
func FetchRangeNodes(ctx context.Context, s uint64, f TileFetcherFunc) ([][]byte, error) {
nc := newNodeCache(f, s)
nIDs := make([]compact.NodeID, 0, compact.RangeSize(0, s))
nIDs = compact.RangeNodes(0, s, nIDs)
hashes := make([][]byte, 0, len(nIDs))
Expand All @@ -181,8 +195,8 @@ func FetchRangeNodes(ctx context.Context, s uint64, gt GetTileFunc) ([][]byte, e
}

// FetchLeafHashes fetches N consecutive leaf hashes starting with the leaf at index first.
func FetchLeafHashes(ctx context.Context, f Fetcher, first, N, logSize uint64) ([][]byte, error) {
nc := newNodeCache(newTileFetcher(f, logSize), logSize)
func FetchLeafHashes(ctx context.Context, f TileFetcherFunc, first, N, logSize uint64) ([][]byte, error) {
nc := newNodeCache(f, logSize)
hashes := make([][]byte, 0, N)
for i, end := first, first+N; i < end; i++ {
nID := compact.NodeID{Level: 0, Index: i}
Expand All @@ -203,21 +217,17 @@ type nodeCache struct {
logSize uint64
ephemeral map[compact.NodeID][]byte
tiles map[tileKey]api.HashTile
getTile GetTileFunc
getTile TileFetcherFunc
}

// GetTileFunc is the signature of a function which knows how to fetch a
// specific tile.
type GetTileFunc func(ctx context.Context, level, index uint64) (*api.HashTile, error)

// tileKey is used as a key in nodeCache's tile map.
type tileKey struct {
tileLevel uint64
tileIndex uint64
}

// newNodeCache creates a new nodeCache instance for a given log size.
func newNodeCache(f GetTileFunc, logSize uint64) nodeCache {
func newNodeCache(f TileFetcherFunc, logSize uint64) nodeCache {
return nodeCache{
logSize: logSize,
ephemeral: make(map[compact.NodeID][]byte),
Expand Down Expand Up @@ -245,12 +255,16 @@ func (n *nodeCache) GetNode(ctx context.Context, id compact.NodeID) ([]byte, err
tKey := tileKey{tileLevel, tileIndex}
t, ok := n.tiles[tKey]
if !ok {
tile, err := n.getTile(ctx, tileLevel, tileIndex)
tileRaw, err := n.getTile(ctx, tileLevel, tileIndex, n.logSize)
if err != nil {
return nil, fmt.Errorf("failed to fetch tile: %w", err)
}
t = *tile
n.tiles[tKey] = *tile
var tile api.HashTile
if err := tile.UnmarshalText(tileRaw); err != nil {
return nil, fmt.Errorf("failed to parse tile: %w", err)
}
t = tile
n.tiles[tKey] = tile
}
// We've got the tile, now we need to look up (or calculate) the node inside of it
numLeaves := 1 << nodeLevel
Expand All @@ -269,31 +283,10 @@ func (n *nodeCache) GetNode(ctx context.Context, id compact.NodeID) ([]byte, err
return r.GetRootHash(nil)
}

// newTileFetcher returns a GetTileFunc based on the passed in Fetcher and log size.
func newTileFetcher(f Fetcher, logSize uint64) GetTileFunc {
return func(ctx context.Context, level, index uint64) (*api.HashTile, error) {
p := layout.TilePath(level, index, logSize)
t, err := f(ctx, p)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to read tile at %q: %w", p, err)
}
return nil, err
}

var tile api.HashTile
if err := tile.UnmarshalText(t); err != nil {
return nil, fmt.Errorf("failed to parse tile: %w", err)
}
return &tile, nil
}
}

// GetEntryBundle fetches the entry bundle at the given _tile index_.
func GetEntryBundle(ctx context.Context, f Fetcher, i, logSize uint64) (api.EntryBundle, error) {
func GetEntryBundle(ctx context.Context, f EntryBundleFetcherFunc, i, logSize uint64) (api.EntryBundle, error) {
bundle := api.EntryBundle{}
p := layout.EntriesPath(i, logSize)
sRaw, err := f(ctx, p)
sRaw, err := f(ctx, i, logSize)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return bundle, fmt.Errorf("leaf bundle at index %d not found: %v", i, err)
Expand All @@ -310,7 +303,8 @@ func GetEntryBundle(ctx context.Context, f Fetcher, i, logSize uint64) (api.Entr
// This tracker handles verification that updates to the tracked log state are
// consistent with previously seen states.
type LogStateTracker struct {
Fetcher Fetcher
CPFetcher CheckpointFetcherFunc
TileFetcher TileFetcherFunc
// Origin is the expected first line of checkpoints from the log.
Origin string
ConsensusCheckpoint ConsensusCheckpointFunc
Expand All @@ -331,10 +325,11 @@ type LogStateTracker struct {
// NewLogStateTracker creates a newly initialised tracker.
// If a serialised LogState representation is provided then this is used as the
// initial tracked state, otherwise a log state is fetched from the target log.
func NewLogStateTracker(ctx context.Context, f Fetcher, checkpointRaw []byte, nV note.Verifier, origin string, cc ConsensusCheckpointFunc) (LogStateTracker, error) {
func NewLogStateTracker(ctx context.Context, cpF CheckpointFetcherFunc, tF TileFetcherFunc, checkpointRaw []byte, nV note.Verifier, origin string, cc ConsensusCheckpointFunc) (LogStateTracker, error) {
ret := LogStateTracker{
ConsensusCheckpoint: cc,
Fetcher: f,
CPFetcher: cpF,
TileFetcher: tF,
LatestConsistent: log.Checkpoint{},
CheckpointNote: nil,
CpSigVerifier: nV,
Expand All @@ -347,7 +342,7 @@ func NewLogStateTracker(ctx context.Context, f Fetcher, checkpointRaw []byte, nV
return ret, err
}
ret.LatestConsistent = *cp
ret.ProofBuilder, err = NewProofBuilder(ctx, ret.LatestConsistent, ret.Fetcher)
ret.ProofBuilder, err = NewProofBuilder(ctx, ret.LatestConsistent, ret.TileFetcher)
if err != nil {
return ret, fmt.Errorf("NewProofBuilder: %v", err)
}
Expand Down Expand Up @@ -390,7 +385,7 @@ func (lst *LogStateTracker) Update(ctx context.Context) ([]byte, [][]byte, []byt
if err != nil {
return nil, nil, nil, err
}
builder, err := NewProofBuilder(ctx, *c, lst.Fetcher)
builder, err := NewProofBuilder(ctx, *c, lst.TileFetcher)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to create proof builder: %w", err)
}
Expand Down Expand Up @@ -421,7 +416,7 @@ func (lst *LogStateTracker) Update(ctx context.Context) ([]byte, [][]byte, []byt
}

// CheckConsistency is a wapper function which simplifies verifying consistency between two or more checkpoints.
func CheckConsistency(ctx context.Context, f Fetcher, cp []log.Checkpoint) error {
func CheckConsistency(ctx context.Context, f TileFetcherFunc, cp []log.Checkpoint) error {
if l := len(cp); l < 2 {
return fmt.Errorf("passed %d checkpoints, need at least 2", l)
}
Expand Down
37 changes: 18 additions & 19 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"

"github.com/transparency-dev/formats/log"
"github.com/transparency-dev/merkle/compact"
"github.com/transparency-dev/trillian-tessera/api"
"github.com/transparency-dev/trillian-tessera/api/layout"
"golang.org/x/mod/sumdb/note"
)

Expand Down Expand Up @@ -76,6 +76,10 @@ func testLogFetcher(_ context.Context, p string) ([]byte, error) {
return os.ReadFile(path)
}

func testLogTileFetcher(ctx context.Context, l, i, s uint64) ([]byte, error) {
return testLogFetcher(ctx, layout.TilePath(l, i, s))
}

// fetchCheckpointShim allows fetcher requests for checkpoints to be intercepted.
type fetchCheckpointShim struct {
// Checkpoints holds raw checkpoints to be returned when the fetcher is asked to retrieve a checkpoint path.
Expand All @@ -86,17 +90,12 @@ type fetchCheckpointShim struct {
// Fetcher intercepts requests for the checkpoint file, returning the zero-th
// entry in the Checkpoints field. All other requests are passed through
// to the delegate fetcher.
func (f *fetchCheckpointShim) Fetcher(deleg Fetcher) Fetcher {
return func(ctx context.Context, path string) ([]byte, error) {
if strings.HasSuffix(path, "checkpoint") {
if len(f.Checkpoints) == 0 {
return nil, os.ErrNotExist
}
r := f.Checkpoints[0]
return r, nil
}
return deleg(ctx, path)
func (f *fetchCheckpointShim) FetchCheckpoint(ctx context.Context) ([]byte, error) {
if len(f.Checkpoints) == 0 {
return nil, os.ErrNotExist
}
r := f.Checkpoints[0]
return r, nil
}

// Advance causes subsequent intercepted checkpoint requests to return
Expand Down Expand Up @@ -177,8 +176,7 @@ func TestCheckLogStateTracker(t *testing.T) {
} {
t.Run(test.desc, func(t *testing.T) {
shim := fetchCheckpointShim{Checkpoints: test.cpRaws}
f := shim.Fetcher(testLogFetcher)
lst, err := NewLogStateTracker(ctx, f, testRawCheckpoints[0], testLogVerifier, testOrigin, UnilateralConsensus(f))
lst, err := NewLogStateTracker(ctx, shim.FetchCheckpoint, testLogTileFetcher, testRawCheckpoints[0], testLogVerifier, testOrigin, UnilateralConsensus(shim.FetchCheckpoint))
if err != nil {
t.Fatalf("NewLogStateTracker: %v", err)
}
Expand Down Expand Up @@ -300,7 +298,7 @@ func TestCheckConsistency(t *testing.T) {
},
} {
t.Run(test.desc, func(t *testing.T) {
err := CheckConsistency(ctx, testLogFetcher, test.cp)
err := CheckConsistency(ctx, testLogTileFetcher, test.cp)
if gotErr := err != nil; gotErr != test.wantErr {
t.Fatalf("wantErr: %t, got %v", test.wantErr, err)
}
Expand All @@ -310,11 +308,12 @@ func TestCheckConsistency(t *testing.T) {

func TestNodeCacheHandlesInvalidRequest(t *testing.T) {
ctx := context.Background()
wantBytes := []byte("one")
f := func(_ context.Context, _, _ uint64) (*api.HashTile, error) {
return &api.HashTile{
wantBytes := []byte("0123456789ABCDEF0123456789ABCDEF")
f := func(_ context.Context, _, _, _ uint64) ([]byte, error) {
h := &api.HashTile{
Nodes: [][]byte{wantBytes},
}, nil
}
return h.MarshalText()
}

// Large tree, but we're emulating skew since f, above, will return a tile which only knows about 1
Expand All @@ -340,7 +339,7 @@ func TestHandleZeroRoot(t *testing.T) {
if len(zeroCP.Hash) == 0 {
t.Fatal("BadTestData: checkpoint.0 has empty root hash")
}
if _, err := NewProofBuilder(context.Background(), zeroCP, testLogFetcher); err != nil {
if _, err := NewProofBuilder(context.Background(), zeroCP, testLogTileFetcher); err != nil {
t.Fatalf("NewProofBuilder: %v", err)
}
}
Loading

0 comments on commit fb28144

Please sign in to comment.