Skip to content

Commit

Permalink
improve session expiearion code and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stremovsky committed Dec 13, 2024
1 parent f5bfabd commit c326fb2
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 22 deletions.
56 changes: 36 additions & 20 deletions src/sessions_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"

Expand All @@ -27,27 +28,35 @@ func (e mainEnv) createSession(w http.ResponseWriter, r *http.Request, ps httpro
if e.enforceAdmin(w, r, event) == "" {
return
}
expiration := e.conf.Policy.MaxSessionRetentionPeriod
userJSON, err := getUserJSON(r, e.conf.Sms.DefaultCountry)
records, err := getJSONPostMap(r)
if err != nil {
returnError(w, r, "failed to decode request body", 405, err, event)
return
}
if len(userJSON.jsonData) == 0 {
returnError(w, r, "empty request body", 405, nil, event)
if len(records) == 0 {
returnError(w, r, "empty body", 405, nil, event)
return
}
expirationStr := getStringValue(records["expiration"])
expiration := setExpiration(e.conf.Policy.MaxSessionRetentionPeriod, expirationStr)
log.Printf("Record expiration: %s", expiration)
userToken := getStringValue(records["token"])
userLogin := getStringValue(records["login"])
userEmail := getStringValue(records["email"])
userPhone := getStringValue(records["phone"])
userCustomIdx := getStringValue(records["custom"])

var userBson bson.M
if len(userJSON.loginIdx) > 0 {
userBson, err = e.db.lookupUserRecordByIndex("login", userJSON.loginIdx, e.conf)
} else if len(userJSON.emailIdx) > 0 {
userBson, err = e.db.lookupUserRecordByIndex("email", userJSON.emailIdx, e.conf)
} else if len(userJSON.phoneIdx) > 0 {
userBson, err = e.db.lookupUserRecordByIndex("phone", userJSON.phoneIdx, e.conf)
} else if len(userJSON.customIdx) > 0 {
userBson, err = e.db.lookupUserRecordByIndex("custom", userJSON.customIdx, e.conf)
} else if len(userJSON.token) > 0 {
userBson, err = e.db.lookupUserRecord(userJSON.token)
if len(userLogin) > 0 {
userBson, err = e.db.lookupUserRecordByIndex("login", userLogin, e.conf)
} else if len(userEmail) > 0 {
userBson, err = e.db.lookupUserRecordByIndex("email", userEmail, e.conf)
} else if len(userPhone) > 0 {
userBson, err = e.db.lookupUserRecordByIndex("phone", userPhone, e.conf)
} else if len(userCustomIdx) > 0 {
userBson, err = e.db.lookupUserRecordByIndex("custom", userCustomIdx, e.conf)
} else if len(userToken) > 0 {
userBson, err = e.db.lookupUserRecord(userToken)
}
if err != nil {
returnError(w, r, "internal error", 405, err, event)
Expand All @@ -59,7 +68,12 @@ func (e mainEnv) createSession(w http.ResponseWriter, r *http.Request, ps httpro
userTOKEN = userBson["token"].(string)
event.Record = userTOKEN
}
session, err = e.db.createSessionRecord(session, userTOKEN, expiration, userJSON.jsonData)
jsonData, err := json.Marshal(records)
if err != nil {
returnError(w, r, "internal error", 405, err, event)
return
}
session, err = e.db.createSessionRecord(session, userTOKEN, expiration, jsonData)
if err != nil {
returnError(w, r, "internal error", 405, err, event)
return
Expand Down Expand Up @@ -111,6 +125,7 @@ func (e mainEnv) newUserSession(w http.ResponseWriter, r *http.Request, ps httpr
}
expirationStr := getStringValue(records["expiration"])
expiration := setExpiration(e.conf.Policy.MaxSessionRetentionPeriod, expirationStr)
log.Printf("Record expiration: %s", expiration)
jsonData, err := json.Marshal(records)
if err != nil {
returnError(w, r, "internal error", 405, err, event)
Expand Down Expand Up @@ -167,21 +182,22 @@ func (e mainEnv) getUserSessions(w http.ResponseWriter, r *http.Request, ps http

func (e mainEnv) getSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
session := ps.ByName("session")
var event *auditEvent
event := audit("get session", session, "session", session)
defer func() {
if event != nil {
event.submit(e.db, e.conf)
}
}()
when, record, userTOKEN, err := e.db.getSession(session)
if len(userTOKEN) > 0 {
event.Record = userTOKEN
e.db.store.DeleteExpired(storage.TblName.Sessions, "token", userTOKEN)
}
if err != nil {
returnError(w, r, err.Error(), 405, err, event)
return
}
if len(userTOKEN) > 0 {
event = audit("get session", session, "session", session)
event.Record = userTOKEN
}

if e.enforceAuth(w, r, event) == "" {
return
}
Expand Down
4 changes: 3 additions & 1 deletion src/sessions_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ type sessionEvent struct {
func (dbobj dbcon) createSessionRecord(sessionUUID string, userTOKEN string, expiration string, data []byte) (string, error) {
var endtime int32
var err error
now := int32(time.Now().Unix())
if len(expiration) > 0 {
endtime, err = parseExpiration(expiration)
if err != nil {
return "", err
}
//log.Printf("expiration set to: %d, now: %d", endtime, now)
}
recordKey, err := generateRecordKey()
if err != nil {
Expand All @@ -34,7 +36,6 @@ func (dbobj dbcon) createSessionRecord(sessionUUID string, userTOKEN string, exp
}
encodedStr := base64.StdEncoding.EncodeToString(encoded)
bdoc := bson.M{}
now := int32(time.Now().Unix())
bdoc["token"] = userTOKEN
bdoc["session"] = sessionUUID
bdoc["endtime"] = endtime
Expand Down Expand Up @@ -63,6 +64,7 @@ func (dbobj dbcon) getSession(sessionUUID string) (int32, []byte, string, error)
}
// check expiration
now := int32(time.Now().Unix())
//log.Printf("getSession checking now: %d exp %d", now, record["endtime"].(int32))
if now > record["endtime"].(int32) {
return 0, nil, "", errors.New("session expired")
}
Expand Down
28 changes: 27 additions & 1 deletion src/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"

uuid "github.com/hashicorp/go-uuid"
)
Expand Down Expand Up @@ -38,6 +39,31 @@ func helpGetUserLoginSessions(login string) (map[string]interface{}, error) {
return helpServe(request)
}

func TestCreateQuickSession(t *testing.T) {
userJSON := `{"login":"bob"}`
raw, err := helpCreateUser(userJSON)
if err != nil {
t.Fatalf("error: %s", err)
}
if _, found := raw["status"]; !found || raw["status"].(string) != "ok" {
t.Fatalf("failed to create user")
}
data := `{"expiration":"3s","cookie":"abcdefg","login":"bob"}`
sid, _ := uuid.GenerateUUID()
raw, _ = helpCreateSession(sid, data)
if _, ok := raw["status"]; !ok || raw["status"].(string) != "ok" {
t.Fatalf("failed to create session")
}
sessionTOKEN := raw["session"].(string)
time.Sleep(5 * time.Second)
log.Printf("After delay--------")
raw, _ = helpGetSession(sessionTOKEN)
log.Printf("Got: %v", raw)
if _, ok := raw["status"]; !ok || raw["status"].(string) == "ok" {
t.Fatalf("Should failed to get session")
}
}

func TestCreateSessionRecord(t *testing.T) {
userJSON := `{"login":"alex"}`
raw, err := helpCreateUser(userJSON)
Expand Down Expand Up @@ -113,7 +139,7 @@ func TestCreateFakeUserSession(t *testing.T) {
sid, _ := uuid.GenerateUUID()
raw, _ := helpCreateSession(sid, data)
if _, ok := raw["status"]; ok && raw["status"].(string) != "ok" {
t.Fatalf("Should fail to create session")
t.Fatalf("Failed to create anonymous session")
}
}

Expand Down

0 comments on commit c326fb2

Please sign in to comment.