Skip to content

Commit 0c23093

Browse files
committed
add concurrent progress bar
1 parent a69112b commit 0c23093

File tree

4 files changed

+256
-6
lines changed

4 files changed

+256
-6
lines changed

oss/downloader.go

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@ type DownloaderOptions struct {
2929
ClientOptions []func(*Options)
3030
}
3131

32+
type downloaderProgress struct {
33+
pr ProgressFunc
34+
written int64
35+
total int64
36+
mu sync.Mutex
37+
}
38+
39+
func (cpt *downloaderProgress) Write(b []byte) (n int, err error) {
40+
n = len(b)
41+
increment := int64(n)
42+
cpt.mu.Lock()
43+
defer cpt.mu.Unlock()
44+
cpt.written += increment
45+
cpt.pr(increment, cpt.written, cpt.total)
46+
return
47+
}
48+
3249
type Downloader struct {
3350
options DownloaderOptions
3451
client DownloadAPIClient
@@ -378,7 +395,7 @@ func (d *downloaderDelegate) download() (*DownloadResult, error) {
378395
}
379396

380397
// writeChunkFn runs in worker goroutines to pull chunks off of the ch channel
381-
writeChunkFn := func(ch chan downloaderChunk) {
398+
writeChunkFn := func(ch chan downloaderChunk, progress *downloaderProgress) {
382399
defer wg.Done()
383400
var hash hash.Hash64
384401
if d.calcCRC {
@@ -395,7 +412,7 @@ func (d *downloaderDelegate) download() (*DownloadResult, error) {
395412
continue
396413
}
397414

398-
dchunk, derr := d.downloadChunk(chunk, hash)
415+
dchunk, derr := d.downloadChunk(chunk, hash, progress)
399416

400417
if derr != nil && derr != io.EOF {
401418
saveErrFn(derr)
@@ -455,9 +472,16 @@ func (d *downloaderDelegate) download() (*DownloadResult, error) {
455472

456473
// Start the download workers
457474
ch := make(chan downloaderChunk, d.options.ParallelNum)
475+
var progress *downloaderProgress
476+
if d.request.ProgressFn != nil {
477+
progress = &downloaderProgress{
478+
pr: d.request.ProgressFn,
479+
total: d.sizeInBytes,
480+
}
481+
}
458482
for i := 0; i < d.options.ParallelNum; i++ {
459483
wg.Add(1)
460-
go writeChunkFn(ch)
484+
go writeChunkFn(ch, progress)
461485
}
462486

463487
// Start tracker worker if need track downloaded chunk
@@ -511,7 +535,7 @@ func (d *downloaderDelegate) incrWritten(n int64) {
511535
d.written += n
512536
}
513537

514-
func (d *downloaderDelegate) downloadChunk(chunk downloaderChunk, hash hash.Hash64) (downloadedChunk, error) {
538+
func (d *downloaderDelegate) downloadChunk(chunk downloaderChunk, hash hash.Hash64, progress *downloaderProgress) (downloadedChunk, error) {
515539
// Get the next byte range of data
516540
var request GetObjectRequest
517541
copyRequest(&request, d.request)
@@ -546,6 +570,15 @@ func (d *downloaderDelegate) downloadChunk(chunk downloaderChunk, hash hash.Hash
546570
r io.Reader = reader
547571
crc64 uint64 = 0
548572
)
573+
writer := io.MultiWriter()
574+
if progress != nil {
575+
writer = io.MultiWriter(writer, progress)
576+
}
577+
if hash != nil {
578+
hash.Reset()
579+
writer = io.MultiWriter(writer, hash)
580+
}
581+
r = io.TeeReader(reader, writer)
549582
if hash != nil {
550583
hash.Reset()
551584
r = io.TeeReader(reader, hash)

oss/downloader_mock_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,3 +1193,90 @@ func TestMockDownloaderCRCCheck(t *testing.T) {
11931193
io.Copy(hash, rfile)
11941194
assert.Equal(t, datasum, hash.Sum64())
11951195
}
1196+
1197+
func TestMockDownloaderWithProgress(t *testing.T) {
1198+
length := 3*1024*1024 + 1234
1199+
data := []byte(randStr(length))
1200+
gmtTime := getNowGMT()
1201+
tracker := &downloaderMockTracker{
1202+
lastModified: gmtTime,
1203+
data: data,
1204+
}
1205+
server := testSetupDownloaderMockServer(t, tracker)
1206+
defer server.Close()
1207+
assert.NotNil(t, server)
1208+
cfg := LoadDefaultConfig().
1209+
WithCredentialsProvider(credentials.NewAnonymousCredentialsProvider()).
1210+
WithRegion("cn-hangzhou").
1211+
WithEndpoint(server.URL).
1212+
WithReadWriteTimeout(300 * time.Second)
1213+
client := NewClient(cfg)
1214+
var n int64
1215+
d := client.NewDownloader(func(do *DownloaderOptions) {
1216+
do.ParallelNum = 1
1217+
do.PartSize = 1 * 1024 * 1024
1218+
})
1219+
assert.NotNil(t, d)
1220+
assert.NotNil(t, d.client)
1221+
assert.Equal(t, int64(1*1024*1024), d.options.PartSize)
1222+
assert.Equal(t, 1, d.options.ParallelNum)
1223+
// filePath is invalid
1224+
_, err := d.DownloadFile(
1225+
context.TODO(),
1226+
&GetObjectRequest{
1227+
Bucket: Ptr("bucket"),
1228+
Key: Ptr("key"),
1229+
ProgressFn: func(increment, transferred, total int64) {
1230+
n = transferred
1231+
},
1232+
}, "")
1233+
assert.NotNil(t, err)
1234+
assert.Contains(t, err.Error(), "invalid field, filePath")
1235+
localFile := randStr(8) + "-no-surfix"
1236+
defer func() {
1237+
os.Remove(localFile)
1238+
}()
1239+
_, err = d.DownloadFile(
1240+
context.TODO(),
1241+
&GetObjectRequest{
1242+
Bucket: Ptr("bucket"),
1243+
Key: Ptr("key"),
1244+
ProgressFn: func(increment, transferred, total int64) {
1245+
n = transferred
1246+
},
1247+
}, localFile)
1248+
assert.Nil(t, err)
1249+
assert.Equal(t, n, int64(length))
1250+
n = int64(0)
1251+
d = client.NewDownloader(func(do *DownloaderOptions) {
1252+
do.ParallelNum = 3
1253+
do.PartSize = 3 * 1024 * 1024
1254+
})
1255+
assert.NotNil(t, d)
1256+
assert.NotNil(t, d.client)
1257+
assert.Equal(t, int64(3*1024*1024), d.options.PartSize)
1258+
assert.Equal(t, 3, d.options.ParallelNum)
1259+
// filePath is invalid
1260+
_, err = d.DownloadFile(
1261+
context.TODO(),
1262+
&GetObjectRequest{
1263+
Bucket: Ptr("bucket"),
1264+
Key: Ptr("key"),
1265+
ProgressFn: func(increment, transferred, total int64) {
1266+
n = transferred
1267+
},
1268+
}, "")
1269+
assert.NotNil(t, err)
1270+
assert.Contains(t, err.Error(), "invalid field, filePath")
1271+
_, err = d.DownloadFile(
1272+
context.TODO(),
1273+
&GetObjectRequest{
1274+
Bucket: Ptr("bucket"),
1275+
Key: Ptr("key"),
1276+
ProgressFn: func(increment, transferred, total int64) {
1277+
n = transferred
1278+
},
1279+
}, localFile)
1280+
assert.Nil(t, err)
1281+
assert.Equal(t, n, int64(length))
1282+
}

oss/uploader.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,21 @@ type UploaderOptions struct {
2626
ClientOptions []func(*Options)
2727
}
2828

29+
type uploaderProgress struct {
30+
pr ProgressFunc
31+
written int64
32+
total int64
33+
mu sync.Mutex
34+
}
35+
36+
func (cpt *uploaderProgress) incrWritten(n int64) {
37+
increment := n
38+
cpt.mu.Lock()
39+
defer cpt.mu.Unlock()
40+
cpt.written += increment
41+
cpt.pr(increment, cpt.written, cpt.total)
42+
}
43+
2944
type Uploader struct {
3045
options UploaderOptions
3146
client UploadAPIClient
@@ -539,7 +554,7 @@ func (u *uploaderDelegate) multiPart() (*UploadResult, error) {
539554
}
540555

541556
// readChunk runs in worker goroutines to pull chunks off of the ch channel
542-
readChunkFn := func(ch chan uploaderChunk) {
557+
readChunkFn := func(ch chan uploaderChunk, progress *uploaderProgress) {
543558
defer wg.Done()
544559
for {
545560
data, ok := <-ch
@@ -562,6 +577,9 @@ func (u *uploaderDelegate) multiPart() (*UploadResult, error) {
562577
//fmt.Printf("UploadPart result: %#v, %#v\n", upResult, err)
563578

564579
if err == nil {
580+
if progress != nil {
581+
progress.incrWritten(int64(data.size))
582+
}
565583
mu.Lock()
566584
parts = append(parts, UploadPart{ETag: upResult.ETag, PartNumber: data.partNum})
567585
if enableCRC {
@@ -578,9 +596,16 @@ func (u *uploaderDelegate) multiPart() (*UploadResult, error) {
578596
}
579597

580598
ch := make(chan uploaderChunk, u.options.ParallelNum)
599+
var progress *uploaderProgress
600+
if u.request.ProgressFn != nil {
601+
progress = &uploaderProgress{
602+
pr: u.request.ProgressFn,
603+
total: u.totalSize,
604+
}
605+
}
581606
for i := 0; i < u.options.ParallelNum; i++ {
582607
wg.Add(1)
583-
go readChunkFn(ch)
608+
go readChunkFn(ch, progress)
584609
}
585610

586611
// Read and queue the parts

oss/uploader_mock_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,3 +1799,108 @@ func TestMockUploadCRC64Fail(t *testing.T) {
17991799
allCrc64ecma := fmt.Sprint(hashall.Sum64())
18001800
assert.Equal(t, dataCrc64ecma, allCrc64ecma)
18011801
}
1802+
1803+
func TestMockUploadSinglePartFromFileWithProgress(t *testing.T) {
1804+
partSize := DefaultUploadPartSize
1805+
length := 5*100*1024 + 123
1806+
partsNum := length/int(partSize) + 1
1807+
tracker := &uploaderMockTracker{
1808+
partNum: partsNum,
1809+
saveDate: make([][]byte, partsNum),
1810+
checkTime: make([]time.Time, partsNum),
1811+
timeout: make([]time.Duration, partsNum),
1812+
uploadPartErr: make([]bool, partsNum),
1813+
}
1814+
data := []byte(randStr(length))
1815+
hash := NewCRC64(0)
1816+
hash.Write(data)
1817+
dataCrc64ecma := fmt.Sprint(hash.Sum64())
1818+
localFile := randStr(8) + ".txt"
1819+
createFileFromByte(t, localFile, data)
1820+
defer func() {
1821+
os.Remove(localFile)
1822+
}()
1823+
server := testSetupUploaderMockServer(t, tracker)
1824+
defer server.Close()
1825+
assert.NotNil(t, server)
1826+
cfg := LoadDefaultConfig().
1827+
WithCredentialsProvider(credentials.NewAnonymousCredentialsProvider()).
1828+
WithRegion("cn-hangzhou").
1829+
WithEndpoint(server.URL).
1830+
WithReadWriteTimeout(300 * time.Second)
1831+
client := NewClient(cfg)
1832+
u := NewUploader(client)
1833+
assert.NotNil(t, u.client)
1834+
assert.Equal(t, DefaultUploadParallel, u.options.ParallelNum)
1835+
assert.Equal(t, DefaultUploadPartSize, u.options.PartSize)
1836+
n := int64(0)
1837+
result, err := u.UploadFile(context.TODO(), &PutObjectRequest{
1838+
Bucket: Ptr("bucket"),
1839+
Key: Ptr("key"),
1840+
ProgressFn: func(increment, transferred, total int64) {
1841+
n = transferred
1842+
fmt.Printf("increment:%#v, transferred:%#v, total:%#v\n", increment, transferred, total)
1843+
},
1844+
}, localFile)
1845+
assert.Nil(t, err)
1846+
assert.NotNil(t, result)
1847+
assert.Nil(t, result.UploadId)
1848+
assert.Equal(t, dataCrc64ecma, *result.HashCRC64)
1849+
assert.Equal(t, n, int64(length))
1850+
}
1851+
func TestMockUploadParallelFromFileWithProgress(t *testing.T) {
1852+
partSize := int64(100 * 1024)
1853+
length := 5*100*1024 + 123
1854+
partsNum := length/int(partSize) + 1
1855+
tracker := &uploaderMockTracker{
1856+
partNum: partsNum,
1857+
saveDate: make([][]byte, partsNum),
1858+
checkTime: make([]time.Time, partsNum),
1859+
timeout: make([]time.Duration, partsNum),
1860+
uploadPartErr: make([]bool, partsNum),
1861+
}
1862+
data := []byte(randStr(length))
1863+
hash := NewCRC64(0)
1864+
hash.Write(data)
1865+
dataCrc64ecma := fmt.Sprint(hash.Sum64())
1866+
localFile := randStr(8) + "-no-surfix"
1867+
createFileFromByte(t, localFile, data)
1868+
defer func() {
1869+
os.Remove(localFile)
1870+
}()
1871+
server := testSetupUploaderMockServer(t, tracker)
1872+
defer server.Close()
1873+
assert.NotNil(t, server)
1874+
cfg := LoadDefaultConfig().
1875+
WithCredentialsProvider(credentials.NewAnonymousCredentialsProvider()).
1876+
WithRegion("cn-hangzhou").
1877+
WithEndpoint(server.URL).
1878+
WithReadWriteTimeout(300 * time.Second)
1879+
client := NewClient(cfg)
1880+
u := NewUploader(client,
1881+
func(uo *UploaderOptions) {
1882+
uo.ParallelNum = 4
1883+
uo.PartSize = partSize
1884+
},
1885+
)
1886+
assert.Equal(t, 4, u.options.ParallelNum)
1887+
assert.Equal(t, partSize, u.options.PartSize)
1888+
tracker.timeout[0] = 1 * time.Second
1889+
tracker.timeout[2] = 500 * time.Millisecond
1890+
n := int64(0)
1891+
result, err := u.UploadFile(context.TODO(), &PutObjectRequest{
1892+
Bucket: Ptr("bucket"),
1893+
Key: Ptr("key"),
1894+
ProgressFn: func(increment, transferred, total int64) {
1895+
n = transferred
1896+
fmt.Printf("increment:%#v, transferred:%#v, total:%#v\n", increment, transferred, total)
1897+
},
1898+
}, localFile)
1899+
assert.Nil(t, err)
1900+
assert.NotNil(t, result)
1901+
assert.Nil(t, err)
1902+
assert.NotNil(t, result)
1903+
assert.Equal(t, "uploadId-1234", *result.UploadId)
1904+
assert.Equal(t, dataCrc64ecma, *result.HashCRC64)
1905+
assert.Equal(t, n, int64(length))
1906+
}

0 commit comments

Comments
 (0)