Skip to content

Commit

Permalink
bridgev2/database: include portal receiver in reaction queries
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Nov 14, 2024
1 parent 3f9a637 commit 21aa329
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
34 changes: 17 additions & 17 deletions bridgev2/database/reaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ const (
getReactionBaseQuery = `
SELECT bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata FROM reaction
`
getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5`
getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 AND emoji_id=$4 ORDER BY message_part_id ASC LIMIT 1`
getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 ORDER BY timestamp DESC`
getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2`
getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3`
getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6`
getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 AND emoji_id=$5 ORDER BY message_part_id ASC LIMIT 1`
getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 ORDER BY timestamp DESC`
getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3`
getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4`
getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2`
upsertReactionQuery = `
INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata)
Expand All @@ -54,28 +54,28 @@ const (
DO UPDATE SET sender_mxid=excluded.sender_mxid, mxid=excluded.mxid, timestamp=excluded.timestamp, emoji=excluded.emoji, metadata=excluded.metadata
`
deleteReactionQuery = `
DELETE FROM reaction WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5
DELETE FROM reaction WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6
`
)

func (rq *ReactionQuery) GetByID(ctx context.Context, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, messageID, messagePartID, senderID, emojiID)
func (rq *ReactionQuery) GetByID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, receiver, messageID, messagePartID, senderID, emojiID)
}

func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, messageID, senderID, emojiID)
func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, receiver, messageID, senderID, emojiID)
}

func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, messageID, senderID)
func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, receiver, messageID, senderID)
}

func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, messageID networkid.MessageID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, messageID)
func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, receiver, messageID)
}

func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, messageID, partID)
func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, receiver, messageID, partID)
}

func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
Expand All @@ -89,7 +89,7 @@ func (rq *ReactionQuery) Upsert(ctx context.Context, reaction *Reaction) error {

func (rq *ReactionQuery) Delete(ctx context.Context, reaction *Reaction) error {
ensureBridgeIDMatches(&reaction.BridgeID, rq.BridgeID)
return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID)
return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.Room.Receiver, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID)
}

func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
Expand Down
14 changes: 7 additions & 7 deletions bridgev2/portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if portal.Bridge.Config.OutgoingMessageReID {
deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID)
}
existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID)
existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID)
if err != nil {
log.Err(err).Msg("Failed to check if reaction is a duplicate")
return
Expand All @@ -1169,7 +1169,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
}
react.PreHandleResp = &preResp
if preResp.MaxReactions > 0 {
allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, reactionTarget.ID, preResp.SenderID)
allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, portal.Receiver, reactionTarget.ID, preResp.SenderID)
if err != nil {
log.Err(err).Msg("Failed to get all reactions to message by sender")
portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err))
Expand Down Expand Up @@ -2162,9 +2162,9 @@ func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventW

func (portal *Portal) getTargetReaction(ctx context.Context, evt RemoteReactionRemove) (*database.Reaction, error) {
if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok {
return portal.Bridge.DB.Reaction.GetByID(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID())
return portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID())
} else {
return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID())
return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID())
}
}

Expand Down Expand Up @@ -2196,9 +2196,9 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User
}
var existingReactions []*database.Reaction
if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok {
existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart())
existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart())
} else {
existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, evt.GetTargetMessage())
existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, portal.Receiver, evt.GetTargetMessage())
}
existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction)
for _, existingReaction := range existingReactions {
Expand Down Expand Up @@ -2317,7 +2317,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi
return
}
emoji, emojiID := evt.GetReactionEmoji()
existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID)
existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID)
if err != nil {
log.Err(err).Msg("Failed to check if reaction is a duplicate")
return
Expand Down

0 comments on commit 21aa329

Please sign in to comment.