Skip to content

Commit

Permalink
Merge branch 'main' into renaud.hartert/stream-reset
Browse files Browse the repository at this point in the history
  • Loading branch information
renaudhartert-db authored Nov 14, 2024
2 parents 6d2c183 + ee6e70a commit 7ecc27b
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 55 deletions.
19 changes: 10 additions & 9 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,16 @@ class Config:
max_connections_per_pool: int = ConfigAttribute()
databricks_environment: Optional[DatabricksEnvironment] = None

def __init__(self,
*,
# Deprecated. Use credentials_strategy instead.
credentials_provider: Optional[CredentialsStrategy] = None,
credentials_strategy: Optional[CredentialsStrategy] = None,
product=None,
product_version=None,
clock: Optional[Clock] = None,
**kwargs):
def __init__(
self,
*,
# Deprecated. Use credentials_strategy instead.
credentials_provider: Optional[CredentialsStrategy] = None,
credentials_strategy: Optional[CredentialsStrategy] = None,
product=None,
product_version=None,
clock: Optional[Clock] = None,
**kwargs):
self._header_factory = None
self._inner = {}
self._user_agent_other_info = []
Expand Down
11 changes: 6 additions & 5 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,12 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
# detect Azure AD Tenant ID if it's not specified directly
token_endpoint = cfg.oidc_endpoints.token_endpoint
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
inner = ClientCredentials(client_id=cfg.azure_client_id,
client_secret="", # we have no (rotatable) secrets in OIDC flow
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
endpoint_params=params,
use_params=True)
inner = ClientCredentials(
client_id=cfg.azure_client_id,
client_secret="", # we have no (rotatable) secrets in OIDC flow
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
endpoint_params=params,
use_params=True)

def refreshed_headers() -> Dict[str, str]:
token = inner.token()
Expand Down
19 changes: 10 additions & 9 deletions tests/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,16 @@ def _test_runtime_auth_from_jobs_inner(w, env_or_skip, random, dbr_versions, lib

tasks = []
for v in dbr_versions:
t = Task(task_key=f'test_{v.key.replace(".", "_")}',
notebook_task=NotebookTask(notebook_path=notebook_path),
new_cluster=ClusterSpec(
spark_version=v.key,
num_workers=1,
instance_pool_id=instance_pool_id,
# GCP uses "custom" data security mode by default, which does not support UC.
data_security_mode=DataSecurityMode.SINGLE_USER),
libraries=[library])
t = Task(
task_key=f'test_{v.key.replace(".", "_")}',
notebook_task=NotebookTask(notebook_path=notebook_path),
new_cluster=ClusterSpec(
spark_version=v.key,
num_workers=1,
instance_pool_id=instance_pool_id,
# GCP uses "custom" data security mode by default, which does not support UC.
data_security_mode=DataSecurityMode.SINGLE_USER),
libraries=[library])
tasks.append(t)

waiter = w.jobs.submit(run_name=f'Runtime Native Auth {random(10)}', tasks=tasks)
Expand Down
25 changes: 13 additions & 12 deletions tests/integration/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ def test_submitting_jobs(w, random, env_or_skip):
with w.dbfs.open(py_on_dbfs, write=True, overwrite=True) as f:
f.write(b'import time; time.sleep(10); print("Hello, World!")')

waiter = w.jobs.submit(run_name=f'py-sdk-{random(8)}',
tasks=[
jobs.SubmitTask(
task_key='pi',
new_cluster=compute.ClusterSpec(
spark_version=w.clusters.select_spark_version(long_term_support=True),
# node_type_id=w.clusters.select_node_type(local_disk=True),
instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'),
num_workers=1),
spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'),
)
])
waiter = w.jobs.submit(
run_name=f'py-sdk-{random(8)}',
tasks=[
jobs.SubmitTask(
task_key='pi',
new_cluster=compute.ClusterSpec(
spark_version=w.clusters.select_spark_version(long_term_support=True),
# node_type_id=w.clusters.select_node_type(local_disk=True),
instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'),
num_workers=1),
spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'),
)
])

logging.info(f'starting to poll: {waiter.run_id}')

Expand Down
12 changes: 7 additions & 5 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,13 @@ def inner(h: BaseHTTPRequestHandler):
assert len(requests) == 2


@pytest.mark.parametrize('chunk_size,expected_chunks,data_size',
[(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
])
@pytest.mark.parametrize(
'chunk_size,expected_chunks,data_size',
[
(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
])
def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size):
rng = random.Random(42)
test_data = bytes(rng.getrandbits(8) for _ in range(data_size))
Expand Down
22 changes: 14 additions & 8 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,14 +370,20 @@ def inner(h: BaseHTTPRequestHandler):
assert {'Authorization': 'Taker this-is-it'} == headers


@pytest.mark.parametrize(['azure_environment', 'expected'],
[('PUBLIC', ENVIRONMENTS['PUBLIC']), ('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']),
('CHINA', ENVIRONMENTS['CHINA']), ('public', ENVIRONMENTS['PUBLIC']),
('usgovernment', ENVIRONMENTS['USGOVERNMENT']), ('china', ENVIRONMENTS['CHINA']),
# Kept for historical compatibility
('AzurePublicCloud', ENVIRONMENTS['PUBLIC']),
('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']),
('AzureChinaCloud', ENVIRONMENTS['CHINA']), ])
@pytest.mark.parametrize(
['azure_environment', 'expected'],
[
('PUBLIC', ENVIRONMENTS['PUBLIC']),
('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']),
('CHINA', ENVIRONMENTS['CHINA']),
('public', ENVIRONMENTS['PUBLIC']),
('usgovernment', ENVIRONMENTS['USGOVERNMENT']),
('china', ENVIRONMENTS['CHINA']),
# Kept for historical compatibility
('AzurePublicCloud', ENVIRONMENTS['PUBLIC']),
('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']),
('AzureChinaCloud', ENVIRONMENTS['CHINA']),
])
def test_azure_environment(azure_environment, expected):
c = Config(credentials_strategy=noop_credentials,
azure_workspace_resource_id='...',
Expand Down
17 changes: 10 additions & 7 deletions tests/test_model_serving_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token'


@pytest.mark.parametrize("env_values, oauth_file_name", [
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')], "invalid_file_name"), # In Model Serving and Invalid File Name
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
], "invalid_file_name"), # In Model Serving and Invalid File Name
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
])
@pytest.mark.parametrize(
"env_values, oauth_file_name",
[
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')
], "invalid_file_name"), # In Model Serving and Invalid File Name
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
], "invalid_file_name"), # In Model Serving and Invalid File Name
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
])
@raises(default_auth_base_error_message)
def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
# Guarantee that the tests defaults to env variables rather than config file.
Expand Down

0 comments on commit 7ecc27b

Please sign in to comment.