Skip to content

Commit

Permalink
gai: Fix "party mode" extension script.
Browse files Browse the repository at this point in the history
  • Loading branch information
patniemeyer committed Dec 13, 2024
1 parent 7a3e504 commit 09770fe
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 43 deletions.
18 changes: 14 additions & 4 deletions gai-frontend/lib/chat/chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -204,19 +204,27 @@ class _ChatViewState extends State<ChatView> {
String? modelId,
String? modelName,
}) {
final message = ChatMessage(
_addChatMessage(ChatMessage(
source: source,
message: msg,
metadata: metadata,
sourceName: sourceName,
modelId: modelId,
modelName: modelName,
);
_addChatMessage(message);
));
}

// Add a message to the chat history and update the UI
void _addChatMessage(ChatMessage message) {
log('Adding message: ${message.message.truncate(64)}');

// Add the verbose model name for the model if not provided.
// Note: This should probably be pushed down to the UI logic to support localization.
if (message.modelName == null && message.modelId != null) {
final model = _modelManager.getModel(message.modelId!);
message = message.copyWith(modelName: model?.name);
}

setState(() {
_chatHistory.addMessage(message);
});
Expand Down Expand Up @@ -382,7 +390,6 @@ class _ChatViewState extends State<ChatView> {
chatResponse.message,
metadata: metadata,
modelId: modelId,
modelName: _modelManager.getModelOrDefaultNullable(modelId)?.name,
);
}

