Skip to content

Commit

Permalink
fix used token count
Browse files Browse the repository at this point in the history
  • Loading branch information
SylveonDeko committed Aug 9, 2024
1 parent 75122ee commit 7c82768
Showing 1 changed file with 58 additions and 36 deletions.
94 changes: 58 additions & 36 deletions src/Mewdeko/Modules/OwnerOnly/Services/OwnerOnlyService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class OwnerOnlyService : ILateExecutor, IReadyExecutor, INService
private readonly Replacer rep;
private readonly IBotStrings strings;
private readonly GuildSettingsService guildSettings;
private static readonly Dictionary<ulong, Conversation> UserConversations = new Dictionary<ulong, Conversation>();
private static readonly Dictionary<ulong, Conversation> UserConversations = new();

#pragma warning disable CS8714
private ConcurrentDictionary<ulong?, ConcurrentDictionary<int, Timer>> autoCommands =
Expand Down Expand Up @@ -313,41 +313,41 @@ private async Task OnMessageReceived(SocketMessage args)
return;
if (args is not IUserMessage usrMsg)
return;
// try
// {
if (args.Content is "deletesession")
try
{
if (UserConversations.TryGetValue(args.Author.Id, out _))
if (args.Content is "deletesession")
{
ClearConversation(args.Author.Id);
await args.Channel.SendConfirmAsync("Conversation deleted.");
if (UserConversations.TryGetValue(args.Author.Id, out _))
{
ClearConversation(args.Author.Id);
await args.Channel.SendConfirmAsync("Conversation deleted.");
return;
}

await args.Channel.SendErrorAsync("You dont have a conversation saved.", bss.Data);
return;
}

await args.Channel.SendErrorAsync("You dont have a conversation saved.", bss.Data);
return;
}
await using var dbContext = await dbProvider.GetContextAsync();

await using var dbContext = await dbProvider.GetContextAsync();

(Database.Models.OwnerOnly actualItem, bool added) toUpdate = dbContext.OwnerOnly.Any()
? (await dbContext.OwnerOnly.FirstOrDefaultAsync(), false)
: (new Database.Models.OwnerOnly
{
GptTokensUsed = 0
}, true);

var loadingMsg = await usrMsg.Channel.SendConfirmAsync($"{bss.Data.LoadingEmote} Awaiting response...");
await StreamResponseAndUpdateEmbedAsync(bss.Data.ChatGptKey, bss.Data.ChatGptModel,
bss.Data.ChatGptInitPrompt +
$"The users name is {args.Author}, you are in the discord server {guildChannel.Guild} and in the channel {guildChannel} and there are {(await guildChannel.GetUsersAsync().FlattenAsync()).Count()} users that can see this channel.",
loadingMsg, toUpdate, args.Author, args.Content);
//}
// catch (Exception e)
// {
// Log.Warning(e, "Error in ChatGPT");
// await usrMsg.SendErrorReplyAsync("Something went wrong, please try again later.");
// }
(Database.Models.OwnerOnly actualItem, bool added) toUpdate = dbContext.OwnerOnly.Any()
? (await dbContext.OwnerOnly.FirstOrDefaultAsync(), false)
: (new Database.Models.OwnerOnly
{
GptTokensUsed = 0
}, true);

var loadingMsg = await usrMsg.Channel.SendConfirmAsync($"{bss.Data.LoadingEmote} Awaiting response...");
await StreamResponseAndUpdateEmbedAsync(bss.Data.ChatGptKey, bss.Data.ChatGptModel,
bss.Data.ChatGptInitPrompt +
$"The users name is {args.Author}, you are in the discord server {guildChannel.Guild} and in the channel {guildChannel} and there are {(await guildChannel.GetUsersAsync().FlattenAsync()).Count()} users that can see this channel.",
loadingMsg, toUpdate, args.Author, args.Content);
}
catch (Exception e)
{
Log.Warning(e, "Error in ChatGPT");
await usrMsg.SendErrorReplyAsync("Something went wrong, please try again later.");
}
}

