Skip to content

Commit

Permalink
refactor: Use ExcelAsyncUtil to run task
Browse files Browse the repository at this point in the history
  • Loading branch information
kaspermarstal committed Jan 19, 2025
1 parent a4e9e77 commit fbcf791
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
11 changes: 5 additions & 6 deletions src/Cellm/AddIn/ExcelFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ public static object PromptWith(
.AddUserMessage(userMessage)
.Build();

// ExcelAsyncUtil yields Excel's main thread, Task.Run enables async/await in inner code
return ExcelAsyncUtil.Run(nameof(PromptWith), new object[] { providerAndModel, instructionsOrContext, instructionsOrTemperature, temperature }, () =>
{
return Task.Run(async () => await CallModelAsync(prompt, arguments.Provider)).GetAwaiter().GetResult();
});
return ExcelAsyncUtil.RunTask(
nameof(PromptWith),
new object[] { providerAndModel, instructionsOrContext, instructionsOrTemperature, temperature },
() => CompleteAsync(prompt, arguments.Provider, null));
}
catch (CellmException ex)
{
Expand All @@ -116,7 +115,7 @@ public static object PromptWith(
/// <returns>A task that represents the asynchronous operation. The task result contains the model's response as a string.</returns>
/// <exception cref="CellmException">Thrown when an unexpected error occurs during the operation.</exception>

internal static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
internal static async Task<string> CompleteAsync(Prompt prompt, Provider? provider = null, Uri? baseAddress = null)
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress, CancellationToken.None);
Expand Down
22 changes: 6 additions & 16 deletions src/Cellm/Models/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,24 @@
using Cellm.Models.Providers.OpenAi;
using Cellm.Models.Providers.OpenAiCompatible;
using MediatR;
using Microsoft.Extensions.Options;
using Polly.Timeout;

namespace Cellm.Models;

public class Client(ISender sender, IOptions<ProviderConfiguration> providerConfiguration)
public class Client(ISender sender)
{
private readonly ProviderConfiguration _providerConfiguration = providerConfiguration.Value;

public async Task<Prompt> Send(Prompt prompt, string? provider, Uri? baseAddress, CancellationToken cancellationToken)
public async Task<Prompt> Send(Prompt prompt, Provider? provider, Uri? baseAddress, CancellationToken cancellationToken)
{
try
{
provider ??= _providerConfiguration.DefaultProvider;

if (!Enum.TryParse<Provider>(provider, true, out var parsedProvider))
{
throw new ArgumentException($"Unsupported provider: {provider}");
}

IModelResponse response = parsedProvider switch
IModelResponse response = provider switch
{
Provider.Anthropic => await sender.Send(new AnthropicRequest(prompt, provider, baseAddress), cancellationToken),
Provider.Anthropic => await sender.Send(new AnthropicRequest(prompt, provider.ToString(), baseAddress), cancellationToken),
Provider.Llamafile => await sender.Send(new LlamafileRequest(prompt), cancellationToken),
Provider.Ollama => await sender.Send(new OllamaRequest(prompt), cancellationToken),
Provider.OpenAi => await sender.Send(new OpenAiRequest(prompt), cancellationToken),
Provider.OpenAiCompatible => await sender.Send(new OpenAiCompatibleRequest(prompt, baseAddress), cancellationToken),
_ => throw new InvalidOperationException($"Provider {parsedProvider} is defined but not implemented")
Provider.OpenAiCompatible => await sender.Send(new OpenAiCompatibleRequest(prompt, baseAddress ?? throw new NullReferenceException($"{nameof(Provider.OpenAiCompatible)} requires BaseAddress")), cancellationToken),
_ => throw new NotSupportedException($"Provider {provider} is not supported")
};

return response.Prompt;
Expand Down

0 comments on commit fbcf791

Please sign in to comment.