forked from bloomreach/s4cmd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paths4cmd.py
executable file
·1506 lines (1301 loc) · 49.9 KB
/
s4cmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
#
# Copyright 2012 BloomReach, Inc.
# Portions Copyright 2014 Databricks
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Super S3 command line tool.
"""
import sys, os, re, optparse, multiprocessing, fnmatch, time, hashlib, errno
import logging, traceback, types, threading, random, socket
IS_PYTHON2 = sys.version_info[0] == 2
if IS_PYTHON2:
from cStringIO import StringIO
import Queue
import ConfigParser
else:
from io import StringIO
import queue as Queue
import configparser as ConfigParser
def cmp(a, b):
return (a > b) - (a < b)
from functools import cmp_to_key
# We need boto 2.3.0 for multipart upload1
import boto
import boto.s3
import boto.s3.key
import boto.exception
##
## Global constants
##
S4CMD_VERSION = "1.5.23"
PATH_SEP = '/'
DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S UTC'
TIMESTAMP_REGEX = re.compile(r'(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2}).(\d{3})Z')
TIMESTAMP_FORMAT = '%4s-%2s-%2s %2s:%2s'
SOCKET_TIMEOUT = 5 * 60 # in sec(s) (timeout if we don't receive any recv() callback)
socket.setdefaulttimeout(SOCKET_TIMEOUT)
# Global list for temp files.
TEMP_FILES = set()
# Environment variable names for S3 credentials.
S3_ACCESS_KEY_NAME = "S3_ACCESS_KEY"
S3_SECRET_KEY_NAME = "S3_SECRET_KEY"
##
## Utility classes
##
class Options:
'''Default option class for available options. Use the default value from opt parser.
The values can be overwritten by command line options or set at run-time.
'''
def __init__(self, opt = None):
parser = get_opt_parser()
for o in parser.option_list:
self.__dict__[o.dest] = o.default if (opt is None) or (opt.__dict__[o.dest] is None) else opt.__dict__[o.dest]
class Failure(RuntimeError):
'''Exception for runtime failures'''
pass
class InvalidArgument(RuntimeError):
'''Exception for invalid input parameters'''
pass
class RetryFailure(Exception):
'''Runtime failure that can be retried'''
pass
class S4cmdLoggingClass:
def __init__(self):
self.log = logging.Logger("s4cmd")
self.log.stream = sys.stderr
self.log_handler = logging.StreamHandler(self.log.stream)
self.log.addHandler(self.log_handler)
def configure(self, opt):
'Configure the logger based on command-line arguments'''
self.log_handler.setFormatter(logging.Formatter('%(message)s', DATETIME_FORMAT))
if opt.debug:
self.log.verbosity = 3
self.log_handler.setFormatter(logging.Formatter(
' (%(levelname).1s)%(filename)s:%(lineno)-4d %(message)s',
DATETIME_FORMAT))
self.log.setLevel(logging.DEBUG)
elif opt.verbose:
self.log.verbosity = 2
self.log.setLevel(logging.INFO)
else:
self.log.verbosity = 1
self.log.setLevel(logging.ERROR)
def get_loggers(self):
'''Return a list of the logger methods: (debug, info, warn, error)'''
return self.log.debug, self.log.info, self.log.warn, self.log.error
s4cmd_logging = S4cmdLoggingClass()
debug, info, warn, error = s4cmd_logging.get_loggers()
def get_default_thread_count():
return int(os.getenv('S4CMD_NUM_THREADS', multiprocessing.cpu_count() * 4))
def log_calls(func):
'''Decorator to log function calls.'''
def wrapper(*args, **kwargs):
callStr = "%s(%s)" % (func.__name__, ", ".join([repr(p) for p in args] + ["%s=%s" % (k, repr(v)) for (k, v) in list(kwargs.items())]))
debug(">> %s", callStr)
ret = func(*args, **kwargs)
debug("<< %s: %s", callStr, repr(ret))
return ret
return wrapper
##
## Utility functions
##
def synchronized(func):
'''Decorator to synchronize function.'''
func.__lock__ = threading.Lock()
def synced_func(*args, **kargs):
with func.__lock__:
return func(*args, **kargs)
return synced_func
def clear_progress():
'''Clear previous progress message, if any.'''
progress('')
@synchronized
def progress(msg, *args):
'''Show current progress message to stderr.
This function will remember the previous message so that next time,
it will clear the previous message before showing next one.
'''
# Don't show any progress if the output is directed to a file.
if not sys.stderr.isatty():
return
text = (msg % args)
if progress.prev_message:
sys.stderr.write(' ' * len(progress.prev_message) + '\r')
sys.stderr.write(text + '\r')
progress.prev_message = text
progress.prev_message = None
@synchronized
def message(msg, *args):
'''Program message output.'''
clear_progress()
text = (msg % args)
sys.stdout.write(text + '\n')
def fail(message, exc_info = None, status = 1, stacktrace = False):
'''Utility function to handle runtime failures gracefully.
Show concise information if possible, then terminate program.
'''
text = message
if exc_info:
text += str(exc_info)
error(text)
if stacktrace:
error(traceback.format_exc())
clean_tempfiles()
if __name__ == '__main__':
sys.exit(status)
else:
raise RuntimeError(status)
@synchronized
def tempfile_get(target):
'''Get a temp filename for atomic download.'''
fn = '%s-%s.tmp' % (target, ''.join(random.Random().sample("0123456789abcdefghijklmnopqrstuvwxyz", 15)))
TEMP_FILES.add(fn)
return fn
@synchronized
def tempfile_set(tempfile, target):
'''Atomically rename and clean tempfile'''
if target:
os.rename(tempfile, target)
else:
os.unlink(tempfile)
if target in TEMP_FILES:
TEMP_FILES.remove(tempfile)
def clean_tempfiles():
'''Clean up temp files'''
for fn in TEMP_FILES:
if os.path.exists(fn):
os.unlink(fn)
class S3URL:
'''Simple wrapper for S3 URL.
This class parses a S3 URL and provides accessors to each component.
'''
S3URL_PATTERN = re.compile(r'(s3[n]?)://([^/]+)[/]?(.*)')
def __init__(self, uri):
'''Initialization, parse S3 URL'''
try:
self.proto, self.bucket, self.path = S3URL.S3URL_PATTERN.match(uri).groups()
self.proto = 's3' # normalize s3n => s3
except:
raise InvalidArgument('Invalid S3 URI: %s' % uri)
def __str__(self):
'''Return the original S3 URL'''
return S3URL.combine(self.proto, self.bucket, self.path)
def get_fixed_path(self):
'''Get the fixed part of the path without wildcard'''
pi = self.path.split(PATH_SEP)
fi = []
for p in pi:
if '*' in p or '?' in p:
break
fi.append(p)
return PATH_SEP.join(fi)
@staticmethod
def combine(proto, bucket, path):
'''Combine each component and general a S3 url string, no path normalization
here. The path should not start with slash.
'''
return '%s://%s/%s' % (proto, bucket, path)
@staticmethod
def is_valid(uri):
'''Check if given uri is a valid S3 URL'''
return S3URL.S3URL_PATTERN.match(uri) != None
class TaskQueue(Queue.Queue):
'''Wrapper class to Queue.
Since we need to ensure that main thread is not blocked by child threads
and cannot be wake up by Ctrl-C interrupt, we have to override join()
method.
'''
def __init__(self):
Queue.Queue.__init__(self)
self.exc_info = None
def join(self):
'''Override original join() with a timeout and handle keyboard interrupt.'''
self.all_tasks_done.acquire()
try:
while self.unfinished_tasks:
self.all_tasks_done.wait(1000)
# Child thread has exceptions, fail main thread too.
if self.exc_info:
fail('[Thread Failure] ', exc_info = self.exc_info)
except KeyboardInterrupt:
raise Failure('Interrupted by user')
finally:
self.all_tasks_done.release()
def terminate(self, exc_info = None):
'''Terminate all threads by deleting the queue and forcing the child threads
to quit.
'''
if exc_info:
self.exc_info = exc_info
try:
while self.get_nowait():
self.task_done()
except Queue.Empty:
pass
class ThreadPool(object):
'''Utility class for thread pool.
This class needs to work with a utility class, which is derived from Worker.
'''
class Worker(threading.Thread):
'''Utility thread worker class.
This class handles all items in task queue and execute them. It also
handles runtime errors gracefully, and provides automatic retry.
'''
def __init__(self, pool):
'''Thread worker initalization.
Setup values and start threads right away.
'''
threading.Thread.__init__(self)
self.pool = pool
self.opt = pool.opt
self.daemon = True
self.start()
def run(self):
'''Main thread worker execution.
This function extract items from task queue and execute them accordingly.
It will retry tasks when encounter exceptions by putting the same item
back to the work queue.
'''
while True:
item = self.pool.tasks.get()
if not item:
break
try:
func_name, retry, args, kargs = item
self.__class__.__dict__[func_name](self, *args, **kargs)
except InvalidArgument as e:
self.pool.tasks.terminate(e)
fail('[Invalid Argument] ', exc_info = e)
except Failure as e:
self.pool.tasks.terminate(e)
fail('[Runtime Failure] ', exc_info = e)
# Also retry known S3ResponseError since S3 has transient
# errors from time to time.
# except boto.exception.S3ResponseError, e:
# self.pool.tasks.terminate(e)
# fail('[S3ResponseError] %s: %s' % (e.error_code, e.error_message))
except OSError as e:
self.pool.tasks.terminate(e)
fail('[OSError] %d: %s' % (e.errno, e.strerror))
except Exception as e:
# XXX Should we retry on all unknown exceptions?
if retry >= self.opt.retry:
self.pool.tasks.terminate(e)
fail('[Runtime Exception] ', exc_info = e, stacktrace = True)
else:
# Show content of exceptions.
error(e)
time.sleep(self.opt.retry_delay)
self.pool.tasks.put((func_name, retry + 1, args, kargs))
finally:
self.pool.processed()
self.pool.tasks.task_done()
def __init__(self, thread_class, opt):
'''Constructor of ThreadPool.
Create workers and pool will automatically inherit all methods from
thread_class by redirecting calls through __getattribute__().
'''
self.opt = opt
self.tasks = TaskQueue()
self.processed_tasks = 0
self.thread_class = thread_class
self.workers = []
for i in range(opt.num_threads):
self.workers.append(thread_class(self))
def __enter__(self):
'''Utility function for with statement'''
return self
def __exit__(self, exc_type, exc_value, traceback):
'''Utility function for with statement, wait for completion'''
self.join()
return isinstance(exc_value, TypeError)
def __getattribute__(self, name):
'''Special attribute accessor to add tasks into task queue.
Here if we found a function not in ThreadPool, we will try
to see if we have a function in the utility class. If so, we
add the function call into task queue.
'''
try:
attr = super(ThreadPool, self).__getattribute__(name)
except AttributeError as e:
if name in self.thread_class.__dict__:
# Here we masquerade the original function with add_task(). So the
# function call will be put into task queue.
def deferred_task(*args, **kargs):
self.add_task(name, *args, **kargs)
attr = deferred_task
else:
raise AttributeError('Unable to resolve %s' % name)
return attr
def add_task(self, func_name, *args, **kargs):
'''Utility function to add a single task into task queue'''
self.tasks.put((func_name, 0, args, kargs))
def join(self):
'''Utility function to wait all tasks to complete'''
self.tasks.join()
# Force each thread to break loop.
for worker in self.workers:
self.tasks.put(None)
# Wait for all thread to terminate.
for worker in self.workers:
worker.join()
worker.s3 = None
@synchronized
def processed(self):
'''Increase the processed task counter and show progress message'''
self.processed_tasks += 1
qsize = self.tasks.qsize()
if qsize > 0:
progress('[%d task(s) completed, %d remaining, %d thread(s)]', self.processed_tasks, qsize, len(self.workers))
else:
progress('[%d task(s) completed, %d thread(s)]', self.processed_tasks, len(self.workers))
class S3Handler(object):
'''Core S3 class.
This class provide the functions for all operations. It will start thread
pool to execute tasks generated by each operation. See ThreadUtil for
more details about the tasks.
'''
S3_KEYS = None
@staticmethod
def s3_keys_from_env():
'''Retrieve S3 access keys from the environment, or None if not present.'''
env = os.environ
if S3_ACCESS_KEY_NAME in env and S3_SECRET_KEY_NAME in env:
keys = (env[S3_ACCESS_KEY_NAME], env[S3_SECRET_KEY_NAME])
debug("read S3 keys from environment")
return keys
else:
return None
@staticmethod
def s3_keys_from_s3cfg(opt):
'''Retrieve S3 access key settings from s3cmd's config file, if present; otherwise return None.'''
try:
if opt.s3cfg != None:
s3cfg_path = "%s" % opt.s3cfg
else:
s3cfg_path = "%s/.s3cfg" % os.environ["HOME"]
if not os.path.exists(s3cfg_path):
return None
config = ConfigParser.ConfigParser()
config.read(s3cfg_path)
keys = config.get("default", "access_key"), config.get("default", "secret_key")
debug("read S3 keys from $HOME/.s3cfg file")
return keys
except Exception as e:
info("could not read S3 keys from %s file; skipping (%s)", s3cfg_path, e)
return None
@staticmethod
def init_s3_keys(opt):
'''Initialize s3 access keys from environment variable or s3cfg config file.'''
S3Handler.S3_KEYS = S3Handler.s3_keys_from_env() or S3Handler.s3_keys_from_s3cfg(opt)
def __init__(self, opt):
'''Constructor, connect to S3 store'''
self.s3 = None
self.opt = opt
self.connect()
def __del__(self):
'''Destructor, stop s3 connection'''
if self.s3 is not None:
self.s3.close()
self.s3 = None
def connect(self):
'''Connect to S3 storage'''
try:
if S3Handler.S3_KEYS:
self.s3 = boto.connect_s3(S3Handler.S3_KEYS[0],
S3Handler.S3_KEYS[1],
is_secure = self.opt.use_ssl,
suppress_consec_slashes = False)
else:
self.s3 = boto.connect_s3(is_secure = self.opt.use_ssl,
suppress_consec_slashes = False)
except Exception as e:
raise RetryFailure('Unable to connect to s3: %s' % e)
@log_calls
def list_buckets(self):
'''List all buckets'''
result = []
for bucket in self.s3.get_all_buckets():
result.append({
'name': S3URL.combine('s3', bucket.name, ''),
'is_dir': True,
'size': 0,
'last_modified': bucket.creation_date
})
return result
@log_calls
def s3walk(self, basedir, show_dir = None):
'''Walk through a S3 directory. This function initiate a walk with a basedir.
It also supports multiple wildcards.
'''
# Provide the default value from command line if no override.
if not show_dir:
show_dir = self.opt.show_dir
# trailing slash normalization, this is for the reason that we want
# ls 's3://foo/bar/' has the same result as 's3://foo/bar'. Since we
# call partial_match() to check wildcards, we need to ensure the number
# of slashes stays the same when we do this.
if basedir[-1] == PATH_SEP:
basedir = basedir[0:-1]
s3url = S3URL(basedir)
result = []
pool = ThreadPool(ThreadUtil, self.opt)
pool.s3walk(s3url, s3url.get_fixed_path(), s3url.path, result)
pool.join()
# automatic directory detection
if not show_dir and len(result) == 1 and result[0]['is_dir']:
path = result[0]['name']
s3url = S3URL(path)
result = []
pool = ThreadPool(ThreadUtil, self.opt)
pool.s3walk(s3url, s3url.get_fixed_path(), s3url.path, result)
pool.join()
def compare(x, y):
'''Comparator for ls output'''
result = -cmp(x['is_dir'], y['is_dir'])
if result != 0:
return result
return cmp(x['name'], y['name'])
compare = cmp_to_key(compare)
return sorted(result, key=compare)
@log_calls
def local_walk(self, basedir):
'''Walk through local directories from root basedir'''
result = []
for root, dirs, files in os.walk(basedir):
for f in files:
result.append(os.path.join(root, f))
return result
@log_calls
def get_basename(self, path):
'''Unix style basename.
This fuction will return 'bar' for '/foo/bar/' instead of empty string.
It is used to normalize the input trailing slash.
'''
if path[-1] == PATH_SEP:
path = path[0:-1]
return os.path.basename(path)
def source_expand(self, source):
'''Expand the wildcards for an S3 path. This emulates the shall expansion
for wildcards if the input is local path.
'''
result = []
if not isinstance(source, list):
source = [source]
for src in source:
# XXX Hacky: We need to disable recursive when we expand the input
# parameters, need to pass this as an override parameter if
# provided.
tmp = self.opt.recursive
self.opt.recursive = False
result += [f['name'] for f in self.s3walk(src, True)]
self.opt.recursive = tmp
if (len(result) == 0) and (not self.opt.ignore_empty_source):
fail("[Runtime Failure] Source doesn't exist.")
return result
@log_calls
def put_single_file(self, pool, source, target):
'''Upload a single file or a directory by adding a task into queue'''
if os.path.isdir(source):
if self.opt.recursive:
for f in [f for f in self.local_walk(source) if not os.path.isdir(f)]:
target_url = S3URL(target)
# deal with ./ or ../ here by normalizing the path.
joined_path = os.path.normpath(os.path.join(target_url.path, os.path.relpath(f, source)))
pool.upload(None, f, S3URL.combine('s3', target_url.bucket, joined_path))
else:
message('omitting directory "%s".' % source)
else:
pool.upload(None, source, target)
@log_calls
def put_files(self, source, target):
'''Upload files to S3.
This function can handle multiple file upload if source is a list.
Also, it works for recursive mode which copy all files and keep the
directory structure under the given source directory.
'''
pool = ThreadPool(ThreadUtil, self.opt)
if not isinstance(source, list):
source = [source]
if target[-1] == PATH_SEP:
for src in source:
self.put_single_file(pool, src, os.path.join(target, self.get_basename(src)))
else:
if len(source) == 1:
self.put_single_file(pool, source[0], target)
else:
raise Failure('Target "%s" is not a directory (with a trailing slash).' % target)
pool.join()
@log_calls
def update_privilege(self, source, target):
'''Get privileges from metadata of the source in s3, and apply them to target'''
s3url = S3URL(source)
bucket = self.s3.lookup(s3url.bucket, validate=self.opt.validate)
remoteKey = bucket.get_key(s3url.path)
if 'privilege' in remoteKey.metadata:
os.chmod(target, int(remoteKey.metadata['privilege'], 8))
@log_calls
def get_single_file(self, pool, source, target):
'''Download a single file or a directory by adding a task into queue'''
if source[-1] == PATH_SEP:
if self.opt.recursive:
basepath = S3URL(source).path
for f in [f for f in self.s3walk(source) if not f['is_dir']]:
pool.download(None, f['name'], os.path.join(target, os.path.relpath(S3URL(f['name']).path, basepath)))
else:
message('omitting directory "%s".' % source)
else:
pool.download(None, source, target)
@log_calls
def get_files(self, source, target):
'''Download files.
This function can handle multiple files if source S3 URL has wildcard
characters. It also handles recursive mode by download all files and
keep the directory structure.
'''
pool = ThreadPool(ThreadUtil, self.opt)
source = self.source_expand(source)
if os.path.isdir(target):
for src in source:
self.get_single_file(pool, src, os.path.join(target, self.get_basename(S3URL(src).path)))
else:
if len(source) > 1:
raise Failure('Target "%s" is not a directory.' % target)
# Get file if it exists on s3 otherwise do nothing
elif len(source) == 1:
self.get_single_file(pool, source[0], target)
else:
#Source expand may return empty list only if ignore-empty-source is set to true
pass
pool.join()
@log_calls
def delete_removed_files(self, source, target):
'''Remove remote files that are not present in the local source.
'''
message("Deleting files found in %s and not in %s", source, target)
if os.path.isdir(source):
unecessary = []
basepath = S3URL(target).path
for f in [f for f in self.s3walk(target) if not f['is_dir']]:
local_name = os.path.join(source, os.path.relpath(S3URL(f['name']).path, basepath))
if not os.path.isfile(local_name):
message("%s not found locally, adding to delete queue", local_name)
unecessary.append(f['name'])
if len(unecessary) > 0:
pool = ThreadPool(ThreadUtil, self.opt)
for del_file in unecessary:
pool.delete(del_file)
pool.join()
else:
raise Failure('Source "%s" is not a directory.' % target)
@log_calls
def cp_single_file(self, pool, source, target, delete_source):
'''Copy a single file or a directory by adding a task into queue'''
if source[-1] == PATH_SEP:
if self.opt.recursive:
basepath = S3URL(source).path
for f in [f for f in self.s3walk(source) if not f['is_dir']]:
pool.copy(f['name'], os.path.join(target, os.path.relpath(S3URL(f['name']).path, basepath)), delete_source)
else:
message('omitting directory "%s".' % source)
else:
pool.copy(source, target, delete_source)
@log_calls
def cp_files(self, source, target, delete_source = False):
'''Copy files
This function can handle multiple files if source S3 URL has wildcard
characters. It also handles recursive mode by copying all files and
keep the directory structure.
'''
pool = ThreadPool(ThreadUtil, self.opt)
source = self.source_expand(source)
if target[-1] == PATH_SEP:
for src in source:
self.cp_single_file(pool, src, os.path.join(target, self.get_basename(S3URL(src).path)), delete_source)
else:
if len(source) > 1:
raise Failure('Target "%s" is not a directory (with a trailing slash).' % target)
# Copy file if it exists otherwise do nothing
elif len(source) == 1:
self.cp_single_file(pool, source[0], target, delete_source)
else:
# Source expand may return empty list only if ignore-empty-source is set to true
pass
pool.join()
@log_calls
def del_files(self, source):
'''Delete files on S3'''
src_files = []
for key in self.s3walk(source):
if not key['is_dir']: # ignore directories
src_files.append(key['name'])
pool = ThreadPool(ThreadUtil, self.opt)
for src_file in src_files:
pool.delete(src_file)
pool.join()
@log_calls
def sync_files(self, source, target):
'''Sync files to S3. Does implement deletions if syncing TO s3.
Currently identical to get/put -r -f --sync-check with exception of deletions.
'''
src_s3_url = S3URL.is_valid(source)
dst_s3_url = S3URL.is_valid(target)
if src_s3_url and not dst_s3_url:
self.get_files(source, target)
elif not src_s3_url and dst_s3_url:
self.put_files(source, target)
if self.opt.delete_removed:
self.delete_removed_files(source, target)
elif src_s3_url and dst_s3_url:
self.cp_files(source, target)
else:
raise InvalidArgument('No S3 URI provided')
@log_calls
def size(self, source):
'''Get the size component of the given s3url. If it is a
directory, combine the sizes of all the files under
that directory. Subdirectories will not be counted unless
--recursive option is set.
'''
result = []
for src in self.source_expand(source):
size = 0
for f in self.s3walk(src):
size += f['size']
result.append((src, size))
return result
class ThreadUtil(S3Handler, ThreadPool.Worker):
'''Thread workers for S3 operations.
This class contains all thread workers for S3 operations.
1) Expand source into [source] list if it contains wildcard characters '*' or '?'.
This is done by shell, but we need to do this ourselves for S3 path.
Basically we see [source] as the first-class source list.
2) Run the following algorithm:
if target is directory? (S3 path uses trailing slash to determine this)
for src in source:
copy src to target/src.basename
else
if source has only one element?
copy src to target
else
error "target should be a directory"!
3) Copy operations should work for both single file and directory:
def copy(src, target)
if src is a directory?
copy the whole directory recursively to target
else
copy the file src to target
'''
def __init__(self, pool):
'''Constructor'''
S3Handler.__init__(self, pool.opt)
ThreadPool.Worker.__init__(self, pool)
def reset(self):
'''Reset connection for retry.'''
if self.s3:
self.s3.close()
self.connect()
def auto_reset(func):
'''Simple decorator for connection reset.
This is necessary to have a clean connection if s3 connection failed.
'''
def wrapper(self, *args, **kwargs):
self.reset()
return func(self, *args, **kwargs)
return wrapper
@log_calls
def mkdirs(self, target):
'''Ensure all directories are created for a given target file.'''
path = os.path.dirname(target)
if path and path != '/' and not os.path.isdir(path):
# Multi-threading means there will be intervleaved execution
# between the check and creation of the directory.
try:
os.makedirs(path)
except OSError as ose:
if ose.errno != errno.EEXIST:
raise Failure('Unable to create directory (%s)' % (path,))
@log_calls
def file_hash(self, filename, block_size = None):
'''Calculate MD5 hash code for a local file'''
if not block_size:
block_size = 2**20
m = hashlib.md5()
with open(filename, 'rb') as f:
while True:
data = f.read(block_size)
if not data:
break
m.update(data)
return m.hexdigest()
@log_calls
def sync_check(self, localFilename, remoteKey):
'''Check MD5 for a local file and a remote file.
Return True if they have the same md5 hash, otherwise False.
'''
if not remoteKey:
return False
if not os.path.exists(localFilename):
return False
localmd5 = self.file_hash(localFilename)
# check multiple md5 locations
return (remoteKey.etag and remoteKey.etag == '"%s"' % localmd5) or \
(remoteKey.md5 and remoteKey.md5 == localmd5) or \
('md5' in remoteKey.metadata and remoteKey.metadata['md5'] == localmd5)
@log_calls
def partial_match(self, path, filter_path):
'''Partially match a path and a filter_path with wildcards.
This function will return True if this path partially match a filter path.
This is used for walking through directories with multiple level wildcard.
'''
if not path or not filter_path:
return True
# trailing slash normalization
if path[-1] == PATH_SEP:
path = path[0:-1]
if filter_path[-1] == PATH_SEP:
filter_path += '*'
pi = path.split(PATH_SEP)
fi = filter_path.split(PATH_SEP)
# Here, if we are in recursive mode, we allow the pi to be longer than fi.
# Otherwise, length of pi should be equal or less than the lenght of fi.
min_len = min(len(pi), len(fi))
matched = fnmatch.fnmatch(PATH_SEP.join(pi[0:min_len]), PATH_SEP.join(fi[0:min_len]))
return matched and (self.opt.recursive or len(pi) <= len(fi))
@log_calls
@auto_reset
def s3walk(self, s3url, s3dir, filter_path, result):
'''Thread worker for s3walk.
Recursively walk into all subdirectories if they still match the filter
path partially.
'''
bucket = self.s3.lookup(s3url.bucket, validate=self.opt.validate)
for key in bucket.list(s3dir, PATH_SEP):
if not self.partial_match(key.name, filter_path):
continue
# determine if it is a leaf node.
is_dir = (key.name[-1] == PATH_SEP)
if is_dir:
is_leaf = (key.name.count(PATH_SEP) == filter_path.count(PATH_SEP) + 1)
else:
is_leaf = (key.name.count(PATH_SEP) == filter_path.count(PATH_SEP))
if self.opt.recursive:
is_leaf = not is_dir
if is_leaf:
result.append({
'name': S3URL.combine(s3url.proto, s3url.bucket, key.name),
'is_dir': is_dir,
'size': key.size if not is_dir else 0,
'last_modified': key.last_modified if not is_dir else None
})
elif is_dir and key.name != s3dir: # bug?
self.pool.s3walk(s3url, key.name, filter_path, result)
class MultipartItem:
'''Utility class for multiple part upload/download.
This class is used to keep track of a single upload/download file, so
that we can initialize/finalize a file when needed.
'''
def __init__(self, id):
'''Constructor.
An unique identify for a single donwload/upload file.
- Download: the temporary file name.
- Upload: the id of multipart upload provided by S3.
'''
self.id = id
self.processed = 0
self.total = -1
@synchronized
def complete(self):
'''Increase the processed counter, and see if the file is completely
uploaded or downloaded.
'''
self.processed += 1
return self.processed == self.total
@log_calls
def get_file_splits(self, id, source, target, fsize, splitsize):
'''Get file splits for upload/download.'''
pos = 0
part = 1 # S3 part id starts from 1
mpi = ThreadUtil.MultipartItem(id)
splits = []
while pos < fsize:
chunk = min(splitsize, fsize - pos)
assert(chunk > 0)
splits.append((mpi, source, target, pos, chunk, part))
part += 1
pos += chunk
mpi.total = len(splits)
return splits
@log_calls
def get_file_privilege(self, source):
'''Get privileges of a local file'''
try:
return str(oct(os.stat(source).st_mode)[-3:])
except Exception as e:
raise Failure('Could not get stat for %s, error_message = %s', source, e)
@log_calls
@auto_reset
def upload(self, mpi, source, target, pos = 0, chunk = 0, part = 0):
'''Thread worker for upload operation.'''
s3url = S3URL(target)
bucket = self.s3.lookup(s3url.bucket, validate=self.opt.validate)
# Initialization: Set up multithreaded uploads.
if not mpi:
fsize = os.path.getsize(source)
key = bucket.get_key(s3url.path)