Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip LLM codemods when no service is available #418

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package io.codemodder.plugins.llm;

import com.google.inject.AbstractModule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Provides configured LLM services. */
public final class LLMServiceModule extends AbstractModule {

private static final String OPENAI_KEY_NAME = "CODEMODDER_OPENAI_API_KEY";
private static final String AZURE_OPENAI_KEY_NAME = "CODEMODDER_AZURE_OPENAI_API_KEY";
private static final String AZURE_OPENAI_ENDPOINT = "CODEMODDER_AZURE_OPENAI_ENDPOINT";
private static final Logger logger = LoggerFactory.getLogger(LLMServiceModule.class);

@Override
protected void configure() {
Expand All @@ -22,19 +25,20 @@ protected void configure() {
+ " must be set");
}
if (azureOpenAIKey != null) {
logger.info("Using Azure OpenAI service with endpoint {}", azureOpenAIEndpoint);
bind(OpenAIService.class)
.toProvider(() -> OpenAIService.fromAzureOpenAI(azureOpenAIKey, azureOpenAIEndpoint));
return;
}

bind(OpenAIService.class).toProvider(() -> OpenAIService.fromOpenAI(getOpenAIToken()));
}

private String getOpenAIToken() {
final var openAIKey = System.getenv(OPENAI_KEY_NAME);
if (openAIKey == null) {
throw new IllegalArgumentException(OPENAI_KEY_NAME + " environment variable must be set");
if (openAIKey != null) {
logger.info("Using OpenAI service");
bind(OpenAIService.class).toProvider(() -> OpenAIService.fromOpenAI(openAIKey));
return;
}
return openAIKey;

logger.info("No LLM service available");
bind(OpenAIService.class).toProvider(OpenAIService::noServiceAvailable);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class OpenAIService {
private final OpenAIClient api;
private static final int TIMEOUT_SECONDS = 90;
private final ModelMapper modelMapper;
private boolean serviceAvailable = true;

private static OpenAIClientBuilder builder(final KeyCredential key) {
HttpClientOptions clientOptions = new HttpClientOptions();
Expand All @@ -31,6 +32,12 @@ private static OpenAIClientBuilder builder(final KeyCredential key) {
.credential(key);
}

OpenAIService(final boolean serviceAvailable) {
this.serviceAvailable = serviceAvailable;
this.modelMapper = null;
this.api = null;
}

OpenAIService(final ModelMapper mapper, final KeyCredential key) {
this.modelMapper = mapper;
this.api = builder(key).buildClient();
Expand Down Expand Up @@ -66,6 +73,19 @@ public static OpenAIService fromAzureOpenAI(final String token, final String end
Objects.requireNonNull(endpoint));
}

public static OpenAIService noServiceAvailable() {
return new OpenAIService(false);
}

/**
* Returns whether the service is available.
*
* @return whether the service is available
*/
public boolean isServiceAvailable() {
return serviceAvailable;
}

/**
* Gets the completion for the given messages.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.codemodder.plugins.llm;

import io.codemodder.RuleSarif;
import io.codemodder.SarifPluginRawFileChanger;
import java.util.Objects;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add a comment here when to use this type since it's public.

/** A base class for LLM codemods that process SARIF and use the OpenAI service. */
public abstract class SarifPluginLLMCodemod extends SarifPluginRawFileChanger {
protected final OpenAIService openAI;

public SarifPluginLLMCodemod(RuleSarif sarif, final OpenAIService openAI) {
super(sarif);
this.openAI = Objects.requireNonNull(openAI);
}

/**
* Indicates whether the codemod should run.
*
* <p>Subclasses can override this method to add additional hecks but should call
* super.shouldRun() to ensure the OpenAI service is available.
*/
@Override
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add a comment here that I think subclassses shouldn't override this without also considering the same logic?

public boolean shouldRun() {
return openAI.isServiceAvailable();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@
* </ol>
*/
public abstract class SarifToLLMForBinaryVerificationAndFixingCodemod
extends SarifPluginRawFileChanger {
extends SarifPluginLLMCodemod {

private final OpenAIService openAI;
private final Model model;

protected SarifToLLMForBinaryVerificationAndFixingCodemod(
final RuleSarif sarif, final OpenAIService openAI, final Model model) {
super(sarif);
this.openAI = Objects.requireNonNull(openAI);
super(sarif, openAI);
this.model = Objects.requireNonNull(model);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@
* <p>To accomplish that, we need the analysis to "bucket" the code into one of the above
* categories.
*/
public abstract class SarifToLLMForMultiOutcomeCodemod extends SarifPluginRawFileChanger {
public abstract class SarifToLLMForMultiOutcomeCodemod extends SarifPluginLLMCodemod {

private static final Logger logger =
LoggerFactory.getLogger(SarifToLLMForMultiOutcomeCodemod.class);
private final OpenAIService openAI;
private final List<LLMRemediationOutcome> remediationOutcomes;
private final Model categorizationModel;
private final Model codeChangingModel;
Expand All @@ -65,8 +64,7 @@ protected SarifToLLMForMultiOutcomeCodemod(
final List<LLMRemediationOutcome> remediationOutcomes,
final Model categorizationModel,
final Model codeChangingModel) {
super(sarif);
this.openAI = Objects.requireNonNull(openAI);
super(sarif, openAI);
this.remediationOutcomes = Objects.requireNonNull(remediationOutcomes);
if (remediationOutcomes.size() < 2) {
throw new IllegalArgumentException("must have 2+ remediation outcome");
Expand All @@ -78,7 +76,7 @@ protected SarifToLLMForMultiOutcomeCodemod(
@Override
public CodemodFileScanningResult onFileFound(
final CodemodInvocationContext context, final List<Result> results) {
logger.info("processing: {}", context.path());
logger.debug("processing: {}", context.path());

List<CodemodChange> changes = new ArrayList<>();
for (Result result : results) {
Expand Down
Loading