Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

describe DJL blocks as tensorflow #3178

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions api/src/main/java/ai/djl/nn/Blocks.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.convolutional.Convolution;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.training.Trainer;
import ai.djl.training.loss.Loss;
import ai.djl.util.Pair;
import ai.djl.util.PairList;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -167,4 +173,220 @@ public static String describe(Block block, String blockName, int beginAxis) {
}
return sb.toString();
}

/**
* Builds an equivalent tensorflow model from the the DJL model using functional or sequential
* API.
*
* @param trainer The trainer containing the DJL model
* @param functionalApi if <code>true</code>, keras's functional API is used, otherwise the
* sequential API. The model should be initialized when using the functional API.
* @return Python code
*/
public static String describeAsTensorflow(Trainer trainer, boolean functionalApi) {
Block block = trainer.getModel().getBlock();
String model =
describeAsTensorflow(block, "SequentialBlock", "", functionalApi ? "inputs" : null);
if (functionalApi) {
String inputLayer =
block.isInitialized()
? String.format(
"inputs = tf.keras.layers.InputLayer(input_shape = %s).output",
block.getInputShapes()[0].slice(1))
: "# define input tensor here";
return String.format(
"%s%n%s%nmodel = tf.keras.Model(inputs=inputs, outputs=outputs)%n%nloss = %s",
inputLayer, model, describeAsTensorflow(trainer.getLoss()));
}
return String.format(
"model = %s%n%nloss = %s", model, describeAsTensorflow(trainer.getLoss()));
}

static String describeAsTensorflow(Loss loss) {
switch (loss.getClass().getSimpleName()) {
case "SoftmaxCrossEntropyLoss":
return "tf.keras.losses.categorical_crossentropy";
default:
return "tf.keras.losses.mean_squared_error";
}
}

/**
* Builds a tensorflow layer equivalent to the passed {@link Block}.
*
* @param block the block to translate
* @param blockName the DJL name of the passed block, or <code>null</code> if the block's class
* name is to be used
* @param pythonName the name to be used for the keras layer name
* @param input if not <code>null</code>, the input tensor to call the layer with required by
* functional API, otherwise sequential API is used
* @return Python expression for sequential API or Python statements for functional API
*/
public static String describeAsTensorflow(
Block block, String blockName, String pythonName, String input) {
if (block instanceof LambdaBlock
&& !LambdaBlock.DEFAULT_NAME.equals(((LambdaBlock) block).getName())) {
blockName = ((LambdaBlock) block).getName();
}
switch (blockName) {
case "ParallelBlock":
{
Object[][] args = {{-1}};
return format("tf.keras.layers.Concatenate", args, block, pythonName, input);
}
case "batchFlatten":
{
Object[][] args = {};
return format("tf.keras.layers.Flatten", args, block, pythonName, input);
}
case "SequentialBlock":
{
Object[][] args = {{-1}};
String op =
pythonName.isEmpty()
? "tf.keras.models.Sequential"
: "tf.keras.Sequential";
return format(op, args, block, pythonName, input);
}
case "Add":
{
Object[][] args = {{-1}};
return format("tf.keras.layers.Add", args, block, pythonName, input);
}
case "Linear":
{
Object[][] args = {
{block.getOutputShapes(new Shape[] {new Shape(0)})[0].get(0)}
};
return format("tf.keras.layers.Dense", args, block, pythonName, input);
}
case "GELU":
case "Mish":
case "ReLU6":
case "ReLU":
case "SELU":
case "sigmoid":
case "softPlus":
case "softSign":
case "Tanh":
{
Object[][] args = {{"tf.keras.activations." + blockName.toLowerCase()}};
return format("tf.keras.layers.Activation", args, block, pythonName, input);
}
case "identity":
{
Object[][] args = {};
return format("tf.keras.layers.Identity", args, block, pythonName, input);
}
case "Conv2d":
{
Convolution conv = (Convolution) block;
String padding =
new Shape(0, 0).equals(conv.getPadding()) ? "'VALID'" : "'SAME'";
Object[][] args = {
{conv.getFilters(), "filters"},
{conv.getKernelShape(), "kernel_size"},
{conv.getStride(), "strides"},
{null, "padding", padding},
{conv.getDilation(), "dilation_rate"},
{null, "data_format", "'channels_first'"},
{null, "use_bias", conv.isIncludeBias()}
};
return format("tf.keras.layers.Conv2D", args, block, pythonName, input);
}

case "BatchNorm":
{
BatchNorm norm = (BatchNorm) block;
Object[][] args = {
{norm.getScale(), "scale"},
{norm.getCenter(), "center"},
{norm.getEpsilon(), "epsilon"},
{norm.getMomentum(), "momentum"},
{norm.getAxis(), "axis"}
};
return format(
"tf.keras.layers.BatchNormalization", args, block, pythonName, input);
}

case "globalAvgPool2d":
{
Object[][] args = {{null, "data_format", "'channels_first'"}};
return format(
"tf.keras.layers.GlobalAveragePooling2D",
args,
block,
pythonName,
input);
}
default:
{
Object[][] args = {{-1}};
return format(blockName, args, block, pythonName, input);
}
}
}

