Skip to content
This repository has been archived by the owner on Dec 7, 2020. It is now read-only.

Commit

Permalink
feat #524 Support mapping additional request paths to different upstr…
Browse files Browse the repository at this point in the history
…eam URLs
  • Loading branch information
chirino committed Apr 8, 2020
1 parent 008527a commit 4d9b550
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 15 deletions.
38 changes: 38 additions & 0 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ limitations under the License.
package main

import (
"errors"
"fmt"
"os"
"os/signal"
"reflect"
"strings"
"syscall"
"time"

Expand Down Expand Up @@ -220,6 +222,42 @@ func parseCLIOptions(cx *cli.Context, config *Config) (err error) {
config.Resources = append(config.Resources, resource)
}
}
if cx.IsSet("upstream-url-paths") {
for _, x := range cx.StringSlice("upstream-url-paths") {
path, err := cliParseUpstreamURLPath(x)
if err != nil {
return fmt.Errorf("invalid upstream-url-paths %s, %s", x, err)
}
config.UpstreamPaths = append(config.UpstreamPaths, path)
}
}

return nil
}

func cliParseUpstreamURLPath(resource string) (r UpstreamURLPath, err error) {
if resource == "" {
return r, errors.New("no value given")
}
for _, x := range strings.Split(resource, "|") {
kp := strings.Split(x, "=")
if len(kp) != 2 {
return r, errors.New("config pair, should be (uri|upstream-url)=value")
}
switch kp[0] {
case "uri":
r.URL = kp[1]
case "upstream-url":
r.Upstream = kp[1]
default:
return r, fmt.Errorf("invalid identifier '%s', should be uri or upstream-url", kp[0])
}
}
if r.URL == "" {
return r, errors.New("uri config missing")
}
if r.Upstream == "" {
return r, errors.New("upstream-url config missing")
}
return r, err
}
9 changes: 9 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ func (r *Config) isValid() error {
if r.Upstream == "" {
return errors.New("you have not specified an upstream endpoint to proxy to")
}

if len(r.UpstreamPaths) > 0 {
for _, p := range r.UpstreamPaths {
if _, err := url.Parse(p.Upstream); err != nil {
return fmt.Errorf("the upstream endpoint `%s` is invalid, %s", p, err)
}
}
}

if _, err := url.Parse(r.Upstream); err != nil {
return fmt.Errorf("the upstream endpoint is invalid, %s", err)
}
Expand Down
9 changes: 9 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ var (
ErrDecryption = errors.New("failed to decrypt token")
)

type UpstreamURLPath struct {
// URL the url for the resource
URL string `json:"uri" yaml:"uri"`
// Upstream is the upstream endpoint i.e whom were proxying to
Upstream string `json:"upstream-url" yaml:"upstream-url"`
}

// Resource represents a url resource to protect
type Resource struct {
// URL the url for the resource
Expand Down Expand Up @@ -184,6 +191,8 @@ type Config struct {
Scopes []string `json:"scopes" yaml:"scopes" usage:"list of scopes requested when authenticating the user"`
// Upstream is the upstream endpoint i.e whom were proxying to
Upstream string `json:"upstream-url" yaml:"upstream-url" usage:"url for the upstream endpoint you wish to proxy" env:"UPSTREAM_URL"`
// Resources is a list of protected resources
UpstreamPaths []UpstreamURLPath `json:"upstream-url-paths" yaml:"upstream-url-paths" usage:"list of upstream url paths 'uri=/admin*|upstream-url=http://server1|uri=/data*|upstream-url=http://server2:8080'"`
// UpstreamCA is the path to a CA certificate in PEM format to validate the upstream certificate
UpstreamCA string `json:"upstream-ca" yaml:"upstream-ca" usage:"the path to a file container a CA certificate to validate the upstream tls endpoint"`
// Resources is a list of protected resources
Expand Down
17 changes: 12 additions & 5 deletions forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main
import (
"fmt"
"net/http"
"net/url"
"time"

"github.com/coreos/go-oidc/jose"
Expand Down Expand Up @@ -53,27 +54,33 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler {
req.Header.Set(k, v)
}

r.upstream.ServeHTTP(w, req)
})
}

func (r *oauthProxy) forwardToUpstream(upstreamUrl *url.URL, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// @note: by default goproxy only provides a forwarding proxy, thus all requests have to be absolute and we must update the host headers
req.URL.Host = r.endpoint.Host
req.URL.Scheme = r.endpoint.Scheme
req.URL.Host = upstreamUrl.Host
req.URL.Scheme = upstreamUrl.Scheme
if v := req.Header.Get("Host"); v != "" {
req.Host = v
req.Header.Del("Host")
} else if !r.config.PreserveHost {
req.Host = r.endpoint.Host
req.Host = upstreamUrl.Host
}

if isUpgradedConnection(req) {
r.log.Debug("upgrading the connnection", zap.String("client_ip", req.RemoteAddr))
if err := tryUpdateConnection(req, w, r.endpoint); err != nil {
if err := tryUpdateConnection(req, w, upstreamUrl); err != nil {
r.log.Error("failed to upgrade connection", zap.Error(err))
w.WriteHeader(http.StatusInternalServerError)
return
}
return
}

r.upstream.ServeHTTP(w, req)
next.ServeHTTP(w, req)
})
}

