diff --git a/.gitattributes b/.gitattributes index 6d94640..4498dfb 100644 --- a/.gitattributes +++ b/.gitattributes @@ -5,6 +5,14 @@ *.scala text eol=lf *.groovy text eol=lf *.conf text eol=lf +*.txt text eol=lf +*.proto text eol=lf +*.fbs text eol=lf +*.sql text eol=lf +*.csv text eol=lf +*.xml text eol=lf +*.ksh text eol=lf +*.py text eol=lf # Similarly for CRLF *.bat text eol=crlf @@ -12,3 +20,5 @@ # Perl scripts should always have LF (Unix default) line ending *.pl text eol=lf *.pm text eol=lf + + diff --git a/optimus/platform/projects/alarms/src/main/scala/optimus/tools/scalacplugins/entity/reporter/StagingAlarms.scala b/optimus/platform/projects/alarms/src/main/scala/optimus/tools/scalacplugins/entity/reporter/StagingAlarms.scala index 1face35..618e913 100644 --- a/optimus/platform/projects/alarms/src/main/scala/optimus/tools/scalacplugins/entity/reporter/StagingAlarms.scala +++ b/optimus/platform/projects/alarms/src/main/scala/optimus/tools/scalacplugins/entity/reporter/StagingAlarms.scala @@ -160,7 +160,7 @@ object StagingErrors extends OptimusErrorsBase with OptimusPluginAlarmHelper { warning0( 20013, StagingPhase.POST_TYPER_STANDARDS, - "[NEW]Suspicious use of implicit Predef.augmentString. Use .toSeq if you really want to treat it as a Seq[Char]. Consider if surrounding code should be a map rather than flatMap. Prefer + rather than ++ for String concatenation." + "Suspicious use of implicit Predef.augmentString. Use .toSeq if you really want to treat it as a Seq[Char]. Consider if surrounding code should be a map rather than flatMap. Prefer + rather than ++ for String concatenation." ) } diff --git a/optimus/platform/projects/annotations/src/main/scala/optimus/platform/_intern.java b/optimus/platform/projects/annotations/src/main/scala/optimus/platform/_intern.java index 5382743..c8c2063 100644 --- a/optimus/platform/projects/annotations/src/main/scala/optimus/platform/_intern.java +++ b/optimus/platform/projects/annotations/src/main/scala/optimus/platform/_intern.java @@ -16,9 +16,7 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; -/** - * Currently tells the DAL reader to intern this value before setting it - */ +/** Currently tells the DAL reader to intern this value before setting it */ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.FIELD) public @interface _intern {} diff --git a/optimus/platform/projects/breadcrumbs/src/main/scala/optimus/breadcrumbs/crumbs/Properties.scala b/optimus/platform/projects/breadcrumbs/src/main/scala/optimus/breadcrumbs/crumbs/Properties.scala index 0f2f183..8f9606e 100644 --- a/optimus/platform/projects/breadcrumbs/src/main/scala/optimus/breadcrumbs/crumbs/Properties.scala +++ b/optimus/platform/projects/breadcrumbs/src/main/scala/optimus/breadcrumbs/crumbs/Properties.scala @@ -750,6 +750,9 @@ object Properties extends KnownProperties { val numBadBaseline = propI val numFileReadErrors = propI val buildNumber = propI + + val oldCacheSize = propI + val newCacheSize = propI } final case class RequestsStallInfo(pluginType: StallPlugin.Value, reqCount: Int, req: Seq[String]) { diff --git a/optimus/platform/projects/collections/src/main/scala-2.12/optimus/collection/RedBlackHelper.java b/optimus/platform/projects/collections/src/main/scala-2.12/optimus/collection/RedBlackHelper.java index ff6e326..8423e08 100644 --- a/optimus/platform/projects/collections/src/main/scala-2.12/optimus/collection/RedBlackHelper.java +++ b/optimus/platform/projects/collections/src/main/scala-2.12/optimus/collection/RedBlackHelper.java @@ -23,8 +23,8 @@ class RedBlackHelper { private RedBlackHelper() {} - private final static Field treeMapField; - private final static Field treeSetField; + private static final Field treeMapField; + private static final Field treeSetField; static { try { @@ -37,23 +37,23 @@ private RedBlackHelper() {} } } - static scala.Option> maxBefore(A key, TreeMap map) throws IllegalAccessException { + static scala.Option> maxBefore(A key, TreeMap map) + throws IllegalAccessException { NewRedBlackTree.Tree tree = (NewRedBlackTree.Tree) treeMapField.get(map); Ordering ordering = map.ordering(); NewRedBlackTree.Tree result = NewRedBlackTree$.MODULE$.maxBefore(tree, key, ordering); - scala.Tuple2 y = result == null - ? null - : new scala.Tuple2(result.key(), result.value()); + scala.Tuple2 y = + result == null ? null : new scala.Tuple2(result.key(), result.value()); return Option.apply(y); } - static scala.Option> minAfter(A key, TreeMap map) throws IllegalAccessException { + static scala.Option> minAfter(A key, TreeMap map) + throws IllegalAccessException { NewRedBlackTree.Tree tree = (NewRedBlackTree.Tree) treeMapField.get(map); Ordering ordering = map.ordering(); NewRedBlackTree.Tree result = NewRedBlackTree$.MODULE$.minAfter(tree, key, ordering); - scala.Tuple2 y = result == null - ? null - : new scala.Tuple2(result.key(), result.value()); + scala.Tuple2 y = + result == null ? null : new scala.Tuple2(result.key(), result.value()); return Option.apply(y); } @@ -61,9 +61,7 @@ static scala.Option maxBefore(A key, TreeSet set) throws IllegalAccess NewRedBlackTree.Tree tree = (NewRedBlackTree.Tree) treeSetField.get(set); Ordering ordering = set.ordering(); NewRedBlackTree.Tree result = NewRedBlackTree$.MODULE$.maxBefore(tree, key, ordering); - A y = result == null - ? null - : result.key(); + A y = result == null ? null : result.key(); return Option.apply(y); } @@ -71,9 +69,7 @@ static scala.Option minAfter(A key, TreeSet set) throws IllegalAccessE NewRedBlackTree.Tree tree = (NewRedBlackTree.Tree) treeSetField.get(set); Ordering ordering = set.ordering(); NewRedBlackTree.Tree result = NewRedBlackTree$.MODULE$.minAfter(tree, key, ordering); - A y = result == null - ? null - : result.key(); + A y = result == null ? null : result.key(); return Option.apply(y); } } diff --git a/optimus/platform/projects/collections/src/main/scala-2.13/optimus/collection/RedBlackHelper.java b/optimus/platform/projects/collections/src/main/scala-2.13/optimus/collection/RedBlackHelper.java index 05573a4..639fcec 100644 --- a/optimus/platform/projects/collections/src/main/scala-2.13/optimus/collection/RedBlackHelper.java +++ b/optimus/platform/projects/collections/src/main/scala-2.13/optimus/collection/RedBlackHelper.java @@ -23,8 +23,8 @@ class RedBlackHelper { private RedBlackHelper() {} - private final static Field treeMapField; - private final static Field treeSetField; + private static final Field treeMapField; + private static final Field treeSetField; static { try { @@ -37,19 +37,23 @@ private RedBlackHelper() {} } } - static scala.Option> maxBefore(A key, TreeMap map) throws IllegalAccessException { + static scala.Option> maxBefore(A key, TreeMap map) + throws IllegalAccessException { RedBlackTree.Tree tree = (RedBlackTree.Tree) treeMapField.get(map); Ordering ordering = map.ordering(); RedBlackTree.Tree result = RedBlackTree$.MODULE$.maxBefore(tree, key, ordering); - scala.Tuple2 y = result == null ? null : new scala.Tuple2(result.key(), result.value()); + scala.Tuple2 y = + result == null ? null : new scala.Tuple2(result.key(), result.value()); return Option.apply(y); } - static scala.Option> minAfter(A key, TreeMap map) throws IllegalAccessException { + static scala.Option> minAfter(A key, TreeMap map) + throws IllegalAccessException { RedBlackTree.Tree tree = (RedBlackTree.Tree) treeMapField.get(map); Ordering ordering = map.ordering(); RedBlackTree.Tree result = RedBlackTree$.MODULE$.minAfter(tree, key, ordering); - scala.Tuple2 y = result == null ? null : new scala.Tuple2(result.key(), result.value()); + scala.Tuple2 y = + result == null ? null : new scala.Tuple2(result.key(), result.value()); return Option.apply(y); } diff --git a/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationCmds.java b/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationCmds.java index 848aace..17f5b99 100644 --- a/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationCmds.java +++ b/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationCmds.java @@ -246,30 +246,40 @@ public static void prefixCallWithDumpOnTransitivelyCached(String methodToPatch) /** * Inject prefix call InstrumentedModuleCtor.trigger in EvaluationContext.current * Only useful if you also executed markAllModuleCtors - * @see optimus.debug.InstrumentationCmds#markAllModuleCtors() - * @see optimus.debug.InstrumentedModuleCtor#trigger() - * @see optimus.debug.RTVerifierCategory#MODULE_CTOR_EC_CURRENT - * @see optimus.debug.RTVerifierCategory#MODULE_LAZY_VAL_EC_CURRENT + * @see InstrumentationCmds#markAllModuleCtors() + * @see InstrumentedModuleCtor#trigger() + * @see RTVerifierCategory#MODULE_CTOR_EC_CURRENT + * @see RTVerifierCategory#MODULE_LAZY_VAL_EC_CURRENT */ public static void prefixECCurrentWithTriggerIfInModuleCtor() { - prefixCall("optimus.graph.OGSchedulerContext.current", "optimus.debug.InstrumentedModuleCtor.trigger"); + var moduleCtorTrigger = "optimus.debug.InstrumentedModuleCtor.trigger"; + prefixCall("optimus.graph.OGSchedulerContext.current", moduleCtorTrigger); + /* eventually we will detect the methods with an annotation: + prefixCall("optimus.platform.ScenarioStack.getNode", moduleCtorTrigger); + prefixCall("optimus.platform.ScenarioStack.env", moduleCtorTrigger); + prefixCall("optimus.platform.ScenarioStack.getTrackingNodeID", moduleCtorTrigger); + prefixCall("optimus.platform.ScenarioStack.getParentTrackingNode", moduleCtorTrigger); + prefixCall("optimus.platform.ScenarioStack.pluginTags", moduleCtorTrigger); + prefixCall("optimus.platform.ScenarioStack.findPluginTag", moduleCtorTrigger); + */ } /** * When markAllModuleCtors is requested this function allows additions to the exclusion list * @param className JVM class name of the module - * @see optimus.debug.InstrumentationCmds#markAllModuleCtors() + * @see InstrumentationCmds#markAllModuleCtors() + * @see InstrumentationCmds#markAllEntityCtorsForSIDetection() */ - public static void excludeModuleFromModuleCtorReporting(String className) { + public static void excludeFromModuleOrEntityCtorReporting(String className) { var jvmName = className.replace('.', '/'); - InstrumentationConfig.addModuleExclusion(jvmName); + InstrumentationConfig.addModuleOrEntityExclusion(jvmName); } /** * When markAllModuleCtors or individual module bracketing is enabled, some call stacks can be disabled * @param methodToPatch fully specified method reference - * @see optimus.debug.InstrumentationCmds#markAllModuleCtors() - * @see optimus.debug.InstrumentationConfig#addModuleConstructionIntercept + * @see InstrumentationCmds#markAllModuleCtors() + * @see InstrumentationConfig#addModuleConstructionIntercept */ public static void excludeMethodFromModuleCtorReporting(String methodToPatch) { MethodRef mref = asMethodRef(methodToPatch); @@ -293,8 +303,8 @@ public static void markScenarioStackAsInitializing(String className) { * Instrument all entity constructors to call a prefix/postfix methods to mark/unmark entity ctors as running * @see InstrumentationCmds#reportFindingTweaksInEntityConstructor() * @see InstrumentationCmds#reportTouchingTweakableInEntityConstructor() - * @see optimus.debug.RTVerifierCategory#TWEAK_IN_ENTITY_CTOR - * @see optimus.debug.RTVerifierCategory#TWEAKABLE_IN_ENTITY_CTOR + * @see RTVerifierCategory#TWEAK_IN_ENTITY_CTOR + * @see RTVerifierCategory#TWEAKABLE_IN_ENTITY_CTOR */ public static void markAllEntityCtorsForSIDetection() { instrumentAllEntities = EntityInstrumentationType.markScenarioStack; @@ -302,8 +312,8 @@ public static void markAllEntityCtorsForSIDetection() { /** * Instrument all module constructors to call a prefix/postfix methods to mark/unmark module ctors as running - * @see optimus.debug.RTVerifierCategory#MODULE_CTOR_EC_CURRENT - * @see optimus.debug.RTVerifierCategory#MODULE_LAZY_VAL_EC_CURRENT + * @see RTVerifierCategory#MODULE_CTOR_EC_CURRENT + * @see RTVerifierCategory#MODULE_LAZY_VAL_EC_CURRENT */ public static void markAllModuleCtors() { instrumentAllModuleConstructors = true; @@ -311,9 +321,9 @@ public static void markAllModuleCtors() { /** * Instrument all classes that don't implement (and base class doesn't either) their own hashCode. - * Therefore relying on identity hashCodes with calls to optimus.debug.InstrumentedHashCodes#hashCode(java.lang.Object) + * Therefore relying on identity hashCodes with calls to InstrumentedHashCodes#hashCode(java.lang.Object) * @apiNote Use to flag values that use identity hashCode while being used as a key in property caching - * @see optimus.debug.InstrumentedHashCodes#hashCode(java.lang.Object) + * @see InstrumentedHashCodes#hashCode(java.lang.Object) */ public static void reportSuspiciousHashCodesCalls() { instrumentAllHashCodes = true; @@ -322,7 +332,7 @@ public static void reportSuspiciousHashCodesCalls() { /** * Instrument callouts and report touching tweakables or entity ctor (which should be RT) * @see InstrumentationCmds#markAllEntityCtorsForSIDetection() - * @see optimus.debug.RTVerifierCategory#TWEAKABLE_IN_ENTITY_CTOR + * @see RTVerifierCategory#TWEAKABLE_IN_ENTITY_CTOR */ public static void reportTouchingTweakableInEntityConstructor() { InstrumentationConfig.addVerifyScenarioStackCalls(); @@ -332,7 +342,7 @@ public static void reportTouchingTweakableInEntityConstructor() { /** * Instrument callouts and report touching tweaked values or entity ctor (which should be RT) * @see InstrumentationCmds#markAllEntityCtorsForSIDetection() - * @see optimus.debug.RTVerifierCategory#TWEAK_IN_ENTITY_CTOR + * @see RTVerifierCategory#TWEAK_IN_ENTITY_CTOR */ public static void reportFindingTweaksInEntityConstructor() { InstrumentationConfig.addVerifyScenarioStackCalls(); @@ -365,7 +375,7 @@ public static void traceSelfAndParentOnException() { /** * When traceSelfAndParentOnException or individual exception reporting is enabled, some call stacks can be disabled * @param methodToPatch fully specified method reference - * @see optimus.debug.InstrumentationCmds#traceSelfAndParentOnException() + * @see InstrumentationCmds#traceSelfAndParentOnException() */ public static void excludeMethodFromExceptionReporting(String methodToPatch) { MethodRef mref = asMethodRef(methodToPatch); diff --git a/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationConfig.java b/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationConfig.java index b3cc9a8..30c8db6 100644 --- a/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationConfig.java +++ b/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationConfig.java @@ -11,6 +11,8 @@ */ package optimus.debug; +import static org.objectweb.asm.Type.VOID_TYPE; + import java.util.ArrayList; import java.util.HashMap; import java.util.function.BiPredicate; @@ -27,7 +29,7 @@ enum EntityInstrumentationType { public class InstrumentationConfig { final private static HashMap clsPatches = new HashMap<>(); final private static HashMap entityClasses = new HashMap<>(); - final private static HashMap moduleExclusions = new HashMap<>(); + final private static HashMap moduleOrEntityExclusions = new HashMap<>(); final private static ArrayList multiClsPatches = new ArrayList<>(); public static boolean setExceptionHookToTraceAsNodeOnStartup; @@ -56,6 +58,9 @@ public class InstrumentationConfig { private static final String CACHED_VALUE_DESC = "L" + CACHED_VALUE_TYPE + ";"; private static final String CACHED_FUNC_DESC = "(I" + OBJECT_DESC + OBJECT_ARR_DESC + ")" + CACHED_VALUE_DESC; + public static final String THROWABLE = "java/lang/Throwable"; + public static final Type THROWABLE_TYPE = Type.getObjectType(THROWABLE); + static final String __constructedAt = "constructedAt"; private static final MethodRef traceAsNode = new MethodRef(IS, "traceAsNode"); @@ -76,8 +81,15 @@ public class InstrumentationConfig { static InstrumentationConfig.MethodRef iecPause = new InstrumentationConfig.MethodRef(IEC_TYPE, "pauseReporting", "()I"); static InstrumentationConfig.MethodRef iecResume = new InstrumentationConfig.MethodRef(IEC_TYPE, "resumeReporting", "(I)V"); - public final static String CALL_WITH_ARGS_NAME = "CallWithArgs"; - public final static String CALL_WITH_ARGS = IS + "$" + CALL_WITH_ARGS_NAME; + public final static String CWA_INNER_NAME = "CallWithArgs"; + public final static String CWA = IS + "$" + CWA_INNER_NAME; + public final static Type CWA_TYPE = Type.getObjectType(CWA); + + static final MethodRef cwaPrefix = new MethodRef(IS, "cwaPrefix", Type.getMethodDescriptor(CWA_TYPE, CWA_TYPE)); + static final MethodRef cwaSuffix = new MethodRef(IS, "cwaSuffix", Type.getMethodDescriptor(VOID_TYPE, CWA_TYPE, OBJECT_TYPE)); + static final MethodRef cwaSuffixOnException = + new MethodRef(IS, "cwaSuffixOnException", Type.getMethodDescriptor(VOID_TYPE, CWA_TYPE, THROWABLE_TYPE)); + static final MethodRef nativePrefix = new MethodRef(IS, "nativePrefix", "()V"); static final MethodRef nativeSuffix = new MethodRef(IS, "nativeSuffix", "()V"); @@ -122,15 +134,15 @@ public static boolean isEntity(String className, String superName) { return true; } - static boolean isModuleExcluded(String className) { - synchronized (moduleExclusions) { - return moduleExclusions.containsKey(className); + static boolean isModuleOrEntityExcluded(String className) { + synchronized (moduleOrEntityExclusions) { + return moduleOrEntityExclusions.containsKey(className); } } - static void addModuleExclusion(String className) { - synchronized (moduleExclusions) { - moduleExclusions.put(className, Boolean.TRUE); + static void addModuleOrEntityExclusion(String className) { + synchronized (moduleOrEntityExclusions) { + moduleOrEntityExclusions.put(className, Boolean.TRUE); } } @@ -188,18 +200,20 @@ static class MethodPatch { final public MethodRef from; public MethodRef prefix; public MethodRef suffix; + public MethodRef suffixOnException; // suffix call used if an exception was thrown FieldRef cacheInField; boolean checkAndReturn; + boolean localValueIsCallWithArgs; // constructs an object of a class derived from CallWithArgs instead of passing object[] boolean prefixWithID; boolean prefixWithThis; boolean prefixWithArgs; - boolean passLocalValue; + boolean passLocalValue; // saves the result of the prefix into a value, and than passes it to the suffix boolean suffixWithID; boolean suffixWithThis; - boolean suffixWithReturnValue; + boolean suffixWithReturnValue; // calls suffix, then it returns the original value boolean suffixWithArgs; boolean suffixNoArgumentBoxing; - boolean suffixWithCallArgs; /// All arguments are packaged into custom generated class + boolean wrapWithTryCatch; // adds a try catch and invokes suffixOnException in case of exception thrown FieldRef storeToField; ClassPatch classPatch; BiPredicate predicate; @@ -388,13 +402,21 @@ static MethodPatch addSuffixCall(MethodRef from, MethodRef to) { return methodPatch; } - public static MethodPatch addSuffixCallWithCallsArgs(MethodRef from, MethodRef to) { - var methodPatch = putIfAbsentMethodPatch(from); - methodPatch.suffix = to; - methodPatch.suffixWithCallArgs = true; - return methodPatch; + public static MethodPatch addDefaultRecording(MethodRef from) { + return addRecording(from, cwaPrefix, cwaSuffix, cwaSuffixOnException); } + public static MethodPatch addRecording(MethodRef from, MethodRef prefix, MethodRef suffix, MethodRef onException) { + MethodPatch methodPatch = putIfAbsentMethodPatch(from); + methodPatch.localValueIsCallWithArgs = true; + methodPatch.wrapWithTryCatch = true; + methodPatch.passLocalValue = true; + methodPatch.suffixWithReturnValue = true; + methodPatch.prefix = prefix; + methodPatch.suffix = suffix; + methodPatch.suffixOnException = onException; + return methodPatch; + } /** Hook should probably be added just once, but it's actually OK to call it multiple times */ private static void addExceptionHook() { @@ -488,12 +510,15 @@ static FieldRef addRecordPrefixCallIntoMemberWithStackTrace(String fieldName, Me return fieldRef; } - public static void addModuleConstructionIntercept(String clsName) { + public static ClassPatch addModuleConstructionIntercept(String clsName) { var mref = new MethodRef(clsName, ""); addPrefixCall(mref, InstrumentationConfig.imcEnterCtor, false, false); addSuffixCall(mref, InstrumentationConfig.imcExitCtor); - putIfAbsentClassPatch(clsName).bracketAllLzyComputes = true; + var classPatch = putIfAbsentClassPatch(clsName); + classPatch.bracketAllLzyComputes = true; + return classPatch; } + public static void addAllMethodPatchAndChangeSuper( Object id, Predicate classPredicate, diff --git a/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationInjector.java b/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationInjector.java index d42024f..bd02dc5 100644 --- a/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationInjector.java +++ b/optimus/platform/projects/entityagent/src/main/java/optimus/debug/InstrumentationInjector.java @@ -14,15 +14,7 @@ import static optimus.debug.EntityInstrumentationType.markScenarioStack; import static optimus.debug.EntityInstrumentationType.none; import static optimus.debug.EntityInstrumentationType.recordConstructedAt; -import static optimus.debug.InstrumentationConfig.CACHED_VALUE_TYPE; -import static optimus.debug.InstrumentationConfig.CALL_WITH_ARGS; -import static optimus.debug.InstrumentationConfig.OBJECT_ARR_DESC; -import static optimus.debug.InstrumentationConfig.OBJECT_DESC; -import static optimus.debug.InstrumentationConfig.OBJECT_TYPE; -import static optimus.debug.InstrumentationConfig.instrumentAllNativePackagePrefixes; -import static optimus.debug.InstrumentationConfig.patchForSuffixAsNode; -import static optimus.debug.InstrumentationConfig.patchForCachingMethod; -import static optimus.debug.InstrumentationConfig.patchForBracketingLzyCompute; +import static optimus.debug.InstrumentationConfig.*; import static optimus.debug.InstrumentationInjector.ENTITY_DESC; import static optimus.debug.InstrumentationInjector.SCALA_NOTHING; @@ -57,22 +49,22 @@ public class InstrumentationInjector implements ClassFileTransformer { public byte[] transform(ClassLoader loader, String className, Class classBeingRedefined, ProtectionDomain protectionDomain, byte[] bytes) throws IllegalClassFormatException { - InstrumentationConfig.ClassPatch patch = InstrumentationConfig.forClass(className); + ClassPatch patch = forClass(className); if (instrumentAllNativePackagePrefixes != null && className.startsWith(instrumentAllNativePackagePrefixes)) { - patch = new InstrumentationConfig.ClassPatch(); + patch = new ClassPatch(); patch.wrapNativeCalls = true; } - if (patch == null && !InstrumentationConfig.instrumentAnyGroups()) + if (patch == null && !instrumentAnyGroups()) return bytes; ClassReader crSource = new ClassReader(bytes); - var entityInstrType = InstrumentationConfig.instrumentAllEntities; - if (entityInstrType != none && InstrumentationConfig.isEntity(className, crSource.getSuperName())) { - if (entityInstrType == markScenarioStack) InstrumentationConfig.addMarkScenarioStackAsInitializing(className); - else if (entityInstrType == recordConstructedAt) InstrumentationConfig.recordConstructorInvocationSite(className); + var entityInstrType = instrumentAllEntities; + if (entityInstrType != none && shouldInstrumentEntity(className, crSource.getSuperName())) { + if (entityInstrType == markScenarioStack) addMarkScenarioStackAsInitializing(className); + else if (entityInstrType == recordConstructedAt) recordConstructorInvocationSite(className); if (patch == null) - patch = InstrumentationConfig.forClass(className); // Re-read reconfigured value + patch = forClass(className); // Re-read reconfigured value } boolean addHashCode = shouldAddHashCode(loader, crSource, className); @@ -80,22 +72,18 @@ public byte[] transform(ClassLoader loader, String className, Class classBein setForwardMethodToNewHashCode(className, patch); if (patch == null && addHashCode) { - patch = new InstrumentationConfig.ClassPatch(); + patch = new ClassPatch(); setForwardMethodToNewHashCode(className, patch); } - if (InstrumentationConfig.instrumentAllEntityApplies && shouldCacheApplyMethods(crSource, className)) { + if (instrumentAllEntityApplies && shouldCacheApplyMethods(crSource, className)) { if (patch == null) - patch = new InstrumentationConfig.ClassPatch(); + patch = new ClassPatch(); patch.cacheAllApplies = true; } - if (InstrumentationConfig.instrumentAllModuleConstructors && shouldAddModuleConstructorBracketing(loader, className)) { - InstrumentationConfig.addModuleConstructionIntercept(className); - if (patch == null) - patch = InstrumentationConfig.forClass(className); // Re-read reconfigured value - patch.bracketAllLzyComputes = true; - } + if (instrumentAllModuleConstructors && shouldInstrumentModuleCtor(loader, className)) + patch = addModuleConstructionIntercept(className); if (patch == null) return bytes; @@ -116,21 +104,18 @@ private boolean shouldCacheApplyMethods(ClassReader crSource, String className) return false; } - private boolean shouldAddModuleConstructorBracketing(ClassLoader loader, String className) { - if (loader == null) - return false; - - if (!className.endsWith("$")) - return false; // Looking for companion objects - if (className.startsWith("scala")) - return false; + private boolean shouldInstrumentEntity(String className, String superName) { + return isEntity(className, superName) && !isModuleOrEntityExcluded(className); + } - return !InstrumentationConfig.isModuleExcluded(className); + private boolean shouldInstrumentModuleCtor(ClassLoader loader, String className) { + var isModuleCtor = loader != null && className.endsWith("$") && !className.startsWith("scala"); + return isModuleCtor && !isModuleOrEntityExcluded(className); } private boolean shouldAddHashCode(ClassLoader loader, ClassReader crSource, String className) { - if (!InstrumentationConfig.instrumentAllHashCodes) + if (!instrumentAllHashCodes) return false; // Interfaces are not included if ((crSource.getAccess() & ACC_INTERFACE) != 0) @@ -150,18 +135,18 @@ private boolean shouldAddHashCode(ClassLoader loader, ClassReader crSource, Stri return !className.startsWith("sun/") && !className.startsWith("java/security"); } - private void setForwardMethodToNewHashCode(String className, InstrumentationConfig.ClassPatch patch) { - var mrHashCode = new InstrumentationConfig.MethodRef(className, "hashCode", "()I"); - patch.methodForward = new InstrumentationConfig.MethodForward(mrHashCode, InstrumentedHashCodes.mrHashCode); + private void setForwardMethodToNewHashCode(String className, ClassPatch patch) { + var mrHashCode = new MethodRef(className, "hashCode", "()I"); + patch.methodForward = new MethodForward(mrHashCode, InstrumentedHashCodes.mrHashCode); } } class InstrumentationInjectorAdapter extends ClassVisitor implements Opcodes { - private final InstrumentationConfig.ClassPatch classPatch; + private final ClassPatch classPatch; private final String className; private boolean seenForwardedMethod; - InstrumentationInjectorAdapter(InstrumentationConfig.ClassPatch patch, String className, ClassVisitor cv) { + InstrumentationInjectorAdapter(ClassPatch patch, String className, ClassVisitor cv) { super(ASM9, cv); this.classPatch = patch; this.className = className; @@ -243,7 +228,7 @@ private void writeNativeWrapper(int access, String name, String desc, String sig var mv = new GeneratorAdapter(mvWriter, useAccess, name, desc); mv.visitCode(); - mv.visitMethodInsn(INVOKESTATIC, InstrumentationConfig.nativePrefix.cls, InstrumentationConfig.nativePrefix.method, InstrumentationConfig.nativePrefix.descriptor, false); + mv.visitMethodInsn(INVOKESTATIC, nativePrefix.cls, nativePrefix.method, nativePrefix.descriptor, false); // call original method mv.loadArgs(); @@ -254,10 +239,10 @@ private void writeNativeWrapper(int access, String name, String desc, String sig mv.dup(); mv.loadArgs(); var descX= "(ZJ"+ OBJECT_DESC + "J" + OBJECT_DESC + ")V"; - mv.visitMethodInsn(INVOKESTATIC, InstrumentationConfig.nativeSuffix.cls, InstrumentationConfig.nativeSuffix.method, descX, false); + mv.visitMethodInsn(INVOKESTATIC, nativeSuffix.cls, nativeSuffix.method, descX, false); writeNativeMethodCall(name, desc); } else - mv.visitMethodInsn(INVOKESTATIC, InstrumentationConfig.nativeSuffix.cls, InstrumentationConfig.nativeSuffix.method, "()V", false); + mv.visitMethodInsn(INVOKESTATIC, nativeSuffix.cls, nativeSuffix.method, "()V", false); mv.returnValue(); mv.visitMaxs(0, 0); @@ -277,8 +262,8 @@ public MethodVisitor visitMethod(int access, String name, String desc, String si if (classPatch.allMethodsPatch != null) return new InstrumentationInjectorMethodVisitor(classPatch.allMethodsPatch, mv, access, name, desc); - InstrumentationConfig.MethodPatch methodPatch = classPatch.forMethod(name, desc); - if (methodPatch != null) + MethodPatch methodPatch = classPatch.forMethod(name, desc); + if (methodPatch != null && (methodPatch.predicate == null || methodPatch.predicate.test(name, desc))) return new InstrumentationInjectorMethodVisitor(methodPatch, mv, access, name, desc); else if (classPatch.cacheAllApplies && name.equals("apply") && isCreateEntityMethod(desc)) return new InstrumentationInjectorMethodVisitor(patchForCachingMethod(className, name), mv, access, name, desc); @@ -290,7 +275,7 @@ else if (classPatch.bracketAllLzyComputes && name.endsWith("$lzycompute")) return mv; // Just copy the entire method } - private void writeGetterMethod(InstrumentationConfig.GetterMethod getter) { + private void writeGetterMethod(GetterMethod getter) { var desc = getter.mRef.descriptor == null ? "()" + OBJECT_DESC : getter.mRef.descriptor; MethodVisitor mv = cv.visitMethod(ACC_PUBLIC, getter.mRef.method, desc, null, null); mv.visitCode(); @@ -327,7 +312,7 @@ private void writeEqualsForCachingOverride() { mv.visitEnd(); } - private void writeImplementForwardCall(InstrumentationConfig.MethodForward forwards) { + private void writeImplementForwardCall(MethodForward forwards) { assert forwards.from.descriptor != null; var mv = cv.visitMethod(ACC_PUBLIC, forwards.from.method, forwards.from.descriptor, null, null); var mv2 = new GeneratorAdapter(mv, ACC_PUBLIC, forwards.from.method, forwards.from.descriptor); @@ -368,28 +353,43 @@ public void visitEnd() { } class InstrumentationInjectorMethodVisitor extends AdviceAdapter implements Opcodes { - private final InstrumentationConfig.MethodPatch patch; + private final MethodPatch patch; private final Label __localValueStart = new Label(); private final Label __localValueEnd = new Label(); private int __localValue; // When local passing is enabled this will point to a slot for local var private Type localValueType; private String localValueDesc; + private String callWithArgsCtorDescriptor; + private String callWithArgsType; private int methodID; // If allocation requested private boolean thisIsAvailable; - private boolean doTransform; - InstrumentationInjectorMethodVisitor(InstrumentationConfig.MethodPatch patch, MethodVisitor mv, int access, + private final Label tryBlockStart = new Label(); + private final Label tryBlockEnd = new Label(); + private final Label catchBlockStart = new Label(); + private final Label catchBlockEnd = new Label(); + InstrumentationInjectorMethodVisitor(MethodPatch patch, MethodVisitor mv, int access, String name, String descriptor) { super(ASM9, mv, access, name, descriptor); this.patch = patch; - this.doTransform = patch.predicate == null || patch.predicate.test(name, methodDesc); - - if (patch.passLocalValue) { + if (patch.passLocalValue && !patch.localValueIsCallWithArgs) { localValueType = patch.prefix.descriptor != null ? Type.getMethodType(patch.prefix.descriptor).getReturnType() : OBJECT_TYPE; localValueDesc = localValueType.getDescriptor(); } + + if (patch.localValueIsCallWithArgs) { + // we generate a CallWithArgs instance... + Type thisOwner = Type.getObjectType(patch.from.cls); + callWithArgsType = CallWithArgsGenerator.generateClassName(getName()); + byte[] newBytes = CallWithArgsGenerator.create(callWithArgsType, thisOwner, getArgumentTypes(), getReturnType(), getName()); + DynamicClassLoader.loadClassInCurrentClassLoader(newBytes); + callWithArgsCtorDescriptor = CallWithArgsGenerator.getCtrDescriptor(thisOwner, getArgumentTypes()); + // ...and we pass our CallWithArgs instance to the suffix calls + localValueType = Type.getObjectType(InstrumentationConfig.CWA); + localValueDesc = localValueType.getDescriptor(); + } } private void dupReturnValueOrNullForVoid(int opcode, boolean boxValueTypes) { @@ -408,7 +408,7 @@ else if (opcode == LRETURN || opcode == DRETURN) private String loadMethodID() { if (methodID == 0) - methodID = InstrumentationConfig.allocateID(patch.from); + methodID = allocateID(patch.from); mv.visitIntInsn(SIPUSH, methodID); return "I"; } @@ -421,7 +421,15 @@ private String loadThisOrNull() { return OBJECT_DESC; } - private void ifNotZeroReturn(InstrumentationConfig.FieldRef fpatch) { + private String loadLocalValueIfRequested() { + if (patch.passLocalValue) { + mv.visitVarInsn(localValueType.getOpcode(ILOAD), __localValue); + return localValueType.getDescriptor(); + } + return ""; + } + + private void ifNotZeroReturn(FieldRef fpatch) { loadThis(); mv.visitFieldInsn(GETFIELD, patch.from.cls, fpatch.name, fpatch.type); Label label1 = new Label(); @@ -436,8 +444,10 @@ private void injectMethodPrefix() { if (patch.cacheInField != null) ifNotZeroReturn(patch.cacheInField); - if (patch.prefix == null) - return; + if (patch.prefix == null && !patch.localValueIsCallWithArgs) + return; // prefix is implied for localValueIsCallWithArgs + + MethodRef prefix = patch.prefix; if (patch.passLocalValue) { visitLabel(__localValueStart); @@ -448,23 +458,34 @@ private void injectMethodPrefix() { if (patch.prefixWithID) descriptor += loadMethodID(); - if(patch.prefixWithThis) { + if(patch.prefixWithThis) descriptor += loadThisOrNull(); - } + if (patch.prefixWithArgs) { loadArgArray(); descriptor += OBJECT_ARR_DESC; } + if (patch.passLocalValue || patch.storeToField != null) descriptor += ")" + OBJECT_DESC; else descriptor += ")V"; - // If descriptor was supplied just use that - if (patch.prefix.descriptor != null) - descriptor = patch.prefix.descriptor; + if (patch.localValueIsCallWithArgs) { + mv.visitTypeInsn(NEW, callWithArgsType); + dup(); + loadThis(); + loadArgs(); + mv.visitMethodInsn(INVOKESPECIAL, callWithArgsType, "", callWithArgsCtorDescriptor, false); + } + + if (prefix != null) { + if (prefix.descriptor != null) // If descriptor was supplied just use that + descriptor = prefix.descriptor; + + mv.visitMethodInsn(INVOKESTATIC, prefix.cls, prefix.method, descriptor, false); + } - mv.visitMethodInsn(Opcodes.INVOKESTATIC, patch.prefix.cls, patch.prefix.method, descriptor, false); if (patch.passLocalValue) { mv.visitVarInsn(localValueType.getOpcode(ISTORE), __localValue); } else if (patch.storeToField != null) { @@ -484,13 +505,17 @@ private void injectMethodPrefix() { mv.visitInsn(ARETURN); mv.visitLabel(continueLabel); } + + if (patch.wrapWithTryCatch) { + mv.visitTryCatchBlock(tryBlockStart, tryBlockEnd, catchBlockStart, THROWABLE); + mv.visitLabel(tryBlockStart); + } } @Override public void visitCode() { super.visitCode(); - if(doTransform) - injectMethodPrefix(); + injectMethodPrefix(); } @Override @@ -501,11 +526,8 @@ protected void onMethodEnter() { @Override protected void onMethodExit(int opcode) { - if(!doTransform) - return; - - if(opcode == ATHROW && patch.suffixNoArgumentBoxing) - return; // DO not generate exit call at the point of exception throw + if (opcode == ATHROW && (patch.suffixNoArgumentBoxing || patch.wrapWithTryCatch)) + return; // do not generate exit call at the point of exception throw if (patch.cacheInField != null) { dup(); @@ -523,32 +545,17 @@ protected void onMethodExit(int opcode) { dupReturnValueOrNullForVoid(opcode, !patch.suffixNoArgumentBoxing); descriptor += OBJECT_DESC; } - if (patch.passLocalValue) { - mv.visitVarInsn(localValueType.getOpcode(ILOAD), __localValue); - descriptor += localValueType.getDescriptor(); - } + + descriptor += loadLocalValueIfRequested(); + if (patch.suffixWithID) descriptor += loadMethodID(); + if (patch.suffixWithThis) descriptor += loadThisOrNull(); - if(patch.suffixWithArgs) { - loadArgs(); - } - if(patch.suffixWithCallArgs) { - Type thisOwner = Type.getObjectType(patch.from.cls); - String newClsName = CallWithArgsGenerator.generateClassName(getName()); - byte[] newBytes = CallWithArgsGenerator.create(newClsName, thisOwner, getArgumentTypes(), getReturnType(), getName()); - DynamicClassLoader.loadClassInCurrentClassLoader(newBytes); - String ctrDescriptor = CallWithArgsGenerator.getCtrDescriptor(thisOwner, getArgumentTypes()); - dup(); - mv.visitTypeInsn(NEW, newClsName); - dup(); - loadThis(); + + if (patch.suffixWithArgs) loadArgs(); - mv.visitMethodInsn(INVOKESPECIAL, newClsName, "", ctrDescriptor, false); - descriptor += getReturnType().getDescriptor(); - descriptor += "L" + CALL_WITH_ARGS + ";"; - } descriptor += ")V"; @@ -556,12 +563,12 @@ protected void onMethodExit(int opcode) { if (patch.suffix.descriptor != null) descriptor = patch.suffix.descriptor; - mv.visitMethodInsn(Opcodes.INVOKESTATIC, patch.suffix.cls, patch.suffix.method, descriptor, false); + mv.visitMethodInsn(INVOKESTATIC, patch.suffix.cls, patch.suffix.method, descriptor, false); } @Override public void visitMethodInsn(int opcodeAndSource, String owner, String name, String descriptor, boolean isInterface) { - if(patch.classPatch != null && patch.classPatch.replaceObjectAsBase != null && + if (patch.classPatch != null && patch.classPatch.replaceObjectAsBase != null && opcodeAndSource == INVOKESPECIAL && owner.equals(OBJECT_TYPE.getInternalName())) owner = patch.classPatch.replaceObjectAsBase; super.visitMethodInsn(opcodeAndSource, owner, name, descriptor, isInterface); @@ -569,10 +576,26 @@ public void visitMethodInsn(int opcodeAndSource, String owner, String name, Stri @Override public void visitMaxs(int maxStack, int maxLocals) { - if (doTransform && patch.passLocalValue) { + if (patch.wrapWithTryCatch) { + visitLabel(tryBlockEnd); + mv.visitInsn(NOP); + visitLabel(catchBlockStart); + + var suffixOnException = patch.suffixOnException; + if (suffixOnException != null) { + dup(); + var descriptor = "(" + THROWABLE_TYPE.getDescriptor() + loadLocalValueIfRequested() + ")V"; + mv.visitMethodInsn(INVOKESTATIC, suffixOnException.cls, suffixOnException.method, descriptor, false); + } + throwException(); + visitLabel(catchBlockEnd); + } + + if (patch.passLocalValue) { visitLabel(__localValueEnd); mv.visitLocalVariable("__locValue", localValueDesc, null, __localValueStart, __localValueEnd, __localValue); } + super.visitMaxs(maxStack, maxLocals); } } \ No newline at end of file diff --git a/optimus/platform/projects/entityagent/src/main/java/optimus/graph/rtverifier/CallWithArgsGenerator.java b/optimus/platform/projects/entityagent/src/main/java/optimus/graph/rtverifier/CallWithArgsGenerator.java index 8e721bc..c316093 100644 --- a/optimus/platform/projects/entityagent/src/main/java/optimus/graph/rtverifier/CallWithArgsGenerator.java +++ b/optimus/platform/projects/entityagent/src/main/java/optimus/graph/rtverifier/CallWithArgsGenerator.java @@ -14,8 +14,8 @@ import java.util.concurrent.atomic.AtomicInteger; import static optimus.debug.InstrumentationConfig.IS; -import static optimus.debug.InstrumentationConfig.CALL_WITH_ARGS_NAME; -import static optimus.debug.InstrumentationConfig.CALL_WITH_ARGS; +import static optimus.debug.InstrumentationConfig.CWA_INNER_NAME; +import static optimus.debug.InstrumentationConfig.CWA; import static optimus.debug.InstrumentationConfig.OBJECT_ARR_DESC; import static optimus.debug.InstrumentationConfig.OBJECT_TYPE; @@ -43,8 +43,8 @@ public static byte[] create(String className, Type originalOwner, Type[] args, T private static ClassWriter createClassWriter(String className, Type originalOwner, Type[] args, Type returnType, String originalMethod) { ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES); - cw.visit(V11, ACC_PUBLIC | ACC_SUPER, className, null, CALL_WITH_ARGS, null); - cw.visitInnerClass(CALL_WITH_ARGS, IS, CALL_WITH_ARGS_NAME, ACC_PUBLIC | ACC_STATIC | ACC_ABSTRACT); + cw.visit(V11, ACC_PUBLIC | ACC_SUPER, className, null, CWA, null); + cw.visitInnerClass(CWA, IS, CWA_INNER_NAME, ACC_PUBLIC | ACC_STATIC | ACC_ABSTRACT); generateFields(cw, originalOwner, args); generateCtor(cw, className, originalOwner, args); generateArgsMethod(cw, className, args); @@ -71,7 +71,7 @@ private static void generateCtor(ClassWriter cw, String className, Type original mv.visitCode(); mv.visitVarInsn(ALOAD, 0); - mv.visitMethodInsn(INVOKESPECIAL, CALL_WITH_ARGS, "", "()V", false); + mv.visitMethodInsn(INVOKESPECIAL, CWA, "", "()V", false); var argSlot = 1; argSlot += assignField(mv, className, argSlot, "original", originalOwner); @@ -128,11 +128,11 @@ private static void generateReApplyMethod(ClassWriter cw, String className, Type var mv = new GeneratorAdapter(mvWriter, ACC_PUBLIC, "reApply", descriptor); mv.visitCode(); - mv.visitVarInsn(ALOAD, 0); + mv.loadThis(); mv.visitFieldInsn(GETFIELD, className, "original", originalOwner.getDescriptor()); for (int i = 0; i < args.length; i++) { - mv.visitVarInsn(ALOAD, 0); + mv.loadThis(); mv.visitFieldInsn(GETFIELD, className, "arg" + i, args[i].getDescriptor()); } diff --git a/optimus/platform/projects/entityagent/src/main/java/optimus/systemexit/SystemExitInterceptedException.java b/optimus/platform/projects/entityagent/src/main/java/optimus/systemexit/SystemExitInterceptedException.java index e0a0894..895ff30 100644 --- a/optimus/platform/projects/entityagent/src/main/java/optimus/systemexit/SystemExitInterceptedException.java +++ b/optimus/platform/projects/entityagent/src/main/java/optimus/systemexit/SystemExitInterceptedException.java @@ -17,11 +17,14 @@ * tests). Utilized when exit.intercept system property is set to 'intercept-all' */ public class SystemExitInterceptedException extends RuntimeException { - public SystemExitInterceptedException() { - super("[EXIT-INTERCEPT] instrumented System.exit exception"); + private final int exitStatus; + + public SystemExitInterceptedException(int exitStatus) { + super("[EXIT-INTERCEPT] instrumented System.exit(" + exitStatus + ") exception"); + this.exitStatus = exitStatus; } - public SystemExitInterceptedException(String message) { - super(message); + public int exitStatus() { + return exitStatus; } } diff --git a/optimus/platform/projects/entityagent/src/main/java/optimus/systemexit/SystemExitReplacement.java b/optimus/platform/projects/entityagent/src/main/java/optimus/systemexit/SystemExitReplacement.java index 37b65c6..5483f01 100644 --- a/optimus/platform/projects/entityagent/src/main/java/optimus/systemexit/SystemExitReplacement.java +++ b/optimus/platform/projects/entityagent/src/main/java/optimus/systemexit/SystemExitReplacement.java @@ -65,7 +65,7 @@ public static void exitImpl(int status) { hook.accept(status); } } - throw new SystemExitInterceptedException(); + throw new SystemExitInterceptedException(status); } else { logger.debug("[EXIT-INTERCEPT] normal exit"); diff --git a/optimus/platform/projects/git-utils/src/main/scala/optimus/git/diffparser/DiffParser.java b/optimus/platform/projects/git-utils/src/main/scala/optimus/git/diffparser/DiffParser.java index 8a96b61..f47af91 100644 --- a/optimus/platform/projects/git-utils/src/main/scala/optimus/git/diffparser/DiffParser.java +++ b/optimus/platform/projects/git-utils/src/main/scala/optimus/git/diffparser/DiffParser.java @@ -41,9 +41,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Every method here must be, otherwise the parser does not work properly. - */ +/** Every method here must be, otherwise the parser does not work properly. */ public enum DiffParser { INITIAL { @Override @@ -62,9 +60,7 @@ public PartialDiff modifyDiff(PartialDiff diff, String currentLine) { return diff; } }, - /** - * The parser is in this state if it is currently parsing a header line. - */ + /** The parser is in this state if it is currently parsing a header line. */ HEADER { @Override public DiffParser nextState(ParseWindow window) { @@ -85,8 +81,8 @@ public PartialDiff modifyDiff(PartialDiff diff, String currentLine) { }, /** * The parser is in this state if it is currently parsing the line containing the "from" file. - *

