Skip to content

Commit

Permalink
Merge pull request #537 from m29h/issue-536
Browse files Browse the repository at this point in the history
pass io.ReadCloser instances instead of files on harddrive, fixes #536
  • Loading branch information
adamdecaf authored Mar 12, 2024
2 parents 5903d14 + 87d3099 commit 33df3aa
Show file tree
Hide file tree
Showing 27 changed files with 318 additions and 336 deletions.
49 changes: 16 additions & 33 deletions cmd/server/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,37 +163,14 @@ func ofacRecords(logger log.Logger, initialDir string) (*ofac.Results, error) {
if len(files) == 0 {
return nil, errors.New("no OFAC Results")
}

var res *ofac.Results
for i := range files {
// Does order matter?

if i == 0 {
rr, err := ofac.Read(files[i])
if err != nil {
return nil, fmt.Errorf("reading %s: %v", files[i], err)
}
if rr != nil {
res = rr
}
} else {
rr, err := ofac.Read(files[i])
if err != nil {
return nil, fmt.Errorf("read and replace: %v", err)
}
if rr != nil {
res.Addresses = append(res.Addresses, rr.Addresses...)
res.AlternateIdentities = append(res.AlternateIdentities, rr.AlternateIdentities...)
res.SDNs = append(res.SDNs, rr.SDNs...)
res.SDNComments = append(res.SDNComments, rr.SDNComments...)
}
}
res, err := ofac.Read(files)
if err != nil {
return nil, err
}

// Merge comments into SDNs
res.SDNs = mergeSpilloverRecords(res.SDNs, res.SDNComments)

return res, err
return res, nil
}

func mergeSpilloverRecords(sdns []*ofac.SDN, comments []*ofac.SDNComments) []*ofac.SDN {
Expand All @@ -212,7 +189,8 @@ func dplRecords(logger log.Logger, initialDir string) ([]*dpl.DPL, error) {
if err != nil {
return nil, err
}
return dpl.Read(file)

return dpl.Read(file["dpl.txt"])
}

func cslRecords(logger log.Logger, initialDir string) (*csl.CSL, error) {
Expand All @@ -221,11 +199,11 @@ func cslRecords(logger log.Logger, initialDir string) (*csl.CSL, error) {
logger.Warn().Logf("skipping CSL download: %v", err)
return &csl.CSL{}, nil
}
cslRecords, err := csl.ReadFile(file)
cslRecords, err := csl.ReadFile(file["csl.csv"])
if err != nil {
return nil, err
}
return cslRecords, err
return cslRecords, nil
}

func euCSLRecords(logger log.Logger, initialDir string) ([]*csl.EUCSLRecord, error) {
Expand All @@ -235,11 +213,13 @@ func euCSLRecords(logger log.Logger, initialDir string) ([]*csl.EUCSLRecord, err
// no error to return because we skip the download
return nil, nil
}
cslRecords, _, err := csl.ReadEUFile(file)

cslRecords, _, err := csl.ParseEU(file["eu_csl.csv"])
if err != nil {
return nil, err
}
return cslRecords, err

}

func ukCSLRecords(logger log.Logger, initialDir string) ([]*csl.UKCSLRecord, error) {
Expand All @@ -249,7 +229,7 @@ func ukCSLRecords(logger log.Logger, initialDir string) ([]*csl.UKCSLRecord, err
// no error to return because we skip the download
return nil, nil
}
cslRecords, _, err := csl.ReadUKCSLFile(file)
cslRecords, _, err := csl.ReadUKCSLFile(file["ConList.csv"])
if err != nil {
return nil, err
}
Expand All @@ -264,7 +244,7 @@ func ukSanctionsListRecords(logger log.Logger, initialDir string) ([]*csl.UKSanc
return nil, nil
}

records, _, err := csl.ReadUKSanctionsListFile(file)
records, _, err := csl.ReadUKSanctionsListFile(file["UK_Sanctions_List.ods"])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -342,6 +322,9 @@ func (s *searcher) refreshData(initialDir string) (*DownloadStats, error) {
lastDataRefreshFailure.WithLabelValues("CSL").Set(float64(time.Now().Unix()))
stats.Errors = append(stats.Errors, fmt.Errorf("CSL: %v", err))
}
if consolidatedLists == nil {
consolidatedLists = new(csl.CSL)
}
els := precomputeCSLEntities[csl.EL](consolidatedLists.ELs, s.pipe)
meus := precomputeCSLEntities[csl.MEU](consolidatedLists.MEUs, s.pipe)
ssis := precomputeCSLEntities[csl.SSI](consolidatedLists.SSIs, s.pipe)
Expand Down
8 changes: 7 additions & 1 deletion cmd/server/search_crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"

Expand All @@ -25,7 +27,11 @@ var (

func init() {
// Set SDN Comments
ofacResults, err := ofac.Read(filepath.Join("..", "..", "test", "testdata", "sdn_comments.csv"))
fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "sdn_comments.csv"))
if err != nil {
panic(fmt.Sprintf("%v", err))
}
ofacResults, err := ofac.Read(map[string]io.ReadCloser{"sdn_comments.csv": fd})
if err != nil {
panic(fmt.Sprintf("ERROR reading sdn_comments.csv: %v", err))
}
Expand Down
8 changes: 7 additions & 1 deletion cmd/server/search_handlers_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ package main
import (
"crypto/rand"
"fmt"
"io"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
Expand Down Expand Up @@ -56,7 +58,11 @@ func BenchmarkSearchHandler(b *testing.B) {
}

func BenchmarkJaroWinkler(b *testing.B) {
results, err := ofac.Read(filepath.Join("..", "..", "test", "testdata", "sdn.csv"))
fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "sdn.csv"))
if err != nil {
b.Error(err)
}
results, err := ofac.Read(map[string]io.ReadCloser{"sdn.csv": fd})
require.NoError(b, err)
require.Len(b, results.SDNs, 7379)

Expand Down
3 changes: 3 additions & 0 deletions cmd/server/search_us_csl.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ func searchUSCSL(logger log.Logger, searcher *searcher) http.HandlerFunc {

func precomputeCSLEntities[T any](items []*T, pipe *pipeliner) []*Result[T] {
out := make([]*Result[T], len(items))
if items == nil {
return out
}

for i, item := range items {
name := cslName(item)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ require (
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673
go4.org v0.0.0-20230225012048-214862532bf5
golang.org/x/oauth2 v0.14.0
golang.org/x/sync v0.6.0
golang.org/x/text v0.14.0
)

Expand All @@ -37,7 +38,6 @@ require (
github.com/prometheus/procfs v0.12.0 // indirect
github.com/rickar/cal/v2 v2.1.13 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.15.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/protobuf v1.31.0 // indirect
Expand Down
3 changes: 1 addition & 2 deletions pkg/csl/csl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
)

func TestCSL(t *testing.T) {
t.Skip("CSL is currently broken, looks like they require API access now")

if testing.Short() {
t.Skip("ignorning network test")
Expand All @@ -27,7 +26,7 @@ func TestCSL(t *testing.T) {
file, err := Download(logger, dir)
require.NoError(t, err)

cslRecords, err := ReadFile(file)
cslRecords, err := ReadFile(file["csl.csv"])
require.NoError(t, err)

if len(cslRecords.SSIs) == 0 {
Expand Down
11 changes: 4 additions & 7 deletions pkg/csl/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package csl

import (
"fmt"
"io"
"net/url"
"os"

Expand All @@ -19,22 +20,18 @@ var (
usDownloadURL = strx.Or(os.Getenv("CSL_DOWNLOAD_TEMPLATE"), os.Getenv("US_CSL_DOWNLOAD_URL"), publicUSDownloadURL)
)

func Download(logger log.Logger, initialDir string) (string, error) {
func Download(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) {
dl := download.New(logger, download.HTTPClient)

cslURL, err := buildDownloadURL(usDownloadURL)
if err != nil {
return "", err
return nil, err
}

cslNameAndSource := make(map[string]string)
cslNameAndSource["csl.csv"] = cslURL

file, err := dl.GetFiles(initialDir, cslNameAndSource)
if len(file) == 0 || err != nil {
return "", fmt.Errorf("csl download: %v", err)
}
return file[0], nil
return dl.GetFiles(initialDir, cslNameAndSource)
}

func buildDownloadURL(urlStr string) (string, error) {
Expand Down
10 changes: 4 additions & 6 deletions pkg/csl/download_eu.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package csl

import (
"fmt"
"io"
"os"

"github.com/moov-io/base/log"
Expand All @@ -22,15 +23,12 @@ var (
euDownloadURL = strx.Or(os.Getenv("EU_CSL_DOWNLOAD_URL"), publicEUDownloadURL)
)

func DownloadEU(logger log.Logger, initialDir string) (string, error) {
func DownloadEU(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) {
dl := download.New(logger, download.HTTPClient)

euCSLNameAndSource := make(map[string]string)
euCSLNameAndSource["eu_csl.csv"] = euDownloadURL

file, err := dl.GetFiles(initialDir, euCSLNameAndSource)
if len(file) == 0 || err != nil {
return "", fmt.Errorf("eu csl download: %v", err)
}
return file[0], nil
return dl.GetFiles(initialDir, euCSLNameAndSource)

}
34 changes: 19 additions & 15 deletions pkg/csl/download_eu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package csl

import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
Expand All @@ -24,18 +25,19 @@ func TestEUDownload(t *testing.T) {
t.Fatal(err)
}
fmt.Println("file in test: ", file)
if file == "" {
if len(file) == 0 {
t.Fatal("no EU CSL file")
}
defer os.RemoveAll(filepath.Dir(file))

if !strings.EqualFold("eu_csl.csv", filepath.Base(file)) {
t.Errorf("unknown file %s", file)
for fn := range file {
if !strings.EqualFold("eu_csl.csv", filepath.Base(fn)) {
t.Errorf("unknown file %s", file)
}
}
}

func TestEUDownload_initialDir(t *testing.T) {
dir, err := os.MkdirTemp("", "iniital-dir")
dir, err := os.MkdirTemp("", "initial-dir")
if err != nil {
t.Fatal(err)
}
Expand All @@ -55,19 +57,21 @@ func TestEUDownload_initialDir(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if file == "" {
if len(file) == 0 {
t.Fatal("no EU CSL file")
}

if strings.EqualFold("eu_csl.csv", filepath.Base(file)) {
bs, err := os.ReadFile(file)
if err != nil {
t.Fatal(err)
}
if v := string(bs); v != "file=eu_csl.csv" {
t.Errorf("eu_csl.csv: %v", v)
for fn, fd := range file {
if strings.EqualFold("eu_csl.csv", filepath.Base(fn)) {
bs, err := io.ReadAll(fd)
if err != nil {
t.Fatal(err)
}
if v := string(bs); v != "file=eu_csl.csv" {
t.Errorf("eu_csl.csv: %v", v)
}
} else {
t.Fatalf("unknown file: %v", file)
}
} else {
t.Fatalf("unknown file: %v", file)
}
}
32 changes: 18 additions & 14 deletions pkg/csl/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package csl

import (
"io"
"os"
"path/filepath"
"strings"
Expand All @@ -22,13 +23,14 @@ func TestDownload(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if file == "" {
if len(file) == 0 {
t.Fatal("no CSL file")
}
defer os.RemoveAll(filepath.Dir(file))

if !strings.EqualFold("csl.csv", filepath.Base(file)) {
t.Errorf("unknown file %s", file)
for fn := range file {
if !strings.EqualFold("csl.csv", filepath.Base(fn)) {
t.Errorf("unknown file %s", fn)
}
}
}

Expand All @@ -55,20 +57,22 @@ func TestDownload_initialDir(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if file == "" {
if len(file) == 0 {
t.Fatal("no CSL file")
}

if strings.EqualFold("csl.csv", filepath.Base(file)) {
bs, err := os.ReadFile(file)
if err != nil {
t.Fatal(err)
}
if v := string(bs); v != "file=csl.csv" {
t.Errorf("csl.csv: %v", v)
for fn, fd := range file {
if strings.EqualFold("csl.csv", filepath.Base(fn)) {
bs, err := io.ReadAll(fd)
if err != nil {
t.Fatal(err)
}
if v := string(bs); v != "file=csl.csv" {
t.Errorf("csl.csv: %v", v)
}
} else {
t.Fatalf("unknown file: %v", file)
}
} else {
t.Fatalf("unknown file: %v", file)
}
}

Expand Down
Loading

0 comments on commit 33df3aa

Please sign in to comment.