Skip to content

Commit 39cea8b

Browse files
committed
Merge remote-tracking branch 'origin/master' into push-v2-integration
2 parents 747ea46 + fe833f1 commit 39cea8b

File tree

7 files changed

+110
-57
lines changed

7 files changed

+110
-57
lines changed

mergin/client.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import warnings
1919
from time import sleep
2020

21-
from typing import List
22-
2321
from .common import (
2422
PUSH_ATTEMPT_WAIT,
2523
PUSH_ATTEMPTS,
@@ -75,15 +73,30 @@ class ServerType(Enum):
7573

7674

7775
def decode_token_data(token):
78-
token_prefix = "Bearer ."
76+
token_prefix = "Bearer "
7977
if not token.startswith(token_prefix):
80-
raise TokenError(f"Token doesn't start with 'Bearer .': {token}")
78+
raise TokenError(f"Token doesn't start with 'Bearer ': {token}")
8179
try:
82-
data = token[len(token_prefix) :].split(".")[0]
83-
# add proper base64 padding"
84-
data += "=" * (-len(data) % 4)
85-
decoded = zlib.decompress(base64.urlsafe_b64decode(data))
86-
return json.loads(decoded)
80+
token_raw = token[len(token_prefix) :]
81+
is_compressed = False
82+
83+
# compressed tokens start with dot,
84+
# see https://github.com/pallets/itsdangerous/blob/main/src/itsdangerous/url_safe.py#L55
85+
if token_raw.startswith("."):
86+
token_raw = token_raw.lstrip(".")
87+
is_compressed = True
88+
89+
payload_raw = token_raw.split(".")[0]
90+
91+
# add proper base64 padding
92+
payload_raw += "=" * (-len(payload_raw) % 4)
93+
payload_data = base64.urlsafe_b64decode(payload_raw)
94+
95+
if is_compressed:
96+
payload_data = zlib.decompress(payload_data)
97+
98+
return json.loads(payload_data)
99+
87100
except (IndexError, TypeError, ValueError, zlib.error):
88101
raise TokenError(f"Invalid token data: {token}")
89102

@@ -214,12 +227,14 @@ def user_agent_info(self):
214227

215228
def validate_auth(self):
216229
"""Validate that client has valid auth token or can be logged in."""
217-
218230
if self._auth_session:
219231
# Refresh auth token if it expired or will expire very soon
220-
delta = self._auth_session["expire"] - datetime.now(timezone.utc)
221-
if delta.total_seconds() < 5:
222-
self.log.info("Token has expired - refreshing...")
232+
expire = self._auth_session.get("expire")
233+
now = datetime.now(timezone.utc)
234+
delta = expire - now
235+
delta_seconds = delta.total_seconds()
236+
if delta_seconds < 5:
237+
self.log.debug(f"Token has expired: expire={expire} now={now} delta={delta_seconds:.1f}s")
223238
if self._auth_params.get("login", None) and self._auth_params.get("password", None):
224239
self.log.info("Token has expired - refreshing...")
225240
self.login(self._auth_params["login"], self._auth_params["password"])
@@ -286,7 +301,7 @@ def get(self, path, data=None, headers={}, validate_auth=True):
286301
request = urllib.request.Request(url, headers=headers)
287302
return self._do_request(request, validate_auth=validate_auth)
288303

289-
def post(self, path, data=None, headers={}, validate_auth=True, query_params: dict[str, str] = None):
304+
def post(self, path, data=None, headers={}, validate_auth=True, query_params: typing.Dict[str, str] = None):
290305
url = urllib.parse.urljoin(self.url, urllib.parse.quote(path))
291306
if query_params:
292307
url += "?" + urllib.parse.urlencode(query_params)
@@ -318,7 +333,6 @@ def login(self, login, password):
318333
:type password: String
319334
"""
320335
params = {"login": login, "password": password}
321-
self._auth_session = None
322336
self.log.info(f"Going to log in user {login}")
323337
try:
324338
resp = self.post(
@@ -329,12 +343,14 @@ def login(self, login, password):
329343
except ClientError as e:
330344
self.log.info(f"Login problem: {e.detail}")
331345
raise LoginError(e.detail)
346+
expires = dateutil.parser.parse(session["expire"])
332347
self._auth_session = {
333348
"token": f"Bearer {session['token']}",
334-
"expire": dateutil.parser.parse(session["expire"]),
349+
"expire": expires,
335350
}
336351
self._user_info = {"username": data["username"]}
337352
self.log.info(f"User {data['username']} successfully logged in.")
353+
self.log.debug(f"The auth token expires at {expires}")
338354
return session
339355

340356
def username(self):
@@ -1352,7 +1368,7 @@ def get_workspace_member(self, workspace_id: int, user_id: int) -> dict:
13521368
resp = self.get(f"v2/workspaces/{workspace_id}/members/{user_id}")
13531369
return json.load(resp)
13541370

1355-
def list_workspace_members(self, workspace_id: int) -> List[dict]:
1371+
def list_workspace_members(self, workspace_id: int) -> typing.List[dict]:
13561372
"""
13571373
Get a list of workspace members
13581374
"""
@@ -1383,7 +1399,7 @@ def remove_workspace_member(self, workspace_id: int, user_id: int):
13831399
self.check_collaborators_members_support()
13841400
self.delete(f"v2/workspaces/{workspace_id}/members/{user_id}")
13851401

1386-
def list_project_collaborators(self, project_id: str) -> List[dict]:
1402+
def list_project_collaborators(self, project_id: str) -> typing.List[dict]:
13871403
"""
13881404
Get a list of project collaborators
13891405
"""

mergin/client_pull.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,23 @@ class DownloadJob:
4444
"""
4545

4646
def __init__(
47-
self, project_path, total_size, version, update_tasks, download_queue_items, directory, mp, project_info
47+
self,
48+
project_path,
49+
total_size,
50+
version,
51+
update_tasks,
52+
download_queue_items,
53+
tmp_dir: tempfile.TemporaryDirectory,
54+
mp,
55+
project_info,
4856
):
4957
self.project_path = project_path
5058
self.total_size = total_size # size of data to download (in bytes)
5159
self.transferred_size = 0
5260
self.version = version
5361
self.update_tasks = update_tasks
5462
self.download_queue_items = download_queue_items
55-
self.directory = directory # project's directory
63+
self.tmp_dir = tmp_dir
5664
self.mp = mp # MerginProject instance
5765
self.is_cancelled = False
5866
self.project_info = project_info # parsed JSON with project info returned from the server
@@ -96,7 +104,7 @@ def _do_download(item, mc, mp, project_path, job):
96104
job.transferred_size += item.size
97105

98106

99-
def _cleanup_failed_download(directory, mergin_project=None):
107+
def _cleanup_failed_download(mergin_project: MerginProject = None):
100108
"""
101109
If a download job fails, there will be the newly created directory left behind with some
102110
temporary files in it. We want to remove it because a new download would fail because
@@ -109,7 +117,7 @@ def _cleanup_failed_download(directory, mergin_project=None):
109117
mergin_project.remove_logging_handler()
110118

111119
# keep log file as it might contain useful debug info
112-
log_file = os.path.join(directory, ".mergin", "client-log.txt")
120+
log_file = os.path.join(mergin_project.dir, ".mergin", "client-log.txt")
113121
dest_path = None
114122

115123
if os.path.exists(log_file):
@@ -118,7 +126,6 @@ def _cleanup_failed_download(directory, mergin_project=None):
118126
dest_path = tmp_file.name
119127
shutil.copyfile(log_file, dest_path)
120128

121-
shutil.rmtree(directory)
122129
return dest_path
123130

124131

@@ -138,6 +145,8 @@ def download_project_async(mc, project_path, directory, project_version=None):
138145
mp.log.info("--- version: " + mc.user_agent_info())
139146
mp.log.info(f"--- start download {project_path}")
140147

148+
tmp_dir = tempfile.TemporaryDirectory(prefix="python-api-client-", ignore_cleanup_errors=True, delete=True)
149+
141150
try:
142151
# check whether we download the latest version or not
143152
latest_proj_info = mc.project_info(project_path)
@@ -147,7 +156,7 @@ def download_project_async(mc, project_path, directory, project_version=None):
147156
project_info = latest_proj_info
148157

149158
except ClientError:
150-
_cleanup_failed_download(directory, mp)
159+
_cleanup_failed_download(mp)
151160
raise
152161

153162
version = project_info["version"] if project_info["version"] else "v0"
@@ -158,7 +167,7 @@ def download_project_async(mc, project_path, directory, project_version=None):
158167
update_tasks = [] # stuff to do at the end of download
159168
for file in project_info["files"]:
160169
file["version"] = version
161-
items = _download_items(file, directory)
170+
items = _download_items(file, tmp_dir.name)
162171
is_latest_version = project_version == latest_proj_info["version"]
163172
update_tasks.append(UpdateTask(file["path"], items, latest_version=is_latest_version))
164173

@@ -172,7 +181,7 @@ def download_project_async(mc, project_path, directory, project_version=None):
172181

173182
mp.log.info(f"will download {len(update_tasks)} files in {len(download_list)} chunks, total size {total_size}")
174183

175-
job = DownloadJob(project_path, total_size, version, update_tasks, download_list, directory, mp, project_info)
184+
job = DownloadJob(project_path, total_size, version, update_tasks, download_list, tmp_dir, mp, project_info)
176185

177186
# start download
178187
job.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
@@ -203,7 +212,7 @@ def download_project_is_running(job):
203212
traceback_lines = traceback.format_exception(type(exc), exc, exc.__traceback__)
204213
job.mp.log.error("Error while downloading project: " + "".join(traceback_lines))
205214
job.mp.log.info("--- download aborted")
206-
job.failure_log_file = _cleanup_failed_download(job.directory, job.mp)
215+
job.failure_log_file = _cleanup_failed_download(job.mp)
207216
raise future.exception()
208217
if future.running():
209218
return True
@@ -229,18 +238,20 @@ def download_project_finalize(job):
229238
traceback_lines = traceback.format_exception(type(exc), exc, exc.__traceback__)
230239
job.mp.log.error("Error while downloading project: " + "".join(traceback_lines))
231240
job.mp.log.info("--- download aborted")
232-
job.failure_log_file = _cleanup_failed_download(job.directory, job.mp)
241+
job.failure_log_file = _cleanup_failed_download(job.mp)
233242
raise future.exception()
234243

235244
job.mp.log.info("--- download finished")
236245

237246
for task in job.update_tasks:
238247
# right now only copy tasks...
239-
task.apply(job.directory, job.mp)
248+
task.apply(job.mp.dir, job.mp)
240249

241250
# final update of project metadata
242251
job.mp.update_metadata(job.project_info)
243252

253+
job.tmp_dir.cleanup()
254+
244255

245256
def download_project_cancel(job):
246257
"""
@@ -336,7 +347,7 @@ def __init__(
336347
version,
337348
files_to_merge,
338349
download_queue_items,
339-
temp_dir,
350+
tmp_dir,
340351
mp,
341352
project_info,
342353
basefiles_to_patch,
@@ -351,7 +362,7 @@ def __init__(
351362
self.version = version
352363
self.files_to_merge = files_to_merge # list of FileToMerge instances
353364
self.download_queue_items = download_queue_items
354-
self.temp_dir = temp_dir # full path to temporary directory where we store downloaded files
365+
self.tmp_dir = tmp_dir # TemporaryDirectory instance where we store downloaded files
355366
self.mp = mp # MerginProject instance
356367
self.is_cancelled = False
357368
self.project_info = project_info # parsed JSON with project info returned from the server
@@ -413,8 +424,7 @@ def pull_project_async(mc, directory) -> PullJob:
413424
# then we just download the whole file
414425
_pulling_file_with_diffs = lambda f: "diffs" in f and len(f["diffs"]) != 0
415426

416-
temp_dir = mp.fpath_meta(f"fetch_{local_version}-{server_version}")
417-
os.makedirs(temp_dir, exist_ok=True)
427+
tmp_dir = tempfile.TemporaryDirectory(prefix="mm-pull-", ignore_cleanup_errors=True, delete=True)
418428
pull_changes = mp.get_pull_changes(server_info["files"])
419429
mp.log.debug("pull changes:\n" + pprint.pformat(pull_changes))
420430
fetch_files = []
@@ -441,10 +451,10 @@ def pull_project_async(mc, directory) -> PullJob:
441451

442452
for file in fetch_files:
443453
diff_only = _pulling_file_with_diffs(file)
444-
items = _download_items(file, temp_dir, diff_only)
454+
items = _download_items(file, tmp_dir.name, diff_only)
445455

446456
# figure out destination path for the file
447-
file_dir = os.path.dirname(os.path.normpath(os.path.join(temp_dir, file["path"])))
457+
file_dir = os.path.dirname(os.path.normpath(os.path.join(tmp_dir.name, file["path"])))
448458
basename = os.path.basename(file["diff"]["path"]) if diff_only else os.path.basename(file["path"])
449459
dest_file_path = os.path.join(file_dir, basename)
450460
os.makedirs(file_dir, exist_ok=True)
@@ -465,8 +475,8 @@ def pull_project_async(mc, directory) -> PullJob:
465475
file_path = file["path"]
466476
mp.log.info(f"missing base file for {file_path} -> going to download it (version {server_version})")
467477
file["version"] = server_version
468-
items = _download_items(file, temp_dir, diff_only=False)
469-
dest_file_path = mp.fpath(file["path"], temp_dir)
478+
items = _download_items(file, tmp_dir.name, diff_only=False)
479+
dest_file_path = mp.fpath(file["path"], tmp_dir.name)
470480
# dest_file_path = os.path.join(os.path.dirname(os.path.normpath(os.path.join(temp_dir, file['path']))), os.path.basename(file['path']))
471481
files_to_merge.append(FileToMerge(dest_file_path, items))
472482
continue
@@ -490,7 +500,7 @@ def pull_project_async(mc, directory) -> PullJob:
490500
server_version,
491501
files_to_merge,
492502
download_list,
493-
temp_dir,
503+
tmp_dir,
494504
mp,
495505
server_info,
496506
basefiles_to_patch,
@@ -604,10 +614,10 @@ def pull_project_finalize(job: PullJob):
604614
# download their full versions so we have them up-to-date for applying changes
605615
for file_path, file_diffs in job.basefiles_to_patch:
606616
basefile = job.mp.fpath_meta(file_path)
607-
server_file = job.mp.fpath(file_path, job.temp_dir)
617+
server_file = job.mp.fpath(file_path, job.tmp_dir.name)
608618

609619
shutil.copy(basefile, server_file)
610-
diffs = [job.mp.fpath(f, job.temp_dir) for f in file_diffs]
620+
diffs = [job.mp.fpath(f, job.tmp_dir.name) for f in file_diffs]
611621
patch_error = job.mp.apply_diffs(server_file, diffs)
612622
if patch_error:
613623
# that's weird that we are unable to apply diffs to the basefile!
@@ -623,7 +633,7 @@ def pull_project_finalize(job: PullJob):
623633
raise ClientError("Cannot patch basefile {}! Please try syncing again.".format(basefile))
624634

625635
try:
626-
conflicts = job.mp.apply_pull_changes(job.pull_changes, job.temp_dir, job.project_info, job.mc)
636+
conflicts = job.mp.apply_pull_changes(job.pull_changes, job.tmp_dir.name, job.project_info, job.mc)
627637
except Exception as e:
628638
job.mp.log.error("Failed to apply pull changes: " + str(e))
629639
job.mp.log.info("--- pull aborted")
@@ -636,7 +646,7 @@ def pull_project_finalize(job: PullJob):
636646
else:
637647
job.mp.log.info("--- pull finished -- at version " + job.mp.version())
638648

639-
shutil.rmtree(job.temp_dir)
649+
job.tmp_dir.cleanup() # delete our temporary dir and all its content
640650
return conflicts
641651

642652

@@ -788,7 +798,7 @@ def download_files_async(
788798
mp.log.info(f"Got project info. version {project_info['version']}")
789799

790800
# set temporary directory for download
791-
temp_dir = tempfile.mkdtemp(prefix="python-api-client-")
801+
tmp_dir = tempfile.mkdtemp(prefix="python-api-client-")
792802

793803
if output_paths is None:
794804
output_paths = []
@@ -798,7 +808,7 @@ def download_files_async(
798808
if len(output_paths) != len(file_paths):
799809
warn = "Output file paths are not of the same length as file paths. Cannot store required files."
800810
mp.log.warning(warn)
801-
shutil.rmtree(temp_dir)
811+
shutil.rmtree(tmp_dir)
802812
raise ClientError(warn)
803813

804814
download_list = []
@@ -812,7 +822,7 @@ def download_files_async(
812822
if file["path"] in file_paths:
813823
index = file_paths.index(file["path"])
814824
file["version"] = version
815-
items = _download_items(file, temp_dir)
825+
items = _download_items(file, tmp_dir)
816826
is_latest_version = version == latest_proj_info["version"]
817827
task = UpdateTask(file["path"], items, output_paths[index], latest_version=is_latest_version)
818828
download_list.extend(task.download_queue_items)
@@ -832,13 +842,13 @@ def download_files_async(
832842
if not download_list or missing_files:
833843
warn = f"No [{', '.join(missing_files)}] exists at version {version}"
834844
mp.log.warning(warn)
835-
shutil.rmtree(temp_dir)
845+
shutil.rmtree(tmp_dir)
836846
raise ClientError(warn)
837847

838848
mp.log.info(
839849
f"will download files [{', '.join(files_to_download)}] in {len(download_list)} chunks, total size {total_size}"
840850
)
841-
job = DownloadJob(project_path, total_size, version, update_tasks, download_list, temp_dir, mp, project_info)
851+
job = DownloadJob(project_path, total_size, version, update_tasks, download_list, tmp_dir, mp, project_info)
842852
job.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
843853
job.futures = []
844854
for item in download_list:
@@ -862,8 +872,8 @@ def download_files_finalize(job):
862872
job.mp.log.info("--- download finished")
863873

864874
for task in job.update_tasks:
865-
task.apply(job.directory, job.mp)
875+
task.apply(job.tmp_dir, job.mp)
866876

867877
# Remove temporary download directory
868-
if job.directory is not None and os.path.exists(job.directory):
869-
shutil.rmtree(job.directory)
878+
if job.tmp_dir is not None and os.path.exists(job.tmp_dir):
879+
shutil.rmtree(job.tmp_dir)

0 commit comments

Comments
 (0)