From 901144fe6e273674e7f3897cfa98f821afb05b32 Mon Sep 17 00:00:00 2001 From: Johnathan Gilday Date: Thu, 18 Jul 2024 14:33:37 -0400 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20Added=20Tests=20for=20EnvironmentBa?= =?UTF-8?q?sedModelMapper=20(#422)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit While working to confirm that the logic in codemodder-java and the platform is in sync, I found opportunities to add missing tests and fix some pitfalls. /towards #work --- .../llm/EnvironmentBasedModelMapper.java | 22 +++++-- .../llm/EnvironmentBasedModelMapperTest.java | 62 +++++++++++++++++++ 2 files changed, 79 insertions(+), 5 deletions(-) create mode 100644 plugins/codemodder-plugin-llm/src/test/java/io/codemodder/plugins/llm/EnvironmentBasedModelMapperTest.java diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/EnvironmentBasedModelMapper.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/EnvironmentBasedModelMapper.java index 12bb71aed..5bbbe47cf 100644 --- a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/EnvironmentBasedModelMapper.java +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/EnvironmentBasedModelMapper.java @@ -1,22 +1,34 @@ package io.codemodder.plugins.llm; import java.util.HashMap; +import java.util.Map; /** Mapper that maps models to their deployment names based on environment variables. */ final class EnvironmentBasedModelMapper implements ModelMapper { - private static final String DEPLOYMENT_TEMPLATE = "CODEMODDER_AZURE_OPENAI_%s_DEPLOYMENT"; private final HashMap map = new HashMap<>(); EnvironmentBasedModelMapper() { - for (Model m : StandardModel.values()) { - final var deployment = System.getenv(String.format(DEPLOYMENT_TEMPLATE, m)); - map.put(m, deployment == null ? m.id() : deployment); + this(System.getenv()); + } + + EnvironmentBasedModelMapper(final Map environment) { + for (final Model model : StandardModel.values()) { + final var name = String.format(DEPLOYMENT_TEMPLATE, toEnvironmentVariableCase(model.id())); + final var deployment = environment.getOrDefault(name, model.id()); + map.put(model, deployment); } } @Override public String getModelName(Model model) { - return map.get(model); + return map.getOrDefault(model, model.id()); } + + /** Converts a model ID to environment variable casing. */ + private static String toEnvironmentVariableCase(String input) { + return input.toUpperCase().replace('-', '_').replace('.', '_'); + } + + private static final String DEPLOYMENT_TEMPLATE = "CODEMODDER_AZURE_OPENAI_%s_DEPLOYMENT"; } diff --git a/plugins/codemodder-plugin-llm/src/test/java/io/codemodder/plugins/llm/EnvironmentBasedModelMapperTest.java b/plugins/codemodder-plugin-llm/src/test/java/io/codemodder/plugins/llm/EnvironmentBasedModelMapperTest.java new file mode 100644 index 000000000..f353bb2cc --- /dev/null +++ b/plugins/codemodder-plugin-llm/src/test/java/io/codemodder/plugins/llm/EnvironmentBasedModelMapperTest.java @@ -0,0 +1,62 @@ +package io.codemodder.plugins.llm; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import java.util.Map; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +/** Unit tests for {@link EnvironmentBasedModelMapper}. */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +final class EnvironmentBasedModelMapperTest { + + private EnvironmentBasedModelMapper mapper; + + @BeforeAll + void before() { + final var environment = + Map.of( + "CODEMODDER_AZURE_OPENAI_GPT_3_5_TURBO_0125_DEPLOYMENT", + "my-gpt-3.5-turbo", + "CODEMODDER_AZURE_OPENAI_GPT_4_0613_DEPLOYMENT", + "my-gpt-4", + "CODEMODDER_AZURE_OPENAI_GPT_4_TURBO_2024_04_09_DEPLOYMENT", + "my-gpt-4-turbo", + "CODEMODDER_AZURE_OPENAI_GPT_4O_2024_05_13_DEPLOYMENT", + "my-gpt-4o"); + mapper = new EnvironmentBasedModelMapper(environment); + } + + /** Spot checks one of the standard models to make sure the mapping works as expected */ + @Test + void it_maps_model_name_to_deployment() { + final var name = mapper.getModelName(StandardModel.GPT_3_5_TURBO_0125); + assertThat(name).isEqualTo("my-gpt-3.5-turbo"); + } + + /** + * This is a meta-test that fails when we add a new standard model but forget to update the + * mapping in {@link #before()} to ensure that all standard models are covered. + */ + @EnumSource(StandardModel.class) + @ParameterizedTest + void it_looks_up_all_standard_models(final Model model) { + final var name = mapper.getModelName(model); + assertThat(name).isNotEqualTo(model.id()).startsWith("my-gpt"); + } + + @Test + void it_returns_model_id_when_no_mapping_exists() { + // GIVEN some model that doesn't have a mapping + final var model = mock(Model.class, withSettings().stubOnly()); + when(model.id()).thenReturn("test"); + final var name = mapper.getModelName(model); + assertThat(name).isEqualTo(model.id()); + } +}