diff --git a/src/main/scala/org/ergoplatform/http/api/ErgoBaseApiRoute.scala b/src/main/scala/org/ergoplatform/http/api/ErgoBaseApiRoute.scala index 5c00e87902..dafc0bf0a1 100644 --- a/src/main/scala/org/ergoplatform/http/api/ErgoBaseApiRoute.scala +++ b/src/main/scala/org/ergoplatform/http/api/ErgoBaseApiRoute.scala @@ -91,9 +91,9 @@ trait ErgoBaseApiRoute extends ApiRoute with ApiCodecs { val maxTxCost = ergoSettings.nodeSettings.maxTransactionCost utxo.withMempool(mp) .validateWithCost(tx, maxTxCost) - .map(cost => UnconfirmedTransaction(tx, Some(cost), now, now, bytes, source = None)) + .map(cost => new UnconfirmedTransaction(tx, Some(cost), now, now, bytes, source = None)) case _ => - tx.statelessValidity().map(_ => UnconfirmedTransaction(tx, None, now, now, bytes, source = None)) + tx.statelessValidity().map(_ => new UnconfirmedTransaction(tx, None, now, now, bytes, source = None)) } } diff --git a/src/main/scala/org/ergoplatform/modifiers/mempool/UnconfirmedTransaction.scala b/src/main/scala/org/ergoplatform/modifiers/mempool/UnconfirmedTransaction.scala index 5c99a0e4e3..bd18f338d8 100644 --- a/src/main/scala/org/ergoplatform/modifiers/mempool/UnconfirmedTransaction.scala +++ b/src/main/scala/org/ergoplatform/modifiers/mempool/UnconfirmedTransaction.scala @@ -13,12 +13,12 @@ import scorex.util.{ModifierId, ScorexLogging} * @param transactionBytes - transaction bytes, to avoid serializations when we send it over the wire * @param source - peer which delivered the transaction (None if transaction submitted via API) */ -case class UnconfirmedTransaction(transaction: ErgoTransaction, - lastCost: Option[Int], - createdTime: Long, - lastCheckedTime: Long, - transactionBytes: Option[Array[Byte]], - source: Option[ConnectedPeer]) +class UnconfirmedTransaction(val transaction: ErgoTransaction, + val lastCost: Option[Int], + val createdTime: Long, + val lastCheckedTime: Long, + val transactionBytes: Option[Array[Byte]], + val source: Option[ConnectedPeer]) extends ScorexLogging { def id: ModifierId = transaction.id @@ -27,7 +27,13 @@ case class UnconfirmedTransaction(transaction: ErgoTransaction, * Updates cost and last checked time of unconfirmed transaction */ def withCost(cost: Int): UnconfirmedTransaction = { - copy(lastCost = Some(cost), lastCheckedTime = System.currentTimeMillis()) + new UnconfirmedTransaction( + transaction, + lastCost = Some(cost), + createdTime, + lastCheckedTime = System.currentTimeMillis(), + transactionBytes, + source) } override def equals(obj: Any): Boolean = obj match { @@ -42,12 +48,12 @@ object UnconfirmedTransaction { def apply(tx: ErgoTransaction, source: Option[ConnectedPeer]): UnconfirmedTransaction = { val now = System.currentTimeMillis() - UnconfirmedTransaction(tx, None, now, now, Some(tx.bytes), source) + new UnconfirmedTransaction(tx, None, now, now, Some(tx.bytes), source) } def apply(tx: ErgoTransaction, txBytes: Array[Byte], source: Option[ConnectedPeer]): UnconfirmedTransaction = { val now = System.currentTimeMillis() - UnconfirmedTransaction(tx, None, now, now, Some(txBytes), source) + new UnconfirmedTransaction(tx, None, now, now, Some(txBytes), source) } } diff --git a/src/main/scala/org/ergoplatform/nodeView/mempool/ErgoMemPool.scala b/src/main/scala/org/ergoplatform/nodeView/mempool/ErgoMemPool.scala index 34bbdc6466..8e5431c369 100644 --- a/src/main/scala/org/ergoplatform/nodeView/mempool/ErgoMemPool.scala +++ b/src/main/scala/org/ergoplatform/nodeView/mempool/ErgoMemPool.scala @@ -88,17 +88,13 @@ class ErgoMemPool private[mempool](private[mempool] val pool: OrderedTxPool, /** * Method to put a transaction into the memory pool. Validation of the transactions against - * the state is done in NodeVieHolder. This put() method can check whether a transaction is valid + * the state is done in NodeViewHolder. This put() method can check whether a transaction is valid * @param unconfirmedTx * @return Success(updatedPool), if transaction successfully added to the pool, Failure(_) otherwise */ def put(unconfirmedTx: UnconfirmedTransaction): ErgoMemPool = { - if (!pool.contains(unconfirmedTx.id)) { - val updatedPool = pool.put(unconfirmedTx, feeFactor(unconfirmedTx)) - new ErgoMemPool(updatedPool, stats, sortingOption) - } else { - this - } + val updatedPool = pool.put(unconfirmedTx, feeFactor(unconfirmedTx)) + new ErgoMemPool(updatedPool, stats, sortingOption) } def put(txs: TraversableOnce[UnconfirmedTransaction]): ErgoMemPool = { @@ -139,7 +135,8 @@ class ErgoMemPool private[mempool](private[mempool] val pool: OrderedTxPool, case None => log.warn(s"pool.get failed for $unconfirmedTransactionId") pool.orderedTransactions.valuesIterator.find(_.id == unconfirmedTransactionId) match { - case Some(utx) => invalidate(utx) + case Some(utx) => + invalidate(utx) case None => log.warn(s"Can't invalidate transaction $unconfirmedTransactionId as it is not in the pool") this diff --git a/src/main/scala/org/ergoplatform/nodeView/mempool/OrderedTxPool.scala b/src/main/scala/org/ergoplatform/nodeView/mempool/OrderedTxPool.scala index 30b92bcdfc..127edd8201 100644 --- a/src/main/scala/org/ergoplatform/nodeView/mempool/OrderedTxPool.scala +++ b/src/main/scala/org/ergoplatform/nodeView/mempool/OrderedTxPool.scala @@ -17,12 +17,12 @@ import scala.collection.immutable.TreeMap * @param outputs - mapping `box.id` -> `WeightedTxId(tx.id,tx.weight)` required for getting a transaction by its output box * @param inputs - mapping `box.id` -> `WeightedTxId(tx.id,tx.weight)` required for getting a transaction by its input box id */ -case class OrderedTxPool(orderedTransactions: TreeMap[WeightedTxId, UnconfirmedTransaction], - transactionsRegistry: TreeMap[ModifierId, WeightedTxId], - invalidatedTxIds: ApproximateCacheLike[String], - outputs: TreeMap[BoxId, WeightedTxId], - inputs: TreeMap[BoxId, WeightedTxId]) - (implicit settings: ErgoSettings) extends ScorexLogging { +class OrderedTxPool(val orderedTransactions: TreeMap[WeightedTxId, UnconfirmedTransaction], + val transactionsRegistry: TreeMap[ModifierId, WeightedTxId], + val invalidatedTxIds: ApproximateCacheLike[String], + val outputs: TreeMap[BoxId, WeightedTxId], + val inputs: TreeMap[BoxId, WeightedTxId]) + (implicit settings: ErgoSettings) extends ScorexLogging { import OrderedTxPool.weighted @@ -66,14 +66,26 @@ case class OrderedTxPool(orderedTransactions: TreeMap[WeightedTxId, UnconfirmedT */ def put(unconfirmedTx: UnconfirmedTransaction, feeFactor: Int): OrderedTxPool = { val tx = unconfirmedTx.transaction - val wtx = weighted(tx, feeFactor) - val newPool = OrderedTxPool( - orderedTransactions.updated(wtx, unconfirmedTx), - transactionsRegistry.updated(wtx.id, wtx), - invalidatedTxIds, - outputs ++ tx.outputs.map(_.id -> wtx), - inputs ++ tx.inputs.map(_.boxId -> wtx) - ).updateFamily(tx, wtx.weight, System.currentTimeMillis(), 0) + + val newPool = transactionsRegistry.get(tx.id) match { + case Some(wtx) => + new OrderedTxPool( + orderedTransactions.updated(wtx, unconfirmedTx), + transactionsRegistry, + invalidatedTxIds, + outputs, + inputs + ) + case None => + val wtx = weighted(tx, feeFactor) + new OrderedTxPool( + orderedTransactions.updated(wtx, unconfirmedTx), + transactionsRegistry.updated(wtx.id, wtx), + invalidatedTxIds, + outputs ++ tx.outputs.map(_.id -> wtx), + inputs ++ tx.inputs.map(_.boxId -> wtx) + ).updateFamily(tx, wtx.weight, System.currentTimeMillis(), 0) + } if (newPool.orderedTransactions.size > mempoolCapacity) { val victim = newPool.orderedTransactions.last._2 newPool.remove(victim) @@ -94,7 +106,7 @@ case class OrderedTxPool(orderedTransactions: TreeMap[WeightedTxId, UnconfirmedT def remove(tx: ErgoTransaction): OrderedTxPool = { transactionsRegistry.get(tx.id) match { case Some(wtx) => - OrderedTxPool( + new OrderedTxPool( orderedTransactions - wtx, transactionsRegistry - tx.id, invalidatedTxIds, @@ -107,11 +119,14 @@ case class OrderedTxPool(orderedTransactions: TreeMap[WeightedTxId, UnconfirmedT def remove(utx: UnconfirmedTransaction): OrderedTxPool = remove(utx.transaction) + /** + * Remove transaction from the pool and add it to invalidated transaction ids cache + */ def invalidate(unconfirmedTx: UnconfirmedTransaction): OrderedTxPool = { val tx = unconfirmedTx.transaction transactionsRegistry.get(tx.id) match { case Some(wtx) => - OrderedTxPool( + new OrderedTxPool( orderedTransactions - wtx, transactionsRegistry - tx.id, invalidatedTxIds.put(tx.id), @@ -119,17 +134,20 @@ case class OrderedTxPool(orderedTransactions: TreeMap[WeightedTxId, UnconfirmedT inputs -- tx.inputs.map(_.boxId) ).updateFamily(tx, -wtx.weight, System.currentTimeMillis(), depth = 0) case None => - OrderedTxPool(orderedTransactions, transactionsRegistry, invalidatedTxIds.put(tx.id), outputs, inputs) + if (orderedTransactions.valuesIterator.exists(utx => utx.id == tx.id)) { + new OrderedTxPool( + orderedTransactions.filter(_._2.id != tx.id), + transactionsRegistry - tx.id, + invalidatedTxIds.put(tx.id), + outputs -- tx.outputs.map(_.id), + inputs -- tx.inputs.map(_.boxId) + ) + } else { + new OrderedTxPool(orderedTransactions, transactionsRegistry, invalidatedTxIds.put(tx.id), outputs, inputs) + } } } - def filter(condition: UnconfirmedTransaction => Boolean): OrderedTxPool = { - orderedTransactions.foldLeft(this)((pool, entry) => { - val tx = entry._2 - if (condition(tx)) pool else pool.remove(tx) - }) - } - /** * Do not place transaction in the pool if the transaction known to be invalid, pool already has it, or the pool * is overfull. @@ -175,13 +193,14 @@ case class OrderedTxPool(orderedTransactions: TreeMap[WeightedTxId, UnconfirmedT this } else { - val uniqueTxIds: Set[WeightedTxId] = tx.inputs.flatMap(input => this.outputs.get(input.boxId))(collection.breakOut) + val uniqueTxIds: Set[WeightedTxId] = tx.inputs.flatMap(input => this.outputs.get(input.boxId)).toSet val parentTxs = uniqueTxIds.flatMap(wtx => this.orderedTransactions.get(wtx).map(ut => wtx -> ut)) parentTxs.foldLeft(this) { case (pool, (wtx, ut)) => val parent = ut.transaction val newWtx = WeightedTxId(wtx.id, wtx.weight + weight, wtx.feePerFactor, wtx.created) - val newPool = OrderedTxPool(pool.orderedTransactions - wtx + (newWtx -> ut), + val newPool = new OrderedTxPool( + pool.orderedTransactions - wtx + (newWtx -> ut), pool.transactionsRegistry.updated(parent.id, newWtx), invalidatedTxIds, parent.outputs.foldLeft(pool.outputs)((newOutputs, box) => newOutputs.updated(box.id, newWtx)), @@ -220,7 +239,7 @@ object OrderedTxPool { val cacheSettings = settings.cacheSettings.mempool val frontCacheSize = cacheSettings.invalidModifiersCacheSize val frontCacheExpiration = cacheSettings.invalidModifiersCacheExpiration - OrderedTxPool( + new OrderedTxPool( TreeMap.empty[WeightedTxId, UnconfirmedTransaction], TreeMap.empty[ModifierId, WeightedTxId], ExpiringApproximateCache.empty(frontCacheSize, frontCacheExpiration), diff --git a/src/test/scala/org/ergoplatform/nodeView/mempool/ErgoMemPoolSpec.scala b/src/test/scala/org/ergoplatform/nodeView/mempool/ErgoMemPoolSpec.scala index faaf40218e..c68d279f50 100644 --- a/src/test/scala/org/ergoplatform/nodeView/mempool/ErgoMemPoolSpec.scala +++ b/src/test/scala/org/ergoplatform/nodeView/mempool/ErgoMemPoolSpec.scala @@ -378,6 +378,20 @@ class ErgoMemPoolSpec extends AnyFlatSpec pool.size shouldBe 0 pool.stats.takenTxns shouldBe (family_depth + 1) * txs.size } + + it should "put not adding transaction twice" in { + val pool = ErgoMemPool.empty(settings).pool + val tx = invalidErgoTransactionGen.sample.get + val now = System.currentTimeMillis() + + val utx1 = new UnconfirmedTransaction(tx, None, now, now, None, None) + val utx2 = new UnconfirmedTransaction(tx, None, now, now, None, None) + val utx3 = new UnconfirmedTransaction(tx, None, now + 1, now + 1, None, None) + val updPool = pool.put(utx1, 100).remove(utx1).put(utx2, 500).put(utx3, 5000) + updPool.size shouldBe 1 + updPool.get(utx3.id).get.lastCheckedTime shouldBe (now + 1) + } + }