Skip to content

Commit

Permalink
more polishing in UnsignedBigInt impl
Browse files Browse the repository at this point in the history
  • Loading branch information
kushti committed Sep 17, 2024
1 parent 343a385 commit 1c2b99d
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 45 deletions.
24 changes: 24 additions & 0 deletions core/shared/src/main/scala/sigma/SigmaDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,15 @@ trait BigInt {
*/
def shiftRight(n: Int): BigInt

/**
* @return unsigned representation of this BigInt, or exception if its value is negative
*/
def toUnsigned: UnsignedBigInt

/**
* @return unsigned representation of this BigInt modulo `m`. Cryptographic mod operation is done, ie result is
* non-negative always
*/
def toUnsignedMod(m: UnsignedBigInt): UnsignedBigInt
}

Expand Down Expand Up @@ -297,6 +304,23 @@ trait UnsignedBigInt {
def subtractMod(that: UnsignedBigInt, m: UnsignedBigInt): UnsignedBigInt
def multiplyMod(that: UnsignedBigInt, m: UnsignedBigInt): UnsignedBigInt

/**
* @return a big integer whose value is `this xor that`
*/
def xor(that: UnsignedBigInt): UnsignedBigInt

/**
* @return a 256-bit signed integer whose value is (this << n). The shift distance, n, may be negative,
* in which case this method performs a right shift. (Computes floor(this * 2n).)
*/
def shiftLeft(n: Int): UnsignedBigInt

/**
* @return a 256-bit signed integer whose value is (this >> n). Sign extension is performed. The shift distance, n,
* may be negative, in which case this method performs a left shift. (Computes floor(this / 2n).)
*/
def shiftRight(n: Int): UnsignedBigInt

def toSigned(): BigInt
}

Expand Down
70 changes: 39 additions & 31 deletions core/shared/src/main/scala/sigma/ast/SType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ object SType {
SBoolean, SString, STuple, SGroupElement, SSigmaProp, SContext, SGlobal, SHeader, SPreHeader,
SAvlTree, SBox, SOption, SCollection, SBigInt
)
private val v6Types = v5Types ++ Seq(SByte, SShort, SInt, SLong)
private val v6Types = v5Types ++ Seq(SByte, SShort, SInt, SLong, SUnsignedBigInt)

private val v5TypesMap = v5Types.map { t => (t.typeId, t) }.toMap