- * Example line:
+ * + *

Example line:
* {@code --- /path/to/file.txt} */ FROM_FILE { @@ -96,7 +92,8 @@ public DiffParser nextState(ParseWindow window) { if (isToFile(line)) { return TO_FILE; } else { - throw new IllegalStateException("A FROM_FILE line ('---') must be directly followed by a TO_FILE line ('+++')!"); + throw new IllegalStateException( + "A FROM_FILE line ('---') must be directly followed by a TO_FILE line ('+++')!"); } } @@ -107,8 +104,8 @@ public PartialDiff modifyDiff(PartialDiff diff, String currentLine) { }, /** * The parser is in this state if it is currently parsing the line containing the "to" file. - *

- * Example line:
+ * + *

Example line:
* {@code +++ /path/to/file.txt} */ TO_FILE { @@ -118,7 +115,8 @@ public DiffParser nextState(ParseWindow window) { if (Hunk.isHunkStart(line)) { return HUNK_START; } else { - throw new IllegalStateException("A TO_FILE line ('+++') must be directly followed by a HUNK_START line ('@@')!"); + throw new IllegalStateException( + "A TO_FILE line ('+++') must be directly followed by a HUNK_START line ('@@')!"); } } @@ -129,8 +127,8 @@ public PartialDiff modifyDiff(PartialDiff diff, String currentLine) { }, /** * The parser is in this state if it is currently parsing a line containing the header of a hunk. - *

- * Example line:
+ * + *

Example line:
* {@code @@ -1,5 +2,6 @@} */ HUNK_START { @@ -152,10 +150,10 @@ public PartialDiff modifyDiff(PartialDiff diff, String currentLine) { } }, /** - * The parser is in this state if it is currently parsing a line containing a line that is in the first file, - * but not the second (a "from" line). - *

- * Example line:
+ * The parser is in this state if it is currently parsing a line containing a line that is in the + * first file, but not the second (a "from" line). + * + *

Example line:
* {@code - only the dash at the start is important} */ FROM_LINE { @@ -183,10 +181,10 @@ public PartialDiff modifyDiff(PartialDiff diff, String currentLine) { } }, /** - * The parser is in this state if it is currently parsing a line containing a line that is in the second file, - * but not the first (a "to" line). - *

- * Example line:
+ * The parser is in this state if it is currently parsing a line containing a line that is in the + * second file, but not the first (a "to" line). + * + *

Example line:
* {@code + only the plus at the start is important} */ TO_LINE { @@ -214,8 +212,8 @@ public PartialDiff modifyDiff(PartialDiff diff, String currentLine) { } }, /** - * The parser is in this state if it is currently parsing a line that is contained in both files (a "neutral" line). - * This line can contain any string. + * The parser is in this state if it is currently parsing a line that is contained in both files + * (a "neutral" line). This line can contain any string. */ NEUTRAL_LINE { @Override @@ -242,8 +240,8 @@ public PartialDiff modifyDiff(PartialDiff diff, String currentLine) { } }, /** - * The parser is in this state if it is currently parsing a line that is the delimiter between two Diffs. - * This line is always a new line. + * The parser is in this state if it is currently parsing a line that is the delimiter between two + * Diffs. This line is always a new line. */ END { @Override @@ -259,8 +257,8 @@ public PartialDiff modifyDiff(PartialDiff diff, String currentLine) { }; /** - * Returns the next state of the state machine depending on the current state and the content of a window of lines around the line - * that is currently being parsed. + * Returns the next state of the state machine depending on the current state and the content of a + * window of lines around the line that is currently being parsed. * * @param window the window around the line currently being parsed. * @return the next state of the state machine. @@ -274,8 +272,9 @@ private static void logTransition(String currentLine, DiffParser fromState, Diff } public static List parse(String text) { - ParseWindow window = new ParseWindow(new ByteArrayInputStream(text.stripLeading() - .getBytes(StandardCharsets.UTF_8))); + ParseWindow window = + new ParseWindow( + new ByteArrayInputStream(text.stripLeading().getBytes(StandardCharsets.UTF_8))); DiffParser state = INITIAL; List parsedDiffs = new ArrayList<>(); PartialDiff currentDiff = PartialDiff.empty(); @@ -335,8 +334,7 @@ protected static boolean isEnd(String currentLine, ParseWindow window) { } else { return false; } - } private static final Logger logger = LoggerFactory.getLogger(DiffParser.class); -} \ No newline at end of file +} diff --git a/optimus/platform/projects/git-utils/src/main/scala/optimus/git/diffparser/ParseWindow.java b/optimus/platform/projects/git-utils/src/main/scala/optimus/git/diffparser/ParseWindow.java index 889fbd8..409a135 100644 --- a/optimus/platform/projects/git-utils/src/main/scala/optimus/git/diffparser/ParseWindow.java +++ b/optimus/platform/projects/git-utils/src/main/scala/optimus/git/diffparser/ParseWindow.java @@ -36,9 +36,9 @@ import java.util.LinkedList; /** - * A {@link ParseWindow} slides through the lines of a input stream and - * offers methods to get the currently focused line as well as upcoming lines. - * It is backed by an automatically resizing {@link LinkedList} + * A {@link ParseWindow} slides through the lines of a input stream and offers methods to get the + * currently focused line as well as upcoming lines. It is backed by an automatically resizing + * {@link LinkedList} * * @author Tom Hombergs */ @@ -49,17 +49,15 @@ public ParseWindow(InputStream input) { } /** - * Looks ahead from the current line and retrieves a line that will be the - * focus line after the window has slided forward. + * Looks ahead from the current line and retrieves a line that will be the focus line after the + * window has slided forward. * - * @param distance the number of lines to look ahead. Must be greater or equal 0. - * 0 returns the focus line. 1 returns the first line after the - * current focus line and so on. Note that all lines up to the - * returned line will be held in memory until the window has - * slided past them, so be careful not to look ahead too far! - * @return the line identified by the distance parameter that lies ahead of - * the focus line. Returns null if the line cannot be read because - * it lies behind the end of the stream. + * @param distance the number of lines to look ahead. Must be greater or equal 0. 0 returns the + * focus line. 1 returns the first line after the current focus line and so on. Note that all + * lines up to the returned line will be held in memory until the window has slided past them, + * so be careful not to look ahead too far! + * @return the line identified by the distance parameter that lies ahead of the focus line. + * Returns null if the line cannot be read because it lies behind the end of the stream. */ public String getFutureLine(int distance) { try { @@ -68,7 +66,6 @@ public String getFutureLine(int distance) { } catch (IndexOutOfBoundsException ignored) { return null; } - } public void addLine(int pos, String line) { @@ -78,8 +75,7 @@ public void addLine(int pos, String line) { /** * Resizes the sliding window to the given size, if necessary. * - * @param newSize the new size of the window (i.e. the number of lines in the - * window). + * @param newSize the new size of the window (i.e. the number of lines in the window). */ private void resizeWindowIfNecessary(int newSize) { try { @@ -100,8 +96,8 @@ private void resizeWindowIfNecessary(int newSize) { /** * Slides the window forward one line. * - * @return the next line that is in the focus of this window or null if the - * end of the stream has been reached. + * @return the next line that is in the focus of this window or null if the end of the stream has + * been reached. */ public String slideForward() { try { @@ -121,7 +117,6 @@ public String slideForward() { } catch (IOException e) { throw new RuntimeException(e); } - } private String getNextLine() throws IOException { @@ -131,9 +126,9 @@ private String getNextLine() throws IOException { } /** - * Guarantees that a virtual blank line is injected at the end of the input - * stream to ensure the parser attempts to transition to the {@code END} - * state, if necessary, when the end of stream is reached. + * Guarantees that a virtual blank line is injected at the end of the input stream to ensure the + * parser attempts to transition to the {@code END} state, if necessary, when the end of stream is + * reached. */ private String getNextLineOrVirtualBlankLineAtEndOfStream(String nextLine) { if ((nextLine == null) && !isEndOfStream) { @@ -145,9 +140,8 @@ private String getNextLineOrVirtualBlankLineAtEndOfStream(String nextLine) { } /** - * Returns the line currently focused by this window. This is actually the - * same line as returned by {@link #slideForward()} but calling - * this method does not slide the window forward a step. + * Returns the line currently focused by this window. This is actually the same line as returned + * by {@link #slideForward()} but calling this method does not slide the window forward a step. * * @return the currently focused line. */ @@ -159,4 +153,4 @@ public String getFocusLine() { private LinkedList lineQueue = new LinkedList(); private int lineNumber = 0; private boolean isEndOfStream = false; -} \ No newline at end of file +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/MSNetSSLSocket.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/MSNetSSLSocket.java new file mode 100644 index 0000000..bb334c7 --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/MSNetSSLSocket.java @@ -0,0 +1,238 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import javax.annotation.Nullable; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.security.cert.CertificateException; + +import msjava.msnet.auth.MSNetAuthContext; +import msjava.msnet.auth.MSNetAuthStatus; +import msjava.msnet.ssl.SSLEncryptor; +import msjava.msnet.ssl.SSLEncryptorResult; +import msjava.msnet.ssl.SSLHandshaker; +import msjava.msnet.ssl.verification.CertificateVerifier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class MSNetSSLSocket extends MSNetTCPSocket { + private static final Logger LOGGER = LoggerFactory.getLogger(MSNetSSLSocket.class); + + private enum SSLSocketState { + HANDSHAKE_NEEDED, + HANDSHAKE_COMPLETED, + READY_FOR_ENCRYPTING + } + + private final SSLEngine sslEngine; + private final SSLEncryptor sslEncryptor; + private final SSLHandshaker sslHandshaker; + private final CertificateVerifier certificateVerifier; + + private MSNetTCPSocketBuffer encryptedBufferToWrite = new MSNetTCPSocketBuffer(); + + private final MSNetTCPSocketBuffer tmpReadBuffer = new MSNetTCPSocketBuffer(); + + private SSLSocketState socketState = SSLSocketState.HANDSHAKE_NEEDED; + + MSNetSSLSocket( + MSNetTCPSocketImpl impl, + MSNetTCPSocketFactory parentFactory, + SSLEngine sslEngine, + boolean slicingBuffers) { + this(impl, parentFactory, sslEngine, new SSLEncryptor(sslEngine, slicingBuffers)); + } + + MSNetSSLSocket( + MSNetTCPSocketImpl impl, + MSNetTCPSocketFactory parentFactory, + SSLEngine sslEngine, + SSLEncryptor sslEncryptor) { + super(impl, parentFactory); + this.sslEngine = sslEngine; + this.sslEncryptor = sslEncryptor; + this.sslHandshaker = new SSLHandshaker(sslEngine); + this.certificateVerifier = new CertificateVerifier(sslEngine); + } + + @Override + // we need to make sure we propagate any exceptions and readBytes=-1 up the chain, which signify + // error on the channel + public MSNetIOStatus read(MSNetTCPSocketBuffer destBuffer) { + // ssl handshake, raw messages should be relayed + if (socketState != SSLSocketState.READY_FOR_ENCRYPTING) { + return super.read(destBuffer); + } + + MSNetIOStatus result = super.read(tmpReadBuffer); + boolean readBytes = result.getNumBytesProcessed() > 0; + // if we read something from the socket, try to decrypt it and store correct size of decrypted + // message + if (readBytes) { + decryptAndUpdateIOResult(tmpReadBuffer, destBuffer, result); + } + return result; + } + + @Override + public MSNetIOStatus write(MSNetTCPSocketBuffer buf) { + return write(Collections.singletonList(buf)); + } + + @Override + public MSNetIOStatus write(List bufs) { + // ssl handshake, simply write bytes to underlying socket + if (socketState != SSLSocketState.READY_FOR_ENCRYPTING) { + MSNetIOStatus write = super.write(bufs); + if (isLastHandshakeMessageSucessfullyWritten(write)) { + setReadyForEncrypting(); + } + return write; + } + + int bytesWrittenSum = 0; + MSNetIOStatus status = new MSNetIOStatus(); + + for (MSNetTCPSocketBuffer unencryptedBuf : bufs) { + copyAndEncrypt(unencryptedBuf, encryptedBufferToWrite); + status = writeBuffer(encryptedBufferToWrite); + + boolean fullyWritten = encryptedBufferToWrite.size() == 0; + if (fullyWritten) { + // set number of the unencrypted buf size here!!!! The connection above keeps + // track of unencrypted message sizes! + bytesWrittenSum += unencryptedBuf.size(); + } + // we stop on the following conditions: + // 1. We did not manage to fully write to socket and we will later retry OR + // 2. There was an error writing to socket and we need to propagate it up + if (!fullyWritten || status.inError()) { + break; + } + } + + // update status with accumulated written size + status.setNumBytesProcessed(bytesWrittenSum); + return status; + } + + private void copyAndEncrypt( + MSNetTCPSocketBuffer rawBuffer, MSNetTCPSocketBuffer encryptedBuffer) { + // only copy the buffer if the encrypted buffer has no remaining message to be sent + // if it does, that means that this buffer was already copied on previous iteration, and was not + // fully sent + boolean noPreviousMessage = encryptedBuffer.size() == 0; + if (noPreviousMessage && rawBuffer.size() != 0) { + encryptedBuffer.store(rawBuffer.peek()); + encrypt(encryptedBuffer); + } + } + + private MSNetIOStatus writeBuffer(MSNetTCPSocketBuffer buf) { + MSNetIOStatus write = super.write(buf); + buf.processed(write.getNumBytesProcessed()); + return write; + } + + private void encrypt(MSNetTCPSocketBuffer buf) { + if (socketState == SSLSocketState.READY_FOR_ENCRYPTING && buf.size() != 0) { + SSLEncryptorResult encrypt = sslEncryptor.encrypt(buf); + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "Buffer initial size: " + + encrypt.getBytesConsumed() + + ", encrypted buffer size: " + + encrypt.getBytesProduced()); + } + } + } + + // update io status with the size of decrypted message instead of the message size read from raw + // socket + private void decryptAndUpdateIOResult( + MSNetTCPSocketBuffer originBuffer, MSNetTCPSocketBuffer destBuffer, MSNetIOStatus result) { + int originBufferSize = originBuffer.size(); + SSLEncryptorResult decrypted = sslEncryptor.decrypt(originBuffer, destBuffer); + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "Decrypted: " + + decrypted.getBytesProduced() + + ", buffer still contains: " + + originBuffer.size() + + " " + + "encrypted bytes"); + } + + // UPDATE the result's bytes processed to reflect the actual decrypted message byte size + result.setNumBytesInMessage(originBufferSize); + result.setNumBytesProcessed(decrypted.getBytesProduced()); + } + + private boolean isLastHandshakeMessageSucessfullyWritten(MSNetIOStatus write) { + return socketState == SSLSocketState.HANDSHAKE_COMPLETED + && write.getNumBytesInMessage() == write.getNumBytesProcessed(); + } + + public MSNetTCPSocketBuffer getOutputBuffer() { + return sslHandshaker.getOutputBuffer(); + } + + public boolean doHandshake(MSNetTCPSocketBuffer netData) throws Exception { + return sslHandshaker.doHandshake(netData); + } + + public boolean verifyCertificates(boolean encryptionOnly, @Nullable String serviceHostname) + throws CertificateException, SSLException, java.security.cert.CertificateException { + return certificateVerifier.verify( + Optional.ofNullable(serviceHostname).orElse(this.getAddress().getHost()), encryptionOnly); + } + + private String getUserIdFromPrincipalName(String pname) { + return pname.split(",")[0].substring("CN=".length()).split("@")[0]; + } + + public MSNetAuthContext getAuthContext() throws SSLPeerUnverifiedException { + + String authMechanism = "SSL"; + String peerUserId = + getUserIdFromPrincipalName(sslEngine.getSession().getPeerPrincipal().getName()); + MSNetAuthStatus status = new MSNetAuthStatus(MSNetAuthStatus.Authenticated); + + return new MSNetAuthContext(status, peerUserId, peerUserId, authMechanism); + } + + public void setReadyForEncrypting() { + socketState = SSLSocketState.READY_FOR_ENCRYPTING; + + // Once handshake is finished we can free up the buffers that are not going to be used anymore. + sslHandshaker.cleanupBuffers(); + } + + public void setHandshakeCompleted() { + socketState = SSLSocketState.HANDSHAKE_COMPLETED; + } + + @Override + public void close() throws IOException { + super.close(); + sslEngine.closeOutbound(); + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/MSNetSSLSocketFactory.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/MSNetSSLSocketFactory.java new file mode 100644 index 0000000..f0c2c6b --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/MSNetSSLSocketFactory.java @@ -0,0 +1,110 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet; + +import java.io.IOException; +import java.util.Objects; + +import javax.annotation.Nullable; +import javax.net.ssl.SSLEngine; + +import msjava.msnet.ssl.SSLEngineBuilder; +import msjava.msnet.ssl.SSLEngineConfig; +import msjava.msnet.ssl.SSLEngineFactory; +import msjava.msnet.ssl.SSLEstablisher; + +/** + * For general library overview and code examples refer to the {@link SSLEstablisher} documentation. + */ +public class MSNetSSLSocketFactory extends MSNetTCPSocketFactoryNIOImpl { + + private final boolean slicingBuffers; + + @Nullable protected final SSLEngineBuilder sslEngineBuilder; + @Nullable private volatile SSLEngineFactory sslEngineFactory; + + public MSNetSSLSocketFactory() { + slicingBuffers = false; + this.sslEngineBuilder = new SSLEngineBuilder(new SSLEngineConfig()); + sslEngineFactory = null; + } + + public MSNetSSLSocketFactory(SSLEngineConfig sslEngineConfig) { + slicingBuffers = false; + this.sslEngineBuilder = new SSLEngineBuilder(sslEngineConfig); + sslEngineFactory = null; + } + + public MSNetSSLSocketFactory(SSLEngineFactory sslEngineFactory, boolean slicingBuffers) { + this.slicingBuffers = slicingBuffers; + this.sslEngineBuilder = null; + this.sslEngineFactory = Objects.requireNonNull(sslEngineFactory); + } + + @Override + public MSNetTCPSocket createMSNetTCPSocket(boolean isServer) throws MSNetIOException { + SSLEngine sslEngine = makeSSLEngine(isServer); + if (sslEngineFactory == null) { + return new MSNetSSLSocket( + createMSNetTCPSocketImpl(isServer), this, sslEngine, slicingBuffers); + } else { + final SSLEngineFactory sslEngineFactory = this.sslEngineFactory; + return new MSNetSSLSocket( + createMSNetTCPSocketImpl(isServer), this, sslEngine, slicingBuffers) { + @Override + public void close() throws IOException { + super.close(); + sslEngineFactory.dispose(sslEngine); + } + }; + } + } + + @Override + public MSNetTCPSocket acceptMSNetTCPSocket(MSNetTCPSocketImpl serverSocketImpl) + throws MSNetIOException { + SSLEngine sslEngine = makeSSLEngine(true); + if (sslEngineFactory == null) { + return new MSNetSSLSocket( + acceptMSNetTCPSocketImpl(serverSocketImpl), this, sslEngine, slicingBuffers); + } else { + final SSLEngineFactory sslEngineFactory = this.sslEngineFactory; + return new MSNetSSLSocket( + acceptMSNetTCPSocketImpl(serverSocketImpl), this, sslEngine, slicingBuffers) { + @Override + public void close() throws IOException { + super.close(); + sslEngineFactory.dispose(sslEngine); + } + }; + } + } + + private SSLEngine makeSSLEngine(boolean isServer) { + if (sslEngineBuilder != null) { + return sslEngineBuilder.build(isServer); + } else { + if (isServer) { + return Objects.requireNonNull(sslEngineFactory).createServerEngine(); + } else { + return Objects.requireNonNull(sslEngineFactory).createClientEngine(); + } + } + } + + public void setSslEngineFactory(SSLEngineFactory sslEngineFactory) { + Objects.requireNonNull(sslEngineFactory); + Objects.requireNonNull(this.sslEngineFactory); + this.sslEngineFactory = sslEngineFactory; + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEncryptor.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEncryptor.java new file mode 100644 index 0000000..af11787 --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEncryptor.java @@ -0,0 +1,189 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl; + +import static javax.net.ssl.SSLEngineResult.Status.BUFFER_UNDERFLOW; + +import java.nio.ByteBuffer; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; + +import msjava.msnet.MSNetByteBufferManager; +import msjava.msnet.MSNetTCPSocketBuffer; + +/** This class uses SSLEngine for encrypting/decrypting incoming buffer. */ +public class SSLEncryptor { + + private static final int ENCRYPT_BUFFER_SIZE = + Integer.parseInt(System.getProperty("msjava.msnet.ssl.encryptBufferSize", "-1")); + private static final int DECRYPT_BUFFER_SIZE = + Integer.parseInt(System.getProperty("msjava.msnet.ssl.decryptBufferSize", "-1")); + + private final SSLEngine sslEngine; + private final boolean slicingBuffers; + + /** + * SSL Engine encrypts and decrypts data in chunks. There are two local buffers that are reused + * for that. + * + *

* Bytes are consumed in chunks from the input buffer and the result is inserted to the one + * of corresponding encryptedBytesBuffer/decryptedBytesBuffer. * Then content of these buffers is + * copied to the result buffer. * This repeats until input buffer is empty. + */ + private ByteBuffer encryptedBytesBuffer; + + private ByteBuffer decryptedBytesBuffer; + + public SSLEncryptor(SSLEngine sslEngine, boolean slicingBuffers) { + this( + sslEngine, + ENCRYPT_BUFFER_SIZE > 0 + ? ENCRYPT_BUFFER_SIZE + : sslEngine.getSession().getPacketBufferSize(), + DECRYPT_BUFFER_SIZE > 0 + ? DECRYPT_BUFFER_SIZE + : sslEngine.getSession().getApplicationBufferSize(), + slicingBuffers); + } + + public SSLEncryptor( + SSLEngine sslEngine, int encryptBufferSize, int decryptBufferSize, boolean slicingBuffers) { + this.sslEngine = sslEngine; + this.slicingBuffers = slicingBuffers; + this.encryptedBytesBuffer = ByteBuffer.allocate(encryptBufferSize); + this.decryptedBytesBuffer = ByteBuffer.allocate(decryptBufferSize); + } + + /** + * Buffer passed as an argument is overwritten with encrypted data and returned as a part of + * SSLEncryptorResult. + */ + public SSLEncryptorResult encrypt(MSNetTCPSocketBuffer buf) { + ByteBuffer bytesToEncrypt = MSNetByteBufferManager.getInstance().getBuffer(buf.size(), false); + bytesToEncrypt.put(buf.retrieve()); + buf.clear(); + + MSNetTCPSocketBuffer resultBuffer = buf; + + SSLEngineResult result; + bytesToEncrypt.flip(); + int bytesConsumed = 0; + do { + int n = encryptedBytesBuffer.remaining() / 2; + if (n == 0) n = encryptedBytesBuffer.remaining(); + if (!slicingBuffers || bytesToEncrypt.remaining() <= n) { + result = encryptChunk(bytesToEncrypt, encryptedBytesBuffer); + addToBuffer(encryptedBytesBuffer, resultBuffer); + } else { + ByteBuffer slice = bytesToEncrypt.slice(); + slice.limit(n); + result = encryptChunk(slice, encryptedBytesBuffer); + bytesToEncrypt.position(bytesToEncrypt.position() + n); + addToBuffer(encryptedBytesBuffer, resultBuffer); + } + bytesConsumed += result.bytesConsumed(); + } while (bytesToEncrypt.hasRemaining() && result.bytesConsumed() != 0); + + return SSLEncryptorResult.success(resultBuffer.size(), bytesConsumed); + } + + /** + * We consume everything from the buffer til its empty or buffer underflow occurs. In case of + * buffer underflow we return only that piece of data that we managed to decrypt. + * + *

Returns buffer with decrypted data and statistics of the bytes consumed/produced. + */ + public SSLEncryptorResult decrypt( + MSNetTCPSocketBuffer originBuffer, MSNetTCPSocketBuffer destBuffer) { + ByteBuffer bytesToDecrypt = + MSNetByteBufferManager.getInstance().getBuffer(originBuffer.size(), false); + bytesToDecrypt.put(originBuffer.peek()); + + int bytesConsumed = 0; + int bytesProduced = 0; + SSLEngineResult result; + bytesToDecrypt.flip(); + do { + result = decryptChunk(bytesToDecrypt, decryptedBytesBuffer); + if (result.getStatus() == BUFFER_UNDERFLOW) { + return SSLEncryptorResult.bufferUnderflow(bytesProduced, bytesConsumed); + } + originBuffer.retrieve(result.bytesConsumed()); + bytesConsumed += result.bytesConsumed(); + bytesProduced += result.bytesProduced(); + addToBuffer(decryptedBytesBuffer, destBuffer); + } while (bytesToDecrypt.hasRemaining() && result.bytesConsumed() != 0); + + return SSLEncryptorResult.success(bytesProduced, bytesConsumed); + } + + private SSLEngineResult encryptChunk(ByteBuffer bytesToEncrypt, ByteBuffer encryptToBuffer) { + SSLEngineResult result; + + try { + result = sslEngine.wrap(bytesToEncrypt, encryptToBuffer); + } catch (SSLException e) { + throw new IllegalStateException("Unexpected exception during SSLEngine.warp", e); + } + + switch (result.getStatus()) { + case OK: + break; + case BUFFER_OVERFLOW: + // encryptToBuffer is flushed after each iteration - because of that buffer overflow should + // not happen + throw new IllegalStateException("Buffer overflow should not occur after wrap."); + case BUFFER_UNDERFLOW: + throw new IllegalStateException("Buffer underflow occurred after a wrap"); + case CLOSED: + throw new IllegalStateException("The sslEngine is closed when encrypting data"); + default: + throw new IllegalStateException( + "Invalid SSL status: " + result.getStatus() + " during encryption"); + } + return result; + } + + private SSLEngineResult decryptChunk(ByteBuffer bytesToDecrypt, ByteBuffer decryptToBuffer) { + SSLEngineResult result; + + try { + result = sslEngine.unwrap(bytesToDecrypt, decryptToBuffer); + } catch (SSLException e) { + throw new IllegalStateException("Unexpected exception during SSLEngine.unwarp", e); + } + + switch (result.getStatus()) { + case OK: + break; + case BUFFER_UNDERFLOW: + break; + case BUFFER_OVERFLOW: + // decryptToBuffer is flushed after each iteration - because of that buffer overflow should + // not happen + throw new IllegalStateException("buffer overflow should not occur."); + case CLOSED: + throw new IllegalStateException("The sslEngine is closed when decrypting data"); + } + + return result; + } + + private void addToBuffer(ByteBuffer srcBuffer, MSNetTCPSocketBuffer destBuffer) { + srcBuffer.flip(); + destBuffer.store(srcBuffer); + srcBuffer.clear(); + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEncryptorResult.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEncryptorResult.java new file mode 100644 index 0000000..60800ae --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEncryptorResult.java @@ -0,0 +1,56 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl; + +/** + * This class is returned as a result of calling encrypt/decrypt method on the SSLEncryptor object. + * + *

Contains statistics about how many bytes have been produced/consumed. + */ +public class SSLEncryptorResult { + + enum State { + OK, + BUFFER_UNDERFLOW + } + + private final int bytesConsumed; + private final int bytesProduced; + private final State state; + + static SSLEncryptorResult success(int bytesProduced, int bytesConsumed) { + return new SSLEncryptorResult(bytesConsumed, bytesProduced, State.OK); + } + + static SSLEncryptorResult bufferUnderflow(int bytesProduced, int bytesConsumed) { + return new SSLEncryptorResult(bytesConsumed, bytesProduced, State.BUFFER_UNDERFLOW); + } + + private SSLEncryptorResult(int bytesConsumed, int bytesProduced, State state) { + this.bytesConsumed = bytesConsumed; + this.bytesProduced = bytesProduced; + this.state = state; + } + + public int getBytesConsumed() { + return bytesConsumed; + } + + public int getBytesProduced() { + return bytesProduced; + } + + public State getState() { + return state; + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEngineBuilder.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEngineBuilder.java new file mode 100644 index 0000000..294d10e --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEngineBuilder.java @@ -0,0 +1,190 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.*; +import java.io.FileInputStream; +import java.io.InputStream; +import java.security.KeyStore; +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.handler.ssl.ClientAuth; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; + +/** Builds SSLEngine object based on the SSLEngineConfig */ +public class SSLEngineBuilder { + private static final String SSLENGINE_TYPE = System.getProperty("msjava.msnet.ssl.engine", "JDK"); + private static final Logger LOGGER = LoggerFactory.getLogger(SSLEngineBuilder.class); + + private final SSLEngineConfig config; + + private final KeyManagerFactory keyManagerFactory; + private final TrustManagerFactory trustManagerFactory; + + public SSLEngineBuilder(SSLEngineConfig config) { + this.config = config; + this.keyManagerFactory = + Optional.ofNullable(config.keystorePath).map(this::createKeyManagerFactory).orElse(null); + this.trustManagerFactory = + Optional.ofNullable(config.truststorePath) + .map(this::createTrustManagerFactory) + .orElse(null); + } + + public SSLEngine build(boolean isServer) { + try { + if (SSLENGINE_TYPE.equals("OpenSSL")) { + return createOpenSSLEngine(isServer); + } else if (SSLENGINE_TYPE.equals("JDK")) { + return createJdkSSLEngine(isServer); + } else { + throw new IllegalArgumentException( + "Unknown SSLEngine type " + SSLENGINE_TYPE + ", please use OpenSSL or JDK"); + } + } catch (Exception e) { + throw new IllegalStateException("Could not create SSLEngine", e); + } + } + + private SSLEngine createJdkSSLEngine(boolean isServer) throws Exception { + KeyManager[] keyManagers = + Optional.ofNullable(keyManagerFactory).map(KeyManagerFactory::getKeyManagers).orElse(null); + TrustManager[] trustManagers = + Optional.ofNullable(trustManagerFactory) + .map(TrustManagerFactory::getTrustManagers) + .orElse(null); + + SSLContext context = SSLContext.getInstance(SSLEngineConfig.TLS_PROTOCOL_VERSION); + context.init(keyManagers, trustManagers, new SecureRandom()); + + SSLEngine sslEngine = context.createSSLEngine(); + if (isServer) { + sslEngine.setUseClientMode(false); + if (config.enabledClientAuthSSL) { + sslEngine.setNeedClientAuth(true); + } + } else { + sslEngine.setUseClientMode(true); + } + + LOGGER.debug( + "Creating SSLEngine with keystore " + + config.keystorePath + + ", truststore=" + + config.truststorePath); + LOGGER.debug( + "Creating SSLEngine in Server mode=" + + isServer + + " and clientAuth=" + + config.enabledClientAuthSSL); + + return sslEngine; + } + + private List getCiphers() { + String ciphers = + System.getProperty("msjava.msnet.ssl.ciphers", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); + return Arrays.asList(ciphers.split(":")); + } + + private SSLEngine createOpenSSLEngine(boolean isServer) throws Exception { + io.netty.handler.ssl.SslContext sslContext = null; + if (isServer) { + List ciphers = getCiphers(); + sslContext = + SslContextBuilder.forServer(keyManagerFactory) + .trustManager(trustManagerFactory) + .clientAuth(ClientAuth.REQUIRE) + .sslProvider(SslProvider.OPENSSL) + .protocols(SSLEngineConfig.TLS_PROTOCOL_VERSION) + .ciphers(ciphers) + .build(); + } else { + ClientAuth clientAuth = ClientAuth.NONE; + if (config.enabledClientAuthSSL) { + clientAuth = ClientAuth.REQUIRE; + } + sslContext = + SslContextBuilder.forClient() + .keyManager(keyManagerFactory) + .trustManager(trustManagerFactory) + .clientAuth(clientAuth) + .sslProvider(SslProvider.OPENSSL) + .protocols(SSLEngineConfig.TLS_PROTOCOL_VERSION) + .build(); + } + + SSLEngine sslEngine = sslContext.newEngine(PooledByteBufAllocator.DEFAULT); + + LOGGER.debug( + "Creating SSLEngine with keystore " + + config.keystorePath + + ", truststore=" + + config.truststorePath); + LOGGER.debug( + "Creating SSLEngine in Server mode=" + + isServer + + " and clientAuth=" + + config.enabledClientAuthSSL); + + return sslEngine; + } + + private KeyManagerFactory createKeyManagerFactory(String filepath) { + try { + assertPasswordNotNull("Keystore password cannot be set to null", config.keystorePassword); + assertPasswordNotNull("Key password cannot be set to null", config.keyPassword); + + KeyStore keyStore = KeyStore.getInstance("JKS"); + try (InputStream keyStoreIS = new FileInputStream(filepath)) { + keyStore.load(keyStoreIS, config.keystorePassword.toCharArray()); + } + KeyManagerFactory kmf = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(keyStore, config.keyPassword.toCharArray()); + return kmf; + } catch (Exception e) { + throw new IllegalArgumentException("Could not create key managers", e); + } + } + + private TrustManagerFactory createTrustManagerFactory(String filepath) { + try { + assertPasswordNotNull("Truststore password cannot be set to null", config.truststorePassword); + KeyStore trustStore = KeyStore.getInstance("JKS"); + try (InputStream trustStoreIS = new FileInputStream(filepath)) { + trustStore.load(trustStoreIS, config.truststorePassword.toCharArray()); + } + TrustManagerFactory trustFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustFactory.init(trustStore); + return trustFactory; + } catch (Exception e) { + throw new IllegalArgumentException("Could not create trust managers", e); + } + } + + private void assertPasswordNotNull(String message, String password) { + if (password == null) { + throw new IllegalArgumentException(message); + } + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEngineConfig.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEngineConfig.java new file mode 100644 index 0000000..cf0bb69 --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEngineConfig.java @@ -0,0 +1,109 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl; + +import static msjava.base.util.internal.SystemPropertyUtils.getBoolean; +import static msjava.base.util.internal.SystemPropertyUtils.getProperty; + +import msjava.base.slf4j.ContextLogger; +import org.slf4j.Logger; + +/** + * Config for SSLEngine. Default values are obtained from vm options. Can override any of them using + * corresponding methods. Protocol version used in SSL Engine: TLSv1.2 + * + *

For general library overview and code examples refer to the {@link SSLEstablisher} + * documentation. + */ +public class SSLEngineConfig { + private static final Logger LOGGER = ContextLogger.safeLogger(); + + public static final String TLS_PROTOCOL_VERSION = "TLSv1.2"; + + public static final String KEY_STORE_PATH_PARAMETER = "msjava.msnet.ssl.keystore"; + public static final String TRUST_STORE_PATH_PARAMETER = "msjava.msnet.ssl.truststore"; + public static final String KEY_STORE_PASSWORD_PARAMETER = "msjava.msnet.ssl.keystore_password"; + public static final String KEY_PASSWORD_PARAMETER = "msjava.msnet.ssl.key_password"; + public static final String TRUST_STORE_PASSWORD_PARAMETER = + "msjava.msnet.ssl.truststore_password"; + public static final String ENABLED_CLIENT_AUTH_PARAMETER = "msjava.msnet.ssl.enabled_client_auth"; + + public String keystorePassword = getProperty(KEY_STORE_PASSWORD_PARAMETER, LOGGER); + public String keyPassword = getProperty(KEY_PASSWORD_PARAMETER, LOGGER); + public String truststorePassword = getProperty(TRUST_STORE_PASSWORD_PARAMETER, LOGGER); + + public String keystorePath = getProperty(KEY_STORE_PATH_PARAMETER, LOGGER); + public String truststorePath = getProperty(TRUST_STORE_PATH_PARAMETER, LOGGER); + + boolean enabledClientAuthSSL = getBoolean(ENABLED_CLIENT_AUTH_PARAMETER, true, LOGGER); + + /** + * Overrides value obtained from vm arg: msjava.msnet.ssl.keystore_password + * + * @param keystorePassword + */ + public SSLEngineConfig withKeystorePassword(String keystorePassword) { + this.keystorePassword = keystorePassword; + return this; + } + + /** + * Overrides value obtained from vm arg: msjava.msnet.ssl.truststore + * + * @param trustStorePassword + */ + public SSLEngineConfig withTruststorePassword(String trustStorePassword) { + this.truststorePassword = trustStorePassword; + return this; + } + + /** + * Overrides value obtained from vm arg: msjava.msnet.ssl.key_password + * + * @param keyPassword + */ + public SSLEngineConfig withKeyPassword(String keyPassword) { + this.keyPassword = keyPassword; + return this; + } + + /** + * Overrides value obtained from vm arg: msjava.msnet.ssl.enabled_client_auth + * + * @param enabledClientAuthSSL + */ + public SSLEngineConfig withClientAuthEnabled(boolean enabledClientAuthSSL) { + this.enabledClientAuthSSL = enabledClientAuthSSL; + return this; + } + + /** + * Overrides value obtained from vm arg: msjava.msnet.ssl.keystore + * + * @param keystorePath + */ + public SSLEngineConfig withKeystorePath(String keystorePath) { + this.keystorePath = keystorePath; + return this; + } + + /** + * Overrides value obtained from vm arg: msjava.msnet.ssl.truststore + * + * @param truststorePath + */ + public SSLEngineConfig withTruststorePath(String truststorePath) { + this.truststorePath = truststorePath; + return this; + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEngineFactory.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEngineFactory.java new file mode 100644 index 0000000..b5d592a --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEngineFactory.java @@ -0,0 +1,23 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl; + +import javax.net.ssl.SSLEngine; + +public interface SSLEngineFactory { + SSLEngine createServerEngine(); + + SSLEngine createClientEngine(); + + void dispose(SSLEngine sslEngine); +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEstablisher.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEstablisher.java new file mode 100644 index 0000000..2efaa0f --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEstablisher.java @@ -0,0 +1,233 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package msjava.msnet.ssl; + +import javax.annotation.Nullable; + +import com.google.common.base.Stopwatch; +import msjava.msnet.MSNetAbstractTCPServer; +import msjava.msnet.MSNetEstablishStatus; +import msjava.msnet.MSNetEstablisher; +import msjava.msnet.MSNetEstablisherFactory; +import msjava.msnet.MSNetSSLSocket; +import msjava.msnet.MSNetSSLSocketFactory; +import msjava.msnet.MSNetTCPConnection; +import msjava.msnet.MSNetTCPSocket; +import msjava.msnet.MSNetTCPSocketBuffer; +import msjava.msnet.MSNetTCPSocketFactory; +import msjava.msnet.utils.MSNetConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * SSLEstablisher is used to to coordinate and supervise SSL handshake performed by {@link + * MSNetSSLSocket} + * + *

+ * + *

+ * + *

Steps needed to set up SSL connection between server and client:

+ * + *
    + *
  • Create respective {@link MSNetSSLSocketFactory} for both server and client connection. + *
  • Use {@link SSLEngineConfig} to configure {@link MSNetSSLSocketFactory}. For more info refer + * to the {@link SSLEngineConfig} doc. + *
  • Use {@link MSNetAbstractTCPServer#setSocketFactory(MSNetTCPSocketFactory)} to inject + * previously created {@link MSNetSSLSocketFactory} + *
  • Use {@link MSNetTCPConnection#setSocketFactory(MSNetTCPSocketFactory)} to set {@link + * MSNetSSLSocketFactory} for the client connection. + *
  • {@link MSNetSSLSocketFactory} can also be globally set using {@link + * MSNetConfiguration#setDefaultMSNetTCPSocketFactory(MSNetTCPSocketFactory)} + *
  • Use {@link MSNetAbstractTCPServer#addEstablisherFactory(MSNetEstablisherFactory)} to pass + * there {@link SSLEstablisherFactory} + *
  • Create {@SSLEstablisher} object using preferably {@link SSLEstablisherFactory} and pass it + * to the client connection {@link MSNetTCPConnection#addEstablisher(MSNetEstablisher)} + *
+ * + *

Code example. Server configuration:

+ * + *
+ *     SSLEngineConfig serverSSLConfig = new SSLEngineConfig()
+ *          .withKeyPassword(keyPassword)
+ *          .withKeystorePassword(keystorePassword)
+ *          .withTruststorePassword(trustStorePassword)
+ *          .withClientAuthEnabled(true)
+ *          .withKeystorePath(keystorePath)
+ *          .withTruststorePath(truststorePath);
+ *
+ *     MSNetSSLSocketFactory serverSocketFactory = new MSNetSSLSocketFactory(sslServerConfig);
+ *     server = new MSNetThreadPoolTCPServer(serverLoop, new MSNetInetAddress("localhost:0"), new MSNetID("test ssl server"));
+ *     server.addEstablisherFactory(new SSLEstablisherFactory());
+ *     server.setSocketFactory(serverSocketFactory);
+ *     server.start();
+ * 
+ * + *

Code example. Client connection configuration:

+ * + *
+ *     SSLEngineConfig clientSSLConfig = new SSLEngineConfig()
+ *          .withKeyPassword(keyPassword)
+ *          .withKeystorePassword(keystorePassword)
+ *          .withTruststorePassword(trustStorePassword)
+ *          .withKeystorePath(keystorePath)
+ *          .withTruststorePath(truststorePath);
+ *
+ *     MSNetSSLSocketFactory clientSocketFactory = new MSNetSSLSocketFactory(clientSSLConfig);
+ *     clientConnection = new MSNetTCPConnection(clientLoop, serverAddress, "test ssl connection");
+ *     clientConnection.addEstablisher(new SSLEstablisherFactory().createEstablisher());
+ *     clientConnection.setSocketFactory(clientSocketFactory);
+ * 
+ * + *

Library limitations

+ * + *
    + *
  • Session renegotiation is not supported + *
  • Sending user data during handshake is not supported + *
+ * + *

Support

+ * + *
    + *
  • This is an experimental library created by Optimus/DAL team and therefore is not supported + * by msjava team. + *
  • If you have any question regarding this particular library contact the DAL support mailing + * list. + *
+ */ +public class SSLEstablisher extends MSNetEstablisher { + private static final Logger LOGGER = LoggerFactory.getLogger(SSLEstablisher.class); + + /** Must be higher than MSNetSOCKS4Establisher.DEFAULT_ESTABLISH_PRIORITY */ + public static final int DEFAULT_ESTABLISHER_PRIORITY = 1100; + + private static final String ESTABLISHER_NAME = "SSL Establisher"; + + private final boolean encryptionOnly; + // The hostname which the client initiates connections to. It might differ from the hostname + // returned by reverse + // DNS because of DNS CNAME pointers, load balancers, service discovery methods. We need to check + // the original + // hostname to protect against hijack. + @Nullable private final String serviceHostname; + + private MSNetTCPConnection conn; + private MSNetSSLSocket sslSocket; + private MSNetEstablishStatus status = MSNetEstablishStatus.UNKNOWN; + + private boolean isServerSide; + + private Stopwatch stopwatch = Stopwatch.createUnstarted(); + + SSLEstablisher() { + this(DEFAULT_ESTABLISHER_PRIORITY, false, null); + } + + public SSLEstablisher(int priority, boolean encryptionOnly, @Nullable String serviceHostname) { + super(priority); + this.encryptionOnly = encryptionOnly; + this.serviceHostname = serviceHostname; + } + + @Override + public void init(boolean isServerSide, MSNetTCPConnection conn) { + this.isServerSide = isServerSide; + this.conn = conn; + } + + @Override + public MSNetEstablishStatus establish(MSNetTCPSocketBuffer readBuf) { + LOGGER.debug("Trying to perform handshake"); + + try { + startStopwatch(); + sslSocket = getSocketFromConnection(conn); + return doHandshake(readBuf); + } catch (Exception e) { + LOGGER.error("Could not establish handshake", e); + return MSNetEstablishStatus.FAILURE; + } + } + + private MSNetEstablishStatus doHandshake(MSNetTCPSocketBuffer readBuf) throws Exception { + status = handshake(readBuf); + if (status == MSNetEstablishStatus.COMPLETE) { + if (isServerSide) { + // Server side is actually the last one to send handshake data + // Because of that there is a intermediate state that is changed on the socket when last + // piece of handshake data is send. + sslSocket.setHandshakeCompleted(); + } else { + sslSocket.setReadyForEncrypting(); + } + + if (!encryptionOnly) { + conn.setAuthContext(sslSocket.getAuthContext()); + } + LOGGER.info( + "Successful handshake with {}, it took {}", conn.getAuthContext(), stopwatch.stop()); + } + + return status; + } + + private MSNetEstablishStatus handshake(MSNetTCPSocketBuffer readBuf) throws Exception { + if (!sslSocket.doHandshake(readBuf)) { + return MSNetEstablishStatus.CONTINUE; + } + + if (!sslSocket.verifyCertificates(encryptionOnly, serviceHostname)) { + LOGGER.error("Certificate validation failed. Failed to establish connection."); + return MSNetEstablishStatus.FAILURE; + } + + return MSNetEstablishStatus.COMPLETE; + } + + private void startStopwatch() { + if (!stopwatch.isRunning()) { + stopwatch.start(); + } + } + + private MSNetSSLSocket getSocketFromConnection(MSNetTCPConnection conn) { + MSNetTCPSocket socket = conn.getSocket(); + if (socket instanceof MSNetSSLSocket) { + return (MSNetSSLSocket) socket; + } + + throw new IllegalArgumentException( + "Connection is not a SSL Connection. Cannot extract SSLEngine from the socket."); + } + + @Override + public void cleanup() { + status = MSNetEstablishStatus.UNKNOWN; + stopwatch.reset(); + } + + @Override + public MSNetEstablishStatus getStatus() { + return status; + } + + @Override + public String getEstablisherName() { + return ESTABLISHER_NAME; + } + + @Override + public MSNetTCPSocketBuffer getOutputBuffer() { + LOGGER.debug("isServer=" + isServerSide + ". Sending " + sslSocket.getOutputBuffer()); + return sslSocket.getOutputBuffer(); + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEstablisherFactory.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEstablisherFactory.java new file mode 100644 index 0000000..665910b --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLEstablisherFactory.java @@ -0,0 +1,42 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl; + +import javax.annotation.Nullable; + +import msjava.msnet.MSNetEstablisher; +import msjava.msnet.MSNetEstablisherFactory; + +/** + * For general library overview and code examples refer to the {@link SSLEstablisher} documentation. + */ +public class SSLEstablisherFactory implements MSNetEstablisherFactory { + + private final boolean encryptionOnly; // no authentication + @Nullable private final String serviceHostname; + + public SSLEstablisherFactory() { + this(false, null); + } + + public SSLEstablisherFactory(boolean encryptionOnly, @Nullable String serviceHostname) { + this.encryptionOnly = encryptionOnly; + this.serviceHostname = serviceHostname; + } + + @Override + public MSNetEstablisher createEstablisher() { + return new SSLEstablisher( + SSLEstablisher.DEFAULT_ESTABLISHER_PRIORITY, encryptionOnly, serviceHostname); + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLHandshaker.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLHandshaker.java new file mode 100644 index 0000000..8934de8 --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/SSLHandshaker.java @@ -0,0 +1,211 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl; + +import msjava.msnet.MSNetTCPSocketBuffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; +import java.nio.ByteBuffer; + +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*; + +/** + * This class handles logic responsible to perform ssl handshake. More details in the doHandshake + * method's documentation. + */ +public class SSLHandshaker { + private static final Logger LOGGER = LoggerFactory.getLogger(SSLHandshaker.class); + + private static final ByteBuffer EMPTY_APP_DATA_BUFFER = ByteBuffer.allocate(0); + + private boolean handshakeStarted = false; + + private final SSLEngine engine; + + private ByteBuffer outgoingHandshakeData; + private ByteBuffer incomingHandshakeData; + private ByteBuffer peerAppData; + + private MSNetTCPSocketBuffer outputBuffer = new MSNetTCPSocketBuffer(); + + public SSLHandshaker(SSLEngine engine) { + this.engine = engine; + } + + /** + * SSL handshake consists of several data exchanges between server and client. Read is represented + * by unwrap and write by wrap. + * + *

Writes and reads may be followed by some tasks that need to be performed in a separate + * thread. + * + *

+ * + *

Typical flow of the SSLHandshake: Client SSL/TLS Message HandshakeStatus wrap() ClientHello + * NEED_UNWRAP unwrap() ServerHello/Cert/ServerHelloDone NEED_WRAP wrap() ClientKeyExchange + * NEED_WRAP wrap() ChangeCipherSpec NEED_WRAP wrap() Finished NEED_UNWRAP unwrap() + * ChangeCipherSpec NEED_UNWRAP unwrap() Finished FINISHED + * + *

More info about the protocol and java implementation: + * https://docs.oracle.com/javase/8/docs/technotes/guides/security/jsse/JSSERefGuide.html#SSLEngine + */ + public boolean doHandshake(MSNetTCPSocketBuffer netData) throws Exception { + beginHandshake(); + + SSLEngineResult.HandshakeStatus hs = unwrapWithSubsequentTasks(netData); + if (hs == SSLEngineResult.HandshakeStatus.FINISHED) { + return true; + } + + hs = wrapWithSubsequentTasks(); + if (hs == SSLEngineResult.HandshakeStatus.FINISHED) { + return true; + } + + LOGGER.trace("Continuing with handshake"); + return false; + } + + private void beginHandshake() throws SSLException { + if (!handshakeStarted) { + engine.beginHandshake(); + handshakeStarted = true; + + allocateBuffers(); + } + } + + private void allocateBuffers() { + SSLSession session = engine.getSession(); + outgoingHandshakeData = ByteBuffer.allocate(session.getPacketBufferSize()); + incomingHandshakeData = ByteBuffer.allocate(session.getPacketBufferSize()); + peerAppData = ByteBuffer.allocate(session.getApplicationBufferSize()); + } + + /** This function simply reads incoming data and changes ssl engine state. */ + private SSLEngineResult.HandshakeStatus unwrapWithSubsequentTasks(MSNetTCPSocketBuffer netData) + throws SSLException { + SSLEngineResult.HandshakeStatus hs = engine.getHandshakeStatus(); + + boolean bufferUnderflow = false; + + // Because we send batched handshake data it might be needed to do few subsequent unwraps. + while ((hs == NEED_UNWRAP || hs == NEED_TASK) && !bufferUnderflow) { + + writeToHandshakeBufferAsMuchAsPossible(netData, incomingHandshakeData); + + switch (hs) { + case NEED_UNWRAP: + + // Process incoming handshaking data + incomingHandshakeData.flip(); + // Although peerAppData is a non empty buffer we don't expect at this point any user data. + SSLEngineResult sslEngineResult = engine.unwrap(incomingHandshakeData, peerAppData); + hs = sslEngineResult.getHandshakeStatus(); + incomingHandshakeData.compact(); + + SSLEngineResult.Status status = sslEngineResult.getStatus(); + switch (status) { + case OK: + break; + + case BUFFER_OVERFLOW: + throw new IllegalStateException( + "Application data is not supposed to be exchanged at this point"); + + case BUFFER_UNDERFLOW: + // Break the loop and wait for more data + bufferUnderflow = true; + break; + } + break; + case NEED_TASK: + hs = handleTask(); + break; + } + } + return hs; + } + + /** + * Since handshake data might be larger than max incomingHandshakeData(16k) size we have to split + * it and pass to the sslEngine in chunks. + */ + private void writeToHandshakeBufferAsMuchAsPossible( + MSNetTCPSocketBuffer netData, ByteBuffer incomingHandshakeData) { + int bytesToRetrieve = Math.min(incomingHandshakeData.remaining(), netData.size()); + incomingHandshakeData.put(netData.retrieve(bytesToRetrieve)); + } + + /** This functions writes data to the outgoingHandshakeData buffer. */ + private SSLEngineResult.HandshakeStatus wrapWithSubsequentTasks() throws SSLException { + SSLEngineResult.HandshakeStatus hs = engine.getHandshakeStatus(); + + // It might be needed to do 3 subsequent wraps and then send them in batch. + while (hs == NEED_WRAP || hs == NEED_TASK) { + switch (hs) { + case NEED_WRAP: + outgoingHandshakeData.clear(); + + // Generate handshaking data but no user data to send. + SSLEngineResult sslEngineResult = + engine.wrap(EMPTY_APP_DATA_BUFFER, outgoingHandshakeData); + hs = sslEngineResult.getHandshakeStatus(); + + switch (sslEngineResult.getStatus()) { + case OK: + outgoingHandshakeData.flip(); + outputBuffer.store(outgoingHandshakeData); + break; + + case BUFFER_OVERFLOW: + throw new IllegalStateException("Buffer overflow should not happen during wrap"); + + case BUFFER_UNDERFLOW: + throw new IllegalStateException("Buffer underflow should not happen during wrap"); + } + break; + case NEED_TASK: + hs = handleTask(); + break; + } + } + + return hs; + } + + public MSNetTCPSocketBuffer getOutputBuffer() { + return outputBuffer; + } + + private SSLEngineResult.HandshakeStatus handleTask() { + Runnable task; + while ((task = engine.getDelegatedTask()) != null) { + task.run(); + } + + return engine.getHandshakeStatus(); + } + + public void cleanupBuffers() { + outgoingHandshakeData = null; + incomingHandshakeData = null; + peerAppData = null; + handshakeStarted = false; + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/CertificateExpirationValidator.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/CertificateExpirationValidator.java new file mode 100644 index 0000000..19167f7 --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/CertificateExpirationValidator.java @@ -0,0 +1,40 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl.verification; + +import java.security.cert.Certificate; +import java.security.cert.CertificateExpiredException; +import java.security.cert.CertificateNotYetValidException; +import java.security.cert.X509Certificate; + +import javax.net.ssl.SSLEngine; + +/** + * Client is always supposed to validate server. Server only if client auth is enabled. Otherwise it + * will not get client's certificate. + */ +class CertificateExpirationValidator { + + static boolean validate(SSLEngine engine, Certificate[] certs) + throws CertificateNotYetValidException, CertificateExpiredException { + if (engine.getUseClientMode() || engine.getNeedClientAuth()) { + for (Certificate cert : certs) { + if (cert instanceof X509Certificate) ((X509Certificate) cert).checkValidity(); + else + throw new CertificateNotYetValidException( + "Unsupported certificate: " + cert.getType() + " " + cert.toString()); + } + } + return true; + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/CertificateVerifier.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/CertificateVerifier.java new file mode 100644 index 0000000..8408f14 --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/CertificateVerifier.java @@ -0,0 +1,96 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl.verification; + +import java.security.cert.Certificate; +import java.security.cert.CertificateException; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; + +import msjava.base.slf4j.ContextLogger; +import msjava.base.util.internal.SystemPropertyUtils; + +/** + * Certificate verifier checks two things: 1. if peer certificate is not expired 2. checks whether + * DNS name in the peer certificate matches with the one we are connected to. + * + *

While client always performs step 1 server performs step 1 only when client auth is enabled. + * + *

Step 2 is performed when msjava.msnet.ssl.verification.enabledHostnameVerification is set tu + * true + */ +public class CertificateVerifier { + + private static final String VERIFY_HOSTNAME_PARAMETER = "msjava.msnet.ssl.verify_hostnames"; + private static final boolean DEFAULT_VERIFY_HOSTNAMES = + SystemPropertyUtils.getBoolean(VERIFY_HOSTNAME_PARAMETER, true, ContextLogger.safeLogger()); + + private static FallbackHostNameVerifier fallbackHostNameVerifier = (hostName, certs) -> false; + + private final boolean verifyHostnames; + private final SSLEngine sslEngine; + + public CertificateVerifier(SSLEngine sslEngine) { + this.verifyHostnames = DEFAULT_VERIFY_HOSTNAMES; + this.sslEngine = sslEngine; + } + + CertificateVerifier(boolean verifyHostnames, SSLEngine sslEngine) { + this.verifyHostnames = verifyHostnames; + this.sslEngine = sslEngine; + } + + public static void setFallbackHostNameVerifier( + FallbackHostNameVerifier fallbackHostNameVerifier) { + CertificateVerifier.fallbackHostNameVerifier = fallbackHostNameVerifier; + } + + public boolean verify(String hostName, boolean encryptionOnly) + throws SSLException, javax.security.cert.CertificateException, CertificateException { + if (isServerWithClientAuthDisabled(encryptionOnly)) { + return true; + } + Certificate[] certs = sslEngine.getSession().getPeerCertificates(); + + boolean expirationValidationResult = CertificateExpirationValidator.validate(sslEngine, certs); + boolean validateHostnameResult = + validateHostname(hostName, certs, !sslEngine.getUseClientMode()); + + return expirationValidationResult && validateHostnameResult; + } + + private boolean isServerWithClientAuthDisabled(boolean encryptionOnly) { + boolean isServer = !sslEngine.getUseClientMode(); + boolean clientAuthDisabled = !sslEngine.getNeedClientAuth(); + + return isServer && (clientAuthDisabled || encryptionOnly); + } + + private boolean validateHostname(String hostName, Certificate[] certs, boolean isServer) + throws CertificateException { + if (!verifyHostnames) { + return true; + } + + if (isServer) { + return true; + } + + if (!HostNameVerifier.verifyHostname(hostName, certs)) { + return fallbackHostNameVerifier.validatePeerCertificate(hostName, certs); + } + + return true; + } +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/FallbackHostNameVerifier.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/FallbackHostNameVerifier.java new file mode 100644 index 0000000..21e53fa --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/FallbackHostNameVerifier.java @@ -0,0 +1,20 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl.verification; + +import java.security.cert.Certificate; +import java.security.cert.CertificateException; + +public interface FallbackHostNameVerifier { + boolean validatePeerCertificate(String hostName, Certificate[] certs) throws CertificateException; +} diff --git a/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/HostNameVerifier.java b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/HostNameVerifier.java new file mode 100644 index 0000000..6ad964d --- /dev/null +++ b/optimus/platform/projects/msnet-ssl/src/main/java/msjava/msnet/ssl/verification/HostNameVerifier.java @@ -0,0 +1,135 @@ +/* + * Morgan Stanley makes this available to you under the Apache License, Version 2.0 (the "License"). + * You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. + * See the NOTICE file distributed with this work for additional information regarding copyright ownership. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msjava.msnet.ssl.verification; + +import java.io.ByteArrayInputStream; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.CertificateParsingException; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import javax.naming.InvalidNameException; +import javax.naming.ldap.LdapName; +import javax.naming.ldap.Rdn; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Compares provided hostname with SAN or CN names from the certificates. */ +class HostNameVerifier { + private static final Logger LOGGER = LoggerFactory.getLogger(CertificateVerifier.class); + + static boolean verifyHostname(String hostName, Certificate[] certs) throws CertificateException { + Set allowedHostNames = getAllowedServerNames(convertToJavaX509Certificates(certs)); + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "Comparing received hostname {} with certificate host names {}", + hostName, + allowedHostNames); + } + + boolean hostNameMatches = + allowedHostNames.contains(hostName) + || (!hostName.endsWith(".") && allowedHostNames.contains(hostName + ".")); + if (!hostNameMatches) { + LOGGER.warn("The host name '{}' is not a legit host name of a server.", hostName); + } + + return hostNameMatches; + } + + public static X509Certificate[] convertToJavaX509Certificates(Certificate[] certs) + throws CertificateException { + CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); + X509Certificate[] convertedCerts = new X509Certificate[certs.length]; + for (int i = 0; i < certs.length; i++) { + convertedCerts[i] = convertCert(certFactory, certs[i]); + } + return convertedCerts; + } + + private static X509Certificate convertCert( + CertificateFactory certificateFactory, Certificate cert) throws CertificateException { + try { + ByteArrayInputStream stream = new ByteArrayInputStream(cert.getEncoded()); + return (X509Certificate) certificateFactory.generateCertificate(stream); + } catch (CertificateException e) { + throw new CertificateException("Could not convert cert to X509Certificate.", e); + } + } + + private static Set getAllowedServerNames(X509Certificate[] certs) + throws CertificateParsingException { + Set allowedSubjectAlternativeNames = new HashSet<>(); + for (X509Certificate cert : certs) { + allowedSubjectAlternativeNames.addAll(getSubjectAlternativeNames(cert)); + getCommonName(cert).ifPresent(allowedSubjectAlternativeNames::add); + } + return allowedSubjectAlternativeNames; + } + + private static Optional getCommonName(X509Certificate cert) { + /* + * The original code below contain `sun.security.x509.X509Name`, which can't be used as import for Java 11+. + * Packages sun.* hold internal stuff, and should not be used by thirdparty applications in general case. + * Since java 9 the module system has been introduced, java 9+ would "protect" these packages even in compile time. + * + * try { + * String commonName = new X500Name(cert.getSubjectX500Principal().getName()).getCommonName(); + * return Optional.of(commonName); + * } catch (IOException e) { + * // if anything happen when extra the common name from certs, we just ignore that + * } + * + * Therefore, alternative solution by 'LdapName' should be used to get common name in roughly RFC 1779 DN + * dn = "CN=commonName, OU=organizationUnit, ON=organizationName, LN=localityName, SN=stateName, ...."; + */ + try { + String dn = cert.getSubjectX500Principal().getName(); + LdapName ln = new LdapName(dn); + String commonName = ""; + for (Rdn rdn : ln.getRdns()) { + if (rdn.getType().equalsIgnoreCase("CN")) { + commonName = rdn.getValue().toString(); + break; + } + } + return Optional.of(commonName); + } catch (InvalidNameException e) { + // if anything happen when extra the common name from certs, we just ignore that + } + + return Optional.empty(); + } + + private static List getSubjectAlternativeNames(X509Certificate cert) + throws CertificateParsingException { + // this list contains host names but also some unexpected numbers + Collection> entries = + Optional.ofNullable(cert.getSubjectAlternativeNames()).orElse(Collections.emptyList()); + + return entries.stream() + .flatMap(Collection::stream) + .filter(entry -> entry instanceof String) + .map(entry -> (String) entry) + .collect(Collectors.toList()); + } +} diff --git a/optimus/platform/projects/scala_compat/src/main/scala-2.12/optimus/scalacompat/collection/BuilderProvider.java b/optimus/platform/projects/scala_compat/src/main/scala-2.12/optimus/scalacompat/collection/BuilderProvider.java index ad28842..59cadd3 100644 --- a/optimus/platform/projects/scala_compat/src/main/scala-2.12/optimus/scalacompat/collection/BuilderProvider.java +++ b/optimus/platform/projects/scala_compat/src/main/scala-2.12/optimus/scalacompat/collection/BuilderProvider.java @@ -15,13 +15,12 @@ import scala.collection.mutable.Builder; /** - * Scala's TraversableLike.newBuilder is private[this], which causes problems in AsyncBase, where - * we need to create new instances of wrapped collection classes. Unfortunately, Iterable.companion + * Scala's TraversableLike.newBuilder is private[this], which causes problems in AsyncBase, where we + * need to create new instances of wrapped collection classes. Unfortunately, Iterable.companion * returns companion objects that are degenerate, e.g. we get the same builder for every variety of * Set. * - * This workaround exposes newBuilder via Java, which does not support instance-level privacy. - * + *

This workaround exposes newBuilder via Java, which does not support instance-level privacy. */ public abstract class BuilderProvider implements HasNewBuilder { public static Builder exposedBuilder(HasNewBuilder coll) { diff --git a/optimus/platform/projects/scala_compat/src/main/scala-2.12/optimus/scalacompat/collection/CanEqual.java b/optimus/platform/projects/scala_compat/src/main/scala-2.12/optimus/scalacompat/collection/CanEqual.java index 490a656..0620032 100644 --- a/optimus/platform/projects/scala_compat/src/main/scala-2.12/optimus/scalacompat/collection/CanEqual.java +++ b/optimus/platform/projects/scala_compat/src/main/scala-2.12/optimus/scalacompat/collection/CanEqual.java @@ -15,28 +15,27 @@ public abstract class CanEqual { public static int knownSize(scala.collection.GenTraversableOnce p1) { return p1.sizeHintIfCheap(); } - public static boolean canEqual(scala.collection.GenTraversableOnce p1, scala.collection.GenTraversableOnce p2) { - if (p1 == p2) - return true; - if (p1 == null || p2 == null) - return false; + + public static boolean canEqual( + scala.collection.GenTraversableOnce p1, scala.collection.GenTraversableOnce p2) { + if (p1 == p2) return true; + if (p1 == null || p2 == null) return false; int size1 = p1.sizeHintIfCheap(); int size2 = p2.sizeHintIfCheap(); return size1 == -1 || size2 == -1 || size1 == size2; } public static boolean canEqual(String p1, String p2) { - if (p1 == p2) - return true; - if (p1 == null || p2 == null) - return false; + if (p1 == p2) return true; + if (p1 == null || p2 == null) return false; return p1.length() == p2.length() && p1.hashCode() == p2.hashCode(); } public static boolean canEqual(Enum p1, Enum p2) { return p1 == p2; } + public static boolean canEqual(Object p1, Object p2) { return ((p1 == null) == (p2 == null)); } -} \ No newline at end of file +} diff --git a/optimus/platform/projects/scala_compat/src/main/scala-2.13/optimus/scalacompat/collection/BuilderProvider.java b/optimus/platform/projects/scala_compat/src/main/scala-2.13/optimus/scalacompat/collection/BuilderProvider.java index 3f1cb85..6ce1f77 100644 --- a/optimus/platform/projects/scala_compat/src/main/scala-2.13/optimus/scalacompat/collection/BuilderProvider.java +++ b/optimus/platform/projects/scala_compat/src/main/scala-2.13/optimus/scalacompat/collection/BuilderProvider.java @@ -15,18 +15,18 @@ import scala.collection.mutable.Builder; /** - * Scala's IterableOps.newSpecificBuilder is protected, which causes problems in AsyncBase, where - * we need to create new instances of wrapped collection classes. Unfortunately, Iterable.companion + * Scala's IterableOps.newSpecificBuilder is protected, which causes problems in AsyncBase, where we + * need to create new instances of wrapped collection classes. Unfortunately, Iterable.companion * returns companion objects that are degenerate, e.g. we get the same builder for every variety of * Set. * - * This workaround exposes newSpecificBuilder via Java, which does not support Scala protected - * + *

This workaround exposes newSpecificBuilder via Java, which does not support Scala protected */ public abstract class BuilderProvider { public static Builder exposedBuilder(Object coll) { return ((IterableOps) coll).newSpecificBuilder(); } + public static String stringPrefix(Object coll) { return ((scala.collection.Iterable) coll).stringPrefix(); } diff --git a/optimus/platform/projects/scala_compat/src/main/scala-2.13/optimus/scalacompat/collection/CanEqual.java b/optimus/platform/projects/scala_compat/src/main/scala-2.13/optimus/scalacompat/collection/CanEqual.java index 25a0c99..d1dcc35 100644 --- a/optimus/platform/projects/scala_compat/src/main/scala-2.13/optimus/scalacompat/collection/CanEqual.java +++ b/optimus/platform/projects/scala_compat/src/main/scala-2.13/optimus/scalacompat/collection/CanEqual.java @@ -12,24 +12,21 @@ package optimus.scalacompat.collection; public abstract class CanEqual { - public static boolean canEqual(scala.collection.IterableOnce p1, scala.collection.IterableOnce p2) { - if (p1 == p2) - return true; - if (p1 == null || p2 == null) - return false; + public static boolean canEqual( + scala.collection.IterableOnce p1, scala.collection.IterableOnce p2) { + if (p1 == p2) return true; + if (p1 == null || p2 == null) return false; int size1 = p1.knownSize(); int size2 = p2.knownSize(); return size1 == -1 || size2 == -1 || size1 == size2; } public static boolean canEqual(String p1, String p2) { - if (p1 == p2) - return true; - if (p1 == null || p2 == null) - return false; + if (p1 == p2) return true; + if (p1 == null || p2 == null) return false; return p1.length() == p2.length() && p1.hashCode() == p2.hashCode(); } - //TODO Stable, OptimusCollections, things that we patch a hashcode into etc + // TODO Stable, OptimusCollections, things that we patch a hashcode into etc // public static boolean canEqual(Stable p1, Stable p2) { // if (p1 == p2) // return true; @@ -40,6 +37,7 @@ public static boolean canEqual(String p1, String p2) { public static boolean canEqual(Enum p1, Enum p2) { return p1 == p2; } + public static boolean canEqual(Object p1, Object p2) { return ((p1 == null) == (p2 == null)); } diff --git a/optimus/platform/projects/stagingplugin/src/main/scala/optimus/platform/annotations/internal/EntityMetaDataAnnotation.java b/optimus/platform/projects/stagingplugin/src/main/scala/optimus/platform/annotations/internal/EntityMetaDataAnnotation.java index a134167..7b641ef 100644 --- a/optimus/platform/projects/stagingplugin/src/main/scala/optimus/platform/annotations/internal/EntityMetaDataAnnotation.java +++ b/optimus/platform/projects/stagingplugin/src/main/scala/optimus/platform/annotations/internal/EntityMetaDataAnnotation.java @@ -17,7 +17,7 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; -@Target({ ElementType.TYPE }) +@Target({ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface EntityMetaDataAnnotation { @@ -35,8 +35,10 @@ // these fields are in line with the attribute names // if we add/modify fields in this annotation then we need to keep this is step - // additionally the code that generates this annotation in OptimusNames, AdjustAst should be kept in line - // this is also parsed in code, so MetaDataReader and EntityHierarchyManager will need corresponding changes + // additionally the code that generates this annotation in OptimusNames, AdjustAst should be kept + // in line + // this is also parsed in code, so MetaDataReader and EntityHierarchyManager will need + // corresponding changes public static String name_slotNumber = "slotNumber"; public static String name_explicitSlotNumber = "explicitSlotNumber"; diff --git a/optimus/platform/projects/utils/src/main/java/optimus/utils/datetime/ZonedDateTimeOps.java b/optimus/platform/projects/utils/src/main/java/optimus/utils/datetime/ZonedDateTimeOps.java index 0eac241..6375cef 100644 --- a/optimus/platform/projects/utils/src/main/java/optimus/utils/datetime/ZonedDateTimeOps.java +++ b/optimus/platform/projects/utils/src/main/java/optimus/utils/datetime/ZonedDateTimeOps.java @@ -20,13 +20,16 @@ public final class ZonedDateTimeOps { // Disallow instantiation private ZonedDateTimeOps() {} - // JSR310 accepted short ids for timezones (see java.time.ZoneId.SHORT_IDS), but java.time does not. + // JSR310 accepted short ids for timezones (see java.time.ZoneId.SHORT_IDS), but java.time does + // not. // - // Additionally in Java 8 there is a difference in parsing behaviour with datetime strings that contain + // Additionally in Java 8 there is a difference in parsing behaviour with datetime strings that + // contain // offset and timezone, made a bit worse by JSR310 insisting on a timezone being specified. // JSR310 looks at the zone and then applies an adjustment based on the offset provided: // - // scala> javax.time.calendar.ZonedDateTime.parse("2016-11-10T23:59:59.210-04:00[US/Eastern]").toInstant + // scala> + // javax.time.calendar.ZonedDateTime.parse("2016-11-10T23:59:59.210-04:00[US/Eastern]").toInstant // res0: javax.time.Instant = 2016-11-11T03:59:59.210Z // ^^ // scala> java.time.ZonedDateTime.parse("2016-11-10T23:59:59.210-04:00[US/Eastern]").toInstant @@ -35,19 +38,22 @@ private ZonedDateTimeOps() {} // // The java.time offset bug was fixed in Java 9 // (see https://stackoverflow.com/questions/56255020/zoneddatetime-change-behavior-jdk-8-11), - // but the short id issue remains, and because ZonedDateTime.parse is being used in places where the strings come + // but the short id issue remains, and because ZonedDateTime.parse is being used in places where + // the strings come // externally, we will need to preserve the JSR310 behaviour. public static ZonedDateTime parseTreatingOffsetAndZoneIdLikeJSR310(CharSequence chars) { try { return ZonedDateTime.parse(chars); } catch (DateTimeParseException e) { - // See if it failed because of use of short ids in the string. To be compatible with javax.time + // See if it failed because of use of short ids in the string. To be compatible with + // javax.time // we have to support them by substituting them in the string to be parsed. String fixed = expandShortId(chars.toString()); if (fixed != null) { return ZonedDateTime.parse(fixed); } else { - // Not the format we aim to deal with here so throw original exception. JSR310 would have done the same. + // Not the format we aim to deal with here so throw original exception. JSR310 would have + // done the same. throw e; } } @@ -65,6 +71,8 @@ private static String expandShortId(String str) { String compatibleId = ZoneId.SHORT_IDS.getOrDefault(zoneId, null); if (compatibleId != null) { return str.replace(zoneId, compatibleId); - } else { return null; } + } else { + return null; + } } } diff --git a/optimus/platform/projects/utils/src/main/java/optimus/utils/misc/Color.java b/optimus/platform/projects/utils/src/main/java/optimus/utils/misc/Color.java index ea263eb..26c674b 100644 --- a/optimus/platform/projects/utils/src/main/java/optimus/utils/misc/Color.java +++ b/optimus/platform/projects/utils/src/main/java/optimus/utils/misc/Color.java @@ -14,32 +14,32 @@ // Color definitions consistent with awt public class Color { - public final static Color white = new Color(255, 255, 255); - public final static Color WHITE = white; - public final static Color lightGray = new Color(192, 192, 192); - public final static Color LIGHT_GRAY = lightGray; - public final static Color gray = new Color(128, 128, 128); - public final static Color GRAY = gray; - public final static Color darkGray = new Color(64, 64, 64); - public final static Color DARK_GRAY = darkGray; - public final static Color black = new Color(0, 0, 0); - public final static Color BLACK = black; - public final static Color red = new Color(255, 0, 0); - public final static Color RED = red; - public final static Color pink = new Color(255, 175, 175); - public final static Color PINK = pink; - public final static Color orange = new Color(255, 200, 0); - public final static Color ORANGE = orange; - public final static Color yellow = new Color(255, 255, 0); - public final static Color YELLOW = yellow; - public final static Color green = new Color(0, 255, 0); - public final static Color GREEN = green; - public final static Color magenta = new Color(255, 0, 255); - public final static Color MAGENTA = magenta; - public final static Color cyan = new Color(0, 255, 255); - public final static Color CYAN = cyan; - public final static Color blue = new Color(0, 0, 255); - public final static Color BLUE = blue; + public static final Color white = new Color(255, 255, 255); + public static final Color WHITE = white; + public static final Color lightGray = new Color(192, 192, 192); + public static final Color LIGHT_GRAY = lightGray; + public static final Color gray = new Color(128, 128, 128); + public static final Color GRAY = gray; + public static final Color darkGray = new Color(64, 64, 64); + public static final Color DARK_GRAY = darkGray; + public static final Color black = new Color(0, 0, 0); + public static final Color BLACK = black; + public static final Color red = new Color(255, 0, 0); + public static final Color RED = red; + public static final Color pink = new Color(255, 175, 175); + public static final Color PINK = pink; + public static final Color orange = new Color(255, 200, 0); + public static final Color ORANGE = orange; + public static final Color yellow = new Color(255, 255, 0); + public static final Color YELLOW = yellow; + public static final Color green = new Color(0, 255, 0); + public static final Color GREEN = green; + public static final Color magenta = new Color(255, 0, 255); + public static final Color MAGENTA = magenta; + public static final Color cyan = new Color(0, 255, 255); + public static final Color CYAN = cyan; + public static final Color blue = new Color(0, 0, 255); + public static final Color BLUE = blue; private int value; private static void testColorValueRange(int r, int g, int b, int a) { @@ -63,7 +63,8 @@ private static void testColorValueRange(int r, int g, int b, int a) { badComponentString = badComponentString + " Blue"; } if (rangeError) { - throw new IllegalArgumentException("Color parameter outside of expected range:" + badComponentString); + throw new IllegalArgumentException( + "Color parameter outside of expected range:" + badComponentString); } } @@ -84,5 +85,4 @@ private static int getRGB(int r, int g, int b, int a) { testColorValueRange(r, g, b, a); return ((a & 0xFF) << 24) | ((r & 0xFF) << 16) | ((g & 0xFF) << 8) | ((b & 0xFF)); } - } diff --git a/optimus/platform/projects/utils/src/main/java/org/apache/zookeeper/OptimusClientCnxnSocketNetty.java b/optimus/platform/projects/utils/src/main/java/org/apache/zookeeper/OptimusClientCnxnSocketNetty.java index 615d336..645b19a 100644 --- a/optimus/platform/projects/utils/src/main/java/org/apache/zookeeper/OptimusClientCnxnSocketNetty.java +++ b/optimus/platform/projects/utils/src/main/java/org/apache/zookeeper/OptimusClientCnxnSocketNetty.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -16,7 +16,8 @@ * limitations under the License. */ -// N.B. This is adapted from ZK 3.5.5, we need to redo the patch after move to a later version of ZK!!! +// N.B. This is adapted from ZK 3.5.5, we need to redo the patch after move to a later version of +// ZK!!! package org.apache.zookeeper; @@ -66,9 +67,9 @@ import static org.apache.zookeeper.common.X509Exception.SSLContextException; /** - * ClientCnxnSocketNetty implements ClientCnxnSocket abstract methods. - * It's responsible for connecting to server, reading/writing network traffic and - * being a layer between network data and higher level packets. + * ClientCnxnSocketNetty implements ClientCnxnSocket abstract methods. It's responsible for + * connecting to server, reading/writing network traffic and being a layer between network data and + * higher level packets. */ public class OptimusClientCnxnSocketNetty extends ClientCnxnSocket { @@ -76,7 +77,9 @@ public class OptimusClientCnxnSocketNetty extends ClientCnxnSocket { static { if (!(Info.MAJOR == 3 && Info.MINOR == 5 && Info.MICRO == 5)) { - LOG.error("The Excepted ZK version is 3.5.5, but we have {} now, need to upgrade the OptimusClientCnxnSocketNetty class to fit the ZK version", Version.getVersion()); + LOG.error( + "The Excepted ZK version is 3.5.5, but we have {} now, need to upgrade the OptimusClientCnxnSocketNetty class to fit the ZK version", + Version.getVersion()); System.exit(255); } } @@ -104,7 +107,7 @@ public class OptimusClientCnxnSocketNetty extends ClientCnxnSocket { initProperties(); } - /** + /* * lifecycles diagram: *

* loop: @@ -145,67 +148,70 @@ private Bootstrap configureBootstrapAllocator(Bootstrap bootstrap) { void connect(InetSocketAddress addr) { firstConnect = new CountDownLatch(1); - Bootstrap bootstrap = new Bootstrap() - .group(eventLoopGroup) - .channel(NettyUtils.nioOrEpollSocketChannel()) - .option(ChannelOption.SO_LINGER, -1) - .option(ChannelOption.TCP_NODELAY, true) - .handler(new ZKClientPipelineFactory(addr.getHostString(), addr.getPort())); + Bootstrap bootstrap = + new Bootstrap() + .group(eventLoopGroup) + .channel(NettyUtils.nioOrEpollSocketChannel()) + .option(ChannelOption.SO_LINGER, -1) + .option(ChannelOption.TCP_NODELAY, true) + .handler(new ZKClientPipelineFactory(addr.getHostString(), addr.getPort())); bootstrap = configureBootstrapAllocator(bootstrap); bootstrap.validate(); connectLock.lock(); try { connectFuture = bootstrap.connect(addr); - connectFuture.addListener((ChannelFutureListener) channelFuture -> { - // this lock guarantees that channel won't be assigned after cleanup(). - boolean connected = false; - connectLock.lock(); - try { - if (!channelFuture.isSuccess()) { - LOG.info("future isn't success, cause:", channelFuture.cause()); - return; - } else if (connectFuture == null) { - LOG.info("connect attempt cancelled"); - // If the connect attempt was cancelled but succeeded - // anyway, make sure to close the channel, otherwise - // we may leak a file descriptor. - channelFuture.channel().close(); - return; - } - // setup channel, variables, connection, etc. - channel = channelFuture.channel(); - - disconnected.set(false); - initialized = false; - lenBuffer.clear(); - incomingBuffer = lenBuffer; - - sendThread.primeConnection(); - updateNow(); - updateLastSendAndHeard(); - - if (sendThread.tunnelAuthInProgress()) { - waitSasl.drainPermits(); - needSasl.set(true); - sendPrimePacket(); - } else { - needSasl.set(false); - } - connected = true; - } finally { - connectFuture = null; - connectLock.unlock(); - if (connected) { - LOG.info("channel is connected: {}", channelFuture.channel()); - } - // need to wake on connect success or failure to avoid - // timing out ClientCnxn.SendThread which may be - // blocked waiting for first connect in doTransport(). - wakeupCnxn(); - firstConnect.countDown(); - } - }); + connectFuture.addListener( + (ChannelFutureListener) + channelFuture -> { + // this lock guarantees that channel won't be assigned after cleanup(). + boolean connected = false; + connectLock.lock(); + try { + if (!channelFuture.isSuccess()) { + LOG.info("future isn't success, cause:", channelFuture.cause()); + return; + } else if (connectFuture == null) { + LOG.info("connect attempt cancelled"); + // If the connect attempt was cancelled but succeeded + // anyway, make sure to close the channel, otherwise + // we may leak a file descriptor. + channelFuture.channel().close(); + return; + } + // setup channel, variables, connection, etc. + channel = channelFuture.channel(); + + disconnected.set(false); + initialized = false; + lenBuffer.clear(); + incomingBuffer = lenBuffer; + + sendThread.primeConnection(); + updateNow(); + updateLastSendAndHeard(); + + if (sendThread.tunnelAuthInProgress()) { + waitSasl.drainPermits(); + needSasl.set(true); + sendPrimePacket(); + } else { + needSasl.set(false); + } + connected = true; + } finally { + connectFuture = null; + connectLock.unlock(); + if (connected) { + LOG.info("channel is connected: {}", channelFuture.channel()); + } + // need to wake on connect success or failure to avoid + // timing out ClientCnxn.SendThread which may be + // blocked waiting for first connect in doTransport(). + wakeupCnxn(); + firstConnect.countDown(); + } + }); } finally { connectLock.unlock(); } @@ -252,8 +258,7 @@ void saslCompleted() { } @Override - void connectionPrimed() { - } + void connectionPrimed() {} @Override void packetAdded() { @@ -285,9 +290,7 @@ private void wakeupCnxn() { } @Override - void doTransport(int waitTimeOut, - List pendingQueue, - ClientCnxn cnxn) + void doTransport(int waitTimeOut, List pendingQueue, ClientCnxn cnxn) throws IOException, InterruptedException { try { if (!firstConnect.await(waitTimeOut, TimeUnit.MILLISECONDS)) { @@ -310,9 +313,8 @@ void doTransport(int waitTimeOut, // channel disconnection happened if (disconnected.get()) { addBack(head); - throw new EndOfStreamException("channel for sessionid 0x" - + Long.toHexString(sessionId) - + " is lost"); + throw new EndOfStreamException( + "channel for sessionid 0x" + Long.toHexString(sessionId) + " is lost"); } if (head != null) { doWrite(pendingQueue, head, cnxn); @@ -330,9 +332,9 @@ private void addBack(Packet head) { /** * Sends a packet to the remote peer and flushes the channel. + * * @param p packet to send. - * @return a ChannelFuture that will complete when the write operation - * succeeds or fails. + * @return a ChannelFuture that will complete when the write operation succeeds or fails. */ private ChannelFuture sendPktAndFlush(Packet p) { return sendPkt(p, true); @@ -340,20 +342,21 @@ private ChannelFuture sendPktAndFlush(Packet p) { /** * Sends a packet to the remote peer but does not flush() the channel. + * * @param p packet to send. - * @return a ChannelFuture that will complete when the write operation - * succeeds or fails. + * @return a ChannelFuture that will complete when the write operation succeeds or fails. */ private ChannelFuture sendPktOnly(Packet p) { return sendPkt(p, false); } // Use a single listener instance to reduce GC - private final GenericFutureListener> onSendPktDoneListener = f -> { - if (f.isSuccess()) { - sentCount.getAndIncrement(); - } - }; + private final GenericFutureListener> onSendPktDoneListener = + f -> { + if (f.isSuccess()) { + sentCount.getAndIncrement(); + } + }; private ChannelFuture sendPkt(Packet p, boolean doFlush) { // Assuming the packet will be sent out successfully. Because if it fails, @@ -361,9 +364,8 @@ private ChannelFuture sendPkt(Packet p, boolean doFlush) { p.createBB(); updateLastSend(); final ByteBuf writeBuffer = Unpooled.wrappedBuffer(p.bb); - final ChannelFuture result = doFlush - ? channel.writeAndFlush(writeBuffer) - : channel.write(writeBuffer); + final ChannelFuture result = + doFlush ? channel.writeAndFlush(writeBuffer) : channel.write(writeBuffer); result.addListener(onSendPktDoneListener); return result; } @@ -373,17 +375,15 @@ private void sendPrimePacket() { sendPktAndFlush(outgoingQueue.remove()); } - /** - * doWrite handles writing the packets from outgoingQueue via network to server. - */ + /** doWrite handles writing the packets from outgoingQueue via network to server. */ private void doWrite(List pendingQueue, Packet p, ClientCnxn cnxn) { updateNow(); boolean anyPacketsSent = false; while (true) { if (p != WakeupPacket.getInstance()) { - if ((p.requestHeader != null) && - (p.requestHeader.getType() != ZooDefs.OpCode.ping) && - (p.requestHeader.getType() != ZooDefs.OpCode.auth)) { + if ((p.requestHeader != null) + && (p.requestHeader.getType() != ZooDefs.OpCode.ping) + && (p.requestHeader.getType() != ZooDefs.OpCode.auth)) { p.requestHeader.setXid(cnxn.getXid()); synchronized (pendingQueue) { pendingQueue.add(p); @@ -445,8 +445,7 @@ public static Packet getInstance() { } /** - * ZKClientPipelineFactory is the netty pipeline factory for this netty - * connection implementation. + * ZKClientPipelineFactory is the netty pipeline factory for this netty connection implementation. */ private class ZKClientPipelineFactory extends ChannelInitializer { private SSLContext sslContext = null; @@ -487,8 +486,8 @@ private synchronized void initSSL(ChannelPipeline pipeline) throws SSLContextExc } /** - * ZKClientHandler is the netty handler that sits in netty upstream last - * place. It mainly handles read traffic and helps synchronize connection state. + * ZKClientHandler is the netty handler that sits in netty upstream last place. It mainly handles + * read traffic and helps synchronize connection state. */ private class ZKClientHandler extends SimpleChannelInboundHandler { AtomicBoolean channelClosed = new AtomicBoolean(false); @@ -500,8 +499,8 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { } /** - * netty handler has encountered problems. We are cleaning it up and tell outside to close - * the channel/connection. + * netty handler has encountered problems. We are cleaning it up and tell outside to close the + * channel/connection. */ private void cleanup() { if (!channelClosed.compareAndSet(false, true)) { @@ -516,8 +515,7 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Excep updateNow(); while (buf.isReadable()) { if (incomingBuffer.remaining() > buf.readableBytes()) { - int newLimit = incomingBuffer.position() - + buf.readableBytes(); + int newLimit = incomingBuffer.position() + buf.readableBytes(); incomingBuffer.limit(newLimit); } buf.readBytes(incomingBuffer); @@ -555,20 +553,18 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { } /** - * Sets the test ByteBufAllocator. This allocator will be used by all - * future instances of this class. - * It is not recommended to use this method outside of testing. - * @param allocator the ByteBufAllocator to use for all netty buffer - * allocations. + * Sets the test ByteBufAllocator. This allocator will be used by all future instances of this + * class. It is not recommended to use this method outside of testing. + * + * @param allocator the ByteBufAllocator to use for all netty buffer allocations. */ static void setTestAllocator(ByteBufAllocator allocator) { TEST_ALLOCATOR.set(allocator); } /** - * Clears the test ByteBufAllocator. The default allocator will be used - * by all future instances of this class. - * It is not recommended to use this method outside of testing. + * Clears the test ByteBufAllocator. The default allocator will be used by all future instances of + * this class. It is not recommended to use this method outside of testing. */ static void clearTestAllocator() { TEST_ALLOCATOR.set(null); diff --git a/optimus/platform/projects/utils/src/main/java/patch/MilliInstant.java b/optimus/platform/projects/utils/src/main/java/patch/MilliInstant.java index c3c9de9..fa496d0 100644 --- a/optimus/platform/projects/utils/src/main/java/patch/MilliInstant.java +++ b/optimus/platform/projects/utils/src/main/java/patch/MilliInstant.java @@ -15,18 +15,15 @@ import java.time.Instant; public class MilliInstant { - // TODO (OPTIMUS-33822): Delete all this when we know that > ms precision in Instant will not break the DAL. + // TODO (OPTIMUS-33822): Delete all this when we know that > ms precision in Instant will not + // break the DAL. - /** - Like java.time.now() but with precision limited to milliseconds - */ + /** Like java.time.now() but with precision limited to milliseconds */ public static Instant now() { return Instant.ofEpochMilli(Clock.systemUTC().millis()); } - /** - * Like java.time.now(Clock), but with precision limited to milliseconds. - */ + /** Like java.time.now(Clock), but with precision limited to milliseconds. */ public static Instant now(Clock clock) { return Instant.ofEpochMilli(clock.millis()); } diff --git a/optimus/platform/projects/utils/src/main/java/patch/MilliLocalDateTime.java b/optimus/platform/projects/utils/src/main/java/patch/MilliLocalDateTime.java index 6d87c82..b877251 100644 --- a/optimus/platform/projects/utils/src/main/java/patch/MilliLocalDateTime.java +++ b/optimus/platform/projects/utils/src/main/java/patch/MilliLocalDateTime.java @@ -16,14 +16,18 @@ import java.time.ZoneId; import java.time.temporal.ChronoUnit; -/** @see MilliInstant */ +/** + * @see MilliInstant + */ public class MilliLocalDateTime { public static LocalDateTime now() { return LocalDateTime.now().truncatedTo(ChronoUnit.MILLIS); } + public static LocalDateTime now(Clock clock) { return LocalDateTime.now(clock).truncatedTo(ChronoUnit.MILLIS); } + public static LocalDateTime now(ZoneId zone) { return LocalDateTime.now(zone).truncatedTo(ChronoUnit.MILLIS); } diff --git a/optimus/platform/projects/utils/src/main/java/patch/MilliLocalTime.java b/optimus/platform/projects/utils/src/main/java/patch/MilliLocalTime.java index 586311e..03712bf 100644 --- a/optimus/platform/projects/utils/src/main/java/patch/MilliLocalTime.java +++ b/optimus/platform/projects/utils/src/main/java/patch/MilliLocalTime.java @@ -14,7 +14,9 @@ import java.time.LocalTime; import java.time.temporal.ChronoUnit; -/** @see MilliInstant */ +/** + * @see MilliInstant + */ public class MilliLocalTime { public static LocalTime now() { return LocalTime.now().truncatedTo(ChronoUnit.MILLIS); diff --git a/optimus/platform/projects/utils/src/main/java/patch/MilliZonedDateTime.java b/optimus/platform/projects/utils/src/main/java/patch/MilliZonedDateTime.java index 0215022..c7f0ef7 100644 --- a/optimus/platform/projects/utils/src/main/java/patch/MilliZonedDateTime.java +++ b/optimus/platform/projects/utils/src/main/java/patch/MilliZonedDateTime.java @@ -16,14 +16,18 @@ import java.time.ZonedDateTime; import java.time.temporal.ChronoUnit; -/** @see MilliInstant */ +/** + * @see MilliInstant + */ public class MilliZonedDateTime { public static ZonedDateTime now() { return ZonedDateTime.now().truncatedTo(ChronoUnit.MILLIS); } + public static ZonedDateTime now(Clock clock) { return ZonedDateTime.now(clock).truncatedTo(ChronoUnit.MILLIS); } + public static ZonedDateTime now(ZoneId zone) { return ZonedDateTime.now(zone).truncatedTo(ChronoUnit.MILLIS); } diff --git a/optimus/platform/projects/utils/src/main/scala/optimus/utils/ErrorIgnoringFileVisitor.scala b/optimus/platform/projects/utils/src/main/scala/optimus/utils/ErrorIgnoringFileVisitor.scala index 591670a..450f061 100644 --- a/optimus/platform/projects/utils/src/main/scala/optimus/utils/ErrorIgnoringFileVisitor.scala +++ b/optimus/platform/projects/utils/src/main/scala/optimus/utils/ErrorIgnoringFileVisitor.scala @@ -25,7 +25,6 @@ abstract class ErrorIgnoringFileVisitor extends SimpleFileVisitor[Path] { * longer there we probably don't need to do anything about them */ override def visitFileFailed(file: Path, exc: IOException): FileVisitResult = { - ErrorIgnoringFileVisitor.log.debug(s"Unable to visit path: $file (probably it was deleted while we were scanning)") FileVisitResult.CONTINUE } } diff --git a/optimus/platform/projects/utils/src/main/scala/optimus/utils/datetime/DateTimeStorable.java b/optimus/platform/projects/utils/src/main/scala/optimus/utils/datetime/DateTimeStorable.java index 4ac6881..f4afd96 100644 --- a/optimus/platform/projects/utils/src/main/scala/optimus/utils/datetime/DateTimeStorable.java +++ b/optimus/platform/projects/utils/src/main/scala/optimus/utils/datetime/DateTimeStorable.java @@ -16,101 +16,94 @@ import java.time.temporal.JulianFields; /** - *

- * DateTimeStorable provides static utility functions to convert {@link LocalDate} - * and {@link LocalDateTime} to binary format. - *

- * The binary format for LocalDate is the Modified Julian Day Number, stored in - * the last 24 bits (3 bytes) of an int. This gives a date range of 1858/11/17 - * to 47793/05/02. - *

- * The binary format for LocalDateTime stores the date part in the first 24 bits - * (3 bytes) and the time offset in the day in the last 40 bits of a long. The time - * part is stored in microsecond (1E-06) precision. This ordering of date and time - * allows for direct comparison and sorting of the binary values. Microsecond precision - * needs 37 bits, so 3 bits are unused. - *

+ * DateTimeStorable provides static utility functions to convert {@link LocalDate} and {@link + * LocalDateTime} to binary format. * + *

The binary format for LocalDate is the Modified Julian Day Number, stored in the last 24 bits + * (3 bytes) of an int. This gives a date range of 1858/11/17 to 47793/05/02. + * + *

The binary format for LocalDateTime stores the date part in the first 24 bits (3 bytes) and + * the time offset in the day in the last 40 bits of a long. The time part is stored in microsecond + * (1E-06) precision. This ordering of date and time allows for direct comparison and sorting of the + * binary values. Microsecond precision needs 37 bits, so 3 bits are unused. */ public class DateTimeStorable { - private static int maxMJD = (int)Math.pow( 2, 24 ) - 1; // 16777215 - private static long maxMicros = 86399999999L; // 24*60*60*1000*1000 - 1 - - /** - * Convert a {@link LocalDate} to a binary representation. - * - * @param date - * @return binary - * @throws IllegalArgumentException - */ - public static int storeDate( LocalDate date ) throws IllegalArgumentException { - if( date == null ) { - throw new IllegalArgumentException( String.format( "Need LocalDateTime, not null" ) ); - } - long mjd = date.getLong(JulianFields.MODIFIED_JULIAN_DAY); - if( mjd > maxMJD || mjd < 0 ) { - throw new IllegalArgumentException( String.format( - "Cannot store %d as julian days, out of range.", mjd ) ); - } - return (int) mjd; - } + private static int maxMJD = (int) Math.pow(2, 24) - 1; // 16777215 + private static long maxMicros = 86399999999L; // 24*60*60*1000*1000 - 1 - /** - * Convert binary representation to a {@link LocalDate}. - * - * @param binary - * @return {@link LocalDate} - * @throws IllegalArgumentException - */ - public static LocalDate restoreDate( int binary ) throws IllegalArgumentException { - int mjd = binary; - if( mjd > maxMJD || mjd < 0 ) { - throw new IllegalArgumentException( String.format( - "Cannot restore %d as julian days, out of range.", mjd ) ); - } - return LocalDate.ofEpochDay(0).with(JulianFields.MODIFIED_JULIAN_DAY, mjd ); - } + /** + * Convert a {@link LocalDate} to a binary representation. + * + * @param date + * @return binary + * @throws IllegalArgumentException + */ + public static int storeDate(LocalDate date) throws IllegalArgumentException { + if (date == null) { + throw new IllegalArgumentException(String.format("Need LocalDateTime, not null")); + } + long mjd = date.getLong(JulianFields.MODIFIED_JULIAN_DAY); + if (mjd > maxMJD || mjd < 0) { + throw new IllegalArgumentException( + String.format("Cannot store %d as julian days, out of range.", mjd)); + } + return (int) mjd; + } - /** - * Convert a {@link LocalDateTime} to a binary representation. - * - * @param dateTime - * @return binary - * @throws IllegalArgumentException - */ - public static long storeDateTime(LocalDateTime dateTime) throws IllegalArgumentException { - if( dateTime == null ) { - throw new IllegalArgumentException( String.format( "Need LocalDateTime, not null" ) ); - } - long mjd = storeDate( dateTime.toLocalDate() ); - long timeOffset = dateTime.getHour(); // hours - timeOffset = timeOffset*60 + dateTime.getMinute(); // minutes - timeOffset = timeOffset*60 + dateTime.getSecond(); // seconds - timeOffset = timeOffset*(long)1E06 + dateTime.getNano()/1000; // microseconds - return ( mjd << 40 ) + timeOffset; - } + /** + * Convert binary representation to a {@link LocalDate}. + * + * @param binary + * @return {@link LocalDate} + * @throws IllegalArgumentException + */ + public static LocalDate restoreDate(int binary) throws IllegalArgumentException { + int mjd = binary; + if (mjd > maxMJD || mjd < 0) { + throw new IllegalArgumentException( + String.format("Cannot restore %d as julian days, out of range.", mjd)); + } + return LocalDate.ofEpochDay(0).with(JulianFields.MODIFIED_JULIAN_DAY, mjd); + } - /** - * Convert a binary representation to a {@link LocalDateTime}. - * - * @param binary - * @return {@link LocalDateTime} - * @throws IllegalArgumentException - */ - public static LocalDateTime restoreDateTime( long binary ) throws IllegalArgumentException { - // 24 bits will always fit in valid range, no check needed - int mjd = (int)( binary >>> 40 ); - // Only use lower 40 bits, shifting left then right, perhaps use mask instead - long micros = ( binary << 24 ) >>> 24 ; - - if( micros > maxMicros || micros < 0 ) { - throw new IllegalArgumentException( String.format( - "Cannot restore %d micros, out of range. Max = %d", micros, maxMicros ) ); - } - LocalDate julian = LocalDate.ofEpochDay(0).with(JulianFields.MODIFIED_JULIAN_DAY, mjd); - LocalDateTime dateTime = julian.atStartOfDay().plusNanos( micros*1000 ); - return dateTime; - } -} + /** + * Convert a {@link LocalDateTime} to a binary representation. + * + * @param dateTime + * @return binary + * @throws IllegalArgumentException + */ + public static long storeDateTime(LocalDateTime dateTime) throws IllegalArgumentException { + if (dateTime == null) { + throw new IllegalArgumentException(String.format("Need LocalDateTime, not null")); + } + long mjd = storeDate(dateTime.toLocalDate()); + long timeOffset = dateTime.getHour(); // hours + timeOffset = timeOffset * 60 + dateTime.getMinute(); // minutes + timeOffset = timeOffset * 60 + dateTime.getSecond(); // seconds + timeOffset = timeOffset * (long) 1E06 + dateTime.getNano() / 1000; // microseconds + return (mjd << 40) + timeOffset; + } + /** + * Convert a binary representation to a {@link LocalDateTime}. + * + * @param binary + * @return {@link LocalDateTime} + * @throws IllegalArgumentException + */ + public static LocalDateTime restoreDateTime(long binary) throws IllegalArgumentException { + // 24 bits will always fit in valid range, no check needed + int mjd = (int) (binary >>> 40); + // Only use lower 40 bits, shifting left then right, perhaps use mask instead + long micros = (binary << 24) >>> 24; + if (micros > maxMicros || micros < 0) { + throw new IllegalArgumentException( + String.format("Cannot restore %d micros, out of range. Max = %d", micros, maxMicros)); + } + LocalDate julian = LocalDate.ofEpochDay(0).with(JulianFields.MODIFIED_JULIAN_DAY, mjd); + LocalDateTime dateTime = julian.atStartOfDay().plusNanos(micros * 1000); + return dateTime; + } +}