diff --git a/Dockerfile b/Dockerfile index baa7cfc..12c06ce 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,7 +17,7 @@ ARG SPARK_IMAGE=ghcr.io/datapunchorg/spark:spark-3.2-1642867779 -FROM golang:1.17.3-alpine as builder +FROM golang:1.21.3-alpine3.17 as builder WORKDIR /workspace @@ -35,7 +35,7 @@ COPY pkg/ pkg/ # Build RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 GO111MODULE=on go build -a -o /usr/bin/spark-ui-reverse-proxy main.go -FROM alpine:3.14 +FROM alpine:3.17 USER root COPY --from=builder /usr/bin/spark-ui-reverse-proxy /usr/bin/ diff --git a/pkg/handlers/sparkui.go b/pkg/handlers/sparkui.go index 2cfe146..b699f29 100644 --- a/pkg/handlers/sparkui.go +++ b/pkg/handlers/sparkui.go @@ -18,17 +18,21 @@ package handlers import ( "fmt" - "github.com/gin-gonic/gin" "log" "net/http" "net/http/httputil" "net/url" "regexp" "strings" + + "github.com/gin-gonic/gin" ) var sparkUIAppNameURLRegex = regexp.MustCompile("{{\\s*[$]appName\\s*}}") var sparkUIAppNamespaceURLRegex = regexp.MustCompile("{{\\s*[$]appNamespace\\s*}}") +var magicPaths = []string{ + "StreamingQuery/statistics", +} func getSparkUIServiceUrl(sparkUIServiceUrlFormat string, appName string, appNamespace string) string { return sparkUIAppNamespaceURLRegex.ReplaceAllString(sparkUIAppNameURLRegex.ReplaceAllString(sparkUIServiceUrlFormat, appName), appNamespace) @@ -48,15 +52,12 @@ func ServeSparkUI(c *gin.Context, config *ApiConfig, uiRootPath string) { path = "" } else { appName = path[0:index] - path = path[index + 1:] + path = path[index+1:] } // get url for the underlying Spark UI Kubernetes service, which is created by spark-on-k8s-operator sparkUIServiceUrl := getSparkUIServiceUrl(config.SparkUIServiceUrl, appName, config.SparkApplicationNamespace) - proxyBasePath := "" - if config.ModifyRedirectUrl { - proxyBasePath = fmt.Sprintf("%s/%s", uiRootPath, appName) - } - proxy, err := newReverseProxy(sparkUIServiceUrl, path, proxyBasePath) + proxyBasePath := fmt.Sprintf("%s/%s", uiRootPath, appName) + proxy, err := newReverseProxy(sparkUIServiceUrl, path, proxyBasePath, config.ModifyRedirectUrl) if err != nil { c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to create reverse proxy for application %s: %s", appName, err.Error())) return @@ -65,7 +66,7 @@ func ServeSparkUI(c *gin.Context, config *ApiConfig, uiRootPath string) { proxy.ServeHTTP(c.Writer, c.Request) } -func newReverseProxy(sparkUIServiceUrl string, targetPath string, proxyBasePath string) (*httputil.ReverseProxy, error) { +func newReverseProxy(sparkUIServiceUrl string, targetPath string, proxyBasePath string, modifyRedirectUrl bool) (*httputil.ReverseProxy, error) { log.Printf("Creating revers proxy for Spark UI service url %s", sparkUIServiceUrl) targetUrl := sparkUIServiceUrl if targetPath != "" { @@ -78,46 +79,61 @@ func newReverseProxy(sparkUIServiceUrl string, targetPath string, proxyBasePath if err != nil { return nil, fmt.Errorf("failed to parse target Spark UI url %s: %s", targetUrl, err.Error()) } + director := func(req *http.Request) { - url.RawQuery = req.URL.RawQuery - url.RawFragment = req.URL.RawFragment - log.Printf("Reverse proxy: serving backend url %s for originally requested url %s", url, req.URL) - req.URL = url + modifyRequest(req, url) } + modifyResponse := func(resp *http.Response) error { - if proxyBasePath != "" && resp.StatusCode == http.StatusFound { - // Append the proxy base path before the redirect path. - // Also modify redirect url to only contain path and not contain host name, - // so redirect will retain the original requested host name. - headerName := "Location" - locationHeaderValues := resp.Header[headerName] - if len(locationHeaderValues) > 0 { - newValues := make([]string, 0, len(locationHeaderValues)) - for _, oldHeaderValue := range locationHeaderValues { - parsedUrl, err := url.Parse(oldHeaderValue) - if err != nil { - log.Printf("Reverse proxy: invalid response header value %s: %s (backend url %s): %s", headerName, oldHeaderValue, url, err.Error()) - newValues = append(newValues, oldHeaderValue) - } else { - parsedUrl.Scheme = "" - parsedUrl.Host = "" - newPath := parsedUrl.Path - if !strings.HasPrefix(newPath, "/") { - newPath = "/" + newPath - } + return modifyResponseRedirect(resp, proxyBasePath, url, modifyRedirectUrl) + } + return &httputil.ReverseProxy{ + Director: director, + ModifyResponse: modifyResponse, + }, nil +} + +func modifyRequest(req *http.Request, url *url.URL) { + url.RawQuery = req.URL.RawQuery + url.RawFragment = req.URL.RawFragment + log.Printf("Reverse proxy: serving backend url %s for originally requested url %s", url, req.URL) + req.URL = url +} + +func modifyResponseRedirect(resp *http.Response, proxyBasePath string, url *url.URL, modifyRedirectUrl bool) error { + if modifyRedirectUrl && resp.StatusCode == http.StatusFound { + // Append the proxy base path before the redirect path. + // Also modify redirect url to only contain path and not contain host name, + // so redirect will retain the original requested host name. + headerName := "Location" + locationHeaderValues := resp.Header[headerName] + if len(locationHeaderValues) > 0 { + newValues := make([]string, 0, len(locationHeaderValues)) + for _, oldHeaderValue := range locationHeaderValues { + parsedUrl, err := url.Parse(oldHeaderValue) + if err != nil { + log.Printf("Reverse proxy: invalid response header value %s: %s (backend url %s): %s", headerName, oldHeaderValue, url, err.Error()) + newValues = append(newValues, oldHeaderValue) + } else { + parsedUrl.Scheme = "" + parsedUrl.Host = "" + newPath := parsedUrl.Path + if !strings.HasPrefix(newPath, "/") { + newPath = "/" + newPath + } + idx := strings.Index(strings.ToLower(newPath), strings.ToLower(proxyBasePath)) + if idx < 0 { parsedUrl.Path = proxyBasePath + newPath - newHeaderValue := parsedUrl.String() - log.Printf("Reverse proxy: modifying response header %s from %s to %s (backend url %s)", headerName, oldHeaderValue, newHeaderValue, url) - newValues = append(newValues, newHeaderValue) + } else { + parsedUrl.Path = newPath[:idx] + proxyBasePath + newPath[idx+len(proxyBasePath):] } + newHeaderValue := parsedUrl.String() + log.Printf("Reverse proxy: modifying response header %s from %s to %s (backend url %s)", headerName, oldHeaderValue, newHeaderValue, url) + newValues = append(newValues, newHeaderValue) } - resp.Header[headerName] = newValues } + resp.Header[headerName] = newValues } - return nil } - return &httputil.ReverseProxy{ - Director: director, - ModifyResponse: modifyResponse, - }, nil -} \ No newline at end of file + return nil +} diff --git a/pkg/handlers/sparkui_test.go b/pkg/handlers/sparkui_test.go index 9b449b4..9da3514 100644 --- a/pkg/handlers/sparkui_test.go +++ b/pkg/handlers/sparkui_test.go @@ -17,8 +17,11 @@ limitations under the License. package handlers import ( - "github.com/stretchr/testify/assert" + "net/http" + "net/url" "testing" + + "github.com/stretchr/testify/assert" ) func Test_getSparkUIServiceUrl(t *testing.T) { @@ -32,3 +35,40 @@ func Test_getSparkUIServiceUrl(t *testing.T) { getSparkUIServiceUrl( "http://{{$appName}}-ui-svc.{{$appNamespace}}.svc.cluster.local:4040", "app1", "ns1")) } + +func TestModifyRequest(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "/sparkui/a3ac46c8487ecb95/static/webui.js?id=87c23377-4a64-47d3-82d7-5da9b39801a5", nil) + assert.NoError(t, err, "unexpected error") + u, err := url.Parse("http://a3ac46c8487ecb95-ui-svc.cluster.local:4040/static/webui.js") + assert.NoError(t, err, "unexpected error") + modifyRequest(r, u) + t.Logf("url=%s", r.URL.String()) +} + +func TestModifyResponse(t *testing.T) { + headers := http.Header{} + headers.Add("Location", "/sparkui/StreamingQuery/statistics/?id=7ab24792-82e1-433b-a158-dc5792878f57") + resp := &http.Response{ + Status: http.StatusText(http.StatusFound), + StatusCode: http.StatusFound, + Header: headers, + } + u, err := url.Parse("http://a3ac46c8487ecb95-ui-svc.cluster.local:4040/StreamingQuery/statistics/") + assert.NoError(t, err, "unexpected error") + + err = modifyResponseRedirect(resp, "/sparkui/a3ac46c8487ecb95", u, true) + assert.NoError(t, err, "unexpected error") + t.Logf("\n\"/sparkui/a3ac46c8487ecb95\" -> url=%s", resp.Header["Location"][0]) + + err = modifyResponseRedirect(resp, "", u, false) + assert.NoError(t, err, "unexpected error") + t.Logf("\n\"\" -> url=%s", resp.Header["Location"][0]) + + err = modifyResponseRedirect(resp, "/", u, true) + assert.NoError(t, err, "unexpected error") + t.Logf("\n\"/\" -> url=%s", resp.Header["Location"][0]) + + err = modifyResponseRedirect(resp, "sparkui/StreamingQuery", u, true) + assert.NoError(t, err, "unexpected error") + t.Logf("\n\"/sparkui/StreamingQuery\" -> url=%s", resp.Header["Location"][0]) +}