Skip to content

Commit

Permalink
additional checks, simplifications
Browse files Browse the repository at this point in the history
Since lightwalletd only works with Sapling and beyond, and since Sapling
requires v4 transactions, there's no need to distinguish between pre-v4
and v4 (we will never see pre-v4). This simplifies some of the
conditionals.

Also, add some additional checks, such as version group ID and consensus
branch ID.
  • Loading branch information
Larry Ruane committed Oct 18, 2021
1 parent 25949b2 commit 6f019a3
Showing 1 changed file with 55 additions and 58 deletions.
113 changes: 55 additions & 58 deletions parser/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,13 @@ func (tx *Transaction) ToCompact(index int) *walletrpc.CompactTx {
return ctx
}

// parse versions 2 through 4
func (tx *Transaction) parseV1to4(data []byte, version uint32) ([]byte, error) {
// parse version 4
func (tx *Transaction) parseV4(data []byte) ([]byte, error) {
s := bytestring.String(data)
var err error
if tx.nVersionGroupID != 0x892F2085 {
return nil, errors.New(fmt.Sprintf("version group ID %x must be 0x892F2085", tx.nVersionGroupID))
}
s, err = tx.ParseTransparent([]byte(s))
if err != nil {
return nil, err
Expand All @@ -373,45 +376,41 @@ func (tx *Transaction) parseV1to4(data []byte, version uint32) ([]byte, error) {
return nil, errors.New("could not read nLockTime")
}

if tx.fOverwintered {
if !s.ReadUint32(&tx.nExpiryHeight) {
return nil, errors.New("could not read nExpiryHeight")
}
if !s.ReadUint32(&tx.nExpiryHeight) {
return nil, errors.New("could not read nExpiryHeight")
}

var spendCount, outputCount int

if tx.version >= 4 {
if !s.ReadInt64(&tx.valueBalanceSapling) {
return nil, errors.New("could not read valueBalance")
}
if !s.ReadCompactSize(&spendCount) {
return nil, errors.New("could not read nShieldedSpend")
}
if spendCount > 0 {
tx.shieldedSpends = make([]*spend, spendCount)
for i := 0; i < spendCount; i++ {
newSpend := &spend{}
s, err = newSpend.ParseFromSlice([]byte(s), tx.version)
if err != nil {
return nil, errors.Wrap(err, "while parsing shielded Spend")
}
tx.shieldedSpends[i] = newSpend
if !s.ReadInt64(&tx.valueBalanceSapling) {
return nil, errors.New("could not read valueBalance")
}
if !s.ReadCompactSize(&spendCount) {
return nil, errors.New("could not read nShieldedSpend")
}
if spendCount > 0 {
tx.shieldedSpends = make([]*spend, spendCount)
for i := 0; i < spendCount; i++ {
newSpend := &spend{}
s, err = newSpend.ParseFromSlice([]byte(s), 4)
if err != nil {
return nil, errors.Wrap(err, "while parsing shielded Spend")
}
tx.shieldedSpends[i] = newSpend
}
if !s.ReadCompactSize(&outputCount) {
return nil, errors.New("could not read nShieldedOutput")
}
if outputCount > 0 {
tx.shieldedOutputs = make([]*output, outputCount)
for i := 0; i < outputCount; i++ {
newOutput := &output{}
s, err = newOutput.ParseFromSlice([]byte(s), tx.version)
if err != nil {
return nil, errors.Wrap(err, "while parsing shielded Output")
}
tx.shieldedOutputs[i] = newOutput
}
if !s.ReadCompactSize(&outputCount) {
return nil, errors.New("could not read nShieldedOutput")
}
if outputCount > 0 {
tx.shieldedOutputs = make([]*output, outputCount)
for i := 0; i < outputCount; i++ {
newOutput := &output{}
s, err = newOutput.ParseFromSlice([]byte(s), 4)
if err != nil {
return nil, errors.Wrap(err, "while parsing shielded Output")
}
tx.shieldedOutputs[i] = newOutput
}
}
var joinSplitCount int
Expand All @@ -422,7 +421,7 @@ func (tx *Transaction) parseV1to4(data []byte, version uint32) ([]byte, error) {
if joinSplitCount > 0 {
tx.joinSplits = make([]*joinSplit, joinSplitCount)
for i := 0; i < joinSplitCount; i++ {
js := &joinSplit{version: tx.version}
js := &joinSplit{version: 4}
s, err = js.ParseFromSlice([]byte(s))
if err != nil {
return nil, errors.Wrap(err, "while parsing JoinSplit")
Expand All @@ -438,21 +437,25 @@ func (tx *Transaction) parseV1to4(data []byte, version uint32) ([]byte, error) {
return nil, errors.New("could not read joinSplitSig")
}
}
if tx.version >= 4 {
if spendCount+outputCount > 0 && !s.ReadBytes(&tx.bindingSigSapling, 64) {
return nil, errors.New("could not read bindingSigSapling")
}
if spendCount+outputCount > 0 && !s.ReadBytes(&tx.bindingSigSapling, 64) {
return nil, errors.New("could not read bindingSigSapling")
}
return s, nil
}

// parse version 5
func (tx *Transaction) parseV5(data []byte, version uint32) ([]byte, error) {
func (tx *Transaction) parseV5(data []byte) ([]byte, error) {
s := bytestring.String(data)
var err error
if !s.ReadUint32(&tx.consensusBranchID) {
return nil, errors.New("could not read nVersionGroupId")
}
if tx.nVersionGroupID != 0x26A7270A {
return nil, errors.New(fmt.Sprintf("version number %d must be 0x26A7270A", tx.nVersionGroupID))
}
if tx.consensusBranchID != 0x37519621 {
return nil, errors.New("unknown consensusBranchID")
}
if !s.ReadUint32(&tx.nLockTime) {
return nil, errors.New("could not read nLockTime")
}
Expand Down Expand Up @@ -511,6 +514,9 @@ func (tx *Transaction) parseV5(data []byte, version uint32) ([]byte, error) {
if !s.ReadCompactSize(&actionsCount) {
return nil, errors.New("could not read nActionsOrchard")
}
if actionsCount >= (1 << 16) {
return nil, errors.New(fmt.Sprintf("actionsCount (%d) must be less than 2^16", actionsCount))
}
if !s.Skip(820 * actionsCount) {
return nil, errors.New("could not skip vActionsOrchard")
}
Expand Down Expand Up @@ -554,31 +560,22 @@ func (tx *Transaction) ParseFromSlice(data []byte) ([]byte, error) {
}

tx.fOverwintered = (header >> 31) == 1
tx.version = header & 0x7FFFFFFF

// Implement the effective version rule. From the spec section 7.1:
//
// "Version constraints apply to the effectiveVersion, which is equal to
// min(2, version) when fOverwintered = 0 and to version otherwise."
if !tx.fOverwintered && tx.version > 2 {
tx.version = 2
if !tx.fOverwintered {
return nil, errors.New("fOverwinter flag must be set")
}

// Spec says fOverwinter must be set for version 5.
if tx.version >= 5 && !tx.fOverwintered {
return nil, errors.New(fmt.Sprintf("version %d requires fOverwinter", tx.version))
tx.version = header & 0x7FFFFFFF
if tx.version < 4 {
return nil, errors.New(fmt.Sprintf("version number %d must be greater or equal to 4", tx.version))
}

if tx.version >= 3 {
if !s.ReadUint32(&tx.nVersionGroupID) {
return nil, errors.New("could not read nVersionGroupId")
}
if !s.ReadUint32(&tx.nVersionGroupID) {
return nil, errors.New("could not read nVersionGroupId")
}
// parse the main part of the transaction
if tx.version <= 4 {
s, err = tx.parseV1to4([]byte(s), tx.version)
s, err = tx.parseV4([]byte(s))
} else {
s, err = tx.parseV5([]byte(s), tx.version)
s, err = tx.parseV5([]byte(s))
}
if err != nil {
return nil, err
Expand Down

0 comments on commit 6f019a3

Please sign in to comment.