static String format(String op, Object[][] args, Block block, String pythonName, String input) {
String pref = "";
StringBuilder sb = new StringBuilder(op + "(");
for (Object[] arg : args) {
String s = arg.length >= 3 ? String.valueOf(arg[2]) : null;
if (Integer.valueOf(-1).equals(arg[0])) {
List<String> nameOfLayers = new ArrayList<>();
List<String> layers = new ArrayList<>();
for (Pair<String, Block> pair : block.getChildren()) {
String name = pair.getKey().substring(2);
String pythonNameOfLayer =
pythonName
+ (pythonName.isEmpty() ? "" : "_")
+ name
+ pair.getKey().substring(0, 2);
String layer =
describeAsTensorflow(pair.getValue(), name, pythonNameOfLayer, input);
layers.add(layer);
if (input != null) {
nameOfLayers.add(
layer.substring(
layer.lastIndexOf('\n') + 1, layer.lastIndexOf(" = ")));
if (op.endsWith("Sequential")) {
input = nameOfLayers.get(nameOfLayers.size() - 1);
}
}
}
if (input != null) {
pref = layers.stream().collect(Collectors.joining("\n", "", "\n"));
if (!op.endsWith("Sequential")) {
input = nameOfLayers.stream().collect(Collectors.joining(", ", "[", "]"));
}
continue;
} else {
s =
layers.stream()
.map(b -> b.replaceAll("(?m)^", " "))
.collect(Collectors.joining(",\n", "[\n", "\n]"));
}
} else if (arg[0] != null) {
s = arg[0].toString();
} else if (s == null) {
continue; // cannot resolve index, so skip
}
s = "true".equals(s) ? "True" : "false".equals(s) ? "False" : s;
if (arg.length >= 2 && arg[1] != null) {
s = String.format("%s=%s", arg[1], s);
}
sb.append(s);
sb.append(", ");
}
String name = pythonName.isEmpty() ? "outputs" : pythonName;
sb.append(String.format("name='%s'", name));
sb.append(')');
if (input != null) {
if (op.endsWith("Sequential")) {
return String.format("%s%s = %s", pref, name, input);
}
return String.format("%s%s = %s(%s)", pref, name, sb, input);
}
return sb.toString();
}
}
54 changes: 54 additions & 0 deletions api/src/main/java/ai/djl/nn/core/Add.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn.core;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.nn.Block;
import ai.djl.nn.ParallelBlock;

import java.util.Collections;
import java.util.List;

/**
* {@code Add} is a {@link Block} whose children form a parallel branch in the network and are
* combined by {@link NDArrays#add(NDArray...)} to produce a single output.
*
* <p>{@code Add} has no direct parameters.
*/
public class Add extends ParallelBlock {

/**
* Creates a block whose branches are combined to form a single output by {@link
* NDArrays#add(NDArray...)}.
*/
public Add() {
this(Collections.emptyList());
}

/**
* Creates a block whose branches are formed by each block in the list of blocks, and are
* combined by {@link NDArrays#add(NDArray...)} to form a single output.
*
* @param blocks the blocks that form each of the parallel branches
*/
public Add(List<Block> blocks) {
super(
list -> {
NDArray[] arrays = list.stream().map(NDList::head).toArray(NDArray[]::new);
return new NDList(NDArrays.add(arrays));
},
blocks);
}
}
45 changes: 45 additions & 0 deletions api/src/main/java/ai/djl/nn/norm/BatchNorm.java
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,51 @@ public static NDList batchNorm(
input, runningMean, runningVar, gamma, beta, axis, momentum, eps, training);
}

/**
* Returns axis the axis in which channel is specified.
*
* @return axis the axis in which channel is specified
*/
public int getAxis() {
return axis;
}

/**
* Returns the epsilon value to prevent division by 0.
*
* @return the epsilon value to prevent division by 0
*/
public float getEpsilon() {
return epsilon;
}

/**
* Returns the momentum for moving average.
*
* @return the momentum for moving average
*/
public float getMomentum() {
return momentum;
}

/**
* Returns offset of `beta` to add to normalized tensor.
*
* @return offset of `beta` to add to normalized tensor
*/
public boolean getCenter() {
return center;
}

/**
* Whether multiply result by `gamma`.
*
* @return whether multiply result by `gamma`.
*/
public boolean getScale() {
return scale;
}

/**
* Creates a builder to build a {@code BatchNorm}.
*
Expand Down
Loading
Loading