diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptExtension.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptExtension.groovy index a884f0f..c2a44d8 100644 --- a/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptExtension.groovy +++ b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptExtension.groovy @@ -61,6 +61,8 @@ class GptPromptExtension extends PluginExtensionPoint { final ai = new GptPromptModel(session) .withModel(opts.model as String) .withDebug(opts.debug as Boolean) + .withTemperature(opts.temperature as Double) + .withMaxToken(opts.maxTokens as Integer) .build() // run the prompt final response = ai.prompt(query, opts.schema as Map) @@ -84,6 +86,8 @@ class GptPromptExtension extends PluginExtensionPoint { final ai = new GptPromptModel(session) .withModel(opts.model as String) .withDebug(opts.debug as Boolean) + .withTemperature(opts.temperature as Double) + .withMaxToken(opts.maxTokens as Integer) .build() final target = CH.createBy(source) diff --git a/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptModel.groovy b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptModel.groovy index 7872c77..265d8c0 100644 --- a/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptModel.groovy +++ b/plugins/nf-gpt/src/main/nextflow/gpt/prompt/GptPromptModel.groovy @@ -70,16 +70,16 @@ class GptPromptModel { GptPromptModel build() { final modelName = model ?: config.model() - final temp = temperature ?: config.temperature() + final temperature = this.temperature ?: config.temperature() final tokens = maxTokens ?: config.maxTokens() - log.debug "Creating OpenAI chat model: $modelName; api-key: ${StringUtils.redact(config.apiKey())}; temperature: $temp; maxTokens: ${maxTokens}" + log.debug "Creating OpenAI chat model: $modelName; api-key: ${StringUtils.redact(config.apiKey())}; temperature: $temperature; maxTokens: ${maxTokens}" client = OpenAiChatModel.builder() .apiKey(config.apiKey()) .modelName(modelName) .logRequests(debug) .logResponses(debug) .temperature(temperature) - .maxTokens(maxTokens) + .maxTokens(tokens) .responseFormat("json_object") .build(); return this