Skip to content

Commit

Permalink
Merge pull request #149 from matrix-org/dmr/extension-scoping-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson authored Aug 16, 2023
2 parents 2738d09 + 7e9815d commit 75b3d16
Show file tree
Hide file tree
Showing 12 changed files with 389 additions and 22 deletions.
9 changes: 8 additions & 1 deletion sync3/extensions/account_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ func TestLiveAccountDataAggregation(t *testing.T) {
ext := &AccountDataRequest{
Core: Core{
Enabled: &boolTrue,
Lists: []string{"*"},
Rooms: []string{"*"},
},
}
var res Response
var extCtx Context
var extCtx = Context{
AllSubscribedRooms: []string{roomA},
}
room1 := &caches.RoomAccountDataUpdate{
RoomUpdate: &dummyRoomUpdate{
roomID: roomA,
Expand Down Expand Up @@ -78,6 +82,9 @@ func TestLiveAccountDataAggregation(t *testing.T) {
room1.AccountData[0].Data, room1.AccountData[1].Data,
},
}
if res.AccountData == nil {
t.Fatalf("Didn't get account data: %v", res)
}
if !reflect.DeepEqual(res.AccountData.Rooms, wantRoomAccountData) {
t.Fatalf("got %+v\nwant %+v", res.AccountData.Rooms, wantRoomAccountData)
}
Expand Down
64 changes: 56 additions & 8 deletions sync3/extensions/extensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ type GenericRequest interface {
OnlyLists() []string
// Returns the value of the `rooms` JSON key. nil for "not specified".
OnlyRooms() []string
// InterpretAsInitial interprets this as an initial request rather than a delta, and
// overwrites fields accordingly. This can be useful when fields have default
// values, but is a little ugly. Use sparingly.
InterpretAsInitial()
// Overwrite fields in the request by side-effecting on this struct.
ApplyDelta(next GenericRequest)
// ProcessInitial provides a means for extensions to return data to clients immediately.
Expand Down Expand Up @@ -77,6 +81,17 @@ func (r *Core) OnlyRooms() []string {
return r.Rooms
}

func (r *Core) InterpretAsInitial() {
// An omitted/nil value for lists and rooms normally means "no change".
// If this extension has never been specified before, nil means "all lists/rooms".
if r.Lists == nil {
r.Lists = []string{"*"}
}
if r.Rooms == nil {
r.Rooms = []string{"*"}
}
}

func (r *Core) ApplyDelta(gnext GenericRequest) {
if gnext == nil {
return
Expand All @@ -100,22 +115,28 @@ func (r *Core) ApplyDelta(gnext GenericRequest) {
// according to the "core" extension scoping logic. Extensions are free to suppress
// updates for a room based on additional criteria.
func (r *Core) RoomInScope(roomID string, extCtx Context) bool {
// If the extension hasn't had its scope configured, process everything.
if r.Lists == nil && r.Rooms == nil {
return true
// First determine which rooms the extension is monitoring outside of any sliding windows.
roomsToMonitor := r.Rooms
if len(roomsToMonitor) > 0 && roomsToMonitor[0] == "*" {
roomsToMonitor = extCtx.AllSubscribedRooms
}

// If this extension has been explicitly subscribed to this room, process the update.
for _, roomInScope := range r.Rooms {
// Process the update if this room is one of those monitored rooms.
for _, roomInScope := range roomsToMonitor {
if roomInScope == roomID {
return true
}
}

// If the room belongs to one of the lists that this extension should process, process the update.
// Next determine which lists the extension is monitoring.
listsToMonitor := r.Lists
if len(listsToMonitor) > 0 && listsToMonitor[0] == "*" {
listsToMonitor = extCtx.AllLists
}

// Process the update if the room is visible in one of those lists.
visibleInLists := extCtx.RoomIDsToLists[roomID]
for _, visibleInList := range visibleInLists {
for _, shouldProcessList := range r.Lists {
for _, shouldProcessList := range listsToMonitor {
if visibleInList == shouldProcessList {
return true
}
Expand Down Expand Up @@ -175,6 +196,8 @@ func (r Request) EnabledExtensions() (exts []GenericRequest) {
return
}

// ApplyDelta applies the `next` request as a delta atop the previous Request r, and
// returns the result as a new Request.
func (r Request) ApplyDelta(next *Request) Request {
currFields := r.fields()
nextFields := next.fields()
Expand All @@ -187,6 +210,7 @@ func (r Request) ApplyDelta(next *Request) Request {
}
if isNil(curr) {
// the next field is what we want to apply
next.InterpretAsInitial()
currFields[i] = next
hasChanges = true
} else {
Expand All @@ -202,6 +226,24 @@ func (r Request) ApplyDelta(next *Request) Request {
return r
}

func (r *Request) InterpretAsInitial() {
if r.ToDevice != nil {
r.ToDevice.InterpretAsInitial()
}
if r.E2EE != nil {
r.E2EE.InterpretAsInitial()
}
if r.AccountData != nil {
r.AccountData.InterpretAsInitial()
}
if r.Typing != nil {
r.Typing.InterpretAsInitial()
}
if r.Receipts != nil {
r.Receipts.InterpretAsInitial()
}
}

// Response represents the top-level `extensions` key in the JSON response.
//
// To add a new extension, add a field here and in fields().
Expand Down Expand Up @@ -233,6 +275,8 @@ func (r Response) HasData(isInitial bool) bool {
return false
}

// Context is a summary of useful information about the sync3.Request and the state of
// the requester's connection.
type Context struct {
*Handler
// RoomIDToTimeline is a map from room IDs to slices of event IDs. The keys are the
Expand All @@ -253,6 +297,10 @@ type Context struct {
// enclose those sliding windows. Values should be nonnil and nonempty, and may
// contain multiple list names.
RoomIDsToLists map[string][]string
// AllLists is the slice of list names provided to the Sliding Window API.
AllLists []string
// AllSubscribedRooms is the slice of room IDs provided to the Room Subscription API.
AllSubscribedRooms []string
}

type HandlerInterface interface {
Expand Down
2 changes: 2 additions & 0 deletions sync3/extensions/extensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ func TestExtension_ApplyDelta(t *testing.T) {
AccountData: &AccountDataRequest{
Core: Core{
Enabled: &boolTrue,
Lists: []string{"*"},
Rooms: []string{"*"},
},
},
},
Expand Down
6 changes: 5 additions & 1 deletion sync3/extensions/receipts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ func TestLiveReceiptsAggregation(t *testing.T) {
ext := &ReceiptsRequest{
Core: Core{
Enabled: &boolTrue,
Lists: []string{"*"},
Rooms: []string{"*"},
},
}
var res Response
var extCtx Context
extCtx := Context{
AllSubscribedRooms: []string{roomA, roomB},
}
receiptA1 := &caches.ReceiptUpdate{
Receipt: internal.Receipt{
RoomID: roomA,
Expand Down
8 changes: 6 additions & 2 deletions sync3/extensions/typing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ func TestLiveTypingAggregation(t *testing.T) {
ext := &TypingRequest{
Core: Core{
Enabled: &boolTrue,
Lists: []string{"*"},
Rooms: []string{"*"},
},
}
var res Response
var extCtx Context
extCtx := Context{
AllSubscribedRooms: []string{roomA, roomB, roomC},
}
typingA1 := &caches.TypingUpdate{
RoomUpdate: &dummyRoomUpdate{
roomID: roomA,
Expand Down Expand Up @@ -83,6 +87,6 @@ func TestLiveTypingAggregation(t *testing.T) {
ext.AppendLive(ctx, &res, extCtx, eventC1)
want[roomC] = eventC1.GlobalRoomMetadata().TypingEvent
if !reflect.DeepEqual(res.Typing.Rooms, want) {
t.Fatalf("got %+v\nwant %+v", res.Typing.Rooms, want)
t.Fatalf("got %s\nwant %s", res.Typing.Rooms, want)
}
}
24 changes: 20 additions & 4 deletions sync3/handler/connstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,13 @@ func (s *ConnState) onIncomingRequest(reqCtx context.Context, req *sync3.Request
// is being notified about (e.g. for room account data)
extCtx, region := internal.StartSpan(reqCtx, "extensions")
response.Extensions = s.extensionsHandler.Handle(extCtx, s.muxedReq.Extensions, extensions.Context{
UserID: s.userID,
DeviceID: s.deviceID,
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
IsInitial: isInitial,
UserID: s.userID,
DeviceID: s.deviceID,
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
IsInitial: isInitial,
RoomIDsToLists: s.lists.ListsByVisibleRoomIDs(s.muxedReq.Lists),
AllSubscribedRooms: keys(s.roomSubscriptions),
AllLists: s.muxedReq.ListKeys(),
})
region.End()

Expand Down Expand Up @@ -758,3 +761,16 @@ func clampSliceRangeToListSize(ctx context.Context, r [2]int64, totalRooms int64
return [2]int64{r[0], lastIndexWithRoom}
}
}

// Returns a slice containing copies of the keys of the given map, in no particular
// order.
func keys[K comparable, V any](m map[K]V) []K {
if m == nil {
return nil
}
output := make([]K, len(m))
for key := range m {
output = append(output, key)
}
return output
}
12 changes: 7 additions & 5 deletions sync3/handler/connstate_live.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,13 @@ func (s *connStateLive) processUpdate(ctx context.Context, update caches.Update,
// pass event to extensions AFTER processing
roomIDsToLists := s.lists.ListsByVisibleRoomIDs(s.muxedReq.Lists)
s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extensions.Context{
IsInitial: false,
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
UserID: s.userID,
DeviceID: s.deviceID,
RoomIDsToLists: roomIDsToLists,
IsInitial: false,
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
UserID: s.userID,
DeviceID: s.deviceID,
RoomIDsToLists: roomIDsToLists,
AllSubscribedRooms: keys(s.roomSubscriptions),
AllLists: s.muxedReq.ListKeys(),
})
}

Expand Down
6 changes: 6 additions & 0 deletions sync3/range.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ import (
"sort"
)

// SliceRanges is a slice of integer pairs [a, b]. Each pair represents the integers x
// in the range a <= x <= b (note: closed at both ends). The slice as a whole represents
// the set of integers x in any of the slice's closed intervals.
//
// Within the slice, pairs are arranged in no particular order. Two pairs may represent
// overlapping ranges of integers; use Valid to test for this.
type SliceRanges [][2]int64

func (r SliceRanges) Valid() bool {
Expand Down
10 changes: 10 additions & 0 deletions sync3/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ type RequestListDelta struct {
// request.
func (r *Request) ApplyDelta(nextReq *Request) (result *Request, delta *RequestDelta) {
if r == nil {
nextReq.Extensions.InterpretAsInitial()
result = &Request{
Extensions: nextReq.Extensions,
}
Expand Down Expand Up @@ -470,6 +471,15 @@ func (r *Request) ApplyDelta(nextReq *Request) (result *Request, delta *RequestD
return
}

// ListKeys builds a slice containing the names of the lists this request has defined.
func (r *Request) ListKeys() []string {
listKeys := make([]string, 0, len(r.Lists))
for listKey, _ := range r.Lists {
listKeys = append(listKeys, listKey)
}
return listKeys
}

type RequestFilters struct {
Spaces []string `json:"spaces"`
IsDM *bool `json:"is_dm"`
Expand Down
Loading

0 comments on commit 75b3d16

Please sign in to comment.