diff --git a/.github/docker/build_image.sh b/.github/docker/build_image.sh index a126fb02..11fbb404 100755 --- a/.github/docker/build_image.sh +++ b/.github/docker/build_image.sh @@ -33,8 +33,8 @@ cd ../.. # Build docker image docker build \ ---build-arg SPARK_URL=https://archive.apache.org/dist/spark/spark-3.4.0/spark-3.4.0-bin-hadoop3.tgz \ ---build-arg SPARK_BINARY_NAME=spark-3.4.0-bin-hadoop3.tgz \ +--build-arg SPARK_URL=https://archive.apache.org/dist/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz \ +--build-arg SPARK_BINARY_NAME=spark-3.5.0-bin-hadoop3.tgz \ --build-arg JDBC_URL=https://repo1.maven.org/maven2/net/snowflake/snowflake-jdbc/${TEST_JDBC_VERSION}/$JDBC_JAR_NAME \ --build-arg JDBC_BINARY_NAME=$JDBC_JAR_NAME \ --build-arg SPARK_CONNECTOR_LOCATION=target/scala-${TEST_SCALA_VERSION}/$SPARK_CONNECTOR_JAR_NAME \ diff --git a/.github/workflows/ClusterTest.yml b/.github/workflows/ClusterTest.yml deleted file mode 100644 index a46ef151..00000000 --- a/.github/workflows/ClusterTest.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Spark Cluster Test - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -jobs: - test: - name: CT - runs-on: ubuntu-latest - strategy: - matrix: - scala_version: [ '2.12.11' ] - spark_version: [ '3.4.0' ] - use_copy_unload: [ 'true' ] - cloud_provider: [ 'gcp' ] - env: - SNOWFLAKE_TEST_CONFIG_SECRET: ${{ secrets.SNOWFLAKE_TEST_CONFIG_SECRET }} - TEST_SPARK_VERSION: '3.4' - DOCKER_IMAGE_TAG: 'snowflakedb/spark-base:3.4.0' - TEST_SCALA_VERSION: '2.12' - TEST_COMPILE_SCALA_VERSION: '2.12.11' - TEST_SPARK_CONNECTOR_VERSION: '2.16.0' - TEST_JDBC_VERSION: '3.17.0' - - steps: - - uses: actions/checkout@v2 - - name: Set up JDK 1.8 - uses: actions/setup-java@v1 - with: - java-version: 1.8 - - name: The test depens on GITHUB_RUN_ID and GITHUB_SHA - run: echo GITHUB_RUN_ID=${GITHUB_RUN_ID} GITHUB_SHA=${GITHUB_SHA} SCALA=$TEST_SCALA_VERSION SCALA_COMPILE=$TEST_COMPILE_SCALA_VERSION SPARK_CONN=$TEST_SPARK_CONNECTOR_VERSION SPARK=$TEST_SPARK_VERSION JDBC=$TEST_JDBC_VERSION - - name: Build docker image for spark base - run: .github/docker/build_image.sh - - name: Run docker-compose test - run: docker-compose -f .github/docker/docker-compose.yml up -d - - name: Decrypt snowflake.json for testing - run: ./.github/scripts/decrypt_secret.sh snowflake.travis.json snowflake.travis.json.gpg - - name: Wait test done and check result - run: .github/docker/check_result.sh - - name: Clean up the container and image - run: .github/docker/cleanup_docker.sh diff --git a/.github/workflows/IntegrationTest_2.12.yml b/.github/workflows/IntegrationTest_2.12.yml index 80235ce8..ba9b6f38 100644 --- a/.github/workflows/IntegrationTest_2.12.yml +++ b/.github/workflows/IntegrationTest_2.12.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: scala_version: [ '2.12.11' ] - spark_version: [ '3.4.0' ] + spark_version: [ '3.5.0' ] use_copy_unload: [ 'true', 'false' ] cloud_provider: [ 'aws', 'azure' ] # run_query_in_async can be removed after async mode is stable diff --git a/.github/workflows/IntegrationTest_2.13.yml b/.github/workflows/IntegrationTest_2.13.yml index 99999310..56913a92 100644 --- a/.github/workflows/IntegrationTest_2.13.yml +++ b/.github/workflows/IntegrationTest_2.13.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: scala_version: [ '2.13.9' ] - spark_version: [ '3.4.0' ] + spark_version: [ '3.5.0' ] use_copy_unload: [ 'true', 'false' ] cloud_provider: [ 'aws', 'azure' ] # run_query_in_async can be removed after async mode is stable diff --git a/.github/workflows/IntegrationTest_gcp_2.12.yml b/.github/workflows/IntegrationTest_gcp_2.12.yml index 674e0b8f..39f00ac8 100644 --- a/.github/workflows/IntegrationTest_gcp_2.12.yml +++ b/.github/workflows/IntegrationTest_gcp_2.12.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: scala_version: [ '2.12.11' ] - spark_version: [ '3.4.0' ] + spark_version: [ '3.5.0' ] use_copy_unload: [ 'false' ] cloud_provider: [ 'gcp' ] # run_query_in_async can be removed after async mode is stable diff --git a/.github/workflows/IntegrationTest_gcp_2.13.yml b/.github/workflows/IntegrationTest_gcp_2.13.yml index 4a0f2b64..a9a1b11b 100644 --- a/.github/workflows/IntegrationTest_gcp_2.13.yml +++ b/.github/workflows/IntegrationTest_gcp_2.13.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: scala_version: [ '2.13.9' ] - spark_version: [ '3.4.0' ] + spark_version: [ '3.5.0' ] use_copy_unload: [ 'false' ] cloud_provider: [ 'gcp' ] # run_query_in_async can be removed after async mode is stable diff --git a/ClusterTest/build.sbt b/ClusterTest/build.sbt deleted file mode 100644 index 71f0767d..00000000 --- a/ClusterTest/build.sbt +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2015-2019 Snowflake Computing - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -val sparkConnectorVersion = "2.16.0" -val scalaVersionMajor = "2.12" -val sparkVersionMajor = "3.4" -val sparkVersion = s"${sparkVersionMajor}.0" -val testSparkVersion = sys.props.get("spark.testVersion").getOrElse(sparkVersion) - -unmanagedJars in Compile += file(s"../target/scala-${scalaVersionMajor}/" + - s"spark-snowflake_${scalaVersionMajor}-${sparkConnectorVersion}-spark_${sparkVersionMajor}.jar") - -lazy val root = project.withId("spark-snowflake").in(file(".")) - .settings( - name := "ClusterTest", - organization := "net.snowflake", - version := s"1.0", - scalaVersion := sys.props.getOrElse("SPARK_SCALA_VERSION", default = "2.12.11"), - crossScalaVersions := Seq("2.12.11"), - javacOptions ++= Seq("-source", "1.8", "-target", "1.8"), - licenses += "Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0"), - credentials += Credentials(Path.userHome / ".ivy2" / ".credentials"), - resolvers += - "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", - libraryDependencies ++= Seq( - "net.snowflake" % "snowflake-jdbc" % "3.17.0", - "org.apache.commons" % "commons-lang3" % "3.5" % "provided, runtime", - "org.apache.spark" %% "spark-core" % testSparkVersion % "provided, runtime", - "org.apache.spark" %% "spark-sql" % testSparkVersion % "provided, runtime", - "org.apache.spark" %% "spark-hive" % testSparkVersion % "provided, runtime" - ), - ) diff --git a/ClusterTest/build/sbt-launch-1.2.8.jar b/ClusterTest/build/sbt-launch-1.2.8.jar deleted file mode 100644 index 26ab884e..00000000 Binary files a/ClusterTest/build/sbt-launch-1.2.8.jar and /dev/null differ diff --git a/ClusterTest/project/build.properties b/ClusterTest/project/build.properties deleted file mode 100644 index c0bab049..00000000 --- a/ClusterTest/project/build.properties +++ /dev/null @@ -1 +0,0 @@ -sbt.version=1.2.8 diff --git a/ClusterTest/project/plugins.sbt b/ClusterTest/project/plugins.sbt deleted file mode 100644 index c6f66607..00000000 --- a/ClusterTest/project/plugins.sbt +++ /dev/null @@ -1,8 +0,0 @@ -addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.1") - -addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.11") - -addSbtPlugin("org.foundweekends" % "sbt-bintray" % "0.5.4") - -addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.1") - diff --git a/ClusterTest/run_cluster_test.sh b/ClusterTest/run_cluster_test.sh deleted file mode 100755 index 987e540d..00000000 --- a/ClusterTest/run_cluster_test.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash - -echo "This script run inside of the docker container." - -echo "Below variables must be set by workflow and docker-conpose file" -echo "GITHUB_RUN_ID=$GITHUB_RUN_ID" -echo "GITHUB_SHA=$GITHUB_SHA" -echo "TEST_SPARK_CONNECTOR_VERSION=$TEST_SPARK_CONNECTOR_VERSION" -echo "TEST_SCALA_VERSION=$TEST_SCALA_VERSION" -echo "TEST_COMPILE_SCALA_VERSION=$TEST_COMPILE_SCALA_VERSION" -echo "TEST_JDBC_VERSION=$TEST_JDBC_VERSION" -echo "TEST_SPARK_VERSION=$TEST_SPARK_VERSION" -echo "SNOWFLAKE_TEST_CONFIG=$SNOWFLAKE_TEST_CONFIG" - -export SPARK_HOME=/users/spark -export SPARK_WORKDIR=/users/spark/work - -export SPARK_CONNECTOR_JAR_NAME=spark-snowflake_${TEST_SCALA_VERSION}-${TEST_SPARK_CONNECTOR_VERSION}-spark_${TEST_SPARK_VERSION}.jar -export JDBC_JAR_NAME=snowflake-jdbc-${TEST_JDBC_VERSION}.jar - -# Check test file exists -ls -al $SNOWFLAKE_TEST_CONFIG \ - $SPARK_WORKDIR/${SPARK_CONNECTOR_JAR_NAME} \ - $SPARK_WORKDIR/${JDBC_JAR_NAME} \ - $SPARK_WORKDIR/clustertest_${TEST_SCALA_VERSION}-1.0.jar \ - $SPARK_WORKDIR/ClusterTest.py - -echo "Important: if new test cases are added, script .github/docker/check_result.sh MUST be updated" -# Run pyspark test -# python3 has been installed in the container -$SPARK_HOME/bin/spark-submit \ - --jars $SPARK_WORKDIR/${SPARK_CONNECTOR_JAR_NAME},$SPARK_WORKDIR/${JDBC_JAR_NAME} \ - --conf "spark.pyspark.python=python3" --conf "spark.pyspark.driver.python=python3" \ - --conf "spark.executor.extraJavaOptions=-Djava.io.tmpdir=$SPARK_WORKDIR -Dnet.snowflake.jdbc.loggerImpl=net.snowflake.client.log.SLF4JLogger -Dlog4j.configuration=file://${SPARK_HOME}/conf/log4j_executor.properties" \ - --conf "spark.driver.extraJavaOptions=-Djava.io.tmpdir=$SPARK_WORKDIR -Dnet.snowflake.jdbc.loggerImpl=net.snowflake.client.log.SLF4JLogger -Dlog4j.configuration=file://${SPARK_HOME}/conf/log4j_driver.properties" \ - --master spark://master:7077 --deploy-mode client \ - $SPARK_WORKDIR/ClusterTest.py remote - -$SPARK_HOME/bin/spark-submit \ - --jars $SPARK_WORKDIR/${SPARK_CONNECTOR_JAR_NAME},$SPARK_WORKDIR/${JDBC_JAR_NAME} \ - --conf "spark.executor.extraJavaOptions=-Djava.io.tmpdir=$SPARK_WORKDIR -Dnet.snowflake.jdbc.loggerImpl=net.snowflake.client.log.SLF4JLogger -Dlog4j.configuration=file://${SPARK_HOME}/conf/log4j_executor.properties" \ - --conf "spark.driver.extraJavaOptions=-Djava.io.tmpdir=$SPARK_WORKDIR -Dnet.snowflake.jdbc.loggerImpl=net.snowflake.client.log.SLF4JLogger -Dlog4j.configuration=file://${SPARK_HOME}/conf/log4j_driver.properties" \ - --master spark://master:7077 --deploy-mode client \ - --class net.snowflake.spark.snowflake.ClusterTest \ - $SPARK_WORKDIR/clustertest_${TEST_SCALA_VERSION}-1.0.jar remote "net.snowflake.spark.snowflake.testsuite.BasicReadWriteSuite;" - -# Low memory partition upload test. Setting executor memory to 900MB. The actual usable heap size can be calculated as (900 - 300) * 0.6 = 360MB. -# Heap size is shared by RDD storage and execution memory together. The partition is 309 MB un compressed and 220 MB compressed. -# In this test we verify that with multipart upload, there is no OOM. Otherwise there is. -$SPARK_HOME/bin/spark-submit \ - --jars $SPARK_WORKDIR/${SPARK_CONNECTOR_JAR_NAME},$SPARK_WORKDIR/${JDBC_JAR_NAME} \ - --conf "spark.executor.extraJavaOptions=-Djava.io.tmpdir=$SPARK_WORKDIR -Dnet.snowflake.jdbc.loggerImpl=net.snowflake.client.log.SLF4JLogger -Dlog4j.configuration=file://${SPARK_HOME}/conf/log4j_executor.properties" \ - --conf "spark.driver.extraJavaOptions=-Djava.io.tmpdir=$SPARK_WORKDIR -Dnet.snowflake.jdbc.loggerImpl=net.snowflake.client.log.SLF4JLogger -Dlog4j.configuration=file://${SPARK_HOME}/conf/log4j_driver.properties" \ - --master spark://master:7077 --deploy-mode client \ - --class net.snowflake.spark.snowflake.ClusterTest \ - --driver-memory 600m \ - --executor-memory 600m \ - $SPARK_WORKDIR/clustertest_${TEST_SCALA_VERSION}-1.0.jar remote "net.snowflake.spark.snowflake.testsuite.LowMemoryStressSuite;" diff --git a/ClusterTest/src/main/python/ClusterTest.py b/ClusterTest/src/main/python/ClusterTest.py deleted file mode 100644 index 82664957..00000000 --- a/ClusterTest/src/main/python/ClusterTest.py +++ /dev/null @@ -1,135 +0,0 @@ -from pyspark import SparkConf, SparkContext -from pyspark.sql import * -from pyspark.sql.types import * -from pyspark import SparkConf, SparkContext -from pyspark.sql.functions import udf -import json -import os -import datetime -import sys - -# Below options are only used by external stage -# sc._jsc.hadoopConfiguration().set("fs.s3n.awsAccessKeyId", "") -# sc._jsc.hadoopConfiguration().set("fs.s3n.awsSecretAccessKey", "") - -# Utility function to get configration info from snowflake.travis.json -def getConfig(filename): - with open(filename) as f: - snowflake_travis_json = json.load(f) - - sfOptions = {} - sfOptions["sfCompress"] = snowflake_travis_json["common"]["sfCompress"] - sfOptions["sfSSL"] = snowflake_travis_json["common"]["sfSSL"] - sfOptions["dbtable"] = snowflake_travis_json["common"]["dbtable"] - sfOptions["runOption"] = snowflake_travis_json["common"]["runOption"] - sfOptions["sfTimeZone"] = snowflake_travis_json["common"]["sfTimeZone"] - sfOptions["sfDatabase"] = snowflake_travis_json["common"]["sfDatabase"] - sfOptions["sfSchema"] = snowflake_travis_json["common"]["sfSchema"] - sfOptions["sfWarehouse"] = snowflake_travis_json["common"]["sfWarehouse"] - sfOptions["sfUser"] = snowflake_travis_json["common"]["sfUser"] - sfOptions["pem_private_key"] = snowflake_travis_json["common"]["pem_private_key"] - - for item in snowflake_travis_json["account_info"]: - if item["name"] == 'aws': - sfOptions["sfURL"] = item["config"]["sfURL"] - data = item["config"]["sfURL"].split(".") - sfOptions["sfAccount"] = data[0] - - return sfOptions - - -# UDF for testing -def udf_plus_one(s): - return s + 1 - - -def writeResult(row): - row_seq = [row] - fields = [StructField("testCaseName", StringType(), True), \ - StructField("testStatus", StringType(), True), \ - StructField("githubRunId", StringType(), True), \ - StructField("commitID", StringType(), True), \ - StructField("testType", StringType(), True), \ - StructField("startTime", StringType(), True), \ - StructField("testRunTime", StringType(), True), \ - StructField("reason", StringType(), True), \ - ] - schema = StructType(fields) - df = spark.createDataFrame(row_seq, schema) - df.write \ - .format(SNOWFLAKE_SOURCE_NAME) \ - .options(**sfOptions) \ - .option("dbtable", targetTableName) \ - .option("sfSchema", targetSchema) \ - .mode('append') \ - .save() - - -# Retrieve file from env SNOWFLAKE_TEST_CONFIG. -filename = os.environ.get('SNOWFLAKE_TEST_CONFIG') -githubsha = os.environ.get('GITHUB_SHA') -githubrunid = os.environ.get('GITHUB_RUN_ID') -targetTableName = 'CLUSTER_TEST_RESULT_TABLE' -targetSchema = 'SPARK_TEST' -startDate = str(datetime.datetime.today()) -SNOWFLAKE_SOURCE_NAME = "net.snowflake.spark.snowflake" - -# sc = SparkContext("local", "Simple App") -# spark = SQLContext(sc) -# spark_conf = SparkConf().setMaster('local').setAppName('PythonClusterTest') -spark_session_builder = SparkSession.builder \ - .appName("PythonClusterTest") - -local_agv = sys.argv[1] -if local_agv.lower() == 'local': - spark_session_builder = spark_session_builder.master('local') - -spark = spark_session_builder.getOrCreate() - -sfOptions = getConfig(filename) - - -spark.udf.register("udfPlusOne", udf_plus_one) - -try: - testdf = spark.read.format(SNOWFLAKE_SOURCE_NAME) \ - .options(**sfOptions) \ - .option("dbtable", "ORDERS") \ - .option("sfSchema", "TPCH_SF1") \ - .load() - - testdf.createOrReplaceTempView("test_table_python") - - row_count = spark.sql("select count(O_TOTALPRICE) from test_table_python").collect()[0][0] - origin_sum = spark.sql("select sum(O_TOTALPRICE) from test_table_python").collect()[0][0] - plus_one_sum = spark.sql("select sum(udfPlusOne(O_TOTALPRICE)) from test_table_python").collect()[0][0] - - print ("row_count ", row_count) - print ("origin_sum ", origin_sum) - print ("plus_one_sum ", plus_one_sum) - - expect_result = int(origin_sum + row_count) - actual_result = int(plus_one_sum) - - print ("expect_result ", expect_result) - print ("actual_result ", actual_result) - - if expect_result == actual_result: - print("test is successful") - result_row = Row("ClusterTest.py", "Success", githubrunid, githubsha, 'Python', startDate, "not collected", - "success") - writeResult(result_row) - else: - reason = "result sum is incorrect, expect_result=" + str(expect_result) + " actual_result=" + str(actual_result) - print("test is fail: ", reason) - result_row = Row("ClusterTest.py", "Fail", githubrunid, githubsha, 'Python', startDate, "not collected", reason) - writeResult(result_row) - - # raise Exception('test exception') - -except Exception as e: - reason = str(e) - print("test raise exception: ", reason) - result_row = Row("ClusterTest.py", "Exception", githubrunid, githubsha, 'Python', startDate, "not collected", - reason) - writeResult(result_row) diff --git a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTest.scala b/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTest.scala deleted file mode 100644 index 5c535ede..00000000 --- a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTest.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright 2015-2020 Snowflake Computing - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package net.snowflake.spark.snowflake - -import net.snowflake.spark.snowflake.testsuite.ClusterTestSuiteBase -import org.slf4j.{Logger, LoggerFactory} -import org.apache.spark.sql.SparkSession - -object ClusterTest { - val log: Logger = LoggerFactory.getLogger(getClass) - - val RemoteMode = "remote" - val LocalMode = "local" - - val TestSuiteSeparator = ";" - - // Driver function to run the test. - def main(args: Array[String]): Unit = { - log.info(s"Test Spark Connector: ${net.snowflake.spark.snowflake.Utils.VERSION}") - - val usage = s"""Two parameters are need: [local | remote] and - | testClassNames (using ';' to separate multiple classes) - |""".stripMargin - log.info(usage) - - if (args.length < 2) { - throw new Exception(s"At least two parameters are need. Usage: $usage") - } - - // Setup Spark session. - // local mode is introduced for debugging purpose - val runMode = args(0) - var sparkSessionBuilder = SparkSession - .builder() - .appName("Spark SQL basic example") - .config("spark.some.config.option", "some-value") - if (runMode.equalsIgnoreCase(LocalMode)) { - sparkSessionBuilder = sparkSessionBuilder - .config("spark.master", "local") - } - val spark = sparkSessionBuilder.getOrCreate() - - // Run specified test suites - val testSuiteNames = args(1).split(TestSuiteSeparator) - for (testSuiteName <- testSuiteNames) { - if (!testSuiteName.trim.isEmpty) { - // Retrieve commit ID from env. - val commitID = scala.util.Properties - .envOrElse(TestUtils.GITHUB_SHA, "commit id not set") - - // val testSuiteName = "net.snowflake.spark.snowflake.testsuite.BasicReadWriteSuite" - val resultBuilder = new ClusterTestResultBuilder() - .withTestType("Scala") - .withTestCaseName(testSuiteName) - .withCommitID(commitID) - .withTestStatus(TestUtils.TEST_RESULT_STATUS_INIT) - .withStartTimeInMill(System.currentTimeMillis()) - .withGithubRunId(TestUtils.githubRunId) - - try { - Class - .forName(testSuiteName) - .newInstance() - .asInstanceOf[ClusterTestSuiteBase] - .run(spark, resultBuilder) - } catch { - case e: Throwable => - log.error(e.getMessage) - resultBuilder - .withTestStatus(TestUtils.TEST_RESULT_STATUS_EXCEPTION) - .withReason(e.getMessage) - } finally { - // Set test end time. - resultBuilder - .withEndTimeInMill(System.currentTimeMillis()) - // Write test result - resultBuilder.build().writeToSnowflake() - } - } - } - - spark.stop() - } -} diff --git a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTestCheckResult.scala b/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTestCheckResult.scala deleted file mode 100644 index 06d526d6..00000000 --- a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTestCheckResult.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright 2015-2020 Snowflake Computing - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package net.snowflake.spark.snowflake - -// This class is used separately with others in .github/docker/check_result.sh -// If more libraries are used, the script needs to download the libraries and -// update classpath. -object ClusterTestCheckResult { - -// Driver function to run the test. - def main(args: Array[String]): Unit = { - val usage = s"Two parameters are need: testCaseCount and timeoutInSeconds" - println(usage) - - if (args.length < 2) { - throw new Exception(s"At least two parameters are need. Usage: $usage") - } - - val testCaseCount = args(0).toInt - var leftTime = args(1).toInt - // Check result in every 10 seconds - val checkInterval: Int = 10000 - val commitID = scala.util.Properties.envOrNone(TestUtils.GITHUB_SHA) - if (commitID.isEmpty) { - throw new Exception( - s"Caller has to set env variable ${TestUtils.GITHUB_SHA}" - ) - } - - val connection = TestUtils.getJDBCConnection(TestUtils.param) - - // Wait for all test cases are done. - var testDone = false - while (!testDone && leftTime > 0) { - // Sleep some time and then check result - Thread.sleep(checkInterval) - leftTime = leftTime - checkInterval / 1000 - - val resultSet = connection - .createStatement() - .executeQuery( - s"""select count(*) from ${TestUtils.CLUSTER_TEST_RESULT_TABLE} - | where githubRunId = '${TestUtils.githubRunId}' - |""".stripMargin - ) - resultSet.next() - - val finishedTestCount = resultSet.getInt(1) - println( - s"Finished test cases: $finishedTestCount, expected : $testCaseCount" - ) - - if (finishedTestCount == testCaseCount) { - testDone = true - } - } - - // Retrieve successful test case count - val resultSet = connection - .createStatement() - .executeQuery( - s"""select count(*) from ${TestUtils.CLUSTER_TEST_RESULT_TABLE} - | where githubRunId = '${TestUtils.githubRunId}' - | and testStatus = '${TestUtils.TEST_RESULT_STATUS_SUCCESS}' - |""".stripMargin - ) - resultSet.next() - val successCount = resultSet.getInt(1) - - // Print out all test cases results - val rs = connection - .createStatement() - .executeQuery(s"""select * from ${TestUtils.CLUSTER_TEST_RESULT_TABLE} - | where githubRunId = '${TestUtils.githubRunId}' - |""".stripMargin) - val rsmd = rs.getMetaData - val columnsNumber = rsmd.getColumnCount - - // Output column name - var sb = new StringBuilder - for (i <- 1 to columnsNumber) { - sb.append(rsmd.getColumnName(i)).append(", ") - } - println(sb.toString()) - - // Output test case result - while (rs.next) { - sb = new StringBuilder - for (i <- 1 to columnsNumber) { - sb.append(rs.getString(i)).append(", ") - } - println(sb.toString()) - } - - if (successCount != testCaseCount) { - throw new Exception( - s"Some test case fail: expected $testCaseCount, actual: $successCount" - ) - } else { - println(s"All test cases are PASS $testCaseCount") - } - } -} diff --git a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTestResult.scala b/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTestResult.scala deleted file mode 100644 index f865e48c..00000000 --- a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTestResult.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2015-2020 Snowflake Computing - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package net.snowflake.spark.snowflake - -import java.time.Instant - -class ClusterTestResult(builder: ClusterTestResultBuilder) { - val testType: String = builder.testType - val testCaseName: String = builder.testCaseName - val testStatus: String = builder.testStatus - val commitID: String = builder.commitID - val githubRunId: String = builder.githubRunId - val startTime: String = - Instant.ofEpochMilli(builder.startTimeInMillis).toString - val testRunTime: String = { - val usedTime = builder.endTimeInMillis - builder.startTimeInMillis - if (usedTime < 0) { - s"Wrong time: Start ${builder.endTimeInMillis} end: ${builder.startTimeInMillis}" - } else if (usedTime < 1000) { - s"$usedTime ms" - } else if (usedTime < 1000 * 60) { - "%.2f seconds".format(usedTime.toDouble / 1000) - } else { - "%.2f minutes".format(usedTime.toDouble / 1000 / 60) - } - } - val reason: String = builder.reason - - def writeToSnowflake(): Unit = { - val connection = TestUtils.getServerConnection(TestUtils.param) - - // Create test result table if it doesn't exist. - if (!DefaultJDBCWrapper.tableExists( - TestUtils.param, - TestUtils.CLUSTER_TEST_RESULT_TABLE - )) { - DefaultJDBCWrapper.executeInterruptibly( - connection, - s"""create table ${TestUtils.CLUSTER_TEST_RESULT_TABLE} ( - | testCaseName String, - | testStatus String, - | githubRunId String, - | commitID String, - | testType String, - | startTime String, - | testRunTime String, - | reason String ) - |""".stripMargin - ) - } - - // Write test result into table - DefaultJDBCWrapper.executeInterruptibly( - connection, - s"""insert into ${TestUtils.CLUSTER_TEST_RESULT_TABLE} values ( - | '$testCaseName' , - | '$testStatus' , - | '$githubRunId' , - | '$commitID' , - | '$testType', - | '$startTime' , - | '$testRunTime' , - | '$reason' - | ) """.stripMargin - ) - - connection.close() - } -} diff --git a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTestResultBuilder.scala b/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTestResultBuilder.scala deleted file mode 100644 index 837bc546..00000000 --- a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/ClusterTestResultBuilder.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright 2015-2020 Snowflake Computing - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package net.snowflake.spark.snowflake - -class ClusterTestResultBuilder() { - private[snowflake] var testType = "Scala" // There are Scala test and Python test. - private[snowflake] var commitID: String = _ - private[snowflake] var githubRunId: String = _ - private[snowflake] var startTimeInMillis: Long = 0 - private[snowflake] var endTimeInMillis: Long = 0 - private[snowflake] var testCaseName: String = _ - private[snowflake] var testStatus: String = "NotStarted" - private[snowflake] var reason: String = "no reason" - - def build(): ClusterTestResult = { - new ClusterTestResult(this) - } - - def withTestType(testType: String): ClusterTestResultBuilder = { - this.testType = testType - this - } - def withGithubRunId(jobStartTime: String): ClusterTestResultBuilder = { - this.githubRunId = jobStartTime - this - } - def withCommitID(commitID: String): ClusterTestResultBuilder = { - this.commitID = commitID - this - } - def withStartTimeInMill(startTimeInMillis: Long): ClusterTestResultBuilder = { - this.startTimeInMillis = startTimeInMillis - this - } - def withEndTimeInMill(endTimeInMillis: Long): ClusterTestResultBuilder = { - this.endTimeInMillis = endTimeInMillis - this - } - def withTestCaseName(testCaseName: String): ClusterTestResultBuilder = { - this.testCaseName = testCaseName - this - } - def withTestStatus(testStatus: String): ClusterTestResultBuilder = { - this.testStatus = testStatus - this - } - def withReason(reason: String): ClusterTestResultBuilder = { - this.reason = reason - this - } -} diff --git a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/TestUtils.scala b/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/TestUtils.scala deleted file mode 100644 index 5183fc1a..00000000 --- a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/TestUtils.scala +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Copyright 2015-2020 Snowflake Computing - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package net.snowflake.spark.snowflake - -import java.sql.{Connection, DriverManager} -import java.util.Properties - -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.{ - JsonNode, - ObjectMapper -} -import net.snowflake.spark.snowflake.ClusterTest.log -import net.snowflake.spark.snowflake.Parameters.MergedParameters - -import scala.collection.mutable -import scala.io.Source - -object TestUtils { - val SNOWFLAKE_TEST_ACCOUNT = "SNOWFLAKE_TEST_ACCOUNT" - val SNOWFLAKE_TEST_CONFIG = "SNOWFLAKE_TEST_CONFIG" - val CLUSTER_TEST_RESULT_TABLE = "CLUSTER_TEST_RESULT_TABLE" - val GITHUB_SHA = "GITHUB_SHA" - val GITHUB_RUN_ID = "GITHUB_RUN_ID" - val SNOWFLAKE_NAME = "net.snowflake.spark.snowflake" - val JDBC_DRIVER = "net.snowflake.client.jdbc.SnowflakeDriver" - - // Test case result status - val TEST_RESULT_STATUS_INIT = "Initialized" - val TEST_RESULT_STATUS_START = "Started" - val TEST_RESULT_STATUS_SUCCESS = "Success" - val TEST_RESULT_STATUS_FAIL = "Fail" - val TEST_RESULT_STATUS_EXCEPTION = "Exception" - - lazy val githubRunId: String = { - val jobTime = System.getenv(TestUtils.GITHUB_RUN_ID) - if (jobTime == null) { - throw new Exception( - s"env variable ${TestUtils.GITHUB_RUN_ID} needs to be set" - ) - } - jobTime - } - - /** - * read sfOptions from json config e.g. snowflake.travis.json - */ - def loadJsonConfig(configFile: String): Option[Map[String, String]] = { - - var result: Map[String, String] = Map() - - def read(node: JsonNode): Unit = { - val itr = node.fields() - while (itr.hasNext) { - val entry = itr.next() - result = result + (entry.getKey -> entry.getValue.asText()) - } - } - - try { - val jsonConfigFile = Source.fromFile(configFile) - val file = jsonConfigFile.mkString - val mapper: ObjectMapper = new ObjectMapper() - val json = mapper.readTree(file) - val commonConfig = json.get("common") - val accountConfig = json.get("account_info") - val accountName: String = - Option(System.getenv(SNOWFLAKE_TEST_ACCOUNT)).getOrElse("aws") - - log.info(s"test account: $accountName") - - read(commonConfig) - - read( - ( - for (i <- 0 until accountConfig.size() - if accountConfig.get(i).get("name").asText() == accountName) - yield accountConfig.get(i).get("config") - ).head - ) - - log.info(s"load config from $configFile") - jsonConfigFile.close() - Some(result) - } catch { - case e: Throwable => - log.info(s"Can't read $configFile, reason: ${e.getMessage}") - None - } - } - - // Used for internal integration testing in SF env. - def readConfigValueFromEnv(name: String): Option[String] = { - scala.util.Properties.envOrNone(s"SPARK_CONN_ENV_${name.toUpperCase}") - } - - // Overwite options with environment variable settings. - def overWriteOptionsWithEnv( - sfOptions: Option[Map[String, String]] - ): Option[Map[String, String]] = { - if (sfOptions.isDefined) { - var resultOptions = new mutable.HashMap[String, String] - // Retrieve all options from Environment variables - Parameters.KNOWN_PARAMETERS foreach { param => - val opt = readConfigValueFromEnv(param) - if (opt.isDefined) { - log.info(s"Get config from env: $param") - resultOptions += (param -> opt.get) - } - } - // Merge the options that are not set Env - for ((key, value) <- sfOptions.get) { - if (resultOptions.get(key).isEmpty) { - resultOptions += (key -> value) - } - } - Some(resultOptions.toMap) - } else { - None - } - } - - // Load sfOptions from config file and env. - lazy val sfOptions: Map[String, String] = { - var configFile = System.getenv(SNOWFLAKE_TEST_CONFIG) - if (configFile == null) { - configFile = "snowflake.travis.json" - } - overWriteOptionsWithEnv(loadJsonConfig(configFile)).get - } - - // Load sfOptions from config file and env, but exclude "table" - lazy val sfOptionsNoTable: Map[String, String] = { - var resultOptions = new mutable.HashMap[String, String] - for ((key, value) <- sfOptions) { - if (key != "dbtable") { - resultOptions += (key -> value) - } - } - resultOptions.toMap - } - - // parameters for connection - lazy val param: MergedParameters = Parameters.mergeParameters(sfOptions) - - /** - * Get a connection based on the provided parameters - */ - def getJDBCConnection(params: MergedParameters): Connection = { - // Derive class name - try Class.forName("com.snowflake.client.jdbc.SnowflakeDriver") - catch { - case _: ClassNotFoundException => - System.err.println("Driver not found") - } - - val sfURL = params.sfURL - val jdbcURL = s"""jdbc:snowflake://$sfURL""" - - val jdbcProperties = new Properties() - - // Obligatory properties - jdbcProperties.put("db", params.sfDatabase) - jdbcProperties.put("schema", params.sfSchema) // Has a default - if (params.sfUser != null) { - // user is optional when using Oauth token - jdbcProperties.put("user", params.sfUser) - } - - params.privateKey match { - case Some(privateKey) => - jdbcProperties.put("privateKey", privateKey) - case None => - // Adding OAuth Token parameter - params.sfToken match { - case Some(value) => - jdbcProperties.put("token", value) - case None => jdbcProperties.put("password", params.sfPassword) - } - } - jdbcProperties.put("ssl", params.sfSSL) // Has a default - // Optional properties - if (params.sfAccount.isDefined) { - jdbcProperties.put("account", params.sfAccount.get) - } - if (params.sfWarehouse.isDefined) { - jdbcProperties.put("warehouse", params.sfWarehouse.get) - } - if (params.sfRole.isDefined) { - jdbcProperties.put("role", params.sfRole.get) - } - params.getTimeOutputFormat match { - case Some(value) => - jdbcProperties.put(Parameters.PARAM_TIME_OUTPUT_FORMAT, value) - case _ => // No default value for it. - } - params.getQueryResultFormat match { - case Some(value) => - jdbcProperties.put(Parameters.PARAM_JDBC_QUERY_RESULT_FORMAT, value) - case _ => // No default value for it. - } - - // Set up proxy info if it is configured. - params.setJDBCProxyIfNecessary(jdbcProperties) - - // Adding Authenticator parameter - params.sfAuthenticator match { - case Some(value) => - jdbcProperties.put("authenticator", value) - case _ => // No default value for it. - } - - // Always set CLIENT_SESSION_KEEP_ALIVE. - // Note, can be overridden with options - jdbcProperties.put("client_session_keep_alive", "true") - - // Force DECIMAL for NUMBER (SNOW-33227) - jdbcProperties.put("JDBC_TREAT_DECIMAL_AS_INT", "false") - - // Add extra properties from sfOptions - val extraOptions = params.sfExtraOptions - for ((k: String, v: Object) <- extraOptions) { - jdbcProperties.put(k.toLowerCase, v.toString) - } - - DriverManager.getConnection(jdbcURL, jdbcProperties) - } - - def getServerConnection(params: MergedParameters, - enableCache: Boolean = false): ServerConnection = - ServerConnection.getServerConnection(params, enableCache) -} diff --git a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/testsuite/BasicReadWriteSuite.scala b/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/testsuite/BasicReadWriteSuite.scala deleted file mode 100644 index 731fed1f..00000000 --- a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/testsuite/BasicReadWriteSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2015-2020 Snowflake Computing - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package net.snowflake.spark.snowflake.testsuite - -import net.snowflake.spark.snowflake.{ - ClusterTestResultBuilder, - DefaultJDBCWrapper, - ServerConnection, - TestUtils -} -import org.apache.spark.sql.SparkSession - -class BasicReadWriteSuite extends ClusterTestSuiteBase { - override def runImpl(sparkSession: SparkSession, - resultBuilder: ClusterTestResultBuilder): Unit = { - // its row count is 6.0M, compressed data size in SF is 157.7 MB. - val sourceSchema = "TPCH_SF1" - val sourceTableName = "LINEITEM" - val targetSchema = "spark_test" - val targetTableName = s"test_write_table_$randomSuffix" - - // Read write a basic table: - super.readWriteSnowflakeTable( - sparkSession, - resultBuilder, - TestUtils.sfOptionsNoTable, - sourceSchema, - sourceTableName, - targetSchema, - targetTableName - ) - - // If test is successful, drop the target table, - // otherwise, keep it for further investigation. - if (resultBuilder.testStatus == TestUtils.TEST_RESULT_STATUS_SUCCESS) { - val connection = TestUtils.getServerConnection(TestUtils.param) - connection - .createStatement() - .execute(s"drop table $targetSchema.$targetTableName") - connection.close() - } - } -} diff --git a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/testsuite/ClusterTestSuiteBase.scala b/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/testsuite/ClusterTestSuiteBase.scala deleted file mode 100644 index 4d2bed24..00000000 --- a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/testsuite/ClusterTestSuiteBase.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright 2015-2020 Snowflake Computing - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package net.snowflake.spark.snowflake.testsuite - -import net.snowflake.spark.snowflake.ClusterTest.log -import net.snowflake.spark.snowflake.{ClusterTestResultBuilder, TestUtils} -import org.apache.spark.sql.{SaveMode, SparkSession} - -import scala.util.Random - -trait ClusterTestSuiteBase { - def run(sparkSession: SparkSession, - resultBuilder: ClusterTestResultBuilder): Unit = { - // Start to run the test. - resultBuilder.withTestStatus(TestUtils.TEST_RESULT_STATUS_START) - - // Run the test implementation - runImpl(sparkSession, resultBuilder) - - // The test implementation should set up the test status - assert(resultBuilder.testStatus != TestUtils.TEST_RESULT_STATUS_START) - } - - // Each test case MUST implement this function. - def runImpl(sparkSession: SparkSession, - resultBuilder: ClusterTestResultBuilder): Unit - - protected def randomSuffix: String = Math.abs(Random.nextLong()).toString - - // Utility function to read one table and write to another. - protected def readWriteSnowflakeTable(sparkSession: SparkSession, - resultBuilder: ClusterTestResultBuilder, - sfOptionsNoTable: Map[String, String], - sourceSchema: String, - sourceTableName: String, - targetSchema: String, - targetTableName: String): Unit = { - val sqlContext = sparkSession.sqlContext - val tableNameInfo = - s"Source:$sourceSchema.$sourceTableName Target=$targetSchema.$targetTableName" - - // Read DataFrame. - val df = sqlContext.read - .format(TestUtils.SNOWFLAKE_NAME) - .options(sfOptionsNoTable) - // .option("query", s"select * from $sourceSchema.$sourceTableName limit 100000") - .option("dbtable", sourceTableName) - .option("sfSchema", sourceSchema) - .load() - - // Write DataFrame - df.write - .format(TestUtils.SNOWFLAKE_NAME) - .options(sfOptionsNoTable) - .option("dbtable", targetTableName) - .option("sfSchema", targetSchema) - .mode(SaveMode.Overwrite) - .save() - - // Source rowCount - val sourceRowCount = sparkSession.read - .format(TestUtils.SNOWFLAKE_NAME) - .options(sfOptionsNoTable) - .option("dbtable", sourceTableName) - .option("sfSchema", sourceSchema) - .load() - .count() - - // Target rowCount - val targetRowCount = sparkSession.read - .format(TestUtils.SNOWFLAKE_NAME) - .options(sfOptionsNoTable) - .option("dbtable", targetTableName) - .option("sfSchema", targetSchema) - .load() - .count() - - // verify row count to be equal - val rowCountInfo = - s"sourceRowCount=$sourceRowCount, targetRowCount=$targetRowCount" - log.info(rowCountInfo) - if (sourceRowCount != targetRowCount) { - resultBuilder - .withTestStatus(TestUtils.TEST_RESULT_STATUS_FAIL) - .withReason( - s"Read Write row count is incorrect: $tableNameInfo $rowCountInfo" - ) - return - } - - // Source HASH_AGG result - val sourceHashAgg = sparkSession.read - .format(TestUtils.SNOWFLAKE_NAME) - .options(sfOptionsNoTable) - .option( - "query", - s"select HASH_AGG(*) from $sourceSchema.$sourceTableName" - ) - .load() - .collect()(0)(0) - - // Target HASH_AGG result - val targetHashAgg = sparkSession.read - .format(TestUtils.SNOWFLAKE_NAME) - .options(sfOptionsNoTable) - .option( - "query", - s"select HASH_AGG(*) from $targetSchema.$targetTableName" - ) - .load() - .collect()(0)(0) - - val hashAggInfo = - s"sourceHashAgg=$sourceHashAgg targetHashAgg=$targetHashAgg" - // Verify hash agg to be equal - if (sourceHashAgg != targetHashAgg) { - resultBuilder - .withTestStatus(TestUtils.TEST_RESULT_STATUS_FAIL) - .withReason( - s"hash agg result is incorrect: $tableNameInfo $hashAggInfo" - ) - return - } - - // Test succeed. - resultBuilder - .withTestStatus(TestUtils.TEST_RESULT_STATUS_SUCCESS) - .withReason("Success") - } -} diff --git a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/testsuite/LowMemoryStressSuite.scala b/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/testsuite/LowMemoryStressSuite.scala deleted file mode 100644 index 46fa4a4e..00000000 --- a/ClusterTest/src/main/scala/net/snowflake/spark/snowflake/testsuite/LowMemoryStressSuite.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright 2015-2020 Snowflake Computing - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package net.snowflake.spark.snowflake.testsuite - -import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_NAME -import net.snowflake.spark.snowflake.{ - ClusterTestResultBuilder, - DefaultJDBCWrapper, - Parameters, - ServerConnection, - TestUtils -} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SaveMode, SparkSession} -import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType} -import org.apache.spark.util.SizeEstimator -import org.slf4j.LoggerFactory - -import scala.util.Random - -class LowMemoryStressSuite extends ClusterTestSuiteBase { - private val log = LoggerFactory.getLogger(getClass) - - override def runImpl(sparkSession: SparkSession, - resultBuilder: ClusterTestResultBuilder): Unit = { - - def getRandomString(len: Int): String = { - Random.alphanumeric take len mkString "" - } - - val randomString1 = getRandomString(50000) - val randomString2 = getRandomString(50000) - val randomString3 = getRandomString(50000) - val randomString4 = getRandomString(50000) - val randomString5 = getRandomString(50000) - val randomString6 = getRandomString(50000) - val randomString7 = getRandomString(50000) - val randomString8 = getRandomString(50000) - val partitionCount = 1 - val rowCountPerPartition = 800 - // Create RDD which generates data with multiple partitions - val testRDD: RDD[Row] = sparkSession.sparkContext - .parallelize(Seq[Int](), partitionCount) - .mapPartitions { _ => { - (1 to rowCountPerPartition).map { i => { - Row(randomString1, randomString2, - randomString3, randomString4, - randomString5, randomString6, - randomString7, randomString8) - } - }.iterator - } - } - val schema = StructType( - List( - StructField("str1", StringType), - StructField("str2", StringType), - StructField("str3", StringType), - StructField("str4", StringType), - StructField("str5", StringType), - StructField("str6", StringType), - StructField("str7", StringType), - StructField("str8", StringType) - ) - ) - val test_big_partition = s"test_big_partition_$randomSuffix" - - // Convert RDD to DataFrame - val df = sparkSession.createDataFrame(testRDD, schema) - - // Write to snowflake - df.write - .format(TestUtils.SNOWFLAKE_NAME) - .options(TestUtils.sfOptionsNoTable) - .option("dbtable", test_big_partition) - .mode(SaveMode.Overwrite) - .save() - - log.info( - s"""Finished the first multi-part upload test.""".stripMargin) - - var noOOMError = true - try { - // Write to snowflake with multi-part feature off - df.write - .format(TestUtils.SNOWFLAKE_NAME) - .options(TestUtils.sfOptionsNoTable) - .option("dbtable", test_big_partition) - .option(Parameters.PARAM_USE_AWS_MULTIPLE_PARTS_UPLOAD, "off") - .mode(SaveMode.Overwrite) - .save() - } - catch { - case e: Throwable => { - // Test succeed - noOOMError = false - resultBuilder - .withTestStatus(TestUtils.TEST_RESULT_STATUS_SUCCESS) - .withReason("Success") - } - } - if (noOOMError) { - throw new Exception("Expecting OOM error but didn't catch that.") - } - - // If test is successful, drop the target table, - // otherwise, keep it for further investigation. - if (resultBuilder.testStatus == TestUtils.TEST_RESULT_STATUS_SUCCESS) { - val connection = TestUtils.getServerConnection(TestUtils.param) - connection - .createStatement() - .execute(s"drop table if exists $test_big_partition") - connection.close() - } - } -} diff --git a/build.sbt b/build.sbt index e718ccea..718cf9b4 100644 --- a/build.sbt +++ b/build.sbt @@ -16,8 +16,8 @@ import scala.util.Properties -val sparkVersion = "3.4" -val testSparkVersion = sys.props.get("spark.testVersion").getOrElse("3.4.0") +val sparkVersion = "3.5" +val testSparkVersion = sys.props.get("spark.testVersion").getOrElse("3.5.0") /* * Don't change the variable name "sparkConnectorVersion" because @@ -26,7 +26,7 @@ val testSparkVersion = sys.props.get("spark.testVersion").getOrElse("3.4.0") * Tests/jenkins/BumpUpSparkConnectorVersion/run.sh * in snowflake repository. */ -val sparkConnectorVersion = "2.16.0" +val sparkConnectorVersion = "3.0.0" lazy val ItTest = config("it") extend Test @@ -41,7 +41,7 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) .settings( name := "spark-snowflake", organization := "net.snowflake", - version := s"${sparkConnectorVersion}-spark_3.4", + version := s"${sparkConnectorVersion}", scalaVersion := sys.props.getOrElse("SPARK_SCALA_VERSION", default = "2.12.11"), // Spark 3.2 supports scala 2.12 and 2.13 crossScalaVersions := Seq("2.12.11", "2.13.9"), @@ -119,11 +119,6 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) Bing Li https://github.com/binglihub - - Mingli-Rui - Mingli Rui - https://github.com/Mingli-Rui - , publishTo := Some( diff --git a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala index 236b8fbd..a0cd6df7 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala @@ -875,7 +875,8 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { } // Negative Test - test("test COPY missing files and fails") { + // disabled since Spark 3.5, since the test DF has only one partation + ignore("test COPY missing files and fails") { setupLargeResultTable // Don't run this test for use_copy_unload=true // because there are only 3 files (2 partitions) for this data size. diff --git a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeTelemetrySuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeTelemetrySuite.scala index ce74c359..e504fde6 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeTelemetrySuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeTelemetrySuite.scala @@ -87,8 +87,6 @@ class SnowflakeTelemetrySuite extends IntegrationSuiteBase { sparkSession.stop() val pythonTestFile = System.getProperty("user.dir") + "/src/test/python/unittest.py" - val dummyArchiveFile = System.getProperty("user.dir") + - "/ClusterTest/src/main/python/ClusterTest.py#environment" sparkSession = SparkSession.builder .master("local") @@ -99,16 +97,14 @@ class SnowflakeTelemetrySuite extends IntegrationSuiteBase { // It may not be necessary for spark 2.X. .config("spark.sql.legacy.timeParserPolicy", "LEGACY") .config("spark.files", pythonTestFile) - .config("spark.archives", dummyArchiveFile) .getOrCreate() // unit test val metric: ObjectNode = mapper.createObjectNode() val arrayNode = metric.putArray("dependencies") val dependencies = SnowflakeTelemetry.getSparkDependencies - assert(dependencies.length == 2) + assert(dependencies.length == 1) assert(dependencies.contains(pythonTestFile)) - assert(dependencies.contains(dummyArchiveFile)) // Integration test // A basis dataframe read @@ -126,9 +122,8 @@ class SnowflakeTelemetrySuite extends IntegrationSuiteBase { assert(clientInfoMessages.nonEmpty) clientInfoMessages.foreach { x => val sparkDependencies = x.get("data").get(TelemetryFieldNames.DEPENDENCIES) - assert(sparkDependencies.isArray && sparkDependencies.size() == 2) + assert(sparkDependencies.isArray && sparkDependencies.size() == 1) assert(nodeContains(sparkDependencies, pythonTestFile)) - assert(nodeContains(sparkDependencies, dummyArchiveFile)) } } finally { // reset default SparkSession diff --git a/src/it/scala/org/apache/spark/sql/SFDataFrameAggregateSuite.scala b/src/it/scala/org/apache/spark/sql/SFDataFrameAggregateSuite.scala deleted file mode 100644 index f2dcc6e1..00000000 --- a/src/it/scala/org/apache/spark/sql/SFDataFrameAggregateSuite.scala +++ /dev/null @@ -1,42 +0,0 @@ -package org.apache.spark.sql - -import org.apache.spark.sql.snowflake.{SFQueryTest, SFTestData, SFTestSessionBase} - -class SFDataFrameAggregateSuite - extends DataFrameAggregateSuite - with SFTestSessionBase - with SFQueryTest - with SFTestData { - - override def spark: SparkSession = getSnowflakeSession() - - override protected def blackList: Seq[String] = - Seq( - // replace by TS - zero moments, Snowflake aggregate suite returns null instead of Double.NaN - // Snowflake skewness function is called skew - "zero moments", - // Replaced by TS - SPARK-21580 ints in aggregation expressions are taken as group-by ordinal. - // Snowflake does not have table TestData2 stored in test database - "SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.", - // We have different SnowflakePlans, so testing the .toString does not make sense - "SPARK-24788: RelationalGroupedDataset.toString with unresolved exprs should not fail", - // Snowflake does not support float type, struct() - "SPARK-26021: NaN and -0.0 in grouping expressions", - "max_by", - "min_by", - // Below 2 test cases are new test cases for spark 3.2. - // The are not runnable for the negative test assertion: - // It expects to see an SparkException, but Spark connector test will write data - // to a table and read it back, so no SparkException is raised. - // val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), - // (Period.ofMonths(10), Duration.ofDays(10))) - // .toDF("year-month", "day") - // val error = intercept[SparkException] { - // checkAnswer(df2.select(sum($"year-month")), Nil) - // } - "SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`", - "SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`" - // Replace Spark exception by Snowflake exception, replaced by - // TS - SPARK-21896: Window functions inside aggregate functions - ) -} diff --git a/src/it/scala/org/apache/spark/sql/SFDataFrameSetOperationsSuite.scala b/src/it/scala/org/apache/spark/sql/SFDataFrameSetOperationsSuite.scala deleted file mode 100644 index 684b35d5..00000000 --- a/src/it/scala/org/apache/spark/sql/SFDataFrameSetOperationsSuite.scala +++ /dev/null @@ -1,19 +0,0 @@ -package org.apache.spark.sql - -import org.apache.spark.sql.snowflake.{SFQueryTest, SFTestData, SFTestSessionBase} - -class SFDataFrameSetOperationsSuite - extends DataFrameSetOperationsSuite - with SFTestSessionBase - with SFQueryTest - with SFTestData { - - override def spark: SparkSession = getSnowflakeSession() - - override protected def blackList: Seq[String] = - Seq( - "SPARK-25368 Incorrect predicate pushdown returns wrong result", - "exceptAll - nullability", - "intersectAll - nullability" - ) -} diff --git a/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFramesSuite.scala b/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFramesSuite.scala deleted file mode 100644 index 96c6f403..00000000 --- a/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFramesSuite.scala +++ /dev/null @@ -1,377 +0,0 @@ -package org.apache.spark.sql - -import net.snowflake.spark.snowflake.{SnowflakeConnectorUtils, TestUtils} -import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.snowflake.{SFQueryTest, SFTestData, SFTestSessionBase} - -class SFDataFrameWindowFramesSuite - extends SFQueryTest - with SFTestSessionBase - with SFTestData { - - import SFTestImplicits._ - - override def spark: SparkSession = getSnowflakeSession() - - override protected def blackList: Seq[String] = Seq.empty - - test("lead/lag with empty data frame") { - val df = Seq.empty[(Int, String)].toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value") - - checkAnswer( - df.select( - lead("value", 1).over(window), - lag("value", 1).over(window)), - Nil) - } - - test("lead/lag with positive offset") { - val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value") - - checkAnswer( - df.select( - $"key", - lead("value", 1).over(window), - lag("value", 1).over(window)), - Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil) - } - - test("reverse lead/lag with positive offset") { - val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value".desc) - - checkAnswer( - df.select( - $"key", - lead("value", 1).over(window), - lag("value", 1).over(window)), - Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil) - } - - test("lead/lag with negative offset") { - val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value") - - checkAnswer( - df.select( - $"key", - lead("value", -1).over(window), - lag("value", -1).over(window)), - Row(1, null, "3") :: Row(1, "1", null) :: Row(2, null, "4") :: Row(2, "2", null) :: Nil) - } - - test("reverse lead/lag with negative offset") { - val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value".desc) - - checkAnswer( - df.select( - $"key", - lead("value", -1).over(window), - lag("value", -1).over(window)), - Row(1, null, "1") :: Row(1, "3", null) :: Row(2, null, "2") :: Row(2, "4", null) :: Nil) - } - - test("lead/lag with default value") { - val default = "n/a" - val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4"), (2, "5")).toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value") - - checkAnswer( - df.select( - $"key", - lead("value", 2, default).over(window), - lag("value", 2, default).over(window), - lead("value", -2, default).over(window), - lag("value", -2, default).over(window)), - Row(1, default, default, default, default) :: Row(1, default, default, default, default) :: - Row(2, "5", default, default, "5") :: Row(2, default, "2", "2", default) :: - Row(2, default, default, default, default) :: Nil) - } - - test("rows/range between with empty data frame") { - val df = Seq.empty[(String, Int)].toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value") - - checkAnswer( - df.select( - $"key", - first("value").over( - window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - first("value").over( - window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), - Nil) - } - - test("rows between should accept int/long values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), - Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - - val e = intercept[AnalysisException]( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) - assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) - } - - test("range between should accept at most one ORDER BY expression when unbounded") { - val df = Seq((1, 1)).toDF("key", "value") - val window = Window.orderBy($"key", $"value") - - checkAnswer( - df.select( - $"key", - min("key").over( - window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), - Seq(Row(1, 1)) - ) - - val e1 = intercept[AnalysisException]( - df.select( - min("key").over(window.rangeBetween(Window.unboundedPreceding, 1)))) - assert(e1.message.contains("A range window frame with value boundaries cannot be used in a " + - "window specification with multiple order by expressions")) - - val e2 = intercept[AnalysisException]( - df.select( - min("key").over(window.rangeBetween(-1, Window.unboundedFollowing)))) - assert(e2.message.contains("A range window frame with value boundaries cannot be used in a " + - "window specification with multiple order by expressions")) - - val e3 = intercept[AnalysisException]( - df.select( - min("key").over(window.rangeBetween(-1, 1)))) - assert(e3.message.contains("A range window frame with value boundaries cannot be used in a " + - "window specification with multiple order by expressions")) - } - - test("range between should accept numeric values only when bounded") { - val df = Seq("non_numeric").toDF("value") - val window = Window.orderBy($"value") - - checkAnswer( - df.select( - $"value", - min("value").over( - window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), - Row("non_numeric", "non_numeric") :: Nil) - - // The error message for 3.4 is different. - val expectedErrorMessage1 = - if (TestUtils.compareVersion(SnowflakeConnectorUtils.SUPPORT_SPARK_VERSION, "3.4") >= 0) { - "The data type of the upper bound \"STRING\" does not match the expected data type" - } else { - "The data type of the upper bound 'string' does not match the expected data type" - } - val e1 = intercept[AnalysisException]( - df.select( - min("value").over(window.rangeBetween(Window.unboundedPreceding, 1)))) - assert(e1.message.contains(expectedErrorMessage1)) - - // The error message for 3.4 is different. - val expectedErrorMessage2 = - if (TestUtils.compareVersion(SnowflakeConnectorUtils.SUPPORT_SPARK_VERSION, "3.4") >= 0) { - "The data type of the lower bound \"STRING\" does not match the expected data type" - } else { - "The data type of the lower bound 'string' does not match the expected data type" - } - val e2 = intercept[AnalysisException]( - df.select( - min("value").over(window.rangeBetween(-1, Window.unboundedFollowing)))) - assert(e2.message.contains(expectedErrorMessage2)) - - // The error message for 3.4 is different. - val expectedErrorMessage3 = - if (TestUtils.compareVersion(SnowflakeConnectorUtils.SUPPORT_SPARK_VERSION, "3.4") >= 0) { - "The data type of the lower bound \"STRING\" does not match the expected data type" - } else { - "The data type of the lower bound 'string' does not match the expected data type" - } - val e3 = intercept[AnalysisException]( - df.select( - min("value").over(window.rangeBetween(-1, 1)))) - assert(e3.message.contains(expectedErrorMessage3)) - } - - test("range between should accept int/long values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), - Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), - Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) - ) - } - - test("unbounded rows/range between with aggregation") { - val df = Seq(("one", 1), ("two", 2), ("one", 3), ("two", 4)).toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value") - - checkAnswer( - df.select( - $"key", - sum("value").over(window. - rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - sum("value").over(window. - rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), - Row("one", 4, 4) :: Row("one", 4, 4) :: Row("two", 6, 6) :: Row("two", 6, 6) :: Nil) - } - - test("unbounded preceding/following rows between with aggregation") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - val window = Window.partitionBy($"value").orderBy($"key") - - checkAnswer( - df.select( - $"key", - last("key").over( - window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), - last("key").over( - window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), - Row(1, 1, 1) :: Row(2, 3, 2) :: Row(3, 3, 3) :: Row(1, 4, 1) :: Row(2, 4, 2) :: - Row(4, 4, 4) :: Nil) - } - - test("reverse unbounded preceding/following rows between with aggregation") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - val window = Window.partitionBy($"value").orderBy($"key".desc) - - checkAnswer( - df.select( - $"key", - last("key").over( - window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), - last("key").over( - window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), - Row(1, 1, 1) :: Row(3, 2, 3) :: Row(2, 2, 2) :: Row(4, 1, 4) :: Row(2, 1, 2) :: - Row(1, 1, 1) :: Nil) - } - - test("unbounded preceding/following range between with aggregation") { - val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - val window = Window.partitionBy("value").orderBy("key") - - checkAnswer( - df.select( - $"key", - avg("key").over(window.rangeBetween(Window.unboundedPreceding, 1)) - .as("avg_key1"), - avg("key").over(window.rangeBetween(Window.currentRow, Window.unboundedFollowing)) - .as("avg_key2")), - Row(3, 3.0d, 4.0d) :: Row(5, 4.0d, 5.0d) :: Row(2, 2.0d, 17.0d / 4.0d) :: - Row(4, 11.0d / 3.0d, 5.0d) :: Row(5, 17.0d / 4.0d, 11.0d / 2.0d) :: - Row(6, 17.0d / 4.0d, 6.0d) :: Nil) - } - - // This is here to illustrate the fact that reverse order also reverses offsets. - test("reverse preceding/following range between with aggregation") { - val df = Seq(1, 2, 4, 3, 2, 1).toDF("value") - val window = Window.orderBy($"value".desc) - - checkAnswer( - df.select( - $"value", - sum($"value").over(window.rangeBetween(Window.unboundedPreceding, 1)), - sum($"value").over(window.rangeBetween(1, Window.unboundedFollowing))), - Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: - Row(2, 13, 2) :: Row(1, 13, null) :: Nil) - } - - test("sliding rows between with aggregation") { - val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val window = Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2) - - checkAnswer( - df.select( - $"key", - avg("key").over(window)), - Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 3.0d / 2.0d) :: Row(2, 2.0d) :: - Row(2, 2.0d) :: Nil) - } - - test("reverse sliding rows between with aggregation") { - val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val window = Window.partitionBy($"value").orderBy($"key".desc).rowsBetween(-1, 2) - - checkAnswer( - df.select( - $"key", - avg("key").over(window)), - Row(1, 1.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 4.0d / 3.0d) :: Row(2, 2.0d) :: - Row(2, 2.0d) :: Nil) - } - - test("sliding range between with aggregation") { - val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") - val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1) - - checkAnswer( - df.select( - $"key", - avg("key").over(window)), - Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 7.0d / 4.0d) :: Row(3, 5.0d / 2.0d) :: - Row(2, 2.0d) :: Row(2, 2.0d) :: Nil) - } - - test("reverse sliding range between with aggregation") { - val df = Seq( - (1, "Thin", "Cell Phone", 6000), - (2, "Normal", "Tablet", 1500), - (3, "Mini", "Tablet", 5500), - (4, "Ultra thin", "Cell Phone", 5500), - (5, "Very thin", "Cell Phone", 6000), - (6, "Big", "Tablet", 2500), - (7, "Bendable", "Cell Phone", 3000), - (8, "Foldable", "Cell Phone", 3000), - (9, "Pro", "Tablet", 4500), - (10, "Pro2", "Tablet", 6500)). - toDF("id", "product", "category", "revenue") - val window = Window.partitionBy($"category").orderBy($"revenue".desc). - rangeBetween(-2000L, 1000L) - - checkAnswer( - df.select( - $"id", - avg($"revenue").over(window).cast("int")), - Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: - Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: - Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: - Row(10, 6000) :: Nil) - } - - test("SPARK-24033: Analysis Failure of OffsetWindowFunction") { - val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") - val res = - Row(1, 1, null) :: Row (1, 2, 1) :: Row(1, 3, 2) :: Row(2, 1, null) :: Row(2, 2, 1) :: Nil - checkAnswer( - ds.withColumn("m", - lead("i", -1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), - res) - checkAnswer( - ds.withColumn("m", - lag("i", 1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), - res) - } -} diff --git a/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFunctionsSuite.scala b/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFunctionsSuite.scala deleted file mode 100644 index 77d10e01..00000000 --- a/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFunctionsSuite.scala +++ /dev/null @@ -1,821 +0,0 @@ -package org.apache.spark.sql - -import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.catalyst.optimizer.TransposeWindow -import org.apache.spark.sql.execution.exchange.Exchange -import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.snowflake.{SFQueryTest, SFTestData, SFTestSessionBase} -import org.apache.spark.sql.types.{ArrayType, DataType, LongType, StringType, StructType} -import org.apache.spark.util.Utils -import org.scalatest.Matchers.the - -class SFDataFrameWindowFunctionsSuite - extends SFQueryTest - with SFTestSessionBase - with SFTestData { - import SFTestImplicits._ - - /** - * Drops temporary view `viewNames` after calling `f`. - */ - protected def withTempView(viewNames: String*)(f: => Unit): Unit = { - Utils.tryWithSafeFinally(f) { - viewNames.foreach { viewName => - try spark.catalog.dropTempView(viewName) catch { - // If the test failed part way, we don't want to mask the failure by failing to remove - // temp views that never got created. - case _: NoSuchTableException => - } - } - } - } - - protected def sql(sqlText: String) = spark.sql(sqlText) - - override def spark: SparkSession = getSnowflakeSession() - - protected def sparkContext = spark.sparkContext - - override protected def blackList: Seq[String] = - Seq( - // TS - corr, covar_pop, stddev_pop functions in specific window - "corr, covar_pop, stddev_pop functions in specific window", - // TS - covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window - "covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window", - // TS - NaN and -0.0 in window partition keys - "NaN and -0.0 in window partition keys", - // TS - This case is commented out because the column name in error is upper case. - // actual message: cannot resolve '`invalid`' given input columns: [KEY, VALUE] - // expect message: cannot resolve '`invalid`' given input columns: [key, value] - "aggregation function on invalid column" - ) - - test("reuse window partitionBy") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val w = Window.partitionBy("key").orderBy("value") - - checkAnswer( - df.select( - lead("key", 1).over(w), - lead("value", 1).over(w)), - Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) - } - - test("reuse window orderBy") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val w = Window.orderBy("value").partitionBy("key") - - checkAnswer( - df.select( - lead("key", 1).over(w), - lead("value", 1).over(w)), - Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) - } - - test("rank functions in unspecific window") { - withTempView("window_table") { - val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - max("key").over(Window.partitionBy("value").orderBy("key")), - min("key").over(Window.partitionBy("value").orderBy("key")), - mean("key").over(Window.partitionBy("value").orderBy("key")), - count("key").over(Window.partitionBy("value").orderBy("key")), - sum("key").over(Window.partitionBy("value").orderBy("key")), - ntile(2).over(Window.partitionBy("value").orderBy("key")), - row_number().over(Window.partitionBy("value").orderBy("key")), - dense_rank().over(Window.partitionBy("value").orderBy("key")), - rank().over(Window.partitionBy("value").orderBy("key")), - cume_dist().over(Window.partitionBy("value").orderBy("key")), - percent_rank().over(Window.partitionBy("value").orderBy("key"))), - Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: - Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) - } - } - - test("window function should fail if order by clause is not specified") { - val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") - val e = intercept[AnalysisException]( - // Here we missed .orderBy("key")! - df.select(row_number().over(Window.partitionBy("value"))).collect()) - assert(e.message.contains("requires window to be ordered")) - } - - test("corr, covar_pop, stddev_pop functions in specific window") { - val df = Seq( - ("a", "p1", 10.0, 20.0), - ("b", "p1", 20.0, 10.0), - ("c", "p2", 20.0, 20.0), - ("d", "p2", 20.0, 20.0), - ("e", "p3", 0.0, 0.0), - ("f", "p3", 6.0, 12.0), - ("g", "p3", 6.0, 12.0), - ("h", "p3", 8.0, 16.0), - ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") - checkAnswer( - df.select( - $"key", - corr("value1", "value2").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - covar_pop("value1", "value2") - .over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - var_pop("value1") - .over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - stddev_pop("value1") - .over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - var_pop("value2") - .over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - stddev_pop("value2") - .over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), - - // As stddev_pop(expr) = sqrt(var_pop(expr)) - // the "stddev_pop" column can be calculated from the "var_pop" column. - // - // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) - // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns. - Seq( - Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), - Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), - Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), - Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), - Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), - Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), - Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), - Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), - Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0))) - } - - test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { - val df = Seq( - ("a", "p1", 10.0, 20.0), - ("b", "p1", 20.0, 10.0), - ("c", "p2", 20.0, 20.0), - ("d", "p2", 20.0, 20.0), - ("e", "p3", 0.0, 0.0), - ("f", "p3", 6.0, 12.0), - ("g", "p3", 6.0, 12.0), - ("h", "p3", 8.0, 16.0), - ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") - checkAnswer( - df.select( - $"key", - covar_samp("value1", "value2").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - var_samp("value1").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - variance("value1").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - stddev_samp("value1").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - stddev("value1").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) - ), - Seq( - Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), - Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), - Row("c", 0.0, 0.0, 0.0, 0.0, 0.0), - Row("d", 0.0, 0.0, 0.0, 0.0, 0.0), - Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), - Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), - Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), - Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), - Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) - } - - test("collect_list in ascending ordered window") { - val df = Seq( - ("a", "p1", "1"), - ("b", "p1", "2"), - ("c", "p1", "2"), - ("d", "p1", null), - ("e", "p1", "3"), - ("f", "p2", "10"), - ("g", "p2", "11"), - ("h", "p3", "20"), - ("i", "p4", null)).toDF("key", "partition", "value") - checkAnswer( - df.select( - $"key", - sort_array( - collect_list("value").over(Window.partitionBy($"partition").orderBy($"value") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), - Seq( - Row("a", Array("1", "2", "2", "3")), - Row("b", Array("1", "2", "2", "3")), - Row("c", Array("1", "2", "2", "3")), - Row("d", Array("1", "2", "2", "3")), - Row("e", Array("1", "2", "2", "3")), - Row("f", Array("10", "11")), - Row("g", Array("10", "11")), - Row("h", Array("20")), - Row("i", Array()))) - } - - test("collect_list in descending ordered window") { - val df = Seq( - ("a", "p1", "1"), - ("b", "p1", "2"), - ("c", "p1", "2"), - ("d", "p1", null), - ("e", "p1", "3"), - ("f", "p2", "10"), - ("g", "p2", "11"), - ("h", "p3", "20"), - ("i", "p4", null)).toDF("key", "partition", "value") - checkAnswer( - df.select( - $"key", - sort_array( - collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc) - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), - Seq( - Row("a", Array("1", "2", "2", "3")), - Row("b", Array("1", "2", "2", "3")), - Row("c", Array("1", "2", "2", "3")), - Row("d", Array("1", "2", "2", "3")), - Row("e", Array("1", "2", "2", "3")), - Row("f", Array("10", "11")), - Row("g", Array("10", "11")), - Row("h", Array("20")), - Row("i", Array()))) - } - - test("collect_set in window") { - val df = Seq( - ("a", "p1", "1"), - ("b", "p1", "2"), - ("c", "p1", "2"), - ("d", "p1", "3"), - ("e", "p1", "3"), - ("f", "p2", "10"), - ("g", "p2", "11"), - ("h", "p3", "20")).toDF("key", "partition", "value") - checkAnswer( - df.select( - $"key", - sort_array( - collect_set("value").over(Window.partitionBy($"partition").orderBy($"value") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), - Seq( - Row("a", Array("1", "2", "3")), - Row("b", Array("1", "2", "3")), - Row("c", Array("1", "2", "3")), - Row("d", Array("1", "2", "3")), - Row("e", Array("1", "2", "3")), - Row("f", Array("10", "11")), - Row("g", Array("10", "11")), - Row("h", Array("20")))) - } - - test("skewness and kurtosis functions in window") { - val df = Seq( - ("a", "p1", 1.0), - ("b", "p1", 1.0), - ("c", "p1", 2.0), - ("d", "p1", 2.0), - ("e", "p1", 3.0), - ("f", "p1", 3.0), - ("g", "p1", 3.0), - ("h", "p2", 1.0), - ("i", "p2", 2.0), - ("j", "p2", 5.0)).toDF("key", "partition", "value") - checkAnswer( - df.select( - $"key", - skewness("value").over(Window.partitionBy("partition").orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - kurtosis("value").over(Window.partitionBy("partition").orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), - // results are checked by scipy.stats.skew() and scipy.stats.kurtosis() - Seq( - Row("a", -0.27238010581457267, -1.506920415224914), - Row("b", -0.27238010581457267, -1.506920415224914), - Row("c", -0.27238010581457267, -1.506920415224914), - Row("d", -0.27238010581457267, -1.506920415224914), - Row("e", -0.27238010581457267, -1.506920415224914), - Row("f", -0.27238010581457267, -1.506920415224914), - Row("g", -0.27238010581457267, -1.506920415224914), - Row("h", 0.5280049792181881, -1.5000000000000013), - Row("i", 0.5280049792181881, -1.5000000000000013), - Row("j", 0.5280049792181881, -1.5000000000000013))) - } - - test("aggregation function on invalid column") { - val df = Seq((1, "1")).toDF("key", "value") - val e = intercept[AnalysisException]( - df.select($"key", count("invalid").over())) - assert(e.message.contains("cannot resolve '`invalid`' given input columns: [key, value]")) - } - - test("numerical aggregate functions on string column") { - val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") - checkAnswer( - df.select($"key", - var_pop("value1").over(), - variance("value1").over(), - stddev_pop("value1").over(), - stddev("value1").over(), - sum("value1").over(), - mean("value1").over(), - avg("value1").over(), - corr("value1", "value2").over(), - covar_pop("value1", "value2").over(), - covar_samp("value1", "value2").over(), - skewness("value1").over(), - kurtosis("value1").over()), - Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) - } - - test("statistical functions") { - val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). - toDF("key", "value") - val window = Window.partitionBy($"key") - checkAnswer( - df.select( - $"key", - var_pop($"value").over(window), - var_samp($"value").over(window), - approx_count_distinct($"value").over(window)), - Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) - ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) - } - - test("window function with aggregates") { - val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). - toDF("key", "value") - val window = Window.orderBy() - checkAnswer( - df.groupBy($"key") - .agg( - sum($"value"), - sum(sum($"value")).over(window) - sum($"value")), - Seq(Row("a", 6, 9), Row("b", 9, 6))) - } - - test("SPARK-16195 empty over spec") { - withTempView("window_table") { - val df = Seq(("a", 1), ("a", 1), ("a", 2), ("b", 2)). - toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select($"key", $"value", sum($"value").over(), avg($"value").over()), - Seq(Row("a", 1, 6, 1.5), Row("a", 1, 6, 1.5), Row("a", 2, 6, 1.5), Row("b", 2, 6, 1.5))) - checkAnswer( - sql("select key, value, sum(value) over(), avg(value) over() from window_table"), - Seq(Row("a", 1, 6, 1.5), Row("a", 1, 6, 1.5), Row("a", 2, 6, 1.5), Row("b", 2, 6, 1.5))) - } - } - - test("window function with udaf") { - val udaf = new UserDefinedAggregateFunction { - def inputSchema: StructType = new StructType() - .add("a", LongType) - .add("b", LongType) - - def bufferSchema: StructType = new StructType() - .add("product", LongType) - - def dataType: DataType = LongType - - def deterministic: Boolean = true - - def initialize(buffer: MutableAggregationBuffer): Unit = { - buffer(0) = 0L - } - - def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - if (!(input.isNullAt(0) || input.isNullAt(1))) { - buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) - } - } - - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) - } - - def evaluate(buffer: Row): Any = - buffer.getLong(0) - } - val df = Seq( - ("a", 1, 1), - ("a", 1, 5), - ("a", 2, 10), - ("a", 2, -1), - ("b", 4, 7), - ("b", 3, 8), - ("b", 2, 4)) - .toDF("key", "a", "b") - val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L) - checkAnswer( - df.select( - $"key", - $"a", - $"b", - udaf($"a", $"b").over(window)), - Seq( - Row("a", 1, 1, 6), - Row("a", 1, 5, 6), - Row("a", 2, 10, 24), - Row("a", 2, -1, 24), - Row("b", 4, 7, 60), - Row("b", 3, 8, 32), - Row("b", 2, 4, 8))) - } - - test("window function with aggregator") { - val agg = udaf(new Aggregator[(Long, Long), Long, Long] { - def zero: Long = 0L - def reduce(b: Long, a: (Long, Long)): Long = b + (a._1 * a._2) - def merge(b1: Long, b2: Long): Long = b1 + b2 - def finish(r: Long): Long = r - def bufferEncoder: Encoder[Long] = Encoders.scalaLong - def outputEncoder: Encoder[Long] = Encoders.scalaLong - }) - - val df = Seq( - ("a", 1, 1), - ("a", 1, 5), - ("a", 2, 10), - ("a", 2, -1), - ("b", 4, 7), - ("b", 3, 8), - ("b", 2, 4)) - .toDF("key", "a", "b") - val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L) - checkAnswer( - df.select( - $"key", - $"a", - $"b", - agg($"a", $"b").over(window)), - Seq( - Row("a", 1, 1, 6), - Row("a", 1, 5, 6), - Row("a", 2, 10, 24), - Row("a", 2, -1, 24), - Row("b", 4, 7, 60), - Row("b", 3, 8, 32), - Row("b", 2, 4, 8))) - } - - test("null inputs") { - val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)) - .toDF("key", "value") - val window = Window.orderBy() - checkAnswer( - df.select( - $"key", - $"value", - avg(lit(null)).over(window), - sum(lit(null)).over(window)), - Seq( - Row("a", 1, null, null), - Row("a", 1, null, null), - Row("a", 2, null, null), - Row("a", 2, null, null), - Row("b", 4, null, null), - Row("b", 3, null, null), - Row("b", 2, null, null))) - } - - test("last/first with ignoreNulls") { - val nullStr: String = null - val df = Seq( - ("a", 0, nullStr), - ("a", 1, "x"), - ("a", 2, "y"), - ("a", 3, "z"), - ("a", 4, nullStr), - ("b", 1, nullStr), - ("b", 2, nullStr)). - toDF("key", "order", "value") - val window = Window.partitionBy($"key").orderBy($"order") - checkAnswer( - df.select( - $"key", - $"order", - first($"value").over(window), - first($"value", ignoreNulls = false).over(window), - first($"value", ignoreNulls = true).over(window), - last($"value").over(window), - last($"value", ignoreNulls = false).over(window), - last($"value", ignoreNulls = true).over(window)), - Seq( - Row("a", 0, null, null, null, null, null, null), - Row("a", 1, null, null, "x", "x", "x", "x"), - Row("a", 2, null, null, "x", "y", "y", "y"), - Row("a", 3, null, null, "x", "z", "z", "z"), - Row("a", 4, null, null, "x", null, null, "z"), - Row("b", 1, null, null, null, null, null, null), - Row("b", 2, null, null, null, null, null, null))) - } - - test("last/first on descending ordered window") { - val nullStr: String = null - val df = Seq( - ("a", 0, nullStr), - ("a", 1, "x"), - ("a", 2, "y"), - ("a", 3, "z"), - ("a", 4, "v"), - ("b", 1, "k"), - ("b", 2, "l"), - ("b", 3, nullStr)). - toDF("key", "order", "value") - val window = Window.partitionBy($"key").orderBy($"order".desc) - checkAnswer( - df.select( - $"key", - $"order", - first($"value").over(window), - first($"value", ignoreNulls = false).over(window), - first($"value", ignoreNulls = true).over(window), - last($"value").over(window), - last($"value", ignoreNulls = false).over(window), - last($"value", ignoreNulls = true).over(window)), - Seq( - Row("a", 0, "v", "v", "v", null, null, "x"), - Row("a", 1, "v", "v", "v", "x", "x", "x"), - Row("a", 2, "v", "v", "v", "y", "y", "y"), - Row("a", 3, "v", "v", "v", "z", "z", "z"), - Row("a", 4, "v", "v", "v", "v", "v", "v"), - Row("b", 1, null, null, "l", "k", "k", "k"), - Row("b", 2, null, null, "l", "l", "l", "l"), - Row("b", 3, null, null, null, null, null, null))) - } - - test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { - val src = Seq((0, 3, 5)).toDF("a", "b", "c") - .withColumn("Data", struct("a", "b")) - .drop("a") - .drop("b") - val winSpec = Window.partitionBy("Data.a", "Data.b").orderBy($"c".desc) - val df = src.select($"*", max("c").over(winSpec) as "max") - checkAnswer(df, Row(5, Row(0, 3), 5)) - } - - test("aggregation and rows between with unbounded + predicate pushdown") { - withTempView("window_table") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - val selectList = Seq($"key", $"value", - last("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), - last("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), - last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))) - - checkAnswer( - df.select(selectList: _*).where($"value" < "3"), - Seq(Row(1, "1", 1, 1, 1), Row(2, "2", 3, 2, 3), Row(3, "2", 3, 3, 3))) - } - } - - test("aggregation and range between with unbounded + predicate pushdown") { - withTempView("window_table") { - val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - val selectList = Seq($"key", $"value", - last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)).equalTo("2") - .as("last_v"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) - .as("avg_key1"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) - .as("avg_key2"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 1)) - .as("avg_key3")) - - checkAnswer( - df.select(selectList: _*).where($"value" < 2), - Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) - } - } - - test("Window spill with less than the inMemoryThreshold") { - val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value") - - withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "2", - SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") { - assertNotSpilled(sparkContext, "select") { - df.select($"key", sum("value").over(window)).collect() - } - } - } - - test("Window spill with more than the inMemoryThreshold but less than the spillThreshold") { - val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value") - - withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", - SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") { - assertNotSpilled(sparkContext, "select") { - df.select($"key", sum("value").over(window)).collect() - } - } - } - - test("Window spill with more than the inMemoryThreshold and spillThreshold") { - val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") - val window = Window.partitionBy($"key").orderBy($"value") - - withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", - SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { - assertSpilled(sparkContext, "select") { - df.select($"key", sum("value").over(window)).collect() - } - } - } - - // Spark connector test doesn't support function toDS(). - // So comment out this test case to avoid compiling problem. - /* - test("SPARK-21258: complex object in combination with spilling") { - // Make sure we trigger the spilling path. - withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", - SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { - val sampleSchema = new StructType(). - add("f0", StringType). - add("f1", LongType). - add("f2", ArrayType(new StructType(). - add("f20", StringType))). - add("f3", ArrayType(new StructType(). - add("f30", StringType))) - - val w0 = Window.partitionBy("f0").orderBy("f1") - val w1 = w0.rowsBetween(Long.MinValue, Long.MaxValue) - - val c0 = first(struct($"f2", $"f3")).over(w0) as "c0" - val c1 = last(struct($"f2", $"f3")).over(w1) as "c1" - - val input = - """{"f1":1497820153720,"f2":[{"f20":"x","f21":0}],"f3":[{"f30":"x","f31":0}]} - |{"f1":1497802179638} - |{"f1":1497802189347} - |{"f1":1497802189593} - |{"f1":1497802189597} - |{"f1":1497802189599} - |{"f1":1497802192103} - |{"f1":1497802193414} - |{"f1":1497802193577} - |{"f1":1497802193709} - |{"f1":1497802202883} - |{"f1":1497802203006} - |{"f1":1497802203743} - |{"f1":1497802203834} - |{"f1":1497802203887} - |{"f1":1497802203893} - |{"f1":1497802203976} - |{"f1":1497820168098} - |""".stripMargin.split("\n").toSeq - - import testImplicits._ - - assertSpilled(sparkContext, "select") { - spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } - } - } - } - */ - - test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { - def checkAnalysisError(df: => DataFrame, clause: String): Unit = { - val thrownException = the[AnalysisException] thrownBy { - df.queryExecution.analyzed - } - assert(thrownException.message.contains(s"window functions inside $clause clause")) - } - - checkAnalysisError( - testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1), "WHERE") - checkAnalysisError( - testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1), "WHERE") - checkAnalysisError( - testData2.groupBy($"a") - .agg(avg($"b").as("avgb")) - .where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1), "WHERE") - checkAnalysisError( - testData2.groupBy($"a") - .agg(max($"b").as("maxb"), sum($"b").as("sumb")) - .where(rank().over(Window.orderBy($"a")) === 1), "WHERE") - checkAnalysisError( - testData2.groupBy($"a") - .agg(max($"b").as("maxb"), sum($"b").as("sumb")) - .where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1), "WHERE") - - checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"), "WHERE") - checkAnalysisError( - sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"), "WHERE") - checkAnalysisError( - sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"), - "HAVING") - checkAnalysisError( - sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"), - "HAVING") - checkAnalysisError( - sql( - s"""SELECT a, MAX(b) - |FROM testData2 - |GROUP BY a - |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin), - "HAVING") - } - - // Spark connector test doesn't support function stripAQEPlan(). - // So comment out this test case to avoid compiling problem. - /* - test("window functions in multiple selects") { - val df = Seq( - ("S1", "P1", 100), - ("S1", "P1", 700), - ("S2", "P1", 200), - ("S2", "P2", 300) - ).toDF("sno", "pno", "qty") - - Seq(true, false).foreach { transposeWindowEnabled => - val excludedRules = if (transposeWindowEnabled) "" else TransposeWindow.ruleName - withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) { - val w1 = Window.partitionBy("sno") - val w2 = Window.partitionBy("sno", "pno") - - val select = df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2")) - .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")) - - val expectedNumExchanges = if (transposeWindowEnabled) 1 else 2 - val actualNumExchanges = stripAQEPlan(select.queryExecution.executedPlan).collect { - case e: Exchange => e - }.length - assert(actualNumExchanges == expectedNumExchanges) - - checkAnswer( - select, - Seq( - Row("S1", "P1", 100, 800, 800), - Row("S1", "P1", 700, 800, 800), - Row("S2", "P1", 200, 200, 500), - Row("S2", "P2", 300, 300, 500))) - } - } - } - */ - - test("NaN and -0.0 in window partition keys") { - val df = Seq( - (Float.NaN, Double.NaN), - (0.0f/0.0f, 0.0/0.0), - (0.0f, 0.0), - (-0.0f, -0.0)).toDF("f", "d") - - checkAnswer( - df.select($"f", count(lit(1)).over(Window.partitionBy("f", "d"))), - Seq( - Row(Float.NaN, 2), - Row(0.0f/0.0f, 2), - Row(0.0f, 2), - Row(-0.0f, 2))) - - // test with complicated window partition keys. - val windowSpec1 = Window.partitionBy(array("f"), struct("d")) - checkAnswer( - df.select($"f", count(lit(1)).over(windowSpec1)), - Seq( - Row(Float.NaN, 2), - Row(0.0f/0.0f, 2), - Row(0.0f, 2), - Row(-0.0f, 2))) - - val windowSpec2 = Window.partitionBy(array(struct("f")), struct(array("d"))) - checkAnswer( - df.select($"f", count(lit(1)).over(windowSpec2)), - Seq( - Row(Float.NaN, 2), - Row(0.0f/0.0f, 2), - Row(0.0f, 2), - Row(-0.0f, 2))) - - // test with df with complicated-type columns. - val df2 = Seq( - (Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN))), - (Array(0.0f, -0.0f), Tuple2(0.0d, Double.NaN), Seq(Tuple2(0.0d, 0.0/0.0))) - ).toDF("arr", "stru", "arrOfStru") - val windowSpec3 = Window.partitionBy("arr", "stru", "arrOfStru") - checkAnswer( - df2.select($"arr", $"stru", $"arrOfStru", count(lit(1)).over(windowSpec3)), - Seq( - Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), Seq(Row(-0.0d, Double.NaN)), 2), - Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, 0.0/0.0)), 2))) - } -} diff --git a/src/it/scala/org/apache/spark/sql/SFDateFunctionsSuite.scala b/src/it/scala/org/apache/spark/sql/SFDateFunctionsSuite.scala deleted file mode 100644 index eb6015b3..00000000 --- a/src/it/scala/org/apache/spark/sql/SFDateFunctionsSuite.scala +++ /dev/null @@ -1,794 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat -import java.time.{Instant, LocalDateTime, ZoneId} -import java.util.concurrent.TimeUnit -import java.util.{Locale, TimeZone} - -import org.apache.spark.SparkUpgradeException -import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{CEST, LA} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.snowflake.{SFQueryTest, SFTestSessionBase} -import org.apache.spark.sql.types.DoubleType -import org.apache.spark.unsafe.types.CalendarInterval - -class SFDateFunctionsSuite extends SFQueryTest with SFTestSessionBase { - import SFTestImplicits._ - protected def sql(sqlText: String) = spark.sql(sqlText) - - test("function current_date") { - val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") - val d0 = DateTimeUtils.currentDate(ZoneId.systemDefault()) - val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) - val d2 = - DateTimeUtils.fromJavaDate(sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) - val d3 = DateTimeUtils.currentDate(ZoneId.systemDefault()) - assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) - } - - test("function current_timestamp and now") { - val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") - checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) - - // Execution in one query should return the same value - checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true)) - - // Current timestamp should return the current timestamp ... - val before = System.currentTimeMillis - val got = sql("SELECT CURRENT_TIMESTAMP()").collect().head.getTimestamp(0).getTime - val after = System.currentTimeMillis - assert(got >= before && got <= after) - - // Now alias - checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = NOW()"""), Row(true)) - } - - test("timestamp comparison with date strings") { - val df = Seq( - (1, Timestamp.valueOf("2015-01-01 00:00:00")), - (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") - - checkAnswer( - df.select("t").filter($"t" <= "2014-06-01"), - Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) - - checkAnswer( - df.select("t").filter($"t" >= "2014-06-01"), - Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) - } - - test("date comparison with date strings") { - val df = Seq((1, Date.valueOf("2015-01-01")), (2, Date.valueOf("2014-01-01"))).toDF("i", "t") - - checkAnswer( - df.select("t").filter($"t" <= "2014-06-01"), - Row(Date.valueOf("2014-01-01")) :: Nil) - - checkAnswer(df.select("t").filter($"t" >= "2015"), Row(Date.valueOf("2015-01-01")) :: Nil) - } - - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) - val sdfDate = new SimpleDateFormat("yyyy-MM-dd", Locale.US) - val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) - val ts = new Timestamp(sdf.parse("2013-04-08 13:10:15").getTime) - - test("date format") { - Seq("legacy", "corrected").foreach { legacyParserPolicy => - withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> legacyParserPolicy) { - val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") - - checkAnswer( - df.select(date_format($"a", "y"), date_format($"b", "y"), date_format($"c", "y")), - Row("2015", "2015", "2013")) - - checkAnswer( - df.selectExpr("date_format(a, 'y')", "date_format(b, 'y')", "date_format(c, 'y')"), - Row("2015", "2015", "2013")) - } - } - } - - test("year") { - val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") - - checkAnswer(df.select(year($"a"), year($"b"), year($"c")), Row(2015, 2015, 2013)) - - checkAnswer(df.selectExpr("year(a)", "year(b)", "year(c)"), Row(2015, 2015, 2013)) - } - - test("quarter") { - val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) - - val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") - - checkAnswer(df.select(quarter($"a"), quarter($"b"), quarter($"c")), Row(2, 2, 4)) - - checkAnswer(df.selectExpr("quarter(a)", "quarter(b)", "quarter(c)"), Row(2, 2, 4)) - } - - test("month") { - val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") - - checkAnswer(df.select(month($"a"), month($"b"), month($"c")), Row(4, 4, 4)) - - checkAnswer(df.selectExpr("month(a)", "month(b)", "month(c)"), Row(4, 4, 4)) - } - - test("dayofmonth") { - val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") - - checkAnswer(df.select(dayofmonth($"a"), dayofmonth($"b"), dayofmonth($"c")), Row(8, 8, 8)) - - checkAnswer(df.selectExpr("day(a)", "day(b)", "dayofmonth(c)"), Row(8, 8, 8)) - } - - test("dayofyear") { - val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") - - checkAnswer(df.select(dayofyear($"a"), dayofyear($"b"), dayofyear($"c")), Row(98, 98, 98)) - - checkAnswer(df.selectExpr("dayofyear(a)", "dayofyear(b)", "dayofyear(c)"), Row(98, 98, 98)) - } - - // SNOW-205533: disabled temporarily because test fails. Investigate pushdown bug -// test("hour") { -// val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") -// -// checkAnswer(df.select(hour($"a"), hour($"b"), hour($"c")), Row(0, 13, 13)) -// -// checkAnswer(df.selectExpr("hour(a)", "hour(b)", "hour(c)"), Row(0, 13, 13)) -// } - - test("minute") { - val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") - - checkAnswer(df.select(minute($"a"), minute($"b"), minute($"c")), Row(0, 10, 10)) - - checkAnswer(df.selectExpr("minute(a)", "minute(b)", "minute(c)"), Row(0, 10, 10)) - } - - test("second") { - val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") - - checkAnswer(df.select(second($"a"), second($"b"), second($"c")), Row(0, 15, 15)) - - checkAnswer(df.selectExpr("second(a)", "second(b)", "second(c)"), Row(0, 15, 15)) - } - - test("weekofyear") { - val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") - - checkAnswer(df.select(weekofyear($"a"), weekofyear($"b"), weekofyear($"c")), Row(15, 15, 15)) - - checkAnswer(df.selectExpr("weekofyear(a)", "weekofyear(b)", "weekofyear(c)"), Row(15, 15, 15)) - } - - // SNOW-205533: disabled temporarily because test fails. Investigate pushdown bug - test("function date_add") { - val st1 = "2015-06-01 12:34:56" - val st2 = "2015-06-02 12:34:56" - val t1 = Timestamp.valueOf(st1) - val t2 = Timestamp.valueOf(st2) - val s1 = "2015-06-01" - val s2 = "2015-06-02" - val d1 = Date.valueOf(s1) - val d2 = Date.valueOf(s2) - val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") - checkAnswer( - df.select(date_add(col("d"), 1)), - Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) - checkAnswer( - df.select(date_add(col("t"), 3)), - Seq(Row(Date.valueOf("2015-06-04")), Row(Date.valueOf("2015-06-05")))) - checkAnswer( - df.select(date_add(col("s"), 5)), - Seq(Row(Date.valueOf("2015-06-06")), Row(Date.valueOf("2015-06-07")))) - checkAnswer( - df.select(date_add(col("ss"), 7)), - Seq(Row(Date.valueOf("2015-06-08")), Row(Date.valueOf("2015-06-09")))) - - checkAnswer( - df.withColumn("x", lit(1)).select(date_add(col("d"), col("x"))), - Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) - - checkAnswer(df.selectExpr("DATE_ADD(null, 1)"), Seq(Row(null), Row(null))) - checkAnswer( - df.selectExpr("""DATE_ADD(d, 1)"""), - Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) - } - - // SNOW-205533: disabled temporarily because test fails. Investigate pushdown bug - test("function date_sub") { - val st1 = "2015-06-01 12:34:56" - val st2 = "2015-06-02 12:34:56" - val t1 = Timestamp.valueOf(st1) - val t2 = Timestamp.valueOf(st2) - val s1 = "2015-06-01" - val s2 = "2015-06-02" - val d1 = Date.valueOf(s1) - val d2 = Date.valueOf(s2) - val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") - checkAnswer( - df.select(date_sub(col("d"), 1)), - Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) - checkAnswer( - df.select(date_sub(col("t"), 1)), - Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) - checkAnswer( - df.select(date_sub(col("s"), 1)), - Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) - checkAnswer( - df.select(date_sub(col("ss"), 1)), - Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) - checkAnswer(df.select(date_sub(lit(null), 1)).limit(1), Row(null)) - - checkAnswer( - df.withColumn("x", lit(1)).select(date_sub(col("d"), col("x"))), - Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) - - checkAnswer(df.selectExpr("""DATE_SUB(d, null)"""), Seq(Row(null), Row(null))) - checkAnswer( - df.selectExpr("""DATE_SUB(d, 1)"""), - Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) - } - - /* - * In Spark 3.2, the unit list interval literals can not mix year-month fields - * (YEAR and MONTH) and day-time fields (WEEK, DAY, ..., MICROSECOND). - * For example, INTERVAL 1 month 1 hour is invalid in Spark 3.2. - * In Spark 3.1 and earlier, there is no such limitation and the literal - * returns value of CalendarIntervalType. To restore the behavior before - * Spark 3.2, you can set spark.sql.legacy.interval.enabled to true. - * - * So re-copy test("time_add") and test("time_sub"). - * If don't re-copy them, the test fails with error: - * Cannot mix year-month and day-time fields: INTERVAL'2 months 2 days 2 seconds' - */ - test("time_add") { - val t1 = Timestamp.valueOf("2015-07-31 23:59:59") - val t2 = Timestamp.valueOf("2015-12-31 00:00:00") - val d1 = Date.valueOf("2015-07-31") - val d2 = Date.valueOf("2015-12-31") - val i = new CalendarInterval(2, 2, 2000000L) - val day = "1 day" - val ym = "1 year 2 month" - val dt = "1 day 2 hour 3 minute 4 second 5 millisecond 6 microsecond" - val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") - checkAnswer( - df.selectExpr(s"d + INTERVAL'$ym'"), - Seq(Row(Date.valueOf("2016-09-30")), - Row(Date.valueOf("2017-02-28")))) - checkAnswer( - df.selectExpr(s"t + INTERVAL'$ym'"), - Seq(Row(Timestamp.valueOf("2016-09-30 23:59:59")), - Row(Timestamp.valueOf("2017-02-28 00:00:00")))) - checkAnswer( - df.selectExpr(s"d + INTERVAL'$dt'"), - Seq(Row(Timestamp.valueOf("2015-08-01 02:03:04.005006")), - Row(Timestamp.valueOf("2016-01-01 02:03:04.005006")))) - checkAnswer( - df.selectExpr(s"d + INTERVAL '$day'"), - Seq(Row(Date.valueOf("2015-08-01")), - Row(Date.valueOf("2016-01-01")))) - checkAnswer( - df.selectExpr(s"t + INTERVAL'$dt'"), - Seq(Row(Timestamp.valueOf("2015-08-02 02:03:03.005006")), - Row(Timestamp.valueOf("2016-01-01 02:03:04.005006")))) - withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { - checkAnswer( - df.selectExpr(s"d + INTERVAL'${i.toString}'"), - Seq(Row(Date.valueOf("2015-10-02")), Row(Date.valueOf("2016-03-02")))) - checkAnswer( - df.selectExpr(s"t + INTERVAL'${i.toString}'"), - Seq(Row(Timestamp.valueOf("2015-10-03 00:00:01")), - Row(Timestamp.valueOf("2016-03-02 00:00:02")))) - } - } - - test("time_sub") { - val t1 = Timestamp.valueOf("2015-10-01 00:00:01") - val t2 = Timestamp.valueOf("2016-02-29 00:00:02") - val d1 = Date.valueOf("2015-09-30") - val d2 = Date.valueOf("2016-02-29") - val i = new CalendarInterval(2, 2, 2000000L) - val day = "1 day" - val ym = "1 year 2 month" - val dt = "1 day 2 hour 3 minute 4 second 5 millisecond 6 microsecond" - val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") - checkAnswer( - df.selectExpr(s"d - INTERVAL'$ym'"), - Seq(Row(Date.valueOf("2014-07-30")), - Row(Date.valueOf("2014-12-29")))) - checkAnswer( - df.selectExpr(s"t - INTERVAL'$ym'"), - Seq(Row(Timestamp.valueOf("2014-08-01 00:00:01")), - Row(Timestamp.valueOf("2014-12-29 00:00:02")))) - checkAnswer( - df.selectExpr(s"d - INTERVAL'$dt'"), - Seq(Row(Timestamp.valueOf("2015-09-28 21:56:55.994994")), - Row(Timestamp.valueOf("2016-02-27 21:56:55.994994")))) - checkAnswer( - df.selectExpr(s"d - INTERVAL '$day'"), - Seq(Row(Date.valueOf("2015-09-29")), - Row(Date.valueOf("2016-02-28")))) - checkAnswer( - df.selectExpr(s"t - INTERVAL'$dt'"), - Seq(Row(Timestamp.valueOf("2015-09-29 21:56:56.994994")), - Row(Timestamp.valueOf("2016-02-27 21:56:57.994994")))) - withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { - checkAnswer( - df.selectExpr(s"d - INTERVAL'${i.toString}'"), - Seq(Row(Date.valueOf("2015-07-27")), Row(Date.valueOf("2015-12-26")))) - checkAnswer( - df.selectExpr(s"t - INTERVAL'${i.toString}'"), - Seq(Row(Timestamp.valueOf("2015-07-29 23:59:59")), - Row(Timestamp.valueOf("2015-12-27 00:00:00")))) - } - } - - test("function add_months") { - val d1 = Date.valueOf("2015-08-31") - val d2 = Date.valueOf("2015-02-28") - val df = Seq((1, d1), (2, d2)).toDF("n", "d") - checkAnswer( - df.select(add_months(col("d"), 1)), - Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-28")))) - checkAnswer( - df.selectExpr("add_months(d, -1)"), - Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-28")))) - checkAnswer( - df.withColumn("x", lit(1)).select(add_months(col("d"), col("x"))), - Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-28")))) - } - - test("function months_between") { - val d1 = Date.valueOf("2015-07-31") - val d2 = Date.valueOf("2015-02-16") - val t1 = Timestamp.valueOf("2014-09-30 23:30:00") - val t2 = Timestamp.valueOf("2015-09-16 12:00:00") - val s1 = "2014-09-15 11:30:00" - val s2 = "2015-10-01 00:00:00" - val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") - checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) - checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) - checkAnswer(df.selectExpr("months_between(t, s, true)"), Seq(Row(0.5), Row(-0.5))) - Seq(true, false).foreach { roundOff => - checkAnswer( - df.select(months_between(col("t"), col("d"), roundOff)), - Seq(Row(-10.0), Row(7.0))) - checkAnswer( - df.withColumn("r", lit(false)).selectExpr("months_between(t, s, r)"), - Seq(Row(0.5), Row(-0.5))) - } - } - - test("function last_day") { - val df1 = Seq((1, "2015-07-23"), (2, "2015-07-24")).toDF("i", "d") - val df2 = Seq((1, "2015-07-23 00:11:22"), (2, "2015-07-24 11:22:33")).toDF("i", "t") - checkAnswer( - df1.select(last_day(col("d"))), - Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) - checkAnswer( - df2.select(last_day(col("t"))), - Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) - } - - test("function next_day") { - val df1 = Seq(("mon", "2015-07-23"), ("tuesday", "2015-07-20")).toDF("dow", "d") - val df2 = Seq(("th", "2015-07-23 00:11:22"), ("xx", "2015-07-24 11:22:33")).toDF("dow", "t") - checkAnswer( - df1.select(next_day(col("d"), "MONDAY")), - Seq(Row(Date.valueOf("2015-07-27")), Row(Date.valueOf("2015-07-27")))) - checkAnswer( - df2.select(next_day(col("t"), "th")), - Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) - } - - test("function to_date") { - val d1 = Date.valueOf("2015-07-22") - val d2 = Date.valueOf("2015-07-01") - val d3 = Date.valueOf("2014-12-31") - val t1 = Timestamp.valueOf("2015-07-22 10:00:00") - val t2 = Timestamp.valueOf("2014-12-31 23:59:59") - val t3 = Timestamp.valueOf("2014-12-31 23:59:59") - val s1 = "2015-07-22 10:00:00" - val s2 = "2014-12-31" - val s3 = "2014-31-12" - val df = Seq((d1, t1, s1), (d2, t2, s2), (d3, t3, s3)).toDF("d", "t", "s") - - checkAnswer( - df.select(to_date(col("t"))), - Seq( - Row(Date.valueOf("2015-07-22")), - Row(Date.valueOf("2014-12-31")), - Row(Date.valueOf("2014-12-31")))) - checkAnswer( - df.select(to_date(col("d"))), - Seq( - Row(Date.valueOf("2015-07-22")), - Row(Date.valueOf("2015-07-01")), - Row(Date.valueOf("2014-12-31")))) - checkAnswer( - df.select(to_date(col("s"))), - Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), Row(null))) - - checkAnswer( - df.selectExpr("to_date(t)"), - Seq( - Row(Date.valueOf("2015-07-22")), - Row(Date.valueOf("2014-12-31")), - Row(Date.valueOf("2014-12-31")))) - checkAnswer( - df.selectExpr("to_date(d)"), - Seq( - Row(Date.valueOf("2015-07-22")), - Row(Date.valueOf("2015-07-01")), - Row(Date.valueOf("2014-12-31")))) - checkAnswer( - df.selectExpr("to_date(s)"), - Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), Row(null))) - - // now with format - checkAnswer( - df.select(to_date(col("t"), "yyyy-MM-dd")), - Seq( - Row(Date.valueOf("2015-07-22")), - Row(Date.valueOf("2014-12-31")), - Row(Date.valueOf("2014-12-31")))) - checkAnswer( - df.select(to_date(col("d"), "yyyy-MM-dd")), - Seq( - Row(Date.valueOf("2015-07-22")), - Row(Date.valueOf("2015-07-01")), - Row(Date.valueOf("2014-12-31")))) - val confKey = SQLConf.LEGACY_TIME_PARSER_POLICY.key - withSQLConf(confKey -> "corrected") { - checkAnswer( - df.select(to_date(col("s"), "yyyy-MM-dd")), - Seq(Row(null), Row(Date.valueOf("2014-12-31")), Row(null))) - } - - // now switch format - checkAnswer( - df.select(to_date(col("s"), "yyyy-dd-MM")), - Seq(Row(null), Row(null), Row(Date.valueOf("2014-12-31")))) - - // invalid format - checkAnswer(df.select(to_date(col("s"), "yyyy-hh-MM")), Seq(Row(null), Row(null), Row(null))) - val e = intercept[SparkUpgradeException](df.select(to_date(col("s"), "yyyy-dd-aa")).collect()) - assert(e.getCause.isInstanceOf[IllegalArgumentException]) - assert(e.getMessage.contains("You may get a different result due to the upgrading to Spark")) - - // february - val x1 = "2016-02-29" - val x2 = "2017-02-29" - val df1 = Seq(x1, x2).toDF("x") - checkAnswer( - df1.select(to_date(col("x"))), - Row(Date.valueOf("2016-02-29")) :: Row(null) :: Nil) - } - - test("function trunc") { - val df = Seq( - (1, Timestamp.valueOf("2015-07-22 10:00:00")), - (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t") - - checkAnswer( - df.select(trunc(col("t"), "YY")), - Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01")))) - - checkAnswer( - df.selectExpr("trunc(t, 'Month')"), - Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) - } - - test("function date_trunc") { - val df = Seq( - (1, Timestamp.valueOf("2015-07-22 10:01:40.123456")), - (2, Timestamp.valueOf("2014-12-31 05:29:06.123456"))).toDF("i", "t") - - checkAnswer( - df.select(date_trunc("YY", col("t"))), - Seq( - Row(Timestamp.valueOf("2015-01-01 00:00:00")), - Row(Timestamp.valueOf("2014-01-01 00:00:00")))) - - checkAnswer( - df.selectExpr("date_trunc('MONTH', t)"), - Seq( - Row(Timestamp.valueOf("2015-07-01 00:00:00")), - Row(Timestamp.valueOf("2014-12-01 00:00:00")))) - - checkAnswer( - df.selectExpr("date_trunc('DAY', t)"), - Seq( - Row(Timestamp.valueOf("2015-07-22 00:00:00")), - Row(Timestamp.valueOf("2014-12-31 00:00:00")))) - - checkAnswer( - df.selectExpr("date_trunc('HOUR', t)"), - Seq( - Row(Timestamp.valueOf("2015-07-22 10:00:00")), - Row(Timestamp.valueOf("2014-12-31 05:00:00")))) - - checkAnswer( - df.selectExpr("date_trunc('MINUTE', t)"), - Seq( - Row(Timestamp.valueOf("2015-07-22 10:01:00")), - Row(Timestamp.valueOf("2014-12-31 05:29:00")))) - - checkAnswer( - df.selectExpr("date_trunc('SECOND', t)"), - Seq( - Row(Timestamp.valueOf("2015-07-22 10:01:40")), - Row(Timestamp.valueOf("2014-12-31 05:29:06")))) - - checkAnswer( - df.selectExpr("date_trunc('WEEK', t)"), - Seq( - Row(Timestamp.valueOf("2015-07-20 00:00:00")), - Row(Timestamp.valueOf("2014-12-29 00:00:00")))) - - checkAnswer( - df.selectExpr("date_trunc('QUARTER', t)"), - Seq( - Row(Timestamp.valueOf("2015-07-01 00:00:00")), - Row(Timestamp.valueOf("2014-10-01 00:00:00")))) - - checkAnswer( - df.selectExpr("date_trunc('MILLISECOND', t)"), - Seq( - Row(Timestamp.valueOf("2015-07-22 10:01:40.123")), - Row(Timestamp.valueOf("2014-12-31 05:29:06.123")))) - } - - test("unsupported fmt fields for trunc/date_trunc results null") { - Seq("INVALID", "decade", "century", "millennium", "whatever", null).foreach { f => - checkAnswer( - Seq(Date.valueOf("2014-12-31")) - .toDF("dt") - .selectExpr(s"date_trunc('$f', dt)", "trunc(dt, '$f')"), - Row(null, null)) - } - } - - test("from_unixtime") { - Seq("corrected", "legacy").foreach { legacyParserPolicy => - withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> legacyParserPolicy) { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) - val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - val fmt3 = "yy-MM-dd HH-mm-ss" - val sdf3 = new SimpleDateFormat(fmt3, Locale.US) - val df = - Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") - checkAnswer( - df.select(from_unixtime(col("a"))), - Seq( - Row(sdf1.format(new Timestamp(1000000))), - Row(sdf1.format(new Timestamp(-1000000))))) - checkAnswer( - df.select(from_unixtime(col("a"), fmt2)), - Seq( - Row(sdf2.format(new Timestamp(1000000))), - Row(sdf2.format(new Timestamp(-1000000))))) - checkAnswer( - df.select(from_unixtime(col("a"), fmt3)), - Seq( - Row(sdf3.format(new Timestamp(1000000))), - Row(sdf3.format(new Timestamp(-1000000))))) - checkAnswer( - df.selectExpr("from_unixtime(a)"), - Seq( - Row(sdf1.format(new Timestamp(1000000))), - Row(sdf1.format(new Timestamp(-1000000))))) - checkAnswer( - df.selectExpr(s"from_unixtime(a, '$fmt2')"), - Seq( - Row(sdf2.format(new Timestamp(1000000))), - Row(sdf2.format(new Timestamp(-1000000))))) - checkAnswer( - df.selectExpr(s"from_unixtime(a, '$fmt3')"), - Seq( - Row(sdf3.format(new Timestamp(1000000))), - Row(sdf3.format(new Timestamp(-1000000))))) - } - } - } - - test("datediff") { - val df = Seq( - ( - Date.valueOf("2015-07-24"), - Timestamp.valueOf("2015-07-24 01:00:00"), - "2015-07-23", - "2015-07-23 03:00:00"), - ( - Date.valueOf("2015-07-25"), - Timestamp.valueOf("2015-07-25 02:00:00"), - "2015-07-24", - "2015-07-24 04:00:00")).toDF("a", "b", "c", "d") - checkAnswer(df.select(datediff(col("a"), col("b"))), Seq(Row(0), Row(0))) - checkAnswer(df.select(datediff(col("a"), col("c"))), Seq(Row(1), Row(1))) - checkAnswer(df.select(datediff(col("d"), col("b"))), Seq(Row(-1), Row(-1))) - checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) - } - - test("to_timestamp with microseconds precision") { - withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { - val timestamp = "1970-01-01T00:00:00.123456Z" - val df = Seq(timestamp).toDF("t") - checkAnswer( - df.select(to_timestamp($"t", "yyyy-MM-dd'T'HH:mm:ss.SSSSSSX")), - Seq(Row(Instant.parse(timestamp)))) - } - } - - test("from_utc_timestamp with literal zone") { - val df = Seq( - (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), - (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00")).toDF("a", "b") - checkAnswer( - df.select(from_utc_timestamp(col("a"), LA.getId)), - Seq( - Row(Timestamp.valueOf("2015-07-23 17:00:00")), - Row(Timestamp.valueOf("2015-07-24 17:00:00")))) - checkAnswer( - df.select(from_utc_timestamp(col("b"), LA.getId)), - Seq( - Row(Timestamp.valueOf("2015-07-23 17:00:00")), - Row(Timestamp.valueOf("2015-07-24 17:00:00")))) - } - - test("from_utc_timestamp with column zone") { - val df = Seq( - (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", CEST.getId), - (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", LA.getId)) - .toDF("a", "b", "c") - checkAnswer( - df.select(from_utc_timestamp(col("a"), col("c"))), - Seq( - Row(Timestamp.valueOf("2015-07-24 02:00:00")), - Row(Timestamp.valueOf("2015-07-24 17:00:00")))) - checkAnswer( - df.select(from_utc_timestamp(col("b"), col("c"))), - Seq( - Row(Timestamp.valueOf("2015-07-24 02:00:00")), - Row(Timestamp.valueOf("2015-07-24 17:00:00")))) - } - - test("handling null field by date_part") { - val input = Seq(Date.valueOf("2019-09-20")).toDF("d") - Seq("date_part(null, d)", "date_part(null, date'2019-09-20')").foreach { expr => - val df = input.selectExpr(expr) - assert(df.schema.headOption.get.dataType == DoubleType) - checkAnswer(df, Row(null)) - } - } - - test("to_utc_timestamp with literal zone") { - val df = Seq( - (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), - (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00")).toDF("a", "b") - checkAnswer( - df.select(to_utc_timestamp(col("a"), LA.getId)), - Seq( - Row(Timestamp.valueOf("2015-07-24 07:00:00")), - Row(Timestamp.valueOf("2015-07-25 07:00:00")))) - checkAnswer( - df.select(to_utc_timestamp(col("b"), LA.getId)), - Seq( - Row(Timestamp.valueOf("2015-07-24 07:00:00")), - Row(Timestamp.valueOf("2015-07-25 07:00:00")))) - } - - test("to_utc_timestamp with column zone") { - val df = Seq( - (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", LA.getId), - (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", CEST.getId)) - .toDF("a", "b", "c") - checkAnswer( - df.select(to_utc_timestamp(col("a"), col("c"))), - Seq( - Row(Timestamp.valueOf("2015-07-24 07:00:00")), - Row(Timestamp.valueOf("2015-07-24 22:00:00")))) - checkAnswer( - df.select(to_utc_timestamp(col("b"), col("c"))), - Seq( - Row(Timestamp.valueOf("2015-07-24 07:00:00")), - Row(Timestamp.valueOf("2015-07-24 22:00:00")))) - } - - // SNOW-205533: disabled temporarily because test fails. Investigate pushdown bug -// test("SPARK-30668: use legacy timestamp parser in to_timestamp") { -// val confKey = SQLConf.LEGACY_TIME_PARSER_POLICY.key -// val df = Seq("2020-01-27T20:06:11.847-0800").toDF("ts") -// withSQLConf(confKey -> "legacy") { -// val expected = Timestamp.valueOf("2020-01-27 20:06:11.847") -// checkAnswer(df.select(to_timestamp(col("ts"), "yyyy-MM-dd'T'HH:mm:ss.SSSz")), Row(expected)) -// } -// withSQLConf(confKey -> "corrected") { -// checkAnswer(df.select(to_timestamp(col("ts"), "yyyy-MM-dd'T'HH:mm:ss.SSSz")), Row(null)) -// } -// //withSQLConf(confKey -> "exception") { -// // checkExceptionMessage(df.select(to_timestamp(col("ts"), "yyyy-MM-dd'T'HH:mm:ss.SSSz"))) -// // } -// } - - test("SPARK-30752: convert time zones on a daylight saving day") { - val systemTz = LA.getId - val sessionTz = "UTC" - val fromTz = "Asia/Hong_Kong" - val fromTs = "2019-11-03T12:00:00" // daylight saving date in America/Los_Angeles - val utsTs = "2019-11-03T04:00:00" - val defaultTz = TimeZone.getDefault - try { - TimeZone.setDefault(DateTimeUtils.getTimeZone(systemTz)) - withSQLConf( - SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true", - SQLConf.SESSION_LOCAL_TIMEZONE.key -> sessionTz) { - val expected = LocalDateTime - .parse(utsTs) - .atZone(DateTimeUtils.getZoneId(sessionTz)) - .toInstant - val df = Seq(fromTs).toDF("localTs") - checkAnswer(df.select(to_utc_timestamp(col("localTs"), fromTz)), Row(expected)) - } - } finally { - TimeZone.setDefault(defaultTz) - } - } - - // SNOW-205533: disabled temporarily because test fails. Investigate pushdown bug -// test("SPARK-30766: date_trunc of old timestamps to hours and days") { -// def checkTrunc(level: String, expected: String): Unit = { -// val df = Seq("0010-01-01 01:02:03.123456") -// .toDF() -// .select($"value".cast("timestamp").as("ts")) -// .select(date_trunc(level, $"ts").cast("string")) -// checkAnswer(df, Row(expected)) -// } -// -// checkTrunc("HOUR", "0010-01-01 01:00:00") -// checkTrunc("DAY", "0010-01-01 00:00:00") -// } - - // SNOW-205533: disabled temporarily because test fails. Investigate pushdown bug -// test("SPARK-30793: truncate timestamps before the epoch to seconds and minutes") { -// def checkTrunc(level: String, expected: String): Unit = { -// val df = Seq("1961-04-12 00:01:02.345") -// .toDF() -// .select($"value".cast("timestamp").as("ts")) -// .select(date_trunc(level, $"ts").cast("string")) -// checkAnswer(df, Row(expected)) -// } -// -// checkTrunc("SECOND", "1961-04-12 00:01:02") -// checkTrunc("MINUTE", "1961-04-12 00:01:00") -// } - - override protected def spark: SparkSession = getSnowflakeSession() - override protected def blackList: Seq[String] = Seq.empty - - private def secs(millis: Long): Long = TimeUnit.MILLISECONDS.toSeconds(millis) -} diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala index 4284b0a5..0fc6ca7c 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala @@ -28,12 +28,6 @@ object SnowflakeConnectorUtils { @transient lazy val log: Logger = LoggerFactory.getLogger(getClass.getName) - /** - * Check Spark version, if Spark version matches SUPPORT_SPARK_VERSION enable PushDown, - * otherwise disable it. - */ - val SUPPORT_SPARK_VERSION = "3.4" - // TODO: Improve error handling with retries, etc. @throws[SnowflakeConnectorException]