Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]kadai3-2-nejiyoshida #37

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions kadai3-2/nejiyoshida/downloader/downloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package downloader

import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"runtime"
"sync"
)

type dlunit struct {
buf *bytes.Buffer //DLしたデータ
offset int64 //どの部分のデータなのか
err error
}

func CheckHead(url string) (int64, string, error) {

// 分割できるか&ファイルサイズ確認するためHead要求
resp, err := http.Head(url)
if err != nil {
return 0, "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return 0, "", err
}

//responseのContent-Lengthからサイズを確認
size := resp.ContentLength

//分割ダウンロードできるかどうか。bytesなら可能
dltype := resp.Header.Get("Accept-Ranges")

return size, dltype, nil
}

//一つのダウンロード単位
func download(ctx context.Context, url string, from, to int64) <-chan dlunit {
ch := make(chan dlunit)

go func() {
defer close(ch)

req, err := http.NewRequest(http.MethodGet, url, nil)

if err != nil {
ch <- dlunit{buf: nil, offset: 0, err: err}
return
}

//fromからtoまでリクエストするよう指定。RangeHeaderが利用できないと0から最後まで
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", from, to))
req = req.WithContext(ctx)

cli := http.DefaultClient
resp, err := cli.Do(req)
if err != nil {
ch <- dlunit{buf: nil, offset: 0, err: err}
return
}

defer resp.Body.Close()

var buf bytes.Buffer

//ゴルーチンで戻すためにDLしたパーツをバッファにコピー
_, err = io.Copy(&buf, resp.Body)

if err != nil {
ch <- dlunit{buf: nil, offset: 0, err: err}
return
}

ch <- dlunit{buf: &buf, offset: from, err: nil}

}()

return ch
}

//分割DLできないときの普通のDL
func Download(ctx context.Context, fp *os.File, url string, size int64) error {

//分割しないので0から最後(ファイルサイズ分)まで
p := <-download(ctx, url, 0, size)
if p.err != nil {
return p.err
}

//DLしたファイルを書き込む
_, err := io.Copy(fp, p.buf)
if err != nil {
return err
}

_, err = fp.Seek(0, io.SeekStart)
if err != nil {
return err
}

return nil
}

//並行DL
func ParallelDownload(ctx context.Context, fp *os.File, url string, size int64) error {

numcpu := runtime.NumCPU()
partsize := size / int64(numcpu)

//並行してダウンロードするためのスライス
dlunits := make([]<-chan dlunit, numcpu)

for i := 0; i < numcpu; i++ {
var from, to int64

if i == 0 {
from = 0
} else {
from = partsize*int64(i) + 1
}

if i == numcpu-1 {
to = size
} else {
to = from + partsize
}

dlunits[i] = download(ctx, url, from, to)
}

for p := range merge(dlunits...) {
if p.err != nil {
return p.err
}
//offsetの地点からそれぞれ書き込むことで分割ダウンロードしたものを組み合わせる
fp.WriteAt(p.buf.Bytes(), p.offset)
}

_, err := fp.Seek(0, io.SeekStart)
if err != nil {
return err
}

return nil
}

func merge(chs ...<-chan dlunit) <-chan dlunit {
var wg sync.WaitGroup
merged := make(chan dlunit)

wg.Add(len(chs))

for _, ch := range chs {
go func(ch <-chan dlunit) {
defer wg.Done()

p := <-ch
merged <- p

}(ch)
}

go func() {
wg.Wait()
close(merged)
}()

return merged
}
89 changes: 89 additions & 0 deletions kadai3-2/nejiyoshida/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package main

import (
"context"
"fmt"
"os"
"os/signal"
"strings"
"syscall"

"./downloader"
)

func main() {
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "not enough args")
os.Exit(1)
}

url := os.Args[1]

tmp := strings.Split(url, "/")

filename := tmp[len(tmp)-1]
fp, err := os.Create(filename)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(2)
}
defer fp.Close()

size, dltype, err := downloader.CheckHead(url)

if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(2)
}

//cancel時に他のゴルーチンもとじる
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

res := start(ctx, fp, url, dltype, size)

sig := make(chan os.Signal)
//ctrl+cを受けるようにする
signal.Notify(sig, syscall.SIGINT)

loop:
for {
select {
case err = <-res:
// cancel()が実行されるかエラーが戻ってくるとループを抜ける
break loop

case <-sig:
//ctrl+cなどで中断
fmt.Fprintln(os.Stderr, "ctrl+c received")
cancel()
}
}

if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(2)
}

}

func start(ctx context.Context, fp *os.File, url, dltype string, size int64) <-chan error {
ch := make(chan error)

go func() {
defer close(ch)

var err error

switch dltype {
case "bytes":
err = downloader.ParallelDownload(ctx, fp, url, size)
default:
err = downloader.Download(ctx, fp, url, size)
}

ch <- err
}()

return ch
}