Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several updates to web project #718

Merged
merged 12 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,5 @@ site/
/LLama.Unittest/Models/*.bin
/LLama.Unittest/Models/*.gguf

/LLama.Web/appsettings.Local.json
/LLama.Web/appsettings.Local.json
32 changes: 21 additions & 11 deletions LLama.Web/Common/ISessionConfig.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
namespace LLama.Web.Common
using System.ComponentModel;

namespace LLama.Web.Common;

public interface ISessionConfig
{
public interface ISessionConfig
{
string AntiPrompt { get; set; }
List<string> AntiPrompts { get; set; }
LLamaExecutorType ExecutorType { get; set; }
string Model { get; set; }
string OutputFilter { get; set; }
List<string> OutputFilters { get; set; }
string Prompt { get; set; }
}
string AntiPrompt { get; set; }

[DisplayName("Anti Prompts")]
List<string> AntiPrompts { get; set; }

[DisplayName("Executor Type")]
LLamaExecutorType ExecutorType { get; set; }

string Model { get; set; }

[DisplayName("Output Filter")]
string OutputFilter { get; set; }

List<string> OutputFilters { get; set; }

string Prompt { get; set; }
}
155 changes: 77 additions & 78 deletions LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
@@ -1,117 +1,116 @@
using System.Text;
using System.Text;
using LLama.Abstractions;
using LLama.Native;

namespace LLama.Web.Common
namespace LLama.Web.Common;

public class ModelOptions
: ILLamaParams
{
public class ModelOptions
: ILLamaParams
{
/// <summary>
/// Model friendly name
/// </summary>
public string Name { get; set; }
/// <summary>
/// Model friendly name
/// </summary>
public string Name { get; set; }

/// <summary>
/// Max context insta=nces allowed per model
/// </summary>
public int MaxInstances { get; set; }
/// <summary>
/// Max context insta=nces allowed per model
Lamothe marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
public int MaxInstances { get; set; }

/// <inheritdoc />
public uint? ContextSize { get; set; }
/// <inheritdoc />
public uint? ContextSize { get; set; }

/// <inheritdoc />
public int MainGpu { get; set; } = 0;
/// <inheritdoc />
public int MainGpu { get; set; } = 0;

/// <inheritdoc />
public GPUSplitMode SplitMode { get; set; } = GPUSplitMode.None;
/// <inheritdoc />
public GPUSplitMode SplitMode { get; set; } = GPUSplitMode.None;

/// <inheritdoc />
public int GpuLayerCount { get; set; } = 20;
/// <inheritdoc />
public int GpuLayerCount { get; set; } = 20;

public uint SeqMax { get; }
public uint SeqMax { get; }

/// <inheritdoc />
public uint Seed { get; set; } = 1686349486;
/// <inheritdoc />
public uint Seed { get; set; } = 1686349486;

public bool Embeddings { get; }
public bool Embeddings { get; }

/// <inheritdoc />
public bool UseMemorymap { get; set; } = true;
/// <inheritdoc />
public bool UseMemorymap { get; set; } = true;

/// <inheritdoc />
public bool UseMemoryLock { get; set; } = false;
/// <inheritdoc />
public bool UseMemoryLock { get; set; } = false;

/// <inheritdoc />
public string ModelPath { get; set; }
/// <inheritdoc />
public string ModelPath { get; set; }

/// <inheritdoc />
public AdapterCollection LoraAdapters { get; set; } = new();
/// <inheritdoc />
public AdapterCollection LoraAdapters { get; set; } = new();

/// <inheritdoc />
public string LoraBase { get; set; } = string.Empty;
/// <inheritdoc />
public string LoraBase { get; set; } = string.Empty;

/// <inheritdoc />
public uint? Threads { get; set; }
/// <inheritdoc />
public uint? Threads { get; set; }

/// <inheritdoc />
public uint? BatchThreads { get; set; }
/// <inheritdoc />
public uint? BatchThreads { get; set; }

/// <inheritdoc />
public uint BatchSize { get; set; } = 512;
/// <inheritdoc />
public uint BatchSize { get; set; } = 512;

/// <inheritdoc />
public uint UBatchSize { get; set; } = 512;
/// <inheritdoc />
public uint UBatchSize { get; set; } = 512;

/// <inheritdoc />
public TensorSplitsCollection TensorSplits { get; set; } = new();
/// <inheritdoc />
public TensorSplitsCollection TensorSplits { get; set; } = new();

/// <inheritdoc />
public List<MetadataOverride> MetadataOverrides { get; } = new();
/// <inheritdoc />
public List<MetadataOverride> MetadataOverrides { get; } = new();

/// <inheritdoc />
public float? RopeFrequencyBase { get; set; }
/// <inheritdoc />
public float? RopeFrequencyBase { get; set; }

/// <inheritdoc />
public float? RopeFrequencyScale { get; set; }
/// <inheritdoc />
public float? RopeFrequencyScale { get; set; }

/// <inheritdoc />
public float? YarnExtrapolationFactor { get; set; }
/// <inheritdoc />
public float? YarnExtrapolationFactor { get; set; }

/// <inheritdoc />
public float? YarnAttentionFactor { get; set; }
/// <inheritdoc />
public float? YarnAttentionFactor { get; set; }

/// <inheritdoc />
public float? YarnBetaFast { get; set; }
/// <inheritdoc />
public float? YarnBetaFast { get; set; }

/// <inheritdoc />
public float? YarnBetaSlow { get; set; }
/// <inheritdoc />
public float? YarnBetaSlow { get; set; }

/// <inheritdoc />
public uint? YarnOriginalContext { get; set; }
/// <inheritdoc />
public uint? YarnOriginalContext { get; set; }

/// <inheritdoc />
public RopeScalingType? YarnScalingType { get; set; }
/// <inheritdoc />
public RopeScalingType? YarnScalingType { get; set; }

/// <inheritdoc />
public GGMLType? TypeK { get; set; }
/// <inheritdoc />
public GGMLType? TypeK { get; set; }

/// <inheritdoc />
public GGMLType? TypeV { get; set; }
/// <inheritdoc />
public GGMLType? TypeV { get; set; }

/// <inheritdoc />
public bool NoKqvOffload { get; set; }
/// <inheritdoc />
public bool NoKqvOffload { get; set; }

/// <inheritdoc />
public Encoding Encoding { get; set; } = Encoding.UTF8;
/// <inheritdoc />
public Encoding Encoding { get; set; } = Encoding.UTF8;

/// <inheritdoc />
public bool VocabOnly { get; set; }
/// <inheritdoc />
public bool VocabOnly { get; set; }

/// <inheritdoc />
public float DefragThreshold { get; set; }
/// <inheritdoc />
public float DefragThreshold { get; set; }

/// <inheritdoc />
public LLamaPoolingType PoolingType { get; set; }
}
/// <inheritdoc />
public LLamaPoolingType PoolingType { get; set; }
}
21 changes: 10 additions & 11 deletions LLama.Web/Common/SessionConfig.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
namespace LLama.Web.Common
namespace LLama.Web.Common;

public class SessionConfig : ISessionConfig
{
public class SessionConfig : ISessionConfig
{
public string Model { get; set; }
public string Prompt { get; set; }
public string Model { get; set; }
public string Prompt { get; set; }

public string AntiPrompt { get; set; }
public List<string> AntiPrompts { get; set; }
public string OutputFilter { get; set; }
public List<string> OutputFilters { get; set; }
public LLamaExecutorType ExecutorType { get; set; }
}
public string AntiPrompt { get; set; }
public List<string> AntiPrompts { get; set; }
public string OutputFilter { get; set; }
public List<string> OutputFilters { get; set; }
public LLamaExecutorType ExecutorType { get; set; }
}
92 changes: 44 additions & 48 deletions LLama.Web/Hubs/SessionConnectionHub.cs
Original file line number Diff line number Diff line change
@@ -1,67 +1,63 @@
using LLama.Web.Common;
using LLama.Web.Common;
using LLama.Web.Models;
using LLama.Web.Services;
using Microsoft.AspNetCore.SignalR;

namespace LLama.Web.Hubs
{
public class SessionConnectionHub : Hub<ISessionClient>
{
private readonly ILogger<SessionConnectionHub> _logger;
private readonly IModelSessionService _modelSessionService;
namespace LLama.Web.Hubs;

public SessionConnectionHub(ILogger<SessionConnectionHub> logger, IModelSessionService modelSessionService)
{
_logger = logger;
_modelSessionService = modelSessionService;
}
public class SessionConnectionHub : Hub<ISessionClient>
{
private readonly ILogger<SessionConnectionHub> _logger;
private readonly IModelSessionService _modelSessionService;

public override async Task OnConnectedAsync()
{
_logger.Log(LogLevel.Information, "[OnConnectedAsync], Id: {0}", Context.ConnectionId);
public SessionConnectionHub(ILogger<SessionConnectionHub> logger, IModelSessionService modelSessionService)
{
_logger = logger;
_modelSessionService = modelSessionService;
}

// Notify client of successful connection
await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Connected);
await base.OnConnectedAsync();
}
public override async Task OnConnectedAsync()
{
_logger.Log(LogLevel.Information, "[OnConnectedAsync], Id: {0}", Context.ConnectionId);

// Notify client of successful connection
await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Connected);
await base.OnConnectedAsync();
}

public override async Task OnDisconnectedAsync(Exception exception)
{
_logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId);
public override async Task OnDisconnectedAsync(Exception exception)
{
_logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId);

// Remove connections session on disconnect
await _modelSessionService.CloseAsync(Context.ConnectionId);
await base.OnDisconnectedAsync(exception);
}
// Remove connections session on disconnect
await _modelSessionService.CloseAsync(Context.ConnectionId);
await base.OnDisconnectedAsync(exception);
}

[HubMethodName("LoadModel")]
public async Task OnLoadModel(SessionConfig sessionConfig, InferenceOptions inferenceConfig)
{
_logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId);
await _modelSessionService.CloseAsync(Context.ConnectionId);

[HubMethodName("LoadModel")]
public async Task OnLoadModel(SessionConfig sessionConfig, InferenceOptions inferenceConfig)
// Create model session
var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionConfig, inferenceConfig);
if (modelSession is null)
{
_logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId);
await _modelSessionService.CloseAsync(Context.ConnectionId);

// Create model session
var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionConfig, inferenceConfig);
if (modelSession is null)
{
await Clients.Caller.OnError("Failed to create model session");
return;
}

// Notify client
await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Loaded);
await Clients.Caller.OnError("Failed to create model session");
return;
}

// Notify client
await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Loaded);
}

[HubMethodName("SendPrompt")]
public IAsyncEnumerable<TokenModel> OnSendPrompt(string prompt, InferenceOptions inferConfig, CancellationToken cancellationToken)
{
_logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId);
[HubMethodName("SendPrompt")]
public IAsyncEnumerable<TokenModel> OnSendPrompt(string prompt, InferenceOptions inferConfig, CancellationToken cancellationToken)
{
_logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId);

var linkedCancelationToken = CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted, cancellationToken);
return _modelSessionService.InferAsync(Context.ConnectionId, prompt, inferConfig, linkedCancelationToken.Token);
}
var linkedCancelationToken = CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted, cancellationToken);
return _modelSessionService.InferAsync(Context.ConnectionId, prompt, inferConfig, linkedCancelationToken.Token);
}
}
3 changes: 2 additions & 1 deletion LLama.Web/LLama.Web.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<Import Project="..\LLama\LLamaSharp.Runtime.targets" />
<PropertyGroup>
<TargetFramework>net7.0</TargetFramework>
<TargetFramework>net8.0</TargetFramework>
<Nullable>disable</Nullable>
<ImplicitUsings>enable</ImplicitUsings>
</PropertyGroup>
Expand All @@ -15,6 +15,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Mvc.Razor.RuntimeCompilation" Version="7.0.18" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
</ItemGroup>

Expand Down
Loading