private static void ClearConversation(ulong userId)
Expand All @@ -362,7 +362,6 @@ private async Task StreamResponseAndUpdateEmbedAsync(string apiKey, string model
using var httpClient = httpFactory.CreateClient();
httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {apiKey}");

// Get or create conversation for this user
if (!UserConversations.TryGetValue(author.Id, out var conversation))
{
conversation = new Conversation();
Expand All @@ -373,7 +372,6 @@ private async Task StreamResponseAndUpdateEmbedAsync(string apiKey, string model
UserConversations[author.Id] = conversation;
}

// Add user message to conversation
conversation.Messages.Add(new Message
{
Role = "user", Content = userPrompt
Expand All @@ -387,7 +385,11 @@ private async Task StreamResponseAndUpdateEmbedAsync(string apiKey, string model
role = m.Role, content = m.Content
}).ToArray(),
stream = true,
user = author.Id.ToString()
user = author.Id.ToString(),
stream_options = new
{
include_usage = true
}
};

var content = new StringContent(System.Text.Json.JsonSerializer.Serialize(requestBody), Encoding.UTF8,
Expand All @@ -402,6 +404,7 @@ private async Task StreamResponseAndUpdateEmbedAsync(string apiKey, string model
var responseBuilder = new StringBuilder();
var lastUpdate = DateTimeOffset.UtcNow;
var totalTokens = 0;
var lastResponse = new ChatCompletionChunkResponse();

while (!reader.EndOfStream)
{
Expand All @@ -411,15 +414,17 @@ private async Task StreamResponseAndUpdateEmbedAsync(string apiKey, string model
if (!line.StartsWith("data: ")) continue;
var json = line[6..];
var chatResponse = JsonConvert.DeserializeObject<ChatCompletionChunkResponse>(json);

lastResponse = chatResponse;
if (chatResponse?.Choices is not { Count: > 0 }) continue;
var conversationContent = chatResponse.Choices[0].Delta?.Content;
if (string.IsNullOrEmpty(conversationContent)) continue;
responseBuilder.Append(conversationContent);
if (!((DateTimeOffset.UtcNow - lastUpdate).TotalSeconds >= 1)) continue;
lastUpdate = DateTimeOffset.UtcNow;
var embeds = BuildEmbeds(responseBuilder.ToString(), author,
toUpdate.actualItem.GptTokensUsed + totalTokens);
chatResponse.Usage is not null ? toUpdate.actualItem.GptTokensUsed + chatResponse.Usage.TotalTokens : toUpdate.actualItem.GptTokensUsed);
if (chatResponse.Usage is not null)
totalTokens += chatResponse.Usage.TotalTokens;
await loadingMsg.ModifyAsync(m => m.Embeds = embeds.ToArray());
}

Expand All @@ -436,7 +441,7 @@ private async Task StreamResponseAndUpdateEmbedAsync(string apiKey, string model
}

await using var dbContext = await dbProvider.GetContextAsync();
toUpdate.actualItem.GptTokensUsed += totalTokens;
toUpdate.actualItem.GptTokensUsed += lastResponse.Usage.TotalTokens;

if (toUpdate.added)
dbContext.OwnerOnly.Add(toUpdate.actualItem);
Expand Down Expand Up @@ -1033,6 +1038,21 @@ private class ChatCompletionChunkResponse

[JsonProperty("choices", NullValueHandling = NullValueHandling.Ignore)]
public List<Choice> Choices;

[JsonProperty("usage", NullValueHandling = NullValueHandling.Include)]
public Usage? Usage;
}

private class Usage
{
[JsonProperty("prompt_tokens", NullValueHandling = NullValueHandling.Ignore)]
public int PromptTokens;

[JsonProperty("completion_tokens", NullValueHandling = NullValueHandling.Ignore)]
public int CompletionTokens;

[JsonProperty("total_tokens", NullValueHandling = NullValueHandling.Ignore)]
public int TotalTokens;
}

private class Conversation
Expand All @@ -1045,4 +1065,6 @@ private class Message
public string Role { get; set; }
public string Content { get; set; }
}


}

0 comments on commit 7c82768

Please sign in to comment.