Skip to content

Commit

Permalink
make lockdown like 5x faster using bulk ops lmao
Browse files Browse the repository at this point in the history
  • Loading branch information
SylveonDeko committed Oct 4, 2024
1 parent 8fc625f commit 56564d1
Showing 1 changed file with 136 additions and 104 deletions.
240 changes: 136 additions & 104 deletions src/Mewdeko/Modules/Server Management/Services/ChannelCommandService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -193,83 +193,54 @@ public async Task StoreOriginalPermissions(IGuild guild)
await using var context = await dbContext.GetContextAsync();
var channels = await guild.GetChannelsAsync();

foreach (var channel in channels)
{
if (!IsRelevantChannel(channel)) continue;

var permissionOverrides = channel.PermissionOverwrites;
var existingPermissions = await context.LockdownChannelPermissions
.Where(p => p.GuildId == guild.Id)
.ToListAsync();

foreach (var overwrite in permissionOverrides)
var newPermissions = (from channel in channels
where IsRelevantChannel(channel)
let permissionOverwrites = channel.PermissionOverwrites
from overwrite in permissionOverwrites
let existingEntry =
existingPermissions.FirstOrDefault(p => p.ChannelId == channel.Id && p.TargetId == overwrite.TargetId)
where existingEntry == null
select new LockdownChannelPermissions
{
// Store overrides for both roles and users
var existingEntry = await context.LockdownChannelPermissions
.FirstOrDefaultAsync(p =>
p.GuildId == guild.Id && p.ChannelId == channel.Id && p.TargetId == overwrite.TargetId);

if (existingEntry != null) continue;

// Add new entry for each permission override
var newPermission = new LockdownChannelPermissions
{
GuildId = guild.Id,
ChannelId = channel.Id,
TargetId = overwrite.TargetId,
TargetType = overwrite.TargetType, // Role or User
AllowPermissions = GetRawPermissionValue(overwrite.Permissions.ToAllowList()),
DenyPermissions = GetRawPermissionValue(overwrite.Permissions.ToDenyList())
};

await context.LockdownChannelPermissions.AddAsync(newPermission);
}
GuildId = guild.Id,
ChannelId = channel.Id,
TargetId = overwrite.TargetId,
TargetType = overwrite.TargetType, // Role or User
AllowPermissions = GetRawPermissionValue(overwrite.Permissions.ToAllowList()),
DenyPermissions = GetRawPermissionValue(overwrite.Permissions.ToDenyList())
}).ToList();

// Add all new permissions in one batch
if (newPermissions.Count != 0)
{
await context.LockdownChannelPermissions.AddRangeAsync(newPermissions);
await context.SaveChangesAsync();
}

await context.SaveChangesAsync();
}