Expand Down Expand Up @@ -612,6 +619,9 @@ class _ChatViewState extends State<ChatView> {
onPartyModeChanged: () {
setState(() {
_partyMode = !_partyMode;
if (_partyMode) {
_multiSelectMode = true;
}
});
},
onClearChat: _clearChat,
Expand Down
49 changes: 28 additions & 21 deletions gai-frontend/lib/chat/chat_bubble.dart
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class ChatBubble extends StatelessWidget {
message.message,
style: const TextStyle(
fontFamily: 'Baloo2',
fontSize: 14, // 16px equivalent
fontSize: 14,
// 16px equivalent
height: 1.0,
fontWeight: FontWeight.normal,
color: Colors.white,
Expand All @@ -49,49 +50,51 @@ class ChatBubble extends StatelessWidget {
? Alignment.centerLeft
: Alignment.centerRight,
child: Container(
constraints: BoxConstraints(maxWidth: 0.6 * 800),
constraints: const BoxConstraints(maxWidth: 0.6 * 800),
child: Column(
crossAxisAlignment: src == ChatMessageSource.provider
? CrossAxisAlignment.start
crossAxisAlignment: src == ChatMessageSource.provider
? CrossAxisAlignment.start
: CrossAxisAlignment.end,
children: <Widget>[
// Header row with icon and name for both provider and user
Row(
mainAxisAlignment: src == ChatMessageSource.provider
? MainAxisAlignment.start
mainAxisAlignment: src == ChatMessageSource.provider
? MainAxisAlignment.start
: MainAxisAlignment.end,
crossAxisAlignment: CrossAxisAlignment.center,
children: [
if (src == ChatMessageSource.provider) ...[
Icon(
const Icon(
Icons.stars_rounded,
color: OrchidColors.blue_highlight,
size: iconSize,
),
SizedBox(width: iconSpacing),
const SizedBox(width: iconSpacing),
Text(
message.displayName ?? 'Chat',
style: TextStyle(
style: const TextStyle(
fontFamily: 'Baloo2',
fontSize: 14, // 16px equivalent
fontSize: 14,
// 16px equivalent
height: 1.0,
fontWeight: FontWeight.w500,
color: OrchidColors.blue_highlight,
),
),
] else ...[
Text(
const Text(
'You',
style: TextStyle(
fontFamily: 'Baloo2',
fontSize: 14, // 16px equivalent
fontSize: 14,
// 16px equivalent
height: 1.0,
fontWeight: FontWeight.w500,
color: OrchidColors.blue_highlight,
),
),
SizedBox(width: iconSpacing),
Icon(
const SizedBox(width: iconSpacing),
const Icon(
Icons.account_circle_rounded,
color: OrchidColors.blue_highlight,
size: iconSize,
Expand All @@ -103,12 +106,13 @@ class ChatBubble extends StatelessWidget {
// Message content with padding for provider messages
if (src == ChatMessageSource.provider)
Padding(
padding: EdgeInsets.only(left: iconTotalWidth),
padding: const EdgeInsets.only(left: iconTotalWidth),
child: SelectableText(
message.message,
style: const TextStyle(
fontFamily: 'Baloo2',
fontSize: 20, // 20px design spec
fontSize: 20,
// 20px design spec
height: 1.0,
fontWeight: FontWeight.normal,
color: Colors.white,
Expand All @@ -117,7 +121,8 @@ class ChatBubble extends StatelessWidget {
)
else
Container(
padding: const EdgeInsets.symmetric(horizontal: 25, vertical: 8),
padding:
const EdgeInsets.symmetric(horizontal: 25, vertical: 8),
decoration: BoxDecoration(
color: Colors.black,
borderRadius: BorderRadius.circular(10),
Expand All @@ -126,7 +131,8 @@ class ChatBubble extends StatelessWidget {
message.message,
style: const TextStyle(
fontFamily: 'Baloo2',
fontSize: 20, // 20px design spec
fontSize: 20,
// 20px design spec
height: 1.0,
fontWeight: FontWeight.normal,
color: Colors.white,
Expand All @@ -137,12 +143,13 @@ class ChatBubble extends StatelessWidget {
if (src == ChatMessageSource.provider) ...[
const SizedBox(height: 4),
Padding(
padding: EdgeInsets.only(left: iconTotalWidth),
padding: const EdgeInsets.only(left: iconTotalWidth),
child: SelectableText(
message.formatUsage(),
style: TextStyle(
style: const TextStyle(
fontFamily: 'Baloo2',
fontSize: 14, // 16px equivalent
fontSize: 14,
// 16px equivalent
height: 1.0,
fontWeight: FontWeight.normal,
color: OrchidColors.purpleCaption,
Expand Down
19 changes: 19 additions & 0 deletions gai-frontend/lib/chat/chat_message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ class ChatMessage {
return 'tokens: $prompt in, $completion out';
}

// Clone this immutable object with new values for some fields
ChatMessage copyWith({
ChatMessageSource? source,
String? message,
Map<String, dynamic>? metadata,
String? sourceName,
String? modelId,
String? modelName,
}) {
return ChatMessage(
source: source ?? this.source,
message: message ?? this.message,
metadata: metadata ?? this.metadata,
sourceName: sourceName ?? this.sourceName,
modelId: modelId ?? this.modelId,
modelName: modelName ?? this.modelName,
);
}

@override
String toString() {
return 'ChatMessage(source: $source, modelId: $modelId, model: $modelName, msg: ${message.substring(0, message.length.clamp(0, 50))}...)';
Expand Down
5 changes: 0 additions & 5 deletions gai-frontend/lib/chat/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ class ModelInfo {
);
}

// Format a chat message for this model.
Map<String, String> formatMessage(ChatMessage message) {
return ModelAPI.formatMessage(message, id);
}

// Format chat messages for this model.
List<Map<String, String>> formatMessages(List<ChatMessage> messages) {
return ModelAPI.formatMessages(messages, id);
Expand Down
2 changes: 1 addition & 1 deletion gai-frontend/lib/chat/provider_connection.dart
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ChatInferenceResponse {
return ChatMessage(
source: ChatMessageSource.provider,
message: message,
// sourceName: request.modelId,
sourceName: request.modelId,
metadata: metadata,
modelId: request.modelId,
);
Expand Down
25 changes: 13 additions & 12 deletions gai-frontend/lib/chat/scripting/extensions/party_mode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ function onUserPrompt(userPrompt: string): void {
chatSystemMessage('Extension: Party mode invoked');
chatClientMessage(userPrompt)

// Gather messages of source type 'client' or 'provider', irrespective of the model
// (Same as getConversation(), doing this for illustration)
const filteredMessages = getChatHistory().filter(
(message) =>
message.source === ChatMessageSource.CLIENT ||
message.source === ChatMessageSource.PROVIDER
);

// Send to each user-selected model
// Send to each user-selected model in turn
for (const model of getUserSelectedModels()) {
// Gather all messages of source type 'client' or 'provider', irrespective of source model.
// (Doing this inside the loop allows the models to see the previous models responses.)
const filteredMessages = getChatHistory().filter(
(message) =>
message.source === ChatMessageSource.CLIENT ||
message.source === ChatMessageSource.PROVIDER
);

// Send the messages to the model and add the response to the chat
console.log(`party_mode: Sending messages to model: ${model.name}`);
await chatSendToModel(filteredMessages, model.id);
const response = await chatSendToModel(filteredMessages, model.id);
addChatMessage(response);
}
})();
}

}
13 changes: 13 additions & 0 deletions gai-frontend/lib/chat/scripting/extensions/party_mode_min.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Implement "party mode" chat command

/// Let the IDE see the types from the chat_scripting_api during development.
/// <reference path="../chat_scripting_api.ts" />

function onUserPrompt(userPrompt: string): void {
(async () => {
chatClientMessage(userPrompt)
for (const model of getUserSelectedModels()) {
addChatMessage(await chatSendToModel(getConversation(), model.id));
}
})();
}

0 comments on commit 09770fe

Please sign in to comment.