diff --git a/src/Mewdeko/Modules/Server Management/Services/ChannelCommandService.cs b/src/Mewdeko/Modules/Server Management/Services/ChannelCommandService.cs index e2bf3c19f..888cc9695 100644 --- a/src/Mewdeko/Modules/Server Management/Services/ChannelCommandService.cs +++ b/src/Mewdeko/Modules/Server Management/Services/ChannelCommandService.cs @@ -193,37 +193,33 @@ public async Task StoreOriginalPermissions(IGuild guild) await using var context = await dbContext.GetContextAsync(); var channels = await guild.GetChannelsAsync(); - foreach (var channel in channels) - { - if (!IsRelevantChannel(channel)) continue; - - var permissionOverrides = channel.PermissionOverwrites; + var existingPermissions = await context.LockdownChannelPermissions + .Where(p => p.GuildId == guild.Id) + .ToListAsync(); - foreach (var overwrite in permissionOverrides) + var newPermissions = (from channel in channels + where IsRelevantChannel(channel) + let permissionOverwrites = channel.PermissionOverwrites + from overwrite in permissionOverwrites + let existingEntry = + existingPermissions.FirstOrDefault(p => p.ChannelId == channel.Id && p.TargetId == overwrite.TargetId) + where existingEntry == null + select new LockdownChannelPermissions { - // Store overrides for both roles and users - var existingEntry = await context.LockdownChannelPermissions - .FirstOrDefaultAsync(p => - p.GuildId == guild.Id && p.ChannelId == channel.Id && p.TargetId == overwrite.TargetId); - - if (existingEntry != null) continue; - - // Add new entry for each permission override - var newPermission = new LockdownChannelPermissions - { - GuildId = guild.Id, - ChannelId = channel.Id, - TargetId = overwrite.TargetId, - TargetType = overwrite.TargetType, // Role or User - AllowPermissions = GetRawPermissionValue(overwrite.Permissions.ToAllowList()), - DenyPermissions = GetRawPermissionValue(overwrite.Permissions.ToDenyList()) - }; - - await context.LockdownChannelPermissions.AddAsync(newPermission); - } + GuildId = guild.Id, + ChannelId = channel.Id, + TargetId = overwrite.TargetId, + TargetType = overwrite.TargetType, // Role or User + AllowPermissions = GetRawPermissionValue(overwrite.Permissions.ToAllowList()), + DenyPermissions = GetRawPermissionValue(overwrite.Permissions.ToDenyList()) + }).ToList(); + + // Add all new permissions in one batch + if (newPermissions.Count != 0) + { + await context.LockdownChannelPermissions.AddRangeAsync(newPermissions); + await context.SaveChangesAsync(); } - - await context.SaveChangesAsync(); } /// @@ -231,7 +227,7 @@ public async Task StoreOriginalPermissions(IGuild guild) /// /// The guild whose permissions are being removed. /// A task that represents the asynchronous operation. - private async Task RemovePermissions(IGuild guild) + private static async Task RemovePermissions(IGuild guild) { var channels = await guild.GetChannelsAsync(); @@ -239,37 +235,12 @@ private async Task RemovePermissions(IGuild guild) { if (!IsRelevantChannel(channel)) continue; - var permissionOverrides = channel.PermissionOverwrites; + var permissionOverrides = channel.PermissionOverwrites.Where(x => x.TargetId == guild.EveryoneRole.Id); - foreach (var overwrite in permissionOverrides) + if (permissionOverrides.Any()) { - if (overwrite.TargetId == guild.EveryoneRole.Id) - continue; - - switch (overwrite.TargetType) - { - // Remove permission overrides for both roles and users - case PermissionTarget.Role: - { - var role = guild.GetRole(overwrite.TargetId); - if (role != null) - { - await channel.RemovePermissionOverwriteAsync(role).ConfigureAwait(false); - } - - break; - } - case PermissionTarget.User: - { - var user = await guild.GetUserAsync(overwrite.TargetId); - if (user != null) - { - await channel.RemovePermissionOverwriteAsync(user).ConfigureAwait(false); - } - - break; - } - } + await channel.ModifyAsync(x => + x.PermissionOverwrites = new Optional>(permissionOverrides)); } } } @@ -283,97 +254,158 @@ private async Task RemovePermissions(IGuild guild) /// A task that represents the asynchronous operation. public async Task ApplyLockdown(IGuild guild) { - await StoreOriginalPermissions(guild); // Store all permissions first - await RemovePermissions(guild); // Remove all permissions from the channels, including @everyone + await StoreOriginalPermissions(guild); + await RemovePermissions(guild); var everyoneRole = guild.EveryoneRole; var channels = await guild.GetChannelsAsync(); await using var context = await dbContext.GetContextAsync(); - foreach (var channel in channels) - { - if (!IsRelevantChannel(channel)) continue; + var relevantChannels = channels.Where(IsRelevantChannel).ToList(); + var channelPermissions = new List<(IGuildChannel Channel, OverwritePermissions Permissions)>(); - // Retrieve the stored permissions for the @everyone role from the database + foreach (var channel in relevantChannels) + { var storedPerm = await context.LockdownChannelPermissions.FirstOrDefaultAsync(p => p.GuildId == guild.Id && p.ChannelId == channel.Id && p.TargetId == everyoneRole.Id && p.TargetType == PermissionTarget.Role); - var existingPerms = - // Reconstruct OverwritePermissions from stored permissions - storedPerm != null - ? new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions) - : - // No stored permissions; default to InheritAll - OverwritePermissions.InheritAll; + var existingPerms = storedPerm != null + ? new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions) + : OverwritePermissions.InheritAll; var lockdownPerms = channel switch { - // Modify permissions based on channel type - IVoiceChannel => existingPerms.Modify(connect: PermValue.Deny, speak: PermValue.Deny), - IForumChannel => existingPerms.Modify(sendMessages: PermValue.Deny, createPublicThreads: PermValue.Deny, - createPrivateThreads: PermValue.Deny), + IVoiceChannel => existingPerms.Modify(connect: PermValue.Deny, speak: PermValue.Deny, sendMessages: PermValue.Deny, sendMessagesInThreads: PermValue.Deny), + IForumChannel => existingPerms.Modify(sendMessagesInThreads: PermValue.Deny, + createPublicThreads: PermValue.Deny, + createPrivateThreads: PermValue.Deny, + sendMessages: PermValue.Allow), _ => existingPerms.Modify(sendMessages: PermValue.Deny, createPublicThreads: PermValue.Deny, createPrivateThreads: PermValue.Deny) }; - // Apply the modified permissions to the @everyone role - await channel.AddPermissionOverwriteAsync(everyoneRole, lockdownPerms).ConfigureAwait(false); + channelPermissions.Add((channel, lockdownPerms)); + } + + var groupedChannels = channelPermissions.GroupBy(x => x.Channel.GetType()); + + foreach (var group in groupedChannels) + { + if (group.Key == typeof(SocketTextChannel)) + { + var textChannels = group.Select(x => x.Channel).Cast().ToList(); + await ModifyTextChannelsAsync(textChannels, everyoneRole, group.Select(x => x.Permissions)); + } + else if (group.Key == typeof(SocketVoiceChannel)) + { + var voiceChannels = group.Select(x => x.Channel).Cast().ToList(); + await ModifyVoiceChannelsAsync(voiceChannels, everyoneRole, group.Select(x => x.Permissions)); + } + else if (group.Key == typeof(SocketForumChannel)) + { + var forumChannels = group.Select(x => x.Channel).Cast().ToList(); + await ModifyForumChannelsAsync(forumChannels, everyoneRole, group.Select(x => x.Permissions)); + } } } + private static async Task ModifyTextChannelsAsync(List channels, IRole everyoneRole, + IEnumerable permissions) + { + await Task.WhenAll(channels.Select((channel, index) => + channel.ModifyAsync(x => + { + x.PermissionOverwrites = new Optional>( + [ + new Overwrite(everyoneRole.Id, PermissionTarget.Role, permissions.ElementAt(index)) + ] + ); + }) + )); + } + + private static async Task ModifyForumChannelsAsync(List channels, IRole everyoneRole, + IEnumerable permissions) + { + await Task.WhenAll(channels.Select((channel, index) => + channel.ModifyAsync(x => + { + x.PermissionOverwrites = new Optional>( + [ + new Overwrite(everyoneRole.Id, PermissionTarget.Role, permissions.ElementAt(index)) + ] + ); + }) + )); + } + + private static async Task ModifyVoiceChannelsAsync(List channels, IRole everyoneRole, + IEnumerable permissions) + { + await Task.WhenAll(channels.Select((channel, index) => + channel.ModifyAsync(x => + { + x.PermissionOverwrites = new Optional>( + [ + new Overwrite(everyoneRole.Id, PermissionTarget.Role, permissions.ElementAt(index)) + ] + ); + }) + )); + } + /// /// Restores the original permissions for all roles and users in each relevant channel after the lockdown is lifted. /// /// The guild where the lockdown is being lifted. + /// /// A task that represents the asynchronous operation. public async Task RestoreOriginalPermissions(IGuild guild) { await using var context = await dbContext.GetContextAsync(); var channels = await guild.GetChannelsAsync(); - foreach (var channel in channels) + var relevantChannels = channels.Where(IsRelevantChannel).ToList(); + + var guildRoleIds = guild.Roles.Select(r => r.Id).ToHashSet(); + + var guildUserIds = (await guild.GetUsersAsync()).Select(u => u.Id).ToHashSet(); + + var storedPermissionsByChannel = await context.LockdownChannelPermissions + .Where(p => p.GuildId == guild.Id && relevantChannels.Select(c => c.Id).Contains(p.ChannelId)) + .GroupBy(p => p.ChannelId) + .ToDictionaryAsync(g => g.Key, g => g.ToList()); + + foreach (var channel in relevantChannels) { - if (!IsRelevantChannel(channel)) continue; + if (!storedPermissionsByChannel.TryGetValue(channel.Id, out var storedPermissions)) + continue; - var storedPermissions = await context.LockdownChannelPermissions - .Where(p => p.GuildId == guild.Id && p.ChannelId == channel.Id) - .ToListAsync(); + var overwrites = new List(); foreach (var storedPerm in storedPermissions) { + var permissions = new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions); + switch (storedPerm.TargetType) { - case PermissionTarget.Role: - { - var role = guild.GetRole(storedPerm.TargetId); - if (role != null) - { - var permissions = - new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions); - await channel.AddPermissionOverwriteAsync(role, permissions).ConfigureAwait(false); - } - + case PermissionTarget.Role when guildRoleIds.Contains(storedPerm.TargetId): + overwrites.Add(new Overwrite(storedPerm.TargetId, PermissionTarget.Role, permissions)); break; - } - case PermissionTarget.User: - { - var user = await guild.GetUserAsync(storedPerm.TargetId); - if (user != null) - { - var permissions = - new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions); - await channel.AddPermissionOverwriteAsync(user, permissions).ConfigureAwait(false); - } + case PermissionTarget.User when guildUserIds.Contains(storedPerm.TargetId): + overwrites.Add(new Overwrite(storedPerm.TargetId, PermissionTarget.User, permissions)); break; - } + default: + throw new ArgumentOutOfRangeException(); } } - // Remove the restored permissions from the database + await channel.ModifyAsync(x => x.PermissionOverwrites = new Optional>(overwrites)); + context.LockdownChannelPermissions.RemoveRange(storedPermissions); }