/// <summary>
/// Removes all permission overrides from all channels in the guild.
/// </summary>
/// <param name="guild">The guild whose permissions are being removed.</param>
/// <returns>A task that represents the asynchronous operation.</returns>
private async Task RemovePermissions(IGuild guild)
private static async Task RemovePermissions(IGuild guild)
{
var channels = await guild.GetChannelsAsync();

foreach (var channel in channels)
{
if (!IsRelevantChannel(channel)) continue;

var permissionOverrides = channel.PermissionOverwrites;
var permissionOverrides = channel.PermissionOverwrites.Where(x => x.TargetId == guild.EveryoneRole.Id);

foreach (var overwrite in permissionOverrides)
if (permissionOverrides.Any())
{
if (overwrite.TargetId == guild.EveryoneRole.Id)
continue;

switch (overwrite.TargetType)
{
// Remove permission overrides for both roles and users
case PermissionTarget.Role:
{
var role = guild.GetRole(overwrite.TargetId);
if (role != null)
{
await channel.RemovePermissionOverwriteAsync(role).ConfigureAwait(false);
}

break;
}
case PermissionTarget.User:
{
var user = await guild.GetUserAsync(overwrite.TargetId);
if (user != null)
{
await channel.RemovePermissionOverwriteAsync(user).ConfigureAwait(false);
}

break;
}
}
await channel.ModifyAsync(x =>
x.PermissionOverwrites = new Optional<IEnumerable<Overwrite>>(permissionOverrides));
}
}
}
Expand All @@ -283,97 +254,158 @@ private async Task RemovePermissions(IGuild guild)
/// <returns>A task that represents the asynchronous operation.</returns>
public async Task ApplyLockdown(IGuild guild)
{
await StoreOriginalPermissions(guild); // Store all permissions first
await RemovePermissions(guild); // Remove all permissions from the channels, including @everyone
await StoreOriginalPermissions(guild);
await RemovePermissions(guild);

var everyoneRole = guild.EveryoneRole;
var channels = await guild.GetChannelsAsync();

await using var context = await dbContext.GetContextAsync();

foreach (var channel in channels)
{
if (!IsRelevantChannel(channel)) continue;
var relevantChannels = channels.Where(IsRelevantChannel).ToList();
var channelPermissions = new List<(IGuildChannel Channel, OverwritePermissions Permissions)>();

// Retrieve the stored permissions for the @everyone role from the database
foreach (var channel in relevantChannels)
{
var storedPerm = await context.LockdownChannelPermissions.FirstOrDefaultAsync(p =>
p.GuildId == guild.Id && p.ChannelId == channel.Id && p.TargetId == everyoneRole.Id &&
p.TargetType == PermissionTarget.Role);

var existingPerms =
// Reconstruct OverwritePermissions from stored permissions
storedPerm != null
? new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions)
:
// No stored permissions; default to InheritAll
OverwritePermissions.InheritAll;
var existingPerms = storedPerm != null
? new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions)
: OverwritePermissions.InheritAll;

var lockdownPerms = channel switch
{
// Modify permissions based on channel type
IVoiceChannel => existingPerms.Modify(connect: PermValue.Deny, speak: PermValue.Deny),
IForumChannel => existingPerms.Modify(sendMessages: PermValue.Deny, createPublicThreads: PermValue.Deny,
createPrivateThreads: PermValue.Deny),
IVoiceChannel => existingPerms.Modify(connect: PermValue.Deny, speak: PermValue.Deny, sendMessages: PermValue.Deny, sendMessagesInThreads: PermValue.Deny),
IForumChannel => existingPerms.Modify(sendMessagesInThreads: PermValue.Deny,
createPublicThreads: PermValue.Deny,
createPrivateThreads: PermValue.Deny,
sendMessages: PermValue.Allow),
_ => existingPerms.Modify(sendMessages: PermValue.Deny, createPublicThreads: PermValue.Deny,
createPrivateThreads: PermValue.Deny)
};

// Apply the modified permissions to the @everyone role
await channel.AddPermissionOverwriteAsync(everyoneRole, lockdownPerms).ConfigureAwait(false);
channelPermissions.Add((channel, lockdownPerms));
}

var groupedChannels = channelPermissions.GroupBy(x => x.Channel.GetType());

foreach (var group in groupedChannels)
{
if (group.Key == typeof(SocketTextChannel))
{
var textChannels = group.Select(x => x.Channel).Cast<ITextChannel>().ToList();
await ModifyTextChannelsAsync(textChannels, everyoneRole, group.Select(x => x.Permissions));
}
else if (group.Key == typeof(SocketVoiceChannel))
{
var voiceChannels = group.Select(x => x.Channel).Cast<IVoiceChannel>().ToList();
await ModifyVoiceChannelsAsync(voiceChannels, everyoneRole, group.Select(x => x.Permissions));
}
else if (group.Key == typeof(SocketForumChannel))
{
var forumChannels = group.Select(x => x.Channel).Cast<IForumChannel>().ToList();
await ModifyForumChannelsAsync(forumChannels, everyoneRole, group.Select(x => x.Permissions));
}
}
}

private static async Task ModifyTextChannelsAsync(List<ITextChannel> channels, IRole everyoneRole,
IEnumerable<OverwritePermissions> permissions)
{
await Task.WhenAll(channels.Select((channel, index) =>
channel.ModifyAsync(x =>
{
x.PermissionOverwrites = new Optional<IEnumerable<Overwrite>>(
[
new Overwrite(everyoneRole.Id, PermissionTarget.Role, permissions.ElementAt(index))
]
);
})
));
}

