Skip to content

Commit

Permalink
fix gpt support, uses the actual api now instead of other libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
SylveonDeko committed Aug 9, 2024
1 parent dd6c008 commit 75122ee
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 86 deletions.
1 change: 0 additions & 1 deletion src/Mewdeko/Mewdeko.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
<PackageReference Include="Npgsql.EntityFrameworkCore.PostgreSQL" Version="8.0.2" />
<PackageReference Include="NsfwSpy" Version="3.5.0"/>
<PackageReference Include="Octokit" Version="12.0.0" />
<PackageReference Include="OpenAI" Version="1.11.0" />
<PackageReference Include="Otp.NET" Version="1.4.0" />
<PackageReference Include="PokeApiNet" Version="4.0.0"/>
<PackageReference Include="QRCoder" Version="1.5.1" />
Expand Down
6 changes: 3 additions & 3 deletions src/Mewdeko/Modules/Highlights/Services/HighlightsService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ public class HighlightsService : INService, IReadyExecutor
/// <param name="client">The discord client</param>
/// <param name="cache">Fusion cache</param>
/// <param name="db">The database provider</param>
public HighlightsService(DiscordShardedClient client, IFusionCache cache, DbContextProvider dbProvider)
public HighlightsService(DiscordShardedClient client, IFusionCache cache, DbContextProvider dbProvider, EventHandler eventHandler)
{
this.client = client;
this.cache = cache;
this.dbProvider = dbProvider;
this.client.MessageReceived += StaggerHighlights;
this.client.UserIsTyping += AddHighlightTimer;
eventHandler.MessageReceived += StaggerHighlights;
eventHandler.UserIsTyping += AddHighlightTimer;
_ = HighlightLoop();
}

Expand Down
254 changes: 172 additions & 82 deletions src/Mewdeko/Modules/OwnerOnly/Services/OwnerOnlyService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.Net.Http;
using System.Text;
using System.Threading;
Expand All @@ -12,9 +13,6 @@
using Microsoft.EntityFrameworkCore;
using Newtonsoft.Json;
using Octokit;
using OpenAI_API;
using OpenAI_API.Chat;
using OpenAI_API.Models;
using Serilog;
using StackExchange.Redis;
using Embed = Discord.Embed;
Expand All @@ -40,7 +38,7 @@ public class OwnerOnlyService : ILateExecutor, IReadyExecutor, INService
private readonly Replacer rep;
private readonly IBotStrings strings;
private readonly GuildSettingsService guildSettings;
private readonly ConcurrentDictionary<ulong, Conversation> conversations = new();
private static readonly Dictionary<ulong, Conversation> UserConversations = new Dictionary<ulong, Conversation>();

