From dcd598a07c768b878be218bab7170659fb0a4797 Mon Sep 17 00:00:00 2001
From: Carmen Kwan <carmen.kwan@databricks.com>
Date: Sat, 4 May 2024 00:46:44 +0200
Subject: [PATCH] Python DeltaTableBuilder API

---
 python/delta/tables.py                | 78 ++++++++++++++++++++++++---
 python/delta/tests/test_deltatable.py | 51 +++++++++++++++++-
 2 files changed, 120 insertions(+), 9 deletions(-)

diff --git a/python/delta/tables.py b/python/delta/tables.py
index f9824eaa9d1..b4f328f4b98 100644
--- a/python/delta/tables.py
+++ b/python/delta/tables.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 #
 
+from dataclasses import dataclass
 from typing import (
     TYPE_CHECKING, cast, overload, Any, Dict, Iterable, Optional, Union, NoReturn, List, Tuple
 )
@@ -25,7 +26,7 @@
 
 from pyspark import since
 from pyspark.sql import Column, DataFrame, functions, SparkSession
-from pyspark.sql.types import DataType, StructType, StructField
+from pyspark.sql.types import DataType, StructType, StructField, LongType
 
 
 if TYPE_CHECKING:
@@ -1060,6 +1061,19 @@ def __getNotMatchedBySourceBuilder(
                 DeltaTable._condition_to_jcolumn(condition))
 
 
+@dataclass
+class IdentityGenerator:
+    """
+    Identity generator specifications for the identity column in the Delta table.
+    :param start: the start for the identity column. Default is 1.
+    :type start: int
+    :param step: the step for the identity column. Default is 1.
+    :type step: int
+    """
+    start: int = 1
+    step: int = 1
+
+
 class DeltaTableBuilder(object):
     """
     Builder to specify how to create / replace a Delta table.
@@ -1164,7 +1178,8 @@ def addColumn(
         colName: str,
         dataType: Union[str, DataType],
         nullable: bool = True,
-        generatedAlwaysAs: Optional[str] = None,
+        generatedAlwaysAs: Optional[Union[str, IdentityGenerator]] = None,
+        generatedByDefaultAs: Optional[IdentityGenerator] = None,
         comment: Optional[str] = None,
     ) -> "DeltaTableBuilder":
         """
