diff --git a/go.mod b/go.mod index c49fa0ff..e7b3e390 100644 --- a/go.mod +++ b/go.mod @@ -44,3 +44,10 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect ) +<<<<<<< HEAD +======= + +//replace maunium.net/go/mautrix => ../mautrix-go +//replace go.mau.fi/util => ../../Go/go-util +replace maunium.net/go/mautrix => github.com/maltee1/mautrix-go v0.0.0-20240808204140-9598e29d1124 +>>>>>>> aec8c02 (fixes) diff --git a/go.sum b/go.sum index 4972d7b5..ca3f9472 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,11 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +<<<<<<< HEAD +======= +github.com/maltee1/mautrix-go v0.0.0-20240808204140-9598e29d1124 h1:zgSOHfcfq6NXuHL+mo/IexMGksBmrZcVyPjipTjEILc= +github.com/maltee1/mautrix-go v0.0.0-20240808204140-9598e29d1124/go.mod h1:ZWyxoQxRTBxzWIMs0kQCVogZIY0clTu33h102veCT/Q= +>>>>>>> aec8c02 (fixes) github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= diff --git a/pkg/connector/handlematrix.go b/pkg/connector/handlematrix.go index e86981c6..173ed278 100644 --- a/pkg/connector/handlematrix.go +++ b/pkg/connector/handlematrix.go @@ -373,14 +373,7 @@ func (s *SignalClient) HandleMatrixMembership(ctx context.Context, msg *bridgev2 } gc := &signalmeow.GroupChange{} role := signalmeow.GroupMember_DEFAULT - toMembership := msg.Type.GetTo() - if toMembership == event.MembershipBan { - gc.AddBannedMembers = []*signalmeow.BannedMember{{ - ServiceID: libsignalgo.NewACIServiceID(targetSignalID), - Timestamp: uint64(time.Now().UnixMilli()), - }} - } - if toMembership == event.MembershipInvite || msg.Type == bridgev2.AcceptKnock { + if msg.Type.To == event.MembershipInvite || msg.Type == bridgev2.AcceptKnock { levels, err := msg.Portal.Bridge.Matrix.GetPowerLevels(ctx, msg.Portal.MXID) if err != nil { log.Err(err).Msg("Couldn't get power levels") @@ -391,10 +384,10 @@ func (s *SignalClient) HandleMatrixMembership(ctx context.Context, msg *bridgev2 } switch msg.Type { case bridgev2.AcceptInvite: - gc.PromotePendingMembers = []*signalmeow.PromotePendingMember{&signalmeow.PromotePendingMember{ + gc.PromotePendingMembers = []*signalmeow.PromotePendingMember{{ ACI: targetSignalID, }} - case bridgev2.RevokeInvite, bridgev2.RejectInvite, bridgev2.BanInvited: + case bridgev2.RevokeInvite, bridgev2.RejectInvite: deletePendingMember := libsignalgo.NewACIServiceID(targetSignalID) gc.DeletePendingMembers = []*libsignalgo.ServiceID{&deletePendingMember} case bridgev2.Leave, bridgev2.Kick: @@ -426,11 +419,28 @@ func (s *SignalClient) HandleMatrixMembership(ctx context.Context, msg *bridgev2 ACI: targetSignalID, Role: role, }} - case bridgev2.RetractKnock, bridgev2.RejectKnock, bridgev2.BanKnocked: + case bridgev2.RetractKnock, bridgev2.RejectKnock: gc.DeleteRequestingMembers = []*uuid.UUID{&targetSignalID} + case bridgev2.BanKnocked, bridgev2.BanInvited, bridgev2.BanJoined, bridgev2.Ban: + gc.AddBannedMembers = []*signalmeow.BannedMember{{ + ServiceID: libsignalgo.NewACIServiceID(targetSignalID), + Timestamp: uint64(time.Now().UnixMilli()), + }} + switch msg.Type { + case bridgev2.BanJoined: + gc.DeleteMembers = []*uuid.UUID{&targetSignalID} + case bridgev2.BanInvited: + deletePendingMember := libsignalgo.NewACIServiceID(targetSignalID) + gc.DeletePendingMembers = []*libsignalgo.ServiceID{&deletePendingMember} + case bridgev2.BanKnocked: + gc.DeleteRequestingMembers = []*uuid.UUID{&targetSignalID} + } case bridgev2.Unban: unbanUser := libsignalgo.NewACIServiceID(targetSignalID) gc.DeleteBannedMembers = []*libsignalgo.ServiceID{&unbanUser} + default: + log.Debug().Msg("unsupported membership change") + return false, nil } _, groupID, err := signalid.ParsePortalID(msg.Portal.ID) if err != nil || groupID == "" { diff --git a/pkg/signalid/ids.go b/pkg/signalid/ids.go index e9c83113..1209e90b 100644 --- a/pkg/signalid/ids.go +++ b/pkg/signalid/ids.go @@ -40,14 +40,11 @@ func ParseUserID(userID networkid.UserID) (uuid.UUID, error) { } func ParseUserLoginID(userLoginID networkid.UserLoginID) (uuid.UUID, error) { - serviceID, err := ParseUserLoginIDAsServiceID(userLoginID) + userID, err := ParseUserLoginID(userLoginID) if err != nil { return uuid.Nil, err - } else if serviceID.Type != libsignalgo.ServiceIDTypeACI { - return uuid.Nil, fmt.Errorf("invalid user ID: expected ACI type") - } else { - return serviceID.UUID, nil } + return userID, nil } func ParseUserIDAsServiceID(userID networkid.UserID) (libsignalgo.ServiceID, error) {