diff --git a/flytectl/cmd/upgrade/upgrade.go b/flytectl/cmd/upgrade/upgrade.go index deb88db5b8..9991f934e4 100644 --- a/flytectl/cmd/upgrade/upgrade.go +++ b/flytectl/cmd/upgrade/upgrade.go @@ -119,15 +119,15 @@ func isUpgradeSupported(goos platformutil.Platform) (bool, error) { return false, err } - compatible_version := strings.TrimPrefix(latest, fmt.Sprintf("%s/", github.FlytectlReleaseConfig.ExecutableName)) - if isGreater, err := util.IsVersionGreaterThan(compatible_version, stdlibversion.Version); err != nil { + compatibleVersion := strings.TrimPrefix(latest, fmt.Sprintf("%s/", github.FlytectlReleaseConfig.ExecutableName)) + if isGreater, err := util.IsVersionGreaterThan(compatibleVersion, stdlibversion.Version); err != nil { return false, err } else if !isGreater { fmt.Println("You already have the latest version of Flytectl") return false, nil } - message, err := github.GetUpgradeMessage(compatible_version, goos) + message, err := github.GetUpgradeMessage(compatibleVersion, goos) if err != nil { return false, err } diff --git a/flytectl/pkg/github/githubutil.go b/flytectl/pkg/github/githubutil.go index 845f7f75b6..babfaeace5 100644 --- a/flytectl/pkg/github/githubutil.go +++ b/flytectl/pkg/github/githubutil.go @@ -39,9 +39,10 @@ var Client GHRepoService // FlytectlReleaseConfig represent the updater config for flytectl binary var FlytectlReleaseConfig = &updater.Updater{ - Provider: &GitHubProvider{ + Provider: &GHProvider{ RepositoryURL: flytectlRepository, ArchiveName: getFlytectlAssetName(), + ghRepo: GetGHRepoService(), }, ExecutableName: flytectl, Version: stdlibversion.Version, diff --git a/flytectl/pkg/github/provider.go b/flytectl/pkg/github/provider.go index 0e37cee284..1d3b63e874 100644 --- a/flytectl/pkg/github/provider.go +++ b/flytectl/pkg/github/provider.go @@ -15,24 +15,15 @@ import ( "github.com/mouuff/go-rocket-update/pkg/provider" ) -// type GitHubProvider struct { -// provider.Github -// } - // Github provider finds a archive file in the repository's releases to provide files -type GitHubProvider struct { +type GHProvider struct { RepositoryURL string // Repository URL, example github.com/mouuff/go-rocket-update ArchiveName string // Archive name (the zip/tar.gz you upload for a release on github), example: binaries.zip tmpDir string // temporary directory this is used internally decompressProvider provider.Provider // provider used to decompress the downloaded archive archivePath string // path to the downloaded archive (should be in tmpDir) -} - -// githubTag struct used to unmarshal response from github -// https://api.github.com/repos/ownerName/projectName/tags -type githubTag struct { - Name string `json:"name"` + ghRepo GHRepoService // github repository service } // githubRepositoryInfo is used to get the name of the project and the owner name @@ -43,7 +34,7 @@ type githubRepositoryInfo struct { } // getRepositoryInfo parses the github repository URL -func (c *GitHubProvider) repositoryInfo() (*githubRepositoryInfo, error) { +func (c *GHProvider) repositoryInfo() (*githubRepositoryInfo, error) { re := regexp.MustCompile(`github\.com/(.*?)/(.*?)$`) submatches := re.FindAllStringSubmatch(c.RepositoryURL, 1) if len(submatches) < 1 { @@ -57,7 +48,7 @@ func (c *GitHubProvider) repositoryInfo() (*githubRepositoryInfo, error) { // getArchiveURL get the archive URL for the github repository // If no tag is provided then the latest version is selected -func (c *GitHubProvider) getArchiveURL(tag string) (string, error) { +func (c *GHProvider) getArchiveURL(tag string) (string, error) { if len(tag) == 0 { // Get latest version if no tag is provided var err error @@ -80,41 +71,45 @@ func (c *GitHubProvider) getArchiveURL(tag string) (string, error) { } // Open opens the provider -func (c *GitHubProvider) Open() (err error) { +func (c *GHProvider) Open() (err error) { archiveURL, err := c.getArchiveURL("") // get archive url for latest version if err != nil { - return + return err + } + req, err := http.NewRequest("GET", archiveURL, nil) + if err != nil { + return err } - resp, err := http.Get(archiveURL) + resp, err := http.DefaultClient.Do(req) if err != nil { - return + return err } defer resp.Body.Close() c.tmpDir, err = os.MkdirTemp("", "rocket-update") if err != nil { - return + return err } c.archivePath = filepath.Join(c.tmpDir, c.ArchiveName) archiveFile, err := os.Create(c.archivePath) if err != nil { - return + return err } _, err = io.Copy(archiveFile, resp.Body) archiveFile.Close() if err != nil { - return + return err } c.decompressProvider, err = provider.Decompress(c.archivePath) if err != nil { - return nil + return err } return c.decompressProvider.Open() } // Close closes the provider -func (c *GitHubProvider) Close() error { +func (c *GHProvider) Close() error { if c.decompressProvider != nil { c.decompressProvider.Close() c.decompressProvider = nil @@ -129,17 +124,17 @@ func (c *GitHubProvider) Close() error { } // GetLatestVersion gets the latest version -func (c *GitHubProvider) GetLatestVersion() (string, error) { +func (c *GHProvider) GetLatestVersion() (string, error) { tags, err := c.getReleases() if err != nil { return "", err } - latest_tag := tags[0].GetTagName() - return latest_tag, err + latestTag := tags[0].GetTagName() + return latestTag, err } -func (c *GitHubProvider) getReleases() ([]*go_github.RepositoryRelease, error) { - g := GetGHRepoService() +func (c *GHProvider) getReleases() ([]*go_github.RepositoryRelease, error) { + g := c.ghRepo releases, _, err := g.ListReleases(context.Background(), owner, flyte, &go_github.ListOptions{ PerPage: 100, }) @@ -156,7 +151,7 @@ func (c *GitHubProvider) getReleases() ([]*go_github.RepositoryRelease, error) { } // Walk walks all the files provided -func (c *GitHubProvider) Walk(walkFn provider.WalkFunc) error { +func (c *GHProvider) Walk(walkFn provider.WalkFunc) error { if c.decompressProvider == nil { // TODO specify error return provider.ErrNotOpenned @@ -165,6 +160,6 @@ func (c *GitHubProvider) Walk(walkFn provider.WalkFunc) error { } // Retrieve file relative to "provider" to destination -func (c *GitHubProvider) Retrieve(src string, dest string) error { +func (c *GHProvider) Retrieve(src string, dest string) error { return c.decompressProvider.Retrieve(src, dest) } diff --git a/flytectl/pkg/github/provider_test.go b/flytectl/pkg/github/provider_test.go new file mode 100644 index 0000000000..d9b18eee2d --- /dev/null +++ b/flytectl/pkg/github/provider_test.go @@ -0,0 +1,66 @@ +package github + +import ( + "testing" + + "github.com/flyteorg/flyte/flytectl/pkg/github/mocks" + go_github "github.com/google/go-github/v42/github" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestGetLatestFlytectlVersion(t *testing.T) { + t.Run("Get latest release", func(t *testing.T) { + mockGh := &mocks.GHRepoService{} + // return a list of github releases + mockGh.OnListReleasesMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( + []*go_github.RepositoryRelease{ + {TagName: go_github.String("flytectl/1.2.4")}, + {TagName: go_github.String("flytectl/1.2.3")}, + {TagName: go_github.String("other-1.0.0")}, + }, + nil, + nil, + ) + mockProvider := &GHProvider{ + RepositoryURL: flytectlRepository, + ArchiveName: getFlytectlAssetName(), + ghRepo: mockGh, + } + + latestVersion, err := mockProvider.GetLatestVersion() + assert.Nil(t, err) + assert.Equal(t, "flytectl/1.2.4", latestVersion) + }) +} + +func TestGetFlytectlReleases(t *testing.T) { + t.Run("Get releases", func(t *testing.T) { + mockGh := &mocks.GHRepoService{} + allReleases := []*go_github.RepositoryRelease{ + {TagName: go_github.String("flytectl/1.2.4")}, + {TagName: go_github.String("flytectl/1.2.3")}, + {TagName: go_github.String("other-1.0.0")}, + } + releases := []*go_github.RepositoryRelease{ + {TagName: go_github.String("flytectl/1.2.4")}, + {TagName: go_github.String("flytectl/1.2.3")}, + } + // return a list of github releases + mockGh.OnListReleasesMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( + allReleases, + nil, + nil, + ) + mockProvider := &GHProvider{ + RepositoryURL: flytectlRepository, + ArchiveName: getFlytectlAssetName(), + ghRepo: mockGh, + } + + flytectlReleases, err := mockProvider.getReleases() + assert.Nil(t, err) + assert.Equal(t, releases, flytectlReleases) + }) +}