#pragma warning disable CS8714
private ConcurrentDictionary<ulong?, ConcurrentDictionary<int, Timer>> autoCommands =
Expand Down Expand Up @@ -88,16 +86,16 @@ public OwnerOnlyService(DiscordShardedClient client, CommandHandler cmdHandler,
httpFactory = factory;
this.bss = bss;
handler.MessageReceived += OnMessageReceived;
rep = new ReplacementBuilder()
.WithClient(client)
.WithProviders(phProviders)
.Build();
rep = new ReplacementBuilder()
.WithClient(client)
.WithProviders(phProviders)
.Build();

_ = Task.Run(RotatingStatuses);
_ = Task.Run(RotatingStatuses);

var sub = redis.GetSubscriber();
sub.Subscribe($"{this.creds.RedisKey()}_reload_images",
delegate { imgs.Reload(); }, CommandFlags.FireAndForget);
sub.Subscribe($"{this.creds.RedisKey()}_reload_images",
delegate { imgs.Reload(); }, CommandFlags.FireAndForget);

sub.Subscribe($"{this.creds.RedisKey()}_leave_guild", async (_, v) =>
{
Expand Down Expand Up @@ -315,103 +313,142 @@ private async Task OnMessageReceived(SocketMessage args)
return;
if (args is not IUserMessage usrMsg)
return;
try
// try
// {
if (args.Content is "deletesession")
{
var api = new OpenAI_API.OpenAIAPI(bss.Data.ChatGptKey);
if (args.Content is "deletesession")
if (UserConversations.TryGetValue(args.Author.Id, out _))
{
if (conversations.TryRemove(args.Author.Id, out _))
{
await usrMsg.SendConfirmReplyAsync("Session deleted");
return;
}

await usrMsg.SendConfirmReplyAsync("No session to delete");
ClearConversation(args.Author.Id);
await args.Channel.SendConfirmAsync("Conversation deleted.");
return;
}

await args.Channel.SendErrorAsync("You dont have a conversation saved.", bss.Data);
return;
}

await using var dbContext = await dbProvider.GetContextAsync();

(Database.Models.OwnerOnly actualItem, bool added) toUpdate = dbContext.OwnerOnly.Any()
? (await dbContext.OwnerOnly.FirstOrDefaultAsync(), false)
: (new Database.Models.OwnerOnly
{
GptTokensUsed = 0
}, true);
await using var dbContext = await dbProvider.GetContextAsync();

if (!conversations.TryGetValue(args.Author.Id, out var conversation))
(Database.Models.OwnerOnly actualItem, bool added) toUpdate = dbContext.OwnerOnly.Any()
? (await dbContext.OwnerOnly.FirstOrDefaultAsync(), false)
: (new Database.Models.OwnerOnly
{
conversation = StartNewConversation(args.Author, api);
conversations.TryAdd(args.Author.Id, conversation);
}

conversation.AppendUserInput(args.Content);
GptTokensUsed = 0
}, true);

var loadingMsg = await usrMsg.Channel.SendConfirmAsync($"{bss.Data.LoadingEmote} Awaiting response...");
await StreamResponseAndUpdateEmbedAsync(bss.Data.ChatGptKey, bss.Data.ChatGptModel,
bss.Data.ChatGptInitPrompt +
$"The users name is {args.Author}, you are in the discord server {guildChannel.Guild} and in the channel {guildChannel} and there are {(await guildChannel.GetUsersAsync().FlattenAsync()).Count()} users that can see this channel.",
loadingMsg, toUpdate, args.Author, args.Content);
//}
// catch (Exception e)
// {
// Log.Warning(e, "Error in ChatGPT");
// await usrMsg.SendErrorReplyAsync("Something went wrong, please try again later.");
// }
}

var loadingMsg = await usrMsg.Channel.SendConfirmAsync($"{bss.Data.LoadingEmote} Awaiting response...");
await StreamResponseAndUpdateEmbedAsync(conversation, loadingMsg, dbProvider, toUpdate, args.Author);
}
catch (Exception e)
{
Log.Warning(e, "Error in ChatGPT");
await usrMsg.SendErrorReplyAsync("Something went wrong, please try again later.");
}
private static void ClearConversation(ulong userId)
{
UserConversations.Remove(userId);
}

private Conversation StartNewConversation(SocketUser user, IOpenAIAPI api)
private async Task StreamResponseAndUpdateEmbedAsync(string apiKey, string model, string systemPrompt,
IUserMessage loadingMsg,
(Database.Models.OwnerOnly actualItem, bool added) toUpdate, SocketUser author, string userPrompt)
{
var modelToUse = bss.Data.ChatGptModel switch
using var httpClient = httpFactory.CreateClient();
httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {apiKey}");

// Get or create conversation for this user
if (!UserConversations.TryGetValue(author.Id, out var conversation))
{
"gpt-4-0613" => Model.GPT4_32k_Context,
"gpt4" or "gpt-4" => Model.GPT4,
_ => Model.ChatGPTTurbo
};
conversation = new Conversation();
conversation.Messages.Add(new Message
{
Role = "system", Content = systemPrompt
});
UserConversations[author.Id] = conversation;
}

var chat = api.Chat.CreateConversation(new ChatRequest
// Add user message to conversation
conversation.Messages.Add(new Message
{
MaxTokens = bss.Data.ChatGptMaxTokens, Temperature = bss.Data.ChatGptTemperature, Model = modelToUse
Role = "user", Content = userPrompt
});
chat.AppendSystemMessage(bss.Data.ChatGptInitPrompt);
chat.AppendSystemMessage($"The user's name is {user}.");
return chat;
}

