diff --git a/pkg/kubelet/kubelet_pod_lister.go b/pkg/kubelet/kubelet_pod_lister.go index 28a1161dc2..5f3c688bb3 100644 --- a/pkg/kubelet/kubelet_pod_lister.go +++ b/pkg/kubelet/kubelet_pod_lister.go @@ -37,8 +37,9 @@ const ( ) var ( - podURL string - client http.Client + podURL string + client http.Client + bearerToken string ) func init() { @@ -55,19 +56,20 @@ func init() { client = http.Client{} } -func httpGet(url string) (*http.Response, error) { - objToken, err := os.ReadFile(saPath) +func loadToken(path string) (string, error) { + objToken, err := os.ReadFile(path) if err != nil { - return nil, fmt.Errorf("failed to read from %q: %v", saPath, err) + return "", fmt.Errorf("failed to read from %q: %v", path, err) } - token := string(objToken) + return "Bearer " + string(objToken), nil +} - var bearer = "Bearer " + token +func doFetchPod(url string) (*http.Response, error) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody) if err != nil { return nil, err } - req.Header.Add("Authorization", bearer) + req.Header.Add("Authorization", bearerToken) resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to get response from %q: %v", url, err) @@ -75,9 +77,28 @@ func httpGet(url string) (*http.Response, error) { return resp, err } +func httpGet(path, url string) (*http.Response, error) { + var err error + if bearerToken == "" { + bearerToken, err = loadToken(path) + if err != nil { + return nil, fmt.Errorf("failed to read from %q: %v", path, err) + } + } + resp, err := doFetchPod(url) + if resp != nil && resp.StatusCode > 399 && resp.StatusCode < 500 { // if response in 4xx retry once + bearerToken, err = loadToken(path) + if err != nil { + return nil, fmt.Errorf("failed to read from %q: %v", path, err) + } + resp, err = doFetchPod(url) + } + return resp, err +} + // ListPods obtains PodList func (k *KubeletPodLister) ListPods() (*[]corev1.Pod, error) { - resp, err := httpGet(podURL) + resp, err := httpGet(saPath, podURL) if err != nil { return nil, fmt.Errorf("failed to get response: %v", err) } diff --git a/pkg/kubelet/kubelet_pod_lister_test.go b/pkg/kubelet/kubelet_pod_lister_test.go new file mode 100644 index 0000000000..8e5f1f9ff0 --- /dev/null +++ b/pkg/kubelet/kubelet_pod_lister_test.go @@ -0,0 +1,124 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kubelet + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "sync" + "testing" + + . "github.com/onsi/gomega" +) + +func TestDoFetchPod(t *testing.T) { + g := NewWithT(t) + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "dummy") + })) + defer svr.Close() + res, err := doFetchPod(svr.URL) + g.Expect(err).NotTo(HaveOccurred()) + defer res.Body.Close() + g.Expect(res.StatusCode).To(Equal(http.StatusOK)) + bearerToken = "dummy" + res, err = httpGet("", svr.URL) + g.Expect(err).NotTo(HaveOccurred()) + defer res.Body.Close() + g.Expect(res.StatusCode).To(Equal(http.StatusOK)) +} + +func TestDoFetchPodWithError(t *testing.T) { + g := NewWithT(t) + svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "dummy") + })) + res, err := doFetchPod(svr.URL) + g.Expect(err).To(HaveOccurred()) + g.Expect(res).To(BeNil()) + if res != nil { + defer res.Body.Close() + } + res, err = httpGet("", svr.URL) + if res != nil { + defer res.Body.Close() + } + g.Expect(err).To(HaveOccurred()) + g.Expect(res).To(BeNil()) +} + +func TestLoadToken(t *testing.T) { + g := NewWithT(t) + tmpDir, err := os.MkdirTemp("", "kepler-tmp-") + g.Expect(err).NotTo(HaveOccurred()) + defer os.RemoveAll(tmpDir) + + TokenFile, err := os.CreateTemp(tmpDir, "kubeletToken") + g.Expect(err).NotTo(HaveOccurred()) + _, err = TokenFile.WriteString("token") + g.Expect(err).NotTo(HaveOccurred()) + TokenFile.Close() + + bearerToken, err := loadToken(TokenFile.Name()) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(bearerToken).To(Equal("Bearer token")) + + var once sync.Once + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + once.Do(func() { + w.WriteHeader(http.StatusUnauthorized) + }) + fmt.Fprintf(w, "dummy") + })) + defer svr.Close() + + res, err := httpGet(TokenFile.Name(), svr.URL) + g.Expect(err).NotTo(HaveOccurred()) + defer res.Body.Close() + g.Expect(res.StatusCode).To(Equal(http.StatusOK)) +} + +func TestHttpGet(t *testing.T) { + g := NewWithT(t) + tmpDir, err := os.MkdirTemp("", "kepler-tmp-") + g.Expect(err).NotTo(HaveOccurred()) + defer os.RemoveAll(tmpDir) + + TokenFile, err := os.CreateTemp(tmpDir, "kubeletToken") + g.Expect(err).NotTo(HaveOccurred()) + _, err = TokenFile.WriteString("token") + g.Expect(err).NotTo(HaveOccurred()) + TokenFile.Close() + + bearerToken = "" + g.Expect(bearerToken).To(Equal("")) // need this to pass lint + var once sync.Once + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + once.Do(func() { + w.WriteHeader(http.StatusUnauthorized) + }) + fmt.Fprintf(w, "dummy") + })) + defer svr.Close() + + res, err := httpGet(TokenFile.Name(), svr.URL) + g.Expect(err).NotTo(HaveOccurred()) + defer res.Body.Close() + g.Expect(res.StatusCode).To(Equal(http.StatusOK)) +}