Skip to content

Commit aa14aae

Browse files
authored
Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains (#13625)
* Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains * Spark 4.0: Support recursive delegate unwrapping to find ExtendedParser in parser chains
1 parent 0ce9b16 commit aa14aae

File tree

2 files changed

+273
-3
lines changed

2 files changed

+273
-3
lines changed
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.iceberg.spark;
20+
21+
import static org.assertj.core.api.Assertions.assertThat;
22+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
23+
import static org.mockito.Mockito.mock;
24+
import static org.mockito.Mockito.verify;
25+
import static org.mockito.Mockito.when;
26+
27+
import java.lang.reflect.Field;
28+
import java.util.Collections;
29+
import java.util.List;
30+
import org.apache.iceberg.NullOrder;
31+
import org.apache.iceberg.SortDirection;
32+
import org.apache.iceberg.expressions.Term;
33+
import org.apache.spark.sql.SparkSession;
34+
import org.apache.spark.sql.catalyst.parser.AbstractSqlParser;
35+
import org.apache.spark.sql.catalyst.parser.AstBuilder;
36+
import org.apache.spark.sql.catalyst.parser.ParserInterface;
37+
import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser;
38+
import org.junit.jupiter.api.AfterAll;
39+
import org.junit.jupiter.api.AfterEach;
40+
import org.junit.jupiter.api.BeforeAll;
41+
import org.junit.jupiter.api.BeforeEach;
42+
import org.junit.jupiter.api.Test;
43+
44+
public class TestExtendedParser {
45+
46+
private static SparkSession spark;
47+
private static final String SQL_PARSER_FIELD = "sqlParser";
48+
private ParserInterface originalParser;
49+
50+
@BeforeAll
51+
public static void before() {
52+
spark = SparkSession.builder().master("local").appName("TestExtendedParser").getOrCreate();
53+
}
54+
55+
@AfterAll
56+
public static void after() {
57+
if (spark != null) {
58+
spark.stop();
59+
}
60+
}
61+
62+
@BeforeEach
63+
public void saveOriginalParser() throws Exception {
64+
Class<?> clazz = spark.sessionState().getClass();
65+
Field parserField = null;
66+
while (clazz != null && parserField == null) {
67+
try {
68+
parserField = clazz.getDeclaredField(SQL_PARSER_FIELD);
69+
} catch (NoSuchFieldException e) {
70+
clazz = clazz.getSuperclass();
71+
}
72+
}
73+
parserField.setAccessible(true);
74+
originalParser = (ParserInterface) parserField.get(spark.sessionState());
75+
}
76+
77+
@AfterEach
78+
public void restoreOriginalParser() throws Exception {
79+
setSessionStateParser(spark.sessionState(), originalParser);
80+
}
81+
82+
/**
83+
* Tests that the Iceberg extended SQL parser can correctly parse a sort order string and return
84+
* the expected RawOrderField.
85+
*
86+
* @throws Exception if reflection access fails
87+
*/
88+
@Test
89+
public void testParseSortOrderWithRealIcebergExtendedParser() throws Exception {
90+
ParserInterface origParser = null;
91+
Class<?> clazz = spark.sessionState().getClass();
92+
while (clazz != null && origParser == null) {
93+
try {
94+
Field parserField = clazz.getDeclaredField(SQL_PARSER_FIELD);
95+
parserField.setAccessible(true);
96+
origParser = (ParserInterface) parserField.get(spark.sessionState());
97+
} catch (NoSuchFieldException e) {
98+
clazz = clazz.getSuperclass();
99+
}
100+
}
101+
assertThat(origParser).isNotNull();
102+
103+
IcebergSparkSqlExtensionsParser icebergParser = new IcebergSparkSqlExtensionsParser(origParser);
104+
105+
setSessionStateParser(spark.sessionState(), icebergParser);
106+
107+
List<ExtendedParser.RawOrderField> fields =
108+
ExtendedParser.parseSortOrder(spark, "id ASC NULLS FIRST");
109+
110+
assertThat(fields).isNotEmpty();
111+
ExtendedParser.RawOrderField first = fields.get(0);
112+
assertThat(first.direction()).isEqualTo(SortDirection.ASC);
113+
assertThat(first.nullOrder()).isEqualTo(NullOrder.NULLS_FIRST);
114+
}
115+
116+
/**
117+
* Tests that parseSortOrder can find and use an ExtendedParser that is wrapped inside another
118+
* ParserInterface implementation.
119+
*
120+
* @throws Exception if reflection access fails
121+
*/
122+
@Test
123+
public void testParseSortOrderFindsNestedExtendedParser() throws Exception {
124+
ExtendedParser icebergParser = mock(ExtendedParser.class);
125+
126+
ExtendedParser.RawOrderField field =
127+
new ExtendedParser.RawOrderField(
128+
mock(Term.class), SortDirection.ASC, NullOrder.NULLS_FIRST);
129+
List<ExtendedParser.RawOrderField> expected = Collections.singletonList(field);
130+
131+
when(icebergParser.parseSortOrder("id ASC NULLS FIRST")).thenReturn(expected);
132+
133+
ParserInterface wrapper = new WrapperParser(icebergParser);
134+
135+
setSessionStateParser(spark.sessionState(), wrapper);
136+
137+
List<ExtendedParser.RawOrderField> result =
138+
ExtendedParser.parseSortOrder(spark, "id ASC NULLS FIRST");
139+
assertThat(result).isSameAs(expected);
140+
141+
verify(icebergParser).parseSortOrder("id ASC NULLS FIRST");
142+
}
143+
144+
/**
145+
* Tests that parseSortOrder throws an exception if no ExtendedParser instance can be found in the
146+
* parser chain.
147+
*
148+
* @throws Exception if reflection access fails
149+
*/
150+
@Test
151+
public void testParseSortOrderThrowsWhenNoExtendedParserFound() throws Exception {
152+
ParserInterface dummy = mock(ParserInterface.class);
153+
setSessionStateParser(spark.sessionState(), dummy);
154+
155+
assertThatThrownBy(() -> ExtendedParser.parseSortOrder(spark, "id ASC"))
156+
.isInstanceOf(IllegalStateException.class)
157+
.hasMessageContaining("Iceberg ExtendedParser");
158+
}
159+
160+
/**
161+
* Tests that parseSortOrder can find an ExtendedParser in a parent class field of the parser.
162+
*
163+
* @throws Exception if reflection access fails
164+
*/
165+
@Test
166+
public void testParseSortOrderFindsExtendedParserInParentClassField() throws Exception {
167+
ExtendedParser icebergParser = mock(ExtendedParser.class);
168+
ExtendedParser.RawOrderField field =
169+
new ExtendedParser.RawOrderField(
170+
mock(Term.class), SortDirection.ASC, NullOrder.NULLS_FIRST);
171+
List<ExtendedParser.RawOrderField> expected = Collections.singletonList(field);
172+
when(icebergParser.parseSortOrder("id ASC NULLS FIRST")).thenReturn(expected);
173+
ParserInterface parser = new GrandChildParser(icebergParser);
174+
setSessionStateParser(spark.sessionState(), parser);
175+
176+
List<ExtendedParser.RawOrderField> result =
177+
ExtendedParser.parseSortOrder(spark, "id ASC NULLS FIRST");
178+
assertThat(result).isSameAs(expected);
179+
verify(icebergParser).parseSortOrder("id ASC NULLS FIRST");
180+
}
181+
182+
private static void setSessionStateParser(Object sessionState, ParserInterface parser)
183+
throws Exception {
184+
Class<?> clazz = sessionState.getClass();
185+
Field targetField = null;
186+
while (clazz != null && targetField == null) {
187+
try {
188+
targetField = clazz.getDeclaredField(SQL_PARSER_FIELD);
189+
} catch (NoSuchFieldException e) {
190+
clazz = clazz.getSuperclass();
191+
}
192+
}
193+
if (targetField == null) {
194+
throw new IllegalStateException(
195+
"No suitable sqlParser field found in sessionState class hierarchy!");
196+
}
197+
targetField.setAccessible(true);
198+
targetField.set(sessionState, parser);
199+
}
200+
201+
private static class WrapperParser extends AbstractSqlParser {
202+
private final ParserInterface delegate;
203+
private String name;
204+
205+
WrapperParser(ParserInterface delegate) {
206+
this.delegate = delegate;
207+
this.name = "delegate";
208+
}
209+
210+
public ParserInterface getDelegate() {
211+
return delegate;
212+
}
213+
214+
@Override
215+
public AstBuilder astBuilder() {
216+
return null;
217+
}
218+
}
219+
220+
private static class ChildParser extends WrapperParser {
221+
ChildParser(ParserInterface parent) {
222+
super(parent);
223+
}
224+
}
225+
226+
private static class GrandChildParser extends ChildParser {
227+
GrandChildParser(ParserInterface parent) {
228+
super(parent);
229+
}
230+
}
231+
}

spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/ExtendedParser.java

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
package org.apache.iceberg.spark;
2020

21+
import java.lang.reflect.Field;
2122
import java.util.List;
2223
import org.apache.iceberg.NullOrder;
2324
import org.apache.iceberg.SortDirection;
@@ -52,10 +53,10 @@ public NullOrder nullOrder() {
5253
}
5354

5455
static List<RawOrderField> parseSortOrder(SparkSession spark, String orderString) {
55-
if (spark.sessionState().sqlParser() instanceof ExtendedParser) {
56-
ExtendedParser parser = (ExtendedParser) spark.sessionState().sqlParser();
56+
ExtendedParser extParser = findParser(spark.sessionState().sqlParser(), ExtendedParser.class);
57+
if (extParser != null) {
5758
try {
58-
return parser.parseSortOrder(orderString);
59+
return extParser.parseSortOrder(orderString);
5960
} catch (AnalysisException e) {
6061
throw new IllegalArgumentException(
6162
String.format("Unable to parse sortOrder: %s", orderString), e);
@@ -66,5 +67,43 @@ static List<RawOrderField> parseSortOrder(SparkSession spark, String orderString
6667
}
6768
}
6869

70+
private static <T> T findParser(ParserInterface parser, Class<T> clazz) {
71+
ParserInterface current = parser;
72+
while (current != null) {
73+
if (clazz.isInstance(current)) {
74+
return clazz.cast(current);
75+
}
76+
77+
ParserInterface next = getNextDelegateParser(current);
78+
if (next == null) {
79+
break;
80+
}
81+
82+
current = next;
83+
}
84+
85+
return null;
86+
}
87+
88+
private static ParserInterface getNextDelegateParser(ParserInterface parser) {
89+
try {
90+
Class<?> clazz = parser.getClass();
91+
while (clazz != null) {
92+
for (Field field : clazz.getDeclaredFields()) {
93+
field.setAccessible(true);
94+
Object value = field.get(parser);
95+
if (value instanceof ParserInterface && value != parser) {
96+
return (ParserInterface) value;
97+
}
98+
}
99+
clazz = clazz.getSuperclass();
100+
}
101+
} catch (Exception e) {
102+
// ignore
103+
}
104+
105+
return null;
106+
}
107+
69108
List<RawOrderField> parseSortOrder(String orderString) throws AnalysisException;
70109
}

0 commit comments

Comments
 (0)