Skip to content

Commit

Permalink
feat: implement file upload with rest client
Browse files Browse the repository at this point in the history
  • Loading branch information
jameszyao committed Apr 25, 2024
1 parent 9932123 commit 315d2f4
Show file tree
Hide file tree
Showing 6 changed files with 516 additions and 395 deletions.
3 changes: 2 additions & 1 deletion taskingai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import tool
from . import retrieval
from . import inference
from . import file
from ._version import __version__

__all__ = [
Expand All @@ -11,4 +12,4 @@
"retrieval",
"inference",
"__version__",
]
]
45 changes: 10 additions & 35 deletions taskingai/client/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import datetime
import json
import mimetypes
import os
import re
import tempfile
Expand Down Expand Up @@ -101,32 +100,6 @@ def deserialize(self, response, response_type: Type[BaseModel]):

return response_type(**data)

def prepare_post_parameters(self, post_params=None, files=None):
"""Builds form parameters.
:param post_params: Normal form parameters.
:param files: File parameters.
:return: Form parameters with files.
"""
params = []

if post_params:
params = post_params

if files:
for k, v in six.iteritems(files):
if not v:
continue
file_names = v if type(v) is list else [v]
for n in file_names:
with open(n, "rb") as f:
filename = os.path.basename(f.name)
filedata = f.read()
mimetype = mimetypes.guess_type(filename)[0] or "application/octet-stream"
params.append(tuple([k, tuple([filename, filedata, mimetype])]))

return params

def select_header_accept(self, accepts):
"""Returns `Accept` based on an array of accepts provided.
Expand Down Expand Up @@ -252,10 +225,6 @@ def __call_api(
# specified safe chars, encode everything
resource_path = resource_path.replace("{%s}" % k, quote(str(v), safe=config.safe_chars_for_path_param))

# post parameters
if post_params or files:
post_params = self.prepare_post_parameters(post_params, files)

# auth setting
self.update_params_for_auth(header_params, query_params, auth_settings)

Expand All @@ -272,6 +241,7 @@ def __call_api(
query_params=query_params,
headers=header_params,
post_params=post_params,
files=files,
body=body,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
Expand Down Expand Up @@ -367,6 +337,7 @@ def request(
query_params=None,
headers=None,
post_params=None,
files=None,
body=None,
_preload_content=True,
_request_timeout=None,
Expand Down Expand Up @@ -408,6 +379,7 @@ def request(
query_params=query_params,
headers=headers,
post_params=post_params,
files=files,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
body=body,
Expand All @@ -419,6 +391,7 @@ def request(
query_params=query_params,
headers=headers,
post_params=post_params,
files=files,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
body=body,
Expand All @@ -430,6 +403,7 @@ def request(
query_params=query_params,
headers=headers,
post_params=post_params,
files=files,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
body=body,
Expand Down Expand Up @@ -489,10 +463,6 @@ async def __call_api(
# specified safe chars, encode everything
resource_path = resource_path.replace("{%s}" % k, quote(str(v), safe=config.safe_chars_for_path_param))

# post parameters
if post_params or files:
post_params = self.prepare_post_parameters(post_params, files)

# auth setting
self.update_params_for_auth(header_params, query_params, auth_settings)

Expand All @@ -509,6 +479,7 @@ async def __call_api(
query_params=query_params,
headers=header_params,
post_params=post_params,
files=files,
body=body,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
Expand Down Expand Up @@ -601,6 +572,7 @@ async def request(
query_params=None,
headers=None,
post_params=None,
files=None,
body=None,
_preload_content=True,
_request_timeout=None,
Expand Down Expand Up @@ -642,6 +614,7 @@ async def request(
query_params=query_params,
headers=headers,
post_params=post_params,
files=files,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
body=body,
Expand All @@ -653,6 +626,7 @@ async def request(
query_params=query_params,
headers=headers,
post_params=post_params,
files=files,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
body=body,
Expand All @@ -664,6 +638,7 @@ async def request(
query_params=query_params,
headers=headers,
post_params=post_params,
files=files,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
body=body,
Expand Down
Loading

0 comments on commit 315d2f4

Please sign in to comment.