Skip to content

Commit

Permalink
Prompt API: Add multimodal input IDL skeleton
Browse files Browse the repository at this point in the history
Add IDL for new LanguageModel[Factory] API types, etc., per:
  webmachinelearning/prompt-api#71

API use with currently supported input types is unchanged.
API use with new input types throws TypeErrors for now.

Move create() WPTs into a new file as separate tests.

Bug: 385173789, 385173368
Change-Id: Id8ca1c8410f1a97bb7d28b4bc568020ff0412698
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6216647
Reviewed-by: Clark DuVall <[email protected]>
Commit-Queue: Mike Wasserman <[email protected]>
Cr-Commit-Position: refs/heads/main@{#1414456}
  • Loading branch information
Mike Wasserman authored and Chromium LUCI CQ committed Feb 1, 2025
1 parent 913f9cb commit 259d706
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 114 deletions.
22 changes: 20 additions & 2 deletions third_party/blink/renderer/bindings/generated_in_modules.gni
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,16 @@ generated_dictionary_sources_in_modules = [
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_address_init.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_clone_options.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_clone_options.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_create_core_options.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_create_core_options.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_create_options.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_create_options.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_initial_prompt.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_initial_prompt.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_initial_prompt_line_dict.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_initial_prompt_line_dict.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_line_dict.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_line_dict.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_content_dict.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_content_dict.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_options.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_options.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_detector_create_options.cc",
Expand Down Expand Up @@ -1420,6 +1426,8 @@ generated_enumeration_sources_in_modules = [
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_initial_prompt_role.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_role.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_role.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_type.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_type.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_capability_availability.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_capability_availability.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ai_rewriter_format.cc",
Expand Down Expand Up @@ -3254,6 +3262,16 @@ generated_typedef_sources_in_modules = [
generated_union_sources_in_modules = [
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_adproperties_adpropertiessequence.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_adproperties_adpropertiessequence.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelpromptcontentdict_string.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelpromptcontentdict_string.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelinitialpromptlinedict_ailanguagemodelpromptcontentdict_string_string.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelinitialpromptlinedict_ailanguagemodelpromptcontentdict_string_string.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelpromptcontentdict_ailanguagemodelpromptlinedict_string_string.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelpromptcontentdict_ailanguagemodelpromptlinedict_string_string.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelpromptcontentdict_ailanguagemodelpromptcontentdictorstringorailanguagemodelpromptlinedictorstringsequence_ailanguagemodelpromptlinedict_string_string.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelpromptcontentdict_ailanguagemodelpromptcontentdictorstringorailanguagemodelpromptlinedictorstringsequence_ailanguagemodelpromptlinedict_string_string.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_arraybuffer_arraybufferview_audiobuffer_blob_htmlaudioelement_htmlcanvaselement_htmlimageelement_htmlvideoelement_imagebitmap_imagedata_offscreencanvas_svgimageelement_string_videoframe.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_arraybuffer_arraybufferview_audiobuffer_blob_htmlaudioelement_htmlcanvaselement_htmlimageelement_htmlvideoelement_imagebitmap_imagedata_offscreencanvas_svgimageelement_string_videoframe.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_arraybuffer_arraybufferview_blob_usvstring_writeparams.cc",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_arraybuffer_arraybufferview_blob_usvstring_writeparams.h",
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_union_arraybuffer_arraybufferview_jsonwebkey.cc",
Expand Down
34 changes: 25 additions & 9 deletions third_party/blink/renderer/modules/ai/ai_language_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom-blink.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom-blink.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelpromptcontentdict_ailanguagemodelpromptcontentdictorstringorailanguagemodelpromptlinedictorstringsequence_ailanguagemodelpromptlinedict_string_string.h"
#include "third_party/blink/renderer/core/dom/abort_signal.h"
#include "third_party/blink/renderer/core/dom/events/event.h"
#include "third_party/blink/renderer/modules/ai/ai_language_model_factory.h"
Expand Down Expand Up @@ -182,24 +183,32 @@ ExecutionContext* AILanguageModel::GetExecutionContext() const {

ScriptPromise<IDLString> AILanguageModel::prompt(
ScriptState* script_state,
const WTF::String& input,
const V8AILanguageModelPromptInput* input,
const AILanguageModelPromptOptions* options,
ExceptionState& exception_state) {
if (!script_state->ContextIsValid()) {
ThrowInvalidContextException(exception_state);
return ScriptPromise<IDLString>();
}

ScriptPromiseResolver<IDLString>* resolver =
MakeGarbageCollected<ScriptPromiseResolver<IDLString>>(script_state);
auto promise = resolver->Promise();

// The API impl only accepts a string for now, more to come soon!
if (!input->IsString()) {
resolver->RejectWithTypeError("Input type not supported");
return promise;
}
const WTF::String& input_string = input->GetAsString();

base::UmaHistogramEnumeration(AIMetrics::GetAIAPIUsageMetricName(
AIMetrics::AISessionType::kLanguageModel),
AIMetrics::AIAPI::kSessionPrompt);

base::UmaHistogramCounts1M(AIMetrics::GetAISessionRequestSizeMetricName(
AIMetrics::AISessionType::kLanguageModel),
int(input.CharactersSizeInBytes()));
ScriptPromiseResolver<IDLString>* resolver =
MakeGarbageCollected<ScriptPromiseResolver<IDLString>>(script_state);
auto promise = resolver->Promise();
int(input_string.CharactersSizeInBytes()));

if (!language_model_remote_) {
ThrowSessionDestroyedException(exception_state);
Expand All @@ -219,27 +228,34 @@ ScriptPromise<IDLString> AILanguageModel::prompt(
WrapWeakPersistent(this)),
WTF::BindRepeating(&AILanguageModel::OnContextOverflow,
WrapWeakPersistent(this)));
language_model_remote_->Prompt(input, std::move(pending_remote));
language_model_remote_->Prompt(input_string, std::move(pending_remote));
return promise;
}

ReadableStream* AILanguageModel::promptStreaming(
ScriptState* script_state,
const WTF::String& input,
const V8AILanguageModelPromptInput* input,
const AILanguageModelPromptOptions* options,
ExceptionState& exception_state) {
if (!script_state->ContextIsValid()) {
ThrowInvalidContextException(exception_state);
return nullptr;
}

// The API impl only accepts a string for now, more to come soon!
if (!input->IsString()) {
exception_state.ThrowTypeError("Input type not supported");
return nullptr;
}
const WTF::String& input_string = input->GetAsString();

base::UmaHistogramEnumeration(AIMetrics::GetAIAPIUsageMetricName(
AIMetrics::AISessionType::kLanguageModel),
AIMetrics::AIAPI::kSessionPromptStreaming);

base::UmaHistogramCounts1M(AIMetrics::GetAISessionRequestSizeMetricName(
AIMetrics::AISessionType::kLanguageModel),
int(input.CharactersSizeInBytes()));
int(input_string.CharactersSizeInBytes()));

if (!language_model_remote_) {
ThrowSessionDestroyedException(exception_state);
Expand All @@ -259,7 +275,7 @@ ReadableStream* AILanguageModel::promptStreaming(
WrapWeakPersistent(this)),
WTF::BindRepeating(&AILanguageModel::OnContextOverflow,
WrapWeakPersistent(this)));
language_model_remote_->Prompt(input, std::move(pending_remote));
language_model_remote_->Prompt(input_string, std::move(pending_remote));
return readable_stream;
}

Expand Down
5 changes: 3 additions & 2 deletions third_party/blink/renderer/modules/ai/ai_language_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "third_party/blink/renderer/bindings/core/v8/script_promise.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_clone_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_typedefs.h"
#include "third_party/blink/renderer/core/dom/events/event_target.h"
#include "third_party/blink/renderer/core/event_type_names.h"
#include "third_party/blink/renderer/core/execution_context/execution_context_lifecycle_observer.h"
Expand Down Expand Up @@ -42,11 +43,11 @@ class AILanguageModel final : public EventTarget,

// ai_language_model.idl implementation.
ScriptPromise<IDLString> prompt(ScriptState* script_state,
const WTF::String& input,
const V8AILanguageModelPromptInput* input,
const AILanguageModelPromptOptions* options,
ExceptionState& exception_state);
ReadableStream* promptStreaming(ScriptState* script_state,
const WTF::String& input,
const V8AILanguageModelPromptInput* input,
const AILanguageModelPromptOptions* options,
ExceptionState& exception_state);
ScriptPromise<IDLUnsignedLongLong> countPromptTokens(
Expand Down
4 changes: 2 additions & 2 deletions third_party/blink/renderer/modules/ai/ai_language_model.idl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ interface AILanguageModel : EventTarget {
RaisesException
]
Promise<DOMString> prompt(
DOMString input,
AILanguageModelPromptInput input,
optional AILanguageModelPromptOptions options = {}
);
[
Expand All @@ -33,7 +33,7 @@ interface AILanguageModel : EventTarget {
RaisesException
]
ReadableStream promptStreaming(
DOMString input,
AILanguageModelPromptInput input,
optional AILanguageModelPromptOptions options = {}
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,54 @@

// https://github.com/webmachinelearning/prompt-api

dictionary AILanguageModelInitialPrompt {
// The argument to the prompt() method and others like it
typedef (AILanguageModelPromptLine or sequence<AILanguageModelPromptLine>) AILanguageModelPromptInput;

// Initial prompt lines
dictionary AILanguageModelInitialPromptLineDict {
required AILanguageModelInitialPromptRole role;
required DOMString content;
required AILanguageModelPromptContent content;
};
typedef (
DOMString // interpreted as { role: "user", content: { type: "text", data: providedValue } }
or AILanguageModelPromptContent // interpreted as { role: "user", content: providedValue }
or AILanguageModelInitialPromptLineDict // canonical form
) AILanguageModelInitialPromptLine;

dictionary AILanguageModelCreateOptions {
AbortSignal signal;
AICreateMonitorCallback monitor;
// Prompt lines
dictionary AILanguageModelPromptLineDict {
required AILanguageModelPromptRole role;
required AILanguageModelPromptContent content;
};
typedef (
DOMString // interpreted as { role: "user", content: { type: "text", data: providedValue } }
or AILanguageModelPromptContent // interpreted as { role: "user", content: providedValue }
or AILanguageModelPromptLineDict // canonical form
) AILanguageModelPromptLine;

// Prompt content inside the lines
dictionary AILanguageModelPromptContentDict {
required AILanguageModelPromptType type;
required AILanguageModelPromptData data;
};
typedef (DOMString or AILanguageModelPromptContentDict) AILanguageModelPromptContent;
typedef (ImageBitmapSource or BufferSource or AudioBuffer or HTMLAudioElement or DOMString) AILanguageModelPromptData;
enum AILanguageModelPromptType { "text", "image", "audio" };

// Prompt roles inside the lines
enum AILanguageModelInitialPromptRole { "system", "user", "assistant" };
enum AILanguageModelPromptRole { "user", "assistant" };

dictionary AILanguageModelCreateCoreOptions {
[EnforceRange] unsigned long topK;
float temperature;
sequence<DOMString> expectedInputLanguages;
sequence<AILanguageModelPromptType> expectedInputTypes;
};

dictionary AILanguageModelCreateOptions : AILanguageModelCreateCoreOptions {
AbortSignal signal;
AICreateMonitorCallback monitor;
DOMString systemPrompt;
sequence<AILanguageModelInitialPrompt> initialPrompts;
sequence<AILanguageModelInitialPromptLine> initialPrompts;
};
43 changes: 37 additions & 6 deletions third_party/blink/renderer/modules/ai/ai_language_model_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom-blink.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ai_create_monitor_callback.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_create_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_initial_prompt.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_initial_prompt_line_dict.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_initial_prompt_role.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ai_language_model_prompt_content_dict.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelinitialpromptlinedict_ailanguagemodelpromptcontentdict_string_string.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_union_ailanguagemodelpromptcontentdict_string.h"
#include "third_party/blink/renderer/core/events/progress_event.h"
#include "third_party/blink/renderer/core/execution_context/execution_context.h"
#include "third_party/blink/renderer/modules/ai/ai.h"
Expand Down Expand Up @@ -275,6 +278,11 @@ ScriptPromise<AILanguageModel> AILanguageModelFactory::create(
return promise;
}

// The API impl does not yet support expectedInputTypes, more to come soon!
if (options->hasExpectedInputTypes()) {
resolver->RejectWithTypeError("expectedInputTypes not supported");
}

if (options->hasSystemPrompt()) {
system_prompt = options->systemPrompt();
}
Expand All @@ -286,7 +294,14 @@ ScriptPromise<AILanguageModel> AILanguageModelFactory::create(
// Only the first prompt might have a `system` role, so it's handled
// separately.
auto* first_prompt = prompts.begin()->Get();
if (first_prompt->role() ==
// The API impl only accepts a line dict for now, more to come soon!
if (!first_prompt->IsAILanguageModelInitialPromptLineDict()) {
resolver->RejectWithTypeError("Input type not supported");
return promise;
}
auto* first_prompt_dict =
first_prompt->GetAsAILanguageModelInitialPromptLineDict();
if (first_prompt_dict->role() ==
V8AILanguageModelInitialPromptRole::Enum::kSystem) {
if (options->hasSystemPrompt()) {
// If the system prompt cannot be provided both from system prompt
Expand All @@ -295,23 +310,39 @@ ScriptPromise<AILanguageModel> AILanguageModelFactory::create(
kExceptionMessageSystemPromptIsDefinedMultipleTimes);
return promise;
}
system_prompt = first_prompt->content();
// The API impl only accepts a string for now, more to come soon!
if (!first_prompt_dict->content()->IsString()) {
resolver->RejectWithTypeError("Input type not supported");
return promise;
}
system_prompt = first_prompt_dict->content()->GetAsString();
start_index++;
}
for (size_t index = start_index; index < prompts.size(); ++index) {
auto prompt = prompts[index];
if (prompt->role() ==
// The API impl only accepts a line dict for now, more to come soon!
if (!prompt->IsAILanguageModelInitialPromptLineDict()) {
resolver->RejectWithTypeError("Input type not supported");
return promise;
}
auto* dict = prompt->GetAsAILanguageModelInitialPromptLineDict();
if (dict->role() ==
V8AILanguageModelInitialPromptRole::Enum::kSystem) {
// If any prompt except the first one has a `system` role, reject
// with a `TypeError`.
resolver->RejectWithTypeError(
kExceptionMessageSystemPromptIsNotTheFirst);
return promise;
}
// The API impl only accepts string for now, more to come soon!
if (!dict->content()->IsString()) {
resolver->RejectWithTypeError("Input type not supported");
return promise;
}
initial_prompts.push_back(
mojom::blink::AILanguageModelInitialPrompt::New(
AILanguageModelInitialPromptRole(prompt->role()),
prompt->content()));
AILanguageModelInitialPromptRole(dict->role()),
dict->content()->GetAsString()));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

// https://github.com/webmachinelearning/prompt-api

enum AILanguageModelInitialPromptRole { "system", "user", "assistant" };
enum AILanguageModelPromptRole { "user", "assistant" };

[
Exposed=(Window,Worker),
SecureContext,
Expand Down
Loading

0 comments on commit 259d706

Please sign in to comment.