From 79e92c3b41164475ae63aae3b357ce903685227c Mon Sep 17 00:00:00 2001 From: geooo109 Date: Thu, 19 Dec 2024 18:35:12 +0200 Subject: [PATCH] PR Feedback 1 --- sqlglot/dialects/tsql.py | 11 ++++++----- tests/dialects/test_tsql.py | 20 ++++++++++++++------ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b17c667b4a..7aa7a0e22b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -370,14 +370,14 @@ def _timestrtotime_sql(self: TSQL.Generator, expression: exp.TimeStrToTime): return sql -def _parse_datetrunc(args: t.List) -> exp.TimestampTrunc: +def _build_datetrunc(args: t.List) -> exp.TimestampTrunc: unit = seq_get(args, 0) this = seq_get(args, 1) - if isinstance(this, exp.Expression) and this.is_string: - this = exp.TimeStrToTime(this=this) + if this and this.is_string: + this = exp.cast(this, exp.DataType.Type.DATETIME2) - return exp.TimestampTrunc(unit=unit, this=this) + return exp.TimestampTrunc(this=this, unit=unit) class TSQL(Dialect): @@ -580,7 +580,7 @@ class Parser(parser.Parser): "SUSER_SNAME": exp.CurrentUser.from_arg_list, "SYSTEM_USER": exp.CurrentUser.from_arg_list, "TIMEFROMPARTS": _build_timefromparts, - "DATETRUNC": _parse_datetrunc, + "DATETRUNC": _build_datetrunc, } JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"} @@ -947,6 +947,7 @@ class Generator(generator.Generator): exp.Trim: trim_sql, exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), + exp.TimestampTrunc: lambda self, e: self.func("DATETRUNC", e.unit, e.this), } TRANSFORMS.pop(exp.ReturnsProperty) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b0ead2f52a..4c61780d76 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -2091,18 +2091,26 @@ def test_next_value_for(self): }, ) + # string literals in the DATETRUNC are casted as DATETIME2 def test_datetrunc(self): self.validate_all( - "SELECT DATETRUNC(month, '2021-12-08 11:30:15.1234567')", + "SELECT DATETRUNC(month, 'foo')", write={ - "duckdb": "SELECT DATE_TRUNC('MONTH', CAST('2021-12-08 11:30:15.1234567' AS TIMESTAMP))" + "duckdb": "SELECT DATE_TRUNC('MONTH', CAST('foo' AS TIMESTAMP))", + "tsql": "SELECT DATETRUNC(MONTH, CAST('foo' AS DATETIME2))", }, ) self.validate_all( - "SELECT DATETRUNC(year, DATEFROMPARTS(2010, 12, 31))", - write={"duckdb": "SELECT DATE_TRUNC('YEAR', MAKE_DATE(2010, 12, 31))"}, + "SELECT DATETRUNC(month, foo)", + write={ + "duckdb": "SELECT DATE_TRUNC('MONTH', foo)", + "tsql": "SELECT DATETRUNC(MONTH, foo)", + }, ) self.validate_all( - "SELECT DATETRUNC(year, CAST('2021-12-08' AS date))", - write={"duckdb": "SELECT DATE_TRUNC('YEAR', CAST('2021-12-08' AS DATE))"}, + "SELECT DATETRUNC(year, CAST('foo1' AS date))", + write={ + "duckdb": "SELECT DATE_TRUNC('YEAR', CAST('foo1' AS DATE))", + "tsql": "SELECT DATETRUNC(YEAR, CAST('foo1' AS DATE))", + }, )