Skip to content

Commit

Permalink
Merge pull request #156 from matrix-org/kegan/accurate-load-positions
Browse files Browse the repository at this point in the history
Remove global load position
  • Loading branch information
kegsay authored Jun 15, 2023
2 parents 37087f9 + 2a94a5a commit 409431f
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 65 deletions.
19 changes: 19 additions & 0 deletions state/event_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,25 @@ func (t *EventTable) LatestEventInRooms(txn *sqlx.Tx, roomIDs []string, highestN
return
}

func (t *EventTable) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
// the position (event nid) may be for a random different room, so we need to find the highest nid <= this position for this room
var events []Event
err = t.db.Select(
&events,
`SELECT event_nid, room_id FROM syncv3_events
WHERE event_nid IN (SELECT max(event_nid) FROM syncv3_events WHERE event_nid <= $1 AND room_id = ANY($2) GROUP BY room_id)`,
highestNID, pq.StringArray(roomIDs),
)
if err == sql.ErrNoRows {
err = nil
}
roomToNID = make(map[string]int64)
for _, ev := range events {
roomToNID[ev.RoomID] = ev.NID
}
return
}

func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) {
var events []Event
err := txn.Select(&events, `SELECT event_nid, event FROM syncv3_events WHERE event_nid > $1 AND event_nid <= $2 AND room_id = $3 ORDER BY event_nid ASC LIMIT $4`,
Expand Down
89 changes: 89 additions & 0 deletions state/event_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"bytes"
"database/sql"
"fmt"
"reflect"
"testing"

"github.com/jmoiron/sqlx"
"github.com/tidwall/gjson"

"github.com/matrix-org/sliding-sync/sqlutil"
Expand Down Expand Up @@ -871,6 +873,93 @@ func TestRemoveUnsignedTXNID(t *testing.T) {
}
}

func TestLatestEventNIDInRooms(t *testing.T) {
db, close := connectToDB(t)
defer close()
table := NewEventTable(db)

var result map[string]int64
var err error
// Insert the following:
// - Room FIRST: [N]
// - Room SECOND: [N+1, N+2, N+3] (replace)
// - Room THIRD: [N+4] (max)
first := "!FIRST"
second := "!SECOND"
third := "!THIRD"
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
result, err = table.Insert(txn, []Event{
{
ID: "$N",
Type: "message",
RoomID: first,
JSON: []byte(`{}`),
},
{
ID: "$N+1",
Type: "message",
RoomID: second,
JSON: []byte(`{}`),
},
{
ID: "$N+2",
Type: "message",
RoomID: second,
JSON: []byte(`{}`),
},
{
ID: "$N+3",
Type: "message",
RoomID: second,
JSON: []byte(`{}`),
},
{
ID: "$N+4",
Type: "message",
RoomID: third,
JSON: []byte(`{}`),
},
}, false)
return err
})
assertNoError(t, err)

testCases := []struct {
roomIDs []string
highestNID int64
wantMap map[string]string
}{
// We should see FIRST=N, SECOND=N+3, THIRD=N+4 when querying LatestEventNIDInRooms with N+4
{
roomIDs: []string{first, second, third},
highestNID: result["$N+4"],
wantMap: map[string]string{
first: "$N", second: "$N+3", third: "$N+4",
},
},
// We should see FIRST=N, SECOND=N+2 when querying LatestEventNIDInRooms with N+2
{
roomIDs: []string{first, second, third},
highestNID: result["$N+2"],
wantMap: map[string]string{
first: "$N", second: "$N+2",
},
},
}
for _, tc := range testCases {
gotRoomToNID, err := table.LatestEventNIDInRooms(tc.roomIDs, int64(tc.highestNID))
assertNoError(t, err)
want := make(map[string]int64) // map event IDs to nids
for roomID, eventID := range tc.wantMap {
want[roomID] = int64(result[eventID])
}
if !reflect.DeepEqual(gotRoomToNID, want) {
t.Errorf("%+v: got %v want %v", tc, gotRoomToNID, want)
}
}

}

