From c326fb2eb1c4ccb56a2f1e21f26f2d2804f75dad Mon Sep 17 00:00:00 2001 From: yuli Date: Fri, 13 Dec 2024 16:53:58 +0200 Subject: [PATCH] improve session expiearion code and tests --- src/sessions_api.go | 56 ++++++++++++++++++++++++++++---------------- src/sessions_db.go | 4 +++- src/sessions_test.go | 28 +++++++++++++++++++++- 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/src/sessions_api.go b/src/sessions_api.go index d94f4086..8fffdcc9 100644 --- a/src/sessions_api.go +++ b/src/sessions_api.go @@ -3,6 +3,7 @@ package main import ( "encoding/json" "fmt" + "log" "net/http" "strings" @@ -27,27 +28,35 @@ func (e mainEnv) createSession(w http.ResponseWriter, r *http.Request, ps httpro if e.enforceAdmin(w, r, event) == "" { return } - expiration := e.conf.Policy.MaxSessionRetentionPeriod - userJSON, err := getUserJSON(r, e.conf.Sms.DefaultCountry) + records, err := getJSONPostMap(r) if err != nil { returnError(w, r, "failed to decode request body", 405, err, event) return } - if len(userJSON.jsonData) == 0 { - returnError(w, r, "empty request body", 405, nil, event) + if len(records) == 0 { + returnError(w, r, "empty body", 405, nil, event) return } + expirationStr := getStringValue(records["expiration"]) + expiration := setExpiration(e.conf.Policy.MaxSessionRetentionPeriod, expirationStr) + log.Printf("Record expiration: %s", expiration) + userToken := getStringValue(records["token"]) + userLogin := getStringValue(records["login"]) + userEmail := getStringValue(records["email"]) + userPhone := getStringValue(records["phone"]) + userCustomIdx := getStringValue(records["custom"]) + var userBson bson.M - if len(userJSON.loginIdx) > 0 { - userBson, err = e.db.lookupUserRecordByIndex("login", userJSON.loginIdx, e.conf) - } else if len(userJSON.emailIdx) > 0 { - userBson, err = e.db.lookupUserRecordByIndex("email", userJSON.emailIdx, e.conf) - } else if len(userJSON.phoneIdx) > 0 { - userBson, err = e.db.lookupUserRecordByIndex("phone", userJSON.phoneIdx, e.conf) - } else if len(userJSON.customIdx) > 0 { - userBson, err = e.db.lookupUserRecordByIndex("custom", userJSON.customIdx, e.conf) - } else if len(userJSON.token) > 0 { - userBson, err = e.db.lookupUserRecord(userJSON.token) + if len(userLogin) > 0 { + userBson, err = e.db.lookupUserRecordByIndex("login", userLogin, e.conf) + } else if len(userEmail) > 0 { + userBson, err = e.db.lookupUserRecordByIndex("email", userEmail, e.conf) + } else if len(userPhone) > 0 { + userBson, err = e.db.lookupUserRecordByIndex("phone", userPhone, e.conf) + } else if len(userCustomIdx) > 0 { + userBson, err = e.db.lookupUserRecordByIndex("custom", userCustomIdx, e.conf) + } else if len(userToken) > 0 { + userBson, err = e.db.lookupUserRecord(userToken) } if err != nil { returnError(w, r, "internal error", 405, err, event) @@ -59,7 +68,12 @@ func (e mainEnv) createSession(w http.ResponseWriter, r *http.Request, ps httpro userTOKEN = userBson["token"].(string) event.Record = userTOKEN } - session, err = e.db.createSessionRecord(session, userTOKEN, expiration, userJSON.jsonData) + jsonData, err := json.Marshal(records) + if err != nil { + returnError(w, r, "internal error", 405, err, event) + return + } + session, err = e.db.createSessionRecord(session, userTOKEN, expiration, jsonData) if err != nil { returnError(w, r, "internal error", 405, err, event) return @@ -111,6 +125,7 @@ func (e mainEnv) newUserSession(w http.ResponseWriter, r *http.Request, ps httpr } expirationStr := getStringValue(records["expiration"]) expiration := setExpiration(e.conf.Policy.MaxSessionRetentionPeriod, expirationStr) + log.Printf("Record expiration: %s", expiration) jsonData, err := json.Marshal(records) if err != nil { returnError(w, r, "internal error", 405, err, event) @@ -167,21 +182,22 @@ func (e mainEnv) getUserSessions(w http.ResponseWriter, r *http.Request, ps http func (e mainEnv) getSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { session := ps.ByName("session") - var event *auditEvent + event := audit("get session", session, "session", session) defer func() { if event != nil { event.submit(e.db, e.conf) } }() when, record, userTOKEN, err := e.db.getSession(session) + if len(userTOKEN) > 0 { + event.Record = userTOKEN + e.db.store.DeleteExpired(storage.TblName.Sessions, "token", userTOKEN) + } if err != nil { returnError(w, r, err.Error(), 405, err, event) return } - if len(userTOKEN) > 0 { - event = audit("get session", session, "session", session) - event.Record = userTOKEN - } + if e.enforceAuth(w, r, event) == "" { return } diff --git a/src/sessions_db.go b/src/sessions_db.go index d9827b9d..cddefbef 100644 --- a/src/sessions_db.go +++ b/src/sessions_db.go @@ -18,11 +18,13 @@ type sessionEvent struct { func (dbobj dbcon) createSessionRecord(sessionUUID string, userTOKEN string, expiration string, data []byte) (string, error) { var endtime int32 var err error + now := int32(time.Now().Unix()) if len(expiration) > 0 { endtime, err = parseExpiration(expiration) if err != nil { return "", err } + //log.Printf("expiration set to: %d, now: %d", endtime, now) } recordKey, err := generateRecordKey() if err != nil { @@ -34,7 +36,6 @@ func (dbobj dbcon) createSessionRecord(sessionUUID string, userTOKEN string, exp } encodedStr := base64.StdEncoding.EncodeToString(encoded) bdoc := bson.M{} - now := int32(time.Now().Unix()) bdoc["token"] = userTOKEN bdoc["session"] = sessionUUID bdoc["endtime"] = endtime @@ -63,6 +64,7 @@ func (dbobj dbcon) getSession(sessionUUID string) (int32, []byte, string, error) } // check expiration now := int32(time.Now().Unix()) + //log.Printf("getSession checking now: %d exp %d", now, record["endtime"].(int32)) if now > record["endtime"].(int32) { return 0, nil, "", errors.New("session expired") } diff --git a/src/sessions_test.go b/src/sessions_test.go index a1b13eac..4e79bc02 100644 --- a/src/sessions_test.go +++ b/src/sessions_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" uuid "github.com/hashicorp/go-uuid" ) @@ -38,6 +39,31 @@ func helpGetUserLoginSessions(login string) (map[string]interface{}, error) { return helpServe(request) } +func TestCreateQuickSession(t *testing.T) { + userJSON := `{"login":"bob"}` + raw, err := helpCreateUser(userJSON) + if err != nil { + t.Fatalf("error: %s", err) + } + if _, found := raw["status"]; !found || raw["status"].(string) != "ok" { + t.Fatalf("failed to create user") + } + data := `{"expiration":"3s","cookie":"abcdefg","login":"bob"}` + sid, _ := uuid.GenerateUUID() + raw, _ = helpCreateSession(sid, data) + if _, ok := raw["status"]; !ok || raw["status"].(string) != "ok" { + t.Fatalf("failed to create session") + } + sessionTOKEN := raw["session"].(string) + time.Sleep(5 * time.Second) + log.Printf("After delay--------") + raw, _ = helpGetSession(sessionTOKEN) + log.Printf("Got: %v", raw) + if _, ok := raw["status"]; !ok || raw["status"].(string) == "ok" { + t.Fatalf("Should failed to get session") + } +} + func TestCreateSessionRecord(t *testing.T) { userJSON := `{"login":"alex"}` raw, err := helpCreateUser(userJSON) @@ -113,7 +139,7 @@ func TestCreateFakeUserSession(t *testing.T) { sid, _ := uuid.GenerateUUID() raw, _ := helpCreateSession(sid, data) if _, ok := raw["status"]; ok && raw["status"].(string) != "ok" { - t.Fatalf("Should fail to create session") + t.Fatalf("Failed to create anonymous session") } }