Skip to content

Commit

Permalink
use newer version of chat api
Browse files Browse the repository at this point in the history
  • Loading branch information
sdcb committed Dec 17, 2024
1 parent 5459fb4 commit 66e850a
Show file tree
Hide file tree
Showing 14 changed files with 321 additions and 69 deletions.
71 changes: 49 additions & 22 deletions src/BE/Controllers/Chats/Chats/ChatController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
namespace Chats.BE.Controllers.Chats.Chats;

[Route("api/chats"), Authorize]
public class ChatController(ChatsDB db, CurrentUser currentUser, ILogger<ChatController> logger, IUrlEncryptionService idEncryption) : ControllerBase
public class ChatController(
ChatsDB db,
CurrentUser currentUser,
ILogger<ChatController> logger,
IUrlEncryptionService idEncryption,
ChatStopService stopService) : ControllerBase
{
[HttpPost]
public async Task<IActionResult> StartConversationStreamed(
Expand All @@ -30,7 +35,7 @@ public async Task<IActionResult> StartConversationStreamed(
[FromServices] ChatFactory conversationFactory,
[FromServices] UserModelManager userModelManager,
[FromServices] ClientInfoManager clientInfoManager,
[FromServices] FileUrlProvider fileDownloadUrlProvider,
[FromServices] FileUrlProvider fup,
CancellationToken cancellationToken)
{
InChatContext icc = new();
Expand Down Expand Up @@ -112,20 +117,21 @@ public async Task<IActionResult> StartConversationStreamed(
List<OpenAIChatMessage> messageToSend =
[
..(systemMessage != null ? [new SystemChatMessage(systemMessage.Content[0].ToString())] : Array.Empty<OpenAIChatMessage>()),
..await GetMessageTree(existingMessages, messageId).ToAsyncEnumerable().SelectAwait(async x => await x.ToOpenAI(fileDownloadUrlProvider, cancellationToken)).ToArrayAsync(cancellationToken),
..await GetMessageTree(existingMessages, messageId).ToAsyncEnumerable().SelectAwait(async x => await x.ToOpenAI(fup, cancellationToken)).ToArrayAsync(cancellationToken),
];

// new user message
MessageLiteDto userMessage;
MessageLiteDto userMessageLite;
Message? dbUserMessage = null;
if (messageId != null && existingMessages.TryGetValue(messageId.Value, out MessageLiteDto? parentMessage) && parentMessage.Role == DBChatRole.User)
{
// existing user message
userMessage = existingMessages[messageId!.Value];
userMessageLite = existingMessages[messageId!.Value];
}
else
{
// insert new user message
Message dbUserMessage = new()
dbUserMessage = new()
{
ChatId = chatId,
ChatRoleId = (byte)DBChatRole.User,
Expand All @@ -135,18 +141,19 @@ ..await GetMessageTree(existingMessages, messageId).ToAsyncEnumerable().SelectAw
};
db.Messages.Add(dbUserMessage);
await db.SaveChangesAsync(cancellationToken);
userMessage = new()
userMessageLite = new()
{
Id = dbUserMessage.Id,
Content = request.UserMessage.ToMessageContents(idEncryption),
Role = (DBChatRole)dbUserMessage.ChatRoleId,
ParentId = dbUserMessage.ParentId,
};
messageToSend.Add(await userMessage.ToOpenAI(fileDownloadUrlProvider, cancellationToken));
messageToSend.Add(await userMessageLite.ToOpenAI(fup, cancellationToken));
}

string? errorText = null;
bool everYield = false;
string? stopId = null;
try
{
if (userModel == null)
Expand All @@ -165,9 +172,11 @@ ..await GetMessageTree(existingMessages, messageId).ToAsyncEnumerable().SelectAw
Response.Headers.ContentType = "text/event-stream";
Response.Headers.CacheControl = "no-cache";
Response.Headers.Connection = "keep-alive";
stopId = stopService.CreateAndCombineCancellationToken(ref cancellationToken);
await YieldResponse(SseResponseLine.CreateStopId(stopId));
everYield = true;
}
await YieldResponse(new() { Result = seg.TextSegment, Success = true });
await YieldResponse(SseResponseLine.CreateSegment(seg.TextSegment));

if (cancellationToken.IsCancellationRequested)
{
Expand All @@ -184,7 +193,7 @@ ..await GetMessageTree(existingMessages, messageId).ToAsyncEnumerable().SelectAw
{
icc.FinishReason = DBFinishReason.UpstreamError;
errorText = e.Message;
logger.LogError(e, "Upstream error: {userMessageId}", userMessage.Id);
logger.LogError(e, "Upstream error: {userMessageId}", userMessageLite.Id);
}
catch (TaskCanceledException)
{
Expand All @@ -196,18 +205,22 @@ ..await GetMessageTree(existingMessages, messageId).ToAsyncEnumerable().SelectAw
{
icc.FinishReason = DBFinishReason.UnknownError;
errorText = "Unknown Error";
logger.LogError(e, "Error in conversation for message: {userMessageId}", userMessage.Id);
logger.LogError(e, "Error in conversation for message: {userMessageId}", userMessageLite.Id);
}
finally
{
// cancel the conversation because following code is credit deduction related
cancellationToken = CancellationToken.None;
if (stopId != null)
{
stopService.Remove(stopId);
}
}

// success
// insert new assistant message
InternalChatSegment fullResponse = icc.FullResponse;
Message assistantMessage = new()
Message dbAssistantMessage = new()
{
ChatId = chatId,
ChatRoleId = (byte)DBChatRole.Assistant,
Expand All @@ -216,34 +229,35 @@ ..await GetMessageTree(existingMessages, messageId).ToAsyncEnumerable().SelectAw
MessageContent.FromText(fullResponse.TextSegment),
],
CreatedAt = DateTime.UtcNow,
ParentId = userMessage.Id,
ParentId = userMessageLite.Id,
};

if (errorText != null)
{
assistantMessage.MessageContents.Add(MessageContent.FromError(errorText));
await YieldResponse(new() { Result = errorText, Success = false });
dbAssistantMessage.MessageContents.Add(MessageContent.FromError(errorText));
await YieldResponse(SseResponseLine.CreateError(errorText));
}
assistantMessage.Usage = icc.ToUserModelUsage(currentUser.Id, await clientInfoManager.GetClientInfo(cancellationToken), isApi: false);
db.Messages.Add(assistantMessage);
dbAssistantMessage.Usage = icc.ToUserModelUsage(currentUser.Id, await clientInfoManager.GetClientInfo(cancellationToken), isApi: false);
db.Messages.Add(dbAssistantMessage);

await db.SaveChangesAsync(cancellationToken);
if (icc.Cost.CostBalance > 0)
{
_ = balanceService.AsyncUpdateBalance(currentUser.Id, CancellationToken.None);
await balanceService.UpdateBalance(db, currentUser.Id, CancellationToken.None);
}
if (icc.Cost.CostUsage)
{
_ = balanceService.AsyncUpdateUsage([userModel!.Id], CancellationToken.None);
await balanceService.UpdateUsage(db, userModel!.Id, CancellationToken.None);
}

await YieldResponse(SseResponseLine.CreateEnd(dbUserMessage, dbAssistantMessage, idEncryption, fup));
return new EmptyResult();
}

private readonly static ReadOnlyMemory<byte> dataU8 = "data:"u8.ToArray();
private readonly static ReadOnlyMemory<byte> lfu8 = "\n"u8.ToArray();
private readonly static ReadOnlyMemory<byte> dataU8 = "data: "u8.ToArray();
private readonly static ReadOnlyMemory<byte> lfu8 = "\r\n\r\n"u8.ToArray();

private async Task YieldResponse(SseResponseLine line)
private async Task YieldResponse<T>(SseResponseLine<T> line)
{
await Response.Body.WriteAsync(dataU8);
await Response.Body.WriteAsync(JsonSerializer.SerializeToUtf8Bytes(line, JSON.JsonSerializerOptions));
Expand All @@ -266,4 +280,17 @@ static LinkedList<MessageLiteDto> GetMessageTree(Dictionary<long, MessageLiteDto
}
return line;
}

[HttpPost("stop/{stopId}")]
public IActionResult StopChat(string stopId)
{
if (stopService.TryCancel(stopId))
{
return Ok();
}
else
{
return NotFound();
}
}
}
1 change: 0 additions & 1 deletion src/BE/Controllers/Chats/Chats/Dtos/MessageLiteDto.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using Chats.BE.DB;
using Chats.BE.DB.Enums;
using Chats.BE.DB.Extensions;
using Chats.BE.Services.ChatServices;
using Chats.BE.Services.FileServices;
using OpenAI.Chat;
Expand Down
13 changes: 13 additions & 0 deletions src/BE/Controllers/Chats/Chats/Dtos/SseEndMessage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Chats.BE.Controllers.Chats.Messages.Dtos;
using System.Text.Json.Serialization;

namespace Chats.BE.Controllers.Chats.Chats.Dtos;

public record SseEndMessage
{
[JsonPropertyName("requestMessage")]
public required MessageDto? RequestMessage { get; init; }

[JsonPropertyName("responseMessage")]
public required MessageDto ResponseMessage { get; init; }
}
9 changes: 9 additions & 0 deletions src/BE/Controllers/Chats/Chats/Dtos/SseResponseKind.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Chats.BE.Controllers.Chats.Chats.Dtos;

public enum SseResponseKind
{
StopId = 0,
Segment = 1,
Error = 2,
End = 3,
}
95 changes: 89 additions & 6 deletions src/BE/Controllers/Chats/Chats/Dtos/SseResponseLine.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,95 @@
using System.Text.Json.Serialization;
using Chats.BE.Controllers.Chats.Messages.Dtos;
using Chats.BE.DB;
using Chats.BE.Services.ChatServices;
using Chats.BE.Services.FileServices;
using Chats.BE.Services.UrlEncryption;
using System.Text.Json.Serialization;

namespace Chats.BE.Controllers.Chats.Chats.Dtos;

public record SseResponseLine
public record SseResponseLine<T>
{
[JsonPropertyName("result")]
public required string Result { get; init; }
[JsonPropertyName("r")]
public required T Result { get; init; }

[JsonPropertyName("success")]
public required bool Success { get; init; }
[JsonPropertyName("k")]
public required SseResponseKind Kind { get; init; }
}

public static class SseResponseLine
{
public static SseResponseLine<string> CreateSegment(string segment)
{
return new SseResponseLine<string>
{
Result = segment,
Kind = SseResponseKind.Segment,
};
}

public static SseResponseLine<string> CreateError(string error)
{
return new SseResponseLine<string>
{
Result = error,
Kind = SseResponseKind.Error,
};
}

public static SseResponseLine<SseEndMessage> CreateEnd(
Message? userMessage,
Message assistantMessage,
IUrlEncryptionService urlEncryptionService,
FileUrlProvider fup)
{
ChatMessageTemp? userMessageTemp = userMessage == null ? null : new ChatMessageTemp()
{
Content = [.. userMessage.MessageContents],
CreatedAt = userMessage.CreatedAt,
Id = userMessage.Id,
ParentId = userMessage.ParentId,
Role = (DBChatRole)userMessage.ChatRoleId,
Usage = null,
};
ChatMessageTemp assistantMessageTemp = new()
{
Content = [.. assistantMessage.MessageContents],
CreatedAt = assistantMessage.CreatedAt,
Id = assistantMessage.Id,
ParentId = assistantMessage.ParentId,
Role = (DBChatRole)assistantMessage.ChatRoleId,
Usage = assistantMessage.Usage == null ? null : new ChatMessageTempUsage()
{
Duration = assistantMessage.Usage.TotalDurationMs - assistantMessage.Usage.PreprocessDurationMs,
FirstTokenLatency = assistantMessage.Usage.FirstResponseDurationMs,
InputPrice = assistantMessage.Usage.InputCost,
InputTokens = assistantMessage.Usage.InputTokens,
ModelId = assistantMessage.Usage.UserModel.ModelId,
ModelName = assistantMessage.Usage.UserModel.Model.Name,
OutputPrice = assistantMessage.Usage.OutputCost,
OutputTokens = assistantMessage.Usage.OutputTokens,
ReasoningTokens = assistantMessage.Usage.ReasoningTokens,
},
};
MessageDto? userMessageDto = userMessageTemp?.ToDto(urlEncryptionService, fup);
MessageDto assistantMessageDto = assistantMessageTemp.ToDto(urlEncryptionService, fup);
return new SseResponseLine<SseEndMessage>
{
Result = new SseEndMessage
{
RequestMessage = userMessageDto,
ResponseMessage = assistantMessageDto
},
Kind = SseResponseKind.End,
};
}

public static SseResponseLine<string> CreateStopId(string stopId)
{
return new SseResponseLine<string>
{
Result = stopId,
Kind = SseResponseKind.StopId,
};
}
}
44 changes: 25 additions & 19 deletions src/BE/Controllers/Chats/Messages/Dtos/MessageDto.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,31 @@ public record FileDto
public required Uri Url { get; init; }
}

public record ChatMessageTempUsage
{
public required int Duration { get; init; }
public required int FirstTokenLatency { get; init; }
public required decimal InputPrice { get; init; }
public required int InputTokens { get; init; }
public required short ModelId { get; init; }
public required string ModelName { get; init; }
public required decimal OutputPrice { get; init; }
public required int OutputTokens { get; init; }
public required int ReasoningTokens { get; init; }
}

public record ChatMessageTemp
{
public required long Id { get; init; }
public required long? ParentId { get; init; }
public required DBChatRole Role { get; init; }
public required MessageContent[] Content { get; init; }
public required int? InputTokens { get; init; }
public required int? OutputTokens { get; init; }
public required int? ReasoningTokens { get; init; }
public required decimal? InputPrice { get; init; }
public required decimal? OutputPrice { get; init; }
public required DateTime CreatedAt { get; init; }
public required int? Duration { get; init; }
public required int? FirstTokenLatency { get; init; }
public required short? ModelId { get; init; }
public required string? ModelName { get; init; }
public required ChatMessageTempUsage? Usage { get; init; }

public MessageDto ToDto(IUrlEncryptionService urlEncryption, FileUrlProvider fup)
{
if (ModelId == null)
if (Usage == null)
{
return new RequestMessageDto()
{
Expand All @@ -160,15 +165,16 @@ public MessageDto ToDto(IUrlEncryptionService urlEncryption, FileUrlProvider fup
Role = Role.ToString().ToLowerInvariant(),
Content = MessageContentResponse.FromSegments(Content, fup),
CreatedAt = CreatedAt,
InputTokens = InputTokens!.Value,
OutputTokens = OutputTokens!.Value,
InputPrice = InputPrice!.Value,
OutputPrice = OutputPrice!.Value,
ReasoningTokens = ReasoningTokens!.Value,
Duration = Duration!.Value,
FirstTokenLatency = FirstTokenLatency!.Value,
ModelId = ModelId!.Value,
ModelName = ModelName

InputTokens = Usage.InputTokens,
OutputTokens = Usage.OutputTokens,
InputPrice = Usage.InputPrice,
OutputPrice = Usage.OutputPrice,
ReasoningTokens = Usage.ReasoningTokens,
Duration = Usage.Duration,
FirstTokenLatency = Usage.FirstTokenLatency,
ModelId = Usage.ModelId,
ModelName = Usage.ModelName,
};
}
}
Expand Down
Loading

0 comments on commit 66e850a

Please sign in to comment.