diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala index 7f29d58733b..e55e2c0f916 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala @@ -266,6 +266,7 @@ class DeltaLog private( * the new protocol version is not a superset of the original one used by the snapshot. */ def upgradeProtocol( + catalogTable: Option[CatalogTable], snapshot: Snapshot, newVersion: Protocol): Unit = { val currentVersion = snapshot.protocol @@ -277,7 +278,7 @@ class DeltaLog private( throw new ProtocolDowngradeException(currentVersion, newVersion) } - val txn = startTransaction(Some(snapshot)) + val txn = startTransaction(catalogTable, Some(snapshot)) try { SchemaMergingUtils.checkColumnNameDuplication(txn.metadata.schema, "in the table schema") } catch { @@ -288,11 +289,6 @@ class DeltaLog private( logConsole(s"Upgraded table at $dataPath to $newVersion.") } - // Test-only!! - private[delta] def upgradeProtocol(newVersion: Protocol): Unit = { - upgradeProtocol(unsafeVolatileSnapshot, newVersion) - } - /** * Get all actions starting from "startVersion" (inclusive). If `startVersion` doesn't exist, * return an empty Iterator. diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/CreateDeltaTableCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/CreateDeltaTableCommand.scala index 8e0722d95ab..2865dbab040 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/CreateDeltaTableCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/CreateDeltaTableCommand.scala @@ -612,7 +612,7 @@ case class CreateDeltaTableCommand( deltaLog: DeltaLog, tableWithLocation: CatalogTable, snapshotOpt: Option[Snapshot] = None): OptimisticTransaction = { - val txn = deltaLog.startTransaction(snapshotOpt) + val txn = deltaLog.startTransaction(None, snapshotOpt) // During CREATE/REPLACE, we synchronously run conversion (if Uniform is enabled) so // we always remove the post commit hook here. diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala index 1131b009888..98e8005e475 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/RestoreTableCommand.scala @@ -113,7 +113,7 @@ case class RestoreTableCommand( require(versionToRestore < latestVersion, s"Version to restore ($versionToRestore)" + s"should be less then last available version ($latestVersion)") - deltaLog.withNewTransaction { txn => + deltaLog.withNewTransaction(sourceTable.catalogTable) { txn => val latestSnapshot = txn.snapshot val snapshotToRestore = deltaLog.getSnapshotAt(versionToRestore) val latestSnapshotFiles = latestSnapshot.allFiles diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/alterDeltaTableCommands.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/alterDeltaTableCommands.scala index 2700bfb3522..3bcc40d7900 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/alterDeltaTableCommands.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/alterDeltaTableCommands.scala @@ -231,7 +231,8 @@ case class AlterTableDropFeatureDeltaCommand( val log = table.deltaLog val snapshot = log.update(checkIfUpdatedSinceTs = Some(snapshotRefreshStartTime)) val emptyCommitTS = System.nanoTime() - log.startTransaction(Some(snapshot)).commit(Nil, DeltaOperations.EmptyCommit) + log.startTransaction(table.catalogTable, Some(snapshot)) + .commit(Nil, DeltaOperations.EmptyCommit) log.checkpoint(log.update(checkIfUpdatedSinceTs = Some(emptyCommitTS))) } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaAlterTableTests.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaAlterTableTests.scala index d4f8cb7fdba..2a99f3baf95 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaAlterTableTests.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaAlterTableTests.scala @@ -1549,7 +1549,8 @@ trait DeltaAlterTableByPathTests extends DeltaAlterTableTestBase { override protected def createTable(schema: String, tblProperties: Map[String, String]): String = { val tmpDir = Utils.createTempDir().getCanonicalPath val (deltaLog, snapshot) = getDeltaLogWithSnapshot(tmpDir) - val txn = deltaLog.startTransaction(Some(snapshot)) + // This is a path-based table so we don't need to pass the catalogTable here + val txn = deltaLog.startTransaction(None, Some(snapshot)) val metadata = Metadata( schemaString = StructType.fromDDL(schema).json, configuration = tblProperties) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/rowtracking/RowTrackingConflictResolutionSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/rowtracking/RowTrackingConflictResolutionSuite.scala index 0f3fff7b7f7..4221512c5d9 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/rowtracking/RowTrackingConflictResolutionSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/rowtracking/RowTrackingConflictResolutionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.delta.rowtracking import org.apache.spark.sql.delta.{DeltaLog, DeltaOperations, RowId, RowTrackingFeature} import org.apache.spark.sql.delta.actions.{Action, AddFile} import org.apache.spark.sql.delta.rowid.RowIdTestUtils +import org.apache.spark.sql.delta.test.DeltaTestImplicits._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.TableIdentifier diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/test/DeltaTestImplicits.scala b/spark/src/test/scala/org/apache/spark/sql/delta/test/DeltaTestImplicits.scala index c8de2f3a017..37cc2027b8f 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/test/DeltaTestImplicits.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/test/DeltaTestImplicits.scala @@ -104,6 +104,14 @@ object DeltaTestImplicits { def enableExpiredLogCleanup(): Boolean = { deltaLog.enableExpiredLogCleanup(snapshot.metadata) } + + def upgradeProtocol(newVersion: Protocol): Unit = { + upgradeProtocol(deltaLog.unsafeVolatileSnapshot, newVersion) + } + + def upgradeProtocol(snapshot: Snapshot, newVersion: Protocol): Unit = { + deltaLog.upgradeProtocol(None, snapshot, newVersion) + } } implicit class DeltaTableV2ObjectTestHelper(dt: DeltaTableV2.type) {