diff --git a/CHANGELOG.md b/CHANGELOG.md index d4d7bba..740bef6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,8 +4,6 @@ * `geoipupdate` now supports retrying on more types of errors such as HTTP2 INTERNAL_ERROR. -* `HTTPReader` no longer retries on HTTP errors and therefore - `retryFor` was removed from `NewHTTPReader`. * Now `geoipupdate` doesn't requires the user to specify the config file even if all the other arguments are set via the environment variables. Reported by jsf84ksnf. GitHub #284. @@ -15,9 +13,17 @@ a database edition. * `/geoip/databases/{edition-id}/download` which is responsible for downloading the content of a database edition. This new endpoint redirects downloads to R2 - presigned URLs, so systems running geoipupdate need to be able to reach + presigned URLs, so systems running `geoipupdate` need to be able to + reach `mm-prod-geoip-databases.a2649acb697e2c09b632799562c076f2.r2.cloudflarestorage.com` in addition to `updates.maxmind.com`. +* BREAKING CHANGE: The public package API has been redesigned. The previous + API was not easy to use and had become a maintenance burden. We now + expose a `Client` at `github.com/maxmind/geoipupdate/client` with a + `Download()` method. The intention is to expose less of the `geoipupdate` + internals and provide a simpler and easier to use package. Many + previously exposed methods and types are now either internal only or have + been removed. ## 6.1.0 (2024-01-09) diff --git a/client/client.go b/client/client.go index a9158fc..3c55397 100644 --- a/client/client.go +++ b/client/client.go @@ -1,38 +1,65 @@ +// Package client is a client for downloading GeoIP2 and GeoLite2 MMDB +// databases. package client import ( + "fmt" "net/http" ) -// HTTPReader is a Reader that uses an HTTP client to retrieve -// databases. -type HTTPReader struct { - // client is an http client responsible of fetching database updates. - client *http.Client - // path is the request path. - path string - // accountID is used for request auth. - accountID int - // licenseKey is used for request auth. +// Client downloads GeoIP2 and GeoLite2 MMDB databases. +// +// After creation, it is valid for concurrent use. +type Client struct { + accountID int + endpoint string + httpClient *http.Client licenseKey string - // verbose turns on/off debug logs. - verbose bool } -// NewHTTPReader creates a Reader that downloads database updates via -// HTTP. -func NewHTTPReader( - path string, +// Option is an option for configuring Client. +type Option func(*Client) + +// WithEndpoint sets the base endpoint to use. By default we use +// https://updates.maxmind.com. +func WithEndpoint(endpoint string) Option { + return func(c *Client) { + c.endpoint = endpoint + } +} + +// WithHTTPClient sets the HTTP client to use. By default we use +// http.DefaultClient. +func WithHTTPClient(httpClient *http.Client) Option { + return func(c *Client) { + c.httpClient = httpClient + } +} + +// New creates a Client. +func New( accountID int, licenseKey string, - verbose bool, - httpClient *http.Client, -) *HTTPReader { - return &HTTPReader{ - client: httpClient, - path: path, + options ...Option, +) (Client, error) { + if accountID <= 0 { + return Client{}, fmt.Errorf("invalid account ID: %d", accountID) + } + + if licenseKey == "" { + return Client{}, fmt.Errorf("invalid license key: %s", licenseKey) + } + + c := Client{ accountID: accountID, + endpoint: "https://updates.maxmind.com", + httpClient: http.DefaultClient, licenseKey: licenseKey, - verbose: verbose, } + + for _, opt := range options { + opt(&c) + } + + return c, nil } diff --git a/client/download.go b/client/download.go index 1c9db3d..0f47b7c 100644 --- a/client/download.go +++ b/client/download.go @@ -7,8 +7,8 @@ import ( "errors" "fmt" "io" - "log" "net/http" + "net/url" "strconv" "strings" "time" @@ -17,81 +17,124 @@ import ( "github.com/maxmind/geoipupdate/v6/internal/vars" ) -// Read attempts to fetch database updates for a specific editionID. -// It takes an editionID and its previously downloaded hash if available -// as arguments and returns a ReadResult struct as a response. -// It's the responsibility of the Writer to close the io.ReadCloser -// included in the response after consumption. -func (r *HTTPReader) Read(ctx context.Context, editionID, hash string) (*ReadResult, error) { - result, err := r.get(ctx, editionID, hash) - if err != nil { - return nil, fmt.Errorf("getting update for %s: %w", editionID, err) - } - - return result, nil +// DownloadResponse describes the result of a Download call. +type DownloadResponse struct { + // LastModified is the date that the database was last modified. It will + // only be set if UpdateAvailable is true. + LastModified time.Time + + // MD5 is the string representation of the new database. It will only be set + // if UpdateAvailable is true. + MD5 string + + // Reader can be read to access the database itself. It will only contain a + // database if UpdateAvailable is true. + // + // If the Download call does not return an error, Reader will always be + // non-nil. + // + // If UpdateAvailable is true, the caller must read Reader to completion and + // close it. + Reader io.ReadCloser + + // UpdateAvailable is true if there is an update available for download. It + // will be false if the MD5 used in the Download call matches what the server + // currently has. + UpdateAvailable bool } -const downloadEndpoint = "%s/geoip/databases/%s/download?" - -// get makes an http request to fetch updates for a specific editionID if any. -func (r *HTTPReader) get( +// Download attempts to download the edition. +// +// The editionID parameter is a valid database edition ID, such as +// "GeoIP2-City". +// +// The MD5 parameter is a string representation of the MD5 sum of the database +// MMDB file you have previously downloaded. If you don't yet have one +// downloaded, this can be "". This is used to know if an update is available +// and avoid consuming resources if there is not. +// +// If the current MD5 checksum matches what the server currently has, no +// download is performed. +func (c Client) Download( ctx context.Context, - editionID string, - hash string, -) (result *ReadResult, err error) { - edition, err := r.getMetadata(ctx, editionID) + editionID, + md5 string, +) (DownloadResponse, error) { + metadata, err := c.getMetadata(ctx, editionID) if err != nil { - return nil, err + return DownloadResponse{}, err } - if edition.MD5 == hash { - if r.verbose { - log.Printf("No new updates available for %s", editionID) - } - return &ReadResult{EditionID: editionID, OldHash: hash, NewHash: hash}, nil + if metadata.MD5 == md5 { + return DownloadResponse{ + Reader: io.NopCloser(strings.NewReader("")), + UpdateAvailable: false, + }, nil + } + + reader, modifiedTime, err := c.download(ctx, editionID, metadata.Date) + if err != nil { + return DownloadResponse{}, err } - date := strings.ReplaceAll(edition.Date, "-", "") + return DownloadResponse{ + LastModified: modifiedTime, + MD5: metadata.MD5, + Reader: reader, + UpdateAvailable: true, + }, nil +} + +const downloadEndpoint = "%s/geoip/databases/%s/download?" + +func (c *Client) download( + ctx context.Context, + editionID, + date string, +) (io.ReadCloser, time.Time, error) { + date = strings.ReplaceAll(date, "-", "") params := url.Values{} params.Add("date", date) params.Add("suffix", "tar.gz") - escapedEdition := url.PathEscape(edition.EditionID) - requestURL := fmt.Sprintf(downloadEndpoint, r.path, escapedEdition) + params.Encode() + escapedEdition := url.PathEscape(editionID) + requestURL := fmt.Sprintf(downloadEndpoint, c.endpoint, escapedEdition) + params.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) if err != nil { - return nil, fmt.Errorf("creating download request: %w", err) + return nil, time.Time{}, fmt.Errorf("creating download request: %w", err) } req.Header.Add("User-Agent", "geoipupdate/"+vars.Version) - req.SetBasicAuth(strconv.Itoa(r.accountID), r.licenseKey) + req.SetBasicAuth(strconv.Itoa(c.accountID), c.licenseKey) - response, err := r.client.Do(req) + response, err := c.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("performing download request: %w", err) + return nil, time.Time{}, fmt.Errorf("performing download request: %w", err) } // It is safe to close the response body reader as it wouldn't be // consumed in case this function returns an error. defer func() { if err != nil { + // TODO(horgh): Should we fully consume the body? response.Body.Close() } }() if response.StatusCode != http.StatusOK { + // TODO(horgh): Should we fully consume the body? //nolint:errcheck // we are already returning an error. buf, _ := io.ReadAll(io.LimitReader(response.Body, 256)) httpErr := internal.HTTPError{ Body: string(buf), StatusCode: response.StatusCode, } - return nil, fmt.Errorf("unexpected HTTP status code: %w", httpErr) + return nil, time.Time{}, fmt.Errorf("unexpected HTTP status code: %w", httpErr) } gzReader, err := gzip.NewReader(response.Body) if err != nil { - return nil, fmt.Errorf("encountered an error creating GZIP reader: %w", err) + return nil, time.Time{}, fmt.Errorf("encountered an error creating GZIP reader: %w", err) } defer func() { if err != nil { @@ -105,10 +148,10 @@ func (r *HTTPReader) get( for { header, err := tarReader.Next() if err == io.EOF { - return nil, errors.New("tar archive does not contain an mmdb file") + return nil, time.Time{}, errors.New("tar archive does not contain an mmdb file") } if err != nil { - return nil, fmt.Errorf("reading tar archive: %w", err) + return nil, time.Time{}, fmt.Errorf("reading tar archive: %w", err) } if strings.HasSuffix(header.Name, ".mmdb") { @@ -116,26 +159,18 @@ func (r *HTTPReader) get( } } - modifiedAt, err := parseTime(response.Header.Get("Last-Modified")) + lastModified, err := parseTime(response.Header.Get("Last-Modified")) if err != nil { - return nil, fmt.Errorf("reading Last-Modified header: %w", err) - } - - if r.verbose { - log.Printf("Updates available for %s", editionID) + return nil, time.Time{}, fmt.Errorf("reading Last-Modified header: %w", err) } - return &ReadResult{ - reader: editionReader{ + return editionReader{ Reader: tarReader, gzCloser: gzReader, responseCloser: response.Body, }, - EditionID: editionID, - OldHash: hash, - NewHash: edition.MD5, - ModifiedAt: modifiedAt, - }, nil + lastModified, + nil } // parseTime parses a string representation of a time into time.Time according to the diff --git a/client/metadata.go b/client/metadata.go index fd507ef..f83e732 100644 --- a/client/metadata.go +++ b/client/metadata.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "net/url" "strconv" @@ -24,24 +23,23 @@ type metadata struct { MD5 string `json:"md5"` } -func (r *HTTPReader) getMetadata(ctx context.Context, editionID string) (*metadata, error) { +func (c *Client) getMetadata( + ctx context.Context, + editionID string, +) (*metadata, error) { params := url.Values{} params.Add("edition_id", editionID) - metadataRequestURL := fmt.Sprintf(metadataEndpoint, r.path) + params.Encode() - - if r.verbose { - log.Printf("Requesting metadata for %s: %s", editionID, metadataRequestURL) - } + metadataRequestURL := fmt.Sprintf(metadataEndpoint, c.endpoint) + params.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataRequestURL, nil) if err != nil { return nil, fmt.Errorf("creating metadata request: %w", err) } req.Header.Add("User-Agent", "geoipupdate/"+vars.Version) - req.SetBasicAuth(strconv.Itoa(r.accountID), r.licenseKey) + req.SetBasicAuth(strconv.Itoa(c.accountID), c.licenseKey) - response, err := r.client.Do(req) + response, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("performing metadata request: %w", err) } diff --git a/internal/geoipupdate/database/local_file_writer.go b/internal/geoipupdate/database/local_file_writer.go index 5b66745..b1d43a9 100644 --- a/internal/geoipupdate/database/local_file_writer.go +++ b/internal/geoipupdate/database/local_file_writer.go @@ -1,3 +1,4 @@ +// Package database writes MMDBs to disk. package database import ( @@ -46,30 +47,28 @@ func NewLocalFileWriter( } // Write writes the result struct returned by a Reader to a database file. -func (w *LocalFileWriter) Write(result *ReadResult) (err error) { - // exit early if we've got the latest database version. - if strings.EqualFold(result.OldHash, result.NewHash) { - if w.verbose { - log.Printf("Database %s up to date", result.EditionID) - } - return nil - } - +func (w *LocalFileWriter) Write( + editionID string, + reader io.ReadCloser, + newMD5 string, + lastModified time.Time, +) (err error) { defer func() { - if closeErr := result.reader.Close(); closeErr != nil { + _, _ = io.Copy(io.Discard, reader) //nolint:errcheck // Best effort. + if closeErr := reader.Close(); closeErr != nil { err = errors.Join( err, - fmt.Errorf("closing reader for %s: %w", result.EditionID, closeErr), + fmt.Errorf("closing reader for %s: %w", editionID, closeErr), ) } }() - databaseFilePath := w.getFilePath(result.EditionID) + databaseFilePath := w.getFilePath(editionID) // write the Reader's result into a temporary file. fw, err := newFileWriter(databaseFilePath + tempExtension) if err != nil { - return fmt.Errorf("setting up database writer for %s: %w", result.EditionID, err) + return fmt.Errorf("setting up database writer for %s: %w", editionID, err) } defer func() { if closeErr := fw.close(); closeErr != nil { @@ -80,13 +79,13 @@ func (w *LocalFileWriter) Write(result *ReadResult) (err error) { } }() - if err = fw.write(result.reader); err != nil { - return fmt.Errorf("writing to the temp file for %s: %w", result.EditionID, err) + if err = fw.write(reader); err != nil { + return fmt.Errorf("writing to the temp file for %s: %w", editionID, err) } // make sure the hash of the temp file matches the expected hash. - if err = fw.validateHash(result.NewHash); err != nil { - return fmt.Errorf("validating hash for %s: %w", result.EditionID, err) + if err = fw.validateHash(newMD5); err != nil { + return fmt.Errorf("validating hash for %s: %w", editionID, err) } // move the temoporary database file into its final location and @@ -102,13 +101,13 @@ func (w *LocalFileWriter) Write(result *ReadResult) (err error) { // check if we need to set the file's modified at time if w.preserveFileTime { - if err = setModifiedAtTime(databaseFilePath, result.ModifiedAt); err != nil { + if err = setModifiedAtTime(databaseFilePath, lastModified); err != nil { return err } } if w.verbose { - log.Printf("Database %s successfully updated: %+v", result.EditionID, result.NewHash) + log.Printf("Database %s successfully updated: %+v", editionID, newMD5) } return nil diff --git a/internal/geoipupdate/database/reader.go b/internal/geoipupdate/database/reader.go index 17f4935..b55ba76 100644 --- a/internal/geoipupdate/database/reader.go +++ b/internal/geoipupdate/database/reader.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "io" "time" ) @@ -16,7 +15,6 @@ type Reader interface { // ReadResult is the struct returned by a Reader's Get method. type ReadResult struct { - reader io.ReadCloser EditionID string `json:"edition_id"` OldHash string `json:"old_hash"` NewHash string `json:"new_hash"` diff --git a/internal/geoipupdate/database/writer.go b/internal/geoipupdate/database/writer.go index e5f4015..95ec5bd 100644 --- a/internal/geoipupdate/database/writer.go +++ b/internal/geoipupdate/database/writer.go @@ -1,11 +1,16 @@ package database +import ( + "io" + "time" +) + // ZeroMD5 is the default value provided as an MD5 hash for a non-existent // database. const ZeroMD5 = "00000000000000000000000000000000" // Writer provides an interface for writing a database to a target location. type Writer interface { - Write(*ReadResult) error + Write(string, io.ReadCloser, string, time.Time) error GetHash(editionID string) (string, error) } diff --git a/internal/geoipupdate/geoip_updater.go b/internal/geoipupdate/geoip_updater.go index 0fc1187..e41c1f2 100644 --- a/internal/geoipupdate/geoip_updater.go +++ b/internal/geoipupdate/geoip_updater.go @@ -14,17 +14,22 @@ import ( "github.com/cenkalti/backoff/v4" + "github.com/maxmind/geoipupdate/v6/client" "github.com/maxmind/geoipupdate/v6/internal" "github.com/maxmind/geoipupdate/v6/internal/geoipupdate/database" ) +type updateClient interface { + Download(context.Context, string, string) (client.DownloadResponse, error) +} + // Updater uses config data to initiate a download or update // process for GeoIP databases. type Updater struct { - config *Config - reader database.Reader - output *log.Logger - writer database.Writer + config *Config + output *log.Logger + updateClient updateClient + writer database.Writer } // NewUpdater initialized a new Updater struct. @@ -36,13 +41,15 @@ func NewUpdater(config *Config) (*Updater, error) { } httpClient := &http.Client{Transport: transport} - reader := database.NewHTTPReader( - config.URL, + updateClient, err := client.New( config.AccountID, config.LicenseKey, - config.Verbose, - httpClient, + client.WithEndpoint(config.URL), + client.WithHTTPClient(httpClient), ) + if err != nil { + return nil, err + } writer, err := database.NewLocalFileWriter( config.DatabaseDirectory, @@ -54,10 +61,10 @@ func NewUpdater(config *Config) (*Updater, error) { } return &Updater{ - config: config, - reader: reader, - output: log.New(os.Stdout, "", 0), - writer: writer, + config: config, + output: log.New(os.Stdout, "", 0), + updateClient: updateClient, + writer: writer, }, nil } @@ -83,7 +90,7 @@ func (u *Updater) Run(ctx context.Context) error { for _, editionID := range u.config.EditionIDs { editionID := editionID processFunc := func(ctx context.Context) error { - edition, err := u.downloadEdition(ctx, editionID, u.reader, u.writer) + edition, err := u.downloadEdition(ctx, editionID, u.updateClient, u.writer) if err != nil { return err } @@ -120,7 +127,7 @@ func (u *Updater) Run(ctx context.Context) error { func (u *Updater) downloadEdition( ctx context.Context, editionID string, - r database.Reader, + uc updateClient, w database.Writer, ) (*database.ReadResult, error) { editionHash, err := w.GetHash(editionID) @@ -141,15 +148,41 @@ func (u *Updater) downloadEdition( var edition *database.ReadResult err = backoff.RetryNotify( func() error { - if edition, err = r.Read(ctx, editionID, editionHash); err != nil { + res, err := uc.Download(ctx, editionID, editionHash) + if err != nil { if internal.IsPermanentError(err) { return backoff.Permanent(err) } return err } + defer res.Reader.Close() + + if !res.UpdateAvailable { + if u.config.Verbose { + log.Printf("No new updates available for %s", editionID) + log.Printf("Database %s up to date", editionID) + } - if err = w.Write(edition); err != nil { + edition = &database.ReadResult{ + EditionID: editionID, + OldHash: editionHash, + NewHash: editionHash, + } + return nil + } + + if u.config.Verbose { + log.Printf("Updates available for %s", editionID) + } + + err = u.writer.Write( + editionID, + res.Reader, + res.MD5, + res.LastModified, + ) + if err != nil { if internal.IsPermanentError(err) { return backoff.Permanent(err) } @@ -157,6 +190,12 @@ func (u *Updater) downloadEdition( return err } + edition = &database.ReadResult{ + EditionID: editionID, + OldHash: editionHash, + NewHash: res.MD5, + ModifiedAt: res.LastModified, + } return nil }, b,