Skip to content

Commit

Permalink
Add compute on construction hashcode and protobuf size
Browse files Browse the repository at this point in the history
Signed-off-by: jasperpotts <[email protected]>
  • Loading branch information
jasperpotts committed Jan 9, 2025
1 parent 4088c86 commit 648f4b0
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.hedera.pbj.compiler.impl.MapField;
import com.hedera.pbj.compiler.impl.OneOfField;
import com.hedera.pbj.compiler.impl.SingleField;
import com.hedera.pbj.compiler.impl.generators.protobuf.StaticMeasureRecordMethodGenerator;
import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser;
import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser.MessageDefContext;
import edu.umd.cs.findbugs.annotations.NonNull;
Expand All @@ -31,6 +32,9 @@
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.gradle.internal.impldep.com.google.common.collect.Streams;

/**
* Code generator that parses protobuf files and generates nice Java source for record files for each message type and
Expand Down Expand Up @@ -77,8 +81,10 @@ public void generate(final MessageDefContext msgDef,
cleanDocStr(msgDef.docComment().getText().replaceAll("\n \\*\s*\n","\n * <p>\n"));
// The Javadoc "@Deprecated" tag, which is set if the protobuf schema says the field is deprecated
String deprecated = "";
// The list of fields, as defined in the protobuf schema
// The list of fields, as defined in the protobuf schema & precomputed fields
final List<Field> fields = new ArrayList<>();
// The list of fields, as defined in the protobuf schema
final List<Field> fieldsNoPrecomputed = new ArrayList<>();
// The generated Java code for an enum field if OneOf is used
final List<String> oneofEnums = new ArrayList<>();
// The generated Java code for getters if OneOf is used
Expand All @@ -92,6 +98,7 @@ public void generate(final MessageDefContext msgDef,
imports.add("com.hedera.pbj.runtime.io.buffer");
imports.add("com.hedera.pbj.runtime.io.stream");
imports.add("edu.umd.cs.findbugs.annotations");
imports.add("static "+modelPackage+".schema."+javaRecordName+"Schema");

// Iterate over all the items in the protobuf schema
for (final var item : msgDef.messageBody().messageElement()) {
Expand All @@ -116,6 +123,17 @@ public void generate(final MessageDefContext msgDef,
}
}

// collect all non precomputed fields
fieldsNoPrecomputed.addAll(fields);

// add precomputed fields to fields
fields.add(new SingleField(false, FieldType.FIXED32, -1,
"precomputedHashCode", null, null,
null, null, "Computed hash code, manual input ignored.", false, null));
fields.add(new SingleField(false, FieldType.FIXED32, -1,
"protobufEncodedSize", null, null,
null, null, "Computed protobuf encoded size, manual input ignored.", false, null));

// process field java doc and insert into record java doc
if (!fields.isEmpty()) {
String recordJavaDoc = javaDocComment.isEmpty() ? "/**\n * " + javaRecordName :
Expand All @@ -136,13 +154,25 @@ public void generate(final MessageDefContext msgDef,
// static codec and default instance
bodyContent +=
generateCodecFields(msgDef, lookupHelper, javaRecordName);
bodyContent += "\n";

// constructor
bodyContent += generateConstructor(javaRecordName, fields, true, msgDef, lookupHelper);
bodyContent += "\n";

bodyContent += generateHashCode(fields);
// precomputed constructor
bodyContent += generatePrecomputedConstructor(javaRecordName, fields, true, msgDef, lookupHelper);
bodyContent += "\n";

bodyContent += generateEquals(fields, javaRecordName);
bodyContent += StaticMeasureRecordMethodGenerator.generateStaticMeasureMethod(fieldsNoPrecomputed);
bodyContent += "\n";

// hashCode method
bodyContent += generateHashCode(fieldsNoPrecomputed);
bodyContent += "\n";

// equals method
bodyContent += generateEquals(fieldsNoPrecomputed, javaRecordName);

final List<Field> comparableFields = filterComparableFields(msgDef, lookupHelper, fields);
final boolean hasComparableFields = !comparableFields.isEmpty();
Expand All @@ -159,10 +189,10 @@ public void generate(final MessageDefContext msgDef,
bodyContent += "\n";

// builder copy & new builder methods
bodyContent = genrateBuilderFactoryMethods(bodyContent, fields);
bodyContent = genrateBuilderFactoryMethods(bodyContent, fieldsNoPrecomputed);

// generate builder
bodyContent += generateBuilder(msgDef, fields, lookupHelper);
bodyContent += generateBuilder(msgDef, fieldsNoPrecomputed, lookupHelper);
bodyContent += "\n";

// oneof enums
Expand Down Expand Up @@ -212,7 +242,9 @@ private static String generateClass(final String modelPackage,
import edu.umd.cs.findbugs.annotations.Nullable;
import edu.umd.cs.findbugs.annotations.NonNull;
import static java.util.Objects.requireNonNull;
import static com.hedera.pbj.runtime.ProtoWriterTools.*;
import static com.hedera.pbj.runtime.ProtoConstants.*;
$javaDocComment$deprecated
public record $javaRecordName(
$fields) $implementsComparable{
Expand Down Expand Up @@ -343,18 +375,32 @@ private static String generateHashCode(final List<Field> fields) {
String bodyContent =
"""
/**
* Override the default hashCode method for
* all other objects to make hashCode
* Override the default hashCode method for to make hashCode better distributed and follows protobuf rules
* for default values. This is important for backward compatibility.
*/
@Override
public int hashCode() {
return precomputedHashCode;
}
/**
* Compute the hashcode
*/
private static int computeHashCode($params) {
int result = 1;
""".indent(DEFAULT_INDENT);
if(DEFAULT != null) {
""".replace("$params",fields.stream()
.filter(f -> f.fieldNumber() != -1)
.map(field ->
field.javaFieldType() + " " + field.nameCamelFirstLower()
).collect(Collectors.joining(", ")))
.indent(DEFAULT_INDENT);

bodyContent += statements;

bodyContent +=
"""
}
long hashCode = result;
$hashCodeManipulation
return (int)hashCode;
Expand All @@ -364,6 +410,48 @@ public int hashCode() {
return bodyContent;
}

/**
* Generates a pre-populated constructor for a class.
* @param fields the fields to use for the code generation
* @return the generated code
*/
private static String generatePrecomputedConstructor(
final String constructorName,
final List<Field> fields,
final boolean shouldThrowOnOneOfNull,
final MessageDefContext msgDef,
final ContextualLookupHelper lookupHelper) {
String constructorCode = "this($constructorParams);"
.replace("$constructorParams",Stream.concat(
fields.stream()
.filter(f -> f.fieldNumber() != -1)
.map(Field::nameCamelFirstLower),
Stream.of("0","0")
).collect(Collectors.joining(", \n"+" ".repeat(DEFAULT_INDENT))))
.indent(DEFAULT_INDENT);
return """
/**
* Create a $constructorName without passing computed fields.
* $constructorParamDocs
*/
public $constructorName($constructorParams) {
$constructorCode }
"""
.replace("$constructorParamDocs",fields.stream()
.filter(f -> f.fieldNumber() != -1)
.map(field ->
"\n * The @param "+field.nameCamelFirstLower()+" "+
field.comment().replaceAll("\n", "\n * "+" ".repeat(field.nameCamelFirstLower().length()))
).collect(Collectors.joining(" ")))
.replace("$constructorName", constructorName)
.replace("$constructorParams",fields.stream()
.filter(f -> f.fieldNumber() != -1)
.map(field ->
field.javaFieldType() + " " + field.nameCamelFirstLower()
).collect(Collectors.joining(", ")))
.replace("$constructorCode",constructorCode.indent(DEFAULT_INDENT));
}

/**
* Generates a pre-populated constructor for a class.
* @param fields the fields to use for the code generation
Expand All @@ -389,7 +477,7 @@ private static String generateConstructor(
.replace("$constructorParamDocs",fields.stream().map(field ->
"\n * @param "+field.nameCamelFirstLower()+" "+
field.comment().replaceAll("\n", "\n * "+" ".repeat(field.nameCamelFirstLower().length()))
).collect(Collectors.joining(", ")))
).collect(Collectors.joining(" ")))
.replace("$constructorName", constructorName)
.replace("$constructorParams",fields.stream().map(field ->
field.javaFieldType() + " " + field.nameCamelFirstLower()
Expand All @@ -399,64 +487,51 @@ private static String generateConstructor(
if (shouldThrowOnOneOfNull && field instanceof OneOfField) {
sb.append(generateConstructorCodeForField(field)).append('\n');
}
switch (field.type()) {
case BYTES, STRING: {
sb.append("this.$name = $name != null ? $name : $default;"
.replace("$name", field.nameCamelFirstLower())
.replace("$default", getDefaultValue(field, msgDef, lookupHelper))
);
break;
}
case MAP: {
sb.append("this.$name = PbjMap.of($name);"
.replace("$name", field.nameCamelFirstLower())
);
break;
}
default:
if (field.repeated()) {
sb.append("this.$name = $name == null ? Collections.emptyList() : $name;".replace("$name", field.nameCamelFirstLower()));
} else {
sb.append("this.$name = $name;".replace("$name", field.nameCamelFirstLower()));
if (field.nameCamelFirstLower().equals("precomputedHashCode")) {
final String computeHashCode = "computeHashCode("
+ fields.stream()
.filter(f -> f.fieldNumber() != -1)
.map(f -> "this."+f.nameCamelFirstLower())
.collect(Collectors.joining(", "))
+ ")";
sb.append("this.precomputedHashCode = "+computeHashCode+";");
} else if (field.nameCamelFirstLower().equals("protobufEncodedSize")) {
final String computeProtobufSize = "computeProtobufSize("
+ fields.stream()
.filter(f -> f.fieldNumber() != -1)
.map(f -> "this."+f.nameCamelFirstLower())
.collect(Collectors.joining(", "))
+ ")";
sb.append("this.protobufEncodedSize = "+computeProtobufSize+";");
} else {
switch (field.type()) {
case BYTES, STRING: {
sb.append("this.$name = $name != null ? $name : $default;"
.replace("$name", field.nameCamelFirstLower())
.replace("$default", getDefaultValue(field, msgDef, lookupHelper))
);
break;
}
case MAP: {
sb.append("this.$name = PbjMap.of($name);"
.replace("$name", field.nameCamelFirstLower())
);
break;
}
break;
default:
if (field.repeated()) {
sb.append("this.$name = $name == null ? Collections.emptyList() : $name;".replace(
"$name", field.nameCamelFirstLower()));
} else {
sb.append("this.$name = $name;".replace("$name", field.nameCamelFirstLower()));
}
break;
}
}
return sb.toString();
}).collect(Collectors.joining("\n")).indent(DEFAULT_INDENT * 2));
}

/**
* Generates constructor code for the class
* @param fields the fields to use for the code generation
* @param javaRecordName the name of the class
* @return the generated code
*/
@NonNull
private static String generateConstructor(final List<Field> fields, final String javaRecordName) {
return """
/**
* Override the default constructor adding input validation
* %s
*/
public %s {
%s
}
"""
.formatted(
fields.stream().map(field -> "\n * @param " + field.nameCamelFirstLower() + " " +
field.comment()
.replaceAll("\n", "\n * " + " ".repeat(field.nameCamelFirstLower().length()))
).collect(Collectors.joining()),
javaRecordName,
fields.stream()
.filter(f -> f instanceof OneOfField)
.map(ModelGenerator::generateConstructorCodeForField)
.collect(Collectors.joining("\n"))
)
.indent(DEFAULT_INDENT);
}

/**
* Generates the constructor code for the class
* @param f the field to use for the code generation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ static String generateMeasureMethod(final String modelClassName, final List<Fiel
* @return The length in bytes that would be written
*/
public int measureRecord($modelClass data) {
int size = 0;
$fieldSizeOfLines
return size;
return data.protobufEncodedSize();
}
"""
.replace("$modelClass", modelClassName)
Expand Down
Loading

0 comments on commit 648f4b0

Please sign in to comment.