func TestEventTableSelectUnknownEventIDs(t *testing.T) {
db, close := connectToDB(t)
defer close()
Expand Down
39 changes: 26 additions & 13 deletions state/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ type StartupSnapshot struct {
AllJoinedMembers map[string][]string // room_id -> [user_id]
}

type LatestEvents struct {
Timeline []json.RawMessage
PrevBatch string
LatestNID int64
}

type Storage struct {
Accumulator *Accumulator
EventsTable *EventTable
Expand Down Expand Up @@ -584,16 +590,16 @@ func (s *Storage) RoomStateAfterEventPosition(ctx context.Context, roomIDs []str
return
}

func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, limit int) (map[string][]json.RawMessage, map[string]string, error) {
func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, limit int) (map[string]*LatestEvents, error) {
roomIDToRanges, err := s.visibleEventNIDsBetweenForRooms(userID, roomIDs, 0, to)
if err != nil {
return nil, nil, err
return nil, err
}
result := make(map[string][]json.RawMessage, len(roomIDs))
prevBatches := make(map[string]string, len(roomIDs))
result := make(map[string]*LatestEvents, len(roomIDs))
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
for roomID, ranges := range roomIDToRanges {
var earliestEventNID int64
var latestEventNID int64
var roomEvents []json.RawMessage
// start at the most recent range as we want to return the most recent `limit` events
for i := len(ranges) - 1; i >= 0; i-- {
Expand All @@ -608,26 +614,33 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64,
}
// keep pushing to the front so we end up with A,B,C
for _, ev := range events {
if latestEventNID == 0 { // set first time and never again
latestEventNID = ev.NID
}
roomEvents = append([]json.RawMessage{ev.JSON}, roomEvents...)
earliestEventNID = ev.NID
if len(roomEvents) >= limit {
break
}
}
}
latestEvents := LatestEvents{
LatestNID: latestEventNID,
Timeline: roomEvents,
}
if earliestEventNID != 0 {
// the oldest event needs a prev batch token, so find one now
prevBatch, err := s.EventsTable.SelectClosestPrevBatch(roomID, earliestEventNID)
if err != nil {
return fmt.Errorf("failed to select prev_batch for room %s : %s", roomID, err)
}
prevBatches[roomID] = prevBatch
latestEvents.PrevBatch = prevBatch
}
result[roomID] = roomEvents
result[roomID] = &latestEvents
}
return nil
})
return result, prevBatches, err
return result, err
}

func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []string, from, to int64) (map[string][][2]int64, error) {
Expand All @@ -641,7 +654,7 @@ func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []strin
return nil, fmt.Errorf("VisibleEventNIDsBetweenForRooms.SelectEventsWithTypeStateKeyInRooms: %s", err)
}
}
joinTimingsByRoomID, err := s.determineJoinedRoomsFromMemberships(membershipEvents)
joinTimingsAtFromByRoomID, err := s.determineJoinedRoomsFromMemberships(membershipEvents)
if err != nil {
return nil, fmt.Errorf("failed to work out joined rooms for %s at pos %d: %s", userID, from, err)
}
Expand All @@ -652,7 +665,7 @@ func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []strin
return nil, fmt.Errorf("failed to load membership events: %s", err)
}

return s.visibleEventNIDsWithData(joinTimingsByRoomID, membershipEvents, userID, from, to)
return s.visibleEventNIDsWithData(joinTimingsAtFromByRoomID, membershipEvents, userID, from, to)
}

