Skip to content

Commit

Permalink
Merge pull request #2 from Wiston999/users-cache
Browse files Browse the repository at this point in the history
Some performance improvements
  • Loading branch information
meln5674 authored Jul 12, 2023
2 parents a0adf73 + 19c467e commit c26a3cd
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 28 deletions.
2 changes: 1 addition & 1 deletion e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var _ = Describe("The Nexus OIDC Proxy", Ordered, func() {
b.Navigate("https://nexus.nexus-oidc-proxy-it.cluster")

welcomeImg := `img.nxrm-welcome__logo`
Eventually(welcomeImg, "15s").Should(b.Exist())
Eventually(welcomeImg, "30s").Should(b.Exist())
})

It("Should show the user as logged in", func() {
Expand Down
2 changes: 1 addition & 1 deletion integration-test/nexus-oidc-proxy.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ nexus:
oidc:
wellKnownURL: https://keycloak.nexus-oidc-proxy-it/realms/integration-test/.well-known/openid-configuration
accessTokenHeader: X-Forwarded-Access-Token
syncInterval: 5m
syncInterval: 0s
userTemplate: |-
userId: '{{ .Token.Claims.preferred_username }}'
firstName: '{{ .Token.Claims.given_name }}'
Expand Down
112 changes: 86 additions & 26 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/Masterminds/sprig/v3"
"github.com/aquasecurity/yaml"
jwt "github.com/golang-jwt/jwt/v4"
flag "github.com/spf13/pflag"
"io/ioutil"
"log"
"math/rand"
Expand All @@ -18,8 +14,14 @@ import (
"net/url"
"os"
"strings"
"sync"
"text/template"
"time"

"github.com/Masterminds/sprig/v3"
"github.com/aquasecurity/yaml"
jwt "github.com/golang-jwt/jwt/v4"
flag "github.com/spf13/pflag"
)

const (
Expand Down Expand Up @@ -166,18 +168,35 @@ type ProxyCredentials struct {
}

type ProxyState struct {
Config *ProxyConfig
Credentials *ProxyCredentials
LastUserSync map[string]time.Time
Config *ProxyConfig
Credentials *ProxyCredentials
UsersCache sync.Map
httputil.ReverseProxy
*http.ServeMux
}

type UserCacheEntry struct {
User *NexusUser
LastSync time.Time
RolesLastSync time.Time
}

func NewProxy(config ProxyConfig, credentials ProxyCredentials) (*ProxyState, error) {
state := &ProxyState{
Config: &config,
Credentials: &credentials,
LastUserSync: make(map[string]time.Time),
Config: &config,
Credentials: &credentials,
}
users, err := state.GetUsers(nil)
if err != nil {
log.Printf("Error while warming up users cache: %s. Starting with empty cache", err)
} else {
for _, user := range users {
state.UsersCache.Store(user.UserID, UserCacheEntry{
User: &user,
LastSync: time.Now(),
})
}
log.Printf("Saved %d users in local cache\n", len(users))
}
state.ReverseProxy.Director = state.Director
state.ReverseProxy.ModifyResponse = state.ModifyResponse
Expand Down Expand Up @@ -274,34 +293,48 @@ func (p *ProxyState) GetDesiredUserRoles(token *jwt.Token) ([]string, error) {
return roles, nil
}

func (p *ProxyState) GetUser(userID string) (*NexusUser, bool, error) {
func (p *ProxyState) GetUsers(userID *string) ([]NexusUser, error) {
getUser := p.Config.Nexus.Upstream.Inner
getUser.Path += "service/rest/v1/security/users"
getUser.RawQuery = fmt.Sprintf("userId=%s&source=default", userID)
if userID != nil {
getUser.RawQuery = fmt.Sprintf("userId=%s&source=default", *userID)
} else {
getUser.RawQuery = "source=default"
}
req, err := http.NewRequest(http.MethodGet, getUser.String(), nil)
if err != nil {
return nil, false, err
return nil, err
}
req.SetBasicAuth(p.Credentials.Nexus.Username, p.Credentials.Nexus.Password)
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, false, err
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
body, _ := ioutil.ReadAll(res.Body)
return nil, false, fmt.Errorf("GET %s: %s - %s", &getUser, res.Status, string(body))
return nil, fmt.Errorf("GET %s: %s - %s", &getUser, res.Status, string(body))
}
result := make([]NexusUser, 0, 1)
err = json.NewDecoder(res.Body).Decode(&result)
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
}
return result, nil
}

func (p *ProxyState) GetUser(userID string) (*NexusUser, bool, error) {
result, err := p.GetUsers(&userID)
if err != nil {
return nil, false, err
}
if len(result) == 0 {
return nil, false, nil
}
user := result[0]
return &user, true, nil
return &result[0], true, nil
}

func (p *ProxyState) CreateUser(user *NexusUser) error {
Expand Down Expand Up @@ -356,7 +389,7 @@ func (p *ProxyState) UpdateUser(user *NexusUser) error {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
if res.StatusCode < 200 || res.StatusCode >= 300 {
body, _ := ioutil.ReadAll(res.Body)
return fmt.Errorf("PUT %s: %s - %s", &putUser, res.Status, string(body))
}
Expand Down Expand Up @@ -404,8 +437,9 @@ func (p *ProxyState) ExtractClaims(r *http.Request) (token *jwt.Token, err error

func (p *ProxyState) Director(r *http.Request) {
incomingURL := *r.URL
*r.URL = p.Config.Nexus.Upstream.Inner
r.URL.Path += incomingURL.Path
r.URL.Host = p.Config.Nexus.Upstream.Inner.Host
r.URL.Scheme = p.Config.Nexus.Upstream.Inner.Scheme
r.URL.Path = p.Config.Nexus.Upstream.Inner.Path + incomingURL.Path
defer log.Println(r)
token, err := p.ExtractClaims(r)
if err != nil {
Expand All @@ -418,12 +452,32 @@ func (p *ProxyState) Director(r *http.Request) {
log.Println(err)
return
}
existingUser, found, err := p.GetUser(onboardedUser.UserID)
if err != nil {
log.Println(err)
return
var existingUser *NexusUser
var found bool
var cachedUser UserCacheEntry
cachedValue, exists := p.UsersCache.Load(onboardedUser.UserID)
if exists {
cachedUser = cachedValue.(UserCacheEntry)
}
// If user in cache and recently updated, use cached user
if exists && time.Now().Before(cachedUser.LastSync.Add(p.Config.OIDC.SyncInterval.Inner)) {
existingUser = cachedUser.User
found = true
log.Printf("Found user %s in local cache and in-sync\n", existingUser.UserID)
// Else retrieve user from nexus server
} else {
existingUser, found, err = p.GetUser(onboardedUser.UserID)
if err != nil {
log.Println(err)
return
}
if found {
cachedUser.LastSync = time.Now()
}
}
if !found {
// If user doesn't exist ensure we are not caching invalid data
p.UsersCache.Delete(onboardedUser.UserID)
err = p.CreateUser(onboardedUser)
if err != nil {
log.Println(err)
Expand All @@ -438,9 +492,14 @@ func (p *ProxyState) Director(r *http.Request) {
log.Printf("User %s did not exist after creation?\n", onboardedUser.UserID)
return
}
cachedUser.LastSync = time.Now()
}
// Update user information in case user is just created or refreshed from nexus server
cachedUser.User = existingUser
p.UsersCache.Store(onboardedUser.UserID, cachedUser)

r.Header.Add(p.Config.Nexus.RUTAuthHeader, existingUser.UserID)
lastSync := p.LastUserSync[existingUser.UserID]
lastSync := cachedUser.RolesLastSync
if time.Now().Before(lastSync.Add(p.Config.OIDC.SyncInterval.Inner)) {
return
}
Expand All @@ -457,7 +516,8 @@ func (p *ProxyState) Director(r *http.Request) {
log.Println(err)
return
}
p.LastUserSync[existingUser.UserID] = time.Now()
cachedUser.RolesLastSync = time.Now()
p.UsersCache.Store(existingUser.UserID, cachedUser)
}

func (p *ProxyState) ModifyResponse(resp *http.Response) error {
Expand Down

0 comments on commit c26a3cd

Please sign in to comment.