diff --git a/cmd/api/src/api/auth.go b/cmd/api/src/api/auth.go index 61f7621e33..c9be7787fb 100644 --- a/cmd/api/src/api/auth.go +++ b/cmd/api/src/api/auth.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package api @@ -35,7 +35,9 @@ import ( "github.com/specterops/bloodhound/log" "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/src/config" + "github.com/specterops/bloodhound/src/ctx" "github.com/specterops/bloodhound/src/database" + "github.com/specterops/bloodhound/src/database/types" "github.com/specterops/bloodhound/src/model" ) @@ -84,22 +86,73 @@ func NewAuthenticator(cfg config.Configuration, db database.Database, ctxInitial } } -func (s authenticator) LoginWithSecret(ctx context.Context, loginRequest LoginRequest) (LoginDetails, error) { +func (s authenticator) auditLogin(requestContext context.Context, commitID uuid.UUID, user model.User, loginRequest LoginRequest, status string, loginError error) { + bhCtx := ctx.Get(requestContext) + auditLog := model.AuditLog{ + Action: "LoginAttempt", + Fields: types.JSONUntypedObject{"username": loginRequest.Username}, + RequestID: bhCtx.RequestID, + SourceIpAddress: bhCtx.RequestIP, + Status: status, + CommitID: commitID, + } + + if user.PrincipalName != "" { + auditLog.ActorID = user.ID.String() + auditLog.ActorName = user.PrincipalName + auditLog.ActorEmail = user.EmailAddress.ValueOrZero() + } + + if status == string(model.AuditStatusFailure) { + auditLog.Fields["error"] = loginError + } + + s.db.CreateAuditLog(auditLog) +} + +func (s authenticator) validateSecretLogin(ctx context.Context, loginRequest LoginRequest) (model.User, string, error) { if user, err := s.db.LookupUser(loginRequest.Username); err != nil { if errors.Is(err, database.ErrNotFound) { - return LoginDetails{}, ErrInvalidAuth + return model.User{}, "", ErrInvalidAuth } - return LoginDetails{}, FormatDatabaseError(err) + return model.User{}, "", FormatDatabaseError(err) } else if user.AuthSecret == nil { - return LoginDetails{}, ErrNoUserSecret + return user, "", ErrNoUserSecret } else if err := s.ValidateSecret(ctx, loginRequest.Secret, *user.AuthSecret); err != nil { - return LoginDetails{}, err - } else if err := auth.ValidateTOTPSecret(loginRequest.OTP, *user.AuthSecret); err != nil { - return LoginDetails{}, err + return user, "", err + } else if err = auth.ValidateTOTPSecret(loginRequest.OTP, *user.AuthSecret); err != nil { + return user, "", err } else if sessionToken, err := s.CreateSession(user, *user.AuthSecret); err != nil { + return user, "", err + } else { + return user, sessionToken, nil + } +} + +func (s authenticator) LoginWithSecret(ctx context.Context, loginRequest LoginRequest) (LoginDetails, error) { + var ( + commitID uuid.UUID + err error + sessionToken string + user model.User + ) + + commitID, err = uuid.NewV4() + if err != nil { + log.Errorf("error generating commit ID for login: %s", err) + return LoginDetails{}, err + } + + s.auditLogin(ctx, commitID, user, loginRequest, string(model.AuditStatusIntent), err) + + user, sessionToken, err = s.validateSecretLogin(ctx, loginRequest) + + if err != nil { + s.auditLogin(ctx, commitID, user, loginRequest, string(model.AuditStatusFailure), err) return LoginDetails{}, err } else { + s.auditLogin(ctx, commitID, user, loginRequest, string(model.AuditStatusSuccess), err) return LoginDetails{ User: user, SessionToken: sessionToken, diff --git a/cmd/api/src/api/middleware/middleware.go b/cmd/api/src/api/middleware/middleware.go index 353f6fbc9a..8c2347cd11 100644 --- a/cmd/api/src/api/middleware/middleware.go +++ b/cmd/api/src/api/middleware/middleware.go @@ -19,6 +19,7 @@ package middleware import ( "context" "fmt" + "net" "net/http" "net/url" "strconv" @@ -149,11 +150,21 @@ func ContextMiddleware(next http.Handler) http.Handler { } func parseUserIP(r *http.Request) string { + var remoteIp string + + // The point of this code is to strip the port, so we don't need to save it. + if host, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { + log.Warnf("Error parsing remoteAddress 's': %s", r.RemoteAddr, err) + remoteIp = r.RemoteAddr + } else { + remoteIp = host + } + if result := r.Header.Get("X-Forwarded-For"); result == "" { log.Warnf("No data found in X-Forwarded-For header") - return r.RemoteAddr + return remoteIp } else { - result += "," + r.RemoteAddr + result += "," + remoteIp return result } } diff --git a/cmd/api/src/api/middleware/middleware_internal_test.go b/cmd/api/src/api/middleware/middleware_internal_test.go index e1c26fa135..528204bd34 100644 --- a/cmd/api/src/api/middleware/middleware_internal_test.go +++ b/cmd/api/src/api/middleware/middleware_internal_test.go @@ -91,14 +91,14 @@ func TestParseUserIP_XForwardedFor_RemoteAddr(t *testing.T) { req.Header.Set("X-Forwarded-For", strings.Join([]string{ip1, ip2, ip3}, ",")) req.RemoteAddr = "0.0.0.0:3000" - require.Equal(t, parseUserIP(req), strings.Join([]string{ip1, ip2, ip3, req.RemoteAddr}, ",")) + require.Equal(t, parseUserIP(req), strings.Join([]string{ip1, ip2, ip3, "0.0.0.0"}, ",")) } func TestParseUserIP_RemoteAddrOnly(t *testing.T) { req, err := http.NewRequest("GET", "/teapot", nil) require.Nil(t, err) req.RemoteAddr = "0.0.0.0:3000" - require.Equal(t, parseUserIP(req), req.RemoteAddr) + require.Equal(t, parseUserIP(req), "0.0.0.0") } func TestParsePreferHeaderWait(t *testing.T) { diff --git a/cmd/api/src/database/audit.go b/cmd/api/src/database/audit.go index 3e3f5290d0..8992cd9dd8 100644 --- a/cmd/api/src/database/audit.go +++ b/cmd/api/src/database/audit.go @@ -47,6 +47,10 @@ func newAuditLog(context context.Context, entry model.AuditEntry, idResolver aut CommitID: entry.CommitID, } + if entry.ErrorMsg != "" { + auditLog.Fields["error"] = entry.ErrorMsg + } + authContext := bheCtx.AuthCtx if !authContext.Authenticated() { return auditLog, ErrAuthContextInvalid @@ -65,10 +69,14 @@ func (s *BloodhoundDB) AppendAuditLog(ctx context.Context, entry model.AuditEntr if auditLog, err := newAuditLog(ctx, entry, s.idResolver); err != nil && err != ErrAuthContextInvalid { return fmt.Errorf("audit log append: %w", err) } else { - return CheckError(s.db.Create(&auditLog)) + return s.CreateAuditLog(auditLog) } } +func (s *BloodhoundDB) CreateAuditLog(auditLog model.AuditLog) error { + return CheckError(s.db.Create(&auditLog)) +} + func (s *BloodhoundDB) ListAuditLogs(before, after time.Time, offset, limit int, order string, filter model.SQLFilter) (model.AuditLogs, int, error) { var ( auditLogs model.AuditLogs @@ -123,6 +131,7 @@ func (s *BloodhoundDB) AuditableTransaction(ctx context.Context, auditEntry mode if err != nil { auditEntry.Status = model.AuditStatusFailure + auditEntry.ErrorMsg = err.Error() } else { auditEntry.Status = model.AuditStatusSuccess } diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index a5b4643fcf..6f98839499 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -82,6 +82,7 @@ type Database interface { RawFirst(value any) error Wipe() error Migrate() error + CreateAuditLog(auditLog model.AuditLog) error AppendAuditLog(ctx context.Context, entry model.AuditEntry) error ListAuditLogs(before, after time.Time, offset, limit int, order string, filter model.SQLFilter) (model.AuditLogs, int, error) CreateRole(role model.Role) (model.Role, error) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index c8320a944b..86cfc70612 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -155,6 +155,20 @@ func (mr *MockDatabaseMockRecorder) CreateAssetGroupSelector(arg0, arg1, arg2 in return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).CreateAssetGroupSelector), arg0, arg1, arg2) } +// CreateAuditLog mocks base method. +func (m *MockDatabase) CreateAuditLog(arg0 model.AuditLog) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateAuditLog", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateAuditLog indicates an expected call of CreateAuditLog. +func (mr *MockDatabaseMockRecorder) CreateAuditLog(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuditLog", reflect.TypeOf((*MockDatabase)(nil).CreateAuditLog), arg0) +} + // CreateAuthSecret mocks base method. func (m *MockDatabase) CreateAuthSecret(arg0 context.Context, arg1 model.AuthSecret) (model.AuthSecret, error) { m.ctrl.T.Helper() diff --git a/packages/go/headers/generate.go b/packages/go/headers/cmd/generate.go similarity index 100% rename from packages/go/headers/generate.go rename to packages/go/headers/cmd/generate.go