Skip to content

Commit

Permalink
refactor: Models (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaspermarstal authored Dec 28, 2024
1 parent 78fb85d commit 3d5da5d
Show file tree
Hide file tree
Showing 70 changed files with 653 additions and 589 deletions.
8 changes: 4 additions & 4 deletions src/Cellm.Tests/IntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Cellm.Tests;

[ExcelTestSettings(AddIn = @"..\..\..\..\Cellm\bin\Debug\net6.0-windows\Cellm-AddIn")]
[ExcelTestSettings(AddIn = @"..\..\..\..\Cellm\bin\Debug\net8.0-windows\Cellm-AddIn")]
public class ExcelTests : IDisposable
{
readonly Workbook _testWorkbook;
Expand Down Expand Up @@ -54,17 +54,17 @@ public void TestPromptWith()
Worksheet ws = (Worksheet)_testWorkbook.Sheets[1];
ws.Range["A1"].Value = "Respond with \"Hello World\"";
ws.Range["A2"].Formula = "=PROMPTWITH(\"Anthropic/claude-3-haiku-20240307\",A1)";
ExcelTestHelper.WaitForCellValue(ws.Range["A2"]);
Automation.Wait(5000);
Assert.Equal("Hello World", ws.Range["A2"].Text);

ws.Range["B1"].Value = "Respond with \"Hello World\"";
ws.Range["B2"].Formula = "=PROMPTWITH(\"OpenAI/gpt-4o-mini\",B1)";
ExcelTestHelper.WaitForCellValue(ws.Range["B2"]);
Automation.Wait(5000);
Assert.Equal("Hello World", ws.Range["B2"].Text);

ws.Range["C1"].Value = "Respond with \"Hello World\"";
ws.Range["C2"].Formula = "=PROMPTWITH(\"OpenAI/gemini-1.5-flash-latest\",C1)";
ExcelTestHelper.WaitForCellValue(ws.Range["C2"]);
Automation.Wait(5000);
Assert.Equal("Hello World", ws.Range["C2"].Text);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System.Text;
using Cellm.AddIn.Exceptions;
using Cellm.Prompts;
using Cellm.Services.Configuration;
using Cellm.Models.Providers;
using ExcelDna.Integration;
using Microsoft.Extensions.Configuration;
using Microsoft.Office.Interop.Excel;
Expand All @@ -10,7 +9,7 @@ namespace Cellm.AddIn;

public record Arguments(string Provider, string Model, string Context, string Instructions, double Temperature);

public class PromptArgumentParser
public class ArgumentParser
{
private string? _provider;
private string? _model;
Expand All @@ -20,12 +19,12 @@ public class PromptArgumentParser

private readonly IConfiguration _configuration;

public PromptArgumentParser(IConfiguration configuration)
public ArgumentParser(IConfiguration configuration)
{
_configuration = configuration;
}

public PromptArgumentParser AddProvider(object providerAndModel)
public ArgumentParser AddProvider(object providerAndModel)
{
_provider = providerAndModel switch
{
Expand All @@ -37,7 +36,7 @@ public PromptArgumentParser AddProvider(object providerAndModel)
return this;
}

public PromptArgumentParser AddModel(object providerAndModel)
public ArgumentParser AddModel(object providerAndModel)
{
_model = providerAndModel switch
{
Expand All @@ -49,21 +48,21 @@ public PromptArgumentParser AddModel(object providerAndModel)
return this;
}

public PromptArgumentParser AddInstructionsOrContext(object instructionsOrContext)
public ArgumentParser AddInstructionsOrContext(object instructionsOrContext)
{
_instructionsOrContext = instructionsOrContext;

return this;
}

public PromptArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
public ArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
{
_instructionsOrTemperature = instructionsOrTemperature;

return this;
}

public PromptArgumentParser AddTemperature(object temperature)
public ArgumentParser AddTemperature(object temperature)
{
_temperature = temperature;

Expand All @@ -73,19 +72,19 @@ public PromptArgumentParser AddTemperature(object temperature)
public Arguments Parse()
{
var provider = _provider ?? _configuration
.GetSection(nameof(CellmConfiguration))
.GetValue<string>(nameof(CellmConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(CellmConfiguration.DefaultProvider));
.GetSection(nameof(ProviderConfiguration))
.GetValue<string>(nameof(ProviderConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(ProviderConfiguration.DefaultProvider));

var model = _model ?? _configuration
.GetSection(nameof(CellmConfiguration))
.GetValue<string>(nameof(CellmConfiguration.DefaultModel))
?? throw new ArgumentException(nameof(CellmConfiguration.DefaultModel));
.GetSection($"{provider}Configuration")
.GetValue<string>(nameof(IProviderConfiguration.DefaultModel))
?? throw new ArgumentException(nameof(IProviderConfiguration.DefaultModel));

var defaultTemperature = _configuration
.GetSection(nameof(CellmConfiguration))
.GetValue<double?>(nameof(CellmConfiguration.DefaultTemperature))
?? throw new ArgumentException(nameof(CellmConfiguration.DefaultTemperature));
.GetSection(nameof(ProviderConfiguration))
.GetValue<double?>(nameof(ProviderConfiguration.DefaultTemperature))
?? throw new ArgumentException(nameof(ProviderConfiguration.DefaultTemperature));

return (_instructionsOrContext, _instructionsOrTemperature, _temperature) switch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace Cellm.AddIn;

public class CellmAddIn : IExcelAddIn
public class ExcelAddin : IExcelAddIn
{
public void AutoOpen()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
using System.Text;
using Cellm.AddIn.Exceptions;
using Cellm.Models;
using Cellm.Prompts;
using Cellm.Models.Prompts;
using Cellm.Models.Providers;
using Cellm.Services;
using Cellm.Services.Configuration;
using ExcelDna.Integration;
using Microsoft.Extensions.Configuration;

namespace Cellm.AddIn;

public static class CellmFunctions
public static class ExcelFunctions
{
/// <summary>
/// Sends a prompt to the default model configured in CellmConfiguration.
Expand All @@ -35,8 +35,8 @@ public static object Prompt(
{
var configuration = ServiceLocator.Get<IConfiguration>();

var provider = configuration.GetSection(nameof(CellmConfiguration)).GetValue<string>(nameof(CellmConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(CellmConfiguration.DefaultProvider));
var provider = configuration.GetSection(nameof(ProviderConfiguration)).GetValue<string>(nameof(ProviderConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(ProviderConfiguration.DefaultProvider));

var model = configuration.GetSection($"{provider}Configuration").GetValue<string>(nameof(IProviderConfiguration.DefaultModel))
?? throw new ArgumentException(nameof(IProviderConfiguration.DefaultModel));
Expand Down Expand Up @@ -73,7 +73,7 @@ public static object PromptWith(
{
try
{
var arguments = ServiceLocator.Get<PromptArgumentParser>()
var arguments = ServiceLocator.Get<ArgumentParser>()
.AddProvider(providerAndModel)
.AddModel(providerAndModel)
.AddInstructionsOrContext(instructionsOrContext)
Expand Down Expand Up @@ -116,7 +116,7 @@ public static object PromptWith(
/// <returns>A task that represents the asynchronous operation. The task result contains the model's response as a string.</returns>
/// <exception cref="CellmException">Thrown when an unexpected error occurs during the operation.</exception>

private static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
internal static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress, CancellationToken.None);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace Cellm.Prompts;
namespace Cellm.AddIn;

internal static class SystemMessages
{
Expand Down
1 change: 1 addition & 0 deletions src/Cellm/Cellm.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" Version="9.0.1-preview.1.24570.5" />
<PackageReference Include="Microsoft.Extensions.Caching.Hybrid" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.Configuration" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration.Abstractions" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.FileSystemGlobbing" Version="9.0.0" />
Expand Down
5 changes: 0 additions & 5 deletions src/Cellm/Models/Anthropic/AnthropicResponse.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
using System.Text.Json;
using Cellm.Services.Configuration;
using Cellm.Models.Providers;
using MediatR;
using Microsoft.Extensions.Caching.Hybrid;
using Microsoft.Extensions.Options;

namespace Cellm.Models.ModelRequestBehavior;
namespace Cellm.Models.Behaviors;

internal class CachingBehavior<TRequest, TResponse>(HybridCache cache, IOptions<CellmConfiguration> cellmConfiguration) : IPipelineBehavior<TRequest, TResponse>
internal class CacheBehavior<TRequest, TResponse>(HybridCache cache, IOptions<ProviderConfiguration> providerConfiguration) : IPipelineBehavior<TRequest, TResponse>
where TRequest : IModelRequest<TResponse>
where TResponse : IModelResponse
{
private readonly HybridCacheEntryOptions _cacheEntryOptions = new()
{
Expiration = TimeSpan.FromSeconds(cellmConfiguration.Value.CacheTimeoutInSeconds)
Expiration = TimeSpan.FromSeconds(providerConfiguration.Value.CacheTimeoutInSeconds)
};

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
if (cellmConfiguration.Value.EnableCache)
if (providerConfiguration.Value.EnableCache)
{
return await cache.GetOrCreateAsync(
JsonSerializer.Serialize(request.Prompt),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using MediatR;

namespace Cellm.Models.ModelRequestBehavior;
namespace Cellm.Models.Behaviors;

internal class SentryBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : notnull
Expand Down
24 changes: 24 additions & 0 deletions src/Cellm/Models/Behaviors/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Cellm.Models.Behaviors;
using Cellm.Models.Providers;
using MediatR;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Tools;

internal class ToolBehavior<TRequest, TResponse>(IOptions<ProviderConfiguration> providerConfiguration, IEnumerable<AIFunction> functions)
: IPipelineBehavior<TRequest, TResponse> where TRequest : IModelRequest<TResponse>
{
private readonly ProviderConfiguration _providerConfiguration = providerConfiguration.Value;
private readonly List<AITool> _tools = new(functions);

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
if (_providerConfiguration.EnableTools)
{
request.Prompt.Options.Tools = _tools;
}

return await next();
}
}
57 changes: 22 additions & 35 deletions src/Cellm/Models/Client.cs
Original file line number Diff line number Diff line change
@@ -1,73 +1,60 @@
using System.Text.Json;
using Cellm.AddIn.Exceptions;
using Cellm.Models.Anthropic;
using Cellm.Models.Llamafile;
using Cellm.Models.Ollama;
using Cellm.Models.OpenAi;
using Cellm.Models.OpenAiCompatible;
using Cellm.Prompts;
using Cellm.Services.Configuration;
using Cellm.Models.Exceptions;
using Cellm.Models.Prompts;
using Cellm.Models.Providers;
using Cellm.Models.Providers.Anthropic;
using Cellm.Models.Providers.Llamafile;
using Cellm.Models.Providers.Ollama;
using Cellm.Models.Providers.OpenAi;
using Cellm.Models.Providers.OpenAiCompatible;
using MediatR;
using Microsoft.Extensions.Options;
using Polly.Timeout;

namespace Cellm.Models;

internal class Client(ISender sender, IOptions<CellmConfiguration> cellmConfiguration)
public class Client(ISender sender, IOptions<ProviderConfiguration> providerConfiguration)
{
private readonly CellmConfiguration _cellmConfiguration = cellmConfiguration.Value;
private readonly ProviderConfiguration _providerConfiguration = providerConfiguration.Value;

public async Task<Prompt> Send(Prompt prompt, string? provider, Uri? baseAddress, CancellationToken cancellationToken)
{
try
{
provider ??= _cellmConfiguration.DefaultProvider;
provider ??= _providerConfiguration.DefaultProvider;

if (!Enum.TryParse<Providers>(provider, true, out var parsedProvider))
if (!Enum.TryParse<Provider>(provider, true, out var parsedProvider))
{
throw new ArgumentException($"Unsupported provider: {provider}");
}

IModelResponse response = parsedProvider switch
{
Providers.Anthropic => await sender.Send(new AnthropicRequest(prompt, provider, baseAddress), cancellationToken),
Providers.Llamafile => await sender.Send(new LlamafileRequest(prompt), cancellationToken),
Providers.Ollama => await sender.Send(new OllamaRequest(prompt), cancellationToken),
Providers.OpenAi => await sender.Send(new OpenAiRequest(prompt), cancellationToken),
Providers.OpenAiCompatible => await sender.Send(new OpenAiCompatibleRequest(prompt, baseAddress), cancellationToken),
Provider.Anthropic => await sender.Send(new AnthropicRequest(prompt, provider, baseAddress), cancellationToken),
Provider.Llamafile => await sender.Send(new LlamafileRequest(prompt), cancellationToken),
Provider.Ollama => await sender.Send(new OllamaRequest(prompt), cancellationToken),
Provider.OpenAi => await sender.Send(new OpenAiRequest(prompt), cancellationToken),
Provider.OpenAiCompatible => await sender.Send(new OpenAiCompatibleRequest(prompt, baseAddress), cancellationToken),
_ => throw new InvalidOperationException($"Provider {parsedProvider} is defined but not implemented")
};

return response.Prompt;
}
catch (HttpRequestException ex)
{
throw new CellmException($"HTTP request failed: {ex.Message}", ex);
}
catch (JsonException ex)
{
throw new CellmException($"JSON processing failed: {ex.Message}", ex);
}
catch (NotSupportedException ex)
{
throw new CellmException($"Method not supported: {ex.Message}", ex);
}
catch (FileReaderException ex)
{
throw new CellmException($"File could not be read: {ex.Message}", ex);
throw new CellmModelException($"HTTP request failed: {ex.Message}", ex);
}
catch (NullReferenceException ex)
{
throw new CellmException($"Null reference error: {ex.Message}", ex);
throw new CellmModelException($"Null reference error: {ex.Message}", ex);
}
catch (TimeoutRejectedException ex)
{
throw new CellmException($"Request timed out: {ex.Message}", ex);
throw new CellmModelException($"Request timed out: {ex.Message}", ex);
}
catch (Exception ex) when (ex is not CellmException)
catch (Exception ex) when (ex is not CellmModelException)
{
// Handle any other unexpected exceptions
throw new CellmException($"An unexpected error occurred: {ex.Message}", ex);
throw new CellmModelException($"An unexpected error occurred: {ex.Message}", ex);
}
}
}
10 changes: 10 additions & 0 deletions src/Cellm/Models/Exceptions/CellmModelException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace Cellm.Models.Exceptions;

public class CellmModelException : Exception
{
public CellmModelException(string message = "#CELLM_ERROR?")
: base(message) { }

public CellmModelException(string message, Exception inner)
: base(message, inner) { }
}
5 changes: 0 additions & 5 deletions src/Cellm/Models/Llamafile/LlamafileResponse.cs

This file was deleted.

Loading

0 comments on commit 3d5da5d

Please sign in to comment.