Skip to content

Commit

Permalink
HFPlanner
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Oct 30, 2023
1 parent 8bf61cb commit 3caa7a5
Show file tree
Hide file tree
Showing 39 changed files with 338 additions and 213 deletions.
12 changes: 12 additions & 0 deletions docs/architecture/hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,16 @@ More information about conversation hook please go to [Conversation Hook](../con
```csharp
Task OnStateLoaded(ConversationState state);
Task OnStateChanged(string name, string preValue, string currentValue);
```

### Content Generating Hook
`IContentGeneratingHook`

Model content generating hook, it can be used for logging, metrics and tracing.
```csharp
// Before content generating.
Task BeforeGenerating(Agent agent, List<RoleDialogModel> conversations);

// After content generated.
Task AfterGenerated(RoleDialogModel message, TokenStatsModel tokenStats);
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ namespace BotSharp.Abstraction.Conversations;
public interface IConversationService
{
IConversationStateService States { get; }
string ConversationId { get; }
Task<Conversation> NewConversation(Conversation conversation);
void SetConversationId(string conversationId, List<string> states);
Task<Conversation> GetConversation(string id);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using BotSharp.Abstraction.Functions.Models;
using BotSharp.Abstraction.Models;

namespace BotSharp.Abstraction.Conversations.Models;

public class RoleDialogModel : ITrackableMessage
{
/// <summary>
/// If Role is Assistant, it is same as user's message id.
/// </summary>
public string MessageId { get; set; }

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ namespace BotSharp.Abstraction.Conversations.Models;
public class TokenStatsModel
{
public string Model { get; set; }
public string Prompt { get; set; }
public int PromptCount { get; set; }
public int CompletionCount { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ public interface ITextCompletion
/// <param name="model"></param>
void SetModelName(string model);

Task<string> GetCompletion(string text);
Task<string> GetCompletion(string text, string agentId, string messageId);
}
7 changes: 3 additions & 4 deletions src/Infrastructure/BotSharp.Abstraction/Planning/IExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ namespace BotSharp.Abstraction.Planning;

public interface IExecutor
{
Task<bool> Execute(IRoutingService routing,
Agent router,
Task<RoleDialogModel> Execute(IRoutingService routing,
FunctionCallFromLlm inst,
List<RoleDialogModel> dialogs,
RoleDialogModel message);
RoleDialogModel message,
List<RoleDialogModel> dialogs);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace BotSharp.Abstraction.Planning;
/// </summary>
public interface IPlaner
{
Task<FunctionCallFromLlm> GetNextInstruction(Agent router);
Task<FunctionCallFromLlm> GetNextInstruction(Agent router, string messageId);
Task<bool> AgentExecuting(FunctionCallFromLlm inst, RoleDialogModel message);
Task<bool> AgentExecuted(FunctionCallFromLlm inst, RoleDialogModel message);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using BotSharp.Abstraction.Functions.Models;
using BotSharp.Abstraction.Planning;

namespace BotSharp.Abstraction.Routing;

Expand All @@ -15,9 +14,7 @@ public interface IRoutingHandler
bool Enabled => true;
List<ParameterPropertyDef> Parameters => new List<ParameterPropertyDef>();

void SetRouter(Agent router) { }

void SetDialogs(List<RoleDialogModel> dialogs) { }
void SetDialogs(List<RoleDialogModel> dialogs);

Task<bool> Handle(IRoutingService routing, FunctionCallFromLlm inst, RoleDialogModel message);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ namespace BotSharp.Abstraction.Routing;

public interface IRoutingService
{
List<RoleDialogModel> Dialogs { get; }
Agent Router { get; }
void ResetRecursiveCounter();
void RefreshDialogs();
Task<bool> InvokeAgent(string agentId, RoleDialogModel message);
Task<bool> InstructLoop(RoleDialogModel message);
Task<bool> ExecuteOnce(Agent agent, RoleDialogModel message);
Task<bool> InvokeAgent(string agentId, List<RoleDialogModel> dialogs);
Task<RoleDialogModel> InstructLoop(RoleDialogModel message);
Task<RoleDialogModel> ExecuteOnce(Agent agent, RoleDialogModel message);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ namespace BotSharp.Abstraction.Routing;

public abstract class RoutingHandlerBase
{
protected Agent _router;
protected readonly IServiceProvider _services;
protected readonly ILogger _logger;
protected RoutingSettings _settings;
Expand All @@ -20,11 +19,6 @@ public RoutingHandlerBase(IServiceProvider services,
_settings = settings;
}

public void SetRouter(Agent router)
{
_router = router;
}

public void SetDialogs(List<RoleDialogModel> dialogs)
{
_dialogs = dialogs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ namespace BotSharp.Core.Agents.Services;

public partial class AgentService
{
#if !DEBUG
[MemoryCache(10 * 60)]
#endif
public async Task<List<Agent>> GetAgents(bool? allowRouting = null)
{
var agents = _db.GetAgents(allowRouting: allowRouting);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ public static IServiceCollection AddBotSharp(this IServiceCollection services, I
services.AddSingleton((IServiceProvider x) => routingSettings);

services.AddScoped<NaivePlanner>();
services.AddScoped<ReasoningPlanner>();
services.AddScoped<HFPlanner>();
services.AddScoped<IPlaner>(provider =>
{
if (routingSettings.Planner == nameof(ReasoningPlanner))
return provider.GetRequiredService<ReasoningPlanner>();
if (routingSettings.Planner == nameof(HFPlanner))
return provider.GetRequiredService<HFPlanner>();
else
return provider.GetRequiredService<NaivePlanner>();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,18 @@ public async Task<bool> SendMessage(string agentId,
var routing = _services.GetRequiredService<IRoutingService>();
var settings = _services.GetRequiredService<RoutingSettings>();

var ret = agentId == settings.RouterId ?
var response = agentId == settings.RouterId ?
await routing.InstructLoop(message) :
await routing.ExecuteOnce(agent, message);

await HandleAssistantMessage(message, onMessageReceived);
await HandleAssistantMessage(response, onMessageReceived);

var statistics = _services.GetRequiredService<ITokenStatistics>();
statistics.PrintStatistics();

routing.ResetRecursiveCounter();
routing.RefreshDialogs();

return ret;
return true;
}

private async Task<Conversation> GetConversationRecord(string agentId)
Expand All @@ -86,15 +85,15 @@ private async Task<Conversation> GetConversationRecord(string agentId)
return converation;
}

private async Task HandleAssistantMessage(RoleDialogModel message, Func<RoleDialogModel, Task> onMessageReceived)
private async Task HandleAssistantMessage(RoleDialogModel response, Func<RoleDialogModel, Task> onMessageReceived)
{
var agentService = _services.GetRequiredService<IAgentService>();
var agent = await agentService.GetAgent(message.CurrentAgentId);
var agent = await agentService.GetAgent(response.CurrentAgentId);
var agentName = agent.Name;

var text = message.Role == AgentRole.Function ?
$"Sending [{agentName}] {message.FunctionName}: {message.Content}" :
$"Sending [{agentName}] {message.Role}: {message.Content}";
var text = response.Role == AgentRole.Function ?
$"Sending [{agentName}] {response.FunctionName}: {response.Content}" :
$"Sending [{agentName}] {response.Role}: {response.Content}";
#if DEBUG
Console.WriteLine(text, Color.Yellow);
#else
Expand All @@ -103,21 +102,21 @@ private async Task HandleAssistantMessage(RoleDialogModel message, Func<RoleDial

// Only read content from RichContent for UI rendering. When richContent is null, create a basic text message for richContent.
var state = _services.GetRequiredService<IConversationStateService>();
message.RichContent = message.RichContent ?? new RichContent<TextMessage>
response.RichContent = response.RichContent ?? new RichContent<TextMessage>
{
Recipient = new Recipient { Id = state.GetConversationId() },
Message = new TextMessage { Text = message.Content }
Message = new TextMessage { Text = response.Content }
};

var hooks = _services.GetServices<IConversationHook>().ToList();
foreach (var hook in hooks)
{
await hook.OnResponseGenerated(message);
await hook.OnResponseGenerated(response);
}

await onMessageReceived(message);
await onMessageReceived(response);

// Add to dialog history
_storage.Append(_conversationId, message);
_storage.Append(_conversationId, response);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using BotSharp.Abstraction.Conversations.Models;
using BotSharp.Abstraction.Repositories;

namespace BotSharp.Core.Conversations.Services;
Expand All @@ -12,6 +11,7 @@ public partial class ConversationService : IConversationService
private readonly IConversationStorage _storage;
private readonly IConversationStateService _state;
private string _conversationId;
public string ConversationId => _conversationId;

public IConversationStateService States => _state;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ public async Task<Conversation> Execute(string task, EvaluationRequest request)
};

var textCompletion = CompletionProvider.GetTextCompletion(_services);
RoleDialogModel response = default;
RoleDialogModel response = new RoleDialogModel(AgentRole.User, "");
var dialogs = new List<RoleDialogModel>();
int roundCount = 0;
while (true)
{
// var text = string.Join("\r\n", dialogs.Select(x => $"{x.Role}: {x.Content}"));
// text = instruction + $"\r\n###\r\n{text}\r\n{AgentRole.User}: ";
var question = await textCompletion.GetCompletion(prompt);
var question = await textCompletion.GetCompletion(prompt, request.AgentId, response.MessageId);
dialogs.Add(new RoleDialogModel(AgentRole.User, question));
prompt += question.Trim();

Expand All @@ -61,9 +61,14 @@ public async Task<Conversation> Execute(string task, EvaluationRequest request)

roundCount++;

if (roundCount > 10)
{
Console.WriteLine($"Conversation ended due to execced max round count {roundCount}", Color.Red);
break;
}

if (response.FunctionName == "conversation_end" ||
response.FunctionName == "human_intervention_needed" ||
roundCount > 5)
response.FunctionName == "human_intervention_needed")
{
Console.WriteLine($"Conversation ended by function {response.FunctionName}", Color.Green);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public async Task<InstructResult> Execute(string agentId, RoleDialogModel messag
agentService.RenderedTemplate(agent, templateName);

var completer = CompletionProvider.GetTextCompletion(_services);
var result = await completer.GetCompletion(prompt);
var result = await completer.GetCompletion(prompt, agentId, message.MessageId);
var response = new InstructResult
{
MessageId = message.MessageId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,33 @@
using BotSharp.Abstraction.Planning;
using BotSharp.Abstraction.Repositories;
using BotSharp.Abstraction.Routing.Models;
using BotSharp.Abstraction.Routing.Settings;
using BotSharp.Abstraction.Templating;

namespace BotSharp.Core.Planning;

public class ReasoningPlanner : IPlaner
/// <summary>
/// Human feedback based planner
/// </summary>
public class HFPlanner : IPlaner
{
private readonly IServiceProvider _services;
private readonly ILogger _logger;

public ReasoningPlanner(IServiceProvider services, ILogger<ReasoningPlanner> logger)
public HFPlanner(IServiceProvider services, ILogger<HFPlanner> logger)
{
_services = services;
_logger = logger;
}

public async Task<FunctionCallFromLlm> GetNextInstruction(Agent router)
public async Task<FunctionCallFromLlm> GetNextInstruction(Agent router, string messageId)
{
var next = GetNextStepPrompt(router);

RoleDialogModel response = default;
var inst = new FunctionCallFromLlm();

var completion = CompletionProvider.GetChatCompletion(_services,
model: "llm-gpt4");
var completion = CompletionProvider.GetChatCompletion(_services);

int retryCount = 0;
while (retryCount < 3)
Expand All @@ -36,6 +39,9 @@ public async Task<FunctionCallFromLlm> GetNextInstruction(Agent router)
response = completion.GetChatCompletions(router, new List<RoleDialogModel>
{
new RoleDialogModel(AgentRole.User, next)
{
MessageId = messageId
}
});

inst = response.Content.JsonContent<FunctionCallFromLlm>();
Expand All @@ -59,14 +65,14 @@ public async Task<FunctionCallFromLlm> GetNextInstruction(Agent router)

public async Task<bool> AgentExecuting(FunctionCallFromLlm inst, RoleDialogModel message)
{
message.Content = inst.Question;
message.FunctionArgs = JsonSerializer.Serialize(inst.Arguments);

var db = _services.GetRequiredService<IBotSharpRepository>();
var agent = db.GetAgents(inst.AgentName).FirstOrDefault();
if (!string.IsNullOrEmpty(inst.AgentName))
{
var db = _services.GetRequiredService<IBotSharpRepository>();
var agent = db.GetAgents(inst.AgentName).FirstOrDefault();

var context = _services.GetRequiredService<RoutingContext>();
context.Push(agent.Id);
var context = _services.GetRequiredService<RoutingContext>();
context.Push(agent.Id);
}

return true;
}
Expand All @@ -76,9 +82,6 @@ public async Task<bool> AgentExecuted(FunctionCallFromLlm inst, RoleDialogModel
var context = _services.GetRequiredService<RoutingContext>();
context.Pop();

// push Router to continue
// Make decision according to last agent's response

return true;
}

Expand Down
Loading

0 comments on commit 3caa7a5

Please sign in to comment.