diff --git a/force-app/main/default/classes/OpenAIChatModel.cls b/force-app/main/default/classes/OpenAIChatModel.cls index 5f064c1..6b2781a 100644 --- a/force-app/main/default/classes/OpenAIChatModel.cls +++ b/force-app/main/default/classes/OpenAIChatModel.cls @@ -6,45 +6,70 @@ public class OpenAIChatModel implements ChatModel.I { public Decimal tempature = 0.0; public Integer maxTokens = 2000; public string namedCrediential = 'OpenAI'; + public static Integer retryCount = 0; + public static final Integer maxRetries = 3; public String chat(ChatModel.Message[] messages) { - Http http = new Http(); - HttpRequest request = new HttpRequest(); - request.setEndpoint('callout:' + namedCrediential + '/chat/completions'); - request.setHeader('Content-Type', 'application/json'); - request.setHeader('Authorization', 'Bearer {!$Credential.OpenAI.API_KEY}'); - request.setMethod('POST'); - request.setTimeout(120000); + while(retryCount < maxRetries) { + retryCount++; - ChatAPIRequest requestBody = new ChatAPIRequest(); - requestBody.messages = messages; - requestBody.model = model; - requestBody.temperature = tempature; - requestBody.max_tokens = maxTokens; // TODO dynamicly set this by calculating tokens - requestBody.stream = false; + Http http = new Http(); + HttpRequest request = new HttpRequest(); - request.setBody(JSON.serialize(requestBody, true)); + request.setEndpoint('callout:' + namedCrediential + '/chat/completions'); + request.setHeader('Content-Type', 'application/json'); + request.setHeader('Authorization', 'Bearer {!$Credential.OpenAI.API_KEY}'); + request.setMethod('POST'); + request.setTimeout(120000); - HttpResponse response = http.send(request); + ChatAPIRequest requestBody = new ChatAPIRequest(); + requestBody.messages = messages; + requestBody.model = model; + requestBody.temperature = tempature; + requestBody.max_tokens = maxTokens; // TODO dynamicly set this by calculating tokens + requestBody.stream = false; - // if(response.getStatusCode() == 429) { - //retry? - // } + request.setBody(JSON.serialize(requestBody, true)); - if (response.getStatusCode() == 200) { - ChatAPICompletion results = (ChatAPICompletion) JSON.deserialize( - response.getBody(), - ChatAPICompletion.class - ); + HttpResponse response = http.send(request); + statusCode = response.getStatusCode(); + statusMessage = response.getStatus(); - OpenAIChatModel.Choice completion = results.choices[0]; - return completion.message?.content?.trim(); - } else { - system.debug(response.getBody()); - throw new OpenAIException( - 'OpenAI API returned status code ' + response.getStatusCode() - ); + switch on statusCode { + when 200 { // success + ChatAPICompletion results = (ChatAPICompletion) JSON.deserialize( + response.getBody(), + ChatAPICompletion.class + ); + + OpenAIChatModel.Choice completion = results.choices[0]; + return completion.message?.content?.trim(); + } + when else { + // ====================================== + // = OpenAI API Error Codes: + // ====================================== + // 401 - Invalid Authentication + // 401 - Incorrect API key provided + // 401 - You must be a member of an organization to use the API + // 429 - Rate limit reached for requests + // 429 - You exceeded your current quota, please check your plan and billing details + // 429 - The engine is currently overloaded, please try again later + // 500 - The server had an error while processing your request + // ====================================== + // Retry only for the case + // 429 - The engine is currently overloaded, please try again later + if(statusCode == 429 && statusMessage.contains('overloaded')) { + continue; // retry up to maxRetries times + } + // else throw exception + system.debug(response.getBody()); + throw new OpenAIException( + 'OpenAI API returned status code: ' + statusCode + '-' + statusMessage + ); + } + } } } diff --git a/force-app/main/default/classes/OpenAICompletionModel.cls b/force-app/main/default/classes/OpenAICompletionModel.cls index 85cdf05..c4c4edb 100644 --- a/force-app/main/default/classes/OpenAICompletionModel.cls +++ b/force-app/main/default/classes/OpenAICompletionModel.cls @@ -1,5 +1,7 @@ public with sharing class OpenAICompletionModel implements CompletionModel.I { private static final String ENDPOINT = 'https://api.openai.com/v1/'; + public static Integer retryCount = 0; + public static final Integer maxRetries = 3; private string apiKey; public OpenAICompletionModel(String apiKey) { @@ -7,41 +9,68 @@ public with sharing class OpenAICompletionModel implements CompletionModel.I { } public String complete(String prompt, String[] stop) { - Http http = new Http(); - HttpRequest request = new HttpRequest(); - - request.setEndpoint(ENDPOINT + 'completions'); - request.setHeader('Content-Type', 'application/json'); - request.setHeader('Authorization', 'Bearer ' + apiKey); - request.setMethod('POST'); - - Map requestBody = new Map{ - 'model' => 'text-davinci-003', - 'prompt' => prompt, - 'temperature' => 0.7, - 'max_tokens' => Integer.valueOf((prompt.length() / 4) + 500), - 'n' => 1, - 'stop' => stop - }; - - request.setBody(JSON.serialize(requestBody)); // serialize the Map as JSON - - HttpResponse response = http.send(request); - - if (response.getStatusCode() == 200) { - Map jsonResponse = (Map) JSON.deserializeUntyped( - response.getBody() - ); - - // System.debug(jsonResponse); - List completions = (List) jsonResponse.get('choices'); - Map completion = (Map) completions[0]; - String text = (String) completion.get('text'); - return text.trim(); - } else { - throw new OpenAIException( - 'OpenAI API returned status code ' + response.getStatusCode() - ); + while(retryCount < maxRetries) { + retryCount++; + + Http http = new Http(); + HttpRequest request = new HttpRequest(); + + request.setEndpoint(ENDPOINT + 'completions'); + request.setHeader('Content-Type', 'application/json'); + request.setHeader('Authorization', 'Bearer ' + apiKey); + request.setMethod('POST'); + + Map requestBody = new Map{ + 'model' => 'text-davinci-003', + 'prompt' => prompt, + 'temperature' => 0.7, + 'max_tokens' => Integer.valueOf((prompt.length() / 4) + 500), + 'n' => 1, + 'stop' => stop + }; + + request.setBody(JSON.serialize(requestBody)); // serialize the Map as JSON + + HttpResponse response = http.send(request); + statusCode = response.getStatusCode(); + statusMessage = response.getStatus(); + + switch on statusCode { + when 200 { // success + Map jsonResponse = (Map) JSON.deserializeUntyped( + response.getBody() + ); + + // System.debug(jsonResponse); + List completions = (List) jsonResponse.get('choices'); + Map completion = (Map) completions[0]; + String text = (String) completion.get('text'); + return text.trim(); + } + when else { + // ====================================== + // = OpenAI API Error Codes: + // ====================================== + // 401 - Invalid Authentication + // 401 - Incorrect API key provided + // 401 - You must be a member of an organization to use the API + // 429 - Rate limit reached for requests + // 429 - You exceeded your current quota, please check your plan and billing details + // 429 - The engine is currently overloaded, please try again later + // 500 - The server had an error while processing your request + // ====================================== + // Retry only for the case + // 429 - The engine is currently overloaded, please try again later + if(statusCode == 429 && statusMessage.contains('overloaded')) { + continue; // retry up to maxRetries times + } + // else throw exception + system.debug(response.getBody()); + throw new OpenAIException( + 'OpenAI API returned status code: ' + statusCode + '-' + statusMessage + ); + } + } } }