private static async Task StreamResponseAndUpdateEmbedAsync(Conversation conversation, IUserMessage loadingMsg,
DbContextProvider dbProvider, (Database.Models.OwnerOnly actualItem, bool added) toUpdate, SocketUser author)
{
await using var dbContext = await dbProvider.GetContextAsync();
var requestBody = new
{
model,
messages = conversation.Messages.Select(m => new
{
role = m.Role, content = m.Content
}).ToArray(),
stream = true,
user = author.Id.ToString()
};

var content = new StringContent(System.Text.Json.JsonSerializer.Serialize(requestBody), Encoding.UTF8,
"application/json");

using var response = await httpClient.PostAsync("https://api.openai.com/v1/chat/completions", content);
response.EnsureSuccessStatusCode();

await using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);

var responseBuilder = new StringBuilder();
var lastUpdate = DateTimeOffset.UtcNow;
var totalTokens = 0;

await conversation.StreamResponseFromChatbotAsync(async partialResponse =>
while (!reader.EndOfStream)
{
responseBuilder.Append(partialResponse);
if (!((DateTimeOffset.UtcNow - lastUpdate).TotalSeconds >= 1)) return;
var line = await reader.ReadLineAsync();
if (string.IsNullOrEmpty(line) || line == "data: [DONE]") continue;

if (!line.StartsWith("data: ")) continue;
var json = line[6..];
var chatResponse = JsonConvert.DeserializeObject<ChatCompletionChunkResponse>(json);

if (chatResponse?.Choices is not { Count: > 0 }) continue;
var conversationContent = chatResponse.Choices[0].Delta?.Content;
if (string.IsNullOrEmpty(conversationContent)) continue;
responseBuilder.Append(conversationContent);
if (!((DateTimeOffset.UtcNow - lastUpdate).TotalSeconds >= 1)) continue;
lastUpdate = DateTimeOffset.UtcNow;
var embeds = BuildEmbeds(responseBuilder.ToString(), author, toUpdate.actualItem.GptTokensUsed,
conversation);
var embeds = BuildEmbeds(responseBuilder.ToString(), author,
toUpdate.actualItem.GptTokensUsed + totalTokens);
await loadingMsg.ModifyAsync(m => m.Embeds = embeds.ToArray());
}

// Add assistant's response to the conversation
conversation.Messages.Add(new Message
{
Role = "assistant", Content = responseBuilder.ToString()
});

var finalResponse = responseBuilder.ToString();
if (conversation.MostRecentApiResult.Usage != null)
// Trim conversation history if it gets too long
if (conversation.Messages.Count > 10)
{
toUpdate.actualItem.GptTokensUsed += conversation.MostRecentApiResult.Usage.TotalTokens;
conversation.Messages = conversation.Messages.Skip(conversation.Messages.Count - 10).ToList();
}

await using var dbContext = await dbProvider.GetContextAsync();
toUpdate.actualItem.GptTokensUsed += totalTokens;

if (toUpdate.added)
dbContext.OwnerOnly.Add(toUpdate.actualItem);
else
dbContext.OwnerOnly.Update(toUpdate.actualItem);
await dbContext.SaveChangesAsync();

var finalEmbeds = BuildEmbeds(finalResponse, author, toUpdate.actualItem.GptTokensUsed, conversation);
var finalEmbeds = BuildEmbeds(responseBuilder.ToString(), author, toUpdate.actualItem.GptTokensUsed);
await loadingMsg.ModifyAsync(m => m.Embeds = finalEmbeds.ToArray());
}

