Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions sdks/java/ml/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

plugins { id 'org.apache.beam.module' }
applyJavaNature(
automaticModuleName: 'org.apache.beam.sdk.ml',
)
provideIntegrationTestingDependencies()
enableJavaPerformanceTesting()

description = "Apache Beam :: SDKs :: Java :: ML"
ext.summary = "Java ML module"

dependencies {


}
39 changes: 39 additions & 0 deletions sdks/java/ml/remoteinference/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

plugins {
id("org.apache.beam.module")
id("java-library")
}

description = "Apache Beam :: SDKs :: Java :: ML :: RemoteInference"

dependencies {
// Core Beam SDK
implementation(project(":sdks:java:core"))

implementation("com.openai:openai-java:4.3.0")
implementation("com.google.auto.value:auto-value:1.11.0")
implementation("com.google.auto.value:auto-value-annotations:1.11.0")

// testing
testImplementation(project(":runners:direct-java"))
testImplementation("junit:junit:4.13.2")
testImplementation(project(":sdks:java:testing:test-utils"))
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.remoteinference;

import org.apache.beam.sdk.ml.remoteinference.base.*;
import org.checkerframework.checker.nullness.qual.Nullable;

import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;


import com.google.auto.value.AutoValue;

@SuppressWarnings({ "rawtypes", "unchecked" })
public class RemoteInference {

public static <InputT extends BaseInput, OutputT extends BaseResponse> Invoke<InputT, OutputT> invoke() {
return new AutoValue_RemoteInference_Invoke.Builder<InputT, OutputT>().setParameters(null)
.build();
}

private RemoteInference() {
}

@AutoValue
public abstract static class Invoke<InputT extends BaseInput, OutputT extends BaseResponse>
extends PTransform<PCollection<InputT>, PCollection<Iterable<PredictionResult<InputT, OutputT>>>> {

abstract @Nullable Class<? extends BaseModelHandler> handler();

abstract @Nullable BaseModelParameters parameters();

abstract Builder<InputT, OutputT> builder();

@AutoValue.Builder
abstract static class Builder<InputT extends BaseInput, OutputT extends BaseResponse> {

abstract Builder<InputT, OutputT> setHandler(Class<? extends BaseModelHandler> modelHandler);

abstract Builder<InputT, OutputT> setParameters(BaseModelParameters modelParameters);

abstract Invoke<InputT, OutputT> build();
}

public Invoke<InputT, OutputT> handler(Class<? extends BaseModelHandler> modelHandler) {
return builder().setHandler(modelHandler).build();
}

public Invoke<InputT, OutputT> withParameters(BaseModelParameters modelParameters) {
return builder().setParameters(modelParameters).build();
}

@Override
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the Python implementation (as well as the cross-language implementation in Java) we're generally trying to return input-output pairs to make it easier to process the results downstream.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated output format.

public PCollection<Iterable<PredictionResult<InputT, OutputT>>> expand(PCollection<InputT> input) {
return input.apply(ParDo.of(new RemoteInferenceFn<>(this)));
}

static class RemoteInferenceFn<InputT extends BaseInput, OutputT extends BaseResponse>
extends DoFn<InputT, Iterable<PredictionResult<InputT, OutputT>>> {

private final Class<? extends BaseModelHandler> handlerClass;
private final BaseModelParameters parameters;
private transient BaseModelHandler handler;

RemoteInferenceFn(Invoke<InputT, OutputT> spec) {
this.handlerClass = spec.handler();
this.parameters = spec.parameters();
}

@Setup
public void setupHandler() {
try {
this.handler = handlerClass.getDeclaredConstructor().newInstance();
this.handler.createClient(parameters);
} catch (Exception e) {
throw new RuntimeException("Failed to instantiate handler: "
+ handlerClass.getName(), e);
}
}

@ProcessElement
public void processElement(ProcessContext c) {
Iterable<PredictionResult<InputT, OutputT>> response = this.handler.request(c.element());
c.output(response);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.remoteinference.base;

import java.io.Serializable;

public abstract class BaseInput implements Serializable {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.remoteinference.base;

public interface BaseModelHandler<ParamT extends BaseModelParameters, InputT extends BaseInput, OutputT extends BaseResponse> {

// initialize the model with provided parameters
public void createClient(ParamT parameters);

// Logic to invoke model provider
public Iterable<PredictionResult<InputT, OutputT>> request(InputT input);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.remoteinference.base;

import java.io.Serializable;

public interface BaseModelParameters extends Serializable {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.remoteinference.base;

import java.io.Serializable;

public abstract class BaseResponse implements Serializable {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.remoteinference.base;

import java.io.Serializable;

public class PredictionResult<InputT, OutputT> implements Serializable {

private final InputT input;
private final OutputT output;

private PredictionResult(InputT input, OutputT output) {
this.input = input;
this.output = output;

}

public InputT getInput() {
return input;
}

public OutputT getOutput() {
return output;
}

public static <InputT, OutputT> PredictionResult<InputT, OutputT> create(InputT input, OutputT output) {
return new PredictionResult<>(input, output);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.remoteinference.openai;

import com.openai.client.OpenAIClient;
import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.models.responses.ResponseCreateParams;
import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler;
import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult;

import java.util.List;
import java.util.stream.Collectors;

public class OpenAIModelHandler
implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, OpenAIModelResponse> {

private transient OpenAIClient client;
private transient ResponseCreateParams clientParams;
private OpenAIModelParameters modelParameters;

@Override
public void createClient(OpenAIModelParameters parameters) {
this.modelParameters = parameters;
this.client = OpenAIOkHttpClient.builder()
.apiKey(this.modelParameters.getApiKey())
.build();
}

@Override
public Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> request(OpenAIModelInput input) {

this.clientParams = ResponseCreateParams.builder()
.model(this.modelParameters.getModelName())
.input(input.getInput())
.build();

String output = client.responses().create(clientParams).output().stream()
.flatMap(item -> item.message().stream())
.flatMap(message -> message.content().stream())
.flatMap(content -> content.outputText().stream())
.map(outputText -> outputText.text())
.collect(Collectors.joining());

return List.of(PredictionResult.create(input, OpenAIModelResponse.create(output)));
}

}
Loading
Loading