Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Remove support for embedded Ollama and Llamafile servers #85

Merged
merged 4 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Cellm/AddIn/ExcelAddin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace Cellm.AddIn;

public class ExcelAddin : IExcelAddIn
public class ExcelAddIn : IExcelAddIn
{
public void AutoOpen()
{
Expand Down
3 changes: 1 addition & 2 deletions src/Cellm/Models/Behaviors/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Cellm.Models.Behaviors;
using Cellm.Models.Providers;
using Cellm.Models.Providers;
using MediatR;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;
Expand Down
12 changes: 0 additions & 12 deletions src/Cellm/Models/Providers/Llamafile/LlamafileConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,13 @@

internal class LlamafileConfiguration : IProviderConfiguration
{
public Uri LlamafileUrl { get; init; }

public Uri BaseAddress { get; init; }

public Dictionary<string, Uri> Models { get; init; }

public string DefaultModel { get; init; }

public bool Gpu { get; init; }

public int GpuLayers { get; init; }

public LlamafileConfiguration()
{
LlamafileUrl = default!;
BaseAddress = default!;
Models = default!;
DefaultModel = default!;
Gpu = false;
GpuLayers = 999;
}
}
118 changes: 4 additions & 114 deletions src/Cellm/Models/Providers/Llamafile/LlamafileRequestHandler.cs
Original file line number Diff line number Diff line change
@@ -1,129 +1,19 @@
using System.Diagnostics;
using Cellm.Models.Exceptions;
using Cellm.Models.Local.Utilities;
using Cellm.Models.Providers.OpenAiCompatible;
using Cellm.Models.Providers.OpenAiCompatible;
using MediatR;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Providers.Llamafile;

internal class LlamafileRequestHandler : IProviderRequestHandler<LlamafileRequest, LlamafileResponse>
internal class LlamafileRequestHandler(ISender sender, IOptions<LlamafileConfiguration> llamafileConfiguration) : IProviderRequestHandler<LlamafileRequest, LlamafileResponse>
{
private record Llamafile(string ModelPath, Uri BaseAddress, Process Process);

private readonly AsyncLazy<string> _llamafileExePath;
private readonly Dictionary<string, AsyncLazy<Llamafile>> _llamafiles;
private readonly ProcessManager _processManager;
private readonly FileManager _fileManager;
private readonly ServerManager _serverManager;

private readonly LlamafileConfiguration _llamafileConfiguration;

private readonly ISender _sender;
private readonly ILogger<LlamafileRequestHandler> _logger;

public LlamafileRequestHandler(
IOptions<LlamafileConfiguration> llamafileConfiguration,
ISender sender,
HttpClient httpClient,
FileManager fileManager,
ProcessManager processManager,
ServerManager serverManager,
ILogger<LlamafileRequestHandler> logger)
{
_llamafileConfiguration = llamafileConfiguration.Value;
_sender = sender;
_fileManager = fileManager;
_processManager = processManager;
_serverManager = serverManager;
_logger = logger;

_llamafileExePath = new AsyncLazy<string>(async () =>
{
var llamafileName = Path.GetFileName(_llamafileConfiguration.LlamafileUrl.Segments.Last());
return await _fileManager.DownloadFileIfNotExists(_llamafileConfiguration.LlamafileUrl, _fileManager.CreateCellmFilePath(CreateModelFileName($"{llamafileName}.exe"), "Llamafile"));
});

_llamafiles = _llamafileConfiguration.Models.ToDictionary(x => x.Key, x => new AsyncLazy<Llamafile>(async () =>
{
// Download Llamafile
var exePath = await _llamafileExePath;

// Download model
var modelPath = await _fileManager.DownloadFileIfNotExists(x.Value, _fileManager.CreateCellmFilePath(CreateModelFileName(x.Key), "Llamafile"));

// Start server
var baseAddress = new UriBuilder(
_llamafileConfiguration.BaseAddress.Scheme,
_llamafileConfiguration.BaseAddress.Host,
_serverManager.FindPort(),
_llamafileConfiguration.BaseAddress.AbsolutePath).Uri;

var process = await StartProcess(exePath, modelPath, baseAddress);

return new Llamafile(modelPath, baseAddress, process);
}));
}

public async Task<LlamafileResponse> Handle(LlamafileRequest request, CancellationToken cancellationToken)
{
// Start server on first call
var llamafile = await _llamafiles[request.Prompt.Options.ModelId ?? _llamafileConfiguration.DefaultModel];
request.Prompt.Options.ModelId ??= llamafileConfiguration.Value.DefaultModel;

var openAiResponse = await _sender.Send(new OpenAiCompatibleRequest(request.Prompt, llamafile.BaseAddress), cancellationToken);
var openAiResponse = await sender.Send(new OpenAiCompatibleRequest(request.Prompt, llamafileConfiguration.Value.BaseAddress), cancellationToken);

return new LlamafileResponse(openAiResponse.Prompt);
}

private async Task<Process> StartProcess(string exePath, string modelPath, Uri baseAddress)
{
var processStartInfo = new ProcessStartInfo(exePath);

processStartInfo.ArgumentList.Add("--server");
processStartInfo.ArgumentList.Add("--nobrowser");
processStartInfo.ArgumentList.Add("-m");
processStartInfo.ArgumentList.Add(modelPath);
processStartInfo.ArgumentList.Add("--host");
processStartInfo.ArgumentList.Add(baseAddress.Host);
processStartInfo.ArgumentList.Add("--port");
processStartInfo.ArgumentList.Add(baseAddress.Port.ToString());

if (_llamafileConfiguration.Gpu)
{
processStartInfo.Arguments += $"-ngl {_llamafileConfiguration.GpuLayers} ";
}

processStartInfo.UseShellExecute = false;
processStartInfo.CreateNoWindow = true;
processStartInfo.RedirectStandardError = true;
processStartInfo.RedirectStandardOutput = true;

var process = Process.Start(processStartInfo) ?? throw new CellmModelException("Failed to run Llamafile");

process.OutputDataReceived += (sender, e) =>
{
if (!string.IsNullOrEmpty(e.Data))
{
_logger.LogDebug(e.Data);
}
};

process.BeginOutputReadLine();
process.BeginErrorReadLine();

var uriBuilder = new UriBuilder(baseAddress.Scheme, baseAddress.Host, baseAddress.Port, "/health");
await _serverManager.WaitForServer(uriBuilder.Uri, process);

// Kill Llamafile when Excel exits or dies
_processManager.AssignProcessToExcel(process);

return process;
}

private static string CreateModelFileName(string modelName)
{
return $"Llamafile-{modelName}";
}
}

163 changes: 4 additions & 159 deletions src/Cellm/Models/Providers/Ollama/OllamaRequestHandler.cs
Original file line number Diff line number Diff line change
@@ -1,98 +1,15 @@
using System.Diagnostics;
using System.Net.Http.Json;
using System.Text;
using System.Text.Json;
using Cellm.Models.Exceptions;
using Cellm.Models.Local.Utilities;
using Cellm.Models.Prompts;
using Cellm.Models.Prompts;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Providers.Ollama;

internal class OllamaRequestHandler : IModelRequestHandler<OllamaRequest, OllamaResponse>
internal class OllamaRequestHandler(
[FromKeyedServices(Provider.Ollama)] IChatClient chatClient) : IModelRequestHandler<OllamaRequest, OllamaResponse>
{
private record OllamaServer(Uri BaseAddress, Process Process);

record Tags(List<Model> Models);
record Model(string Name);
record Progress(string Status);

private readonly IChatClient _chatClient;
private readonly OllamaConfiguration _ollamaConfiguration;
private readonly HttpClient _httpClient;
private readonly FileManager _fileManager;
private readonly ProcessManager _processManager;
private readonly ServerManager _serverManager;
private readonly ILogger<OllamaRequestHandler> _logger;

private readonly AsyncLazy<string> _ollamaExePath;
private readonly AsyncLazy<OllamaServer> _ollamaServer;

public OllamaRequestHandler(
[FromKeyedServices(Provider.Ollama)] IChatClient chatClient,
IHttpClientFactory httpClientFactory,
IOptions<OllamaConfiguration> ollamaConfiguration,
FileManager fileManager,
ProcessManager processManager,
ServerManager serverManager,
ILogger<OllamaRequestHandler> logger)
{
_chatClient = chatClient;
_httpClient = httpClientFactory.CreateClient(nameof(Provider.Ollama));
_ollamaConfiguration = ollamaConfiguration.Value;
_fileManager = fileManager;
_processManager = processManager;
_serverManager = serverManager;
_logger = logger;

_ollamaExePath = new AsyncLazy<string>(async () =>
{
var zipFileName = string.Join("-", _ollamaConfiguration.ZipUrl.Segments.Select(x => x.Replace("/", string.Empty)).TakeLast(2));
var zipFilePath = _fileManager.CreateCellmFilePath(zipFileName);

await _fileManager.DownloadFileIfNotExists(
_ollamaConfiguration.ZipUrl,
zipFilePath);

var ollamaPath = _fileManager.ExtractZipFileIfNotExtracted(
zipFilePath,
_fileManager.CreateCellmDirectory(nameof(Ollama), Path.GetFileNameWithoutExtension(zipFileName)));

return Path.Combine(ollamaPath, "ollama.exe");
});

_ollamaServer = new AsyncLazy<OllamaServer>(async () =>
{
var ollamaExePath = await _ollamaExePath;
var process = await StartProcess(ollamaExePath, _ollamaConfiguration.BaseAddress);

return new OllamaServer(_ollamaConfiguration.BaseAddress, process);
});
}

public async Task<OllamaResponse> Handle(OllamaRequest request, CancellationToken cancellationToken)
{
var serverIsRunning = await ServerIsRunning(_ollamaConfiguration.BaseAddress);
if (_ollamaConfiguration.EnableServer && !serverIsRunning)
{
_ = await _ollamaServer;
}

var modelIsDownloaded = await ModelIsDownloaded(
_ollamaConfiguration.BaseAddress,
request.Prompt.Options.ModelId ?? _ollamaConfiguration.DefaultModel);

if (!modelIsDownloaded)
{
await DownloadModel(
_ollamaConfiguration.BaseAddress,
request.Prompt.Options.ModelId ?? _ollamaConfiguration.DefaultModel);
}

var chatCompletion = await _chatClient.CompleteAsync(
var chatCompletion = await chatClient.CompleteAsync(
request.Prompt.Messages,
request.Prompt.Options,
cancellationToken);
Expand All @@ -103,76 +20,4 @@ await DownloadModel(

return new OllamaResponse(prompt);
}

private async Task<bool> ServerIsRunning(Uri baseAddress)
{
var response = await _httpClient.GetAsync(baseAddress);

return response.IsSuccessStatusCode;
}

private async Task<bool> ModelIsDownloaded(Uri baseAddress, string modelId)
{
var tags = await _httpClient.GetFromJsonAsync<Tags>("api/tags") ?? throw new CellmModelException();

return tags.Models.Select(x => x.Name).Contains(modelId);
}

private async Task DownloadModel(Uri baseAddress, string modelId)
{
try
{
var modelName = JsonSerializer.Serialize(new { name = modelId });
var modelStringContent = new StringContent(modelName, Encoding.UTF8, "application/json");
var response = await _httpClient.PostAsync("api/pull", modelStringContent);

response.EnsureSuccessStatusCode();

var progress = await response.Content.ReadFromJsonAsync<List<Progress>>();

if (progress is null || progress.Last().Status != "success")
{
throw new CellmModelException($"Ollama failed to download model {modelId}");
}
}
catch (HttpRequestException ex)
{
throw new CellmModelException($"Ollama failed to download model {modelId} or {modelId} does not exist", ex);
}
}

private async Task<Process> StartProcess(string ollamaExePath, Uri baseAddress)
{
var processStartInfo = new ProcessStartInfo(await _ollamaExePath);

processStartInfo.ArgumentList.Add("serve");
processStartInfo.EnvironmentVariables.Add("OLLAMA_HOST", baseAddress.ToString());

processStartInfo.UseShellExecute = false;
processStartInfo.CreateNoWindow = true;
processStartInfo.RedirectStandardError = true;
processStartInfo.RedirectStandardOutput = true;

var process = Process.Start(processStartInfo) ?? throw new CellmModelException("Failed to run Ollama");

process.OutputDataReceived += (sender, e) =>
{
if (!string.IsNullOrEmpty(e.Data))
{
_logger.LogDebug(e.Data);
Debug.WriteLine(e.Data);
}
};

process.BeginOutputReadLine();
process.BeginErrorReadLine();

var address = new Uri(baseAddress, "/v1/models");
await _serverManager.WaitForServer(address, process);

// Kill Ollama when Excel exits or dies
_processManager.AssignProcessToExcel(process);

return process;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ namespace Cellm.Models.Providers.OpenAiCompatible;

internal record OpenAiCompatibleRequest(
Prompt Prompt,
Uri? BaseAddress = null,
Uri BaseAddress,
string? ApiKey = null) : IModelRequest<OpenAiCompatibleResponse>;
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

internal class OpenAiCompatibleRequestHandler(
OpenAiCompatibleChatClientFactory openAiCompatibleChatClientFactory,
IOptions<OpenAiCompatibleConfiguration> openAiCompatibleConfiguration)

Check warning on line 8 in src/Cellm/Models/Providers/OpenAiCompatible/OpenAiCompatibleRequestHandler.cs

View workflow job for this annotation

GitHub Actions / Build

Parameter 'openAiCompatibleConfiguration' is unread.

Check warning on line 8 in src/Cellm/Models/Providers/OpenAiCompatible/OpenAiCompatibleRequestHandler.cs

View workflow job for this annotation

GitHub Actions / Build

Parameter 'openAiCompatibleConfiguration' is unread.
: IModelRequestHandler<OpenAiCompatibleRequest, OpenAiCompatibleResponse>
{
private readonly OpenAiCompatibleConfiguration _openAiCompatibleConfiguration = openAiCompatibleConfiguration.Value;

public async Task<OpenAiCompatibleResponse> Handle(OpenAiCompatibleRequest request, CancellationToken cancellationToken)
{
var chatClient = openAiCompatibleChatClientFactory.Create(
request.BaseAddress ?? _openAiCompatibleConfiguration.BaseAddress,
request.Prompt.Options.ModelId ?? _openAiCompatibleConfiguration.DefaultModel,
request.ApiKey ?? _openAiCompatibleConfiguration.ApiKey);
request.BaseAddress,
request.Prompt.Options.ModelId ?? string.Empty,
request.ApiKey ?? "API_KEY");

var chatCompletion = await chatClient.CompleteAsync(request.Prompt.Messages, request.Prompt.Options, cancellationToken);

Expand Down
Loading