diff --git a/README.md b/README.md index 63e948e..15d2955 100644 --- a/README.md +++ b/README.md @@ -14,40 +14,82 @@ export OPENAI_API_KEY= 2. Add the following snippet at the beginning of your script: ```nextflow -include { prompt } from 'plugin/nf-gpt' +include { gptPromptForText } from 'plugin/nf-gpt' ``` -3. Use the `prompt` operator to perform a ChatGPT query and collect teh result to a map object having the schema -of your choice, e.g. +3. Use the `gptPromptForText` operator to perform a ChatGPT prompt and get the response. ``` -include { prompt } from 'plugin/nf-gpt' - -def text = ''' -Extract information about a person from In 1968, amidst the fading echoes of Independence Day, -a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, -marked the start of a new journey. -''' - -channel - .of(text) - .prompt(schema: [firstName: 'string', lastName: 'string', birthDate: 'date (YYYY-MM-DD)']) - .view() +include { gptPromptForText } from 'plugin/nf-gpt' + +println gptPromptForText('Tell me a joke') + ``` -4. run using nextflow as usual +4. run using Nextflow as usual ``` nextflow run ``` -### Other example +5. See the folder [examples] for more examples. + + +## Reference + +### Function `gptPromptForText` + +The `gptPromptForText` function carries out a Gpt chat prompt and return the corresponding message as response as a string. Example: + + +```nextflow +println gptPromptForText('Tell me a joke') +``` + + +When the option `numOfChoices` is specified the response is a list of strings. + +```nextflow +def response = gptPromptForText('Tell me a joke', numOfChoices: 3) +for( String it : response ) + println it +``` + +Available options: -See the folder [examples] for more examples -### Options +| name | description | +|---------------|-------------| +| logitBias | Accepts an obnect mapping each token (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100 | +| model | The AI model to be used (default: `gpt-3.5-turbo`) | +| maxTokens | The maximum number of tokens that can be generated in the chat completion | +| numOfChoices | How many chat completion choices to generate for each input message (default: 1) | +| temperature | What sampling temperature to use, between 0 and 2 (default: `0.7`) | + + +### Function `gptPromptForData` + +The `gptPromptForData` function carries out a GPT chat prompt and returns the response as a list of +objects having the schema speciefied. For example: + +```nextflow + +def query = ''' + Extract information about a person from In 1968, amidst the fading echoes of Independence Day, + a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, + marked the start of a new journey. + ''' + +def response = gptPromptForData(query, schema: [firstName: 'string', lastName: 'string', birthDate: 'date (YYYY-MM-DD)']) + +println "First name: ${response[0].firstName}" +println "Last name: ${response[0].lastName}" +println "Birth date: ${response[0].birthDate}" +``` + + +The following options are available: -The `prompt` operator support those options | name | description | |---------------|-------------| @@ -56,6 +98,7 @@ The `prompt` operator support those options | schema | The expected strcuture for the result object represented as map object in which represent the attribute name and the value the attribute type | | temperature | What sampling temperature to use, between 0 and 2 (default: `0.7`) | + ### Configuration file The following config options can be specified in the `nextflow.config` file: @@ -70,7 +113,7 @@ The following config options can be specified in the `nextflow.config` file: | gpt.temperature | What sampling temperature to use, between 0 and 2 (default: `0.7`) | -## Testing and debugging +## Development To build and test the plugin during development, configure a local Nextflow build with the following steps: @@ -96,7 +139,7 @@ To build and test the plugin during development, configure a local Nextflow buil ./launch.sh run nextflow-io/hello -plugins nf-gpt ``` -## Testing without Nextflow build +### Testing without Nextflow build The plugin can be tested without using a local Nextflow build using the following steps: @@ -104,7 +147,7 @@ The plugin can be tested without using a local Nextflow build using the followin 2. Copy `build/plugins/` to `$HOME/.nextflow/plugins` 3. Create a pipeline that uses your plugin and run it: `nextflow run ./my-pipeline-script.nf` -## Package, upload, and publish +### Package, upload, and publish The project should be hosted in a GitHub repository whose name matches the name of the plugin, that is the name of the directory in the `plugins` folder (e.g. `nf-gpt`). diff --git a/examples/example1.nf b/examples/example1.nf index c23d095..e793480 100644 --- a/examples/example1.nf +++ b/examples/example1.nf @@ -1,12 +1,10 @@ -include { prompt } from 'plugin/nf-gpt' +include { gptPromptForText } from 'plugin/nf-gpt' -def text = ''' -Extract information about a person from In 1968, amidst the fading echoes of Independence Day, -a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, -marked the start of a new journey. -''' +/* + * This example show how to use the `gptPromptForText` function in the map operator + */ channel - .of(text) - .prompt(schema: [firstName: 'string', lastName: 'string', birthDate: 'date (YYYY-MM-DD)']) + .of('Tell me joke') + .map { gptPromptForText(it) } .view() diff --git a/examples/example2.nf b/examples/example2.nf index 308f687..91eb7d6 100644 --- a/examples/example2.nf +++ b/examples/example2.nf @@ -1,9 +1,18 @@ -include { prompt } from 'plugin/nf-gpt' +include { gptPromptForText } from 'plugin/nf-gpt' -def query = ''' -Who won most gold medals in swimming and Athletics categories during Barcelona 1992 and London 2012 olympic games?" -''' +/* + * This example show how to use the `gptPromptForText` function in a process + */ -channel .of(query) - .prompt(schema: [athlete: 'string', numberOfMedals: 'number', location:'string', sport:'string']) - .view() +process prompt { + input: + val query + output: + val response + exec: + response = gptPromptForText(query) +} + +workflow { + prompt('Tell me a joke') | view +} diff --git a/examples/example3.nf b/examples/example3.nf index c0b6bd5..a39cd8f 100644 --- a/examples/example3.nf +++ b/examples/example3.nf @@ -1,9 +1,16 @@ -include { prompt } from 'plugin/nf-gpt' +include { gptPromptForData } from 'plugin/nf-gpt' + +/** + * This example show how to perform a GPT prompt and map the response to a structured object + */ + +def text = ''' +Extract information about a person from In 1968, amidst the fading echoes of Independence Day, +a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, +marked the start of a new journey. +''' channel - .fromList(['Barcelona, 1992', 'London, 2012']) - .combine(['Swimming', 'Athletics']) - .prompt(schema: [athlete: 'string', numberOfMedals: 'number', location: 'string', sport: 'string']) { edition, sport -> - "Who won most gold medals in $sport category during $edition olympic games?" - } - .view() + .of(text) + .flatMap { gptPromptForData(it, schema: [firstName: 'string', lastName: 'string', birthDate: 'date (YYYY-MM-DD)']) } + .view() diff --git a/examples/example4.nf b/examples/example4.nf new file mode 100644 index 0000000..7fddd4f --- /dev/null +++ b/examples/example4.nf @@ -0,0 +1,16 @@ +include { gptPromptForData } from 'plugin/nf-gpt' + +/** + * This example show how to perform a GPT prompt and map the response to a structured object + */ + + +def query = ''' +Who won most gold medals in swimming and Athletics categories during Barcelona 1992 and London 2012 olympic games?" +''' + +def RECORD = [athlete: 'string', numberOfMedals: 'number', location:'string', sport:'string'] + +channel .of(query) + .flatMap { gptPromptForData(it, schema:RECORD, temperature: 2d) } + .view() diff --git a/examples/example5.nf b/examples/example5.nf new file mode 100644 index 0000000..61c5118 --- /dev/null +++ b/examples/example5.nf @@ -0,0 +1,16 @@ +include { gptPromptForData } from 'plugin/nf-gpt' + +/** + * This example show how to perform multiple GPT prompts using combine and flatMap operators + */ + + +channel + .fromList(['Barcelona, 1992', 'London, 2012']) + .combine(['Swimming', 'Athletics']) + .flatMap { edition, sport -> + gptPromptForData( + "Who won most gold medals in $sport category during $edition olympic games?", + schema: [athlete: 'string', numberOfMedals: 'number', location: 'string', sport: 'string']) + } + .view() diff --git a/plugins/nf-gpt/build.gradle b/plugins/nf-gpt/build.gradle index 8e6dcc9..c68e7db 100644 --- a/plugins/nf-gpt/build.gradle +++ b/plugins/nf-gpt/build.gradle @@ -59,7 +59,7 @@ dependencies { compileOnly 'org.slf4j:slf4j-api:1.7.10' compileOnly 'org.pf4j:pf4j:3.4.1' // add here plugins depepencies - api 'dev.langchain4j:langchain4j-open-ai:0.27.1' + api 'dev.langchain4j:langchain4j-open-ai:0.28.0' // test configuration testImplementation "org.apache.groovy:groovy:4.0.20" diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/client/GptChatCompletionRequest.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/client/GptChatCompletionRequest.groovy new file mode 100644 index 0000000..3c71349 --- /dev/null +++ b/plugins/nf-gpt/src/main/nextflow/gpt/client/GptChatCompletionRequest.groovy @@ -0,0 +1,127 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.gpt.client + +import groovy.transform.Canonical +import groovy.transform.CompileStatic +import groovy.transform.ToString + +/** + * Model a GTP chat conversation create request object. + * + * See also + * https://platform.openai.com/docs/api-reference/chat/create + * + * @author Paolo Di Tommaso + */ +@CompileStatic +@ToString(includePackage = false, includeNames = true) +class GptChatCompletionRequest { + + @ToString(includePackage = false, includeNames = true) + @CompileStatic + static class Message { + String role + String content + } + + @ToString(includePackage = false, includeNames = true) + @CompileStatic + static class ToolMessage extends Message { + String name + String tool_call_id + } + + @ToString(includePackage = false, includeNames = true) + @CompileStatic + static class Tool { + String type + Function function + } + + @ToString(includePackage = false, includeNames = true) + @CompileStatic + static class Function { + String name + String description + Parameters parameters + } + + @ToString(includePackage = false, includeNames = true) + @CompileStatic + static class Parameters { + String type + Map properties + List required + } + + @ToString(includePackage = false, includeNames = true) + @CompileStatic + static class Param { + String type + String description + } + + @ToString(includePackage = false, includeNames = true) + @CompileStatic + @Canonical + static class ResponseFormat { + static final ResponseFormat TEXT = new ResponseFormat('text') + static final ResponseFormat JSON = new ResponseFormat('json_object') + final String type + } + + /** + * ID of the model to use. + */ + String model + + /** + * A list of tools the model may call + */ + List messages + + List tools + + String tool_choice + + /** + * The maximum number of tokens that can be generated in the chat completion + */ + Integer max_tokens + + /** + * How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices + */ + Integer n + + /** + * What sampling temperature to use, between 0 and 2 + */ + Float temperature + + /** + * Modify the likelihood of specified tokens appearing in the completion + */ + Map logit_bias + + /** + * Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + ResponseFormat response_format +} diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/client/GptChatCompletionResponse.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/client/GptChatCompletionResponse.groovy new file mode 100644 index 0000000..efc7ed1 --- /dev/null +++ b/plugins/nf-gpt/src/main/nextflow/gpt/client/GptChatCompletionResponse.groovy @@ -0,0 +1,66 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.gpt.client + +import groovy.transform.CompileStatic +import groovy.transform.ToString +/** + * Model the GPT chat conversation response object + * + * See also + * https://platform.openai.com/docs/api-reference/chat/object + * + * @author Paolo Di Tommaso + */ +@CompileStatic +@ToString(includePackage = false, includeNames = true) +class GptChatCompletionResponse { + + @ToString(includePackage = false, includeNames = true) + static class Choice { + String finish_reason + Integer index + Message message + } + + @ToString(includePackage = false, includeNames = true) + static class Message { + String role + String content + List tool_calls + } + + @ToString(includePackage = false, includeNames = true) + static class ToolCall { + String id + String type + Function function + } + + @ToString(includePackage = false, includeNames = true) + static class Function { + String name + String arguments + } + + String id + String object + Long created + String model + List choices +} diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/client/GptClient.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/client/GptClient.groovy new file mode 100644 index 0000000..f293ad9 --- /dev/null +++ b/plugins/nf-gpt/src/main/nextflow/gpt/client/GptClient.groovy @@ -0,0 +1,156 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.gpt.client + +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import java.time.temporal.ChronoUnit +import java.util.concurrent.Executors +import java.util.function.Predicate + +import com.google.gson.Gson +import com.google.gson.reflect.TypeToken +import dev.failsafe.Failsafe +import dev.failsafe.RetryPolicy +import dev.failsafe.event.EventListener +import dev.failsafe.event.ExecutionAttemptedEvent +import dev.failsafe.function.CheckedSupplier +import groovy.transform.CompileStatic +import groovy.transform.Memoized +import groovy.util.logging.Slf4j +import nextflow.gpt.config.GptConfig +import nextflow.util.Threads +/** + * HTTP client for Gpt based API conversation + * + * @author Paolo Di Tommaso + */ +@Slf4j +@CompileStatic +class GptClient { + + final private String endpoint + final private GptConfig config + private HttpClient httpClient + + @Memoized + static GptClient client(GptConfig config) { + new GptClient(config) + } + + static GptClient client() { + return client(GptConfig.config()) + } + + /** + * Only for testing + */ + protected GptClient() { } + + protected GptClient(GptConfig config) { + this.config = config + this.endpoint = config.endpoint() + // create http client + this.httpClient = newHttpClient() + } + + protected HttpClient newHttpClient() { + final builder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .followRedirects(HttpClient.Redirect.NEVER) + // use virtual threads executor if enabled + if( Threads.useVirtual() ) + builder.executor(Executors.newVirtualThreadPerTaskExecutor()) + // build and return the new client + return builder.build() + } + + GptChatCompletionResponse sendRequest(GptChatCompletionRequest request) { + return sendRequest0(request, 1) + } + + GptChatCompletionResponse sendRequest0(GptChatCompletionRequest request, int attempt) { + assert endpoint, 'Missing ChatGPT endpoint' + assert !endpoint.endsWith('/'), "Endpoint url must not end with a slash - offending value: $endpoint" + assert config.apiKey(), "Missing ChatGPT API key" + + final body = new Gson().toJson(request) + final uri = URI.create("${endpoint}/v1/chat/completions") + log.debug "ChatGPT request: $uri; attempt=$attempt - request: $body" + final req = HttpRequest.newBuilder() + .uri(uri) + .headers('Content-Type','application/json') + .headers('Authorization', "Bearer ${config.apiKey()}") + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build() + + try { + final resp = httpSend(req) + log.debug "ChatGPT response: statusCode=${resp.statusCode()}; body=${resp.body()}" + if( resp.statusCode()==200 ) + return jsonToCompletionResponse(resp.body()) + else + throw new IllegalStateException("ChatGPT unexpected response: [${resp.statusCode()}] ${resp.body()}") + } + catch (IOException e) { + throw new IllegalStateException("Unable to connect ChatGPT service: $endpoint") + } + } + + protected GptChatCompletionResponse jsonToCompletionResponse(String json) { + final type = new TypeToken(){}.getType() + return new Gson().fromJson(json, type) + } + + protected RetryPolicy retryPolicy(Predicate cond, Predicate handle) { + final cfg = config.retryOpts() + final listener = new EventListener>() { + @Override + void accept(ExecutionAttemptedEvent event) throws Throwable { + def msg = "Gpt connection failure - attempt: ${event.attemptCount}" + if( event.lastResult!=null ) + msg += "; response: ${event.lastResult}" + if( event.lastFailure != null ) + msg += "; exception: [${event.lastFailure.class.name}] ${event.lastFailure.message}" + log.debug(msg) + } + } + return RetryPolicy.builder() + .handleIf(cond) + .handleResultIf(handle) + .withBackoff(cfg.delay.toMillis(), cfg.maxDelay.toMillis(), ChronoUnit.MILLIS) + .withMaxAttempts(cfg.maxAttempts) + .withJitter(cfg.jitter) + .onRetry(listener) + .build() + } + + protected HttpResponse safeApply(CheckedSupplier action) { + final retryOnException = (e -> e instanceof IOException) as Predicate + final retryOnStatusCode = ((HttpResponse resp) -> resp.statusCode() in SERVER_ERRORS) as Predicate> + final policy = retryPolicy(retryOnException, retryOnStatusCode) + return Failsafe.with(policy).get(action) + } + + static private final List SERVER_ERRORS = [429,500,502,503,504] + + protected HttpResponse httpSend(HttpRequest req) { + return safeApply(() -> httpClient.send(req, HttpResponse.BodyHandlers.ofString())) + } +} diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/config/GptConfig.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/config/GptConfig.groovy index 43819c0..7412c70 100644 --- a/plugins/nf-gpt/src/main/nextflow/gpt/config/GptConfig.groovy +++ b/plugins/nf-gpt/src/main/nextflow/gpt/config/GptConfig.groovy @@ -19,6 +19,7 @@ package nextflow.gpt.config import groovy.transform.CompileStatic import groovy.transform.ToString +import nextflow.Global import nextflow.Session import nextflow.SysEnv /** @@ -39,9 +40,14 @@ class GptConfig { private String model private Double temperature private Integer maxTokens + private GptRetryOpts retryOpts static GptConfig config(Session session) { - new GptConfig(session.config.ai as Map ?: Collections.emptyMap(), SysEnv.get()) + new GptConfig(session.config.gpt as Map ?: Collections.emptyMap(), SysEnv.get()) + } + + static GptConfig config() { + config(Global.session as Session) } GptConfig(Map opts, Map env) { @@ -49,6 +55,7 @@ class GptConfig { this.model = opts.model ?: DEFAULT_MODEL this.apiKey = opts.apiKey ?: env.get('OPENAI_API_KEY') this.temperature = opts.temperature!=null ? temperature as Double : DEFAULT_TEMPERATURE + this.retryOpts = new GptRetryOpts( opts.retryPolicy as Map ?: Map.of() ) } String endpoint() { @@ -70,4 +77,8 @@ class GptConfig { Integer maxTokens() { return maxTokens } + + GptRetryOpts retryOpts() { + return retryOpts + } } diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/config/GptRetryOpts.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/config/GptRetryOpts.groovy new file mode 100644 index 0000000..94a595c --- /dev/null +++ b/plugins/nf-gpt/src/main/nextflow/gpt/config/GptRetryOpts.groovy @@ -0,0 +1,29 @@ +package nextflow.gpt.config + +import groovy.transform.CompileStatic +import groovy.transform.ToString +import nextflow.util.Duration + +@ToString(includeNames = true, includePackage = false) +@CompileStatic +class GptRetryOpts { + Duration delay = Duration.of('450ms') + Duration maxDelay = Duration.of('90s') + int maxAttempts = 10 + double jitter = 0.25 + + GptRetryOpts() { + this(Collections.emptyMap()) + } + + GptRetryOpts(Map config) { + if( config.delay ) + delay = config.delay as Duration + if( config.maxDelay ) + maxDelay = config.maxDelay as Duration + if( config.maxAttempts ) + maxAttempts = config.maxAttempts as int + if( config.jitter ) + jitter = config.jitter as double + } +} diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptHelper.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptHelper.groovy new file mode 100644 index 0000000..b2f4151 --- /dev/null +++ b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptHelper.groovy @@ -0,0 +1,95 @@ +package nextflow.gpt.prompt + +import dev.langchain4j.data.message.AiMessage +import dev.langchain4j.data.message.ChatMessage +import dev.langchain4j.data.message.SystemMessage +import dev.langchain4j.data.message.UserMessage + +/** + * Helper methods for GPT conversation + * + * @author Paolo Di Tommaso + */ +class GptHelper { + + static protected String renderSchema(Map schema) { + return 'You must answer strictly in the following JSON format: {"result": [' + schema0(schema) + '] }' + } + + static protected String schema0(Object schema) { + if( schema instanceof List ) { + return "[" + (schema as List).collect(it -> schema0(it)).join(', ') + "]" + } + else if( schema instanceof Map ) { + return "{" + (schema as Map).collect( it -> "\"$it.key\": " + schema0(it.value) ).join(', ') + "}" + } + else if( schema instanceof CharSequence ) { + return "(type: $schema)" + } + else if( schema != null ) + throw new IllegalArgumentException("Unexpected data type: ") + else + throw new IllegalArgumentException("Data structure cannot be null") + } + + static protected List> decodeResponse(Object response, Map schema) { + final result = decodeResponse0(response,schema) + if( !result ) + throw new IllegalArgumentException("Response does not match expected schema: $schema - Offending value: $response") + return result + } + + static protected List> decodeResponse0(Object response, Map schema) { + final expected = schema.keySet() + if( response instanceof Map ) { + if( response.keySet()==expected ) { + return List.of(response as Map) + } + if( isIndexMap(response, schema) ) { + return new ArrayList>(response.values() as Collection>) + } + if( response.size()==1 ) { + return decodeResponse(response.values().first(), schema) + } + } + + if( response instanceof List ) { + final it = (response as List).first() + if( it instanceof Map && it.keySet()==expected ) + return response as List> + } + return null + } + + static protected boolean isIndexMap(Map response, Map schema) { + final keys = response.keySet() + // check all key are integers e.g. 0, 1, 2 + if( keys.every(it-> it.toString().isInteger() ) ) { + // take the first and check the object matches the scherma + final it = response.values().first() + return it instanceof Map && it.keySet()==schema.keySet() + } + return false + } + + static List messageToChat(List> messages) { + if( !messages ) + throw new IllegalArgumentException("Missing 'messages' argument") + final result = new ArrayList () + for( Map it : messages ) { + if( !it.role ) + throw new IllegalArgumentException("Missing 'role' attribute - offending message: $messages") + if( !it.content ) + throw new IllegalArgumentException("Missing 'content' attribute - offending message: $messages") + final msg = switch (it.role) { + case 'user' -> UserMessage.from(it.content) + case 'system' -> SystemMessage.from(it.content) + case 'ai' -> AiMessage.from(it.content) + default -> throw new IllegalArgumentException("Unsupported message role '${it.role}' - offending message: $messages") + } + result.add(msg) + } + return result + } + +} diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptExtension.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptExtension.groovy index c2a44d8..fa6eecb 100644 --- a/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptExtension.groovy +++ b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptExtension.groovy @@ -17,6 +17,7 @@ package nextflow.gpt.prompt + import static nextflow.util.CheckHelper.* import groovy.transform.CompileStatic @@ -26,17 +27,22 @@ import nextflow.Channel import nextflow.Session import nextflow.extension.CH import nextflow.extension.DataflowHelper +import nextflow.gpt.client.GptChatCompletionRequest +import nextflow.gpt.client.GptClient +import nextflow.gpt.config.GptConfig import nextflow.plugin.extension.Factory +import nextflow.plugin.extension.Function import nextflow.plugin.extension.Operator import nextflow.plugin.extension.PluginExtensionPoint /** + * Implements GPT Chat extension methods * * @author Paolo Di Tommaso */ @CompileStatic class GptPromptExtension extends PluginExtensionPoint { - static final private Map VALID_PROMPT_OPTS = [ + static final private Map VALID_PROMPT_DATA_OPTS = [ model: String, schema: Map, debug: Boolean, @@ -44,6 +50,15 @@ class GptPromptExtension extends PluginExtensionPoint { maxTokens: Integer ] + static final private Map VALID_PROMPT_TEXT_OPTS = [ + model: String, + debug: Boolean, + temperature: Double, + maxTokens: Integer, + numOfChoices: Integer, + logitBias: Map + ] + private Session session @Override @@ -54,7 +69,7 @@ class GptPromptExtension extends PluginExtensionPoint { @Factory DataflowWriteChannel fromPrompt(Map opts, String query) { // check params - checkParams( 'fromPrompt', opts, VALID_PROMPT_OPTS ) + checkParams( 'fromPrompt', opts, VALID_PROMPT_DATA_OPTS ) if( opts.schema == null ) throw new IllegalArgumentException("Missing prompt schema") // create the client @@ -63,6 +78,7 @@ class GptPromptExtension extends PluginExtensionPoint { .withDebug(opts.debug as Boolean) .withTemperature(opts.temperature as Double) .withMaxToken(opts.maxTokens as Integer) + .withJsonResponseFormat() .build() // run the prompt final response = ai.prompt(query, opts.schema as Map) @@ -79,7 +95,7 @@ class GptPromptExtension extends PluginExtensionPoint { @Operator DataflowWriteChannel prompt(DataflowReadChannel source, Map opts, Closure template) { // check params - checkParams( 'prompt', opts, VALID_PROMPT_OPTS ) + checkParams( 'prompt', opts, VALID_PROMPT_DATA_OPTS ) if( opts.schema == null ) throw new IllegalArgumentException("Missing prompt schema") // create the client @@ -88,6 +104,7 @@ class GptPromptExtension extends PluginExtensionPoint { .withDebug(opts.debug as Boolean) .withTemperature(opts.temperature as Double) .withMaxToken(opts.maxTokens as Integer) + .withJsonResponseFormat() .build() final target = CH.createBy(source) @@ -105,4 +122,75 @@ class GptPromptExtension extends PluginExtensionPoint { target.bind(it) } } + + @Function + List> gptPromptForData(Map opts, CharSequence query) { + // check params + checkParams( 'gptPromptForData', opts, VALID_PROMPT_DATA_OPTS ) + if( opts.schema == null ) + throw new IllegalArgumentException("Missing prompt schema") + // create the client + final ai = new GptPromptModel(session) + .withModel(opts.model as String) + .withDebug(opts.debug as Boolean) + .withTemperature(opts.temperature as Double) + .withMaxToken(opts.maxTokens as Integer) + .withJsonResponseFormat() + .build() + + return ai.prompt(query.toString(), opts.schema as Map) + } + + /** + * Carry out a GPT text prompt providing one or more messages + * + * @param opts + * Hold the prompt options + * @param messages + * The prompt message content + * @return + * The response content as a string or a list of string when the {@code numOfChoices} option is specified + */ + @Function + Object gptPromptForText(Map opts=Map.of(), String message) { + gptPromptForText(opts, List.of(Map.of('role','user', 'content',message))) + } + + /** + * Carry out a GPT text prompt providing one or more messages + * + * @param opts + * Hold the prompt options + * @param messages + * Hold the messages to carry out the prompt provided a list of key-value pairs, where the key represent + * the message "role" and the value thr message content e.g. + * {@code [ [system: "You should act as a good guy"], [role: "Tell me a joke"] ] + * @return + * The response content as a string or a list of string when the {@code numOfChoices} option is specified + */ + @Function + Object gptPromptForText(Map opts=Map.of(), List> messages) { + // check params + checkParams( 'gptPromptForText', opts, VALID_PROMPT_TEXT_OPTS ) + + final config = GptConfig.config(session) + final client = GptClient.client(config) + final model = opts.model ?: config.model() + final numOfChoices = opts.numOfChoices as Integer ?: 1 + final temperature = opts.temperature as Double ?: config.temperature() + final msg = messages.collect ((Map it)-> new GptChatCompletionRequest.Message(role:it.role, content:it.content)) + final request = new GptChatCompletionRequest( + model: model, + temperature: temperature, + messages: msg, + n: numOfChoices, + max_tokens: opts.maxTokens as Integer, + logit_bias: opts.logitBias as Map + ) + final resp = client.sendRequest(request) + return opts.numOfChoices==null + ? resp.choices.get(0).message.content + : resp.choices.collect(it-> it.message.content) + } + } diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptModel.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptModel.groovy index 265d8c0..7298184 100644 --- a/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptModel.groovy +++ b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptModel.groovy @@ -27,6 +27,9 @@ import groovy.util.logging.Slf4j import nextflow.Session import nextflow.gpt.config.GptConfig import nextflow.util.StringUtils + +import static nextflow.gpt.prompt.GptHelper.* + /** * Simple AI client for OpenAI model * @@ -36,6 +39,8 @@ import nextflow.util.StringUtils @CompileStatic class GptPromptModel { + private static final String JSON_OBJECT = "json_object" + private GptConfig config private OpenAiChatModel client @@ -43,6 +48,7 @@ class GptPromptModel { private boolean debug private Double temperature private Integer maxTokens + private String responseFormat GptPromptModel(Session session) { this.config = GptConfig.config(session) @@ -68,6 +74,16 @@ class GptPromptModel { return this } + GptPromptModel withResponseFormat(String format) { + this.responseFormat = format + return this + } + + GptPromptModel withJsonResponseFormat() { + this.responseFormat = JSON_OBJECT + return this + } + GptPromptModel build() { final modelName = model ?: config.model() final temperature = this.temperature ?: config.temperature() @@ -80,7 +96,7 @@ class GptPromptModel { .logResponses(debug) .temperature(temperature) .maxTokens(tokens) - .responseFormat("json_object") + .responseFormat(responseFormat) .build(); return this } @@ -88,6 +104,10 @@ class GptPromptModel { List> prompt(List messages, Map schema) { if( !messages ) throw new IllegalArgumentException("Missing AI prompt") + if( !schema ) + throw new IllegalArgumentException("Missing AI prompt schema") + if( responseFormat!=JSON_OBJECT ) + throw new IllegalStateException("AI prompt requires json_object response format") final all = new ArrayList(messages) all.add(SystemMessage.from(renderSchema(schema))) if( debug ) @@ -105,63 +125,9 @@ class GptPromptModel { return prompt(List.of(msg), schema) } - static protected String renderSchema(Map schema) { - return 'You must answer strictly in the following JSON format: {"result": [' + schema0(schema) + '] }' - } - - static protected String schema0(Object schema) { - if( schema instanceof List ) { - return "[" + (schema as List).collect(it -> schema0(it)).join(', ') + "]" - } - else if( schema instanceof Map ) { - return "{" + (schema as Map).collect( it -> "\"$it.key\": " + schema0(it.value) ).join(', ') + "}" - } - else if( schema instanceof CharSequence ) { - return "(type: $schema)" - } - else if( schema != null ) - throw new IllegalArgumentException("Unexpected data type: ") - else - throw new IllegalArgumentException("Data structure cannot be null") - } - - static protected List> decodeResponse(Object response, Map schema) { - final result = decodeResponse0(response,schema) - if( !result ) - throw new IllegalArgumentException("Response does not match expected schema: $schema - Offending value: $response") - return result - } - - static protected List> decodeResponse0(Object response, Map schema) { - final expected = schema.keySet() - if( response instanceof Map ) { - if( response.keySet()==expected ) { - return List.of(response as Map) - } - if( isIndexMap(response, schema) ) { - return new ArrayList>(response.values() as Collection>) - } - if( response.size()==1 ) { - return decodeResponse(response.values().first(), schema) - } - } - - if( response instanceof List ) { - final it = (response as List).first() - if( it instanceof Map && it.keySet()==expected ) - return response as List> - } - return null - } - - static protected boolean isIndexMap(Map response, Map schema) { - final keys = response.keySet() - // check all key are integers e.g. 0, 1, 2 - if( keys.every(it-> it.toString().isInteger() ) ) { - // take the first and check the object matches the scherma - final it = response.values().first() - return it instanceof Map && it.keySet()==schema.keySet() - } - return false + String generate(List messages) { + if( responseFormat ) + throw new IllegalArgumentException("Response format '$responseFormat' not support by 'generate' function") + return client.generate(messages).content().text() } } diff --git a/plugins/nf-gpt/src/resources/META-INF/MANIFEST.MF b/plugins/nf-gpt/src/resources/META-INF/MANIFEST.MF index 531edd3..ecbd3bb 100644 --- a/plugins/nf-gpt/src/resources/META-INF/MANIFEST.MF +++ b/plugins/nf-gpt/src/resources/META-INF/MANIFEST.MF @@ -1,6 +1,6 @@ Manifest-Version: 1.0 Plugin-Class: nextflow.gpt.GptPlugin Plugin-Id: nf-gpt -Plugin-Version: 0.2.0 +Plugin-Version: 0.3.0 Plugin-Provider: Seqera Labs Plugin-Requires: >=24.01.0-edge diff --git a/plugins/nf-gpt/src/test/nextflow/gpt/client/GptChatCompletionRequestTest.groovy b/plugins/nf-gpt/src/test/nextflow/gpt/client/GptChatCompletionRequestTest.groovy new file mode 100644 index 0000000..12ebcc3 --- /dev/null +++ b/plugins/nf-gpt/src/test/nextflow/gpt/client/GptChatCompletionRequestTest.groovy @@ -0,0 +1,85 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.gpt.client + +import groovy.json.JsonOutput +import spock.lang.Specification +/** + * + * @author Paolo Di Tommaso + */ +class GptChatCompletionRequestTest extends Specification { + + def 'should serialize a request' () { + given: + def p1 = new GptChatCompletionRequest.Param(type: 'string', description: 'Foo') + def p2 = new GptChatCompletionRequest.Param(type: 'string', description: 'Foo') + def parameters =new GptChatCompletionRequest.Parameters(type: 'object', properties: ['p1': p1, 'p2':p2], required: ['none']) + def fun = new GptChatCompletionRequest.Function(name: 'whats_the_weather_like', description: 'Just a description', parameters: parameters) + def tool = new GptChatCompletionRequest.Tool(type:'function', function: fun) + def msg = new GptChatCompletionRequest.Message(role: 'user', content: 'How do you do?') + and: + def request = new GptChatCompletionRequest(model: 'turbo', messages: [msg], tools: [tool]) + when: + def json = JsonOutput.prettyPrint(JsonOutput.toJson(request)) + then: + json == '''\ + { + "model": "turbo", + "messages": [ + { + "role": "user", + "content": "How do you do?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "whats_the_weather_like", + "description": "Just a description", + "parameters": { + "type": "object", + "properties": { + "p1": { + "type": "string", + "description": "Foo" + }, + "p2": { + "type": "string", + "description": "Foo" + } + }, + "required": [ + "none" + ] + } + } + } + ], + "tool_choice": null, + "max_tokens": null, + "n": null, + "temperature": null, + "logit_bias": null, + "response_format": null + } + '''.stripIndent().rightTrim() + + } +} diff --git a/plugins/nf-gpt/src/test/nextflow/gpt/client/GptClientTest.groovy b/plugins/nf-gpt/src/test/nextflow/gpt/client/GptClientTest.groovy new file mode 100644 index 0000000..220da13 --- /dev/null +++ b/plugins/nf-gpt/src/test/nextflow/gpt/client/GptClientTest.groovy @@ -0,0 +1,203 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.gpt.client + +import nextflow.Session +import nextflow.gpt.config.GptConfig +import spock.lang.Specification +/** + * + * @author Paolo Di Tommaso + */ +class GptClientTest extends Specification { + + def 'should parse response' () { + given: + def client = Spy(GptClient) + and: + def JSON = ''' +{ + "id": "chatcmpl-8w6HJpPYdPLfViMbHiGUCzDXimUEC", + "object": "chat.completion", + "created": 1708857849, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_Erqx0Rj6JLqOnOln8AOn5lXN", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\\"location\\": \\"San Francisco, CA\\"}" + } + }, + { + "id": "call_nKQtiSfcHwzUYVcbvAE57kk3", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\\"location\\": \\"Tokyo\\"}" + } + }, + { + "id": "call_myBENYchUxjtogARuOmzG7cL", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\\"location\\": \\"Paris\\"}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 80, + "completion_tokens": 64, + "total_tokens": 144 + }, + "system_fingerprint": "fp_86156a94a0" +} +''' + + when: + def resp = client.jsonToCompletionResponse(JSON) + then: + resp.id == 'chatcmpl-8w6HJpPYdPLfViMbHiGUCzDXimUEC' + resp.object == 'chat.completion' + resp.created == 1708857849 + resp.model == 'gpt-3.5-turbo-0125' + and: + resp.choices.size() == 1 + resp.choices.get(0).index == 0 + resp.choices.get(0).finish_reason == 'tool_calls' + resp.choices.get(0).message.role == 'assistant' + resp.choices.get(0).message.content == null + and: + resp.choices.get(0).message.tool_calls.size() == 3 + and: + resp.choices.get(0).message.tool_calls[0].type == 'function' + resp.choices.get(0).message.tool_calls[0].function.name == 'get_current_weather' + resp.choices.get(0).message.tool_calls[0].function.arguments == "{\"location\": \"San Francisco, CA\"}" + and: + resp.choices.get(0).message.tool_calls[1].type == 'function' + resp.choices.get(0).message.tool_calls[1].function.name == 'get_current_weather' + resp.choices.get(0).message.tool_calls[1].function.arguments == "{\"location\": \"Tokyo\"}" + and: + resp.choices.get(0).message.tool_calls[2].type == 'function' + resp.choices.get(0).message.tool_calls[2].function.name == 'get_current_weather' + resp.choices.get(0).message.tool_calls[2].function.arguments == "{\"location\": \"Paris\"}" + } + + def 'should call tools' () { + given: + def query = ''' + Check what's the weather like in San Francisco, Tokyo, and Paris, + then print the temperature for each city. + '''.stripIndent() + and: + def tools = [ + new GptChatCompletionRequest.Tool(type:'function', + function: new GptChatCompletionRequest.Function( + name:'get_current_weather', + description: 'Get the current weather in a given location', + parameters: new GptChatCompletionRequest.Parameters( + type:'object', + properties: [ location: new GptChatCompletionRequest.Param(type:'string',description: 'The city and state, e.g. San Francisco, CA')], + required: []))), + + new GptChatCompletionRequest.Tool(type:'function', + function: new GptChatCompletionRequest.Function( + name:'print_value', + description: 'Print a generic value to the standard output', + parameters: new GptChatCompletionRequest.Parameters( + type:'object', + properties: [ value: new GptChatCompletionRequest.Param(type:'string', description: 'The value to be printed')], + required: []))) + + ] + List messages = [ new GptChatCompletionRequest.Message(role: 'user', content: query) ] + def request = new GptChatCompletionRequest( + model: 'gpt-3.5-turbo-0125', + messages: messages, + tools: tools, + tool_choice: 'auto' ) + + and: + def session = Mock(Session) { + getConfig() >> [:] + } + and: + def config = GptConfig.config(session) + + when: + def response = new GptClient(config).sendRequest(request) + then: + response + + when: + for( def choice : response.choices ) { + messages << choice.message + + for( def tool : choice.message.tool_calls ) { + messages << new GptChatCompletionRequest.ToolMessage( + role: 'tool', + name: tool.function.name, + content: '10', + tool_call_id: tool.id ) + } + } + and: + def request2 = new GptChatCompletionRequest( model: 'gpt-3.5-turbo-0125', messages: messages ) + and: + def response2 = new GptClient(config).sendRequest(request2) + then: + response2 + + } + + def 'should get structured output' () { + given: + def session = Mock(Session) { getConfig() >> [:] } + def config = GptConfig.config(session) + def client = new GptClient(config) + and: + def msg1 = new GptChatCompletionRequest.Message(role:'system', content:'You are a helpful assistant designed to output JSON. The top level object contains the attribute `result`. The result is a list of objects having two attributes: `year` and `location`') + def msg2 = new GptChatCompletionRequest.Message(role:'user', content:'List of all editions of olympic games.') + def request = new GptChatCompletionRequest( + model: 'gpt-3.5-turbo-0125', + messages: [msg1, msg2], + response_format: GptChatCompletionRequest.ResponseFormat.JSON + ) + + when: + def response = client.sendRequest(request) + and: + print response.choices[0].message.content + then: + true + + } +} diff --git a/plugins/nf-gpt/src/test/nextflow/gpt/config/GptConfigTest.groovy b/plugins/nf-gpt/src/test/nextflow/gpt/config/GptConfigTest.groovy index dd1b375..a7e66f1 100644 --- a/plugins/nf-gpt/src/test/nextflow/gpt/config/GptConfigTest.groovy +++ b/plugins/nf-gpt/src/test/nextflow/gpt/config/GptConfigTest.groovy @@ -29,7 +29,7 @@ class GptConfigTest extends Specification { def 'should create from session' () { given: - def CONFIG = [ai:[endpoint:'http://xyz.com', model:'gpt-4', apiKey: 'abc']] + def CONFIG = [gpt:[endpoint:'http://xyz.com', model:'gpt-4', apiKey: 'abc']] def session = Mock(Session) {getConfig()>>CONFIG } when: diff --git a/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptHelperTest.groovy b/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptHelperTest.groovy new file mode 100644 index 0000000..55c6d66 --- /dev/null +++ b/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptHelperTest.groovy @@ -0,0 +1,108 @@ +package nextflow.gpt.prompt + +import dev.langchain4j.data.message.AiMessage +import dev.langchain4j.data.message.SystemMessage +import dev.langchain4j.data.message.UserMessage +import spock.lang.Specification + +/** + * + * @author Paolo Di Tommaso + */ +class GptHelperTest extends Specification { + + def 'should render schema' () { + expect: + GptHelper.renderSchema([foo:'string']) == 'You must answer strictly in the following JSON format: {"result": [{"foo": (type: string)}] }' + } + + def 'should render schema /1' () { + expect: + GptHelper.schema0([foo:'string']) == '{"foo": (type: string)}' + } + + def 'should render schema /2' () { + expect: + GptHelper.schema0(SCHEMA) == EXPECTED + + where: + SCHEMA | EXPECTED + [:] | '{}' + [] | '[]' + and: + [color:'string',count:'integer'] | '{"color": (type: string), "count": (type: integer)}' + [[color:'string',count:'integer']] | '[{"color": (type: string), "count": (type: integer)}]' + [[color:'string'], [count:'integer']] | '[{"color": (type: string)}, {"count": (type: integer)}]' + } + + def 'should decode response to a list' () { + given: + def resp + def SCHEMA = [location:'string',year:'string'] + List> result + + when: // a single object is given, then returns it as a list + resp = [location: 'foo', year:'2000'] + result = GptHelper.decodeResponse0(resp, SCHEMA) + then: + result == [[location: 'foo', year:'2000']] + + when: // a list of location is given + resp = [[location: 'foo', year:'2000'], [location: 'bar', year:'2001']] + result = GptHelper.decodeResponse0(resp, SCHEMA) + then: + result == [[location: 'foo', year:'2000'], [location: 'bar', year:'2001']] + + when: // a list wrapped into a result object + resp = [ games: [[location: 'foo', year:'2000'], [location: 'bar', year:'2001']] ] + result = GptHelper.decodeResponse0(resp, SCHEMA) + then: + result == [[location: 'foo', year:'2000'], [location: 'bar', year:'2001']] + + when: // an indexed map is returned + resp = [ 0: [location: 'rome', year:'2000'], 1: [location: 'barna', year:'2001'], 3: [location: 'london', year:'2002'] ] + result = GptHelper.decodeResponse0(resp, SCHEMA) + then: + result == [ [location: 'rome', year:'2000'], [location: 'barna', year:'2001'], [location: 'london', year:'2002']] + } + + def 'should check it is an index map' () { + given: + def SCHEMA = [a: 'string', b: 'String'] + expect: + GptHelper.isIndexMap([0: [a: 'this', b:'that'], 1: [a: 'foo', b:'bar']], SCHEMA) + GptHelper.isIndexMap(['0': [a: 'this', b:'that'], '1': [a: 'foo', b:'bar']], SCHEMA) + !GptHelper.isIndexMap(['x': [a: 'this', b:'that'], 'y': [a: 'foo', b:'bar']], SCHEMA) + } + + + def 'should convert map to chat message' () { + expect: + GptHelper.messageToChat(List.of([role:'user', content:'this'])) == [UserMessage.from('this') ] + GptHelper.messageToChat(List.of([role:'system', content:'that'])) == [SystemMessage.from('that') ] + GptHelper.messageToChat(List.of([role:'ai', content:'other'])) == [AiMessage.from('other') ] + and: + GptHelper.messageToChat(List.of([role:'user', content:'this'],[role:'system', content:'that'])) + == [UserMessage.from('this'), SystemMessage.from('that')] + + when: + GptHelper.messageToChat([]) + then: + def e = thrown(IllegalArgumentException) + e.message == 'Missing \'messages\' argument' + + when: + GptHelper.messageToChat([[foo:'one']]) + then: + e = thrown(IllegalArgumentException) + e.message == 'Missing \'role\' attribute - offending message: [[foo:one]]' + + when: + GptHelper.messageToChat([[role:'one', content:'something']]) + then: + e = thrown(IllegalArgumentException) + e.message == 'Unsupported message role \'one\' - offending message: [[role:one, content:something]]' + + } + +} diff --git a/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptPromptExtensionTest.groovy b/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptPromptExtensionTest.groovy new file mode 100644 index 0000000..3016daa --- /dev/null +++ b/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptPromptExtensionTest.groovy @@ -0,0 +1,82 @@ +package nextflow.gpt.prompt + +import groovyx.gpars.dataflow.DataflowQueue +import nextflow.Session +import spock.lang.Requires +import spock.lang.Specification +/** + * + * @author Paolo Di Tommaso + */ +@Requires({ System.getenv('OPENAI_API_KEY') }) +class GptPromptExtensionTest extends Specification { + + def 'should run a prompt as a operator' () { + given: + def PROMPT = 'Extract information about a person from In 1968, amidst the fading echoes of Independence Day, a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, marked the start of a new journey.' + def SCHEMA = [ + firstName: 'string', + lastName: 'string', + birthDate: 'date string (YYYY-MM-DD)' + ] + and: + def session = Mock(Session) { getConfig()>>[:] } + and: + def ext = new GptPromptExtension(); ext.init(session) + and: + def source = new DataflowQueue(); source.bind(PROMPT) + + when: + def ret = (DataflowQueue) ext.prompt(source, [schema:SCHEMA]) + then: + ret.getVal() == [firstName:'John', lastName:'Doe', birthDate:'1968-07-04'] + } + + def 'should run a prompt for data as a function' () { + given: + def PROMPT = 'Extract information about a person from In 1968, amidst the fading echoes of Independence Day, a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, marked the start of a new journey.' + def SCHEMA = [ + firstName: 'string', + lastName: 'string', + birthDate: 'date string (YYYY-MM-DD)' + ] + and: + def session = Mock(Session) { getConfig()>>[:] } + and: + def ext = new GptPromptExtension(); ext.init(session) + + when: + def result = ext.gptPromptForData([schema:SCHEMA], PROMPT) + then: + result == [ [firstName:'John', lastName:'Doe', birthDate:'1968-07-04'] ] + } + + def 'should run a prompt for text' () { + given: + def PROMPT = 'Extract information about a person from In 1968, amidst the fading echoes of Independence Day, a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, marked the start of a new journey.' + and: + def session = Mock(Session) { getConfig()>>[:] } + and: + def ext = new GptPromptExtension(); ext.init(session) + + when: + def ret = ext.gptPromptForText(PROMPT) + then: + ret.contains('1968') + } + + def 'should run a prompt for text with multiple choices' () { + given: + def PROMPT = 'Extract information about a person from In 1968, amidst the fading echoes of Independence Day, a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, marked the start of a new journey.' + and: + def session = Mock(Session) { getConfig()>>[:] } + and: + def ext = new GptPromptExtension(); ext.init(session) + + when: + def ret = ext.gptPromptForText(PROMPT, numOfChoices: 1) + then: + ret[0].contains('1968') + } + +} diff --git a/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptPromptModelTest.groovy b/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptPromptModelTest.groovy index 8b37cbb..f0857af 100644 --- a/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptPromptModelTest.groovy +++ b/plugins/nf-gpt/src/test/nextflow/gpt/prompt/GptPromptModelTest.groovy @@ -26,29 +26,6 @@ import spock.lang.Specification */ class GptPromptModelTest extends Specification { - def 'should render schema' () { - expect: - GptPromptModel.renderSchema([foo:'string']) == 'You must answer strictly in the following JSON format: {"result": [{"foo": (type: string)}] }' - } - - def 'should render schema /1' () { - expect: - GptPromptModel.schema0([foo:'string']) == '{"foo": (type: string)}' - } - - def 'should render schema /2' () { - expect: - GptPromptModel.schema0(SCHEMA) == EXPECTED - - where: - SCHEMA | EXPECTED - [:] | '{}' - [] | '[]' - and: - [color:'string',count:'integer'] | '{"color": (type: string), "count": (type: integer)}' - [[color:'string',count:'integer']] | '[{"color": (type: string), "count": (type: integer)}]' - [[color:'string'], [count:'integer']] | '[{"color": (type: string)}, {"count": (type: integer)}]' - } @Requires({ System.getenv('OPENAI_API_KEY') }) def 'should render json response' () { @@ -61,7 +38,7 @@ class GptPromptModelTest extends Specification { ] and: def session = Mock(Session) { getConfig()>>[:] } - def model = new GptPromptModel(session).build() + def model = new GptPromptModel(session).withJsonResponseFormat().build() when: def result = model.prompt(PROMPT, SCHEMA) @@ -71,44 +48,6 @@ class GptPromptModelTest extends Specification { result[0].birthDate == '1968-07-04' } - def 'should decode response to a list' () { - given: - def resp - def SCHEMA = [location:'string',year:'string'] - List> result - - when: // a single object is given, then returns it as a list - resp = [location: 'foo', year:'2000'] - result = GptPromptModel.decodeResponse0(resp, SCHEMA) - then: - result == [[location: 'foo', year:'2000']] - when: // a list of location is given - resp = [[location: 'foo', year:'2000'], [location: 'bar', year:'2001']] - result = GptPromptModel.decodeResponse0(resp, SCHEMA) - then: - result == [[location: 'foo', year:'2000'], [location: 'bar', year:'2001']] - - when: // a list wrapped into a result object - resp = [ games: [[location: 'foo', year:'2000'], [location: 'bar', year:'2001']] ] - result = GptPromptModel.decodeResponse0(resp, SCHEMA) - then: - result == [[location: 'foo', year:'2000'], [location: 'bar', year:'2001']] - - when: // an indexed map is returned - resp = [ 0: [location: 'rome', year:'2000'], 1: [location: 'barna', year:'2001'], 3: [location: 'london', year:'2002'] ] - result = GptPromptModel.decodeResponse0(resp, SCHEMA) - then: - result == [ [location: 'rome', year:'2000'], [location: 'barna', year:'2001'], [location: 'london', year:'2002']] - } - - def 'should check it is an index map' () { - given: - def SCHEMA = [a: 'string', b: 'String'] - expect: - GptPromptModel.isIndexMap([0: [a: 'this', b:'that'], 1: [a: 'foo', b:'bar']], SCHEMA) - GptPromptModel.isIndexMap(['0': [a: 'this', b:'that'], '1': [a: 'foo', b:'bar']], SCHEMA) - !GptPromptModel.isIndexMap(['x': [a: 'this', b:'that'], 'y': [a: 'foo', b:'bar']], SCHEMA) - } } diff --git a/prompt-eng.nf b/prompt-eng.nf new file mode 100644 index 0000000..e69de29