Skip to content

Commit 97b4cb3

Browse files
committed
Refactoring VectorAPIFeature
1 parent 3526afd commit 97b4cb3

File tree

1 file changed

+102
-84
lines changed

1 file changed

+102
-84
lines changed

substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/VectorAPIFeature.java

Lines changed: 102 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,16 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
116116
int maxVectorBits = Math.max(VectorAPISupport.singleton().getMaxVectorBytes() * Byte.SIZE, 64);
117117

118118
Class<?>[] vectorElements = new Class<?>[]{float.class, double.class, byte.class, short.class, int.class, long.class};
119-
String[] vectorElementNames = new String[]{"Float", "Double", "Byte", "Short", "Int", "Long"};
120-
int[] elementSizes = new int[]{32, 64, 8, 16, 32, 64};
119+
LaneType[] laneTypes = new LaneType[vectorElements.length];
120+
for (int i = 0; i < vectorElements.length; i++) {
121+
laneTypes[i] = LaneType.fromVectorElement(vectorElements[i], i + 1);
122+
}
123+
121124
String[] vectorSizes = new String[]{"64", "128", "256", "512", "Max"};
125+
Shape[] shapes = new Shape[vectorSizes.length];
126+
for (int i = 0; i < vectorSizes.length; i++) {
127+
shapes[i] = new Shape(vectorSizes[i], i + 1);
128+
}
122129

123130
Object maxBitShape = ReflectionUtil.readStaticField(vectorShapeClass, "S_Max_BIT");
124131
access.registerFieldValueTransformer(ReflectionUtil.lookupField(vectorShapeClass, "vectorBitSize"),
@@ -131,7 +138,7 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
131138
* named using an explicit bit size, e.g., S_256_BIT rather than S_Max_BIT.
132139
*/
133140
int maxSizeIndex = Math.min(Integer.numberOfTrailingZeros(maxVectorBits / 64), vectorSizes.length - 1);
134-
String maxSizeName = vectorSizes[maxSizeIndex];
141+
String maxSizeName = shapes[maxSizeIndex].shapeName();
135142
Object preferredShape = ReflectionUtil.readStaticField(vectorShapeClass, "S_" + maxSizeName + "_BIT");
136143

137144
/*
@@ -141,67 +148,53 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
141148
*/
142149
EconomicMap<Object, AbstractSpeciesStableFields> speciesStableFields = EconomicMap.create();
143150

151+
Class<?> laneTypeClass = ReflectionUtil.lookupClass(VECTOR_API_PACKAGE_NAME + ".LaneType");
152+
UNSAFE.ensureClassInitialized(laneTypeClass);
153+
144154
Class<?> speciesClass = ReflectionUtil.lookupClass(VECTOR_API_PACKAGE_NAME + ".AbstractSpecies");
145-
Object speciesCache = Array.newInstance(speciesClass, 7, 6);
155+
Object speciesCache = Array.newInstance(speciesClass, ReflectionUtil.readStaticField(laneTypeClass, "SK_LIMIT"), ReflectionUtil.readStaticField(vectorShapeClass, "SK_LIMIT"));
146156
UNSAFE.ensureClassInitialized(speciesClass);
147157

148-
for (Class<?> vectorElement : vectorElements) {
149-
String elementName = vectorElement.getName().substring(0, 1).toUpperCase(Locale.ROOT) + vectorElement.getName().substring(1);
150-
151-
String generalVectorName = VECTOR_API_PACKAGE_NAME + "." + elementName + "Vector";
152-
Class<?> vectorClass = ReflectionUtil.lookupClass(generalVectorName);
153-
UNSAFE.ensureClassInitialized(vectorClass);
154-
Method species = ReflectionUtil.lookupMethod(vectorClass, "species", vectorShapeClass);
155-
access.registerFieldValueTransformer(ReflectionUtil.lookupField(vectorClass, "SPECIES_PREFERRED"),
158+
for (LaneType laneType : laneTypes) {
159+
Method species = ReflectionUtil.lookupMethod(laneType.vectorClass(), "species", vectorShapeClass);
160+
access.registerFieldValueTransformer(ReflectionUtil.lookupField(laneType.vectorClass(), "SPECIES_PREFERRED"),
156161
(receiver, originalValue) -> ReflectionUtil.invokeMethod(species, null, preferredShape));
157162

158-
String maxVectorName = VECTOR_API_PACKAGE_NAME + "." + elementName + "MaxVector";
159-
Class<?> maxVectorClass = ReflectionUtil.lookupClass(maxVectorName);
160-
int laneCount = VectorAPISupport.singleton().getMaxLaneCount(vectorElement);
163+
Class<?> maxVectorClass = vectorClass(laneType, shapes[shapes.length - 1]);
164+
int laneCount = VectorAPISupport.singleton().getMaxLaneCount(laneType.elementClass());
161165
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maxVectorClass, "VSIZE"),
162166
(receiver, originalValue) -> maxVectorBits);
163167
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maxVectorClass, "VLENGTH"),
164168
(receiver, originalValue) -> laneCount);
165169
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maxVectorClass, "ZERO"),
166-
(receiver, originalValue) -> makeZeroVector(maxVectorClass, vectorElement, laneCount));
170+
(receiver, originalValue) -> makeZeroVector(maxVectorClass, laneType.elementClass(), laneCount));
167171
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maxVectorClass, "IOTA"),
168-
(receiver, originalValue) -> makeIotaVector(maxVectorClass, vectorElement, laneCount));
172+
(receiver, originalValue) -> makeIotaVector(maxVectorClass, laneType.elementClass(), laneCount));
169173
}
170174

