Skip to content

Commit

Permalink
Java: Add example for AI-based image generation with Amazon Titan on …
Browse files Browse the repository at this point in the history
…Bedrock (awsdocs#5724)
  • Loading branch information
DennisTraub authored and max-webster committed Mar 15, 2024
1 parent 0d93945 commit 5c09c68
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 26 deletions.
25 changes: 25 additions & 0 deletions .doc_gen/metadata/bedrock-runtime_metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,14 @@ bedrock-runtime_InvokeTitanImage:
- description: Invoke the Amazon Titan image generation model.
snippet_tags:
- gov2.bedrock-runtime.InvokeTitanImage
Java:
versions:
- sdk_version: 2
github: javav2/example_code/bedrock-runtime
excerpts:
- description: Invoke the Amazon Titan image generation model.
snippet_tags:
- bedrock-runtime.java2.invoke_titan_image.main
PHP:
versions:
- sdk_version: 3
Expand All @@ -293,6 +301,23 @@ bedrock-runtime_InvokeTitanImage:
services:
bedrock-runtime: {InvokeModel}

bedrock-runtime_InvokeTitanImageAsync:
title: Asynchronously invoke the Amazon Titan on &BR; to generate images
title_abbrev: Asynchronous image generation with Amazon Titan
synopsis: asynchronously invoke the Amazon Titan on &BR; to generate images.
category:
languages:
Java:
versions:
- sdk_version: 2
github: javav2/example_code/bedrock-runtime
excerpts:
- description: Invoke the Amazon Titan image generation model (async).
snippet_tags:
- bedrock-runtime.java2.invoke_titan_image_async.main
services:
bedrock-runtime: {InvokeModel}

bedrock-runtime_InvokeModelWithResponseStream:
title: Invoke Anthropic Claude on &BR; to run an inference with a response stream
title_abbrev: Invoke Anthropic Claude on &BR; and process the response stream
Expand Down
20 changes: 11 additions & 9 deletions javav2/example_code/bedrock-runtime/Readme.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<!--Generated by WRITEME on 2023-11-24 20:21:03.808605 (UTC)-->
<!--Generated by WRITEME on 2023-12-06 16:52:24.629768 (UTC)-->
# Amazon Bedrock Runtime code examples for the SDK for Java 2.x

## Overview
Expand Down Expand Up @@ -35,15 +35,17 @@ For prerequisites, see the [README](../../README.md#Prerequisites) in the `javav

Code excerpts that show you how to call individual service functions.

* [Asynchronously invoke AI21 Labs Jurassic-2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L98) (`InvokeModel`)
* [Asynchronously invoke Anthropic Claude on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L37) (`InvokeModel`)
* [Asynchronously invoke Meta Llama 2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L160) (`InvokeModel`)
* [Asynchronously invoke Stability.ai Stable Diffusion XL on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L219) (`InvokeModel`)
* [Invoke AI21 Labs Jurassic-2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L85) (`InvokeModel`)
* [Invoke Anthropic Claude on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L36) (`InvokeModel`)
* [Asynchronous image generation with Amazon Titan](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L286) (`InvokeModel`)
* [Asynchronously invoke AI21 Labs Jurassic-2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L99) (`InvokeModel`)
* [Asynchronously invoke Anthropic Claude on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L38) (`InvokeModel`)
* [Asynchronously invoke Meta Llama 2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L161) (`InvokeModel`)
* [Image generation with Amazon Titan](src/main/java/com/example/bedrockruntime/InvokeModel.java#L232) (`InvokeModel`)
* [Image generation with Stable Diffusion](src/main/java/com/example/bedrockruntime/InvokeModel.java#L179) (`InvokeModel`)
* [Image generation with Stable Diffusion using the async client](src/main/java/com/example/bedrockruntime/InvokeModelAsync.java#L220) (`InvokeModel`)
* [Invoke AI21 Labs Jurassic-2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L86) (`InvokeModel`)
* [Invoke Anthropic Claude on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L37) (`InvokeModel`)
* [Invoke Anthropic Claude on Amazon Bedrock and process the response stream](src/main/java/com/example/bedrockruntime/InvokeModelWithResponseStream.java#L34) (`InvokeModelWithResponseStream`)
* [Invoke Meta Llama 2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L133) (`InvokeModel`)
* [Invoke Stability.ai Stable Diffusion XL on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L178) (`InvokeModel`)
* [Invoke Meta Llama 2 on Amazon Bedrock](src/main/java/com/example/bedrockruntime/InvokeModel.java#L134) (`InvokeModel`)

## Run the examples

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class BedrockRuntimeUsageDemo {
private static final String JURASSIC2 = "ai21.j2-mid-v1";
private static final String LLAMA2 = "meta.llama2-13b-chat-v1";
private static final String STABLE_DIFFUSION = "stability.stable-diffusion-xl";
private static final String TITAN_IMAGE = "amazon.titan-image-generator-v1";

public static void main(String[] args) {
BedrockRuntimeUsageDemo.textToText();
Expand All @@ -39,9 +40,13 @@ public static void main(String[] args) {
private static void textToText() {

String prompt = "In one sentence, what is a large-language model?";
BedrockRuntimeUsageDemo.invoke(CLAUDE, prompt, null);
BedrockRuntimeUsageDemo.invoke(JURASSIC2, prompt, null);
BedrockRuntimeUsageDemo.invoke(LLAMA2, prompt, null);
BedrockRuntimeUsageDemo.invoke(CLAUDE, prompt);
BedrockRuntimeUsageDemo.invoke(JURASSIC2, prompt);
BedrockRuntimeUsageDemo.invoke(LLAMA2, prompt);
}

private static void invoke(String modelId, String prompt) {
invoke(modelId, prompt, null);
}

private static void invoke(String modelId, String prompt, String stylePreset) {
Expand All @@ -53,19 +58,19 @@ private static void invoke(String modelId, String prompt, String stylePreset) {
switch (modelId) {
case CLAUDE:
printResponse(invokeClaude(prompt));
return;
break;
case JURASSIC2:
printResponse(invokeJurassic2(prompt));
return;
break;
case LLAMA2:
printResponse(invokeLlama2(prompt));
return;
break;
case STABLE_DIFFUSION:
long seed = (random.nextLong() & 0xFFFFFFFFL);
String base64ImageData = invokeStableDiffusion(prompt, seed, stylePreset);
String imagePath = saveImage(base64ImageData);
System.out.printf("Success: The generated image has been saved to %s%n", imagePath);
return;
createImage(STABLE_DIFFUSION, prompt, random.nextLong() & 0xFFFFFFFFL, stylePreset);
break;
case TITAN_IMAGE:
createImage(TITAN_IMAGE, prompt, random.nextLong() & 0xFFFFFFFL);
break;
default:
throw new IllegalStateException("Unexpected value: " + modelId);
}
Expand All @@ -75,12 +80,24 @@ private static void invoke(String modelId, String prompt, String stylePreset) {
}
}

private static void createImage(String modelId, String prompt, long seed) {
createImage(modelId, prompt, seed, null);
}

private static void createImage(String modelId, String prompt, long seed, String stylePreset) {
String base64ImageData = (modelId.equals(STABLE_DIFFUSION))
? invokeStableDiffusion(prompt, seed, stylePreset)
: invokeTitanImage(prompt, seed);
String imagePath = saveImage(modelId, base64ImageData);
System.out.printf("Success: The generated image has been saved to %s%n", imagePath);
}

private static void textToTextWithResponseStream() {
String prompt = "What is a large-language model?";
BedrockRuntimeUsageDemo.invoke(CLAUDE, prompt);
BedrockRuntimeUsageDemo.invokeWithResponseStream(CLAUDE, prompt);
}

private static void invoke(String modelId, String prompt) {
private static void invokeWithResponseStream(String modelId, String prompt) {
System.out.println(new String(new char[88]).replace("\0", "-"));
System.out.printf("Invoking %s with response stream%n", modelId);
System.out.println("Prompt: " + prompt);
Expand All @@ -95,16 +112,17 @@ private static void invoke(String modelId, String prompt) {
}

private static void textToImage() {
String imagePrompt = "A sunset over the ocean";
String imagePrompt = "stylized picture of a cute old steampunk robot";
String stylePreset = "photographic";
BedrockRuntimeUsageDemo.invoke(STABLE_DIFFUSION, imagePrompt, stylePreset);
BedrockRuntimeUsageDemo.invoke(TITAN_IMAGE, imagePrompt);
}

private static void printResponse(String response) {
System.out.printf("Generated text: %s%n", response);
}

private static String saveImage(String base64ImageData) {
private static String saveImage(String modelId, String base64ImageData) {
try {
String directory = "output";
URI uri = InvokeModel.class.getProtectionDomain().getCodeSource().getLocation().toURI();
Expand All @@ -117,7 +135,7 @@ private static String saveImage(String base64ImageData) {
int i = 1;
String fileName;
do {
fileName = String.format("image_%d.png", i);
fileName = String.format("%s_%d.png", modelId, i);
i++;
} while (Files.exists(outputPath.resolve(fileName)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package com.example.bedrockruntime;

// snippet-start:[bedrock-runtime.java2.invoke_model.import]

import org.json.JSONArray;
import org.json.JSONObject;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
Expand Down Expand Up @@ -204,7 +205,7 @@ public static String invokeStableDiffusion(String prompt, long seed, String styl
.put("text_prompts", wrappedPrompt)
.put("seed", seed);

if (stylePreset != null && !stylePreset.isEmpty()) {
if (!(stylePreset == null || stylePreset.isEmpty())) {
payload.put("style_preset", stylePreset);
}

Expand All @@ -227,4 +228,60 @@ public static String invokeStableDiffusion(String prompt, long seed, String styl
return base64ImageData;
}
// snippet-end:[bedrock-runtime.java2.invoke_stable_diffusion.main]

// snippet-start:[bedrock-runtime.java2.invoke_titan_image.main]
/**
* Invokes the Amazon Titan image generation model to create an image using the input
* provided in the request body.
*
* @param prompt The prompt that you want Amazon Titan to use for image generation.
* @param seed The random noise seed for image generation (Range: 0 to 2147483647).
* @return A Base64-encoded string representing the generated image.
*/
public static String invokeTitanImage(String prompt, long seed) {
/*
The different model providers have individual request and response formats.
For the format, ranges, and default values for Titan Image models refer to:
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-image.html
*/
String titanImageModelId = "amazon.titan-image-generator-v1";

BedrockRuntimeClient client = BedrockRuntimeClient.builder()
.region(Region.US_EAST_1)
.credentialsProvider(ProfileCredentialsProvider.create())
.build();

var textToImageParams = new JSONObject().put("text", prompt);

var imageGenerationConfig = new JSONObject()
.put("numberOfImages", 1)
.put("quality", "standard")
.put("cfgScale", 8.0)
.put("height", 512)
.put("width", 512)
.put("seed", seed);

JSONObject payload = new JSONObject()
.put("taskType", "TEXT_IMAGE")
.put("textToImageParams", textToImageParams)
.put("imageGenerationConfig", imageGenerationConfig);

InvokeModelRequest request = InvokeModelRequest.builder()
.body(SdkBytes.fromUtf8String(payload.toString()))
.modelId(titanImageModelId)
.contentType("application/json")
.accept("application/json")
.build();

InvokeModelResponse response = client.invokeModel(request);

JSONObject responseBody = new JSONObject(response.body().asUtf8String());

String base64ImageData = responseBody
.getJSONArray("images")
.getString(0);

return base64ImageData;
}
// snippet-end:[bedrock-runtime.java2.invoke_titan_image.main]
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

Expand Down Expand Up @@ -281,4 +282,74 @@ public static String invokeStableDiffusion(String prompt, long seed, String styl
return base64ImageData;
}
// snippet-end:[bedrock-runtime.java2.invoke_stable_diffusion_async.main]

// snippet-start:[bedrock-runtime.java2.invoke_titan_image_async.main]
/**
* Invokes the Amazon Titan image generation model to create an image using the input
* provided in the request body.
*
* @param prompt The prompt that you want Amazon Titan to use for image generation.
* @param seed The random noise seed for image generation (Range: 0 to 2147483647).
* @return A Base64-encoded string representing the generated image.
*/
public static String invokeTitanImage(String prompt, long seed) {
/*
The different model providers have individual request and response formats.
For the format, ranges, and default values for Titan Image models refer to:
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-image.html
*/
String titanImageModelId = "amazon.titan-image-generator-v1";

BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
.region(Region.US_EAST_1)
.credentialsProvider(ProfileCredentialsProvider.create())
.build();

var textToImageParams = new JSONObject().put("text", prompt);

var imageGenerationConfig = new JSONObject()
.put("numberOfImages", 1)
.put("quality", "standard")
.put("cfgScale", 8.0)
.put("height", 512)
.put("width", 512)
.put("seed", seed);

JSONObject payload = new JSONObject()
.put("taskType", "TEXT_IMAGE")
.put("textToImageParams", textToImageParams)
.put("imageGenerationConfig", imageGenerationConfig);

InvokeModelRequest request = InvokeModelRequest.builder()
.body(SdkBytes.fromUtf8String(payload.toString()))
.modelId(titanImageModelId)
.contentType("application/json")
.accept("application/json")
.build();

CompletableFuture<InvokeModelResponse> completableFuture = client.invokeModel(request)
.whenComplete((response, exception) -> {
if (exception != null) {
System.out.println("Model invocation failed: " + exception);
}
});

String base64ImageData = "";
try {
InvokeModelResponse response = completableFuture.get();
JSONObject responseBody = new JSONObject(response.body().asUtf8String());
base64ImageData = responseBody
.getJSONArray("images")
.getString(0);

} catch (InterruptedException e) {
Thread.currentThread().interrupt();
System.err.println(e.getMessage());
} catch (ExecutionException e) {
System.err.println(e.getMessage());
}

return base64ImageData;
}
// snippet-end:[bedrock-runtime.java2.invoke_titan_image_async.main]
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,14 @@ void InvokeStableDiffusion() {
assertNotNullOrEmpty(base64Result);
System.out.println("Test async invoke Stable Diffusion passed.");
}

@Test
@Tag("IntegrationTest")
void InvokeTitanImage() {
String prompt = "A sunset over the ocean";
long seed = 0;
String base64Result = InvokeModelAsync.invokeTitanImage(prompt, seed);
assertNotNullOrEmpty(base64Result);
System.out.println("Test async invoke Titan Image passed.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,13 @@ void InvokeStableDiffusion() {
assertNotNullOrEmpty(base64Result);
System.out.println("Test sync invoke Stable Diffusion passed.");
}

@Test
@Tag("IntegrationTest")
void InvokeTitanImage() {
String prompt = "A sunset over the ocean";
String base64Result = InvokeModel.invokeTitanImage(prompt, 0);
assertNotNullOrEmpty(base64Result);
System.out.println("Test sync invoke Titan Image passed.");
}
}

0 comments on commit 5c09c68

Please sign in to comment.