Skip to content

Commit

Permalink
refactor: Use ExcelAsyncUtil to run task (#84)
Browse files Browse the repository at this point in the history
* refactor: Make argument provider an enum

* refactor: Use ExcelAsyncUtil to run task
  • Loading branch information
kaspermarstal authored Jan 19, 2025
1 parent 86bb2bf commit 3a85d6f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 27 deletions.
14 changes: 9 additions & 5 deletions src/Cellm/AddIn/ArgumentParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
using Cellm.Models.Providers;
using ExcelDna.Integration;
using Microsoft.Extensions.Configuration;
using Microsoft.Office.Interop.Excel;

namespace Cellm.AddIn;

public record Arguments(string Provider, string Model, string Context, string Instructions, double Temperature);
public record Arguments(Provider Provider, string Model, string Context, string Instructions, double Temperature);

public class ArgumentParser
{
Expand Down Expand Up @@ -71,11 +70,16 @@ public ArgumentParser AddTemperature(object temperature)

public Arguments Parse()
{
var provider = _provider ?? _configuration
var providerAsString = _provider ?? _configuration
.GetSection(nameof(ProviderConfiguration))
.GetValue<string>(nameof(ProviderConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(ProviderConfiguration.DefaultProvider));

if (!Enum.TryParse<Provider>(providerAsString, true, out var provider))
{
throw new ArgumentException($"Unsupported default provider: {providerAsString}");
}

var model = _model ?? _configuration
.GetSection($"{provider}Configuration")
.GetValue<string>(nameof(IProviderConfiguration.DefaultModel))
Expand Down Expand Up @@ -147,9 +151,9 @@ private static string ParseCells(ExcelReference reference)
{
try
{
var app = (Application)ExcelDnaUtil.Application;
var app = ExcelDnaUtil.Application;
var sheetName = (string)XlCall.Excel(XlCall.xlSheetNm, reference);
sheetName = sheetName[(sheetName.LastIndexOf("]") + 1)..];
sheetName = sheetName[(sheetName.LastIndexOf(']') + 1)..];
var worksheet = app.Sheets[sheetName];

var tableBuilder = new StringBuilder();
Expand Down
11 changes: 5 additions & 6 deletions src/Cellm/AddIn/ExcelFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ public static object PromptWith(
.AddUserMessage(userMessage)
.Build();

// ExcelAsyncUtil yields Excel's main thread, Task.Run enables async/await in inner code
return ExcelAsyncUtil.Run(nameof(PromptWith), new object[] { providerAndModel, instructionsOrContext, instructionsOrTemperature, temperature }, () =>
{
return Task.Run(async () => await CallModelAsync(prompt, arguments.Provider)).GetAwaiter().GetResult();
});
return ExcelAsyncUtil.RunTask(
nameof(PromptWith),
new object[] { providerAndModel, instructionsOrContext, instructionsOrTemperature, temperature },
() => CompleteAsync(prompt, arguments.Provider, null));
}
catch (CellmException ex)
{
Expand All @@ -116,7 +115,7 @@ public static object PromptWith(
/// <returns>A task that represents the asynchronous operation. The task result contains the model's response as a string.</returns>
/// <exception cref="CellmException">Thrown when an unexpected error occurs during the operation.</exception>

internal static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
internal static async Task<string> CompleteAsync(Prompt prompt, Provider? provider = null, Uri? baseAddress = null)
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress, CancellationToken.None);
Expand Down
22 changes: 6 additions & 16 deletions src/Cellm/Models/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,24 @@
using Cellm.Models.Providers.OpenAi;
using Cellm.Models.Providers.OpenAiCompatible;
using MediatR;
using Microsoft.Extensions.Options;
using Polly.Timeout;

namespace Cellm.Models;

public class Client(ISender sender, IOptions<ProviderConfiguration> providerConfiguration)
public class Client(ISender sender)
{
private readonly ProviderConfiguration _providerConfiguration = providerConfiguration.Value;

public async Task<Prompt> Send(Prompt prompt, string? provider, Uri? baseAddress, CancellationToken cancellationToken)
public async Task<Prompt> Send(Prompt prompt, Provider? provider, Uri? baseAddress, CancellationToken cancellationToken)
{
try
{
provider ??= _providerConfiguration.DefaultProvider;

if (!Enum.TryParse<Provider>(provider, true, out var parsedProvider))
{
throw new ArgumentException($"Unsupported provider: {provider}");
}

IModelResponse response = parsedProvider switch
IModelResponse response = provider switch
{
Provider.Anthropic => await sender.Send(new AnthropicRequest(prompt, provider, baseAddress), cancellationToken),
Provider.Anthropic => await sender.Send(new AnthropicRequest(prompt, provider.ToString(), baseAddress), cancellationToken),
Provider.Llamafile => await sender.Send(new LlamafileRequest(prompt), cancellationToken),
Provider.Ollama => await sender.Send(new OllamaRequest(prompt), cancellationToken),
Provider.OpenAi => await sender.Send(new OpenAiRequest(prompt), cancellationToken),
Provider.OpenAiCompatible => await sender.Send(new OpenAiCompatibleRequest(prompt, baseAddress), cancellationToken),
_ => throw new InvalidOperationException($"Provider {parsedProvider} is defined but not implemented")
Provider.OpenAiCompatible => await sender.Send(new OpenAiCompatibleRequest(prompt, baseAddress ?? throw new NullReferenceException($"{nameof(Provider.OpenAiCompatible)} requires BaseAddress")), cancellationToken),
_ => throw new NotSupportedException($"Provider {provider} is not supported")
};

return response.Prompt;
Expand Down

0 comments on commit 3a85d6f

Please sign in to comment.