Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java: Add example for AI-based image generation with Amazon Titan on Bedrock #5724

Merged
merged 4 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .doc_gen/metadata/bedrock-runtime_metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,14 @@ bedrock-runtime_InvokeTitanImage:
- description: Invoke the Amazon Titan image generation model.
snippet_tags:
- gov2.bedrock-runtime.InvokeTitanImage
Java:
versions:
- sdk_version: 2
github: javav2/example_code/bedrock-runtime
excerpts:
- description: Invoke the Amazon Titan image generation model.
snippet_tags:
- bedrock-runtime.java2.invoke_titan_image.main
PHP:
versions:
- sdk_version: 3
Expand All @@ -293,6 +301,23 @@ bedrock-runtime_InvokeTitanImage:
services:
bedrock-runtime: {InvokeModel}

bedrock-runtime_InvokeTitanImageAsync:
title: Asynchronously invoke the Amazon Titan on &BR; to generate images
title_abbrev: Asynchronous image generation with Amazon Titan
synopsis: asynchronously invoke the Amazon Titan on &BR; to generate images.
category:
languages:
Java:
versions:
- sdk_version: 2
github: javav2/example_code/bedrock-runtime
excerpts:
- description: Invoke the Amazon Titan image generation model (async).
snippet_tags:
- bedrock-runtime.java2.invoke_titan_image_async.main
services:
bedrock-runtime: {InvokeModel}

bedrock-runtime_InvokeModelWithResponseStream:
title: Invoke Anthropic Claude on &BR; to run an inference with a response stream
title_abbrev: Invoke Anthropic Claude on &BR; and process the response stream
Expand Down
20 changes: 11 additions & 9 deletions javav2/example_code/bedrock-runtime/Readme.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<!--Generated by WRITEME on 2023-11-24 20:21:03.808605 (UTC)-->
<!--Generated by WRITEME on 2023-12-06 16:52:24.629768 (UTC)-->
# Amazon Bedrock Runtime code examples for the SDK for Java 2.x

