Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass io.ReadCloser instances instead of files on harddrive, fixes #536 #537

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 16 additions & 33 deletions cmd/server/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,37 +163,14 @@ func ofacRecords(logger log.Logger, initialDir string) (*ofac.Results, error) {
if len(files) == 0 {
return nil, errors.New("no OFAC Results")
}

var res *ofac.Results
for i := range files {
// Does order matter?

if i == 0 {
rr, err := ofac.Read(files[i])
if err != nil {
return nil, fmt.Errorf("reading %s: %v", files[i], err)
}
if rr != nil {
res = rr
}
} else {
rr, err := ofac.Read(files[i])
if err != nil {
return nil, fmt.Errorf("read and replace: %v", err)
}
if rr != nil {
res.Addresses = append(res.Addresses, rr.Addresses...)
res.AlternateIdentities = append(res.AlternateIdentities, rr.AlternateIdentities...)
res.SDNs = append(res.SDNs, rr.SDNs...)
res.SDNComments = append(res.SDNComments, rr.SDNComments...)
}
}
res, err := ofac.Read(files)
if err != nil {
return nil, err
}

// Merge comments into SDNs
res.SDNs = mergeSpilloverRecords(res.SDNs, res.SDNComments)

return res, err
return res, nil
}

func mergeSpilloverRecords(sdns []*ofac.SDN, comments []*ofac.SDNComments) []*ofac.SDN {
Expand All @@ -212,7 +189,8 @@ func dplRecords(logger log.Logger, initialDir string) ([]*dpl.DPL, error) {
if err != nil {
return nil, err
}
return dpl.Read(file)

return dpl.Read(file["dpl.txt"])
}

func cslRecords(logger log.Logger, initialDir string) (*csl.CSL, error) {
Expand All @@ -221,11 +199,11 @@ func cslRecords(logger log.Logger, initialDir string) (*csl.CSL, error) {
logger.Warn().Logf("skipping CSL download: %v", err)
return &csl.CSL{}, nil
}
cslRecords, err := csl.ReadFile(file)
cslRecords, err := csl.ReadFile(file["csl.csv"])
if err != nil {
return nil, err
}
return cslRecords, err
return cslRecords, nil
}

func euCSLRecords(logger log.Logger, initialDir string) ([]*csl.EUCSLRecord, error) {
Expand All @@ -235,11 +213,13 @@ func euCSLRecords(logger log.Logger, initialDir string) ([]*csl.EUCSLRecord, err
// no error to return because we skip the download
return nil, nil
}
cslRecords, _, err := csl.ReadEUFile(file)

cslRecords, _, err := csl.ParseEU(file["eu_csl.csv"])
if err != nil {
return nil, err
}
return cslRecords, err

}

func ukCSLRecords(logger log.Logger, initialDir string) ([]*csl.UKCSLRecord, error) {
Expand All @@ -249,7 +229,7 @@ func ukCSLRecords(logger log.Logger, initialDir string) ([]*csl.UKCSLRecord, err
// no error to return because we skip the download
return nil, nil
}
cslRecords, _, err := csl.ReadUKCSLFile(file)
cslRecords, _, err := csl.ReadUKCSLFile(file["ConList.csv"])
if err != nil {
return nil, err
}
Expand All @@ -264,7 +244,7 @@ func ukSanctionsListRecords(logger log.Logger, initialDir string) ([]*csl.UKSanc
return nil, nil
}

records, _, err := csl.ReadUKSanctionsListFile(file)
records, _, err := csl.ReadUKSanctionsListFile(file["UK_Sanctions_List.ods"])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -342,6 +322,9 @@ func (s *searcher) refreshData(initialDir string) (*DownloadStats, error) {
lastDataRefreshFailure.WithLabelValues("CSL").Set(float64(time.Now().Unix()))
stats.Errors = append(stats.Errors, fmt.Errorf("CSL: %v", err))
}
if consolidatedLists == nil {
consolidatedLists = new(csl.CSL)
}
els := precomputeCSLEntities[csl.EL](consolidatedLists.ELs, s.pipe)
meus := precomputeCSLEntities[csl.MEU](consolidatedLists.MEUs, s.pipe)
ssis := precomputeCSLEntities[csl.SSI](consolidatedLists.SSIs, s.pipe)
Expand Down
8 changes: 7 additions & 1 deletion cmd/server/search_crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"

Expand All @@ -25,7 +27,11 @@ var (

func init() {
// Set SDN Comments
ofacResults, err := ofac.Read(filepath.Join("..", "..", "test", "testdata", "sdn_comments.csv"))
fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "sdn_comments.csv"))
if err != nil {
panic(fmt.Sprintf("%v", err))
}
ofacResults, err := ofac.Read(map[string]io.ReadCloser{"sdn_comments.csv": fd})
if err != nil {
panic(fmt.Sprintf("ERROR reading sdn_comments.csv: %v", err))
}
Expand Down
8 changes: 7 additions & 1 deletion cmd/server/search_handlers_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ package main
import (
"crypto/rand"
"fmt"
"io"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
Expand Down Expand Up @@ -56,7 +58,11 @@ func BenchmarkSearchHandler(b *testing.B) {
}

func BenchmarkJaroWinkler(b *testing.B) {
results, err := ofac.Read(filepath.Join("..", "..", "test", "testdata", "sdn.csv"))
fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "sdn.csv"))
if err != nil {
b.Error(err)
}
results, err := ofac.Read(map[string]io.ReadCloser{"sdn.csv": fd})
require.NoError(b, err)
require.Len(b, results.SDNs, 7379)

Expand Down
3 changes: 3 additions & 0 deletions cmd/server/search_us_csl.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ func searchUSCSL(logger log.Logger, searcher *searcher) http.HandlerFunc {

func precomputeCSLEntities[T any](items []*T, pipe *pipeliner) []*Result[T] {
out := make([]*Result[T], len(items))
if items == nil {
return out
}

for i, item := range items {
name := cslName(item)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ require (
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673
go4.org v0.0.0-20230225012048-214862532bf5
golang.org/x/oauth2 v0.14.0
golang.org/x/sync v0.6.0
golang.org/x/text v0.14.0
)

Expand All @@ -37,7 +38,6 @@ require (
github.com/prometheus/procfs v0.12.0 // indirect
github.com/rickar/cal/v2 v2.1.13 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.15.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/protobuf v1.31.0 // indirect
Expand Down
3 changes: 1 addition & 2 deletions pkg/csl/csl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
)

func TestCSL(t *testing.T) {
t.Skip("CSL is currently broken, looks like they require API access now")

if testing.Short() {
t.Skip("ignorning network test")
Expand All @@ -27,7 +26,7 @@ func TestCSL(t *testing.T) {
file, err := Download(logger, dir)
require.NoError(t, err)

cslRecords, err := ReadFile(file)
cslRecords, err := ReadFile(file["csl.csv"])
require.NoError(t, err)

if len(cslRecords.SSIs) == 0 {
Expand Down
11 changes: 4 additions & 7 deletions pkg/csl/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package csl

import (
"fmt"
"io"
"net/url"
"os"

Expand All @@ -19,22 +20,18 @@ var (
usDownloadURL = strx.Or(os.Getenv("CSL_DOWNLOAD_TEMPLATE"), os.Getenv("US_CSL_DOWNLOAD_URL"), publicUSDownloadURL)
)

func Download(logger log.Logger, initialDir string) (string, error) {
func Download(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) {
dl := download.New(logger, download.HTTPClient)

cslURL, err := buildDownloadURL(usDownloadURL)
if err != nil {
return "", err
return nil, err
}

cslNameAndSource := make(map[string]string)
cslNameAndSource["csl.csv"] = cslURL

file, err := dl.GetFiles(initialDir, cslNameAndSource)
if len(file) == 0 || err != nil {
return "", fmt.Errorf("csl download: %v", err)
}
return file[0], nil
return dl.GetFiles(initialDir, cslNameAndSource)
}

func buildDownloadURL(urlStr string) (string, error) {
Expand Down
10 changes: 4 additions & 6 deletions pkg/csl/download_eu.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package csl

import (
"fmt"
"io"
"os"

"github.com/moov-io/base/log"
Expand All @@ -22,15 +23,12 @@ var (
euDownloadURL = strx.Or(os.Getenv("EU_CSL_DOWNLOAD_URL"), publicEUDownloadURL)
)

func DownloadEU(logger log.Logger, initialDir string) (string, error) {
func DownloadEU(logger log.Logger, initialDir string) (map[string]io.ReadCloser, error) {
dl := download.New(logger, download.HTTPClient)

euCSLNameAndSource := make(map[string]string)
euCSLNameAndSource["eu_csl.csv"] = euDownloadURL

file, err := dl.GetFiles(initialDir, euCSLNameAndSource)
if len(file) == 0 || err != nil {
return "", fmt.Errorf("eu csl download: %v", err)
}
return file[0], nil
return dl.GetFiles(initialDir, euCSLNameAndSource)

}
34 changes: 19 additions & 15 deletions pkg/csl/download_eu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package csl

import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
Expand All @@ -24,18 +25,19 @@ func TestEUDownload(t *testing.T) {
t.Fatal(err)
}
fmt.Println("file in test: ", file)
if file == "" {
if len(file) == 0 {
t.Fatal("no EU CSL file")
}
defer os.RemoveAll(filepath.Dir(file))

if !strings.EqualFold("eu_csl.csv", filepath.Base(file)) {
t.Errorf("unknown file %s", file)
for fn := range file {
if !strings.EqualFold("eu_csl.csv", filepath.Base(fn)) {
t.Errorf("unknown file %s", file)
}
}
}

func TestEUDownload_initialDir(t *testing.T) {
dir, err := os.MkdirTemp("", "iniital-dir")
dir, err := os.MkdirTemp("", "initial-dir")
if err != nil {
t.Fatal(err)
}
Expand All @@ -55,19 +57,21 @@ func TestEUDownload_initialDir(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if file == "" {
if len(file) == 0 {
t.Fatal("no EU CSL file")
}

if strings.EqualFold("eu_csl.csv", filepath.Base(file)) {
bs, err := os.ReadFile(file)
if err != nil {
t.Fatal(err)
}
if v := string(bs); v != "file=eu_csl.csv" {
t.Errorf("eu_csl.csv: %v", v)
for fn, fd := range file {
if strings.EqualFold("eu_csl.csv", filepath.Base(fn)) {
bs, err := io.ReadAll(fd)
if err != nil {
t.Fatal(err)
}
if v := string(bs); v != "file=eu_csl.csv" {
t.Errorf("eu_csl.csv: %v", v)
}
} else {
t.Fatalf("unknown file: %v", file)
}
} else {
t.Fatalf("unknown file: %v", file)
}
}
32 changes: 18 additions & 14 deletions pkg/csl/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package csl

import (
"io"
"os"
"path/filepath"
"strings"
Expand All @@ -22,13 +23,14 @@ func TestDownload(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if file == "" {
if len(file) == 0 {
t.Fatal("no CSL file")
}
defer os.RemoveAll(filepath.Dir(file))

if !strings.EqualFold("csl.csv", filepath.Base(file)) {
t.Errorf("unknown file %s", file)
for fn := range file {
if !strings.EqualFold("csl.csv", filepath.Base(fn)) {
t.Errorf("unknown file %s", fn)
}
}
}

Expand All @@ -55,20 +57,22 @@ func TestDownload_initialDir(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if file == "" {
if len(file) == 0 {
t.Fatal("no CSL file")
}

if strings.EqualFold("csl.csv", filepath.Base(file)) {
bs, err := os.ReadFile(file)
if err != nil {
t.Fatal(err)
}
if v := string(bs); v != "file=csl.csv" {
t.Errorf("csl.csv: %v", v)
for fn, fd := range file {
if strings.EqualFold("csl.csv", filepath.Base(fn)) {
bs, err := io.ReadAll(fd)
if err != nil {
t.Fatal(err)
}
if v := string(bs); v != "file=csl.csv" {
t.Errorf("csl.csv: %v", v)
}
} else {
t.Fatalf("unknown file: %v", file)
}
} else {
t.Fatalf("unknown file: %v", file)
}
}

Expand Down
Loading
Loading