@@ -1177,9 +1192,15 @@ def addColumn(
         :param nullable: whether column is nullable
         :type nullable: bool
         :param generatedAlwaysAs: a SQL expression if the column is always generated
-                                  as a function of other columns.
+                                  as a function of other columns;
+                                  an IdentityGenerator object if the column is always
+                                  generated using identity generator
+                                  See online documentation for details on Generated Columns.
+        :type generatedAlwaysAs: str or delta.tables.IdentityGenerator
+        :param generatedByDefaultAs: an IdentityGenerator object to generate identity values
+                                     if the user does not provide values for the column
                                   See online documentation for details on Generated Columns.
-        :type generatedAlwaysAs: str
+        :type generatedByDefaultAs: delta.tables.IdentityGenerator
         :param comment: the column comment
         :type comment: str
 
@@ -1203,11 +1224,52 @@ def addColumn(
         if type(nullable) is not bool:
             self._raise_type_error("Column nullable must be bool.", [nullable])
         _col_jbuilder = _col_jbuilder.nullable(nullable)
+
+        if generatedAlwaysAs is not None and generatedByDefaultAs is not None:
+            raise ValueError(
+                "generatedByDefaultAs cannot be set with generatedAlwaysAs.",
+                [generatedByDefaultAs, generatedAlwaysAs]
+            )
         if generatedAlwaysAs is not None:
-            if type(generatedAlwaysAs) is not str:
-                self._raise_type_error("Column generation expression must be str.",
-                                       [generatedAlwaysAs])
-            _col_jbuilder = _col_jbuilder.generatedAlwaysAs(generatedAlwaysAs)
+            if type(generatedAlwaysAs) is str:
+                _col_jbuilder = _col_jbuilder.generatedAlwaysAs(generatedAlwaysAs)
+            elif type(generatedAlwaysAs) is IdentityGenerator:
+                if dataType != LongType():
+                    self._raise_type_error(
+                        "Column identity generation requires the column to be integer.",
+                        [dataType],
+                    )
+                if generatedAlwaysAs.step == 0:
+                    raise ValueError(
+                        "Column identity generation requires step to be non-zero."
+                    )
+                _col_jbuilder = _col_jbuilder.generatedAlwaysAsIdentity(
+                    generatedAlwaysAs.start, generatedAlwaysAs.step
+                )
+            else:
+                self._raise_type_error(
+                    "Column generation expression must be str or IdentityGenerator.",
+                    [generatedAlwaysAs]
+                )
+        elif generatedByDefaultAs is not None:
+            if type(generatedByDefaultAs) is not IdentityGenerator:
+                self._raise_type_error(
+                    "Column generation by default expression must be IdentityGenerator.",
+                    [generatedByDefaultAs]
+                )
+            if dataType != LongType():
+                self._raise_type_error(
+                    "Column identity generation requires the column to be integer.",
+                    [dataType],
+                )
+            if generatedByDefaultAs.step == 0:
+                raise ValueError(
+                    "Column identity generation requires step to be non-zero."
+                )
+            _col_jbuilder = _col_jbuilder.generatedByDefaultAsIdentity(
+                generatedByDefaultAs.start, generatedByDefaultAs.step
+            )
+
         if comment is not None:
             if type(comment) is not str:
                 self._raise_type_error("Column comment must be str.", [comment])
diff --git a/python/delta/tests/test_deltatable.py b/python/delta/tests/test_deltatable.py
index 28b960e57d1..58218823631 100644
--- a/python/delta/tests/test_deltatable.py
+++ b/python/delta/tests/test_deltatable.py
@@ -28,7 +28,7 @@
 from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DataType
 from pyspark.sql.utils import AnalysisException, ParseException
 
-from delta.tables import DeltaTable, DeltaTableBuilder, DeltaOptimizeBuilder
+from delta.tables import DeltaTable, DeltaTableBuilder, DeltaOptimizeBuilder, IdentityGenerator
 from delta.testing.utils import DeltaTestCase
 
 
@@ -977,6 +977,55 @@ def test_delta_table_builder_with_bad_args(self) -> None:
         with self.assertRaises(TypeError):
             builder.addColumn("a", "int", generatedAlwaysAs=1)  # type: ignore[arg-type]
 
+        # bad generatedAlwaysAs - column data type must be Long
+        with self.assertRaises(TypeError):
+            builder.addColumn(
+                "a",
+                "int",
+                generatedAlwaysAs=IdentityGenerator()
+            )  # type: ignore[arg-type]
+
+        # bad generatedAlwaysAs - step can't be 0
+        with self.assertRaises(ValueError):
+            builder.addColumn(
+                "a",
+                LongType,
+                generatedAlwaysAs=IdentityGenerator(step=0)
+            )  # type: ignore[arg-type]
+
+        # bad generatedByDefaultAs - can't be set with generatedAlwaysAs
+        with self.assertRaises(ValueError):
+            builder.addColumn(
+                "a",
+                LongType,
+                generatedAlwaysAs="",
+                generatedByDefaultAs=IdentityGenerator()
+            )  # type: ignore[arg-type]
+
+        # bad generatedByDefaultAs - argument type must be IdentityGenerator
+        with self.assertRaises(TypeError):
+            builder.addColumn(
+                "a",
+                LongType,
+                generatedByDefaultAs=""
+            )  # type: ignore[arg-type]
+
+        # bad generatedByDefaultAs - column data type must be Long
+        with self.assertRaises(TypeError):
+            builder.addColumn(
+                "a",
+                "int",
+                generatedByDefaultAs=IdentityGenerator()
+            )  # type: ignore[arg-type]
+
+        # bad generatedByDefaultAs - step can't be 0
+        with self.assertRaises(ValueError):
+            builder.addColumn(
+                "a",
+                LongType,
+                generatedByDefaultAs=IdentityGenerator(step=0)
+            )  # type: ignore[arg-type]
+
         # bad nullable
         with self.assertRaises(TypeError):
             builder.addColumn("a", "int", nullable=1)  # type: ignore[arg-type]