Skip to content

Commit

Permalink
get transactionID from request headers
Browse files Browse the repository at this point in the history
  • Loading branch information
karinamzalez committed Jan 31, 2024
1 parent 8ab75cd commit 2304de1
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion ssas/service/admin/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func routes() *chi.Mux {
r := chi.NewRouter()
m := monitoring.GetMonitor()

r.Use(gcmw.RequestID, service.NewTransactionID, service.NewAPILogger(), service.ConnectionClose, service.NewCtxLogger)
r.Use(gcmw.RequestID, service.GetTransactionID, service.NewAPILogger(), service.ConnectionClose, service.NewCtxLogger)
r.With(requireBasicAuth).Post(m.WrapHandler("/group", createGroup))
r.With(requireBasicAuth).Get(m.WrapHandler("/group", listGroups))
r.With(requireBasicAuth).Put(m.WrapHandler("/group/{id}", updateGroup))
Expand Down
5 changes: 2 additions & 3 deletions ssas/service/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/go-chi/chi/v5/middleware"
"github.com/sirupsen/logrus"
"github.com/twinj/uuid"

"github.com/CMSgov/bcda-ssas-app/log"
"github.com/CMSgov/bcda-ssas-app/ssas"
Expand Down Expand Up @@ -93,9 +92,9 @@ type CtxTransactionKeyType string
const CtxTransactionKey CtxTransactionKeyType = "ctxTransaction"

// Adds a transaction ID to the request context
func NewTransactionID(next http.Handler) http.Handler {
func GetTransactionID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), CtxTransactionKey, uuid.NewV4().String()))
r = r.WithContext(context.WithValue(r.Context(), CtxTransactionKey, r.Header.Get("transaction_id")))
next.ServeHTTP(w, r)
})
}
2 changes: 1 addition & 1 deletion ssas/service/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func start(ps *service.Server, as *service.Server, forwarder *http.Server) {

func newForwardingRouter() http.Handler {
r := chi.NewRouter()
r.Use(gcmw.RequestID, service.NewTransactionID, service.NewAPILogger(), service.ConnectionClose, service.NewCtxLogger)
r.Use(gcmw.RequestID, service.GetTransactionID, service.NewAPILogger(), service.ConnectionClose, service.NewCtxLogger)
r.Get("/*", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// TODO only forward requests for paths in our own host or resource server
url := "https://" + req.Host + req.URL.String()
Expand Down
10 changes: 5 additions & 5 deletions ssas/service/public/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (s *PublicMiddlewareTestSuite) TestRequireTokenAuthWithInvalidSignature() {
assert.NotNil(s.T(), err)
})
}
s.server = httptest.NewServer(s.CreateRouter(service.NewTransactionID, service.NewCtxLogger, parseToken, testForToken))
s.server = httptest.NewServer(s.CreateRouter(service.GetTransactionID, service.NewCtxLogger, parseToken, testForToken))
client := s.server.Client()

// Valid token should return a 200 response
Expand Down Expand Up @@ -77,7 +77,7 @@ func (s *PublicMiddlewareTestSuite) TestParseTokenEmptyToken() {
assert.NotNil(s.T(), err)
})
}
s.server = httptest.NewServer(s.CreateRouter(service.NewTransactionID, service.NewCtxLogger, parseToken, testForToken))
s.server = httptest.NewServer(s.CreateRouter(service.GetTransactionID, service.NewCtxLogger, parseToken, testForToken))
client := s.server.Client()

// Valid token should return a 200 response
Expand Down Expand Up @@ -111,7 +111,7 @@ func (s *PublicMiddlewareTestSuite) TestParseTokenValidToken() {
assert.Equal(s.T(), groupIDs, rd.AllowedGroupIDs)
})
}
s.server = httptest.NewServer(s.CreateRouter(service.NewTransactionID, service.NewCtxLogger, parseToken, testForToken))
s.server = httptest.NewServer(s.CreateRouter(service.GetTransactionID, service.NewCtxLogger, parseToken, testForToken))
client := s.server.Client()

_, ts, _ := MintRegistrationToken(oktaID, groupIDs)
Expand All @@ -131,7 +131,7 @@ func (s *PublicMiddlewareTestSuite) TestParseTokenValidToken() {
}

func (s *PublicMiddlewareTestSuite) TestRequireRegTokenAuthValidToken() {
s.server = httptest.NewServer(s.CreateRouter(service.NewTransactionID, service.NewCtxLogger, requireRegTokenAuth))
s.server = httptest.NewServer(s.CreateRouter(service.GetTransactionID, service.NewCtxLogger, requireRegTokenAuth))

// Valid token should return a 200 response
req, err := http.NewRequest("GET", s.server.URL, nil)
Expand Down Expand Up @@ -192,7 +192,7 @@ func (s *PublicMiddlewareTestSuite) TestRequireRegTokenAuthRevoked() {
}

func (s *PublicMiddlewareTestSuite) TestRequireRegTokenAuthEmptyToken() {
s.server = httptest.NewServer(s.CreateRouter(service.NewTransactionID, service.NewCtxLogger, requireRegTokenAuth))
s.server = httptest.NewServer(s.CreateRouter(service.GetTransactionID, service.NewCtxLogger, requireRegTokenAuth))
client := s.server.Client()

// Valid token should return a 200 response
Expand Down
2 changes: 1 addition & 1 deletion ssas/service/public/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func routes() *chi.Mux {
router := chi.NewRouter()
m := monitoring.GetMonitor()
//v1 Routes
router.Use(gcmw.RequestID, service.NewTransactionID, service.NewAPILogger(), service.ConnectionClose, service.NewCtxLogger)
router.Use(gcmw.RequestID, service.GetTransactionID, service.NewAPILogger(), service.ConnectionClose, service.NewCtxLogger)
router.Post(m.WrapHandler("/token", token))
router.Post(m.WrapHandler("/introspect", introspect))
router.With(parseToken, requireRegTokenAuth, readGroupID).Post(m.WrapHandler("/register", RegisterSystem))
Expand Down

0 comments on commit 2304de1

Please sign in to comment.