## Overview
Expand Down Expand Up @@ -35,15 +35,17 @@ For prerequisites, see the [README](../../README.md#Prerequisites) in the `javav

Code excerpts that show you how to call individual service functions.

* [Asynchronously invoke AI21 Labs Jurassic-2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L98) (`InvokeModel`)
* [Asynchronously invoke Anthropic Claude on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L37) (`InvokeModel`)
* [Asynchronously invoke Meta Llama 2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L160) (`InvokeModel`)
* [Asynchronously invoke Stability.ai Stable Diffusion XL on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L219) (`InvokeModel`)
* [Invoke AI21 Labs Jurassic-2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L85) (`InvokeModel`)
* [Invoke Anthropic Claude on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L36) (`InvokeModel`)
* [Asynchronous image generation with Amazon Titan](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L286) (`InvokeModel`)
* [Asynchronously invoke AI21 Labs Jurassic-2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L99) (`InvokeModel`)
* [Asynchronously invoke Anthropic Claude on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L38) (`InvokeModel`)
* [Asynchronously invoke Meta Llama 2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L161) (`InvokeModel`)
* [Image generation with Amazon Titan](src/main/java/com/example/bedrockruntime/InvokeModel.java#L232) (`InvokeModel`)
* [Image generation with Stable Diffusion](src/main/java/com/example/bedrockruntime/InvokeModel.java#L179) (`InvokeModel`)
* [Image generation with Stable Diffusion using the async client](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L220) (`InvokeModel`)
* [Invoke AI21 Labs Jurassic-2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L86) (`InvokeModel`)
* [Invoke Anthropic Claude on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L37) (`InvokeModel`)
* [Invoke Anthropic Claude on Amazon Bedrock and process the response stream](src/main/java/com/example/bedrockruntime/InvokeModelWithResponseStream.java#L34) (`InvokeModelWithResponseStream`)
* [Invoke Meta Llama 2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L133) (`InvokeModel`)
* [Invoke Stability.ai Stable Diffusion XL on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L178) (`InvokeModel`)
* [Invoke Meta Llama 2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L134) (`InvokeModel`)

## Run the examples

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class BedrockRuntimeUsageDemo {
private static final String JURASSIC2 = "ai21.j2-mid-v1";
private static final String LLAMA2 = "meta.llama2-13b-chat-v1";
private static final String STABLE_DIFFUSION = "stability.stable-diffusion-xl";
private static final String TITAN_IMAGE = "amazon.titan-image-generator-v1";

public static void main(String[] args) {
BedrockRuntimeUsageDemo.textToText();
Expand All @@ -39,9 +40,13 @@ public static void main(String[] args) {
private static void textToText() {

String prompt = "In one sentence, what is a large-language model?";
BedrockRuntimeUsageDemo.invoke(CLAUDE, prompt, null);
BedrockRuntimeUsageDemo.invoke(JURASSIC2, prompt, null);
BedrockRuntimeUsageDemo.invoke(LLAMA2, prompt, null);
BedrockRuntimeUsageDemo.invoke(CLAUDE, prompt);
BedrockRuntimeUsageDemo.invoke(JURASSIC2, prompt);
BedrockRuntimeUsageDemo.invoke(LLAMA2, prompt);
}

private static void invoke(String modelId, String prompt) {
invoke(modelId, prompt, null);
}

private static void invoke(String modelId, String prompt, String stylePreset) {
Expand All @@ -53,19 +58,19 @@ private static void invoke(String modelId, String prompt, String stylePreset) {
switch (modelId) {
case CLAUDE:
printResponse(invokeClaude(prompt));
return;
break;
case JURASSIC2:
printResponse(invokeJurassic2(prompt));
return;
break;
case LLAMA2:
printResponse(invokeLlama2(prompt));
return;
break;
case STABLE_DIFFUSION:
long seed = (random.nextLong() & 0xFFFFFFFFL);
String base64ImageData = invokeStableDiffusion(prompt, seed, stylePreset);
String imagePath = saveImage(base64ImageData);
System.out.printf("Success: The generated image has been saved to %s%n", imagePath);
return;
createImage(STABLE_DIFFUSION, prompt, random.nextLong() & 0xFFFFFFFFL, stylePreset);
break;
case TITAN_IMAGE:
createImage(TITAN_IMAGE, prompt, random.nextLong() & 0xFFFFFFFL);
break;
default:
throw new IllegalStateException("Unexpected value: " + modelId);
}
Expand All @@ -75,12 +80,24 @@ private static void invoke(String modelId, String prompt, String stylePreset) {
}
}

private static void createImage(String modelId, String prompt, long seed) {
createImage(modelId, prompt, seed, null);
}

private static void createImage(String modelId, String prompt, long seed, String stylePreset) {
String base64ImageData = (modelId.equals(STABLE_DIFFUSION))
? invokeStableDiffusion(prompt, seed, stylePreset)
: invokeTitanImage(prompt, seed);
String imagePath = saveImage(modelId, base64ImageData);
System.out.printf("Success: The generated image has been saved to %s%n", imagePath);
}

private static void textToTextWithResponseStream() {
String prompt = "What is a large-language model?";
BedrockRuntimeUsageDemo.invoke(CLAUDE, prompt);
BedrockRuntimeUsageDemo.invokeWithResponseStream(CLAUDE, prompt);
}

private static void invoke(String modelId, String prompt) {
private static void invokeWithResponseStream(String modelId, String prompt) {
System.out.println(new String(new char[88]).replace("\0", "-"));
System.out.printf("Invoking %s with response stream%n", modelId);
System.out.println("Prompt: " + prompt);
Expand All @@ -95,16 +112,17 @@ private static void invoke(String modelId, String prompt) {
}

private static void textToImage() {
String imagePrompt = "A sunset over the ocean";
String imagePrompt = "stylized picture of a cute old steampunk robot";
String stylePreset = "photographic";
BedrockRuntimeUsageDemo.invoke(STABLE_DIFFUSION, imagePrompt, stylePreset);
BedrockRuntimeUsageDemo.invoke(TITAN_IMAGE, imagePrompt);
}

private static void printResponse(String response) {
System.out.printf("Generated text: %s%n", response);
}

private static String saveImage(String base64ImageData) {
private static String saveImage(String modelId, String base64ImageData) {
try {
String directory = "output";
URI uri = InvokeModel.class.getProtectionDomain().getCodeSource().getLocation().toURI();
Expand All @@ -117,7 +135,7 @@ private static String saveImage(String base64ImageData) {
int i = 1;
String fileName;
do {
fileName = String.format("image_%d.png", i);
fileName = String.format("%s_%d.png", modelId, i);
i++;
} while (Files.exists(outputPath.resolve(fileName)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package com.example.bedrockruntime;

// snippet-start:[bedrock-runtime.java2.invoke_model.import]

import org.json.JSONArray;
import org.json.JSONObject;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
Expand Down Expand Up @@ -204,7 +205,7 @@ public static String invokeStableDiffusion(String prompt, long seed, String styl
.put("text_prompts", wrappedPrompt)
.put("seed", seed);

if (stylePreset != null && !stylePreset.isEmpty()) {
if (!(stylePreset == null || stylePreset.isEmpty())) {
payload.put("style_preset", stylePreset);
}

Expand All @@ -227,4 +228,60 @@ public static String invokeStableDiffusion(String prompt, long seed, String styl
return base64ImageData;
}
// snippet-end:[bedrock-runtime.java2.invoke_stable_diffusion.main]

// snippet-start:[bedrock-runtime.java2.invoke_titan_image.main]
/**
* Invokes the Amazon Titan image generation model to create an image using the input
* provided in the request body.
*
* @param prompt The prompt that you want Amazon Titan to use for image generation.
* @param seed The random noise seed for image generation (Range: 0 to 2147483647).
* @return A Base64-encoded string representing the generated image.
*/
public static String invokeTitanImage(String prompt, long seed) {
/*
The different model providers have individual request and response formats.
For the format, ranges, and default values for Titan Image models refer to:
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-image.html
*/
String titanImageModelId = "amazon.titan-image-generator-v1";

BedrockRuntimeClient client = BedrockRuntimeClient.builder()
.region(Region.US_EAST_1)
.credentialsProvider(ProfileCredentialsProvider.create())
.build();

var textToImageParams = new JSONObject().put("text", prompt);

var imageGenerationConfig = new JSONObject()
.put("numberOfImages", 1)
.put("quality", "standard")
.put("cfgScale", 8.0)
.put("height", 512)
.put("width", 512)
.put("seed", seed);

JSONObject payload = new JSONObject()
.put("taskType", "TEXT_IMAGE")
.put("textToImageParams", textToImageParams)
.put("imageGenerationConfig", imageGenerationConfig);

InvokeModelRequest request = InvokeModelRequest.builder()
.body(SdkBytes.fromUtf8String(payload.toString()))
.modelId(titanImageModelId)
.contentType("application/json")
.accept("application/json")
.build();

InvokeModelResponse response = client.invokeModel(request);

JSONObject responseBody = new JSONObject(response.body().asUtf8String());

String base64ImageData = responseBody
.getJSONArray("images")
.getString(0);

return base64ImageData;
}
// snippet-end:[bedrock-runtime.java2.invoke_titan_image.main]
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

Expand Down Expand Up @@ -281,4 +282,74 @@ public static String invokeStableDiffusion(String prompt, long seed, String styl
return base64ImageData;
}
// snippet-end:[bedrock-runtime.java2.invoke_stable_diffusion_async.main]

// snippet-start:[bedrock-runtime.java2.invoke_titan_image_async.main]
/**
* Invokes the Amazon Titan image generation model to create an image using the input
* provided in the request body.
*
* @param prompt The prompt that you want Amazon Titan to use for image generation.
* @param seed The random noise seed for image generation (Range: 0 to 2147483647).
* @return A Base64-encoded string representing the generated image.
*/
public static String invokeTitanImage(String prompt, long seed) {
/*
The different model providers have individual request and response formats.
For the format, ranges, and default values for Titan Image models refer to:
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-image.html
*/
String titanImageModelId = "amazon.titan-image-generator-v1";

BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
.region(Region.US_EAST_1)
.credentialsProvider(ProfileCredentialsProvider.create())
.build();

var textToImageParams = new JSONObject().put("text", prompt);

var imageGenerationConfig = new JSONObject()
.put("numberOfImages", 1)
.put("quality", "standard")
.put("cfgScale", 8.0)
.put("height", 512)
.put("width", 512)
.put("seed", seed);

JSONObject payload = new JSONObject()
.put("taskType", "TEXT_IMAGE")
.put("textToImageParams", textToImageParams)
.put("imageGenerationConfig", imageGenerationConfig);

InvokeModelRequest request = InvokeModelRequest.builder()
.body(SdkBytes.fromUtf8String(payload.toString()))
.modelId(titanImageModelId)
.contentType("application/json")
.accept("application/json")
.build();

CompletableFuture<InvokeModelResponse> completableFuture = client.invokeModel(request)
.whenComplete((response, exception) -> {
if (exception != null) {
System.out.println("Model invocation failed: " + exception);
}
});

String base64ImageData = "";
try {
InvokeModelResponse response = completableFuture.get();
JSONObject responseBody = new JSONObject(response.body().asUtf8String());
base64ImageData = responseBody
.getJSONArray("images")
.getString(0);

} catch (InterruptedException e) {
Thread.currentThread().interrupt();
System.err.println(e.getMessage());
} catch (ExecutionException e) {
System.err.println(e.getMessage());
}

return base64ImageData;
}
// snippet-end:[bedrock-runtime.java2.invoke_titan_image_async.main]
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,14 @@ void InvokeStableDiffusion() {
assertNotNullOrEmpty(base64Result);
System.out.println("Test async invoke Stable Diffusion passed.");
}

@Test
@Tag("IntegrationTest")
void InvokeTitanImage() {
String prompt = "A sunset over the ocean";
long seed = 0;
String base64Result = InvokeModelAsync.invokeTitanImage(prompt, seed);
assertNotNullOrEmpty(base64Result);
System.out.println("Test async invoke Titan Image passed.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,13 @@ void InvokeStableDiffusion() {
assertNotNullOrEmpty(base64Result);
System.out.println("Test sync invoke Stable Diffusion passed.");
}

@Test
@Tag("IntegrationTest")
void InvokeTitanImage() {
String prompt = "A sunset over the ocean";
String base64Result = InvokeModel.invokeTitanImage(prompt, 0);
assertNotNullOrEmpty(base64Result);
System.out.println("Test sync invoke Titan Image passed.");
}
}
Loading