From 868cc455fefa29f5816d4d6c3f20789f8a365899 Mon Sep 17 00:00:00 2001 From: Martin Date: Thu, 7 Mar 2024 17:21:07 +0100 Subject: [PATCH 1/3] pass io.ReadCloser instances instead of files on harddrive, fixes #536 --- cmd/server/download.go | 49 ++++-------- cmd/server/search_crypto_test.go | 8 +- cmd/server/search_handlers_bench_test.go | 8 +- cmd/server/search_us_csl.go | 3 + pkg/csl/csl_test.go | 3 +- pkg/csl/download.go | 11 +-- pkg/csl/download_eu.go | 10 +-- pkg/csl/download_eu_test.go | 34 +++++---- pkg/csl/download_test.go | 32 ++++---- pkg/csl/download_uk.go | 18 ++--- pkg/csl/download_uk_test.go | 64 +++++++++------- pkg/csl/reader.go | 9 +-- pkg/csl/reader_eu.go | 20 +---- pkg/csl/reader_eu_test.go | 7 +- pkg/csl/reader_test.go | 28 +++++-- pkg/csl/reader_uk.go | 24 +++--- pkg/csl/reader_uk_test.go | 14 +++- pkg/download/client.go | 66 +++++++---------- pkg/dpl/download.go | 15 +--- pkg/dpl/download_test.go | 34 +++++---- pkg/dpl/reader.go | 9 +-- pkg/dpl/reader_test.go | 13 +++- pkg/ofac/download.go | 3 +- pkg/ofac/download_test.go | 11 ++- pkg/ofac/reader.go | 94 +++++++++++------------- pkg/ofac/reader_test.go | 24 ++++-- 26 files changed, 304 insertions(+), 307 deletions(-) diff --git a/cmd/server/download.go b/cmd/server/download.go index 2c9493f5..fed6b803 100644 --- a/cmd/server/download.go +++ b/cmd/server/download.go @@ -163,37 +163,14 @@ func ofacRecords(logger log.Logger, initialDir string) (*ofac.Results, error) { if len(files) == 0 { return nil, errors.New("no OFAC Results") } - - var res *ofac.Results - for i := range files { - // Does order matter? - - if i == 0 { - rr, err := ofac.Read(files[i]) - if err != nil { - return nil, fmt.Errorf("reading %s: %v", files[i], err) - } - if rr != nil { - res = rr - } - } else { - rr, err := ofac.Read(files[i]) - if err != nil { - return nil, fmt.Errorf("read and replace: %v", err) - } - if rr != nil { - res.Addresses = append(res.Addresses, rr.Addresses...) - res.AlternateIdentities = append(res.AlternateIdentities, rr.AlternateIdentities...) - res.SDNs = append(res.SDNs, rr.SDNs...) - res.SDNComments = append(res.SDNComments, rr.SDNComments...) - } - } + res, err := ofac.Read(files) + if err != nil { + return nil, err } // Merge comments into SDNs res.SDNs = mergeSpilloverRecords(res.SDNs, res.SDNComments) - - return res, err + return res, nil } func mergeSpilloverRecords(sdns []*ofac.SDN, comments []*ofac.SDNComments) []*ofac.SDN { @@ -212,7 +189,8 @@ func dplRecords(logger log.Logger, initialDir string) ([]*dpl.DPL, error) { if err != nil { return nil, err } - return dpl.Read(file) + + return dpl.Read(file["dpl.txt"]) } func cslRecords(logger log.Logger, initialDir string) (*csl.CSL, error) { @@ -221,11 +199,11 @@ func cslRecords(logger log.Logger, initialDir string) (*csl.CSL, error) { logger.Warn().Logf("skipping CSL download: %v", err) return &csl.CSL{}, nil } - cslRecords, err := csl.ReadFile(file) + cslRecords, err := csl.ReadFile(file["csl.csv"]) if err != nil { return nil, err } - return cslRecords, err + return cslRecords, nil } func euCSLRecords(logger log.Logger, initialDir string) ([]*csl.EUCSLRecord, error) { @@ -235,11 +213,13 @@ func euCSLRecords(logger log.Logger, initialDir string) ([]*csl.EUCSLRecord, err // no error to return because we skip the download return nil, nil } - cslRecords, _, err := csl.ReadEUFile(file) + + cslRecords, _, err := csl.ParseEU(file["eu_csl.csv"]) if err != nil { return nil, err } return cslRecords, err + } func ukCSLRecords(logger log.Logger, initialDir string) ([]*csl.UKCSLRecord, error) { @@ -249,7 +229,7 @@ func ukCSLRecords(logger log.Logger, initialDir string) ([]*csl.UKCSLRecord, err // no error to return because we skip the download return nil, nil } - cslRecords, _, err := csl.ReadUKCSLFile(file) + cslRecords, _, err := csl.ReadUKCSLFile(file["ConList.csv"]) if err != nil { return nil, err } @@ -264,7 +244,7 @@ func ukSanctionsListRecords(logger log.Logger, initialDir string) ([]*csl.UKSanc return nil, nil } - records, _, err := csl.ReadUKSanctionsListFile(file) + records, _, err := csl.ReadUKSanctionsListFile(file["UK_Sanctions_List.ods"]) if err != nil { return nil, err } @@ -342,6 +322,9 @@ func (s *searcher) refreshData(initialDir string) (*DownloadStats, error) { lastDataRefreshFailure.WithLabelValues("CSL").Set(float64(time.Now().Unix())) stats.Errors = append(stats.Errors, fmt.Errorf("CSL: %v", err)) } + if consolidatedLists == nil { + consolidatedLists = new(csl.CSL) + } els := precomputeCSLEntities[csl.EL](consolidatedLists.ELs, s.pipe) meus := precomputeCSLEntities[csl.MEU](consolidatedLists.MEUs, s.pipe) ssis := precomputeCSLEntities[csl.SSI](consolidatedLists.SSIs, s.pipe) diff --git a/cmd/server/search_crypto_test.go b/cmd/server/search_crypto_test.go index 8c8309f9..962cab96 100644 --- a/cmd/server/search_crypto_test.go +++ b/cmd/server/search_crypto_test.go @@ -7,8 +7,10 @@ package main import ( "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" + "os" "path/filepath" "testing" @@ -25,7 +27,11 @@ var ( func init() { // Set SDN Comments - ofacResults, err := ofac.Read(filepath.Join("..", "..", "test", "testdata", "sdn_comments.csv")) + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "sdn_comments.csv")) + if err != nil { + panic(fmt.Sprintf("%v", err)) + } + ofacResults, err := ofac.Read(map[string]io.ReadCloser{"sdn_comments.csv": fd}) if err != nil { panic(fmt.Sprintf("ERROR reading sdn_comments.csv: %v", err)) } diff --git a/cmd/server/search_handlers_bench_test.go b/cmd/server/search_handlers_bench_test.go index 0966a072..0170296f 100644 --- a/cmd/server/search_handlers_bench_test.go +++ b/cmd/server/search_handlers_bench_test.go @@ -7,10 +7,12 @@ package main import ( "crypto/rand" "fmt" + "io" "math/big" "net/http" "net/http/httptest" "net/url" + "os" "path/filepath" "strings" "testing" @@ -56,7 +58,11 @@ func BenchmarkSearchHandler(b *testing.B) { } func BenchmarkJaroWinkler(b *testing.B) { - results, err := ofac.Read(filepath.Join("..", "..", "test", "testdata", "sdn.csv")) + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "sdn.csv")) + if err != nil { + b.Error(err) + } + results, err := ofac.Read(map[string]io.ReadCloser{"sdn.csv": fd}) require.NoError(b, err) require.Len(b, results.SDNs, 7379) diff --git a/cmd/server/search_us_csl.go b/cmd/server/search_us_csl.go index 95250cee..423e4205 100644 --- a/cmd/server/search_us_csl.go +++ b/cmd/server/search_us_csl.go @@ -39,6 +39,9 @@ func searchUSCSL(logger log.Logger, searcher *searcher) http.HandlerFunc { func precomputeCSLEntities[T any](items []*T, pipe *pipeliner) []*Result[T] { out := make([]*Result[T], len(items)) + if items == nil { + return out + } for i, item := range items { name := cslName(item) diff --git a/pkg/csl/csl_test.go b/pkg/csl/csl_test.go index fb9bcc45..efeb021c 100644 --- a/pkg/csl/csl_test.go +++ b/pkg/csl/csl_test.go @@ -14,7 +14,6 @@ import ( ) func TestCSL(t *testing.T) { - t.Skip("CSL is currently broken, looks like they require API access now") if testing.Short() { t.Skip("ignorning network test") @@ -27,7 +26,7 @@ func TestCSL(t *testing.T) { file, err := Download(logger, dir) require.NoError(t, err) - cslRecords, err := ReadFile(file) + cslRecords, err := ReadFile(file["csl.csv"]) require.NoError(t, err) if len(cslRecords.SSIs) == 0 { diff --git a/pkg/csl/download.go b/pkg/csl/download.go index f1da515f..c168e4c7 100644 --- a/pkg/csl/download.go +++ b/pkg/csl/download.go @@ -6,6 +6,7 @@ package csl import ( "fmt" + "io" "net/url" "os" @@ -19,22 +20,18 @@ var ( usDownloadURL = strx.Or(os.Getenv("CSL_DOWNLOAD_TEMPLATE"), os.Getenv("US_CSL_DOWNLOAD_URL"), publicUSDownloadURL) ) -func Download(logger log.Logger, initialDir string) (string, error) { +func Download(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) { dl := download.New(logger, download.HTTPClient) cslURL, err := buildDownloadURL(usDownloadURL) if err != nil { - return "", err + return nil, err } cslNameAndSource := make(map[string]string) cslNameAndSource["csl.csv"] = cslURL - file, err := dl.GetFiles(initialDir, cslNameAndSource) - if len(file) == 0 || err != nil { - return "", fmt.Errorf("csl download: %v", err) - } - return file[0], nil + return dl.GetFiles(initialDir, cslNameAndSource) } func buildDownloadURL(urlStr string) (string, error) { diff --git a/pkg/csl/download_eu.go b/pkg/csl/download_eu.go index 463b4c85..bf8d5ac2 100644 --- a/pkg/csl/download_eu.go +++ b/pkg/csl/download_eu.go @@ -6,6 +6,7 @@ package csl import ( "fmt" + "io" "os" "github.com/moov-io/base/log" @@ -22,15 +23,12 @@ var ( euDownloadURL = strx.Or(os.Getenv("EU_CSL_DOWNLOAD_URL"), publicEUDownloadURL) ) -func DownloadEU(logger log.Logger, initialDir string) (string, error) { +func DownloadEU(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) { dl := download.New(logger, download.HTTPClient) euCSLNameAndSource := make(map[string]string) euCSLNameAndSource["eu_csl.csv"] = euDownloadURL - file, err := dl.GetFiles(initialDir, euCSLNameAndSource) - if len(file) == 0 || err != nil { - return "", fmt.Errorf("eu csl download: %v", err) - } - return file[0], nil + return dl.GetFiles(initialDir, euCSLNameAndSource) + } diff --git a/pkg/csl/download_eu_test.go b/pkg/csl/download_eu_test.go index 0984009f..b99cc80a 100644 --- a/pkg/csl/download_eu_test.go +++ b/pkg/csl/download_eu_test.go @@ -6,6 +6,7 @@ package csl import ( "fmt" + "io" "os" "path/filepath" "strings" @@ -24,18 +25,19 @@ func TestEUDownload(t *testing.T) { t.Fatal(err) } fmt.Println("file in test: ", file) - if file == "" { + if len(file) == 0 { t.Fatal("no EU CSL file") } - defer os.RemoveAll(filepath.Dir(file)) - if !strings.EqualFold("eu_csl.csv", filepath.Base(file)) { - t.Errorf("unknown file %s", file) + for fn := range file { + if !strings.EqualFold("eu_csl.csv", filepath.Base(fn)) { + t.Errorf("unknown file %s", file) + } } } func TestEUDownload_initialDir(t *testing.T) { - dir, err := os.MkdirTemp("", "iniital-dir") + dir, err := os.MkdirTemp("", "initial-dir") if err != nil { t.Fatal(err) } @@ -55,19 +57,21 @@ func TestEUDownload_initialDir(t *testing.T) { if err != nil { t.Fatal(err) } - if file == "" { + if len(file) == 0 { t.Fatal("no EU CSL file") } - if strings.EqualFold("eu_csl.csv", filepath.Base(file)) { - bs, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - if v := string(bs); v != "file=eu_csl.csv" { - t.Errorf("eu_csl.csv: %v", v) + for fn, fd := range file { + if strings.EqualFold("eu_csl.csv", filepath.Base(fn)) { + bs, err := io.ReadAll(fd) + if err != nil { + t.Fatal(err) + } + if v := string(bs); v != "file=eu_csl.csv" { + t.Errorf("eu_csl.csv: %v", v) + } + } else { + t.Fatalf("unknown file: %v", file) } - } else { - t.Fatalf("unknown file: %v", file) } } diff --git a/pkg/csl/download_test.go b/pkg/csl/download_test.go index e7971691..aed855a9 100644 --- a/pkg/csl/download_test.go +++ b/pkg/csl/download_test.go @@ -5,6 +5,7 @@ package csl import ( + "io" "os" "path/filepath" "strings" @@ -22,13 +23,14 @@ func TestDownload(t *testing.T) { if err != nil { t.Fatal(err) } - if file == "" { + if len(file) == 0 { t.Fatal("no CSL file") } - defer os.RemoveAll(filepath.Dir(file)) - if !strings.EqualFold("csl.csv", filepath.Base(file)) { - t.Errorf("unknown file %s", file) + for fn := range file { + if !strings.EqualFold("csl.csv", filepath.Base(fn)) { + t.Errorf("unknown file %s", fn) + } } } @@ -55,20 +57,22 @@ func TestDownload_initialDir(t *testing.T) { if err != nil { t.Fatal(err) } - if file == "" { + if len(file) == 0 { t.Fatal("no CSL file") } - if strings.EqualFold("csl.csv", filepath.Base(file)) { - bs, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - if v := string(bs); v != "file=csl.csv" { - t.Errorf("csl.csv: %v", v) + for fn, fd := range file { + if strings.EqualFold("csl.csv", filepath.Base(fn)) { + bs, err := io.ReadAll(fd) + if err != nil { + t.Fatal(err) + } + if v := string(bs); v != "file=csl.csv" { + t.Errorf("csl.csv: %v", v) + } + } else { + t.Fatalf("unknown file: %v", file) } - } else { - t.Fatalf("unknown file: %v", file) } } diff --git a/pkg/csl/download_uk.go b/pkg/csl/download_uk.go index cd0d40c8..cfccf1dd 100644 --- a/pkg/csl/download_uk.go +++ b/pkg/csl/download_uk.go @@ -5,7 +5,7 @@ package csl import ( - "fmt" + "io" "os" "github.com/moov-io/base/log" @@ -23,28 +23,20 @@ var ( ukSanctionsListURL = strx.Or(os.Getenv("UK_SANCTIONS_LIST_URL"), publicUKSanctionsListURL) ) -func DownloadUKCSL(logger log.Logger, initialDir string) (string, error) { +func DownloadUKCSL(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) { dl := download.New(logger, download.HTTPClient) ukCSLNameAndSource := make(map[string]string) ukCSLNameAndSource["ConList.csv"] = ukCSLDownloadURL - file, err := dl.GetFiles(initialDir, ukCSLNameAndSource) - if len(file) == 0 || err != nil { - return "", fmt.Errorf("uk csl download: %v", err) - } - return file[0], nil + return dl.GetFiles(initialDir, ukCSLNameAndSource) } -func DownloadUKSanctionsList(logger log.Logger, initialDir string) (string, error) { +func DownloadUKSanctionsList(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) { dl := download.New(logger, download.HTTPClient) ukSanctionsNameAndSource := make(map[string]string) ukSanctionsNameAndSource["UK_Sanctions_List.ods"] = ukSanctionsListURL - file, err := dl.GetFiles(initialDir, ukSanctionsNameAndSource) - if len(file) == 0 || err != nil { - return "", fmt.Errorf("uk download: %v", err) - } - return file[0], nil + return dl.GetFiles(initialDir, ukSanctionsNameAndSource) } diff --git a/pkg/csl/download_uk_test.go b/pkg/csl/download_uk_test.go index cd0952c5..fd4432c4 100644 --- a/pkg/csl/download_uk_test.go +++ b/pkg/csl/download_uk_test.go @@ -6,6 +6,7 @@ package csl import ( "fmt" + "io" "os" "path/filepath" "strings" @@ -24,13 +25,14 @@ func TestUKCSLDownload(t *testing.T) { t.Fatal(err) } fmt.Println("file in test: ", file) - if file == "" { + if len(file) == 0 { t.Fatal("no UK CSL file") } - defer os.RemoveAll(filepath.Dir(file)) - if !strings.EqualFold("ConList.csv", filepath.Base(file)) { - t.Errorf("unknown file %s", file) + for fn := range file { + if !strings.EqualFold("ConList.csv", filepath.Base(fn)) { + t.Errorf("unknown file %s", file) + } } } @@ -55,20 +57,22 @@ func TestUKCSLDownload_initialDir(t *testing.T) { if err != nil { t.Fatal(err) } - if file == "" { + if len(file) == 0 { t.Fatal("no UK CSL file") } - if strings.EqualFold("ConList.csv", filepath.Base(file)) { - bs, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - if v := string(bs); v != "file=ConList.csv" { - t.Errorf("ConList.csv: %v", v) + for fn, fd := range file { + if strings.EqualFold("ConList.csv", filepath.Base(fn)) { + bs, err := io.ReadAll(fd) + if err != nil { + t.Fatal(err) + } + if v := string(bs); v != "file=ConList.csv" { + t.Errorf("ConList.csv: %v", v) + } + } else { + t.Fatalf("unknown file: %v", file) } - } else { - t.Fatalf("unknown file: %v", file) } } @@ -82,13 +86,14 @@ func TestUKSanctionsListDownload(t *testing.T) { t.Fatal(err) } fmt.Println("file in test: ", file) - if file == "" { + if len(file) == 0 { t.Fatal("no UK Sanctions List file") } - defer os.RemoveAll(filepath.Dir(file)) - if !strings.EqualFold("UK_Sanctions_List.ods", filepath.Base(file)) { - t.Errorf("unknown file %s", file) + for fn := range file { + if !strings.EqualFold("UK_Sanctions_List.ods", filepath.Base(fn)) { + t.Errorf("unknown file %s", file) + } } } @@ -113,19 +118,22 @@ func TestUKSanctionsListDownload_initialDir(t *testing.T) { if err != nil { t.Fatal(err) } - if file == "" { + + if len(file) == 0 { t.Fatal("no UK Sanctions List file") } - if strings.EqualFold("UK_Sanctions_List.ods", filepath.Base(file)) { - _, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) + for fn, fd := range file { + if strings.EqualFold("UK_Sanctions_List.ods", filepath.Base(fn)) { + _, err := io.ReadAll(fd) + if err != nil { + t.Fatal(err) + } + // if v := string(bs); v != "file=UK_Sanctions_List.ods" { + // t.Errorf("UK_Sanctions_List.ods: %v", v) + // } + } else { + t.Fatalf("unknown file: %v", file) } - // if v := string(bs); v != "file=UK_Sanctions_List.ods" { - // t.Errorf("UK_Sanctions_List.ods: %v", v) - // } - } else { - t.Fatalf("unknown file: %v", file) } } diff --git a/pkg/csl/reader.go b/pkg/csl/reader.go index ea6d0cb8..f969fcdb 100644 --- a/pkg/csl/reader.go +++ b/pkg/csl/reader.go @@ -4,17 +4,14 @@ import ( "encoding/csv" "errors" "io" - "os" "strings" ) -func ReadFile(path string) (*CSL, error) { - fd, err := os.Open(path) - if err != nil { - return nil, err +func ReadFile(fd io.ReadCloser) (*CSL, error) { + if fd == nil { + return nil, errors.New("CSL file is empty or missing") } defer fd.Close() - return Parse(fd) } diff --git a/pkg/csl/reader_eu.go b/pkg/csl/reader_eu.go index 8d746bf3..bba259af 100644 --- a/pkg/csl/reader_eu.go +++ b/pkg/csl/reader_eu.go @@ -5,26 +5,14 @@ import ( "errors" "fmt" "io" - "os" "strconv" ) -func ReadEUFile(path string) ([]*EUCSLRecord, EUCSL, error) { - fd, err := os.Open(path) - if err != nil { - return nil, nil, err - } - defer fd.Close() - - rows, rowsMap, err := ParseEU(fd) - if err != nil { - return nil, nil, err +func ParseEU(r io.ReadCloser) ([]*EUCSLRecord, EUCSL, error) { + if r == nil { + return nil, nil, errors.New("EU CSL file is empty or missing") } - - return rows, rowsMap, nil -} - -func ParseEU(r io.Reader) ([]*EUCSLRecord, EUCSL, error) { + defer r.Close() reader := csv.NewReader(r) // sets comma delim to ; and ignores " in non quoted field and size of columns // https://stackoverflow.com/questions/31326659/golang-csv-error-bare-in-non-quoted-field diff --git a/pkg/csl/reader_eu_test.go b/pkg/csl/reader_eu_test.go index b4cbd5cb..d2e7f4cd 100644 --- a/pkg/csl/reader_eu_test.go +++ b/pkg/csl/reader_eu_test.go @@ -1,6 +1,7 @@ package csl import ( + "os" "path/filepath" "testing" @@ -8,7 +9,11 @@ import ( ) func TestReadEU(t *testing.T) { - euCSL, euCSLMap, err := ReadEUFile(filepath.Join("..", "..", "test", "testdata", "eu_csl.csv")) + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "eu_csl.csv")) + if err != nil { + t.Error(err) + } + euCSL, euCSLMap, err := ParseEU(fd) if err != nil { t.Fatal(err) } diff --git a/pkg/csl/reader_test.go b/pkg/csl/reader_test.go index ea7e8686..8877067f 100644 --- a/pkg/csl/reader_test.go +++ b/pkg/csl/reader_test.go @@ -13,7 +13,11 @@ import ( ) func TestRead(t *testing.T) { - csl, err := ReadFile(filepath.Join("..", "..", "test", "testdata", "csl.csv")) + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "csl.csv")) + if err != nil { + t.Error(err) + } + csl, err := ReadFile(fd) if err != nil { t.Fatal(err) } @@ -52,8 +56,8 @@ func TestRead_missingRow(t *testing.T) { _, err = fd.WriteString(` \n invalid \n \n`) require.NoError(t, err) - - resp, err := ReadFile(fd.Name()) + fd.Seek(0, 0) + resp, err := ReadFile(fd) require.NoError(t, err) require.Len(t, resp.ELs, 0) @@ -62,7 +66,12 @@ func TestRead_missingRow(t *testing.T) { } func TestRead_invalidRow(t *testing.T) { - csl, err := ReadFile(filepath.Join("..", "..", "test", "testdata", "invalidFiles", "csl.csv")) + + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "invalidFiles", "csl.csv")) + if err != nil { + t.Error(err) + } + csl, err := ReadFile(fd) if err != nil { t.Fatal(err) } @@ -394,9 +403,9 @@ func Test__Issue326EL(t *testing.T) { if err := fd.Sync(); err != nil { t.Fatal(err) } - + fd.Seek(0, 0) // read the line back - csl, err := ReadFile(fd.Name()) + csl, err := ReadFile(fd) if err != nil { t.Fatal(err) } @@ -452,8 +461,11 @@ func Test_expandProgramsList(t *testing.T) { func TestCSL__UniqueIDs(t *testing.T) { // CSL datafiles have added a unique identifier as the first column. // We need verify the old and new file formats can be parsed. - - records, err := ReadFile(filepath.Join("..", "..", "test", "testdata", "csl-unique-ids.csv")) + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "csl-unique-ids.csv")) + if err != nil { + t.Error(err) + } + records, err := ReadFile(fd) if err != nil { t.Fatal(err) } diff --git a/pkg/csl/reader_uk.go b/pkg/csl/reader_uk.go index dceb0279..aaf53420 100644 --- a/pkg/csl/reader_uk.go +++ b/pkg/csl/reader_uk.go @@ -5,20 +5,15 @@ import ( "encoding/csv" "errors" "io" - "os" "strconv" "strings" "github.com/knieriem/odf/ods" ) -func ReadUKCSLFile(path string) ([]*UKCSLRecord, UKCSL, error) { - if path == "" { - return nil, nil, errors.New("path was empty for ukcsl file") - } - fd, err := os.Open(path) - if err != nil { - return nil, nil, err +func ReadUKCSLFile(fd io.ReadCloser) ([]*UKCSLRecord, UKCSL, error) { + if fd == nil { + return nil, nil, errors.New("uk CSL file is empty or missing") } defer fd.Close() @@ -200,11 +195,16 @@ func unmarshalUKCSLRecord(csvRecord []string, ukCSLRecord *UKCSLRecord) { } } -func ReadUKSanctionsListFile(path string) ([]*UKSanctionsListRecord, UKSanctionsListMap, error) { - if path == "" { - return nil, nil, errors.New("path was empty for uk sanctions list file") +func ReadUKSanctionsListFile(f io.ReadCloser) ([]*UKSanctionsListRecord, UKSanctionsListMap, error) { + if f == nil { + return nil, nil, errors.New("uk sanctions list file is empty or missing") + } + defer f.Close() + content, err := io.ReadAll(f) + if err != nil { + return nil, nil, err } - fd, err := ods.Open(path) + fd, err := ods.NewReader(bytes.NewReader(content), int64(len(content))) if err != nil { return nil, nil, err } diff --git a/pkg/csl/reader_uk_test.go b/pkg/csl/reader_uk_test.go index b0e5c11a..e8bb7200 100644 --- a/pkg/csl/reader_uk_test.go +++ b/pkg/csl/reader_uk_test.go @@ -2,6 +2,7 @@ package csl import ( "fmt" + "os" "path/filepath" "testing" @@ -9,7 +10,11 @@ import ( ) func TestReadUKCSL(t *testing.T) { - ukCSL, ukCSLMap, err := ReadUKCSLFile(filepath.Join("..", "..", "test", "testdata", "ConList.csv")) + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "ConList.csv")) + if err != nil { + t.Error(err) + } + ukCSL, ukCSLMap, err := ReadUKCSLFile(fd) if err != nil { t.Fatal(err) } @@ -71,8 +76,13 @@ func TestReadUKCSL(t *testing.T) { func TestReadUKSanctionsList(t *testing.T) { t.Setenv("WITH_UK_SANCTIONS_LIST", "false") + + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "UK_Sanctions_List.ods")) + if err != nil { + t.Error(err) + } // test we don't err on parsing the content - totalReport, report, err := ReadUKSanctionsListFile("../../test/testdata/UK_Sanctions_List.ods") + totalReport, report, err := ReadUKSanctionsListFile(fd) assert.NoError(t, err) // test that we get something more than an empty sanctions list record diff --git a/pkg/download/client.go b/pkg/download/client.go index 0fb9e20a..060981ad 100644 --- a/pkg/download/client.go +++ b/pkg/download/client.go @@ -22,7 +22,7 @@ import ( var ( HTTPClient = &http.Client{ - Timeout: 15 * time.Second, + Timeout: 45 * time.Second, } ) @@ -49,7 +49,7 @@ type Downloader struct { // initialDir is an optional filepath to look for files in before attempting to download. // // Callers are expected to cleanup the temp directory. -func (dl *Downloader) GetFiles(initialDir string, namesAndSources map[string]string) ([]string, error) { +func (dl *Downloader) GetFiles(initialDir string, namesAndSources map[string]string) (map[string]io.ReadCloser, error) { if dl == nil { return nil, errors.New("nil Downloader") } @@ -80,27 +80,29 @@ func (dl *Downloader) GetFiles(initialDir string, namesAndSources map[string]str } var mu sync.Mutex - var out []string - + out := make(map[string]io.ReadCloser) var wg sync.WaitGroup wg.Add(len(namesAndSources)) + +findfiles: for name, source := range namesAndSources { // Check if we have the file locally first - found := false - for i := range localFiles { - if strings.EqualFold(filepath.Base(localFiles[i].Name()), name) { - found = true + for _, file := range localFiles { + if strings.EqualFold(filepath.Base(file.Name()), name) { + fn := filepath.Join(dir, file.Name()) + fd, err := os.Open(fn) + if err != nil { + dl.Logger.Error().LogErrorf("could not read file from %v initialDir: %v", fn, err) + continue + } mu.Lock() - out = append(out, filepath.Join(dir, localFiles[i].Name())) + out[name] = fd mu.Unlock() - break + // file is found, skip downloading + wg.Done() + continue findfiles } } - // Skip downloading this file since we found it - if found { - wg.Done() - continue - } // Download missing files go func(wg *sync.WaitGroup, filename, downloadURL string) { @@ -109,17 +111,17 @@ func (dl *Downloader) GetFiles(initialDir string, namesAndSources map[string]str logger := dl.createLogger(filename, downloadURL) startTime := time.Now().In(time.UTC) - err := dl.retryDownload(dir, filename, downloadURL) + content, err := dl.retryDownload(downloadURL) dur := time.Now().In(time.UTC).Sub(startTime) if err != nil { logger.Error().LogErrorf("FAILURE after %v to download: %v", dur, err) - } else { - logger.Info().Logf("successful download after %v", dur) + return } + logger.Info().Logf("successful download after %v", dur) mu.Lock() - out = append(out, filepath.Join(dir, filename)) + out[filename] = content mu.Unlock() }(&wg, name, source) } @@ -140,41 +142,29 @@ func (dl *Downloader) createLogger(filename, downloadURL string) log.Logger { }) } -func (dl *Downloader) retryDownload(dir, filename, downloadURL string) error { +func (dl *Downloader) retryDownload(downloadURL string) (io.ReadCloser, error) { // Allow a couple retries for various sources (some are flakey) for i := 0; i < 3; i++ { req, err := http.NewRequest(http.MethodGet, downloadURL, nil) if err != nil { - return dl.Logger.Error().LogErrorf("error building HTTP request: %v", err).Err() + return nil, dl.Logger.Error().LogErrorf("error building HTTP request: %v", err).Err() } req.Header.Set("User-Agent", fmt.Sprintf("moov-io/watchman:%v", watchman.Version)) // in order to get passed europes 406 (Not Accepted) req.Header.Set("accept-language", "en-US,en;q=0.9") resp, err := dl.HTTP.Do(req) + if err != nil { dl.Logger.Error().LogErrorf("err while doing client request: ", err) time.Sleep(100 * time.Millisecond) - continue // retry + continue } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - dl.Logger.Error().LogErrorf("we experienced a problem in the dl: %v", resp.StatusCode) - } - - // Copy resp.Body into a file in our temp dir - fd, err := os.Create(filepath.Join(dir, filename)) - if err != nil { resp.Body.Close() - return fmt.Errorf("attempt %d failed to create file: %v", i, err) + continue } - - io.Copy(fd, resp.Body) // copy file contents - - // close the open files - fd.Close() - resp.Body.Close() - return nil // quit after successful download + return resp.Body, nil } - return nil + return nil, errors.New("error max retries reached while trying to obtain file") } diff --git a/pkg/dpl/download.go b/pkg/dpl/download.go index 369eaca3..989e7b66 100644 --- a/pkg/dpl/download.go +++ b/pkg/dpl/download.go @@ -6,8 +6,8 @@ package dpl import ( "fmt" + "io" "os" - "path/filepath" "github.com/moov-io/base/log" "github.com/moov-io/watchman/pkg/download" @@ -23,20 +23,11 @@ var ( ) // Download returns an array of absolute filepaths for files downloaded -func Download(logger log.Logger, initialDir string) (string, error) { +func Download(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) { dl := download.New(logger, download.HTTPClient) addrs := make(map[string]string) addrs["dpl.txt"] = fmt.Sprintf(dplDownloadTemplate, "dpl.txt") - files, err := dl.GetFiles(initialDir, addrs) - if len(files) == 0 || err != nil { - return "", fmt.Errorf("dpl download: %v", err) - } - for i := range files { - if filepath.Base(files[i]) == "dpl.txt" { - return files[i], nil - } - } - return "", nil + return dl.GetFiles(initialDir, addrs) } diff --git a/pkg/dpl/download_test.go b/pkg/dpl/download_test.go index ecd3cd06..1240427a 100644 --- a/pkg/dpl/download_test.go +++ b/pkg/dpl/download_test.go @@ -5,6 +5,7 @@ package dpl import ( + "io" "os" "path/filepath" "strings" @@ -22,13 +23,13 @@ func TestDownloader(t *testing.T) { if err != nil { t.Fatal(err) } - if file == "" { + if len(file) == 0 { t.Fatal("no DPL file") } - defer os.RemoveAll(filepath.Dir(file)) - - if !strings.EqualFold("dpl.txt", filepath.Base(file)) { - t.Errorf("unknown file %s", file) + for filename, _ := range file { + if !strings.EqualFold("dpl.txt", filepath.Base(filename)) { + t.Errorf("unknown file %s", file) + } } } @@ -54,19 +55,20 @@ func TestDownloader__initialDir(t *testing.T) { if err != nil { t.Fatal(err) } - if file == "" { + if len(file) == 0 { t.Fatal("no DPL file") } - - if strings.EqualFold("dpl.txt", filepath.Base(file)) { - bs, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - if v := string(bs); v != "file=dpl.txt" { - t.Errorf("dpl.txt: %v", v) + for fn, fd := range file { + if strings.EqualFold("dpl.txt", filepath.Base(fn)) { + bs, err := io.ReadAll(fd) + if err != nil { + t.Fatal(err) + } + if v := string(bs); v != "file=dpl.txt" { + t.Errorf("dpl.txt: %v", v) + } + } else { + t.Fatalf("unknown file: %v", file) } - } else { - t.Fatalf("unknown file: %v", file) } } diff --git a/pkg/dpl/reader.go b/pkg/dpl/reader.go index a1100f54..762967dc 100644 --- a/pkg/dpl/reader.go +++ b/pkg/dpl/reader.go @@ -8,17 +8,14 @@ import ( "encoding/csv" "errors" "io" - "os" ) // Read parses DPL records from a TXT file and populates the associated arrays. // // For more details on the raw DPL files see https://moov-io.github.io/watchman/file-structure.html -func Read(path string) ([]*DPL, error) { - // open txt file - f, err := os.Open(path) - if err != nil { - return nil, err +func Read(f io.ReadCloser) ([]*DPL, error) { + if f == nil { + return nil, errors.New("DPL file is empty or missing") } defer f.Close() diff --git a/pkg/dpl/reader_test.go b/pkg/dpl/reader_test.go index cdda618e..da9fcc01 100644 --- a/pkg/dpl/reader_test.go +++ b/pkg/dpl/reader_test.go @@ -5,12 +5,17 @@ package dpl import ( + "os" "path/filepath" "testing" ) func TestDPL__read(t *testing.T) { - dpls, err := Read(filepath.Join("..", "..", "test", "testdata", "dpl.txt")) + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "dpl.txt")) + if err != nil { + t.Error(err) + } + dpls, err := Read(fd) if err != nil { t.Fatal(err) } @@ -19,7 +24,11 @@ func TestDPL__read(t *testing.T) { } // this file is formatted incorrectly for DPL, so we expect all rows to be skipped - got, err := Read(filepath.Join("..", "..", "test", "testdata", "sdn.csv")) + fd, err = os.Open(filepath.Join("..", "..", "test", "testdata", "sdn.csv")) + if err != nil { + t.Error(err) + } + got, err := Read(fd) if err != nil { t.Fatal(err) } diff --git a/pkg/ofac/download.go b/pkg/ofac/download.go index 242ccd64..0574b82b 100644 --- a/pkg/ofac/download.go +++ b/pkg/ofac/download.go @@ -6,6 +6,7 @@ package ofac import ( "fmt" + "io" "os" "github.com/moov-io/base/log" @@ -28,7 +29,7 @@ var ( }() ) -func Download(logger log.Logger, initialDir string) ([]string, error) { +func Download(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) { dl := download.New(logger, download.HTTPClient) addrs := make(map[string]string) diff --git a/pkg/ofac/download_test.go b/pkg/ofac/download_test.go index ae8b2456..51207542 100644 --- a/pkg/ofac/download_test.go +++ b/pkg/ofac/download_test.go @@ -21,13 +21,12 @@ func TestDownloader(t *testing.T) { if err != nil { t.Fatal(err) } - defer os.RemoveAll(filepath.Dir(files[0])) if len(files) != 4 { t.Errorf("OFAC: found %d files", len(files)) } - for i := range files { - name := filepath.Base(files[i]) + for fn, _ := range files { + name := filepath.Base(fn) switch name { case "add.csv", "alt.csv", "sdn.csv", "sdn_comments.csv": continue @@ -59,10 +58,10 @@ func TestDownloader__initialDir(t *testing.T) { if err != nil { t.Fatal(err) } - for i := range files { - switch filepath.Base(files[i]) { + for fn, _ := range files { + switch filepath.Base(fn) { case "sdn.txt": - bs, err := os.ReadFile(files[i]) + bs, err := os.ReadFile(fn) if err != nil { t.Fatal(err) } diff --git a/pkg/ofac/reader.go b/pkg/ofac/reader.go index e28ade97..1c9fc3e1 100644 --- a/pkg/ofac/reader.go +++ b/pkg/ofac/reader.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "os" "path/filepath" "strings" ) @@ -17,37 +16,39 @@ import ( // Read will consume the file at path and attempt to parse it was a CSV OFAC file. // // For more details on the raw OFAC files see https://moov-io.github.io/watchman/file-structure.html -func Read(path string) (*Results, error) { - switch filepath.Base(path) { - case "add.csv": - res, err := csvAddressFile(path) - if err != nil { - return res, fmt.Errorf("add.csv: %v", err) - } - return res, err +func Read(files map[string]io.ReadCloser) (*Results, error) { + res := new(Results) + for filename, file := range files { + switch filepath.Base(filename) { + case "add.csv": + err := res.append(csvAddressFile(file)) + if err != nil { + return nil, fmt.Errorf("add.csv: %v", err) + } + case "alt.csv": + err := res.append(csvAlternateIdentityFile(file)) + if err != nil { + return nil, fmt.Errorf("add.csv: %v", err) + } - case "alt.csv": - res, err := csvAlternateIdentityFile(path) - if err != nil { - return res, fmt.Errorf("alt.csv: %v", err) - } - return res, err + case "sdn.csv": + err := res.append(csvSDNFile(file)) + if err != nil { + return nil, fmt.Errorf("add.csv: %v", err) + } - case "sdn.csv": - res, err := csvSDNFile(path) - if err != nil { - return res, fmt.Errorf("sdn.csv: %v", err) - } - return res, err + case "sdn_comments.csv": + err := res.append(csvSDNCommentsFile(file)) + if err != nil { + return nil, fmt.Errorf("add.csv: %v", err) + } - case "sdn_comments.csv": - res, err := csvSDNCommentsFile(path) - if err != nil { - return res, fmt.Errorf("sdn_comments.csv: %v", err) + default: + file.Close() + return nil, fmt.Errorf("error: file %s does not have a handler for processing", filename) } - return res, err } - return nil, nil + return res, nil } type Results struct { @@ -64,14 +65,19 @@ type Results struct { SDNComments []*SDNComments `json:"sdnComments"` } -func csvAddressFile(path string) (*Results, error) { - // Open CSV file - f, err := os.Open(path) +func (r *Results) append(rr *Results, err error) error { if err != nil { - return nil, err + return err } - defer f.Close() + r.Addresses = append(r.Addresses, rr.Addresses...) + r.AlternateIdentities = append(r.AlternateIdentities, rr.AlternateIdentities...) + r.SDNs = append(r.SDNs, rr.SDNs...) + r.SDNComments = append(r.SDNComments, rr.SDNComments...) + return nil +} +func csvAddressFile(f io.ReadCloser) (*Results, error) { + defer f.Close() var out []*Address // Read File into a Variable @@ -108,14 +114,8 @@ func csvAddressFile(path string) (*Results, error) { return &Results{Addresses: out}, nil } -func csvAlternateIdentityFile(path string) (*Results, error) { - // Open CSV file - f, err := os.Open(path) - if err != nil { - return nil, err - } +func csvAlternateIdentityFile(f io.ReadCloser) (*Results, error) { defer f.Close() - var out []*AlternateIdentity // Read File into a Variable @@ -150,14 +150,8 @@ func csvAlternateIdentityFile(path string) (*Results, error) { return &Results{AlternateIdentities: out}, nil } -func csvSDNFile(path string) (*Results, error) { - // Open CSV file - f, err := os.Open(path) - if err != nil { - return nil, err - } +func csvSDNFile(f io.ReadCloser) (*Results, error) { defer f.Close() - var out []*SDN // Read File into a Variable @@ -199,14 +193,8 @@ func csvSDNFile(path string) (*Results, error) { return &Results{SDNs: out}, nil } -func csvSDNCommentsFile(path string) (*Results, error) { - // Open CSV file - f, err := os.Open(path) - if err != nil { - return nil, err - } +func csvSDNCommentsFile(f io.ReadCloser) (*Results, error) { defer f.Close() - // Read File into a Variable r := csv.NewReader(f) r.LazyQuotes = true diff --git a/pkg/ofac/reader_test.go b/pkg/ofac/reader_test.go index 35bfc516..d3a60f74 100644 --- a/pkg/ofac/reader_test.go +++ b/pkg/ofac/reader_test.go @@ -5,6 +5,7 @@ package ofac import ( + "io" "os" "path/filepath" "reflect" @@ -15,28 +16,35 @@ import ( // TestOFAC__read validates reading an OFAC Address CSV File func TestOFAC__read(t *testing.T) { - res, err := Read(filepath.Join("..", "..", "test", "testdata", "add.csv")) + testdata := func(fn string) map[string]io.ReadCloser { + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", fn)) + if err != nil { + t.Error(err) + } + return map[string]io.ReadCloser{fn: fd} + } + res, err := Read(testdata("add.csv")) require.NoError(t, err) require.Len(t, res.Addresses, 11696) require.Len(t, res.AlternateIdentities, 0) require.Len(t, res.SDNs, 0) require.Len(t, res.SDNComments, 0) - res, err = Read(filepath.Join("..", "..", "test", "testdata", "alt.csv")) + res, err = Read(testdata("alt.csv")) require.NoError(t, err) require.Len(t, res.Addresses, 0) require.Len(t, res.AlternateIdentities, 9682) require.Len(t, res.SDNs, 0) require.Len(t, res.SDNComments, 0) - res, err = Read(filepath.Join("..", "..", "test", "testdata", "sdn.csv")) + res, err = Read(testdata("sdn.csv")) require.NoError(t, err) require.Len(t, res.Addresses, 0) require.Len(t, res.AlternateIdentities, 0) require.Len(t, res.SDNs, 7379) require.Len(t, res.SDNComments, 0) - res, err = Read(filepath.Join("..", "..", "test", "testdata", "sdn_comments.csv")) + res, err = Read(testdata("sdn_comments.csv")) require.NoError(t, err) require.Len(t, res.Addresses, 0) require.Len(t, res.AlternateIdentities, 0) @@ -98,9 +106,9 @@ func TestSDNComments(t *testing.T) { if _, err := fd.WriteString(`28264,"hone Number 8613314257947; alt. Phone Number 8618004121000; Identification Number 210302198701102136 (China); a.k.a. "blackjack1987"; a.k.a. "khaleesi"; Linked To: LAZARUS GROUP."`); err != nil { t.Fatal(err) } - + fd.Seek(0, 0) // read with lazy quotes enabled - if res, err := csvSDNCommentsFile(fd.Name()); err != nil { + if res, err := csvSDNCommentsFile(fd); err != nil { t.Errorf("unexpected error: %v", err) } else { if len(res.SDNComments) != 1 { @@ -118,8 +126,8 @@ func TestSDNComments_CryptoCurrencies(t *testing.T) { _, err = fd.WriteString(`42496," alt. Digital Currency Address - XBT 12jVCWW1ZhTLA5yVnroEJswqKwsfiZKsax; alt. Digital Currency Address - XBT 1J378PbmTKn2sEw6NBrSWVfjZLBZW3DZem; alt. Digital Currency Address - XBT 18aqbRhHupgvC9K8qEqD78phmTQQWs7B5d; alt. Digital Currency Address - XBT 16ti2EXaae5izfkUZ1Zc59HMcsdnHpP5QJ; Secondary sanctions risk: North Korea Sanctions Regulations, sections 510.201 and 510.210; Transactions Prohibited For Persons Owned or Controlled By U.S. Financial Institutions: North Korea Sanctions Regulations section 510.214; Passport E59165201 (China) expires 01 Sep 2025; Identification Number 371326198812157611 (China); a.k.a. 'WAKEMEUPUPUP'; a.k.a. 'FAST4RELEASE'; Linked To: LAZARUS GROUP."`) require.NoError(t, err) - - sdn, err := csvSDNCommentsFile(fd.Name()) + fd.Seek(0, 0) + sdn, err := csvSDNCommentsFile(fd) require.NoError(t, err) require.Len(t, sdn.SDNComments, 1) From 8772e83bd37cad17d2565ac4b14053edda2132bd Mon Sep 17 00:00:00 2001 From: Martin Date: Thu, 7 Mar 2024 19:41:30 +0100 Subject: [PATCH 2/3] tidy up code --- go.mod | 2 +- pkg/dpl/download_test.go | 2 +- pkg/dpl/reader.go | 22 ++++++++++------------ pkg/ofac/download_test.go | 4 ++-- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/go.mod b/go.mod index 2b369274..e6f860fa 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 go4.org v0.0.0-20230225012048-214862532bf5 golang.org/x/oauth2 v0.14.0 + golang.org/x/sync v0.6.0 golang.org/x/text v0.14.0 ) @@ -37,7 +38,6 @@ require ( github.com/prometheus/procfs v0.12.0 // indirect github.com/rickar/cal/v2 v2.1.13 // indirect golang.org/x/crypto v0.17.0 // indirect - golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.15.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/protobuf v1.31.0 // indirect diff --git a/pkg/dpl/download_test.go b/pkg/dpl/download_test.go index 1240427a..b20b9a31 100644 --- a/pkg/dpl/download_test.go +++ b/pkg/dpl/download_test.go @@ -26,7 +26,7 @@ func TestDownloader(t *testing.T) { if len(file) == 0 { t.Fatal("no DPL file") } - for filename, _ := range file { + for filename := range file { if !strings.EqualFold("dpl.txt", filepath.Base(filename)) { t.Errorf("unknown file %s", file) } diff --git a/pkg/dpl/reader.go b/pkg/dpl/reader.go index 762967dc..2db09f0c 100644 --- a/pkg/dpl/reader.go +++ b/pkg/dpl/reader.go @@ -28,19 +28,17 @@ func Read(f io.ReadCloser) ([]*DPL, error) { for { line, err := reader.Read() if err != nil { - if err != nil { - // reached the last line - if errors.Is(err, io.EOF) { - break - } - // malformed row - if errors.Is(err, csv.ErrFieldCount) || - errors.Is(err, csv.ErrBareQuote) || - errors.Is(err, csv.ErrQuote) { - continue - } - return nil, err + // reached the last line + if errors.Is(err, io.EOF) { + break } + // malformed row + if errors.Is(err, csv.ErrFieldCount) || + errors.Is(err, csv.ErrBareQuote) || + errors.Is(err, csv.ErrQuote) { + continue + } + return nil, err } if len(line) < 12 || (len(line) >= 2 && line[1] == "Street_Address") { diff --git a/pkg/ofac/download_test.go b/pkg/ofac/download_test.go index 51207542..95229292 100644 --- a/pkg/ofac/download_test.go +++ b/pkg/ofac/download_test.go @@ -25,7 +25,7 @@ func TestDownloader(t *testing.T) { if len(files) != 4 { t.Errorf("OFAC: found %d files", len(files)) } - for fn, _ := range files { + for fn := range files { name := filepath.Base(fn) switch name { case "add.csv", "alt.csv", "sdn.csv", "sdn_comments.csv": @@ -58,7 +58,7 @@ func TestDownloader__initialDir(t *testing.T) { if err != nil { t.Fatal(err) } - for fn, _ := range files { + for fn := range files { switch filepath.Base(fn) { case "sdn.txt": bs, err := os.ReadFile(fn) From 87d30992960a3867ba5975edb3d8489f09492595 Mon Sep 17 00:00:00 2001 From: Martin Date: Fri, 8 Mar 2024 15:29:24 +0100 Subject: [PATCH 3/3] update downloader documentation to new behaviour and code cleanup --- pkg/download/client.go | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/pkg/download/client.go b/pkg/download/client.go index 060981ad..8c88bea6 100644 --- a/pkg/download/client.go +++ b/pkg/download/client.go @@ -43,13 +43,12 @@ type Downloader struct { Logger log.Logger } -// GetFiles will download all provided files, return their filepaths, and store them in a -// temporary directory and an error otherwise. +// GetFiles will initiate download of all provided files, return an io.ReadCloser to their content // // initialDir is an optional filepath to look for files in before attempting to download. // -// Callers are expected to cleanup the temp directory. -func (dl *Downloader) GetFiles(initialDir string, namesAndSources map[string]string) (map[string]io.ReadCloser, error) { +// Callers are expected to call the io.Closer interface method when they are done with the file +func (dl *Downloader) GetFiles(dir string, namesAndSources map[string]string) (map[string]io.ReadCloser, error) { if dl == nil { return nil, errors.New("nil Downloader") } @@ -61,19 +60,6 @@ func (dl *Downloader) GetFiles(initialDir string, namesAndSources map[string]str } // Check the initial directory for files we don't need to download - var dir string - if initialDir != "" { - dir = initialDir // empty, but use it as a directory - } - // Create a temporary directory for downloads if needed - if dir == "" { - temp, err := os.MkdirTemp("", "downloader") - if err != nil { - return nil, fmt.Errorf("downloader: unable to make temp dir: %v", err) - } - dir = temp - } - localFiles, err := os.ReadDir(dir) if err != nil { return nil, fmt.Errorf("readdir %s: %v", dir, err) @@ -93,6 +79,7 @@ findfiles: fd, err := os.Open(fn) if err != nil { dl.Logger.Error().LogErrorf("could not read file from %v initialDir: %v", fn, err) + fd.Close() continue } mu.Lock()