171-
Class<?> laneTypeClass = ReflectionUtil.lookupClass(VECTOR_API_PACKAGE_NAME + ".LaneType");
172-
UNSAFE.ensureClassInitialized(laneTypeClass);
173-
174175
Class<?> valueLayoutClass = ReflectionUtil.lookupClass("java.lang.foreign.ValueLayout");
175176
Method valueLayoutVarHandle = ReflectionUtil.lookupMethod(valueLayoutClass, "varHandle");
176177

177-
for (int laneTypeIndex = 0; laneTypeIndex < vectorElementNames.length; laneTypeIndex++) {
178-
String elementName = vectorElementNames[laneTypeIndex];
179-
Class<?> vectorElement = vectorElements[laneTypeIndex];
180-
int laneTypeSwitchKey = laneTypeIndex + 1;
181-
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + elementName + "Vector";
182-
Class<?> vectorClass = ReflectionUtil.lookupClass(vectorClassName);
183-
178+
for (LaneType laneType : laneTypes) {
184179
// Ensure VarHandle used by memorySegmentGet/Set is initialized.
185180
// Java 22+: ValueLayout valueLayout = (...); valueLayout.varHandle();
186-
Object valueLayout = ReflectionUtil.readStaticField(vectorClass, "ELEMENT_LAYOUT");
181+
Object valueLayout = ReflectionUtil.readStaticField(laneType.vectorClass(), "ELEMENT_LAYOUT");
187182
ReflectionUtil.invokeMethod(valueLayoutVarHandle, valueLayout);
188183

189-
for (int vectorShapeIndex = 0; vectorShapeIndex < vectorSizes.length; vectorShapeIndex++) {
190-
String size = vectorSizes[vectorShapeIndex];
191-
int vectorShapeSwitchKey = vectorShapeIndex + 1;
192-
String fieldName = "SPECIES_" + size.toUpperCase(Locale.ROOT);
193-
Object species = ReflectionUtil.readStaticField(vectorClass, fieldName);
184+
for (Shape shape : shapes) {
185+
String fieldName = "SPECIES_" + shape.shapeName().toUpperCase(Locale.ROOT);
186+
Object species = ReflectionUtil.readStaticField(laneType.vectorClass(), fieldName);
194187

195-
int vectorBitSize = vectorShapeIndex == vectorSizes.length - 1 ? maxVectorBits : Integer.parseInt(size);
188+
int vectorBitSize = shape.shapeName().equals("Max") ? maxVectorBits : Integer.parseInt(shape.shapeName());
196189
int vectorByteSize = vectorBitSize / Byte.SIZE;
197-
int laneCount = vectorShapeIndex == vectorSizes.length - 1 ? VectorAPISupport.singleton().getMaxLaneCount(vectorElement) : vectorBitSize / elementSizes[laneTypeIndex];
190+
int laneCount = shape.shapeName().equals("Max") ? VectorAPISupport.singleton().getMaxLaneCount(laneType.elementClass()) : vectorBitSize / laneType.elementSize();
198191
int laneCountLog2P1 = Integer.numberOfTrailingZeros(laneCount) + 1;
199192
Method makeDummyVector = ReflectionUtil.lookupMethod(speciesClass, "makeDummyVector");
200193
Object dummyVector = ReflectionUtil.invokeMethod(makeDummyVector, species);
201-
Object laneType = ReflectionUtil.readStaticField(laneTypeClass, elementName.toUpperCase(Locale.ROOT));
202-
speciesStableFields.put(species, new AbstractSpeciesStableFields(laneCount, laneCountLog2P1, vectorBitSize, vectorByteSize, dummyVector, laneType));
194+
Object laneTypeObject = ReflectionUtil.readStaticField(laneTypeClass, laneType.elementName().toUpperCase(Locale.ROOT));
195+
speciesStableFields.put(species, new AbstractSpeciesStableFields(laneCount, laneCountLog2P1, vectorBitSize, vectorByteSize, dummyVector, laneTypeObject));
203196

204-
Array.set(Array.get(speciesCache, laneTypeSwitchKey), vectorShapeSwitchKey, species);
197+
Array.set(Array.get(speciesCache, laneType.switchKey()), shape.switchKey(), species);
205198
}
206199
}
207200

@@ -218,20 +211,15 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
218211
* intrinsify operations, we may need to access information about a type before the analysis
219212
* has seen it.
220213
*/
221-
for (Class<?> vectorElement : vectorElements) {
222-
String elementName = vectorElement.getName().substring(0, 1).toUpperCase(Locale.ROOT) + vectorElement.getName().substring(1);
223-
for (String size : vectorSizes) {
224-
String baseName = elementName + size;
225-
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + baseName + "Vector";
226-
Class<?> shuffleClass = ReflectionUtil.lookupClass(vectorClassName + "$" + baseName + "Shuffle");
227-
UNSAFE.ensureClassInitialized(shuffleClass);
214+
for (LaneType laneType : laneTypes) {
215+
for (Shape shape : shapes) {
216+
Class<?> shuffleClass = vectorShuffleClass(laneType, shape);
217+
Class<?> maskClass = vectorMaskClass(laneType, shape);
228218
access.registerAsUsed(shuffleClass);
229-
Class<?> maskClass = ReflectionUtil.lookupClass(vectorClassName + "$" + baseName + "Mask");
230-
UNSAFE.ensureClassInitialized(maskClass);
231219
access.registerAsUsed(maskClass);
232-
if (size.equals("Max")) {
233-
int laneCount = VectorAPISupport.singleton().getMaxLaneCount(vectorElement);
234-
Class<?> shuffleElement = (vectorElement == float.class ? int.class : vectorElement == double.class ? long.class : vectorElement);
220+
if (shape.shapeName().equals("Max")) {
221+
int laneCount = VectorAPISupport.singleton().getMaxLaneCount(laneType.elementClass());
222+
Class<?> shuffleElement = (laneType.elementClass() == float.class ? int.class : laneType.elementClass() == double.class ? long.class : laneType.elementClass());
235223
access.registerFieldValueTransformer(ReflectionUtil.lookupField(shuffleClass, "VLENGTH"),
236224
(receiver, originalValue) -> laneCount);
237225
access.registerFieldValueTransformer(ReflectionUtil.lookupField(shuffleClass, "IOTA"),
@@ -247,32 +235,26 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
247235
/* Warm up caches of arithmetic and conversion operations. */
248236
WarmupData warmupData = new WarmupData();
249237

250-
for (String elementName : vectorElementNames) {
251-
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + elementName + "Vector";
252-
Class<?> vectorClass = ReflectionUtil.lookupClass(vectorClassName);
253-
UNSAFE.ensureClassInitialized(vectorClass);
254-
warmupImplCache(vectorClass, "UN_IMPL", "unaryOperations", warmupData);
255-
warmupImplCache(vectorClass, "BIN_IMPL", "binaryOperations", warmupData);
256-
warmupImplCache(vectorClass, "TERN_IMPL", "ternaryOperations", warmupData);
257-
warmupImplCache(vectorClass, "REDUCE_IMPL", "reductionOperations", warmupData);
258-
if (!elementName.equals("Float") && !elementName.equals("Double")) {
259-
warmupImplCache(vectorClass, "BIN_INT_IMPL", "broadcastIntOperations", warmupData);
238+
for (LaneType laneType : laneTypes) {
239+
warmupImplCache(laneType.vectorClass(), "UN_IMPL", "unaryOperations", warmupData);
240+
warmupImplCache(laneType.vectorClass(), "BIN_IMPL", "binaryOperations", warmupData);
241+
warmupImplCache(laneType.vectorClass(), "TERN_IMPL", "ternaryOperations", warmupData);
242+
warmupImplCache(laneType.vectorClass(), "REDUCE_IMPL", "reductionOperations", warmupData);
243+
if (!laneType.elementName().equals("Float") && !laneType.elementName().equals("Double")) {
244+
warmupImplCache(laneType.vectorClass(), "BIN_INT_IMPL", "broadcastIntOperations", warmupData);
260245
}
261246
}
262247

263248
/* Warm up caches for mapping between lane types, used by shuffles. */
264249
Method asIntegral = ReflectionUtil.lookupMethod(speciesClass, "asIntegral");
265250
Method asFloating = ReflectionUtil.lookupMethod(speciesClass, "asFloating");
266-
for (String elementName : vectorElementNames) {
267-
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + elementName + "Vector";
268-
Class<?> vectorClass = ReflectionUtil.lookupClass(vectorClassName);
269-
UNSAFE.ensureClassInitialized(vectorClass);
270-
for (String size : vectorSizes) {
271-
String fieldName = "SPECIES_" + size.toUpperCase(Locale.ROOT);
272-
Object species = ReflectionUtil.readStaticField(vectorClass, fieldName);
251+
for (LaneType laneType : laneTypes) {
252+
for (Shape shape : shapes) {
253+
String fieldName = "SPECIES_" + shape.shapeName().toUpperCase(Locale.ROOT);
254+
Object species = ReflectionUtil.readStaticField(laneType.vectorClass(), fieldName);
273255
try {
274256
asIntegral.invoke(species);
275-
if (elementName.equals("Int") || elementName.equals("Long")) {
257+
if (laneType.elementName().equals("Int") || laneType.elementName().equals("Long")) {
276258
asFloating.invoke(species);
277259
}
278260
} catch (IllegalAccessException | InvocationTargetException ex) {
@@ -288,31 +270,67 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
288270
if (DeoptimizationSupport.enabled()) {
289271
/* Build a table of payload type descriptors for deoptimization. */
290272
VectorAPIDeoptimizationSupport deoptSupport = new VectorAPIDeoptimizationSupport();
291-
for (Class<?> vectorElement : vectorElements) {
292-
int elementBytes = JavaKind.fromJavaClass(vectorElement).getByteCount();
293-
String elementName = vectorElement.getName().substring(0, 1).toUpperCase(Locale.ROOT) + vectorElement.getName().substring(1);
294-
for (String size : vectorSizes) {
295-
int vectorLength = size.equals("Max")
296-
? VectorAPISupport.singleton().getMaxLaneCount(vectorElement)
297-
: (Integer.parseInt(size) / Byte.SIZE) / elementBytes;
298-
String baseName = elementName + size;
299-
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + baseName + "Vector";
300-
301-
Class<?> vectorClass = ReflectionUtil.lookupClass(vectorClassName);
302-
deoptSupport.putLayout(vectorClass, new VectorAPIDeoptimizationSupport.PayloadLayout(vectorElement, vectorLength));
303-
304-
Class<?> shuffleClass = ReflectionUtil.lookupClass(vectorClassName + "$" + baseName + "Shuffle");
305-
Class<?> shuffleElement = (vectorElement == float.class ? int.class : vectorElement == double.class ? long.class : vectorElement);
273+
for (LaneType laneType : laneTypes) {
274+
int elementBytes = laneType.elementSize() >> 3;
275+
for (Shape shape : shapes) {
276+
int vectorLength = shape.shapeName().equals("Max")
277+
? VectorAPISupport.singleton().getMaxLaneCount(laneType.elementClass())
278+
: (Integer.parseInt(shape.shapeName()) / Byte.SIZE) / elementBytes;
279+
Class<?> vectorClass = vectorClass(laneType, shape);
280+
deoptSupport.putLayout(vectorClass, new VectorAPIDeoptimizationSupport.PayloadLayout(laneType.elementClass(), vectorLength));
281+
282+
Class<?> shuffleClass = vectorShuffleClass(laneType, shape);
283+
Class<?> shuffleElement = (laneType.elementClass() == float.class ? int.class : laneType.elementClass() == double.class ? long.class : laneType.elementClass());
306284
deoptSupport.putLayout(shuffleClass, new VectorAPIDeoptimizationSupport.PayloadLayout(shuffleElement, vectorLength));
307285

308-
Class<?> maskClass = ReflectionUtil.lookupClass(vectorClassName + "$" + baseName + "Mask");
286+
Class<?> maskClass = vectorMaskClass(laneType, shape);
309287
deoptSupport.putLayout(maskClass, new VectorAPIDeoptimizationSupport.PayloadLayout(boolean.class, vectorLength));
310288
}
311289
}
312290
ImageSingletons.add(VectorAPIDeoptimizationSupport.class, deoptSupport);
313291
}
314292
}
315293

294+
private Class<?> vectorClass(LaneType laneType, Shape shape) {
295+
String baseName = laneType.elementName() + shape.shapeName();
296+
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + baseName + "Vector";
297+
Class<?> vectorClass = ReflectionUtil.lookupClass(vectorClassName);
298+
UNSAFE.ensureClassInitialized(vectorClass);
299+
return vectorClass;
300+
}
301+
302+
private Class<?> vectorShuffleClass(LaneType laneType, Shape shape) {
303+
String baseName = laneType.elementName() + shape.shapeName();
304+
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + baseName + "Vector";
305+
Class<?> shuffleClass = ReflectionUtil.lookupClass(vectorClassName + "$" + baseName + "Shuffle");
306+
UNSAFE.ensureClassInitialized(shuffleClass);
307+
return shuffleClass;
308+
}
309+
310+
private Class<?> vectorMaskClass(LaneType laneType, Shape shape) {
311+
String baseName = laneType.elementName() + shape.shapeName();
312+
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + baseName + "Vector";
313+
Class<?> maskClass = ReflectionUtil.lookupClass(vectorClassName + "$" + baseName + "Mask");
314+
UNSAFE.ensureClassInitialized(maskClass);
315+
return maskClass;
316+
}
317+
318+
private record LaneType(Class<?> elementClass, Class<?> vectorClass, String elementName, int elementSize, int switchKey) {
319+
320+
private static LaneType fromVectorElement(Class<?> elementClass, int switchKey) {
321+
String elementName = elementClass.getName().substring(0, 1).toUpperCase(Locale.ROOT) + elementClass.getName().substring(1);
322+
String generalVectorName = VECTOR_API_PACKAGE_NAME + "." + elementName + "Vector";
323+
Class<?> vectorClass = ReflectionUtil.lookupClass(generalVectorName);
324+
UNSAFE.ensureClassInitialized(vectorClass);
325+
int elementSize = JavaKind.fromJavaClass(elementClass).getBitCount();
326+
return new LaneType(elementClass, vectorClass, elementName, elementSize, switchKey);
327+
}
328+
}
329+
330+
private record Shape(String shapeName, int switchKey) {
331+
332+
}
333+
316334
private record AbstractSpeciesStableFields(int laneCount, int laneCountLog2P1, int vectorBitSize, int vectorByteSize, Object dummyVector, Object laneType) {
317335

318336
}

0 commit comments

Comments
 (0)