Skip to content

Commit

Permalink
Merge pull request #17 from chicagopcdc/pcdc_dev
Browse files Browse the repository at this point in the history
Pcdc dev
  • Loading branch information
grugna authored Mar 6, 2024
2 parents 7b01767 + 088d05e commit 290fcec
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 132 deletions.
51 changes: 33 additions & 18 deletions arborist/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,23 @@ type AuthMappingQuery struct {

type AuthMapping map[string][]Action


// TODO This is just a patch to filter out excessive resources. When transitioning to pelican import we should have a project_id = xyz parameter instead
var authMappingProjectExclusion = `
ARRAY[
'programs.pcdc.projects.20231114.%',
'programs.pcdc.projects.20230912.%',
'programs.pcdc.projects.20230523.%',
'programs.pcdc.projects.20230228.%',
'programs.pcdc.projects.20220808.%',
'programs.pcdc.projects.20220501_S01.%',
'programs.pcdc.projects.20220201.%',
'programs.pcdc.projects.20220110.%',
'programs.pcdc.projects.20211006.%',
'programs.pcdc.projects.20210915.%',
'programs.pcdc.projects.20210212.%'
]
`
// authMapping gets the auth mapping for the user with this username.
// The user's auth mapping includes the permissions of the `anonymous` and
// `logged-in` groups.
Expand Down Expand Up @@ -707,24 +724,12 @@ func authMapping(db *sqlx.DB, username string) (AuthMapping, *ErrorResponse) {
INNER JOIN policy_role ON policy_role.policy_id = policies.policy_id
INNER JOIN permission ON permission.role_id = policy_role.role_id
INNER JOIN resource ON resource.path <@ policy_resources.path
WHERE ltree2text(resource.path) NOT LIKE ALL (
ARRAY[
'programs.pcdc.projects.20231114.%',
'programs.pcdc.projects.20230912.%',
'programs.pcdc.projects.20230523.%',
'programs.pcdc.projects.20230228.%',
'programs.pcdc.projects.20220808.%',
'programs.pcdc.projects.20220501_S01.%',
'programs.pcdc.projects.20220201.%',
'programs.pcdc.projects.20220110.%',
'programs.pcdc.projects.20211006.%',
'programs.pcdc.projects.20210915.%',
'programs.pcdc.projects.20210212.%'
]
WHERE ltree2text(resource.path) NOT LIKE ALL (`

stmt += authMappingProjectExclusion
stmt += `
)
`

// where resource.path ~ (CAST('programs.pcdc.projects.20230228.*' AS lquery))
// where ltree2text(resource.path) not like 'programs.pcdc.projects.20220201.%' and ltree2text(resource.path) not like 'programs.pcdc.projects.20220808.%') as teat;

Expand Down Expand Up @@ -766,6 +771,12 @@ func authMappingForGroups(db *sqlx.DB, groups ...string) (AuthMapping, *ErrorRes
INNER JOIN policy_role ON policy_role.policy_id = policies.policy_id
INNER JOIN permission ON permission.role_id = policy_role.role_id
INNER JOIN resource ON resource.path <@ roots.path
WHERE ltree2text(resource.path) NOT LIKE ALL (`

stmt += authMappingProjectExclusion
stmt += `
)
`
// sqlx.In allows safely binding variable numbers of arguments as bindvars.
// See https://jmoiron.github.io/sqlx/#inQueries,
Expand All @@ -792,7 +803,6 @@ func authMappingForGroups(db *sqlx.DB, groups ...string) (AuthMapping, *ErrorRes
return mapping, nil
}


// authMappingForClient gets the auth mapping for a client ID.
// It does NOT includes the permissions of the `anonymous` and
// `logged-in` groups.
Expand All @@ -813,11 +823,16 @@ func authMappingForClient(db *sqlx.DB, clientID string) (AuthMapping, *ErrorResp
INNER JOIN policy_role ON policy_role.policy_id = policies.policy_id
INNER JOIN permission ON permission.role_id = policy_role.role_id
INNER JOIN resource ON resource.path <@ roots.path
WHERE ltree2text(resource.path) NOT LIKE ALL (`

stmt += authMappingProjectExclusion
stmt += `
)
`
err := db.Select(
&mappingQuery,
stmt,
clientID, // $1
clientID, // $1
)
if err != nil {
errResponse := newErrorResponse("mapping query failed", 500, &err)
Expand Down
117 changes: 82 additions & 35 deletions arborist/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (server *Server) Init() (*Server, error) {

// For some reason this is not allowed:
//
// `{resourcePath:/.+}`
// `{resourcePath:/.+}`
//
// so we put the slash at the front here and fix it in parseResourcePath.
const resourcePath string = `/{resourcePath:.+}`
Expand Down Expand Up @@ -102,7 +102,7 @@ func (server *Server) MakeRouter(out io.Writer) http.Handler {
router.HandleFunc("/health", server.handleHealth).Methods("GET")

router.Handle("/auth/mapping", http.HandlerFunc(server.handleAuthMappingGET)).Methods("GET")
router.Handle("/auth/mapping", http.HandlerFunc(server.parseJSON(server.handleAuthMappingPOST))).Methods("POST")
router.Handle("/auth/mapping", http.HandlerFunc(server.handleAuthMappingPOST)).Methods("POST")
router.Handle("/auth/proxy", http.HandlerFunc(server.handleAuthProxy)).Methods("GET")
router.Handle("/auth/request", http.HandlerFunc(server.parseJSON(server.handleAuthRequest))).Methods("POST")
router.Handle("/auth/resources", http.HandlerFunc(server.handleListAuthResourcesGET)).Methods("GET")
Expand Down Expand Up @@ -176,25 +176,30 @@ func (server *Server) MakeRouter(out io.Writer) http.Handler {
// handler signature.
func (server *Server) parseJSON(baseHandler func(http.ResponseWriter, *http.Request, []byte)) func(http.ResponseWriter, *http.Request) {
handler := func(w http.ResponseWriter, r *http.Request) {
if r.Body == nil {
response := newErrorResponse("expected JSON body in the request", 400, nil)
response.log.write(server.logger)
_ = response.write(w, r)
return
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
msg := fmt.Sprintf("could not parse valid JSON from request: %s", err.Error())
response := newErrorResponse(msg, 400, nil)
response.log.write(server.logger)
_ = response.write(w, r)
return
}
body := server.parseJsonBody(w, r)
baseHandler(w, r, body)
}
return handler
}

func (server *Server) parseJsonBody(w http.ResponseWriter, r *http.Request) []byte {
if r.Body == nil {
response := newErrorResponse("expected JSON body in the request", 400, nil)
response.log.write(server.logger)
_ = response.write(w, r)
return nil
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
msg := fmt.Sprintf("could not parse valid JSON from request: %s", err.Error())
response := newErrorResponse(msg, 400, nil)
response.log.write(server.logger)
_ = response.write(w, r)
return nil
}
return body
}

var regWhitespace *regexp.Regexp = regexp.MustCompile(`\s`)

func loggableJSON(bytes []byte) []byte {
Expand Down Expand Up @@ -275,33 +280,75 @@ func (server *Server) handleAuthMappingGET(w http.ResponseWriter, r *http.Reques
}
}

func (server *Server) handleAuthMappingPOST(w http.ResponseWriter, r *http.Request, body []byte) {
func (server *Server) handleAuthMappingPOST(w http.ResponseWriter, r *http.Request) {
var errResponse *ErrorResponse = nil
requestBody := struct {
Username string `json:"username"`
ClientID string `json:"clientID"`
ClientID string `json:"clientID"`
}{}
err := json.Unmarshal(body, &requestBody)
if err != nil {
msg := fmt.Sprintf("could not parse JSON: %s", err.Error())
server.logger.Info("tried to handle auth mapping request but input was invalid: %s", msg)
errResponse = newErrorResponse(msg, 400, nil)
}
if (requestBody.Username == "") == (requestBody.ClientID == "") {
msg := "must specify exactly one of `username` or `clientID`"
server.logger.Info(msg)
errResponse = newErrorResponse(msg, 400, nil)
}
if errResponse != nil {
_ = errResponse.write(w, r)
return

username := ""
clientID := ""
if authHeader := r.Header.Get("Authorization"); authHeader != "" {
server.logger.Info("Attempting to get username or clientID from jwt...")
userJWT := strings.TrimPrefix(authHeader, "Bearer ")
userJWT = strings.TrimPrefix(userJWT, "bearer ")
scopes := []string{"openid"}
info, err := server.decodeToken(userJWT, scopes)
if err != nil {
// Return 401 on failure to decode JWT
msg := fmt.Sprintf("tried to get username/client ID from jwt, but jwt decode failed: %s", err.Error())
server.logger.Info(msg)
errResponse = newErrorResponse(msg, 401, nil)
_ = errResponse.write(w, r)
return
}

// When there is a username, there could be a client ID too (token belonging to a client acting
// on behalf of a user). But this endpoint only supports returning the user's mapping, not
// the combination of user+client access. So ignore the client ID.
if info.username != "" {
username = info.username
server.logger.Info("found username in jwt: %s", username)
} else if info.clientID != "" {
clientID = info.clientID
server.logger.Info("found client ID in jwt: %s", clientID)
} else {
msg := "invalid token (no username or client ID)"
server.logger.Error(msg)
errResponse = newErrorResponse(msg, 401, nil)
_ = errResponse.write(w, r)
return
}
} else {
// If they are not present in the token, fallback on the request body
server.logger.Info("No jwt provided, checking request body")
body := server.parseJsonBody(w, r)
err := json.Unmarshal(body, &requestBody)
if err != nil {
msg := fmt.Sprintf("could not parse JSON: %s", err.Error())
server.logger.Error("tried to handle auth mapping request but input was invalid: %s", msg)
errResponse = newErrorResponse(msg, 400, nil)
} else {
username = requestBody.Username
clientID = requestBody.ClientID
if (username == "") == (clientID == "") {
msg := "must provide a token or specify exactly one of `username` or `clientID` in the request body"
server.logger.Info(msg)
errResponse = newErrorResponse(msg, 400, nil)
}
}
if errResponse != nil {
_ = errResponse.write(w, r)
return
}
}

var mappings AuthMapping
if requestBody.ClientID != "" {
mappings, errResponse = authMappingForClient(server.db, requestBody.ClientID)
if clientID != "" {
mappings, errResponse = authMappingForClient(server.db, clientID)
} else {
mappings, errResponse = authMapping(server.db, requestBody.Username)
mappings, errResponse = authMapping(server.db, username)
}
if errResponse != nil {
errResponse.log.write(server.logger)
Expand Down
Loading

0 comments on commit 290fcec

Please sign in to comment.