Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug of addition law of elliptic curve groups #81

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion crypto/ecpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import (
"github.com/bnb-chain/tss-lib/v2/tss"
)

var (
zero = big.NewInt(0)
)

// ECPoint convenience helper
type ECPoint struct {
curve elliptic.Curve
Expand All @@ -47,19 +51,63 @@ func NewECPointNoCurveCheck(curve elliptic.Curve, X, Y *big.Int) *ECPoint {
}

func (p *ECPoint) X() *big.Int {
if p.coords[0] == nil {
return nil
}
return new(big.Int).Set(p.coords[0])
}

func (p *ECPoint) Y() *big.Int {
if p.coords[1] == nil {
return nil
}
return new(big.Int).Set(p.coords[1])
}

func (p *ECPoint) Add(p1 *ECPoint) (*ECPoint, error) {
if p.X() == nil {
if p.Y() != nil {
return nil, fmt.Errorf("Add: the format of the point is wrong")
}
if p1.X() == nil {
if p1.Y() == nil {
return NewECPoint(p.curve, nil, nil)
}
return nil, fmt.Errorf("Add: the format of the point is wrong")
}
if p1.Y() == nil {
return nil, fmt.Errorf("Add: the format of the point is wrong")
}
return NewECPoint(p.curve, new(big.Int).Set(p1.X()), new(big.Int).Set(p1.Y()))
}
if p.Y() == nil {
return nil, fmt.Errorf("Add: the format of the point is wrong")
}
if p1.X() == nil {
if p1.X() != nil {
return nil, fmt.Errorf("Add: the format of the point is wrong")
}
return NewECPoint(p.curve, new(big.Int).Set(p.X()), new(big.Int).Set(p.Y()))
}

// The case : aG+(-a)G
tempNegative := new(big.Int).Neg(p1.Y())
tempNegative.Mod(tempNegative, p.curve.Params().P)
if tempNegative.Cmp(p.Y()) == 0 {
return NewECPoint(p.curve, nil, nil)
}

// The sum of the other cases
x, y := p.curve.Add(p.X(), p.Y(), p1.X(), p1.Y())
return NewECPoint(p.curve, x, y)
}

func (p *ECPoint) ScalarMult(k *big.Int) *ECPoint {
if new(big.Int).Mod(k, p.curve.Params().N).Cmp(zero) == 0 {
identity, _ := NewECPoint(p.curve, nil, nil)
return identity
}

x, y := p.curve.ScalarMult(p.X(), p.Y(), k.Bytes())
newP, err := NewECPoint(p.curve, x, y) // it must be on the curve, no need to check.
if err != nil {
Expand Down Expand Up @@ -88,6 +136,17 @@ func (p *ECPoint) Equals(p2 *ECPoint) bool {
if p == nil || p2 == nil {
return false
}
if p.X() == nil && p2.X() == nil {
if p.Y() == nil && p2.Y() == nil {
return true
}
}
if p.X() == nil || p.Y() == nil {
return false
}
if p2.X() == nil || p2.Y() == nil {
return false
}
return p.X().Cmp(p2.X()) == 0 && p.Y().Cmp(p2.Y()) == 0
}

Expand All @@ -97,14 +156,22 @@ func (p *ECPoint) SetCurve(curve elliptic.Curve) *ECPoint {
}

func (p *ECPoint) ValidateBasic() bool {
return p != nil && p.coords[0] != nil && p.coords[1] != nil && p.IsOnCurve()
if p == nil {
return false
}
return p.IsOnCurve()
}

func (p *ECPoint) EightInvEight() *ECPoint {
return p.ScalarMult(eight).ScalarMult(eightInv)
}

func ScalarBaseMult(curve elliptic.Curve, k *big.Int) *ECPoint {
if new(big.Int).Mod(k, curve.Params().N).Cmp(zero) == 0 {
p, _ := NewECPoint(curve, nil, nil)
return p
}

x, y := curve.ScalarBaseMult(k.Bytes())
p, err := NewECPoint(curve, x, y) // it must be on the curve, no need to check.
if err != nil {
Expand All @@ -114,6 +181,11 @@ func ScalarBaseMult(curve elliptic.Curve, k *big.Int) *ECPoint {
}

func isOnCurve(c elliptic.Curve, x, y *big.Int) bool {
// identity elemenet in the elliptic curve group
if x == nil && y == nil {
return true

}
if x == nil || y == nil {
return false
}
Expand Down
105 changes: 105 additions & 0 deletions crypto/ecpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
package crypto_test

import (

"crypto/elliptic"

"encoding/hex"
"encoding/json"

"math/big"
"reflect"
"testing"
Expand Down Expand Up @@ -120,7 +124,108 @@
}
}


func TestAddECPoints(t *testing.T) {

curveList := []*elliptic.CurveParams{elliptic.P224().Params(), elliptic.P256().Params(), elliptic.P384().Params()}

// Check 2 + (N-2) = identity element, where N is the order of a given elliptic curve group
for i := 0; i < len(curveList); i++ {
minus2 := big.NewInt(-2)
ECPoint1 := ScalarBaseMult(curveList[i], new(big.Int).Mod(minus2, curveList[i].N))
ECPoint2 := ScalarBaseMult(curveList[i], big.NewInt(2))

result, err := ECPoint1.Add(ECPoint2)

if err != nil {
t.Errorf("Add() error = %v", err)
}

if result.X() != nil || result.Y() != nil {
t.Errorf("Add() expect = nil,nil, got X = %v, Y=%v", result.X(), result.Y())
}
}

// Check identity + 5566*G = 5566G
for i := 0; i < len(curveList); i++ {
ECPoint1 := ScalarBaseMult(curveList[i], big.NewInt(0))
ECPoint2 := ScalarBaseMult(curveList[i], big.NewInt(5566))

result, err := ECPoint1.Add(ECPoint2)

if err != nil {
t.Errorf("Add() error = %v", err)
}

expect := ScalarBaseMult(curveList[i], big.NewInt(5566))

if result.X().Cmp(expect.X()) != 0 || result.Y().Cmp(expect.Y()) != 0 {
t.Errorf("Add() error = Two points not the same, result X = %v, Y=%v, expect X = %v, Y=%v", result.X(), result.Y(), expect.X(), expect.Y())
}
}

// Check 5566*G + identity = 5566G
for i := 0; i < len(curveList); i++ {
ECPoint1 := ScalarBaseMult(curveList[i], big.NewInt(5566))
ECPoint2 := ScalarBaseMult(curveList[i], big.NewInt(0))

result, err := ECPoint1.Add(ECPoint2)

if err != nil {
t.Errorf("Add() error = %v", err)
}

expect := ScalarBaseMult(curveList[i], big.NewInt(5566))

if result.X().Cmp(expect.X()) != 0 || result.Y().Cmp(expect.Y()) != 0 {
t.Errorf("Add() error = Two points not the same, result X = %v, Y=%v, expect X = %v, Y=%v", result.X(), result.Y(), expect.X(), expect.Y())
}
}

// Check 5*G +5*G = 10*G
for i := 0; i < len(curveList); i++ {
ECPoint1 := ScalarBaseMult(curveList[i], big.NewInt(5))
ECPoint2 := ScalarBaseMult(curveList[i], big.NewInt(5))

result, err := ECPoint1.Add(ECPoint2)

if err != nil {
t.Errorf("Add() error = %v", err)
}

expect := ScalarBaseMult(curveList[i], big.NewInt(10))

if result.X().Cmp(expect.X()) != 0 || result.Y().Cmp(expect.Y()) != 0 {
t.Errorf("Add() error = Two points not the same, result X = %v, Y=%v, expect X = %v, Y=%v", result.X(), result.Y(), expect.X(), expect.Y())
}
}
}

func TestScalarMult(t *testing.T) {
curveList := []*elliptic.CurveParams{elliptic.P224().Params(), elliptic.P256().Params(), elliptic.P384().Params()}

for i := 0; i < len(curveList); i++ {
ECPoint1 := ScalarBaseMult(curveList[i], big.NewInt(5))
result := ECPoint1.ScalarMult(curveList[i].N)

if result.X() != nil || result.Y() != nil {
t.Errorf("Add() expect = nil,nil, got X = %v, Y=%v", result.X(), result.Y())
}
}
}

func TestScalarBaseMult(t *testing.T) {
curveList := []*elliptic.CurveParams{elliptic.P224().Params(), elliptic.P256().Params(), elliptic.P384().Params()}

for i := 0; i < len(curveList); i++ {
result := ScalarBaseMult(curveList[i], big.NewInt(0))

if result.X() != nil || result.Y() != nil {
t.Errorf("Add() expect = nil,nil, got X = %v, Y=%v", result.X(), result.Y())
}
}

func TestS256EcpointJsonSerialization(t *testing.T) {

Check failure on line 228 in crypto/ecpoint_test.go

View workflow job for this annotation

GitHub Actions / Test

expected '(', found TestS256EcpointJsonSerialization
ec := btcec.S256()
tss.RegisterCurve("secp256k1", ec)

Expand Down
Loading