diff --git a/sdks/java/ml/build.gradle b/sdks/java/ml/build.gradle new file mode 100644 index 000000000000..7b6b071fa2a1 --- /dev/null +++ b/sdks/java/ml/build.gradle @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { id 'org.apache.beam.module' } +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.ml', +) +provideIntegrationTestingDependencies() +enableJavaPerformanceTesting() + +description = "Apache Beam :: SDKs :: Java :: ML" +ext.summary = "Java ML module" + +dependencies { + + +} diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/remoteinference/build.gradle.kts new file mode 100644 index 000000000000..f5eefd91514b --- /dev/null +++ b/sdks/java/ml/remoteinference/build.gradle.kts @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id("org.apache.beam.module") + id("java-library") +} + +description = "Apache Beam :: SDKs :: Java :: ML :: RemoteInference" + +dependencies { + // Core Beam SDK + implementation(project(":sdks:java:core")) + + implementation("com.openai:openai-java:4.3.0") + implementation("com.google.auto.value:auto-value:1.11.0") + implementation("com.google.auto.value:auto-value-annotations:1.11.0") + implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") + + // testing + testImplementation(project(":runners:direct-java")) + testImplementation("junit:junit:4.13.2") + testImplementation(project(":sdks:java:testing:test-utils")) +} + diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java new file mode 100644 index 000000000000..d0a922f01d8f --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference; + +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.checkerframework.checker.nullness.qual.Nullable; + +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; + + +import com.google.auto.value.AutoValue; + +import java.util.ArrayList; +import java.util.List; + +@SuppressWarnings({ "rawtypes", "unchecked" }) +public class RemoteInference { + + public static Invoke invoke() { + return new AutoValue_RemoteInference_Invoke.Builder().setParameters(null) + .build(); + } + + private RemoteInference() { + } + + @AutoValue + public abstract static class Invoke + extends PTransform, PCollection>>> { + + abstract @Nullable Class handler(); + + abstract @Nullable BaseModelParameters parameters(); + + abstract @Nullable BatchConfig batchConfig(); + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setHandler(Class modelHandler); + + abstract Builder setParameters(BaseModelParameters modelParameters); + + abstract Builder setBatchConfig(BatchConfig config); + + abstract Invoke build(); + } + + public Invoke handler(Class modelHandler) { + return builder().setHandler(modelHandler).build(); + } + + public Invoke withParameters(BaseModelParameters modelParameters) { + return builder().setParameters(modelParameters).build(); + } + + public Invoke withBatchConfig(BatchConfig config) { + return builder().setBatchConfig(config).build(); + } + + @Override + public PCollection>> expand(PCollection input) { + return input.apply(ParDo.of(new BatchElementsFn<>(this.batchConfig() != null ? this.batchConfig() + : this + .parameters() + .defaultBatchConfig()))) + .apply(ParDo.of(new RemoteInferenceFn<>(this))); + } + + static class RemoteInferenceFn + extends DoFn, Iterable>> { + + private final Class handlerClass; + private final BaseModelParameters parameters; + private transient BaseModelHandler handler; + + RemoteInferenceFn(Invoke spec) { + this.handlerClass = spec.handler(); + this.parameters = spec.parameters(); + } + + @Setup + public void setupHandler() { + try { + this.handler = handlerClass.getDeclaredConstructor().newInstance(); + this.handler.createClient(parameters); + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate handler: " + + handlerClass.getName(), e); + } + } + + @ProcessElement + public void processElement(ProcessContext c) { + Iterable> response = this.handler.request(c.element()); + c.output(response); + } + } + + public static class BatchElementsFn extends DoFn> { + private final BatchConfig config; + private List batch; + + public BatchElementsFn(BatchConfig config) { + this.config = config; + } + + @StartBundle + public void startBundle() { + batch = new ArrayList<>(); + } + + @ProcessElement + public void processElement(ProcessContext c) { + batch.add(c.element()); + if (batch.size() >= config.getMaxBatchSize()) { + c.output(new ArrayList<>(batch)); + batch.clear(); + } + } + + @FinishBundle + public void finishBundle(FinishBundleContext c) { + if (!batch.isEmpty()) { + c.output(new ArrayList<>(batch), null, null); + batch.clear(); + } + } + + } + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java new file mode 100644 index 000000000000..939406722b8f --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public abstract class BaseInput implements Serializable { + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java new file mode 100644 index 000000000000..3ad6cdce84e5 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.util.List; + +public interface BaseModelHandler { + + // initialize the model with provided parameters + public void createClient(ParamT parameters); + + // Logic to invoke model provider + public Iterable> request(List input); + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java new file mode 100644 index 000000000000..bb56cb74a555 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public interface BaseModelParameters extends Serializable { + + public BatchConfig defaultBatchConfig(); +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java new file mode 100644 index 000000000000..2e7af0e7c2c1 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public abstract class BaseResponse implements Serializable { + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java new file mode 100644 index 000000000000..a4497c1181f1 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java @@ -0,0 +1,45 @@ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public class BatchConfig implements Serializable { + + private final int minBatchSize; + private final int maxBatchSize; + + private BatchConfig(Builder builder) { + this.minBatchSize = builder.minBatchSize; + this.maxBatchSize = builder.maxBatchSize; + } + + public int getMinBatchSize() { + return minBatchSize; + } + + public int getMaxBatchSize() { + return maxBatchSize; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private int minBatchSize; + private int maxBatchSize; + + public Builder minBatchSize(int minBatchSize) { + this.minBatchSize = minBatchSize; + return this; + } + + public Builder maxBatchSize(int maxBatchSize) { + this.maxBatchSize = maxBatchSize; + return this; + } + + public BatchConfig build() { + return new BatchConfig(this); + } + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java new file mode 100644 index 000000000000..b19f64917479 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public class PredictionResult implements Serializable { + + private final InputT input; + private final OutputT output; + + private PredictionResult(InputT input, OutputT output) { + this.input = input; + this.output = output; + + } + + public InputT getInput() { + return input; + } + + public OutputT getOutput() { + return output; + } + + public static PredictionResult create(InputT input, OutputT output) { + return new PredictionResult<>(input, output); + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java new file mode 100644 index 000000000000..f30e820549ae --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonSchemaLocalValidation; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.StructuredResponseCreateParams; +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; +import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult; + +import java.util.List; +import java.util.stream.Collectors; + +public class OpenAIModelHandler + implements BaseModelHandler { + + private transient OpenAIClient client; + private transient StructuredResponseCreateParams clientParams; + private OpenAIModelParameters modelParameters; + + @Override + public void createClient(OpenAIModelParameters parameters) { + this.modelParameters = parameters; + this.client = OpenAIOkHttpClient.builder() + .apiKey(this.modelParameters.getApiKey()) + .build(); + } + + @Override + public Iterable> request(List input) { + + try { + // Convert input list to JSON string + String inputBatch = new ObjectMapper() + .writeValueAsString(input.stream().map(OpenAIModelInput::getInput).toList()); + + // Build structured response parameters + this.clientParams = ResponseCreateParams.builder() + .model(modelParameters.getModelName()) + .input(inputBatch) + .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) + .instructions(modelParameters.getInstructionPrompt()) + .build(); + + // Get structured output from the model + StructuredInputOutput structuredOutput = client.responses() + .create(clientParams) + .output() + .stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .findFirst() + .orElse(null); + + if (structuredOutput == null || structuredOutput.responses == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + // Map responses to PredictionResults + List> results = structuredOutput.responses.stream() + .map(response -> PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); + + return results; + + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize input batch", e); + } + } + + public static class Response { + @JsonProperty(required = true) + @JsonPropertyDescription("The input string") + public String input; + + @JsonProperty(required = true) + @JsonPropertyDescription("The output string") + public String output; + } + + public static class StructuredInputOutput { + @JsonProperty(required = true) + @JsonPropertyDescription("Array of input-output pairs") + public List responses; + } + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java new file mode 100644 index 000000000000..0500832def3f --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseInput; + +public class OpenAIModelInput extends BaseInput { + + private final String input; + + private OpenAIModelInput(String input) { + + this.input = input; + } + + public String getInput() { + return input; + } + + public static OpenAIModelInput create(String input) { + return new OpenAIModelInput(input); + } + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java new file mode 100644 index 000000000000..e26308694828 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; +import org.apache.beam.sdk.ml.remoteinference.base.BatchConfig; + +public class OpenAIModelParameters implements BaseModelParameters { + + private final String apiKey; + private final String modelName; + private final String instructionPrompt; + private final BatchConfig batchConfig; + + private OpenAIModelParameters(Builder builder) { + this.apiKey = builder.apiKey; + this.modelName = builder.modelName; + this.instructionPrompt = builder.instructionPrompt; + this.batchConfig = BatchConfig.builder().maxBatchSize(1).minBatchSize(1).build(); + } + + public String getApiKey() { + return apiKey; + } + + public String getModelName() { + return modelName; + } + + public String getInstructionPrompt() { + return instructionPrompt; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public BatchConfig defaultBatchConfig() { + return batchConfig; + } + + public static class Builder { + private String apiKey; + private String modelName; + private String instructionPrompt; + + private Builder() { + } + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder instructionPrompt(String prompt) { + this.instructionPrompt = prompt; + return this; + } + + public OpenAIModelParameters build() { + return new OpenAIModelParameters(this); + } + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java new file mode 100644 index 000000000000..e65513851bc0 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse; + +public class OpenAIModelResponse extends BaseResponse { + + private final String output; + + private OpenAIModelResponse(String output) { + this.output = output; + } + + public String getOutput() { + return output; + } + + public static OpenAIModelResponse create(String output) { + return new OpenAIModelResponse(output); + } +} diff --git a/sdks/java/ml/remoteinference/src/test/java/Example.java b/sdks/java/ml/remoteinference/src/test/java/Example.java new file mode 100644 index 000000000000..baf944342fe4 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/test/java/Example.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import org.apache.beam.runners.direct.DirectRunner; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.ml.remoteinference.RemoteInference; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelInput; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelParameters; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelResponse; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.TypeDescriptor; + +public class Example { + public static void main(String[] args) { + + /*PipelineOptions options = PipelineOptionsFactory.create(); + options.setRunner(DirectRunner.class); + Pipeline p = Pipeline.create(options); + + p.apply("text", Create.of( + "An excellent B2B SaaS solution that streamlines business processes efficiently. The platform is user-friendly and highly reliable. Overall, it delivers great value for enterprise teams.")) + .apply(MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)) + .apply("inference", RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey("key") + .modelName("gpt-5-mini") + .instructionPrompt("Analyse sentiment as positive or negative") + .build())) + .apply("print output", ParDo.of(new DoFn() { + @ProcessElement + public void print(ProcessContext c) { + System.out.println("OUTPUT: " + c.element().getOutput()); + } + })); + + p.run();*/ + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 72c5194ec93d..fdfc5da6854c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -383,3 +383,5 @@ include("sdks:java:extensions:sql:iceberg") findProject(":sdks:java:extensions:sql:iceberg")?.name = "iceberg" include("examples:java:iceberg") findProject(":examples:java:iceberg")?.name = "iceberg" + +include("sdks:java:ml:remoteinference") \ No newline at end of file