diff --git a/proxmox/client.go b/proxmox/client.go index 10213f39..ce9c5321 100644 --- a/proxmox/client.go +++ b/proxmox/client.go @@ -16,6 +16,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" ) @@ -26,14 +27,20 @@ const exitStatusSuccess = "OK" // Client - URL, user and password to specific Proxmox node type Client struct { - session *Session - ApiUrl string - Username string - Password string - Otp string - TaskTimeout int + session *Session + ApiUrl string + Username string + Password string + Otp string + TaskTimeout int + version *Version + versionMutex *sync.Mutex } +const ( + Client_Error_Nil string = "client may not be nil" +) + const ( VmRef_Error_Nil string = "vm reference may not be nil" ) @@ -105,7 +112,7 @@ func NewClient(apiUrl string, hclient *http.Client, http_headers string, tls *tl return nil, err } if err_s == nil { - client = &Client{session: sess, ApiUrl: apiUrl, TaskTimeout: taskTimeout} + client = &Client{session: sess, ApiUrl: apiUrl, TaskTimeout: taskTimeout, versionMutex: &sync.Mutex{}} } return client, err_s @@ -130,13 +137,19 @@ func (c *Client) Login(username string, password string, otp string) (err error) return c.session.Login(username, password, otp) } -func (c *Client) GetVersion() (data map[string]interface{}, err error) { - resp, err := c.session.Get("/version", nil, nil) - if err != nil { - return nil, err +// Updates the client's cached version information and returns it. +func (c *Client) GetVersion() (version Version, err error) { + params, err := c.GetItemConfigMapStringInterface("/version", "version", "data") + version = version.mapToSDK(params) + cachedVersion := Version{ // clones the struct + Major: version.Major, + Minor: version.Minor, + Patch: version.Patch, } - - return ResponseJSON(resp) + c.versionMutex.Lock() + c.version = &cachedVersion + c.versionMutex.Unlock() + return } func (c *Client) GetJsonRetryable(url string, data *map[string]interface{}, tries int) error { @@ -2213,3 +2226,59 @@ func (c *Client) CheckTask(resp *http.Response) (exitStatus string, err error) { } return c.WaitForCompletion(taskResponse) } + +// Returns the Client's cached version if it exists, otherwise fetches the version from the API. +func (c *Client) Version() (Version, error) { + if c == nil { + return Version{}, errors.New(Client_Error_Nil) + } + if c.version == nil { + return c.GetVersion() + } + c.versionMutex.Lock() + defer c.versionMutex.Unlock() + return Version{ + Major: c.version.Major, + Minor: c.version.Minor, + Patch: c.version.Patch, + }, nil +} + +type Version struct { + Major uint8 + Minor uint8 + Patch uint8 +} + +// Greater returns true if the version is greater than the other version. +func (v Version) Greater(other Version) bool { + return uint32(v.Major)*256*256+uint32(v.Minor)*256+uint32(v.Patch) > uint32(other.Major)*256*256+uint32(other.Minor)*256+uint32(other.Patch) +} + +func (Version) mapToSDK(params map[string]interface{}) (version Version) { + if itemValue, isSet := params["version"]; isSet { + rawVersion := strings.Split(itemValue.(string), ".") + if len(rawVersion) > 0 { + tmpMajor, _ := strconv.ParseUint(rawVersion[0], 10, 8) + version.Major = uint8(tmpMajor) + } + if len(rawVersion) > 1 { + tmpMinor, _ := strconv.ParseUint(rawVersion[1], 10, 8) + version.Minor = uint8(tmpMinor) + } + if len(rawVersion) > 2 { + tmpPatch, _ := strconv.ParseUint(rawVersion[2], 10, 8) + version.Patch = uint8(tmpPatch) + } + } + return +} + +// Smaller returns true if the version is less than the other version. +func (v Version) Smaller(other Version) bool { + return uint32(v.Major)*256*256+uint32(v.Minor)*256+uint32(v.Patch) < uint32(other.Major)*256*256+uint32(other.Minor)*256+uint32(other.Patch) +} + +func (v Version) String() string { + return strconv.FormatInt(int64(v.Major), 10) + "." + strconv.FormatInt(int64(v.Minor), 10) + "." + strconv.FormatInt(int64(v.Patch), 10) +} diff --git a/proxmox/client_test.go b/proxmox/client_test.go new file mode 100644 index 00000000..dec85678 --- /dev/null +++ b/proxmox/client_test.go @@ -0,0 +1,76 @@ +package proxmox + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_Version_Greater(t *testing.T) { + type input struct { + a Version + b Version + } + tests := []struct { + name string + input input + output bool + }{ + {"a > b 0", input{Version{1, 0, 0}, Version{0, 0, 0}}, true}, + {"a > b 1", input{Version{0, 1, 0}, Version{0, 0, 255}}, true}, + {"a > b 2", input{Version{1, 0, 0}, Version{0, 255, 255}}, true}, + {"a < b 0", input{Version{7, 4, 1}, Version{7, 4, 2}}, false}, + {"a < b 1", input{Version{0, 0, 255}, Version{0, 1, 0}}, false}, + {"a < b 2", input{Version{0, 255, 255}, Version{1, 0, 0}}, false}, + {"a = b", input{Version{0, 0, 0}, Version{0, 0, 0}}, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal(t, test.output, test.input.a.Greater(test.input.b)) + }) + } +} + +func Test_Version_mapToSDK(t *testing.T) { + tests := []struct { + name string + input map[string]interface{} + output Version + }{ + {"empty", map[string]interface{}{}, Version{}}, + {"full", map[string]interface{}{"version": "1.2.3"}, Version{1, 2, 3}}, + {"invalid", map[string]interface{}{"version": ""}, Version{}}, + {"major", map[string]interface{}{"version": "1"}, Version{1, 0, 0}}, + {"partial", map[string]interface{}{"version": "1.2"}, Version{1, 2, 0}}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal(t, test.output, Version{}.mapToSDK(test.input)) + }) + } +} + +func Test_Version_Smaller(t *testing.T) { + type input struct { + a Version + b Version + } + tests := []struct { + name string + input input + output bool + }{ + {"a > b 0", input{Version{1, 0, 0}, Version{0, 0, 0}}, false}, + {"a > b 1", input{Version{0, 1, 0}, Version{0, 0, 255}}, false}, + {"a > b 2", input{Version{1, 0, 0}, Version{0, 255, 255}}, false}, + {"a < b 0", input{Version{7, 4, 1}, Version{7, 4, 2}}, true}, + {"a < b 1", input{Version{0, 0, 255}, Version{0, 1, 0}}, true}, + {"a < b 2", input{Version{0, 255, 255}, Version{1, 0, 0}}, true}, + {"a = b", input{Version{0, 0, 0}, Version{0, 0, 0}}, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal(t, test.output, test.input.a.Smaller(test.input.b)) + }) + } +}