// Work out the NID ranges to pull events from for this user. Given a from and to event nid stream position,
Expand Down Expand Up @@ -682,7 +695,7 @@ func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []strin
// - For Room E: from=1, to=15 returns { RoomE: [ [3,3], [13,15] ] } (tests invites)
func (s *Storage) VisibleEventNIDsBetween(userID string, from, to int64) (map[string][][2]int64, error) {
// load *ALL* joined rooms for this user at from (inclusive)
joinTimingsByRoomID, err := s.JoinedRoomsAfterPosition(userID, from)
joinTimingsAtFromByRoomID, err := s.JoinedRoomsAfterPosition(userID, from)
if err != nil {
return nil, fmt.Errorf("failed to work out joined rooms for %s at pos %d: %s", userID, from, err)
}
Expand All @@ -693,10 +706,10 @@ func (s *Storage) VisibleEventNIDsBetween(userID string, from, to int64) (map[st
return nil, fmt.Errorf("failed to load membership events: %s", err)
}

return s.visibleEventNIDsWithData(joinTimingsByRoomID, membershipEvents, userID, from, to)
return s.visibleEventNIDsWithData(joinTimingsAtFromByRoomID, membershipEvents, userID, from, to)
}

func (s *Storage) visibleEventNIDsWithData(joinTimingsByRoomID map[string]internal.EventMetadata, membershipEvents []Event, userID string, from, to int64) (map[string][][2]int64, error) {
func (s *Storage) visibleEventNIDsWithData(joinTimingsAtFromByRoomID map[string]internal.EventMetadata, membershipEvents []Event, userID string, from, to int64) (map[string][][2]int64, error) {
// load membership events in order and bucket based on room ID
roomIDToLogs := make(map[string][]membershipEvent)
for _, ev := range membershipEvents {
Expand Down Expand Up @@ -758,7 +771,7 @@ func (s *Storage) visibleEventNIDsWithData(joinTimingsByRoomID map[string]intern

// For each joined room, perform the algorithm and delete the logs afterwards
result := make(map[string][][2]int64)
for joinedRoomID, _ := range joinTimingsByRoomID {
for joinedRoomID, _ := range joinTimingsAtFromByRoomID {
roomResult := calculateVisibleEventNIDs(true, from, to, roomIDToLogs[joinedRoomID])
result[joinedRoomID] = roomResult
delete(roomIDToLogs, joinedRoomID)
Expand Down
24 changes: 19 additions & 5 deletions sync3/caches/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.C
// Dispatcher for new events.
type GlobalCache struct {
// LoadJoinedRoomsOverride allows tests to mock out the behaviour of LoadJoinedRooms.
LoadJoinedRoomsOverride func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error)
LoadJoinedRoomsOverride func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, latestNIDs map[string]int64, err error)

// inserts are done by v2 poll loops, selects are done by v3 request threads
// there are lots of overlapping keys as many users (threads) can be joined to the same room (key)
Expand Down Expand Up @@ -135,23 +135,37 @@ func (c *GlobalCache) copyRoom(roomID string) *internal.RoomMetadata {
// The two maps returned by this function have exactly the same set of keys. Each is nil
// iff a non-nil error is returned.
// TODO: remove with LoadRoomState?
// FIXME: return args are a mess
func (c *GlobalCache) LoadJoinedRooms(ctx context.Context, userID string) (
pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimingByRoomID map[string]internal.EventMetadata, err error,
pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimingByRoomID map[string]internal.EventMetadata,
latestNIDs map[string]int64, err error,
) {
if c.LoadJoinedRoomsOverride != nil {
return c.LoadJoinedRoomsOverride(userID)
}
initialLoadPosition, err := c.store.LatestEventNID()
if err != nil {
return 0, nil, nil, err
return 0, nil, nil, nil, err
}
joinTimingByRoomID, err = c.store.JoinedRoomsAfterPosition(userID, initialLoadPosition)
if err != nil {
return 0, nil, nil, err
return 0, nil, nil, nil, err
}
roomIDs := make([]string, len(joinTimingByRoomID))
i := 0
for roomID := range joinTimingByRoomID {
roomIDs[i] = roomID
i++
}

latestNIDs, err = c.store.EventsTable.LatestEventNIDInRooms(roomIDs, initialLoadPosition)
if err != nil {
return 0, nil, nil, nil, err
}

// TODO: no guarantee that this state is the same as latest unless called in a dispatcher loop
rooms := c.LoadRoomsFromMap(ctx, joinTimingByRoomID)
return initialLoadPosition, rooms, joinTimingByRoomID, nil
return initialLoadPosition, rooms, joinTimingByRoomID, latestNIDs, nil
}

func (c *GlobalCache) LoadStateEvent(ctx context.Context, roomID string, loadPosition int64, evType, stateKey string) json.RawMessage {
Expand Down
18 changes: 8 additions & 10 deletions sync3/caches/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ type UserRoomData struct {
HighlightCount int
Invite *InviteData

// these fields are set by LazyLoadTimelines and are per-function call, and are not persisted in-memory.
RequestedPrevBatch string
RequestedTimeline []json.RawMessage
// this field is set by LazyLoadTimelines and is per-function call, and is not persisted in-memory.
// The zero value of this safe to use (0 latest nid, no prev batch, no timeline).
RequestedLatestEvents state.LatestEvents

// TODO: should Canonicalised really be in RoomConMetadata? It's only set in SetRoom AFAICS
CanonicalisedName string // stripped leading symbols like #, all in lower case
Expand Down Expand Up @@ -218,7 +218,7 @@ func (c *UserCache) Unsubscribe(id int) {
func (c *UserCache) OnRegistered(ctx context.Context) error {
// select all spaces the user is a part of to seed the cache correctly. This has to be done in
// the OnRegistered callback which has locking guarantees. This is why...
_, joinedRooms, joinTimings, err := c.globalCache.LoadJoinedRooms(ctx, c.UserID)
_, joinedRooms, joinTimings, _, err := c.globalCache.LoadJoinedRooms(ctx, c.UserID)
if err != nil {
return fmt.Errorf("failed to load joined rooms: %s", err)
}
Expand Down Expand Up @@ -295,24 +295,22 @@ func (c *UserCache) LazyLoadTimelines(ctx context.Context, loadPos int64, roomID
return c.LazyRoomDataOverride(loadPos, roomIDs, maxTimelineEvents)
}
result := make(map[string]UserRoomData)
roomIDToEvents, roomIDToPrevBatch, err := c.store.LatestEventsInRooms(c.UserID, roomIDs, loadPos, maxTimelineEvents)
roomIDToLatestEvents, err := c.store.LatestEventsInRooms(c.UserID, roomIDs, loadPos, maxTimelineEvents)
if err != nil {
logger.Err(err).Strs("rooms", roomIDs).Msg("failed to get LatestEventsInRooms")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return nil
}
c.roomToDataMu.Lock()
for _, requestedRoomID := range roomIDs {
events := roomIDToEvents[requestedRoomID]
latestEvents := roomIDToLatestEvents[requestedRoomID]
urd, ok := c.roomToData[requestedRoomID]
if !ok {
urd = NewUserRoomData()
}
urd.RequestedTimeline = events
if len(events) > 0 {
urd.RequestedPrevBatch = roomIDToPrevBatch[requestedRoomID]
if latestEvents != nil {
urd.RequestedLatestEvents = *latestEvents
}

result[requestedRoomID] = urd
}
c.roomToDataMu.Unlock()
Expand Down
Loading

0 comments on commit 409431f

Please sign in to comment.