private static async Task ModifyForumChannelsAsync(List<IForumChannel> channels, IRole everyoneRole,
IEnumerable<OverwritePermissions> permissions)
{
await Task.WhenAll(channels.Select((channel, index) =>
channel.ModifyAsync(x =>
{
x.PermissionOverwrites = new Optional<IEnumerable<Overwrite>>(
[
new Overwrite(everyoneRole.Id, PermissionTarget.Role, permissions.ElementAt(index))
]
);
})
));
}

private static async Task ModifyVoiceChannelsAsync(List<IVoiceChannel> channels, IRole everyoneRole,
IEnumerable<OverwritePermissions> permissions)
{
await Task.WhenAll(channels.Select((channel, index) =>
channel.ModifyAsync(x =>
{
x.PermissionOverwrites = new Optional<IEnumerable<Overwrite>>(
[
new Overwrite(everyoneRole.Id, PermissionTarget.Role, permissions.ElementAt(index))
]
);
})
));
}


/// <summary>
/// Restores the original permissions for all roles and users in each relevant channel after the lockdown is lifted.
/// </summary>
/// <param name="guild">The guild where the lockdown is being lifted.</param>
/// <exception cref="ArgumentOutOfRangeException"></exception>
/// <returns>A task that represents the asynchronous operation.</returns>
public async Task RestoreOriginalPermissions(IGuild guild)
{
await using var context = await dbContext.GetContextAsync();
var channels = await guild.GetChannelsAsync();

foreach (var channel in channels)
var relevantChannels = channels.Where(IsRelevantChannel).ToList();

var guildRoleIds = guild.Roles.Select(r => r.Id).ToHashSet();

var guildUserIds = (await guild.GetUsersAsync()).Select(u => u.Id).ToHashSet();

var storedPermissionsByChannel = await context.LockdownChannelPermissions
.Where(p => p.GuildId == guild.Id && relevantChannels.Select(c => c.Id).Contains(p.ChannelId))
.GroupBy(p => p.ChannelId)
.ToDictionaryAsync(g => g.Key, g => g.ToList());

foreach (var channel in relevantChannels)
{
if (!IsRelevantChannel(channel)) continue;
if (!storedPermissionsByChannel.TryGetValue(channel.Id, out var storedPermissions))
continue;

var storedPermissions = await context.LockdownChannelPermissions
.Where(p => p.GuildId == guild.Id && p.ChannelId == channel.Id)
.ToListAsync();
var overwrites = new List<Overwrite>();

foreach (var storedPerm in storedPermissions)
{
var permissions = new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions);

switch (storedPerm.TargetType)
{
case PermissionTarget.Role:
{
var role = guild.GetRole(storedPerm.TargetId);
if (role != null)
{
var permissions =
new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions);
await channel.AddPermissionOverwriteAsync(role, permissions).ConfigureAwait(false);
}

case PermissionTarget.Role when guildRoleIds.Contains(storedPerm.TargetId):
overwrites.Add(new Overwrite(storedPerm.TargetId, PermissionTarget.Role, permissions));
break;
}
case PermissionTarget.User:
{
var user = await guild.GetUserAsync(storedPerm.TargetId);
if (user != null)
{
var permissions =
new OverwritePermissions(storedPerm.AllowPermissions, storedPerm.DenyPermissions);
await channel.AddPermissionOverwriteAsync(user, permissions).ConfigureAwait(false);
}

case PermissionTarget.User when guildUserIds.Contains(storedPerm.TargetId):
overwrites.Add(new Overwrite(storedPerm.TargetId, PermissionTarget.User, permissions));
break;
}
default:
throw new ArgumentOutOfRangeException();
}
}

// Remove the restored permissions from the database
await channel.ModifyAsync(x => x.PermissionOverwrites = new Optional<IEnumerable<Overwrite>>(overwrites));

context.LockdownChannelPermissions.RemoveRange(storedPermissions);
}

Expand Down

0 comments on commit 56564d1

Please sign in to comment.