Expand Down Expand Up @@ -398,6 +398,7 @@ case object SByte extends SPrimType with SEmbeddable with SNumericType with SMon
case i: Int => i.toByteExact
case l: Long => l.toByteExact
case bi: BigInt if VersionContext.current.isV6SoftForkActivated => bi.toByte // toByteExact from int is called under the hood
case ubi: UnsignedBigInt if VersionContext.current.isV6SoftForkActivated => ubi.toByte // toByteExact from int is called under the hood
case _ => sys.error(s"Cannot downcast value $v to the type $this")
}
}
Expand All @@ -420,6 +421,7 @@ case object SShort extends SPrimType with SEmbeddable with SNumericType with SMo
case i: Int => i.toShortExact
case l: Long => l.toShortExact
case bi: BigInt if VersionContext.current.isV6SoftForkActivated => bi.toShort // toShortExact from int is called under the hood
case ubi: UnsignedBigInt if VersionContext.current.isV6SoftForkActivated => ubi.toShort // toShortExact from int is called under the hood
case _ => sys.error(s"Cannot downcast value $v to the type $this")
}
}
Expand All @@ -444,6 +446,7 @@ case object SInt extends SPrimType with SEmbeddable with SNumericType with SMono
case i: Int => i
case l: Long => l.toIntExact
case bi: BigInt if VersionContext.current.isV6SoftForkActivated => bi.toInt
case ubi: UnsignedBigInt if VersionContext.current.isV6SoftForkActivated => ubi.toInt
case _ => sys.error(s"Cannot downcast value $v to the type $this")
}
}
Expand All @@ -470,6 +473,7 @@ case object SLong extends SPrimType with SEmbeddable with SNumericType with SMon
case i: Int => i.toLong
case l: Long => l
case bi: BigInt if VersionContext.current.isV6SoftForkActivated => bi.toLong
case ubi: UnsignedBigInt if VersionContext.current.isV6SoftForkActivated => ubi.toLong
case _ => sys.error(s"Cannot downcast value $v to the type $this")
}
}
Expand Down Expand Up @@ -687,15 +691,16 @@ object SOption extends STypeCompanion {

override val reprClass: RClass[_] = RClass(classOf[Option[_]])

type SBooleanOption = SOption[SBoolean.type]
type SByteOption = SOption[SByte.type]
type SShortOption = SOption[SShort.type]
type SIntOption = SOption[SInt.type]
type SLongOption = SOption[SLong.type]
type SBigIntOption = SOption[SBigInt.type]
type SGroupElementOption = SOption[SGroupElement.type]
type SBoxOption = SOption[SBox.type]
type SAvlTreeOption = SOption[SAvlTree.type]
type SBooleanOption = SOption[SBoolean.type]
type SByteOption = SOption[SByte.type]
type SShortOption = SOption[SShort.type]
type SIntOption = SOption[SInt.type]
type SLongOption = SOption[SLong.type]
type SBigIntOption = SOption[SBigInt.type]
type SUnsignedBigIntOption = SOption[SUnsignedBigInt.type]
type SGroupElementOption = SOption[SGroupElement.type]
type SBoxOption = SOption[SBox.type]
type SAvlTreeOption = SOption[SAvlTree.type]

/** This descriptors are instantiated once here and then reused. */
implicit val SByteOption = SOption(SByte)
Expand All @@ -704,6 +709,7 @@ object SOption extends STypeCompanion {
implicit val SIntOption = SOption(SInt)
implicit val SLongOption = SOption(SLong)
implicit val SBigIntOption = SOption(SBigInt)
implicit val SUnsignedBigIntOption = SOption(SUnsignedBigInt)
implicit val SBooleanOption = SOption(SBoolean)
implicit val SAvlTreeOption = SOption(SAvlTree)
implicit val SGroupElementOption = SOption(SGroupElement)
Expand Down Expand Up @@ -764,29 +770,31 @@ object SCollection extends STypeCompanion {
def apply[T <: SType](elemType: T): SCollection[T] = SCollectionType(elemType)
def apply[T <: SType](implicit elemType: T, ov: Overloaded1): SCollection[T] = SCollectionType(elemType)

type SBooleanArray = SCollection[SBoolean.type]
type SByteArray = SCollection[SByte.type]
type SShortArray = SCollection[SShort.type]
type SIntArray = SCollection[SInt.type]
type SLongArray = SCollection[SLong.type]
type SBigIntArray = SCollection[SBigInt.type]
type SGroupElementArray = SCollection[SGroupElement.type]
type SBoxArray = SCollection[SBox.type]
type SAvlTreeArray = SCollection[SAvlTree.type]
type SBooleanArray = SCollection[SBoolean.type]
type SByteArray = SCollection[SByte.type]
type SShortArray = SCollection[SShort.type]
type SIntArray = SCollection[SInt.type]
type SLongArray = SCollection[SLong.type]
type SBigIntArray = SCollection[SBigInt.type]
type SUnsignedBigIntArray = SCollection[SUnsignedBigInt.type]
type SGroupElementArray = SCollection[SGroupElement.type]
type SBoxArray = SCollection[SBox.type]
type SAvlTreeArray = SCollection[SAvlTree.type]

/** This descriptors are instantiated once here and then reused. */
val SBooleanArray = SCollection(SBoolean)
val SByteArray = SCollection(SByte)
val SByteArray2 = SCollection(SCollection(SByte))
val SShortArray = SCollection(SShort)
val SIntArray = SCollection(SInt)
val SLongArray = SCollection(SLong)
val SBigIntArray = SCollection(SBigInt)
val SGroupElementArray = SCollection(SGroupElement)
val SSigmaPropArray = SCollection(SSigmaProp)
val SBoxArray = SCollection(SBox)
val SAvlTreeArray = SCollection(SAvlTree)
val SHeaderArray = SCollection(SHeader)
val SBooleanArray = SCollection(SBoolean)
val SByteArray = SCollection(SByte)
val SByteArray2 = SCollection(SCollection(SByte))
val SShortArray = SCollection(SShort)
val SIntArray = SCollection(SInt)
val SLongArray = SCollection(SLong)
val SBigIntArray = SCollection(SBigInt)
val SUnsignedBigIntArray = SCollection(SUnsignedBigInt)
val SGroupElementArray = SCollection(SGroupElement)
val SSigmaPropArray = SCollection(SSigmaProp)
val SBoxArray = SCollection(SBox)
val SAvlTreeArray = SCollection(SAvlTree)
val SHeaderArray = SCollection(SHeader)
}

/** Type descriptor of tuple type. */
Expand Down
34 changes: 23 additions & 11 deletions core/shared/src/main/scala/sigma/data/CBigInt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ case class CBigInt(override val wrappedValue: BigInteger) extends BigInt with Wr

override def signum: Int = wrappedValue.signum()

override def add(that: BigInt): BigInt = CBigInt(wrappedValue.add(that.asInstanceOf[CBigInt].wrappedValue).to256BitValueExact)
override def add(that: BigInt): BigInt = CBigInt(wrappedValue.add(that.asInstanceOf[CBigInt].wrappedValue).toSignedBigIntValueExact)

override def subtract(that: BigInt): BigInt = CBigInt(wrappedValue.subtract(that.asInstanceOf[CBigInt].wrappedValue).to256BitValueExact)
override def subtract(that: BigInt): BigInt = CBigInt(wrappedValue.subtract(that.asInstanceOf[CBigInt].wrappedValue).toSignedBigIntValueExact)

override def multiply(that: BigInt): BigInt = CBigInt(wrappedValue.multiply(that.asInstanceOf[CBigInt].wrappedValue).to256BitValueExact)
override def multiply(that: BigInt): BigInt = CBigInt(wrappedValue.multiply(that.asInstanceOf[CBigInt].wrappedValue).toSignedBigIntValueExact)

override def divide(that: BigInt): BigInt = CBigInt(wrappedValue.divide(that.asInstanceOf[CBigInt].wrappedValue))

Expand All @@ -44,17 +44,17 @@ case class CBigInt(override val wrappedValue: BigInteger) extends BigInt with Wr

override def max(that: BigInt): BigInt = CBigInt(wrappedValue.max(that.asInstanceOf[CBigInt].wrappedValue))

override def negate(): BigInt = CBigInt(wrappedValue.negate().to256BitValueExact)
override def negate(): BigInt = CBigInt(wrappedValue.negate().toSignedBigIntValueExact)

override def and(that: BigInt): BigInt = CBigInt(wrappedValue.and(that.asInstanceOf[CBigInt].wrappedValue))

override def or(that: BigInt): BigInt = CBigInt(wrappedValue.or(that.asInstanceOf[CBigInt].wrappedValue))

override def xor(that: BigInt): BigInt = CBigInt(wrappedValue.xor(that.asInstanceOf[CBigInt].wrappedValue))

override def shiftLeft(n: Int): BigInt = CBigInt(wrappedValue.shiftLeft(n).to256BitValueExact)
override def shiftLeft(n: Int): BigInt = CBigInt(wrappedValue.shiftLeft(n).toSignedBigIntValueExact)

override def shiftRight(n: Int): BigInt = CBigInt(wrappedValue.shiftRight(n).to256BitValueExact)
override def shiftRight(n: Int): BigInt = CBigInt(wrappedValue.shiftRight(n).toSignedBigIntValueExact)

def toUnsigned: UnsignedBigInt = {
if(this.wrappedValue.compareTo(BigInteger.ZERO) < 0){
Expand Down Expand Up @@ -88,12 +88,13 @@ case class CUnsignedBigInt(override val wrappedValue: BigInteger) extends Unsign
override def compareTo(that: UnsignedBigInt): Int =
wrappedValue.compareTo(that.asInstanceOf[CUnsignedBigInt].wrappedValue)

//todo: consider result's bits limit
override def add(that: UnsignedBigInt): UnsignedBigInt = CUnsignedBigInt(wrappedValue.add(that.asInstanceOf[CUnsignedBigInt].wrappedValue).to256BitValueExact)
override def add(that: UnsignedBigInt): UnsignedBigInt = CUnsignedBigInt(wrappedValue.add(that.asInstanceOf[CUnsignedBigInt].wrappedValue).toUnsignedBigIntValueExact)

override def subtract(that: UnsignedBigInt): UnsignedBigInt = CUnsignedBigInt(wrappedValue.subtract(that.asInstanceOf[CUnsignedBigInt].wrappedValue).to256BitValueExact)
override def subtract(that: UnsignedBigInt): UnsignedBigInt = {
CUnsignedBigInt(wrappedValue.subtract(that.asInstanceOf[CUnsignedBigInt].wrappedValue).toUnsignedBigIntValueExact)
}

override def multiply(that: UnsignedBigInt): UnsignedBigInt = CUnsignedBigInt(wrappedValue.multiply(that.asInstanceOf[CUnsignedBigInt].wrappedValue).to256BitValueExact)
override def multiply(that: UnsignedBigInt): UnsignedBigInt = CUnsignedBigInt(wrappedValue.multiply(that.asInstanceOf[CUnsignedBigInt].wrappedValue).toUnsignedBigIntValueExact)

override def divide(that: UnsignedBigInt): UnsignedBigInt = CUnsignedBigInt(wrappedValue.divide(that.asInstanceOf[CUnsignedBigInt].wrappedValue))

Expand Down Expand Up @@ -129,7 +130,18 @@ case class CUnsignedBigInt(override val wrappedValue: BigInteger) extends Unsign
CUnsignedBigInt(wrappedValue.multiply(thatBi).mod(mBi))
}

/**
* @return a big integer whose value is `this xor that`
*/
def xor(that: UnsignedBigInt): UnsignedBigInt = {
CUnsignedBigInt(wrappedValue.xor(that.asInstanceOf[CUnsignedBigInt].wrappedValue))
}

override def shiftLeft(n: Int): UnsignedBigInt = CUnsignedBigInt(wrappedValue.shiftLeft(n).toUnsignedBigIntValueExact)

override def shiftRight(n: Int): UnsignedBigInt = CUnsignedBigInt(wrappedValue.shiftRight(n).toUnsignedBigIntValueExact)

override def toSigned(): BigInt = {
CBigInt(wrappedValue.to256BitValueExact)
CBigInt(wrappedValue.toSignedBigIntValueExact)
}
}
30 changes: 29 additions & 1 deletion core/shared/src/main/scala/sigma/reflection/ReflectionData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,35 @@ object ReflectionData {
)
)
}
//todo: add UnsignedBigInt
{
val clazz = classOf[sigma.UnsignedBigInt]
val paramTypes = Array[Class[_]](clazz)
registerClassEntry(clazz,
methods = Map(
mkMethod(clazz, "add", paramTypes) { (obj, args) =>
obj.asInstanceOf[UnsignedBigInt].add(args(0).asInstanceOf[UnsignedBigInt])
},
mkMethod(clazz, "max", paramTypes) { (obj, args) =>
obj.asInstanceOf[UnsignedBigInt].max(args(0).asInstanceOf[UnsignedBigInt])
},
mkMethod(clazz, "min", paramTypes) { (obj, args) =>
obj.asInstanceOf[UnsignedBigInt].min(args(0).asInstanceOf[UnsignedBigInt])
},
mkMethod(clazz, "subtract", paramTypes) { (obj, args) =>
obj.asInstanceOf[UnsignedBigInt].subtract(args(0).asInstanceOf[UnsignedBigInt])
},
mkMethod(clazz, "multiply", paramTypes) { (obj, args) =>
obj.asInstanceOf[UnsignedBigInt].multiply(args(0).asInstanceOf[UnsignedBigInt])
},
mkMethod(clazz, "mod", paramTypes) { (obj, args) =>
obj.asInstanceOf[UnsignedBigInt].mod(args(0).asInstanceOf[UnsignedBigInt])
},
mkMethod(clazz, "divide", paramTypes) { (obj, args) =>
obj.asInstanceOf[UnsignedBigInt].divide(args(0).asInstanceOf[UnsignedBigInt])
}
)
)
}
{
val clazz = classOf[CollBuilder]
registerClassEntry(clazz,
Expand Down
11 changes: 10 additions & 1 deletion core/shared/src/main/scala/sigma/util/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ object Extensions {
* not exactly fit in a 256 bit range.
* @see BigInteger#longValueExact
*/
@inline final def to256BitValueExact: BigInteger = {
@inline final def toSignedBigIntValueExact: BigInteger = {
// Comparing with 255 is correct because bitLength() method excludes the sign bit.
// For example, these are the boundary values:
// (new BigInteger("80" + "00" * 31, 16)).bitLength() = 256
Expand All @@ -217,6 +217,15 @@ object Extensions {
throw new ArithmeticException("BigInteger out of 256 bit range");
}

@inline final def toUnsignedBigIntValueExact: BigInteger = {
// todo: make the check soft-forkable
if (x.compareTo(BigInteger.ZERO) >= 0 && x.bitLength() <= 256) {
x
} else {
throw new ArithmeticException("Unsigned BigInteger out of 256 bit range or negative")
}
}

/** Converts `x` to [[sigma.BigInt]] */
def toBigInt: sigma.BigInt = CBigInt(x)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class CSigmaDslBuilder extends SigmaDslBuilder { dsl =>
}

override def byteArrayToBigInt(bytes: Coll[Byte]): BigInt = {
val bi = new BigInteger(bytes.toArray).to256BitValueExact
val bi = new BigInteger(bytes.toArray).toSignedBigIntValueExact
this.BigInt(bi)
}

Expand Down

0 comments on commit 1c2b99d

Please sign in to comment.