private static List<Embed> BuildEmbeds(string response, IUser requester, int totalTokensUsed,
Conversation conversation)
private static List<Embed> BuildEmbeds(string response, IUser requester, int totalTokensUsed)
{
var embeds = new List<Embed>();
var partIndex = 0;
Expand All @@ -429,7 +466,7 @@ private static List<Embed> BuildEmbeds(string response, IUser requester, int tot

if (partIndex + length == response.Length)
embedBuilder.WithFooter(
$"Requested by {requester.Username} | Response Tokens: {conversation.MostRecentApiResult.Usage?.TotalTokens} | Total Used: {totalTokensUsed}");
$"Requested by {requester.Username} | Total Tokens Used: {totalTokensUsed}");

embeds.Add(embedBuilder.Build());
partIndex += length;
Expand Down Expand Up @@ -531,13 +568,13 @@ public async Task OnReadyAsync()
(await dbContext.AutoCommands
.AsNoTracking()
.ToListAsyncEF())
.Where(x => x.Interval >= 5)
.AsEnumerable()
.GroupBy(x => x.GuildId)
.ToDictionary(x => x.Key,
y => y.ToDictionary(x => x.Id, TimerFromAutoCommand)
.ToConcurrent())
.ToConcurrent();
.Where(x => x.Interval >= 5)
.AsEnumerable()
.GroupBy(x => x.GuildId)
.ToDictionary(x => x.Key,
y => y.ToDictionary(x => x.Id, TimerFromAutoCommand)
.ToConcurrent())
.ToConcurrent();

foreach (var cmd in dbContext.AutoCommands.AsNoTracking().Where(x => x.Interval == 0))
{
Expand Down Expand Up @@ -584,7 +621,8 @@ private async Task RotatingStatuses()

if (!bss.Data.RotateStatuses) continue;

IReadOnlyList<RotatingPlayingStatus> rotatingStatuses = await dbContext.RotatingStatus.AsNoTracking().OrderBy(x => x.Id).ToListAsyncEF();
IReadOnlyList<RotatingPlayingStatus> rotatingStatuses =
await dbContext.RotatingStatus.AsNoTracking().OrderBy(x => x.Id).ToListAsyncEF();

if (rotatingStatuses.Count == 0)
continue;
Expand Down Expand Up @@ -707,8 +745,8 @@ public async Task AddNewAutoCommand(AutoCommand cmd)
{
await using var dbContext = await dbProvider.GetContextAsync();

dbContext.AutoCommands.Add(cmd);
await dbContext.SaveChangesAsync();
dbContext.AutoCommands.Add(cmd);
await dbContext.SaveChangesAsync();

if (cmd.Interval >= 5)
{
Expand Down Expand Up @@ -813,7 +851,6 @@ public async Task<bool> RemoveStartupCommand(int index)
dbContext.Remove(cmd);
await dbContext.SaveChangesAsync();
return true;

}

/// <summary>
Expand Down Expand Up @@ -955,4 +992,57 @@ public bool ForwardToAll()
bss.ModifyConfig(config => isToAll = config.ForwardToAllOwners = !config.ForwardToAllOwners);
return isToAll;
}

private class Choice
{
[JsonProperty("index", NullValueHandling = NullValueHandling.Ignore)]
public int? Index;

[JsonProperty("delta", NullValueHandling = NullValueHandling.Ignore)]
public Delta Delta;

[JsonProperty("logprobs", NullValueHandling = NullValueHandling.Ignore)]
public object Logprobs;

[JsonProperty("finish_reason", NullValueHandling = NullValueHandling.Ignore)]
public object FinishReason;
}

private class Delta
{
[JsonProperty("content", NullValueHandling = NullValueHandling.Ignore)]
public string Content;
}

private class ChatCompletionChunkResponse
{
[JsonProperty("id", NullValueHandling = NullValueHandling.Ignore)]
public string Id;

[JsonProperty("object", NullValueHandling = NullValueHandling.Ignore)]
public string Object;

[JsonProperty("created", NullValueHandling = NullValueHandling.Ignore)]
public int? Created;

[JsonProperty("model", NullValueHandling = NullValueHandling.Ignore)]
public string Model;

[JsonProperty("system_fingerprint", NullValueHandling = NullValueHandling.Ignore)]
public string SystemFingerprint;

[JsonProperty("choices", NullValueHandling = NullValueHandling.Ignore)]
public List<Choice> Choices;
}

private class Conversation
{
public List<Message> Messages { get; set; } = new List<Message>();
}

private class Message
{
public string Role { get; set; }
public string Content { get; set; }
}
}

0 comments on commit 75122ee

Please sign in to comment.