Skip to content

Commit

Permalink
[SPARK-46568][PYTHON] Make Python data source options a case-insensit…
Browse files Browse the repository at this point in the history
…ive dictionary

### What changes were proposed in this pull request?

This PR updates the `options` field to use a case-insensitive dictionary to keep the behavior consistent with the Scala side (which uses `CaseInsensitiveStringMap`). Currently, `options` are stored in a normal Python dictionary which can be confusing to users. For instance:
```python
class MyDataSource(DataSource):
    def __init__(self, options):
        self.api_key = options.get("API_KEY") # <- This is None

spark.read.format(..).option("API_KEY", my_key).load(...)
```
Here, `options` will not have this "API_KEY" as everything is converted to lowercase on the Scala side.

### Why are the changes needed?

To improve usability.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

New unit tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#44564 from allisonwang-db/spark-46568-ds-options.

Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
allisonwang-db authored and dongjoon-hyun committed Jan 5, 2024
1 parent 20b6a32 commit a98c885
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 12 deletions.
51 changes: 45 additions & 6 deletions python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@
# limitations under the License.
#
from abc import ABC, abstractmethod
from typing import final, Any, Dict, Iterator, List, Sequence, Tuple, Type, Union, TYPE_CHECKING
from collections import UserDict
from typing import Any, Dict, Iterator, List, Sequence, Tuple, Type, Union, TYPE_CHECKING

from pyspark.sql import Row
from pyspark.sql.types import StructType
from pyspark.errors import PySparkNotImplementedError

if TYPE_CHECKING:
from pyspark.sql._typing import OptionalPrimitiveType
from pyspark.sql.session import SparkSession


__all__ = ["DataSource", "DataSourceReader", "DataSourceWriter", "DataSourceRegistration"]
__all__ = [
"DataSource",
"DataSourceReader",
"DataSourceWriter",
"DataSourceRegistration",
"InputPartition",
"WriterCommitMessage",
]


class DataSource(ABC):
Expand All @@ -45,15 +52,14 @@ class DataSource(ABC):
.. versionadded: 4.0.0
"""

@final
def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None:
def __init__(self, options: Dict[str, str]) -> None:
"""
Initializes the data source with user-provided options.
Parameters
----------
options : dict
A dictionary representing the options for this data source.
A case-insensitive dictionary representing the options for this data source.
Notes
-----
Expand Down Expand Up @@ -403,3 +409,36 @@ def register(
assert sc._jvm is not None
ds = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped)
self.sparkSession._jsparkSession.dataSource().registerPython(name, ds)