Expand Down
58 changes: 48 additions & 10 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func createLogger(config *Config) (*zap.Logger, error) {
// createReverseProxy creates a reverse proxy
func (r *oauthProxy) createReverseProxy() error {
r.log.Info("enabled reverse proxy mode, upstream url", zap.String("url", r.config.Upstream))
if err := r.createUpstreamProxy(r.endpoint); err != nil {
if err := r.createDefaultUpstreamProxy(); err != nil {
return err
}
engine := chi.NewRouter()
Expand Down Expand Up @@ -293,14 +293,15 @@ func (r *oauthProxy) createForwardingProxy() error {
if r.config.SkipUpstreamTLSVerify {
r.log.Warn("tls verification switched off. In forward signing mode it's recommended you verify! (--skip-upstream-tls-verify=false)")
}
if err := r.createUpstreamProxy(nil); err != nil {

proxy, err := r.createUpstreamProxy(nil)
if err != nil {
return err
}
//nolint:bodyclose
forwardingHandler := r.forwardProxyHandler()

// set the http handler
proxy := r.upstream.(*goproxy.ProxyHttpServer)
r.router = proxy

// setup the tls configuration
Expand Down Expand Up @@ -553,8 +554,47 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er
return listener, nil
}

func (r *oauthProxy) createDefaultUpstreamProxy() error {
defaultUpstream, err := r.createUpstreamProxy(r.endpoint)
if err != nil {
return err
}

if len(r.config.UpstreamPaths) > 0 {
engine := chi.NewRouter()

for _, x := range r.config.UpstreamPaths {
path := x
fmt.Printf("%s => %s\n", path.URL, path.Upstream)
upstreamUrl, err := url.Parse(path.Upstream)
if err != nil {
return err
}

proxy, err := r.createUpstreamProxy(upstreamUrl)
if err != nil {
return err
}

engine.Mount(path.URL, r.forwardToUpstream(upstreamUrl, proxy))

//engine.Mount(path.URL, http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
// fmt.Printf("hit %s => %s\n", path.URL, u)
// proxy.ServeHTTP(writer, request)
//}))
}

engine.NotFound(r.forwardToUpstream(r.endpoint, defaultUpstream).ServeHTTP)

r.upstream = engine
} else {
r.upstream = r.forwardToUpstream(r.endpoint, defaultUpstream)
}
return nil
}

// createUpstreamProxy create a reverse http proxy from the upstream
func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error {
func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) (*goproxy.ProxyHttpServer, error) {
dialer := (&net.Dialer{
KeepAlive: r.config.UpstreamKeepaliveTimeout,
Timeout: r.config.UpstreamTimeout,
Expand Down Expand Up @@ -583,7 +623,7 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error {
cert, err := ioutil.ReadFile(r.config.TLSClientCertificate)
if err != nil {
r.log.Error("unable to read client certificate", zap.String("path", r.config.TLSClientCertificate), zap.Error(err))
return err
return nil, err
}
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(cert)
Expand All @@ -597,7 +637,7 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error {
r.log.Info("loading the upstream ca", zap.String("path", r.config.UpstreamCA))
ca, err := ioutil.ReadFile(r.config.UpstreamCA)
if err != nil {
return err
return nil, err
}
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(ca)
Expand All @@ -614,10 +654,9 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error {
proxy.KeepDestinationHeaders = true
proxy.Logger = httplog.New(ioutil.Discard, "", 0)
proxy.KeepDestinationHeaders = true
r.upstream = proxy

// update the tls configuration of the reverse proxy
r.upstream.(*goproxy.ProxyHttpServer).Tr = &http.Transport{
proxy.Tr = &http.Transport{
Dial: dialer,
DisableKeepAlives: !r.config.UpstreamKeepalives,
ExpectContinueTimeout: r.config.UpstreamExpectContinueTimeout,
Expand All @@ -627,8 +666,7 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error {
MaxIdleConns: r.config.MaxIdleConns,
MaxIdleConnsPerHost: r.config.MaxIdleConnsPerHost,
}

return nil
return proxy, nil
}

// createTemplates loads the custom template
Expand Down
99 changes: 99 additions & 0 deletions server_upstream_paths_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
Copyright 2015 All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"

"github.com/stretchr/testify/require"
)

// fakeUpstreamService acts as a fake upstream service, returns the headers and request
type counterService struct {
Name string
HitCounter int64
}

func (f *counterService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Printf("counter %s\n", f.Name)
atomic.AddInt64(&f.HitCounter, 1)
}

//TestWebSocket is used to validate that the proxy reverse proxy WebSocket connections.
func TestUpstreamPaths(t *testing.T) {

// Setup an upstream service.
defaultSvc := &counterService{Name: "default"}
defaultSvcServer := httptest.NewServer(defaultSvc)
defer defaultSvcServer.Close()

adminSvc := &counterService{Name: "admin"}
adminSvcServer := httptest.NewServer(adminSvc)
defer adminSvcServer.Close()

dataSvc := &counterService{Name: "data"}
dataSvcServer := httptest.NewServer(dataSvc)
defer dataSvcServer.Close()

counters := func() []int64 {
return []int64{
atomic.AddInt64(&defaultSvc.HitCounter, 0),
atomic.AddInt64(&adminSvc.HitCounter, 0),
atomic.AddInt64(&dataSvc.HitCounter, 0),
}
}

// Setup the proxy.
config := newFakeKeycloakConfig()
config.Upstream = defaultSvcServer.URL
config.UpstreamPaths = []UpstreamURLPath{
{
URL: "/auth_all/white_listed/admin",
Upstream: adminSvcServer.URL,
},
{
URL: "/auth_all/white_listed/data",
Upstream: dataSvcServer.URL,
},
}

auth := newFakeAuthServer()
if config == nil {
config = newFakeKeycloakConfig()
}
config.DiscoveryURL = auth.getLocation()
config.RevocationEndpoint = auth.getRevocationURL()

proxy, err := newProxy(config)
require.NoError(t, err)

proxyServer := httptest.NewServer(proxy.router)
defer proxyServer.Close()

http.Get(proxyServer.URL + "/auth_all/white_listed/admin")
require.Equal(t, []int64{0, 1, 0}, counters())

http.Get(proxyServer.URL + "/auth_all/white_listed/other")
require.Equal(t, []int64{1, 1, 0}, counters())

http.Get(proxyServer.URL + "/auth_all/white_listed/data")
require.Equal(t, []int64{1, 1, 1}, counters())

}

0 comments on commit 4d9b550

Please sign in to comment.