Skip to content

Commit

Permalink
Revert to use the json payload instead of the SecurityScheme class
Browse files Browse the repository at this point in the history
  • Loading branch information
decko committed Aug 20, 2024
1 parent c6d09cb commit af23c4e
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 70 deletions.
82 changes: 54 additions & 28 deletions pulp-glue/pulp_glue/common/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ def __init__(self, security_scheme: SecurityScheme):

if self.security_type == "oauth2":
self.flows: OAuth2Flows = self.security_scheme["flows"]
client_credentials: t.Optional[ClientCredentials] = self.flows.get("clientCredentials")
client_credentials: t.Optional[ClientCredentials] = self.flows.get(
"clientCredentials")
if client_credentials:
self.flow_type: t.Optional[str] = "clientCredentials"
self.token_url: str = client_credentials["tokenUrl"]
self.scopes: OAuth2FlowsScopes = list(client_credentials.get("scopes").keys())
self.scopes: OAuth2FlowsScopes = list(
client_credentials.get("scopes").keys())

if self.security_type == "http":
self.scheme = self.security_scheme["scheme"]
Expand All @@ -81,7 +83,9 @@ def basic_auth(self) -> t.Optional[t.Union[t.Tuple[str, str], requests.auth.Auth
"""Implement this to provide means of http basic auth."""
return None

def oauth2_client_credentials_auth(self, flow: t.Any) -> t.Optional[t.Union[t.Tuple[str, str], requests.auth.AuthBase]]:
def oauth2_client_credentials_auth(
self, flow: t.Any
) -> t.Optional[t.Union[t.Tuple[str, str], requests.auth.AuthBase]]:
"""Implement this to provide other authentication methods."""
return None

Expand All @@ -100,10 +104,11 @@ def __call__(
authorized_schemes_types.add(security_schemes[name]["type"])

if "oauth2" in authorized_schemes_types:
oauth_flow = OpenAPISecurityScheme(
[flow for flow in authorized_schemes if flow["type"] == "oauth2"][0]
)
if oauth_flow.flow_type == "clientCredentials":
oauth_flow = [
flow for flow in authorized_schemes
if flow["type"] == "oauth2"
][0]
if "clientCredentials" in oauth_flow.get("flows"):
result = self.oauth2_client_credentials_auth(oauth_flow)
if result:
return result
Expand Down Expand Up @@ -165,7 +170,8 @@ def __init__(
if not validate_certs:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

self.debug_callback: t.Callable[[int, str], t.Any] = debug_callback or (lambda i, x: None)
self.debug_callback: t.Callable[[
int, str], t.Any] = debug_callback or (lambda i, x: None)
self.base_url: str = base_url
self.doc_path: str = doc_path
self.safe_calls_only: bool = safe_calls_only
Expand Down Expand Up @@ -211,7 +217,8 @@ def load_api(self, refresh_cache: bool = False) -> None:
apidoc_cache: str = os.path.join(
os.path.expanduser(xdg_cache_home),
"squeezer",
(self.base_url + "_" + self.doc_path).replace(":", "_").replace("/", "_") + "api.json",
(self.base_url + "_" + self.doc_path).replace(":",
"_").replace("/", "_") + "api.json",
)
try:
if refresh_cache:
Expand Down Expand Up @@ -243,7 +250,8 @@ def _parse_api(self, data: bytes) -> None:

def _download_api(self) -> bytes:
try:
response: requests.Response = self._session.get(urljoin(self.base_url, self.doc_path))
response: requests.Response = self._session.get(
urljoin(self.base_url, self.doc_path))
except requests.RequestException as e:
raise OpenAPIError(str(e))
response.raise_for_status()
Expand Down Expand Up @@ -282,7 +290,8 @@ def param_spec(
}
)
if required:
param_spec = {k: v for k, v in param_spec.items() if v.get("required", False)}
param_spec = {k: v for k, v in param_spec.items()
if v.get("required", False)}
return param_spec

def extract_params(
Expand Down Expand Up @@ -370,7 +379,8 @@ def validate_schema(self, schema: t.Any, name: str, value: t.Any) -> t.Any:
break
else:
raise OpenAPIValidationError(
_("No schema in anyOf validated for {name}.").format(name=name)
_("No schema in anyOf validated for {name}.").format(
name=name)
)
elif oneOf:
old_value = value
Expand All @@ -380,12 +390,14 @@ def validate_schema(self, schema: t.Any, name: str, value: t.Any) -> t.Any:
value = self.validate_schema(sub_schema, name, old_value)
if found_valid:
raise OpenAPIValidationError(
_("Multiple schemas in oneOf validated for {name}.").format(name=name)
_("Multiple schemas in oneOf validated for {name}.").format(
name=name)
)
found_valid = True
if not found_valid:
raise OpenAPIValidationError(
_("No schema in oneOf validated for {name}.").format(name=name)
_("No schema in oneOf validated for {name}.").format(
name=name)
)
elif not_schema:
try:
Expand All @@ -394,7 +406,8 @@ def validate_schema(self, schema: t.Any, name: str, value: t.Any) -> t.Any:
pass
else:
raise OpenAPIValidationError(
_("Forbidden schema for {name} validated.").format(name=name)
_("Forbidden schema for {name} validated.").format(
name=name)
)
elif schema_type is None:
# Schema type is not specified.
Expand Down Expand Up @@ -428,7 +441,8 @@ def validate_schema(self, schema: t.Any, name: str, value: t.Any) -> t.Any:
# TODO: Add more types here.
else:
raise OpenAPIError(
_("Type `{schema_type}` is not implemented yet.").format(schema_type=schema_type)
_("Type `{schema_type}` is not implemented yet.").format(
schema_type=schema_type)
)
return value

Expand All @@ -442,7 +456,8 @@ def validate_object(self, schema: t.Any, name: str, value: t.Any) -> t.Dict[str,
if properties or additional_properties is not None:
value = value.copy()
for property_name, property_value in value.items():
property_schema = properties.get(property_name, additional_properties)
property_schema = properties.get(
property_name, additional_properties)
if not property_schema:
raise OpenAPIValidationError(
_("Unexpected property '{property_name}' for '{name}' provided.").format(
Expand All @@ -464,7 +479,8 @@ def validate_object(self, schema: t.Any, name: str, value: t.Any) -> t.Dict[str,

def validate_array(self, schema: t.Any, name: str, value: t.Any) -> t.List[t.Any]:
if not isinstance(value, t.List):
raise OpenAPIValidationError(_("'{name}' is expected to be a list.").format(name=name))
raise OpenAPIValidationError(
_("'{name}' is expected to be a list.").format(name=name))
item_schema = schema["items"]
return [self.validate_schema(item_schema, name, item) for item in value]

Expand Down Expand Up @@ -549,7 +565,8 @@ def render_request_body(
request_body_spec = method_spec["requestBody"]
except KeyError:
if body is not None:
raise OpenAPIError(_("This operation does not expect a request body."))
raise OpenAPIError(
_("This operation does not expect a request body."))
return None, None, None, None
else:
body_required = request_body_spec.get("required", False)
Expand All @@ -562,7 +579,8 @@ def render_request_body(
content_type: t.Optional[str] = None
data: t.Optional[t.Dict[str, t.Any]] = None
json: t.Optional[t.Dict[str, t.Any]] = None
files: t.Optional[t.List[t.Tuple[str, t.Tuple[str, UploadType, str]]]] = None
files: t.Optional[t.List[t.Tuple[str,
t.Tuple[str, UploadType, str]]]] = None

candidate_content_types = [
"multipart/form-data",
Expand Down Expand Up @@ -614,7 +632,8 @@ def render_request_body(
else:
data[key] = value
if uploads:
files = [(key, upload_data) for key, upload_data in uploads.items()]
files = [(key, upload_data)
for key, upload_data in uploads.items()]
break
else:
# No known content-type left
Expand All @@ -641,7 +660,8 @@ def render_request(
validate_body: bool = True,
) -> requests.PreparedRequest:
method_spec = path_spec[method]
content_type, data, json, files = self.render_request_body(method_spec, body, validate_body)
content_type, data, json, files = self.render_request_body(
method_spec, body, validate_body)
security: t.List[t.Dict[str, t.List[str]]] = method_spec.get(
"security", self.api_spec.get("security", {})
)
Expand All @@ -650,7 +670,8 @@ def render_request(
# Bad idea, but you wanted it that way.
auth = None
else:
auth = self.auth_provider(security, self.api_spec["components"]["securitySchemes"])
auth = self.auth_provider(
security, self.api_spec["components"]["securitySchemes"])
else:
# No auth required? Don't provide it.
# No auth_provider available? Hope for the best (should do the trick for cert auth).
Expand Down Expand Up @@ -682,7 +703,8 @@ def parse_response(self, method_spec: t.Dict[str, t.Any], response: requests.Res
except KeyError:
# Fallback 201 -> 200
try:
response_spec = method_spec["responses"][str(100 * int(response.status_code / 100))]
response_spec = method_spec["responses"][str(
100 * int(response.status_code / 100))]
except KeyError:
raise OpenAPIError(
_("Unexpected response '{code}' (expected '{expected}').").format(
Expand Down Expand Up @@ -727,14 +749,17 @@ def call(
parameters = parameters.copy()

if any(self.extract_params("cookie", path_spec, method_spec, parameters)):
raise NotImplementedError(_("Cookie parameters are not implemented."))
raise NotImplementedError(
_("Cookie parameters are not implemented."))

headers = self.extract_params("header", path_spec, method_spec, parameters)
headers = self.extract_params(
"header", path_spec, method_spec, parameters)

for name, value in self.extract_params("path", path_spec, method_spec, parameters).items():
path = path.replace("{" + name + "}", value)

query_params = self.extract_params("query", path_spec, method_spec, parameters)
query_params = self.extract_params(
"query", path_spec, method_spec, parameters)

if any(parameters):
raise OpenAPIError(
Expand Down Expand Up @@ -773,7 +798,8 @@ def call(
except requests.RequestException as e:
raise OpenAPIError(str(e))
self.debug_callback(
1, _("Response: {status_code}").format(status_code=response.status_code)
1, _("Response: {status_code}").format(
status_code=response.status_code)
)
for key, value in response.headers.items():
self.debug_callback(2, f" {key}: {value}")
Expand Down
Loading

0 comments on commit af23c4e

Please sign in to comment.