class CaseInsensitiveDict(UserDict):
"""
A case-insensitive map of string keys to values.
This is used by Python data source options to ensure consistent case insensitivity.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.update(*args, **kwargs)

def __setitem__(self, key: str, value: Any) -> None:
super().__setitem__(key.lower(), value)

def __getitem__(self, key: str) -> Any:
return super().__getitem__(key.lower())

def __delitem__(self, key: str) -> None:
super().__delitem__(key.lower())

def __contains__(self, key: object) -> bool:
if isinstance(key, str):
return super().__contains__(key.lower())
return False

def update(self, *args: Any, **kwargs: Any) -> None:
for k, v in dict(*args, **kwargs).items():
self[k] = v

def copy(self) -> "CaseInsensitiveDict":
return type(self)(self)
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
InputPartition,
DataSourceWriter,
WriterCommitMessage,
CaseInsensitiveDict,
)
from pyspark.sql.types import Row, StructType
from pyspark.testing import assertDataFrameEqual
Expand Down Expand Up @@ -346,6 +347,26 @@ def test_custom_json_data_source_abort(self):
text = file.read()
assert text == "failed"

def test_case_insensitive_dict(self):
d = CaseInsensitiveDict({"foo": 1, "Bar": 2})
self.assertEqual(d["foo"], d["FOO"])
self.assertEqual(d["bar"], d["BAR"])
self.assertTrue("baR" in d)
d["BAR"] = 3
self.assertEqual(d["BAR"], 3)
# Test update
d.update({"BaZ": 3})
self.assertEqual(d["BAZ"], 3)
d.update({"FOO": 4})
self.assertEqual(d["foo"], 4)
# Test delete
del d["FoO"]
self.assertFalse("FOO" in d)
# Test copy
d2 = d.copy()
self.assertEqual(d2["BaR"], 3)
self.assertEqual(d2["baz"], 3)


class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
...
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/worker/create_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
write_with_length,
SpecialLengths,
)
from pyspark.sql.datasource import DataSource
from pyspark.sql.datasource import DataSource, CaseInsensitiveDict
from pyspark.sql.types import _parse_datatype_json_string, StructType
from pyspark.util import handle_worker_exception
from pyspark.worker_util import (
Expand Down Expand Up @@ -120,7 +120,7 @@ def main(infile: IO, outfile: IO) -> None:
)

# Receive the options.
options = dict()
options = CaseInsensitiveDict()
num_options = read_int(infile)
for _ in range(num_options):
key = utf8_deserializer.loads(infile)
Expand All @@ -129,7 +129,7 @@ def main(infile: IO, outfile: IO) -> None:

# Instantiate a data source.
try:
data_source = data_source_cls(options=options)
data_source = data_source_cls(options=options) # type: ignore
except Exception as e:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/worker/write_into_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
SpecialLengths,
)
from pyspark.sql import Row
from pyspark.sql.datasource import DataSource, WriterCommitMessage
from pyspark.sql.datasource import DataSource, WriterCommitMessage, CaseInsensitiveDict
from pyspark.sql.types import (
_parse_datatype_json_string,
StructType,
Expand Down Expand Up @@ -142,7 +142,7 @@ def main(infile: IO, outfile: IO) -> None:
return_col_name = return_type[0].name

# Receive the options.
options = dict()
options = CaseInsensitiveDict()
num_options = read_int(infile)
for _ in range(num_options):
key = utf8_deserializer.loads(infile)
Expand All @@ -153,7 +153,7 @@ def main(infile: IO, outfile: IO) -> None:
overwrite = read_bool(infile)

# Instantiate a data source.
data_source = data_source_cls(options=options)
data_source = data_source_cls(options=options) # type: ignore

# Instantiate the data source writer.
writer = data_source.writer(schema, overwrite)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -790,4 +790,47 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
}
}
}

test("SPARK-46568: case insensitive options") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import (
| DataSource, DataSourceReader, DataSourceWriter, WriterCommitMessage)
|class SimpleDataSourceReader(DataSourceReader):
| def __init__(self, options):
| self.options = options
|
| def read(self, partition):
| foo = self.options.get("Foo")
| bar = self.options.get("BAR")
| baz = "BaZ" in self.options
| yield (foo, bar, baz)
|
|class SimpleDataSourceWriter(DataSourceWriter):
| def __init__(self, options):
| self.options = options
|
| def write(self, row):
| if "FOO" not in self.options or "BAR" not in self.options:
| raise Exception("FOO or BAR not found")
| return WriterCommitMessage()
|
|class $dataSourceName(DataSource):
| def schema(self) -> str:
| return "a string, b string, c string"
|
| def reader(self, schema):
| return SimpleDataSourceReader(self.options)
|
| def writer(self, schema, overwrite):
| return SimpleDataSourceWriter(self.options)
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
spark.dataSource.registerPython(dataSourceName, dataSource)
val df = spark.read.option("foo", 1).option("bar", 2).option("BAZ", 3)
.format(dataSourceName).load()
checkAnswer(df, Row("1", "2", "true"))
df.write.option("foo", 1).option("bar", 2).format(dataSourceName).mode("append").save()
}
}

0 comments on commit a98c885

Please sign in to comment.