diff --git a/pkg/connector/groupinfo.go b/pkg/connector/groupinfo.go index 3bf809d1..43e2b78b 100644 --- a/pkg/connector/groupinfo.go +++ b/pkg/connector/groupinfo.go @@ -20,6 +20,7 @@ import ( "context" "time" + "github.com/google/uuid" "github.com/rs/zerolog" "go.mau.fi/util/ptr" "maunium.net/go/mautrix/bridgev2" @@ -120,11 +121,18 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden } } for _, member := range groupInfo.PendingMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { - continue + var aci uuid.UUID + if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = member.ServiceID.UUID } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) @@ -136,11 +144,18 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden }) } for _, member := range groupInfo.BannedMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { - continue + var aci uuid.UUID + if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = member.ServiceID.UUID } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(aci), Membership: event.MembershipBan, }) } @@ -246,18 +261,36 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { continue } + var aci uuid.UUID + if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = member.ServiceID.UUID + } + mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) } for _, memberServiceID := range groupChange.DeletePendingMembers { - if memberServiceID.Type != libsignalgo.ServiceIDTypeACI { - continue + var aci uuid.UUID + if memberServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, memberServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = memberServiceID.UUID } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(memberServiceID.UUID), + EventSender: s.makeEventSender(aci), Membership: event.MembershipLeave, PrevMembership: event.MembershipInvite, }) @@ -276,11 +309,18 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint }) } for _, member := range groupChange.AddBannedMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { - continue + var aci uuid.UUID + if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = member.ServiceID.UUID } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(aci), Membership: event.MembershipBan, }) } diff --git a/pkg/signalmeow/store/container.go b/pkg/signalmeow/store/container.go index 9edafef6..879716dc 100644 --- a/pkg/signalmeow/store/container.go +++ b/pkg/signalmeow/store/container.go @@ -19,6 +19,7 @@ var _ DeviceStore = (*Container)(nil) type DeviceStore interface { PutDevice(ctx context.Context, dd *DeviceData) error DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, error) + DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error) } // Container is a wrapper for a SQL database that can contain multiple signalmeow sessions. @@ -39,6 +40,7 @@ FROM signalmeow_device ` const getDeviceQuery = getAllDevicesQuery + " WHERE aci_uuid=$1" +const deviceByPNIQuery = getAllDevicesQuery + "Where pni_uuid=$1" func (c *Container) Upgrade(ctx context.Context) error { return c.db.Upgrade(ctx) @@ -122,6 +124,14 @@ func (c *Container) DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, er return sess, err } +func (c *Container) DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error) { + sess, err := c.scanDevice(c.db.QueryRow(ctx, deviceByPNIQuery, pni)) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return sess, err +} + const ( insertDeviceQuery = ` INSERT INTO signalmeow_device (