diff --git a/cmd/meowlnir/main.go b/cmd/meowlnir/main.go index 5da33b0..b87db36 100644 --- a/cmd/meowlnir/main.go +++ b/cmd/meowlnir/main.go @@ -128,7 +128,7 @@ func (m *Meowlnir) Init(ctx context.Context, configPath string, noSaveConfig boo m.EvaluatorByProtectedRoom = make(map[id.RoomID]*policyeval.PolicyEvaluator) m.EvaluatorByManagementRoom = make(map[id.RoomID]*policyeval.PolicyEvaluator, len(m.Config.Meowlnir.ManagementRooms)) for _, roomID := range m.Config.Meowlnir.ManagementRooms { - m.EvaluatorByManagementRoom[roomID] = policyeval.NewPolicyEvaluator(m.Client, m.PolicyStore, roomID) + m.EvaluatorByManagementRoom[roomID] = policyeval.NewPolicyEvaluator(m.Client, m.PolicyStore, roomID, m.DB, m.SynapseDB) } m.Log.Debug().Msg("Preparing crypto helper") @@ -173,7 +173,12 @@ func (m *Meowlnir) ensureBotRegistered(ctx context.Context) { } func (m *Meowlnir) Run(ctx context.Context) { - err := m.StateStore.Upgrade(ctx) + err := m.DB.Upgrade(ctx) + if err != nil { + m.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to upgrade main db") + os.Exit(20) + } + err = m.StateStore.Upgrade(ctx) if err != nil { m.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to upgrade state store") os.Exit(20) diff --git a/config/event.go b/config/event.go index 79da135..f45c645 100644 --- a/config/event.go +++ b/config/event.go @@ -13,8 +13,9 @@ var ( ) type WatchedPolicyList struct { - RoomID id.RoomID `json:"room_id"` - Name string `json:"name"` + RoomID id.RoomID `json:"room_id"` + Name string `json:"name"` + AutoUnban bool `json:"auto_unban"` } type WatchedListsEventContent struct { diff --git a/database/action.go b/database/action.go index 4c9b6a6..7032bde 100644 --- a/database/action.go +++ b/database/action.go @@ -1,13 +1,76 @@ package database import ( + "context" + "time" + + "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) +const ( + getTakenActionBaseQuery = ` + SELECT target_user, in_room_id, action_type, policy_list, rule_entity, action, taken_at + FROM taken_action + ` + getTakenActionsByPolicyListQuery = getTakenActionBaseQuery + `WHERE policy_list=$1` + getTakenActionsByRuleEntityQuery = getTakenActionBaseQuery + `WHERE policy_list=$1 AND rule_entity=$2` + getTakenActionByTargetUserQuery = getTakenActionBaseQuery + `WHERE target_user=$1 AND action_type=$2` + insertTakenActionQuery = ` + INSERT INTO taken_action (target_user, in_room_id, action_type, policy_list, rule_entity, action, taken_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (target_user, in_room_id, action_type) DO UPDATE + SET policy_list=excluded.policy_list, rule_entity=excluded.rule_entity, action=excluded.action, taken_at=excluded.taken_at + ` +) + +type TakenActionQuery struct { + *dbutil.QueryHelper[*TakenAction] +} + +func (taq *TakenActionQuery) Put(ctx context.Context, ta *TakenAction) error { + return taq.Exec(ctx, insertTakenActionQuery, ta.sqlVariables()...) +} + +func (taq *TakenActionQuery) GetAllByPolicyList(ctx context.Context, policyList id.RoomID) ([]*TakenAction, error) { + return taq.QueryMany(ctx, getTakenActionsByPolicyListQuery, policyList) +} + +func (taq *TakenActionQuery) GetAllByRuleEntity(ctx context.Context, policyList id.RoomID, ruleEntity string) ([]*TakenAction, error) { + return taq.QueryMany(ctx, getTakenActionsByRuleEntityQuery, policyList, ruleEntity) +} + +func (taq *TakenActionQuery) GetAllByTargetUser(ctx context.Context, userID id.UserID, actionType TakenActionType) ([]*TakenAction, error) { + return taq.QueryMany(ctx, getTakenActionByTargetUserQuery, userID, actionType) +} + +type TakenActionType string + +const ( + TakenActionTypeBanOrUnban TakenActionType = "ban_or_unban" +) + type TakenAction struct { + TargetUser id.UserID + InRoomID id.RoomID + ActionType TakenActionType PolicyList id.RoomID RuleEntity string - TargetUser id.UserID Action event.PolicyRecommendation + TakenAt time.Time +} + +func (t *TakenAction) sqlVariables() []any { + return []any{t.TargetUser, t.InRoomID, t.ActionType, t.PolicyList, t.RuleEntity, t.Action, t.TakenAt.UnixMilli()} +} + +func (t *TakenAction) Scan(row dbutil.Scannable) (*TakenAction, error) { + var takenAt int64 + err := row.Scan(&t.TargetUser, &t.InRoomID, &t.ActionType, &t.PolicyList, &t.RuleEntity, &t.Action, &takenAt) + if err != nil { + return nil, err + } + t.TakenAt = time.UnixMilli(takenAt) + return t, nil } diff --git a/database/db.go b/database/db.go index 9c2a902..19b2476 100644 --- a/database/db.go +++ b/database/db.go @@ -2,14 +2,23 @@ package database import ( "go.mau.fi/util/dbutil" + + "go.mau.fi/meowlnir/database/upgrades" ) type Database struct { *dbutil.Database + TakenAction *TakenActionQuery } func New(db *dbutil.Database) *Database { + db.UpgradeTable = upgrades.Table return &Database{ Database: db, + TakenAction: &TakenActionQuery{ + QueryHelper: dbutil.MakeQueryHelper(db, func(qh *dbutil.QueryHelper[*TakenAction]) *TakenAction { + return &TakenAction{} + }), + }, } } diff --git a/database/upgrades/00-latest.sql b/database/upgrades/00-latest.sql index 48c6fb8..ab26fc1 100644 --- a/database/upgrades/00-latest.sql +++ b/database/upgrades/00-latest.sql @@ -1,8 +1,15 @@ -- v0 -> v1 (compatible with v1+): Latest schema CREATE TABLE taken_action ( - policy_list TEXT NOT NULL, - rule_entity TEXT NOT NULL, - target_user TEXT NOT NULL, - action TEXT NOT NULL, - taken_at BIGINT NOT NULL + target_user TEXT NOT NULL, + in_room_id TEXT NOT NULL, + action_type TEXT NOT NULL, + policy_list TEXT NOT NULL, + rule_entity TEXT NOT NULL, + action TEXT NOT NULL, + taken_at BIGINT NOT NULL, + + PRIMARY KEY (target_user, in_room_id, action_type) ); + +CREATE INDEX taken_action_list_idx ON taken_action (policy_list); +CREATE INDEX taken_action_entity_idx ON taken_action (policy_list, rule_entity); diff --git a/go.mod b/go.mod index 7587d74..c36a42c 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,13 @@ go 1.23 require ( github.com/lib/pq v1.10.9 - github.com/prometheus/client_golang v1.20.2 + github.com/prometheus/client_golang v1.20.3 github.com/rs/zerolog v1.33.0 - go.mau.fi/util v0.7.1-0.20240901193650-bf007b10eaf6 + go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 - maunium.net/go/mautrix v0.20.1-0.20240902204906-db8f2433a1db + maunium.net/go/mautrix v0.20.1-0.20240906145130-6b055b1475bd ) require ( @@ -24,7 +23,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/mattn/go-sqlite3 v1.14.23 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect github.com/prometheus/client_model v0.6.1 // indirect @@ -36,12 +35,10 @@ require ( github.com/tidwall/sjson v1.2.5 // indirect github.com/yuin/goldmark v1.7.4 // indirect golang.org/x/crypto v0.26.0 // indirect + golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect golang.org/x/net v0.28.0 // indirect - golang.org/x/sys v0.24.0 // indirect - golang.org/x/text v0.17.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) - -//replace maunium.net/go/mautrix => ../mautrix-go -//replace go.mau.fi/util => ../../Go/go-util diff --git a/go.sum b/go.sum index 8d30ecc..4ef1de3 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= +github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw= @@ -40,8 +40,8 @@ github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7c github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.20.2 h1:5ctymQzZlyOON1666svgwn3s6IKWgfbjsejTMiXIyjg= -github.com/prometheus/client_golang v1.20.2/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_golang v1.20.3 h1:oPksm4K8B+Vt35tUhw6GbSNSgVlVSBH0qELP/7u83l4= +github.com/prometheus/client_golang v1.20.3/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= @@ -66,8 +66,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.7.1-0.20240901193650-bf007b10eaf6 h1:cSLCabMKbR6rTPYRGWD2XaHo210BK3BtPg+CRC4A4og= -go.mau.fi/util v0.7.1-0.20240901193650-bf007b10eaf6/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= +go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 h1:VZQlKBbeJ7KOlYSh6BnN5uWQTY/ypn/bJv0YyEd+pXc= +go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2/go.mod h1:WgYvbt9rVmoFeajP97NunQU7AjgvTPiNExN3oTHeePs= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= @@ -79,10 +79,10 @@ golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= -golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -94,5 +94,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/mautrix v0.20.1-0.20240902204906-db8f2433a1db h1:kZ4iZLIPRZD62wqcnGwRATJCNjU6EC8vfEOewwX9pZQ= -maunium.net/go/mautrix v0.20.1-0.20240902204906-db8f2433a1db/go.mod h1:IXDDoX+dqBkNnrjDMouE3FUExiR+hhmaEFsvXG3HzfQ= +maunium.net/go/mautrix v0.20.1-0.20240906145130-6b055b1475bd h1:gfiJD2cPS9iUek1UI+DOUn08zogF4kmu7XYfBqSrAU4= +maunium.net/go/mautrix v0.20.1-0.20240906145130-6b055b1475bd/go.mod h1:l6nYvD5/FMSrAZ/IP1AqJV0b47SRl/0uQNRiy4CcSVk= diff --git a/policyeval/evaluate.go b/policyeval/evaluate.go index 9562e9b..e7a1c92 100644 --- a/policyeval/evaluate.go +++ b/policyeval/evaluate.go @@ -6,8 +6,10 @@ import ( "slices" "github.com/rs/zerolog" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "go.mau.fi/meowlnir/database" "go.mau.fi/meowlnir/policylist" ) @@ -20,26 +22,77 @@ func (pe *PolicyEvaluator) EvaluateAll(ctx context.Context) { func (pe *PolicyEvaluator) EvaluateAllMembers(ctx context.Context, members []id.UserID) { for _, member := range members { - pe.EvaluateNewMember(ctx, member) + pe.EvaluateUser(ctx, member) } } -func (pe *PolicyEvaluator) EvaluateNewMember(ctx context.Context, userID id.UserID) { +func (pe *PolicyEvaluator) EvaluateUser(ctx context.Context, userID id.UserID) { match := pe.Store.MatchUser(pe.GetWatchedLists(), userID) if match == nil { return } - zerolog.Ctx(ctx).Info(). + zerolog.Ctx(ctx).Debug(). Stringer("user_id", userID). - Any("recommendation", match.Recommendations()). Any("matches", match). - Msg("Matched user in membership event") + Msg("Found matches for user") + pe.ApplyPolicy(ctx, userID, match) } func (pe *PolicyEvaluator) EvaluateRemovedRule(ctx context.Context, policy *policylist.Policy) { - // TODO + if policy.Recommendation == event.PolicyRecommendationUnban { + // When an unban rule is removed, evaluate all joined users against the removed rule + // to see if they should be re-evaluated against all rules (and possibly banned) + pe.usersLock.RLock() + users := slices.Collect(maps.Keys(pe.users)) + pe.usersLock.RUnlock() + for _, userID := range users { + if policy.Pattern.Match(string(userID)) { + pe.EvaluateUser(ctx, userID) + } + } + } else { + // For ban rules, find users who were banned by the rule and re-evaluate them. + reevalTargets, err := pe.DB.TakenAction.GetAllByRuleEntity(ctx, policy.RoomID, policy.Entity) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("policy_entity", policy.Entity). + Msg("Failed to get actions taken for removed policy") + pe.sendNotice(ctx, "Database error in EvaluateRemovedRule (GetAllByRuleEntity): %v", err) + return + } + pe.ReevaluateActions(ctx, reevalTargets) + } } func (pe *PolicyEvaluator) EvaluateAddedRule(ctx context.Context, policy *policylist.Policy) { - // TODO + pe.usersLock.RLock() + users := slices.Collect(maps.Keys(pe.users)) + pe.usersLock.RUnlock() + for _, userID := range users { + if policy.Pattern.Match(string(userID)) { + pe.ApplyPolicy(ctx, userID, policylist.Match{policy}) + } + } +} + +func (pe *PolicyEvaluator) ReevaluateAffectedByLists(ctx context.Context, policyLists []id.RoomID) { + var reevalTargets []*database.TakenAction + for _, list := range policyLists { + targets, err := pe.DB.TakenAction.GetAllByPolicyList(ctx, list) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("policy_list_id", list). + Msg("Failed to get actions taken from policy list") + pe.sendNotice(ctx, "Database error in ReevaluateAffectedByLists (GetAllByPolicyList): %v", err) + continue + } + if reevalTargets == nil { + reevalTargets = targets + } else { + reevalTargets = append(reevalTargets, targets...) + } + } + pe.ReevaluateActions(ctx, reevalTargets) +} + +func (pe *PolicyEvaluator) ReevaluateActions(ctx context.Context, actions []*database.TakenAction) { + } diff --git a/policyeval/eventhandle.go b/policyeval/eventhandle.go index 7a77f3a..fe4c8ea 100644 --- a/policyeval/eventhandle.go +++ b/policyeval/eventhandle.go @@ -2,6 +2,7 @@ package policyeval import ( "context" + "fmt" "strings" "github.com/rs/zerolog" @@ -34,7 +35,40 @@ func (pe *PolicyEvaluator) HandleConfigChange(ctx context.Context, evt *event.Ev func (pe *PolicyEvaluator) HandleMember(ctx context.Context, evt *event.Event) { checkRules := pe.updateUser(id.UserID(evt.GetStateKey()), evt.RoomID, evt.Content.AsMember().Membership) if checkRules { - pe.EvaluateNewMember(ctx, id.UserID(evt.GetStateKey())) + pe.EvaluateUser(ctx, id.UserID(evt.GetStateKey())) + } +} + +func addActionString(rec event.PolicyRecommendation) string { + switch rec { + case event.PolicyRecommendationBan: + return "banned" + case event.PolicyRecommendationUnban: + return "added a ban exclusion for" + default: + return fmt.Sprintf("added a `%s` rule for", rec) + } +} + +func changeActionString(rec event.PolicyRecommendation) string { + switch rec { + case event.PolicyRecommendationBan: + return "ban" + case event.PolicyRecommendationUnban: + return "ban exclusion" + default: + return fmt.Sprintf("`%s`", rec) + } +} + +func removeActionString(rec event.PolicyRecommendation) string { + switch rec { + case event.PolicyRecommendationBan: + return "unbanned" + case event.PolicyRecommendationUnban: + return "removed a ban exclusion for" + default: + return fmt.Sprintf("removed a `%s` rule for", rec) } } @@ -47,27 +81,28 @@ func (pe *PolicyEvaluator) HandlePolicyListChange(ctx context.Context, policyRoo Any("added", added). Any("removed", removed). Msg("Policy list change") - if removed != nil && added != nil && removed.Entity == added.Entity { - // probably just a reason change (unless recommendation changed too) - } - if removed != nil && (added == nil || removed.Entity != added.Entity) { - pe.EvaluateRemovedRule(ctx, removed) - // TODO include entity type in message - pe.sendNotice(ctx, - "[%s](%s) (%s): [%s](%s) removed `%s`/`%s` rule matching `%s` for %s", - policyRoomMeta.Name, policyRoom.URI().MatrixToURL(), policyRoomMeta.Name, - removed.Sender, removed.Sender.URI().MatrixToURL(), - removed.EntityType, removed.Recommendation, removed.Entity, removed.Reason, - ) - } - if added != nil && (removed == nil || removed.Entity != added.Entity) { - pe.EvaluateAddedRule(ctx, added) - // TODO include entity type in message + removedAndAddedAreEquivalent := removed != nil && added != nil && removed.Entity == added.Entity && removed.Recommendation == added.Recommendation + if removedAndAddedAreEquivalent { pe.sendNotice(ctx, - "[%s](%s) (%s): [%s](%s) added `%s`/`%s` rule matching `%s` for %s", - policyRoomMeta.Name, policyRoom.URI().MatrixToURL(), policyRoomMeta.Name, - added.Sender, added.Sender.URI().MatrixToURL(), - added.EntityType, added.Recommendation, added.Entity, added.Reason, - ) + "[%s] [%s](%s) changed the %s reason for `%s` from `%s` to `%s`", + policyRoomMeta.Name, added.Sender, added.Sender.URI().MatrixToURL(), + changeActionString(added.Recommendation), added.Entity, removed.Reason, added.Reason) + } else { + if removed != nil { + pe.sendNotice(ctx, + "[%s] [%s](%s) %s %ss matching `%s` for %s", + policyRoomMeta.Name, removed.Sender, removed.Sender.URI().MatrixToURL(), + removeActionString(removed.Recommendation), removed.EntityType, removed.Entity, removed.Reason, + ) + pe.EvaluateRemovedRule(ctx, removed) + } + if added != nil { + pe.sendNotice(ctx, + "[%s] [%s](%s) %s %ss matching `%s` for %s", + policyRoomMeta.Name, added.Sender, added.Sender.URI().MatrixToURL(), + addActionString(added.Recommendation), added.EntityType, added.Entity, added.Reason, + ) + pe.EvaluateAddedRule(ctx, added) + } } } diff --git a/policyeval/execute.go b/policyeval/execute.go new file mode 100644 index 0000000..22bb27c --- /dev/null +++ b/policyeval/execute.go @@ -0,0 +1,156 @@ +package policyeval + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + "time" + + "github.com/rs/zerolog" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/meowlnir/database" + "go.mau.fi/meowlnir/policylist" +) + +func (pe *PolicyEvaluator) getRoomsUserIsIn(userID id.UserID) []id.RoomID { + pe.usersLock.RLock() + rooms := slices.Clone(pe.users[userID]) + pe.usersLock.RUnlock() + return rooms +} + +func (pe *PolicyEvaluator) ApplyPolicy(ctx context.Context, userID id.UserID, policy policylist.Match) { + if userID == pe.Client.UserID { + return + } + recs := policy.Recommendations() + rooms := pe.getRoomsUserIsIn(userID) + if recs.BanOrUnban != nil { + if recs.BanOrUnban.Recommendation == event.PolicyRecommendationBan { + for _, room := range rooms { + pe.ApplyBan(ctx, userID, room, recs.BanOrUnban) + } + } else { + // TODO unban if banned in some rooms? or just require doing that manually + //takenActions, err := pe.DB.TakenAction.GetAllByTargetUser(ctx, userID, database.TakenActionTypeBanOrUnban) + //if err != nil { + // zerolog.Ctx(ctx).Err(err).Stringer("user_id", userID).Msg("Failed to get taken actions") + // pe.sendNotice(ctx, "Database error in ApplyPolicy (GetAllByTargetUser): %v", err) + // return + //} + } + } +} + +func (pe *PolicyEvaluator) ApplyBan(ctx context.Context, userID id.UserID, roomID id.RoomID, policy *policylist.Policy) { + ta := &database.TakenAction{ + TargetUser: userID, + InRoomID: roomID, + ActionType: database.TakenActionTypeBanOrUnban, + PolicyList: policy.RoomID, + RuleEntity: policy.Entity, + Action: policy.Recommendation, + TakenAt: time.Now(), + } + var err error + if !pe.DryRun { + _, err = pe.Client.BanUser(ctx, roomID, &mautrix.ReqBanUser{ + Reason: policy.Reason, + UserID: userID, + }) + } + if err != nil { + var respErr mautrix.HTTPError + if errors.As(err, &respErr) { + err = respErr + } + zerolog.Ctx(ctx).Err(err).Any("attempted_action", ta).Msg("Failed to ban user") + pe.sendNotice(ctx, "Failed to ban [%s](%s) in [%s](%s) for %s: %v", userID, userID.URI().MatrixToURL(), roomID, roomID.URI().MatrixToURL(), policy.Reason, err) + return + } + err = pe.DB.TakenAction.Put(ctx, ta) + if err != nil { + zerolog.Ctx(ctx).Err(err).Any("taken_action", ta).Msg("Failed to save taken action") + pe.sendNotice(ctx, "Banned [%s](%s) in [%s](%s) for %s, but failed to save to database: %v", userID, userID.URI().MatrixToURL(), roomID, roomID.URI().MatrixToURL(), policy.Reason, err) + } else { + zerolog.Ctx(ctx).Info().Any("taken_action", ta).Msg("Took action") + pe.sendNotice(ctx, "Banned [%s](%s) in [%s](%s) for %s", userID, userID.URI().MatrixToURL(), roomID, roomID.URI().MatrixToURL(), policy.Reason) + } + if policy.Reason == "spam" { + go pe.RedactUser(context.WithoutCancel(ctx), userID, policy.Reason) + } +} + +func pluralize(value int, unit string) string { + if value == 1 { + return "1 " + unit + } + return fmt.Sprintf("%d %ss", value, unit) +} + +func (pe *PolicyEvaluator) RedactUser(ctx context.Context, userID id.UserID, reason string) { + events, err := pe.SynapseDB.GetEventsToRedact(ctx, userID, pe.GetProtectedRooms()) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("user_id", userID). + Msg("Failed to get events to redact") + pe.sendNotice(ctx, + "Failed to get events to redact for [%s](%s): %v", + userID, userID.URI().MatrixToURL(), err) + return + } else if len(events) == 0 { + return + } + var errorMessages []string + var redactedCount int + for roomID, roomEvents := range events { + successCount, failedCount := pe.redactEventsInRoom(ctx, userID, roomID, roomEvents, reason) + if failedCount > 0 { + errorMessages = append(errorMessages, fmt.Sprintf( + "* Failed to redact %d/%d events from [%s](%s) in [%s](%s)", + failedCount, failedCount+successCount, userID, userID.URI().MatrixToURL(), roomID, roomID.URI().MatrixToURL())) + } + redactedCount += successCount + } + output := fmt.Sprintf("Redacted %s across %s from [%s](%s)", + pluralize(redactedCount, "event"), pluralize(len(events), "room"), + userID, userID.URI().MatrixToURL()) + if len(errorMessages) > 0 { + output += "\n\n" + strings.Join(errorMessages, "\n") + } + pe.sendNotice(ctx, output) +} + +func (pe *PolicyEvaluator) redactEventsInRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, events []id.EventID, reason string) (successCount, failedCount int) { + for _, evtID := range events { + var resp *mautrix.RespSendEvent + var err error + if !pe.DryRun { + resp, err = pe.Client.RedactEvent(ctx, roomID, evtID, mautrix.ReqRedact{Reason: reason}) + } else { + resp = &mautrix.RespSendEvent{EventID: "$fake-redaction-id"} + } + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("sender", userID). + Stringer("room_id", roomID). + Stringer("event_id", evtID). + Msg("Failed to redact event") + failedCount++ + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("sender", userID). + Stringer("room_id", roomID). + Stringer("event_id", evtID). + Stringer("redaction_id", resp.EventID). + Msg("Successfully redacted event") + successCount++ + } + } + return +} diff --git a/policyeval/main.go b/policyeval/main.go index f363b8d..e09b6ef 100644 --- a/policyeval/main.go +++ b/policyeval/main.go @@ -15,12 +15,17 @@ import ( "maunium.net/go/mautrix/id" "go.mau.fi/meowlnir/config" + "go.mau.fi/meowlnir/database" "go.mau.fi/meowlnir/policylist" + "go.mau.fi/meowlnir/synapsedb" ) type PolicyEvaluator struct { - Client *mautrix.Client - Store *policylist.Store + Client *mautrix.Client + Store *policylist.Store + SynapseDB *synapsedb.SynapseDB + DB *database.Database + DryRun bool ManagementRoom id.RoomID Admins *exsync.Set[id.UserID] @@ -36,8 +41,10 @@ type PolicyEvaluator struct { usersLock sync.RWMutex } -func NewPolicyEvaluator(client *mautrix.Client, store *policylist.Store, managementRoom id.RoomID) *PolicyEvaluator { +func NewPolicyEvaluator(client *mautrix.Client, store *policylist.Store, managementRoom id.RoomID, db *database.Database, synapseDB *synapsedb.SynapseDB) *PolicyEvaluator { pe := &PolicyEvaluator{ + DB: db, + SynapseDB: synapseDB, Client: client, Store: store, ManagementRoom: managementRoom, diff --git a/policyeval/protectedrooms.go b/policyeval/protectedrooms.go index c1d8e67..71d93b3 100644 --- a/policyeval/protectedrooms.go +++ b/policyeval/protectedrooms.go @@ -3,7 +3,9 @@ package policyeval import ( "context" "fmt" + "maps" "slices" + "sync" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -11,6 +13,13 @@ import ( "go.mau.fi/meowlnir/config" ) +func (pe *PolicyEvaluator) GetProtectedRooms() []id.RoomID { + pe.usersLock.RLock() + rooms := slices.Collect(maps.Keys(pe.protectedRooms)) + pe.usersLock.RUnlock() + return rooms +} + func (pe *PolicyEvaluator) IsProtectedRoom(roomID id.RoomID) bool { pe.usersLock.RLock() _, protected := pe.protectedRooms[roomID] @@ -23,23 +32,34 @@ func (pe *PolicyEvaluator) handleProtectedRooms(ctx context.Context, evt *event. if !ok { return []string{"* Failed to parse protected rooms event"} } + var outLock sync.Mutex + reevalMembers := make(map[id.UserID]struct{}) + var wg sync.WaitGroup for _, roomID := range content.Rooms { if pe.IsProtectedRoom(roomID) { continue } - members, err := pe.Client.Members(ctx, roomID) - if err != nil { - out = append(out, fmt.Sprintf("* Failed to get room members for [%s](%s): %v", roomID, roomID.URI().MatrixToURL(), err)) - continue - } - pe.markAsProtectedRoom(roomID, members.Chunk) - if !isInitial { - memberUserIDs := make([]id.UserID, len(members.Chunk)) - for i, member := range members.Chunk { - memberUserIDs[i] = id.UserID(member.GetStateKey()) + wg.Add(1) + go func() { + defer wg.Done() + members, err := pe.Client.Members(ctx, roomID) + outLock.Lock() + defer outLock.Unlock() + if err != nil { + out = append(out, fmt.Sprintf("* Failed to get room members for [%s](%s): %v", roomID, roomID.URI().MatrixToURL(), err)) + return } - pe.EvaluateAllMembers(ctx, memberUserIDs) - } + pe.markAsProtectedRoom(roomID, members.Chunk) + if !isInitial { + for _, member := range members.Chunk { + reevalMembers[id.UserID(member.GetStateKey())] = struct{}{} + } + } + }() + } + wg.Wait() + if len(reevalMembers) > 0 { + pe.EvaluateAllMembers(ctx, slices.Collect(maps.Keys(reevalMembers))) } return } diff --git a/policyeval/watchedlists.go b/policyeval/watchedlists.go index c0f2a66..1f088fe 100644 --- a/policyeval/watchedlists.go +++ b/policyeval/watchedlists.go @@ -3,6 +3,7 @@ package policyeval import ( "context" "fmt" + "sync" "go.mau.fi/util/exslices" "maunium.net/go/mautrix/event" @@ -38,22 +39,35 @@ func (pe *PolicyEvaluator) handleWatchedLists(ctx context.Context, evt *event.Ev } watchedList := make([]id.RoomID, 0, len(content.Lists)) watchedMap := make(map[id.RoomID]*config.WatchedPolicyList, len(content.Lists)) + var outLock sync.Mutex + var wg sync.WaitGroup for _, listInfo := range content.Lists { if _, alreadyWatched := watchedMap[listInfo.RoomID]; alreadyWatched { + outLock.Lock() out = append(out, fmt.Sprintf("* Duplicate watched list [%s](%s)", listInfo.Name, listInfo.RoomID.URI().MatrixToURL())) + outLock.Unlock() continue } - if !pe.Store.Contains(listInfo.RoomID) { - state, err := pe.Client.State(ctx, listInfo.RoomID) - if err != nil { - out = append(out, fmt.Sprintf("* Failed to get room state for [%s](%s): %v", listInfo.Name, listInfo.RoomID.URI().MatrixToURL(), err)) - continue + wg.Add(1) + go func() { + defer wg.Done() + if !pe.Store.Contains(listInfo.RoomID) { + state, err := pe.Client.State(ctx, listInfo.RoomID) + if err != nil { + outLock.Lock() + out = append(out, fmt.Sprintf("* Failed to get room state for [%s](%s): %v", listInfo.Name, listInfo.RoomID.URI().MatrixToURL(), err)) + outLock.Unlock() + return + } + pe.Store.Add(listInfo.RoomID, state) } - pe.Store.Add(listInfo.RoomID, state) - } - watchedMap[listInfo.RoomID] = &listInfo - watchedList = append(watchedList, listInfo.RoomID) + outLock.Lock() + watchedMap[listInfo.RoomID] = &listInfo + watchedList = append(watchedList, listInfo.RoomID) + outLock.Unlock() + }() } + wg.Wait() pe.watchedListsLock.Lock() oldWatchedList := pe.watchedListsList pe.watchedListsMap = watchedMap @@ -61,12 +75,14 @@ func (pe *PolicyEvaluator) handleWatchedLists(ctx context.Context, evt *event.Ev pe.watchedListsLock.Unlock() if !isInitial { unsubscribed, subscribed := exslices.Diff(oldWatchedList, watchedList) - if len(unsubscribed) > 0 { - // TODO re-evaluate banned users who were affected by the removed lists - } - if len(subscribed) > 0 { - // TODO re-evaluate joined users - } + go func(ctx context.Context) { + if len(unsubscribed) > 0 { + pe.ReevaluateAffectedByLists(ctx, unsubscribed) + } + if len(subscribed) > 0 || len(unsubscribed) > 0 { + pe.EvaluateAll(ctx) + } + }(context.WithoutCancel(ctx)) } return } diff --git a/policylist/policy.go b/policylist/policy.go index 0782012..92c6164 100644 --- a/policylist/policy.go +++ b/policylist/policy.go @@ -24,21 +24,16 @@ type Policy struct { type Match []*Policy type Recommendations struct { - Ban bool - Unban bool + BanOrUnban *Policy } // Recommendations aggregates the recommendations in the match. func (m Match) Recommendations() (output Recommendations) { for _, policy := range m { switch policy.Recommendation { - case event.PolicyRecommendationBan: - if !output.Unban { - output.Ban = true - } - case event.PolicyRecommendationUnban: - if !output.Ban { - output.Unban = true + case event.PolicyRecommendationBan, event.PolicyRecommendationUnban: + if output.BanOrUnban == nil { + output.BanOrUnban = policy } } } diff --git a/policylist/store.go b/policylist/store.go index 6182f7b..7e9e82f 100644 --- a/policylist/store.go +++ b/policylist/store.go @@ -1,9 +1,10 @@ package policylist import ( + "maps" + "slices" "sync" - "golang.org/x/exp/maps" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -85,7 +86,7 @@ func (s *Store) Contains(roomID id.RoomID) bool { func (s *Store) match(listIDs []id.RoomID, entity string, listGetter func(*Room) *List) (output Match) { if listIDs == nil { s.roomsLock.Lock() - listIDs = maps.Keys(s.rooms) + listIDs = slices.Collect(maps.Keys(s.rooms)) s.roomsLock.Unlock() } for _, roomID := range listIDs { diff --git a/synapsedb/db.go b/synapsedb/db.go index d21573d..16e8ccd 100644 --- a/synapsedb/db.go +++ b/synapsedb/db.go @@ -3,8 +3,11 @@ package synapsedb import ( "context" + "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exslices" + "maunium.net/go/mautrix/id" ) type SynapseDB struct { @@ -39,6 +42,34 @@ func (s *SynapseDB) CheckVersion(ctx context.Context) error { return nil } +const getUnredactedEventsBySenderInRoomQuery = ` + SELECT events.room_id, events.event_id + FROM events + LEFT JOIN redactions ON events.event_id=redactions.redacts + WHERE events.sender = $1 AND events.room_id = ANY($2) AND redactions.redacts IS NULL +` + +type roomEventTuple struct { + RoomID id.RoomID + EventID id.EventID +} + +var scanRoomEventTuple = dbutil.ConvertRowFn[roomEventTuple](func(row dbutil.Scannable) (t roomEventTuple, err error) { + err = row.Scan(&t.RoomID, &t.EventID) + return +}) + +func (s *SynapseDB) GetEventsToRedact(ctx context.Context, sender id.UserID, inRooms []id.RoomID) (map[id.RoomID][]id.EventID, error) { + output := make(map[id.RoomID][]id.EventID) + err := scanRoomEventTuple.NewRowIter( + s.DB.Query(ctx, getUnredactedEventsBySenderInRoomQuery, sender, pq.Array(exslices.CastToString[string](inRooms))), + ).Iter(func(tuple roomEventTuple) (bool, error) { + output[tuple.RoomID] = append(output[tuple.RoomID], tuple.EventID) + return true, nil + }) + return output, err +} + func (s *SynapseDB) Close() error { return s.DB.Close() }