*
- * tornado-test -V uk.ac.manchester.tornado.unittests.foundation.TestFloats
+ * tornado-test -V uk.ac.manchester.tornado.unittests.foundation.TestFloats
*
*/
public class TestFloats extends TornadoTestBase {
@@ -44,7 +44,7 @@ public void testFloatsCopy() {
FloatArray a = new FloatArray(numElements);
TaskGraph taskGraph = new TaskGraph("s0") //
- .task("t0", TestKernels::testFloatCopy, a) //
+ .task("t0", TestKernels::testFloatCopy222, a) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, a);
ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestKernels.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestKernels.java
index b5438e0cb3..94663bd483 100644
--- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestKernels.java
+++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestKernels.java
@@ -156,6 +156,22 @@ public static void testFloatCopy(FloatArray a) {
}
}
+ public static void testFloatCopy222(FloatArray a) {
+ for (@Parallel int i = 0; i < a.getSize(); i++) {
+ float x = a.get(i);
+ x = x + i;
+ }
+ }
+
+ public static FloatArray testFloatCopy2(FloatArray a) {
+ FloatArray temp = new FloatArray(a.getSize());
+ for (int i = 0; i < a.getSize(); i++) {
+ temp.set(i, 50.0f + a.get(i));
+ }
+
+ return temp;
+ }
+
public static void testDoublesCopy(DoubleArray a) {
for (@Parallel int i = 0; i < a.getSize(); i++) {
a.set(i, 50);
From e957e4fc8792373a9e0682f6dc61fb5c946a47ae Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 26 Jun 2024 13:56:35 +0300
Subject: [PATCH 06/54] Add TornadoMemorySegment and update methods in
FloatArray
---
.../tornado/api/types/arrays/FloatArray.java | 88 +++++++-------
.../types/arrays/TornadoMemorySegment.java | 45 ++++++++
.../plugins/OCLGraphBuilderPlugins.java | 107 +++---------------
3 files changed, 104 insertions(+), 136 deletions(-)
create mode 100644 tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/FloatArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/FloatArray.java
index 2b54f7a0d8..ee281dea0b 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/FloatArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/FloatArray.java
@@ -18,9 +18,7 @@
package uk.ac.manchester.tornado.api.types.arrays;
import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.FloatBuffer;
import java.util.Arrays;
@@ -37,7 +35,7 @@
@SegmentElementSize(size = 4)
public final class FloatArray extends TornadoNativeArray {
private static final int FLOAT_BYTES = 4;
- private MemorySegment segment;
+ private TornadoMemorySegment segment;
private int numberOfElements;
@@ -58,9 +56,7 @@ public FloatArray(int numberOfElements) {
arrayHeaderSize = (int) TornadoNativeArray.ARRAY_HEADER;
baseIndex = arrayHeaderSize / FLOAT_BYTES;
segmentByteSize = numberOfElements * FLOAT_BYTES + arrayHeaderSize;
-
- segment = Arena.ofAuto().allocate(segmentByteSize, 1);
- segment.setAtIndex(JAVA_INT, 0, numberOfElements);
+ segment = new TornadoMemorySegment(segmentByteSize, baseIndex, numberOfElements);
}
/**
@@ -121,7 +117,7 @@ public static FloatArray fromSegment(MemorySegment segment) {
long byteSize = segment.byteSize();
int numElements = (int) (byteSize / FLOAT_BYTES);
FloatArray floatArray = new FloatArray(numElements);
- MemorySegment.copy(segment, 0, floatArray.segment, floatArray.baseIndex * FLOAT_BYTES, byteSize);
+ MemorySegment.copy(segment, 0, floatArray.segment.getSegment(), floatArray.baseIndex * FLOAT_BYTES, byteSize);
return floatArray;
}
@@ -139,6 +135,39 @@ public static FloatArray fromFloatBuffer(FloatBuffer buffer) {
return floatArray;
}
+ /**
+ * Factory method to initialize a {@link FloatArray}. This method can be invoked from a Task-Graph.
+ *
+ * @param array
+ * Input Array.
+ * @param value
+ * The float value to initialize the {@code FloatArray} instance with.
+ */
+ public static void initialize(FloatArray array, float value) {
+ for (@Parallel int i = 0; i < array.getSize(); i++) {
+ array.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link FloatArray} instances into a single {@link FloatArray}.
+ *
+ * @param arrays
+ * Variable number of {@link FloatArray} objects to be concatenated.
+ * @return A new {@link FloatArray} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static FloatArray concat(FloatArray... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(FloatArray::getSize).sum();
+ FloatArray concatArray = new FloatArray(newSize);
+ long currentPositionBytes = 0;
+ for (FloatArray array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
/**
* Converts the float data from off-heap to on-heap, by copying the values of a {@link FloatArray}
* instance into a new on-heap array.
@@ -162,7 +191,7 @@ public float[] toHeapArray() {
* The float value to store at the specified index.
*/
public void set(int index, float value) {
- segment.setAtIndex(JAVA_FLOAT, baseIndex + index, value);
+ segment.setSegmentAt(index, value, baseIndex);
}
/**
@@ -173,7 +202,7 @@ public void set(int index, float value) {
* @return
*/
public float get(int index) {
- return segment.getAtIndex(JAVA_FLOAT, baseIndex + index);
+ return segment.getSegmentFrom(index, baseIndex);
}
/**
@@ -197,7 +226,7 @@ public int getElementSize() {
*/
public void init(float value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(JAVA_FLOAT, baseIndex + i, value);
+ segment.getSegment().setAtIndex(JAVA_FLOAT, baseIndex + i, value);
}
}
@@ -218,7 +247,7 @@ public int getSize() {
*/
@Override
public MemorySegment getSegment() {
- return segment.asSlice(TornadoNativeArray.ARRAY_HEADER);
+ return segment.getSegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
}
/**
@@ -228,7 +257,7 @@ public MemorySegment getSegment() {
*/
@Override
public MemorySegment getSegmentWithHeader() {
- return segment;
+ return segment.getSegment();
}
/**
@@ -252,39 +281,6 @@ public long getNumBytesOfSegment() {
return segmentByteSize - TornadoNativeArray.ARRAY_HEADER;
}
- /**
- * Factory method to initialize a {@link FloatArray}. This method can be invoked from a Task-Graph.
- *
- * @param array
- * Input Array.
- * @param value
- * The float value to initialize the {@code FloatArray} instance with.
- */
- public static void initialize(FloatArray array, float value) {
- for (@Parallel int i = 0; i < array.getSize(); i++) {
- array.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link FloatArray} instances into a single {@link FloatArray}.
- *
- * @param arrays
- * Variable number of {@link FloatArray} objects to be concatenated.
- * @return A new {@link FloatArray} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static FloatArray concat(FloatArray... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(FloatArray::getSize).sum();
- FloatArray concatArray = new FloatArray(newSize);
- long currentPositionBytes = 0;
- for (FloatArray array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
-
/**
* Extracts a slice of elements from a given {@link FloatArray}, creating a new {@link FloatArray} instance.
*
@@ -304,7 +300,7 @@ public FloatArray slice(int offset, int length) {
long sliceOffsetInBytes = TornadoNativeArray.ARRAY_HEADER + offset * FLOAT_BYTES;
long sliceByteLength = length * FLOAT_BYTES;
- MemorySegment sliceSegment = segment.asSlice(sliceOffsetInBytes, sliceByteLength);
+ MemorySegment sliceSegment = segment.getSegment().asSlice(sliceOffsetInBytes, sliceByteLength);
FloatArray slice = fromSegment(sliceSegment);
return slice;
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java
new file mode 100644
index 0000000000..46941d9440
--- /dev/null
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2024, APT Group, Department of Computer Science,
+ * The University of Manchester.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+package uk.ac.manchester.tornado.api.types.arrays;
+
+import static java.lang.foreign.ValueLayout.JAVA_INT;
+
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
+
+public class TornadoMemorySegment {
+ private MemorySegment segment;
+
+ public TornadoMemorySegment(long segmentByteSize, int basedIndex, int numElements) {
+ this.segment = Arena.ofAuto().allocate(segmentByteSize, 1);
+ segment.setAtIndex(JAVA_INT, 0, numElements);
+ }
+
+ public MemorySegment getSegment() {
+ return segment;
+ }
+
+ public void setSegmentAt(int index, float value, int baseIndex) {
+ segment.setAtIndex(ValueLayout.JAVA_FLOAT, baseIndex + index, value);
+ }
+
+ public float getSegmentFrom(int index, int baseIndex) {
+ return segment.getAtIndex(ValueLayout.JAVA_FLOAT, baseIndex + index);
+ }
+}
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
index 15726fab8a..faf9d9e42c 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
@@ -41,9 +41,10 @@
import static uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLIntBinaryIntrinsicNode.Operation.MIN;
import static uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLIntUnaryIntrinsicNode.Operation.POPCOUNT;
-import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
+import org.graalvm.word.LocationIdentity;
+
import jdk.graal.compiler.core.common.memory.BarrierType;
import jdk.graal.compiler.core.common.memory.MemoryOrderMode;
import jdk.graal.compiler.core.common.type.StampFactory;
@@ -52,6 +53,7 @@
import jdk.graal.compiler.nodes.FixedWithNextNode;
import jdk.graal.compiler.nodes.PiNode;
import jdk.graal.compiler.nodes.ValueNode;
+import jdk.graal.compiler.nodes.calc.AddNode;
import jdk.graal.compiler.nodes.calc.MulNode;
import jdk.graal.compiler.nodes.extended.BoxNode;
import jdk.graal.compiler.nodes.extended.JavaReadNode;
@@ -70,9 +72,6 @@
import jdk.graal.compiler.nodes.util.GraphUtil;
import jdk.graal.compiler.replacements.InlineDuringParsingPlugin;
import jdk.vm.ci.hotspot.HotSpotMetaAccessProvider;
-import jdk.vm.ci.meta.ResolvedJavaType;
-import org.graalvm.word.LocationIdentity;
-
import jdk.vm.ci.meta.JavaConstant;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaMethod;
@@ -80,7 +79,7 @@
import uk.ac.manchester.tornado.api.TornadoVMIntrinsics;
import uk.ac.manchester.tornado.api.exceptions.Debug;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
-import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.TornadoMemorySegment;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLArchitecture;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.AtomicAddNodeTemplate;
@@ -404,33 +403,29 @@ public static Class getValueLayoutClass(Class k) {
}
private static void registerMemoryAccessPlugins(InvocationPlugins plugins, HotSpotMetaAccessProvider metaAccessProvider) {
- Registration r = new Registration(plugins, FloatArray.class);
+ Registration r = new Registration(plugins, TornadoMemorySegment.class);
for (JavaKind kind : JavaKind.values()) {
if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- r.register(new InvocationPlugin("set", Receiver.class, int.class, float.class) {
+ r.register(new InvocationPlugin("setSegmentAt", Receiver.class, int.class, float.class, int.class) {
@Override
- public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value) {
- System.out.println("Try to apply YYYY");
- MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
+ AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
+ MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(4)));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
+ JavaWriteNode writeNode = new JavaWriteNode(JavaKind.Float, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
b.add(writeNode);
return true;
}
});
- r.register(new InvocationPlugin("get", Receiver.class, int.class) {
+ r.register(new InvocationPlugin("getSegmentFrom", Receiver.class, int.class, int.class) {
@Override
- public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index) {
- System.out.println("Try to apply XXXX" + receiver.get().toString() + " " + targetMethod.getName());
- MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
- System.out.println("Try to apply XXXX 1111");
+ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
+ AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
+ MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(4)));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- System.out.println("Try to apply XXXX 2222");
-
- JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
- System.out.println("Try to apply XXXX 3333");
- b.addPush(kind, readNode);
+ JavaReadNode readNode = new JavaReadNode(JavaKind.Float, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
+ b.addPush(JavaKind.Float, readNode);
return true;
}
});
@@ -438,74 +433,6 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
}
- // private static void registerMemoryAccessPlugins(InvocationPlugins plugins, HotSpotMetaAccessProvider metaAccessProvider) {
- // // public final ResolvedJavaType AbstractMemorySegmentImpl = lookupTypeOptional("jdk.internal.foreign.AbstractMemorySegmentImpl");
- // String AbstractMemorySegmentImpl = "jdk.internal.foreign.AbstractMemorySegmentImpl";
- //
- // // ResolvedJavaType memorySegmentImplType = types.AbstractMemorySegmentImpl;
- //
- // System.out.println("XXXXXXX REGISTER MEMORY SEGMENT");
- //
- // ResolvedJavaType resolvedType = null;
- // try {
- // // Class name you want to lookup
- // String className = "jdk.internal.foreign.AbstractMemorySegmentImpl";
- //
- // // Lookup the class using reflection
- // Class> clazz = Class.forName(className);
- // resolvedType = metaAccessProvider.lookupJavaType(clazz);
- //
- // // Use clazz as needed
- // } catch (ClassNotFoundException e) {
- // e.printStackTrace();
- // }
- // System.out.println("Resolved Class: " + resolvedType.getSourceFileName() + " " + resolvedType.toJavaName());
- //
- // Registration r = new Registration(plugins, new InvocationPlugins.ResolvedJavaSymbol(resolvedType));
- //
- // System.out.println("re " + r.toString());
- // r.register(new InvocationPlugin("setAtIndex", Receiver.class, ValueLayout.JAVA_FLOAT.getClass(), long.class, float.class) {
- //
- // @Override
- // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index, ValueNode value) {
- // System.out.println("app XXX");
- //
- // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(JavaKind.Float.getByteCount())));
- // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- // JavaWriteNode writeNode = new JavaWriteNode(JavaKind.Float, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
- // b.add(writeNode);
- // return true;
- // }
- // });
- // for (JavaKind kind : JavaKind.values()) {
- // if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- // System.out.println("re " + kind.getJavaName());
- // r.register(new InvocationPlugin("getAtIndex", Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class) {
- // @Override
- // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index) {
- // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
- // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- // JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
- // b.addPush(kind, readNode);
- // return true;
- // }
- // });
- // r.register(new InvocationPlugin("setAtIndex", Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class, kind.toJavaClass()) {
- //
- // @Override
- // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index, ValueNode value) {
- // System.out.println("app XXX");
- // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
- // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- // JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
- // b.add(writeNode);
- // return true;
- // }
- // });
- // }
- // }
- // }
-
private static void registerTornadoVMIntrinsicsPlugins(InvocationPlugins plugins) {
final InvocationPlugin tprintfPlugin2 = new InvocationPlugin("tprintf", String.class, Object[].class) {
@Override
@@ -600,7 +527,7 @@ public boolean defaultHandler(GraphBuilderContext b, ResolvedJavaMethod targetMe
private static void registerOpenCLBuiltinPlugins(InvocationPlugins plugins) {
Registration r = new Registration(plugins, java.lang.Math.class);
- // We have to overwrite some of standard math plugins
+ // We have to overwrite some standard math plugins
r.setAllowOverwrite(true);
registerOpenCLOverridesForType(r, Float.TYPE, JavaKind.Float);
registerOpenCLOverridesForType(r, Double.TYPE, JavaKind.Double);
From 4f0afebcfe700cd3e9c2b01d140101d7add58b6d Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 26 Jun 2024 16:42:03 +0300
Subject: [PATCH 07/54] Enhance Tornado API with expanded memory segment
operations
---
.../tornado/api/types/arrays/ByteArray.java | 90 +++++++++----------
.../tornado/api/types/arrays/CharArray.java | 90 +++++++++----------
.../tornado/api/types/arrays/DoubleArray.java | 90 +++++++++----------
.../tornado/api/types/arrays/FloatArray.java | 8 +-
.../api/types/arrays/HalfFloatArray.java | 90 +++++++++----------
.../tornado/api/types/arrays/IntArray.java | 89 +++++++++---------
.../tornado/api/types/arrays/LongArray.java | 90 +++++++++----------
.../tornado/api/types/arrays/ShortArray.java | 90 +++++++++----------
.../types/arrays/TornadoMemorySegment.java | 53 ++++++++++-
.../plugins/OCLGraphBuilderPlugins.java | 50 +++--------
10 files changed, 361 insertions(+), 379 deletions(-)
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ByteArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ByteArray.java
index 213f68fb2d..3049629296 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ByteArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ByteArray.java
@@ -17,10 +17,6 @@
*/
package uk.ac.manchester.tornado.api.types.arrays;
-import static java.lang.foreign.ValueLayout.JAVA_BYTE;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-
-import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.ByteBuffer;
import java.util.Arrays;
@@ -37,7 +33,7 @@
@SegmentElementSize(size = 1)
public final class ByteArray extends TornadoNativeArray {
private static final int BYTE_BYTES = 1;
- private MemorySegment segment;
+ private TornadoMemorySegment segment;
private int numberOfElements;
private int arrayHeaderSize;
@@ -56,9 +52,7 @@ public ByteArray(int numberOfElements) {
arrayHeaderSize = (int) TornadoNativeArray.ARRAY_HEADER;
baseIndex = arrayHeaderSize / BYTE_BYTES;
segmentByteSize = numberOfElements * BYTE_BYTES + arrayHeaderSize;
-
- segment = Arena.ofAuto().allocate(segmentByteSize, 1);
- segment.setAtIndex(JAVA_INT, 0, numberOfElements);
+ segment = new TornadoMemorySegment(segmentByteSize, baseIndex, numberOfElements);
}
/**
@@ -119,7 +113,7 @@ public static ByteArray fromSegment(MemorySegment segment) {
long byteSize = segment.byteSize();
int numElements = (int) (byteSize / BYTE_BYTES);
ByteArray byteArray = new ByteArray(numElements);
- MemorySegment.copy(segment, 0, byteArray.segment, byteArray.baseIndex * BYTE_BYTES, byteSize);
+ MemorySegment.copy(segment, 0, byteArray.segment.getSegment(), byteArray.baseIndex * BYTE_BYTES, byteSize);
return byteArray;
}
@@ -137,6 +131,39 @@ public static ByteArray fromByteBuffer(ByteBuffer buffer) {
return byteArray;
}
+ /**
+ * Factory method to initialize a {@link ByteArray}. This method can be invoked from a Task-Graph.
+ *
+ * @param array
+ * Input Array.
+ * @param value
+ * The float value to initialize the {@code ByteArray} instance with.
+ */
+ public static void initialize(ByteArray array, byte value) {
+ for (@Parallel int i = 0; i < array.getSize(); i++) {
+ array.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link ByteArray} instances into a single {@link ByteArray}.
+ *
+ * @param arrays
+ * Variable number of {@link ByteArray} objects to be concatenated.
+ * @return A new {@link ByteArray} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static ByteArray concat(ByteArray... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(ByteArray::getSize).sum();
+ ByteArray concatArray = new ByteArray(newSize);
+ long currentPositionBytes = 0;
+ for (ByteArray array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
/**
* Converts the byte data from off-heap to on-heap, by copying the values of a {@link ByteArray}
* instance into a new on-heap array.
@@ -160,7 +187,7 @@ public byte[] toHeapArray() {
* The byte value to store at the specified index.
*/
public void set(int index, byte value) {
- segment.setAtIndex(JAVA_BYTE, baseIndex + index, value);
+ segment.setAtIndex(index, value, baseIndex);
}
/**
@@ -171,7 +198,7 @@ public void set(int index, byte value) {
* @return en element byte of the off-heap array
*/
public byte get(int index) {
- return segment.getAtIndex(JAVA_BYTE, baseIndex + index);
+ return segment.getByteAtIndex(index, baseIndex);
}
/**
@@ -195,7 +222,7 @@ public int getElementSize() {
*/
public void init(byte value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(JAVA_BYTE, baseIndex + i, value);
+ segment.setAtIndex(baseIndex + i, value, baseIndex);
}
}
@@ -215,7 +242,7 @@ public int getSize() {
*/
@Override
public MemorySegment getSegment() {
- return segment.asSlice(TornadoNativeArray.ARRAY_HEADER);
+ return segment.getSegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
}
/**
@@ -225,7 +252,7 @@ public MemorySegment getSegment() {
*/
@Override
public MemorySegment getSegmentWithHeader() {
- return segment;
+ return segment.getSegment();
}
/**
@@ -249,39 +276,6 @@ public long getNumBytesOfSegment() {
return segmentByteSize - TornadoNativeArray.ARRAY_HEADER;
}
- /**
- * Factory method to initialize a {@link ByteArray}. This method can be invoked from a Task-Graph.
- *
- * @param array
- * Input Array.
- * @param value
- * The float value to initialize the {@code ByteArray} instance with.
- */
- public static void initialize(ByteArray array, byte value) {
- for (@Parallel int i = 0; i < array.getSize(); i++) {
- array.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link ByteArray} instances into a single {@link ByteArray}.
- *
- * @param arrays
- * Variable number of {@link ByteArray} objects to be concatenated.
- * @return A new {@link ByteArray} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static ByteArray concat(ByteArray... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(ByteArray::getSize).sum();
- ByteArray concatArray = new ByteArray(newSize);
- long currentPositionBytes = 0;
- for (ByteArray array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
-
/**
* Extracts a slice of elements from a given {@link ByteArray}, creating a new {@link ByteArray} instance.
*
@@ -301,7 +295,7 @@ public ByteArray slice(int offset, int length) {
long sliceOffsetInBytes = TornadoNativeArray.ARRAY_HEADER + offset * BYTE_BYTES;
long sliceByteLength = length * BYTE_BYTES;
- MemorySegment sliceSegment = segment.asSlice(sliceOffsetInBytes, sliceByteLength);
+ MemorySegment sliceSegment = segment.getSegment().asSlice(sliceOffsetInBytes, sliceByteLength);
ByteArray slice = fromSegment(sliceSegment);
return slice;
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/CharArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/CharArray.java
index d4a4b70008..a3d47b4bd4 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/CharArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/CharArray.java
@@ -17,10 +17,6 @@
*/
package uk.ac.manchester.tornado.api.types.arrays;
-import static java.lang.foreign.ValueLayout.JAVA_CHAR;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-
-import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.CharBuffer;
import java.util.Arrays;
@@ -37,7 +33,7 @@
@SegmentElementSize(size = 2)
public final class CharArray extends TornadoNativeArray {
private static final int CHAR_BYTES = 2;
- private MemorySegment segment;
+ private TornadoMemorySegment segment;
private int numberOfElements;
private int arrayHeaderSize;
@@ -56,9 +52,7 @@ public CharArray(int numberOfElements) {
arrayHeaderSize = (int) TornadoNativeArray.ARRAY_HEADER;
baseIndex = arrayHeaderSize / CHAR_BYTES;
segmentByteSize = numberOfElements * CHAR_BYTES + arrayHeaderSize;
-
- segment = Arena.ofAuto().allocate(segmentByteSize, 1);
- segment.setAtIndex(JAVA_INT, 0, numberOfElements);
+ segment = new TornadoMemorySegment(segmentByteSize, baseIndex, numberOfElements);
}
/**
@@ -119,7 +113,7 @@ public static CharArray fromSegment(MemorySegment segment) {
long byteSize = segment.byteSize();
int numElements = (int) (byteSize / CHAR_BYTES);
CharArray charArray = new CharArray(numElements);
- MemorySegment.copy(segment, 0, charArray.segment, charArray.baseIndex * CHAR_BYTES, byteSize);
+ MemorySegment.copy(segment, 0, charArray.segment.getSegment(), charArray.baseIndex * CHAR_BYTES, byteSize);
return charArray;
}
@@ -137,6 +131,39 @@ public static CharArray fromCharBuffer(CharBuffer buffer) {
return charArray;
}
+ /**
+ * Factory method to initialize a {@link CharArray}. This method can be invoked from a Task-Graph.
+ *
+ * @param array
+ * Input Array.
+ * @param value
+ * The float value to initialize the {@code CharArray} instance with.
+ */
+ public static void initialize(CharArray array, char value) {
+ for (@Parallel int i = 0; i < array.getSize(); i++) {
+ array.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link CharArray} instances into a single {@link CharArray}.
+ *
+ * @param arrays
+ * Variable number of {@link CharArray} objects to be concatenated.
+ * @return A new {@link CharArray} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static CharArray concat(CharArray... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(CharArray::getSize).sum();
+ CharArray concatArray = new CharArray(newSize);
+ long currentPositionBytes = 0;
+ for (CharArray array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
/**
* Sets all the values of the {@link CharArray} instance to \u0000, the default char value.
*/
@@ -173,7 +200,7 @@ public char[] toHeapArray() {
* The char value to store at the specified index.
*/
public void set(int index, char value) {
- segment.setAtIndex(JAVA_CHAR, baseIndex + index, value);
+ segment.setAtIndex(index, value, baseIndex);
}
/**
@@ -184,7 +211,7 @@ public void set(int index, char value) {
* @return
*/
public char get(int index) {
- return segment.getAtIndex(JAVA_CHAR, baseIndex + index);
+ return segment.getCharAtIndex(index, baseIndex);
}
/**
@@ -195,7 +222,7 @@ public char get(int index) {
*/
public void init(char value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(JAVA_CHAR, baseIndex + i, value);
+ segment.setAtIndex(baseIndex + i, value, baseIndex);
}
}
@@ -216,7 +243,7 @@ public int getSize() {
*/
@Override
public MemorySegment getSegment() {
- return segment.asSlice(TornadoNativeArray.ARRAY_HEADER);
+ return segment.getSegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
}
/**
@@ -226,7 +253,7 @@ public MemorySegment getSegment() {
*/
@Override
public MemorySegment getSegmentWithHeader() {
- return segment;
+ return segment.getSegment();
}
/**
@@ -250,39 +277,6 @@ public long getNumBytesOfSegment() {
return segmentByteSize - TornadoNativeArray.ARRAY_HEADER;
}
- /**
- * Factory method to initialize a {@link CharArray}. This method can be invoked from a Task-Graph.
- *
- * @param array
- * Input Array.
- * @param value
- * The float value to initialize the {@code CharArray} instance with.
- */
- public static void initialize(CharArray array, char value) {
- for (@Parallel int i = 0; i < array.getSize(); i++) {
- array.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link CharArray} instances into a single {@link CharArray}.
- *
- * @param arrays
- * Variable number of {@link CharArray} objects to be concatenated.
- * @return A new {@link CharArray} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static CharArray concat(CharArray... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(CharArray::getSize).sum();
- CharArray concatArray = new CharArray(newSize);
- long currentPositionBytes = 0;
- for (CharArray array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
-
/**
* Extracts a slice of elements from a given {@link CharArray}, creating a new {@link CharArray} instance.
*
@@ -302,7 +296,7 @@ public CharArray slice(int offset, int length) {
long sliceOffsetInBytes = TornadoNativeArray.ARRAY_HEADER + offset * CHAR_BYTES;
long sliceByteLength = length * CHAR_BYTES;
- MemorySegment sliceSegment = segment.asSlice(sliceOffsetInBytes, sliceByteLength);
+ MemorySegment sliceSegment = segment.getSegment().asSlice(sliceOffsetInBytes, sliceByteLength);
CharArray slice = fromSegment(sliceSegment);
return slice;
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/DoubleArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/DoubleArray.java
index 5674350261..a272acf541 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/DoubleArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/DoubleArray.java
@@ -17,10 +17,6 @@
*/
package uk.ac.manchester.tornado.api.types.arrays;
-import static java.lang.foreign.ValueLayout.JAVA_DOUBLE;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-
-import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.DoubleBuffer;
import java.util.Arrays;
@@ -37,7 +33,7 @@
@SegmentElementSize(size = 8)
public final class DoubleArray extends TornadoNativeArray {
private static final int DOUBLE_BYTES = 8;
- private MemorySegment segment;
+ private TornadoMemorySegment segment;
private int numberOfElements;
private int arrayHeaderSize;
@@ -58,9 +54,7 @@ public DoubleArray(int numberOfElements) {
assert arrayHeaderSize >= 8;
baseIndex = arrayHeaderSize / DOUBLE_BYTES;
segmentByteSize = numberOfElements * DOUBLE_BYTES + arrayHeaderSize;
-
- segment = Arena.ofAuto().allocate(segmentByteSize, 1);
- segment.setAtIndex(JAVA_INT, 0, numberOfElements);
+ segment = new TornadoMemorySegment(segmentByteSize, baseIndex, numberOfElements);
}
/**
@@ -121,7 +115,7 @@ public static DoubleArray fromSegment(MemorySegment segment) {
long byteSize = segment.byteSize();
int numElements = (int) (byteSize / DOUBLE_BYTES);
DoubleArray doubleArray = new DoubleArray(numElements);
- MemorySegment.copy(segment, 0, doubleArray.segment, doubleArray.baseIndex * DOUBLE_BYTES, byteSize);
+ MemorySegment.copy(segment, 0, doubleArray.segment.getSegment(), doubleArray.baseIndex * DOUBLE_BYTES, byteSize);
return doubleArray;
}
@@ -139,6 +133,39 @@ public static DoubleArray fromDoubleBuffer(DoubleBuffer buffer) {
return doubleArray;
}
+ /**
+ * Factory method to initialize a {@link DoubleArray}. This method can be invoked from a Task-Graph.
+ *
+ * @param array
+ * Input Array.
+ * @param value
+ * The float value to initialize the {@code DoubleArray} instance with.
+ */
+ public static void initialize(DoubleArray array, double value) {
+ for (@Parallel int i = 0; i < array.getSize(); i++) {
+ array.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link DoubleArray} instances into a single {@link DoubleArray}.
+ *
+ * @param arrays
+ * Variable number of {@link DoubleArray} objects to be concatenated.
+ * @return A new {@link DoubleArray} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static DoubleArray concat(DoubleArray... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(DoubleArray::getSize).sum();
+ DoubleArray concatArray = new DoubleArray(newSize);
+ long currentPositionBytes = 0;
+ for (DoubleArray array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
/**
* Converts the double data from off-heap to on-heap, by copying the values of a {@link DoubleArray}
* instance into a new on-heap array.
@@ -162,7 +189,7 @@ public double[] toHeapArray() {
* The double value to store at the specified index.
*/
public void set(int index, double value) {
- segment.setAtIndex(JAVA_DOUBLE, baseIndex + index, value);
+ segment.setAtIndex(index, value, baseIndex);
}
/**
@@ -173,7 +200,7 @@ public void set(int index, double value) {
* @return
*/
public double get(int index) {
- return segment.getAtIndex(JAVA_DOUBLE, baseIndex + index);
+ return segment.getDoubleAtIndex(index, baseIndex);
}
/**
@@ -197,7 +224,7 @@ public int getElementSize() {
*/
public void init(double value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(JAVA_DOUBLE, baseIndex + i, value);
+ segment.setAtIndex(i, value, baseIndex);
}
}
@@ -218,7 +245,7 @@ public int getSize() {
*/
@Override
public MemorySegment getSegment() {
- return segment.asSlice(TornadoNativeArray.ARRAY_HEADER);
+ return segment.getSegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
}
/**
@@ -228,7 +255,7 @@ public MemorySegment getSegment() {
*/
@Override
public MemorySegment getSegmentWithHeader() {
- return segment;
+ return segment.getSegment();
}
/**
@@ -252,39 +279,6 @@ public long getNumBytesOfSegment() {
return segmentByteSize - TornadoNativeArray.ARRAY_HEADER;
}
- /**
- * Factory method to initialize a {@link DoubleArray}. This method can be invoked from a Task-Graph.
- *
- * @param array
- * Input Array.
- * @param value
- * The float value to initialize the {@code DoubleArray} instance with.
- */
- public static void initialize(DoubleArray array, double value) {
- for (@Parallel int i = 0; i < array.getSize(); i++) {
- array.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link DoubleArray} instances into a single {@link DoubleArray}.
- *
- * @param arrays
- * Variable number of {@link DoubleArray} objects to be concatenated.
- * @return A new {@link DoubleArray} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static DoubleArray concat(DoubleArray... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(DoubleArray::getSize).sum();
- DoubleArray concatArray = new DoubleArray(newSize);
- long currentPositionBytes = 0;
- for (DoubleArray array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
-
/**
* Extracts a slice of elements from a given {@link DoubleArray}, creating a new {@link DoubleArray} instance.
*
@@ -304,7 +298,7 @@ public DoubleArray slice(int offset, int length) {
long sliceOffsetInBytes = TornadoNativeArray.ARRAY_HEADER + offset * DOUBLE_BYTES;
long sliceByteLength = length * DOUBLE_BYTES;
- MemorySegment sliceSegment = segment.asSlice(sliceOffsetInBytes, sliceByteLength);
+ MemorySegment sliceSegment = segment.getSegment().asSlice(sliceOffsetInBytes, sliceByteLength);
DoubleArray slice = fromSegment(sliceSegment);
return slice;
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/FloatArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/FloatArray.java
index ee281dea0b..85527696ce 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/FloatArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/FloatArray.java
@@ -17,8 +17,6 @@
*/
package uk.ac.manchester.tornado.api.types.arrays;
-import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
-
import java.lang.foreign.MemorySegment;
import java.nio.FloatBuffer;
import java.util.Arrays;
@@ -191,7 +189,7 @@ public float[] toHeapArray() {
* The float value to store at the specified index.
*/
public void set(int index, float value) {
- segment.setSegmentAt(index, value, baseIndex);
+ segment.setAtIndex(index, value, baseIndex);
}
/**
@@ -202,7 +200,7 @@ public void set(int index, float value) {
* @return
*/
public float get(int index) {
- return segment.getSegmentFrom(index, baseIndex);
+ return segment.getFloatAtIndex(index, baseIndex);
}
/**
@@ -226,7 +224,7 @@ public int getElementSize() {
*/
public void init(float value) {
for (int i = 0; i < getSize(); i++) {
- segment.getSegment().setAtIndex(JAVA_FLOAT, baseIndex + i, value);
+ segment.setAtIndex(i, value, baseIndex);
}
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java
index f53a164811..92eb6f90c0 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java
@@ -17,10 +17,6 @@
*/
package uk.ac.manchester.tornado.api.types.arrays;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-import static java.lang.foreign.ValueLayout.JAVA_SHORT;
-
-import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.Arrays;
@@ -38,7 +34,7 @@
public final class HalfFloatArray extends TornadoNativeArray {
private static final int HALF_FLOAT_BYTES = 2;
- private MemorySegment segment;
+ private TornadoMemorySegment segment;
private int numberOfElements;
@@ -59,9 +55,7 @@ public HalfFloatArray(int numberOfElements) {
arrayHeaderSize = (int) TornadoNativeArray.ARRAY_HEADER;
baseIndex = arrayHeaderSize / HALF_FLOAT_BYTES;
segmentByteSize = numberOfElements * HALF_FLOAT_BYTES + arrayHeaderSize;
-
- segment = Arena.ofAuto().allocate(segmentByteSize, 1);
- segment.setAtIndex(JAVA_INT, 0, numberOfElements);
+ segment = new TornadoMemorySegment(segmentByteSize, baseIndex, numberOfElements);
}
/**
@@ -122,10 +116,43 @@ public static HalfFloatArray fromSegment(MemorySegment segment) {
long byteSize = segment.byteSize();
int numElements = (int) (byteSize / HALF_FLOAT_BYTES);
HalfFloatArray halfFloatArray = new HalfFloatArray(numElements);
- MemorySegment.copy(segment, 0, halfFloatArray.segment, (long) halfFloatArray.baseIndex * HALF_FLOAT_BYTES, byteSize);
+ MemorySegment.copy(segment, 0, halfFloatArray.segment.getSegment(), (long) halfFloatArray.baseIndex * HALF_FLOAT_BYTES, byteSize);
return halfFloatArray;
}
+ /**
+ * Factory method to initialize a {@link HalfFloatArray}. This method can be invoked from a Task-Graph.
+ *
+ * @param array
+ * Input Array.
+ * @param value
+ * The float value to initialize the {@code HalfFloatArray} instance with.
+ */
+ public static void initialize(HalfFloatArray array, HalfFloat value) {
+ for (@Parallel int i = 0; i < array.getSize(); i++) {
+ array.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link HalfFloatArray} instances into a single {@link HalfFloatArray}.
+ *
+ * @param arrays
+ * Variable number of {@link HalfFloatArray} objects to be concatenated.
+ * @return A new {@link HalfFloatArray} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static HalfFloatArray concat(HalfFloatArray... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(HalfFloatArray::getSize).sum();
+ HalfFloatArray concatArray = new HalfFloatArray(newSize);
+ long currentPositionBytes = 0;
+ for (HalfFloatArray array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
/**
* Converts the {@link HalfFloat} data from off-heap to on-heap, by copying the values of a {@link HalfFloatArray}
* instance into a new on-heap {@link HalfFloat}.
@@ -164,7 +191,7 @@ public short[] toShortArray() {
* The {@link HalfFloat} value to store at the specified index.
*/
public void set(int index, HalfFloat value) {
- segment.setAtIndex(JAVA_SHORT, baseIndex + index, value.getHalfFloatValue());
+ segment.setAtIndex(index, value.getHalfFloatValue(), baseIndex);
}
/**
@@ -175,7 +202,7 @@ public void set(int index, HalfFloat value) {
* @return
*/
public HalfFloat get(int index) {
- short halfFloatValue = segment.getAtIndex(JAVA_SHORT, baseIndex + index);
+ short halfFloatValue = segment.getShortAtIndex(index, baseIndex);
return new HalfFloat(halfFloatValue);
}
@@ -200,7 +227,7 @@ public int getElementSize() {
*/
public void init(HalfFloat value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(JAVA_SHORT, baseIndex + i, value.getHalfFloatValue());
+ segment.setAtIndex(i, value.getHalfFloatValue(), baseIndex);
}
}
@@ -221,7 +248,7 @@ public int getSize() {
*/
@Override
public MemorySegment getSegment() {
- return segment.asSlice(TornadoNativeArray.ARRAY_HEADER);
+ return segment.getSegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
}
/**
@@ -231,7 +258,7 @@ public MemorySegment getSegment() {
*/
@Override
public MemorySegment getSegmentWithHeader() {
- return segment;
+ return segment.getSegment();
}
/**
@@ -255,39 +282,6 @@ public long getNumBytesOfSegment() {
return segmentByteSize - TornadoNativeArray.ARRAY_HEADER;
}
- /**
- * Factory method to initialize a {@link HalfFloatArray}. This method can be invoked from a Task-Graph.
- *
- * @param array
- * Input Array.
- * @param value
- * The float value to initialize the {@code HalfFloatArray} instance with.
- */
- public static void initialize(HalfFloatArray array, HalfFloat value) {
- for (@Parallel int i = 0; i < array.getSize(); i++) {
- array.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link HalfFloatArray} instances into a single {@link HalfFloatArray}.
- *
- * @param arrays
- * Variable number of {@link HalfFloatArray} objects to be concatenated.
- * @return A new {@link HalfFloatArray} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static HalfFloatArray concat(HalfFloatArray... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(HalfFloatArray::getSize).sum();
- HalfFloatArray concatArray = new HalfFloatArray(newSize);
- long currentPositionBytes = 0;
- for (HalfFloatArray array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
-
/**
* Extracts a slice of elements from a given {@linkHalfFloatArray}, creating a new {@linkHalfFloatArray} instance.
*
@@ -307,7 +301,7 @@ public HalfFloatArray slice(int offset, int length) {
long sliceOffsetInBytes = TornadoNativeArray.ARRAY_HEADER + offset * HALF_FLOAT_BYTES;
long sliceByteLength = length * HALF_FLOAT_BYTES;
- MemorySegment sliceSegment = segment.asSlice(sliceOffsetInBytes, sliceByteLength);
+ MemorySegment sliceSegment = segment.getSegment().asSlice(sliceOffsetInBytes, sliceByteLength);
HalfFloatArray slice = fromSegment(sliceSegment);
return slice;
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/IntArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/IntArray.java
index ffbf25c4db..46be448de6 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/IntArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/IntArray.java
@@ -17,9 +17,6 @@
*/
package uk.ac.manchester.tornado.api.types.arrays;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-
-import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.IntBuffer;
import java.util.Arrays;
@@ -37,7 +34,7 @@
public final class IntArray extends TornadoNativeArray {
private static final int INT_BYTES = 4;
private int numberOfElements;
- private MemorySegment segment;
+ private TornadoMemorySegment segment;
private int arrayHeaderSize;
private int baseIndex;
@@ -55,9 +52,7 @@ public IntArray(int numberOfElements) {
arrayHeaderSize = (int) TornadoNativeArray.ARRAY_HEADER;
baseIndex = arrayHeaderSize / INT_BYTES;
segmentByteSize = numberOfElements * INT_BYTES + arrayHeaderSize;
-
- segment = Arena.ofAuto().allocate(segmentByteSize, 1);
- segment.setAtIndex(JAVA_INT, 0, numberOfElements);
+ segment = new TornadoMemorySegment(segmentByteSize, baseIndex, numberOfElements);
}
/**
@@ -118,7 +113,7 @@ public static IntArray fromSegment(MemorySegment segment) {
long byteSize = segment.byteSize();
int numElements = (int) (byteSize / INT_BYTES);
IntArray intArray = new IntArray(numElements);
- MemorySegment.copy(segment, 0, intArray.segment, intArray.baseIndex * INT_BYTES, byteSize);
+ MemorySegment.copy(segment, 0, intArray.segment.getSegment(), intArray.baseIndex * INT_BYTES, byteSize);
return intArray;
}
@@ -136,6 +131,39 @@ public static IntArray fromIntBuffer(IntBuffer buffer) {
return intArray;
}
+ /**
+ * Factory method to initialize a {@link IntArray}. This method can be invoked from a Task-Graph.
+ *
+ * @param array
+ * Input Array.
+ * @param value
+ * The float value to initialize the {@code IntArray} instance with.
+ */
+ public static void initialize(IntArray array, int value) {
+ for (@Parallel int i = 0; i < array.getSize(); i++) {
+ array.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link IntArray} instances into a single {@link IntArray}.
+ *
+ * @param arrays
+ * Variable number of {@link IntArray} objects to be concatenated.
+ * @return A new {@link IntArray} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static IntArray concat(IntArray... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(IntArray::getSize).sum();
+ IntArray concatArray = new IntArray(newSize);
+ long currentPositionBytes = 0;
+ for (IntArray array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
/**
* Converts the int data from off-heap to on-heap, by copying the values of a {@link IntArray}
* instance into a new on-heap array.
@@ -159,7 +187,7 @@ public int[] toHeapArray() {
* The int value to store at the specified index.
*/
public void set(int index, int value) {
- segment.setAtIndex(JAVA_INT, baseIndex + index, value);
+ segment.setAtIndex(index, value, baseIndex);
}
/**
@@ -170,7 +198,7 @@ public void set(int index, int value) {
* @return
*/
public int get(int index) {
- return segment.getAtIndex(JAVA_INT, baseIndex + index);
+ return segment.getIntAtIndex(index, baseIndex);
}
/**
@@ -194,7 +222,7 @@ public int getElementSize() {
*/
public void init(int value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(JAVA_INT, baseIndex + i, value);
+ segment.setAtIndex(i, value, baseIndex);
}
}
@@ -236,7 +264,7 @@ public long getNumBytesOfSegment() {
*/
@Override
public MemorySegment getSegment() {
- return segment.asSlice(TornadoNativeArray.ARRAY_HEADER);
+ return segment.getSegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
}
/**
@@ -246,40 +274,7 @@ public MemorySegment getSegment() {
*/
@Override
public MemorySegment getSegmentWithHeader() {
- return segment;
- }
-
- /**
- * Factory method to initialize a {@link IntArray}. This method can be invoked from a Task-Graph.
- *
- * @param array
- * Input Array.
- * @param value
- * The float value to initialize the {@code IntArray} instance with.
- */
- public static void initialize(IntArray array, int value) {
- for (@Parallel int i = 0; i < array.getSize(); i++) {
- array.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link IntArray} instances into a single {@link IntArray}.
- *
- * @param arrays
- * Variable number of {@link IntArray} objects to be concatenated.
- * @return A new {@link IntArray} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static IntArray concat(IntArray... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(IntArray::getSize).sum();
- IntArray concatArray = new IntArray(newSize);
- long currentPositionBytes = 0;
- for (IntArray array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
+ return segment.getSegment();
}
/**
@@ -301,7 +296,7 @@ public IntArray slice(int offset, int length) {
long sliceOffsetInBytes = TornadoNativeArray.ARRAY_HEADER + offset * INT_BYTES;
long sliceByteLength = length * INT_BYTES;
- MemorySegment sliceSegment = segment.asSlice(sliceOffsetInBytes, sliceByteLength);
+ MemorySegment sliceSegment = segment.getSegment().asSlice(sliceOffsetInBytes, sliceByteLength);
IntArray slice = fromSegment(sliceSegment);
return slice;
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/LongArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/LongArray.java
index f945bf667b..931319624d 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/LongArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/LongArray.java
@@ -17,10 +17,6 @@
*/
package uk.ac.manchester.tornado.api.types.arrays;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-import static java.lang.foreign.ValueLayout.JAVA_LONG;
-
-import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.LongBuffer;
import java.util.Arrays;
@@ -37,7 +33,7 @@
@SegmentElementSize(size = 8)
public final class LongArray extends TornadoNativeArray {
private static final int LONG_BYTES = 8;
- private MemorySegment segment;
+ private TornadoMemorySegment segment;
private int numberOfElements;
private int arrayHeaderSize;
@@ -55,10 +51,8 @@ public LongArray(int numberOfElements) {
this.numberOfElements = numberOfElements;
arrayHeaderSize = (int) TornadoNativeArray.ARRAY_HEADER;
baseIndex = arrayHeaderSize / LONG_BYTES;
-
segmentByteSize = numberOfElements * LONG_BYTES + arrayHeaderSize;
- segment = Arena.ofAuto().allocate(segmentByteSize, 1);
- segment.setAtIndex(JAVA_INT, 0, numberOfElements);
+ segment = new TornadoMemorySegment(segmentByteSize, baseIndex, numberOfElements);
}
/**
@@ -119,7 +113,7 @@ public static LongArray fromSegment(MemorySegment segment) {
long byteSize = segment.byteSize();
int numElements = (int) (byteSize / LONG_BYTES);
LongArray longArray = new LongArray(numElements);
- MemorySegment.copy(segment, 0, longArray.segment, longArray.baseIndex * LONG_BYTES, byteSize);
+ MemorySegment.copy(segment, 0, longArray.segment.getSegment(), longArray.baseIndex * LONG_BYTES, byteSize);
return longArray;
}
@@ -137,6 +131,39 @@ public static LongArray fromLongBuffer(LongBuffer buffer) {
return longArray;
}
+ /**
+ * Factory method to initialize a {@link LongArray}. This method can be invoked from a Task-Graph.
+ *
+ * @param array
+ * Input Array.
+ * @param value
+ * The float value to initialize the {@code LongArray} instance with.
+ */
+ public static void initialize(LongArray array, long value) {
+ for (@Parallel int i = 0; i < array.getSize(); i++) {
+ array.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link LongArray} instances into a single {@link LongArray}.
+ *
+ * @param arrays
+ * Variable number of {@link LongArray} objects to be concatenated.
+ * @return A new {@link LongArray} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static LongArray concat(LongArray... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(LongArray::getSize).sum();
+ LongArray concatArray = new LongArray(newSize);
+ long currentPositionBytes = 0;
+ for (LongArray array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
/**
* Converts the long data from off-heap to on-heap, by copying the values of a {@link LongArray}
* instance into a new on-heap array.
@@ -160,7 +187,7 @@ public long[] toHeapArray() {
* The long value to store at the specified index.
*/
public void set(int index, long value) {
- segment.setAtIndex(JAVA_LONG, baseIndex + index, value);
+ segment.setAtIndex(index, value, baseIndex);
}
/**
@@ -171,7 +198,7 @@ public void set(int index, long value) {
* @return
*/
public long get(int index) {
- return segment.getAtIndex(JAVA_LONG, baseIndex + index);
+ return segment.getLongAtIndex(index, baseIndex);
}
/**
@@ -195,7 +222,7 @@ public int getElementSize() {
*/
public void init(long value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(JAVA_LONG, baseIndex + i, value);
+ segment.setAtIndex(i, value, baseIndex);
}
}
@@ -216,7 +243,7 @@ public int getSize() {
*/
@Override
public MemorySegment getSegment() {
- return segment.asSlice(TornadoNativeArray.ARRAY_HEADER);
+ return segment.getSegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
}
/**
@@ -226,7 +253,7 @@ public MemorySegment getSegment() {
*/
@Override
public MemorySegment getSegmentWithHeader() {
- return segment;
+ return segment.getSegment();
}
/**
@@ -250,39 +277,6 @@ public long getNumBytesOfSegment() {
return segmentByteSize - TornadoNativeArray.ARRAY_HEADER;
}
- /**
- * Factory method to initialize a {@link LongArray}. This method can be invoked from a Task-Graph.
- *
- * @param array
- * Input Array.
- * @param value
- * The float value to initialize the {@code LongArray} instance with.
- */
- public static void initialize(LongArray array, long value) {
- for (@Parallel int i = 0; i < array.getSize(); i++) {
- array.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link LongArray} instances into a single {@link LongArray}.
- *
- * @param arrays
- * Variable number of {@link LongArray} objects to be concatenated.
- * @return A new {@link LongArray} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static LongArray concat(LongArray... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(LongArray::getSize).sum();
- LongArray concatArray = new LongArray(newSize);
- long currentPositionBytes = 0;
- for (LongArray array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
-
/**
* Extracts a slice of elements from a given {@link LongArray}, creating a new {@link LongArray} instance.
*
@@ -302,7 +296,7 @@ public LongArray slice(int offset, int length) {
long sliceOffsetInBytes = TornadoNativeArray.ARRAY_HEADER + offset * LONG_BYTES;
long sliceByteLength = length * LONG_BYTES;
- MemorySegment sliceSegment = segment.asSlice(sliceOffsetInBytes, sliceByteLength);
+ MemorySegment sliceSegment = segment.getSegment().asSlice(sliceOffsetInBytes, sliceByteLength);
LongArray slice = fromSegment(sliceSegment);
return slice;
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ShortArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ShortArray.java
index c29df7836f..21f65acc4c 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ShortArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ShortArray.java
@@ -17,10 +17,6 @@
*/
package uk.ac.manchester.tornado.api.types.arrays;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-import static java.lang.foreign.ValueLayout.JAVA_SHORT;
-
-import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.ShortBuffer;
import java.util.Arrays;
@@ -37,7 +33,7 @@
@SegmentElementSize(size = 2)
public final class ShortArray extends TornadoNativeArray {
private static final int SHORT_BYTES = 2;
- private MemorySegment segment;
+ private TornadoMemorySegment segment;
private int numberOfElements;
private int arrayHeaderSize;
@@ -57,9 +53,7 @@ public ShortArray(int numberOfElements) {
assert arrayHeaderSize >= 4;
baseIndex = arrayHeaderSize / SHORT_BYTES;
segmentByteSize = numberOfElements * SHORT_BYTES + arrayHeaderSize;
-
- segment = Arena.ofAuto().allocate(segmentByteSize, 1);
- segment.setAtIndex(JAVA_INT, 0, numberOfElements);
+ segment = new TornadoMemorySegment(segmentByteSize, baseIndex, numberOfElements);
}
/**
@@ -120,7 +114,7 @@ public static ShortArray fromSegment(MemorySegment segment) {
long byteSize = segment.byteSize();
int numElements = (int) (byteSize / SHORT_BYTES);
ShortArray shortArray = new ShortArray(numElements);
- MemorySegment.copy(segment, 0, shortArray.segment, shortArray.baseIndex * SHORT_BYTES, byteSize);
+ MemorySegment.copy(segment, 0, shortArray.segment.getSegment(), shortArray.baseIndex * SHORT_BYTES, byteSize);
return shortArray;
}
@@ -138,6 +132,39 @@ public static ShortArray fromShortBuffer(ShortBuffer buffer) {
return shortArray;
}
+ /**
+ * Factory method to initialize a {@link ShortArray}. This method can be invoked from a Task-Graph.
+ *
+ * @param array
+ * Input Array.
+ * @param value
+ * The float value to initialize the {@code ShortArray} instance with.
+ */
+ public static void initialize(ShortArray array, short value) {
+ for (@Parallel int i = 0; i < array.getSize(); i++) {
+ array.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link ShortArray} instances into a single {@link ShortArray}.
+ *
+ * @param arrays
+ * Variable number of {@link ShortArray} objects to be concatenated.
+ * @return A new {@link ShortArray} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static ShortArray concat(ShortArray... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(ShortArray::getSize).sum();
+ ShortArray concatArray = new ShortArray(newSize);
+ long currentPositionBytes = 0;
+ for (ShortArray array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
/**
* Converts the short data from off-heap to on-heap, by copying the values of a {@link ShortArray}
* instance into a new on-heap array.
@@ -161,7 +188,7 @@ public short[] toHeapArray() {
* The short value to store at the specified index.
*/
public void set(int index, short value) {
- segment.setAtIndex(JAVA_SHORT, baseIndex + index, value);
+ segment.setAtIndex(index, value, baseIndex);
}
/**
@@ -172,7 +199,7 @@ public void set(int index, short value) {
* @return
*/
public short get(int index) {
- return segment.getAtIndex(JAVA_SHORT, baseIndex + index);
+ return segment.getShortAtIndex(index, baseIndex);
}
/**
@@ -196,7 +223,7 @@ public int getElementSize() {
*/
public void init(short value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(JAVA_SHORT, baseIndex + i, value);
+ segment.setAtIndex(baseIndex + i, value, baseIndex);
}
}
@@ -217,7 +244,7 @@ public int getSize() {
*/
@Override
public MemorySegment getSegment() {
- return segment.asSlice(TornadoNativeArray.ARRAY_HEADER);
+ return segment.getSegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
}
/**
@@ -227,7 +254,7 @@ public MemorySegment getSegment() {
*/
@Override
public MemorySegment getSegmentWithHeader() {
- return segment;
+ return segment.getSegment();
}
/**
@@ -251,39 +278,6 @@ public long getNumBytesOfSegment() {
return segmentByteSize - TornadoNativeArray.ARRAY_HEADER;
}
- /**
- * Factory method to initialize a {@link ShortArray}. This method can be invoked from a Task-Graph.
- *
- * @param array
- * Input Array.
- * @param value
- * The float value to initialize the {@code ShortArray} instance with.
- */
- public static void initialize(ShortArray array, short value) {
- for (@Parallel int i = 0; i < array.getSize(); i++) {
- array.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link ShortArray} instances into a single {@link ShortArray}.
- *
- * @param arrays
- * Variable number of {@link ShortArray} objects to be concatenated.
- * @return A new {@link ShortArray} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static ShortArray concat(ShortArray... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(ShortArray::getSize).sum();
- ShortArray concatArray = new ShortArray(newSize);
- long currentPositionBytes = 0;
- for (ShortArray array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
-
/**
* Extracts a slice of elements from a given {@link ShortArray}, creating a new {@link ShortArray} instance.
*
@@ -302,7 +296,7 @@ public ShortArray slice(int offset, int length) {
long sliceOffsetInBytes = TornadoNativeArray.ARRAY_HEADER + offset * SHORT_BYTES;
long sliceByteLength = length * SHORT_BYTES;
- MemorySegment sliceSegment = segment.asSlice(sliceOffsetInBytes, sliceByteLength);
+ MemorySegment sliceSegment = segment.getSegment().asSlice(sliceOffsetInBytes, sliceByteLength);
ShortArray slice = fromSegment(sliceSegment);
return slice;
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java
index 46941d9440..6f71b869e0 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java
@@ -35,11 +35,60 @@ public MemorySegment getSegment() {
return segment;
}
- public void setSegmentAt(int index, float value, int baseIndex) {
+ public void setAtIndex(int index, float value, int baseIndex) {
segment.setAtIndex(ValueLayout.JAVA_FLOAT, baseIndex + index, value);
}
- public float getSegmentFrom(int index, int baseIndex) {
+ public float getFloatAtIndex(int index, int baseIndex) {
return segment.getAtIndex(ValueLayout.JAVA_FLOAT, baseIndex + index);
}
+
+ public void setAtIndex(int index, double value, int baseIndex) {
+ segment.setAtIndex(ValueLayout.JAVA_DOUBLE, baseIndex + index, value);
+ }
+
+ public double getDoubleAtIndex(int index, int baseIndex) {
+ return segment.getAtIndex(ValueLayout.JAVA_DOUBLE, baseIndex + index);
+ }
+
+ public void setAtIndex(int index, byte value, int baseIndex) {
+ segment.setAtIndex(ValueLayout.JAVA_BYTE, baseIndex + index, value);
+ }
+
+ public byte getByteAtIndex(int index, int baseIndex) {
+ return segment.getAtIndex(ValueLayout.JAVA_BYTE, baseIndex + index);
+ }
+
+ public void setAtIndex(int index, char value, int baseIndex) {
+ segment.setAtIndex(ValueLayout.JAVA_CHAR, baseIndex + index, value);
+ }
+
+ public char getCharAtIndex(int index, int baseIndex) {
+ return segment.getAtIndex(ValueLayout.JAVA_CHAR, baseIndex + index);
+ }
+
+ public void setAtIndex(int index, int value, int baseIndex) {
+ segment.setAtIndex(ValueLayout.JAVA_INT, baseIndex + index, value);
+ }
+
+ public int getIntAtIndex(int index, int baseIndex) {
+ return segment.getAtIndex(ValueLayout.JAVA_INT, baseIndex + index);
+ }
+
+ public void setAtIndex(int index, long value, int baseIndex) {
+ segment.setAtIndex(ValueLayout.JAVA_LONG, baseIndex + index, value);
+ }
+
+ public long getLongAtIndex(int index, int baseIndex) {
+ return segment.getAtIndex(ValueLayout.JAVA_LONG, baseIndex + index);
+ }
+
+ public void setAtIndex(int index, short value, int baseIndex) {
+ segment.setAtIndex(ValueLayout.JAVA_SHORT, baseIndex + index, value);
+ }
+
+ public short getShortAtIndex(int index, int baseIndex) {
+ return segment.getAtIndex(ValueLayout.JAVA_SHORT, baseIndex + index);
+ }
+
}
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
index faf9d9e42c..6e643d4640 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
@@ -41,8 +41,6 @@
import static uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLIntBinaryIntrinsicNode.Operation.MIN;
import static uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLIntUnaryIntrinsicNode.Operation.POPCOUNT;
-import java.lang.foreign.ValueLayout;
-
import org.graalvm.word.LocationIdentity;
import jdk.graal.compiler.core.common.memory.BarrierType;
@@ -78,7 +76,6 @@
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TornadoVMIntrinsics;
import uk.ac.manchester.tornado.api.exceptions.Debug;
-import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.api.types.arrays.TornadoMemorySegment;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLArchitecture;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind;
@@ -339,11 +336,11 @@ private static boolean printfHandler(GraphBuilderContext b, ResolvedJavaMethod t
int argIndex = 0;
for (Node n : newArrayNode.usages()) {
- if (n instanceof StoreIndexedNode) {
- StoreIndexedNode storeNode = (StoreIndexedNode) n;
+ if (n instanceof StoreIndexedNode storeIndexedNode) {
+ StoreIndexedNode storeNode = storeIndexedNode;
ValueNode value = storeNode.value();
- if (value instanceof BoxNode) {
- BoxNode box = (BoxNode) value;
+ if (value instanceof BoxNode boxNodeValue) {
+ BoxNode box = boxNodeValue;
value = box.getValue();
GraphUtil.unlinkFixedNode(box);
box.safeDelete();
@@ -366,8 +363,8 @@ private static boolean printfHandler(GraphBuilderContext b, ResolvedJavaMethod t
// unbuilt part of the graph. We also need to ensure that we
// do not leave any
// gaps inbetween fixed nodes
- if (n instanceof FixedWithNextNode) {
- GraphUtil.unlinkFixedNode((FixedWithNextNode) n);
+ if (n instanceof FixedWithNextNode fixedWithNextNode) {
+ GraphUtil.unlinkFixedNode(fixedWithNextNode);
}
n.clearInputs();
n.safeDelete();
@@ -380,51 +377,30 @@ private static boolean printfHandler(GraphBuilderContext b, ResolvedJavaMethod t
return true;
}
- public static Class getValueLayoutClass(Class k) {
- if (k == int.class) {
- return ValueLayout.OfInt.class;
- } else if (k == double.class) {
- return ValueLayout.OfDouble.class;
- } else if (k == float.class) {
- return ValueLayout.OfFloat.class;
- } else if (k == long.class) {
- return ValueLayout.OfLong.class;
- } else if (k == boolean.class) {
- return ValueLayout.OfBoolean.class;
- } else if (k == byte.class) {
- return ValueLayout.OfByte.class;
- } else if (k == char.class) {
- return ValueLayout.OfChar.class;
- } else if (k == short.class) {
- return ValueLayout.OfShort.class;
- } else {
- throw new TornadoRuntimeException("Class type " + k + " not supported.");
- }
- }
-
private static void registerMemoryAccessPlugins(InvocationPlugins plugins, HotSpotMetaAccessProvider metaAccessProvider) {
Registration r = new Registration(plugins, TornadoMemorySegment.class);
for (JavaKind kind : JavaKind.values()) {
if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- r.register(new InvocationPlugin("setSegmentAt", Receiver.class, int.class, float.class, int.class) {
+ System.out.println("KInd " + kind.getJavaName() + " " + kind.name());
+ r.register(new InvocationPlugin("setAtIndex", Receiver.class, int.class, kind.toJavaClass(), int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
- MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(4)));
+ MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- JavaWriteNode writeNode = new JavaWriteNode(JavaKind.Float, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
+ JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
b.add(writeNode);
return true;
}
});
- r.register(new InvocationPlugin("getSegmentFrom", Receiver.class, int.class, int.class) {
+ r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", Receiver.class, int.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
- MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(4)));
+ MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- JavaReadNode readNode = new JavaReadNode(JavaKind.Float, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
+ JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
b.addPush(JavaKind.Float, readNode);
return true;
}
From dc142cee304a3e8c37273f9b456a050c0ee657a5 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 26 Jun 2024 17:03:29 +0300
Subject: [PATCH 08/54] Update import statements and rearrange methods
---
.../graal/nodes/FixedArrayCopyNode.java | 23 +++++----
.../phases/TornadoFixedArrayCopyPhase.java | 49 +++++++++----------
.../runtime/sketcher/TornadoSketcher.java | 8 +--
3 files changed, 39 insertions(+), 41 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/FixedArrayCopyNode.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/FixedArrayCopyNode.java
index f182fc3b90..704e4be64d 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/FixedArrayCopyNode.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/FixedArrayCopyNode.java
@@ -21,30 +21,29 @@
*/
package uk.ac.manchester.tornado.drivers.opencl.graal.nodes;
+import jdk.graal.compiler.core.common.LIRKind;
+import jdk.graal.compiler.core.common.type.StampFactory;
+import jdk.graal.compiler.graph.NodeClass;
+import jdk.graal.compiler.lir.Variable;
+import jdk.graal.compiler.nodeinfo.NodeInfo;
+import jdk.graal.compiler.nodes.ValuePhiNode;
+import jdk.graal.compiler.nodes.calc.FloatingNode;
+import jdk.graal.compiler.nodes.spi.LIRLowerable;
+import jdk.graal.compiler.nodes.spi.NodeLIRBuilderTool;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaType;
import jdk.vm.ci.meta.Value;
-import org.graalvm.compiler.core.common.LIRKind;
-import org.graalvm.compiler.core.common.type.StampFactory;
-import org.graalvm.compiler.graph.NodeClass;
-import org.graalvm.compiler.lir.Variable;
-import org.graalvm.compiler.nodeinfo.NodeInfo;
-import org.graalvm.compiler.nodes.ValuePhiNode;
-import org.graalvm.compiler.nodes.calc.FloatingNode;
-import org.graalvm.compiler.nodes.spi.LIRLowerable;
-import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
-
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLArchitecture;
import uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLBinary;
-import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLLIRStmt;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind;
+import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLLIRStmt;
/**
* This node generates a pointer copy between two arrays in private memory.
*/
@NodeInfo
-public class FixedArrayCopyNode extends FloatingNode implements LIRLowerable {
+public class FixedArrayCopyNode extends FloatingNode implements LIRLowerable {
public static final NodeClass TYPE = NodeClass.create(FixedArrayCopyNode.class);
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoFixedArrayCopyPhase.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoFixedArrayCopyPhase.java
index 0871d0e51f..dead8af3d7 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoFixedArrayCopyPhase.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoFixedArrayCopyPhase.java
@@ -21,27 +21,42 @@
*/
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;
-import jdk.vm.ci.meta.ResolvedJavaType;
-import org.graalvm.compiler.graph.Node;
-import org.graalvm.compiler.nodes.GraphState;
-import org.graalvm.compiler.nodes.StructuredGraph;
-import org.graalvm.compiler.nodes.ValuePhiNode;
-import org.graalvm.compiler.nodes.memory.address.OffsetAddressNode;
-import org.graalvm.compiler.phases.Phase;
+import java.util.Optional;
+import jdk.graal.compiler.graph.Node;
+import jdk.graal.compiler.nodes.GraphState;
+import jdk.graal.compiler.nodes.StructuredGraph;
+import jdk.graal.compiler.nodes.ValuePhiNode;
+import jdk.graal.compiler.nodes.memory.address.OffsetAddressNode;
+import jdk.graal.compiler.phases.Phase;
+import jdk.vm.ci.meta.ResolvedJavaType;
import uk.ac.manchester.tornado.api.exceptions.TornadoCompilationException;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLArchitecture;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.FixedArrayCopyNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.FixedArrayNode;
-import java.util.Optional;
-
/**
* This phase examines if a copy takes place between two arrays in private memory based on
* an if condition and, if so, inserts a {@link FixedArrayCopyNode} to generate an update in the references.
*/
public class TornadoFixedArrayCopyPhase extends Phase {
+ private static boolean isFixedArrayCopied(ValuePhiNode phiNode) {
+ return phiNode.usages().filter(OffsetAddressNode.class).isNotEmpty() && phiNode.values().filter(FixedArrayNode.class).isNotEmpty();
+ }
+
+ private static ValuePhiNode getPrivateArrayIndex(Node node) {
+ // identify the index
+ for (Node input : node.inputs()) {
+ if (input instanceof ValuePhiNode phiNode) {
+ return phiNode;
+ } else {
+ return getPrivateArrayIndex(input);
+ }
+ }
+ return null;
+ }
+
@Override
public Optional notApplicableTo(GraphState graphState) {
return ALWAYS_APPLICABLE;
@@ -67,20 +82,4 @@ protected void run(StructuredGraph graph) {
}
}
- private static boolean isFixedArrayCopied(ValuePhiNode phiNode) {
- return phiNode.usages().filter(OffsetAddressNode.class).isNotEmpty() && phiNode.values().filter(FixedArrayNode.class).isNotEmpty();
- }
-
- private static ValuePhiNode getPrivateArrayIndex(Node node) {
- // identify the index
- for (Node input : node.inputs()) {
- if (input instanceof ValuePhiNode phiNode) {
- return phiNode;
- } else {
- return getPrivateArrayIndex(input);
- }
- }
- return null;
- }
-
}
diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/sketcher/TornadoSketcher.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/sketcher/TornadoSketcher.java
index 4a64a5677f..8338ccfd2a 100644
--- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/sketcher/TornadoSketcher.java
+++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/sketcher/TornadoSketcher.java
@@ -25,7 +25,7 @@
*/
package uk.ac.manchester.tornado.runtime.sketcher;
-import static org.graalvm.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Optional;
+import static jdk.graal.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Optional;
import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.guarantee;
import static uk.ac.manchester.tornado.runtime.TornadoCoreRuntime.getDebugContext;
import static uk.ac.manchester.tornado.runtime.TornadoCoreRuntime.getOptions;
@@ -56,7 +56,6 @@
import jdk.graal.compiler.phases.common.DeadCodeEliminationPhase;
import jdk.graal.compiler.phases.tiers.HighTierContext;
import jdk.graal.compiler.phases.util.Providers;
-
import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.common.Access;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
@@ -77,6 +76,7 @@ public class TornadoSketcher {
private static final Map> cache = new ConcurrentHashMap<>();
private static final TimerKey Sketcher = DebugContext.timer("Sketcher");
private static final OptimisticOptimizations optimisticOpts = OptimisticOptimizations.ALL;
+ private static TornadoLogger logger = new TornadoLogger();
private static boolean cacheContainsSketch(ResolvedJavaMethod method, int driverIndex, int deviceIndex) {
List entries = cache.get(method);
@@ -109,7 +109,7 @@ public static Sketch lookup(ResolvedJavaMethod resolvedMethod, int driverIndex,
}
guarantee(sketch != null, "No sketch available for %d:%d %s", driverIndex, deviceIndex, resolvedMethod.getName());
} catch (InterruptedException | ExecutionException e) {
- TornadoLogger.fatal("Failed to retrieve sketch for %d:%d %s ", driverIndex, deviceIndex, resolvedMethod.getName());
+ logger.fatal("Failed to retrieve sketch for %d:%d %s ", driverIndex, deviceIndex, resolvedMethod.getName());
if (TornadoOptions.DEBUG) {
e.printStackTrace();
}
@@ -136,7 +136,7 @@ static void buildSketch(SketchRequest request) {
private static Sketch buildSketch(ResolvedJavaMethod resolvedMethod, Providers providers, PhaseSuite graphBuilderSuite, TornadoSketchTier sketchTier, int backendIndex,
int deviceIndex) {
- TornadoLogger.info("Building sketch of %s", resolvedMethod.getName());
+ logger.info("Building sketch of %s", resolvedMethod.getName());
TornadoCompilerIdentifier id = new TornadoCompilerIdentifier("sketch-" + resolvedMethod.getName(), sketchId.getAndIncrement());
Builder builder = new Builder(getOptions(), getDebugContext(), AllowAssumptions.YES);
builder.method(resolvedMethod);
From fa04b2b7432eed5104aebc11c574672c261ad01d Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 26 Jun 2024 17:25:57 +0300
Subject: [PATCH 09/54] Fix array initialization and remove debug output
---
.../uk/ac/manchester/tornado/api/types/arrays/ByteArray.java | 2 +-
.../uk/ac/manchester/tornado/api/types/arrays/CharArray.java | 2 +-
.../opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java | 1 -
3 files changed, 2 insertions(+), 3 deletions(-)
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ByteArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ByteArray.java
index 3049629296..7a32a94c2c 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ByteArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ByteArray.java
@@ -222,7 +222,7 @@ public int getElementSize() {
*/
public void init(byte value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(baseIndex + i, value, baseIndex);
+ segment.setAtIndex(i, value, baseIndex);
}
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/CharArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/CharArray.java
index a3d47b4bd4..c1f4f4b2ce 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/CharArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/CharArray.java
@@ -222,7 +222,7 @@ public char get(int index) {
*/
public void init(char value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(baseIndex + i, value, baseIndex);
+ segment.setAtIndex(i, value, baseIndex);
}
}
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
index 6e643d4640..a4d91d61aa 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
@@ -382,7 +382,6 @@ private static void registerMemoryAccessPlugins(InvocationPlugins plugins, HotSp
for (JavaKind kind : JavaKind.values()) {
if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- System.out.println("KInd " + kind.getJavaName() + " " + kind.name());
r.register(new InvocationPlugin("setAtIndex", Receiver.class, int.class, kind.toJavaClass(), int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
From d31dd6540df56f96a1f9beecb864b38168a405a4 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Tue, 2 Jul 2024 11:46:53 +0300
Subject: [PATCH 10/54] Swap 'setAtIndex' and 'getAtIndex' functionalities in
OCLGraphBuilderPlugins
---
.../plugins/OCLGraphBuilderPlugins.java | 17 +++++++++--------
1 file changed, 9 insertions(+), 8 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
index a4d91d61aa..e32fa6efcb 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
@@ -382,25 +382,26 @@ private static void registerMemoryAccessPlugins(InvocationPlugins plugins, HotSp
for (JavaKind kind : JavaKind.values()) {
if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- r.register(new InvocationPlugin("setAtIndex", Receiver.class, int.class, kind.toJavaClass(), int.class) {
+ r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", Receiver.class, int.class, int.class) {
@Override
- public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
+ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
+ // System.out.println("kind: " + kind.name());
AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
- b.add(writeNode);
+ JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
+ b.addPush(kind, readNode);
return true;
}
});
- r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", Receiver.class, int.class, int.class) {
+ r.register(new InvocationPlugin("setAtIndex", Receiver.class, int.class, kind.toJavaClass(), int.class) {
@Override
- public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
+ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
- b.addPush(JavaKind.Float, readNode);
+ JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
+ b.add(writeNode);
return true;
}
});
From 5d1019d4dd6c4c2e8c068c375dca4c17068f73a0 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 3 Jul 2024 15:43:41 +0300
Subject: [PATCH 11/54] Update OpenCL graph building and vector handling
---
Makefile | 2 +-
.../types/arrays/TornadoMemorySegment.java | 171 +++++++++++++++++-
.../plugins/OCLGraphBuilderPlugins.java | 46 +++--
.../compiler/plugins/OCLVectorPlugins.java | 5 +-
.../phases/TornadoAtomicsParametersPhase.java | 10 +-
5 files changed, 200 insertions(+), 34 deletions(-)
diff --git a/Makefile b/Makefile
index 9e2f795f08..e51bcfd37a 100644
--- a/Makefile
+++ b/Makefile
@@ -42,7 +42,7 @@ example:
tests:
rm -f tornado_unittests.log
tornado --devices
- tornado-test --ea --verbose
+ tornado-test --verbose
tornado-test --ea -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
test-native.sh
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java
index 6f71b869e0..d81ca6dd8e 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoMemorySegment.java
@@ -17,78 +17,237 @@
*/
package uk.ac.manchester.tornado.api.types.arrays;
-import static java.lang.foreign.ValueLayout.JAVA_INT;
-
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
+/**
+ * The {@code TornadoMemorySegment} class provides a high-level interface for managing a
+ * {@link MemorySegment} with support for different data types.
+ *
+ * This class allows the allocation of a memory segment with a specific size, and provides
+ * methods to set and get values of various types at specific indices relative to a base index.
+ *
+ */
public class TornadoMemorySegment {
private MemorySegment segment;
-
- public TornadoMemorySegment(long segmentByteSize, int basedIndex, int numElements) {
+ private int baseIndex;
+
+ /**
+ * Constructs a {@code TornadoMemorySegment} with a specified byte size and base index.
+ *
+ * This constructor allocates a new memory segment of the specified byte size and
+ * initializes it with a given number of elements.
+ *
+ *
+ * @param segmentByteSize
+ * the size of the memory segment in bytes
+ * @param baseIndex
+ * the base index used for calculating the actual index in the memory segment
+ * @param numElements
+ * the number of elements to initialize in the segment
+ */
+ public TornadoMemorySegment(long segmentByteSize, int baseIndex, int numElements) {
this.segment = Arena.ofAuto().allocate(segmentByteSize, 1);
- segment.setAtIndex(JAVA_INT, 0, numElements);
+ this.baseIndex = baseIndex;
+ segment.setAtIndex(ValueLayout.JAVA_INT, 0, numElements);
}
+ /**
+ * Returns the underlying {@link MemorySegment}.
+ *
+ * @return the memory segment
+ */
public MemorySegment getSegment() {
return segment;
}
+ /**
+ * Sets a {@code float} value at the specified index.
+ *
+ * @param index
+ * the index where the value will be set
+ * @param value
+ * the {@code float} value to set
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ */
public void setAtIndex(int index, float value, int baseIndex) {
segment.setAtIndex(ValueLayout.JAVA_FLOAT, baseIndex + index, value);
}
+ /**
+ * Returns the {@code float} value at the specified index.
+ *
+ * @param index
+ * the index from which the value will be retrieved
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ * @return the {@code float} value at the specified index
+ */
public float getFloatAtIndex(int index, int baseIndex) {
return segment.getAtIndex(ValueLayout.JAVA_FLOAT, baseIndex + index);
}
+ /**
+ * Sets a {@code double} value at the specified index.
+ *
+ * @param index
+ * the index where the value will be set
+ * @param value
+ * the {@code double} value to set
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ */
public void setAtIndex(int index, double value, int baseIndex) {
segment.setAtIndex(ValueLayout.JAVA_DOUBLE, baseIndex + index, value);
}
+ /**
+ * Returns the {@code double} value at the specified index.
+ *
+ * @param index
+ * the index from which the value will be retrieved
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ * @return the {@code double} value at the specified index
+ */
public double getDoubleAtIndex(int index, int baseIndex) {
return segment.getAtIndex(ValueLayout.JAVA_DOUBLE, baseIndex + index);
}
+ /**
+ * Sets a {@code byte} value at the specified index.
+ *
+ * @param index
+ * the index where the value will be set
+ * @param value
+ * the {@code byte} value to set
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ */
public void setAtIndex(int index, byte value, int baseIndex) {
segment.setAtIndex(ValueLayout.JAVA_BYTE, baseIndex + index, value);
}
+ /**
+ * Returns the {@code byte} value at the specified index.
+ *
+ * @param index
+ * the index from which the value will be retrieved
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ * @return the {@code byte} value at the specified index
+ */
public byte getByteAtIndex(int index, int baseIndex) {
return segment.getAtIndex(ValueLayout.JAVA_BYTE, baseIndex + index);
}
+ /**
+ * Sets a {@code char} value at the specified index.
+ *
+ * @param index
+ * the index where the value will be set
+ * @param value
+ * the {@code char} value to set
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ */
public void setAtIndex(int index, char value, int baseIndex) {
segment.setAtIndex(ValueLayout.JAVA_CHAR, baseIndex + index, value);
}
+ /**
+ * Returns the {@code char} value at the specified index.
+ *
+ * @param index
+ * the index from which the value will be retrieved
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ * @return the {@code char} value at the specified index
+ */
public char getCharAtIndex(int index, int baseIndex) {
return segment.getAtIndex(ValueLayout.JAVA_CHAR, baseIndex + index);
}
+ /**
+ * Sets an {@code int} value at the specified index.
+ *
+ * @param index
+ * the index where the value will be set
+ * @param value
+ * the {@code int} value to set
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ */
public void setAtIndex(int index, int value, int baseIndex) {
segment.setAtIndex(ValueLayout.JAVA_INT, baseIndex + index, value);
}
+ /**
+ * Returns the {@code int} value at the specified index.
+ *
+ * @param index
+ * the index from which the value will be retrieved
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ * @return the {@code int} value at the specified index
+ */
public int getIntAtIndex(int index, int baseIndex) {
return segment.getAtIndex(ValueLayout.JAVA_INT, baseIndex + index);
}
+ /**
+ * Sets a {@code long} value at the specified index.
+ *
+ * @param index
+ * the index where the value will be set
+ * @param value
+ * the {@code long} value to set
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ */
public void setAtIndex(int index, long value, int baseIndex) {
segment.setAtIndex(ValueLayout.JAVA_LONG, baseIndex + index, value);
}
+ /**
+ * Returns the {@code long} value at the specified index.
+ *
+ * @param index
+ * the index from which the value will be retrieved
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ * @return the {@code long} value at the specified index
+ */
public long getLongAtIndex(int index, int baseIndex) {
return segment.getAtIndex(ValueLayout.JAVA_LONG, baseIndex + index);
}
+ /**
+ * Sets a {@code short} value at the specified index.
+ *
+ * @param index
+ * the index where the value will be set
+ * @param value
+ * the {@code short} value to set
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ */
public void setAtIndex(int index, short value, int baseIndex) {
segment.setAtIndex(ValueLayout.JAVA_SHORT, baseIndex + index, value);
}
+ /**
+ * Returns the {@code short} value at the specified index.
+ *
+ * @param index
+ * the index from which the value will be retrieved
+ * @param baseIndex
+ * the base index used for calculating the actual index
+ * @return the {@code short} value at the specified index
+ */
public short getShortAtIndex(int index, int baseIndex) {
return segment.getAtIndex(ValueLayout.JAVA_SHORT, baseIndex + index);
}
-
}
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
index e32fa6efcb..4e3bcee532 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
@@ -125,6 +125,7 @@ private static void registerTornadoVMAtomicsPlugins(Registration r) {
r.register(new InvocationPlugin("atomic_add", int[].class, Integer.TYPE, Integer.TYPE) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode array, ValueNode index, ValueNode inc) {
+ receiver.get(true);
AtomicAddNodeTemplate atomicIncNode = new AtomicAddNodeTemplate(array, index, inc);
b.addPush(JavaKind.Int, b.append(atomicIncNode));
return true;
@@ -146,6 +147,7 @@ private static void registerAtomicCall(Registration r, JavaKind returnedJavaKind
r.register(new InvocationPlugin("incrementAndGet", Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
b.addPush(returnedJavaKind, b.append(new IncAtomicNode(receiver.get())));
return true;
}
@@ -154,6 +156,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("decrementAndGet", Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
b.addPush(returnedJavaKind, b.append(new DecAtomicNode(receiver.get())));
return true;
}
@@ -162,6 +165,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("get", Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
b.addPush(returnedJavaKind, b.append(new GetAtomicNode(receiver.get())));
return true;
}
@@ -182,8 +186,8 @@ public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, Va
// args[1] = arguments to the invoke node being substituted
// ========================================================
ValueNode initialValue = args[1];
- if (initialValue instanceof ConstantNode) {
- int value = Integer.parseInt(((ConstantNode) initialValue).getValue().toValueString());
+ if (initialValue instanceof ConstantNode constantNode) {
+ int value = Integer.parseInt(constantNode.getValue().toValueString());
if (value == 0) {
atomic.setInitialValue(initialValue);
} else {
@@ -205,19 +209,20 @@ public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, Va
private static TornadoAtomicIntegerNode resolveReceiverAtomic(ValueNode thisObject) {
TornadoAtomicIntegerNode atomicNode = null;
- if (thisObject instanceof PiNode) {
- thisObject = ((PiNode) thisObject).getOriginalNode();
+ if (thisObject instanceof PiNode objectAsPiNode) {
+ thisObject = objectAsPiNode.getOriginalNode();
}
- if (thisObject instanceof TornadoAtomicIntegerNode) {
- atomicNode = (TornadoAtomicIntegerNode) thisObject;
+ if (thisObject instanceof TornadoAtomicIntegerNode returnedAtomicNode) {
+ atomicNode = returnedAtomicNode;
}
return atomicNode;
}
private static void registerLocalBarrier(Registration r) {
- r.register(new InvocationPlugin("localBarrier", Receiver.class) {
+ r.register(new InvocationPlugin("localBarrier") {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
OCLBarrierNode localBarrierNode = new OCLBarrierNode(OCLBarrierNode.OCLMemFenceFlags.LOCAL);
b.add(localBarrierNode);
return true;
@@ -226,9 +231,10 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerGlobalBarrier(Registration r) {
- r.register(new InvocationPlugin("globalBarrier", Receiver.class) {
+ r.register(new InvocationPlugin("globalBarrier") {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
OCLBarrierNode localBarrierNode = new OCLBarrierNode(OCLBarrierNode.OCLMemFenceFlags.GLOBAL);
b.add(localBarrierNode);
return true;
@@ -237,9 +243,10 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerIntLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
- r.register(new InvocationPlugin("allocateIntLocalArray", Receiver.class, int.class) {
+ r.register(new InvocationPlugin("allocateIntLocalArray", int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
+ receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
LocalArrayNode localArrayNode = new LocalArrayNode(OCLArchitecture.localSpace, elementType, constantNode);
b.push(returnedJavaKind, localArrayNode);
@@ -249,9 +256,10 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerLongLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
- r.register(new InvocationPlugin("allocateLongLocalArray", Receiver.class, int.class) {
+ r.register(new InvocationPlugin("allocateLongLocalArray", int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
+ receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
LocalArrayNode localArrayNode = new LocalArrayNode(OCLArchitecture.localSpace, elementType, constantNode);
b.push(returnedJavaKind, localArrayNode);
@@ -261,9 +269,10 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerFloatLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
- r.register(new InvocationPlugin("allocateFloatLocalArray", Receiver.class, int.class) {
+ r.register(new InvocationPlugin("allocateFloatLocalArray", int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
+ receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
LocalArrayNode localArrayNode = new LocalArrayNode(OCLArchitecture.localSpace, elementType, constantNode);
b.push(returnedJavaKind, localArrayNode);
@@ -273,9 +282,10 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerDoubleLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
- r.register(new InvocationPlugin("allocateDoubleLocalArray", Receiver.class, int.class) {
+ r.register(new InvocationPlugin("allocateDoubleLocalArray", int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
+ receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
LocalArrayNode localArrayNode = new LocalArrayNode(OCLArchitecture.localSpace, elementType, constantNode);
b.push(returnedJavaKind, localArrayNode);
@@ -356,13 +366,10 @@ private static boolean printfHandler(GraphBuilderContext b, ResolvedJavaMethod t
b.add(b.append(printfNode));
while (newArrayNode.hasUsages()) {
Node n = newArrayNode.usages().first();
- // need to remove all nodes from the graph that operate on
- // the new array,
- // however, we cannot remove all inputs as they may be used
- // by the currently
- // unbuilt part of the graph. We also need to ensure that we
- // do not leave any
- // gaps inbetween fixed nodes
+ // We need to remove all nodes from the graph that operate on the new array.
+ // However, we cannot remove all inputs, as they may be used by the currently
+ // unbuilt parts of the graph. We must also ensure that no gaps are left
+ // between fixed nodes.
if (n instanceof FixedWithNextNode fixedWithNextNode) {
GraphUtil.unlinkFixedNode(fixedWithNextNode);
}
@@ -385,7 +392,6 @@ private static void registerMemoryAccessPlugins(InvocationPlugins plugins, HotSp
r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", Receiver.class, int.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
- // System.out.println("kind: " + kind.name());
AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLVectorPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLVectorPlugins.java
index f6c2393b28..e60b3e08d5 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLVectorPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLVectorPlugins.java
@@ -42,7 +42,6 @@
import jdk.graal.compiler.nodes.java.StoreIndexedNode;
import jdk.graal.compiler.nodes.memory.address.AddressNode;
import jdk.graal.compiler.nodes.memory.address.OffsetAddressNode;
-
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import jdk.vm.ci.meta.ResolvedJavaType;
@@ -241,9 +240,10 @@ private static void registerVectorCollectionsPlugins(final InvocationPlugins plu
final Class> declaringClass = vectorKind.getJavaClass();
- final Registration r = new Registration(plugins, declaringClass);
+ final Registration r = new Registration(plugins, declaringClass).setAllowOverwrite(true);
r.register(new InvocationPlugin("loadFromArray", Receiver.class, storageType, int.class) {
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode array, ValueNode index) {
+ receiver.get(true);
final ResolvedJavaType resolvedType = b.getMetaAccess().lookupJavaType(vectorClass);
OCLKind kind = OCLKind.fromResolvedJavaType(resolvedType);
JavaKind elementKind = kind.getElementKind().asJavaKind();
@@ -256,6 +256,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("storeToArray", Receiver.class, vectorClass, storageType, int.class) {
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode value, ValueNode array, ValueNode index) {
+ receiver.get(true);
final ResolvedJavaType resolvedType = b.getMetaAccess().lookupJavaType(vectorClass);
OCLKind kind = OCLKind.fromResolvedJavaType(resolvedType);
JavaKind elementKind = kind.getElementKind().asJavaKind();
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
index 04707c55f4..38d1e98013 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
@@ -10,7 +10,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -31,7 +31,6 @@
import jdk.graal.compiler.nodes.StartNode;
import jdk.graal.compiler.nodes.StructuredGraph;
import jdk.graal.compiler.phases.Phase;
-
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.IncAtomicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.NodeAtomic;
@@ -59,11 +58,12 @@ protected void run(StructuredGraph graph) {
if (!filter.isEmpty()) {
for (NodeAtomic atomic : filter) {
- if (atomic.getAtomicNode() instanceof ParameterNode) {
+ if (atomic.getAtomicNode() instanceof ParameterNode parameterNodeAsAtomic) {
- ParameterNode atomicArgument = (ParameterNode) atomic.getAtomicNode();
+ ParameterNode atomicArgument = parameterNodeAsAtomic;
int indexNode = atomicArgument.index();
+ System.out.println("Index: " + indexNode);
TornadoAtomicIntegerNode newNode = new TornadoAtomicIntegerNode(OCLKind.INTEGER_ATOMIC_JAVA);
graph.addOrUnique(newNode);
newNode.assignIndexFromParameter(indexNode);
@@ -78,7 +78,7 @@ protected void run(StructuredGraph graph) {
newNode.setNext(first);
// Replace usages for this new node
- ParameterNode parameter = (ParameterNode) atomic.getAtomicNode();
+ ParameterNode parameter = parameterNodeAsAtomic;
newNode.replaceAtMatchingUsages(atomic, node -> !node.equals(atomic));
parameter.replaceAtMatchingUsages(newNode, node -> node.equals(atomic));
From d7881fbeeb125b8401eed7a856d91766cea8fd2d Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 3 Jul 2024 16:04:19 +0300
Subject: [PATCH 12/54] Refactor Tensor type classes for code consistency
---
.../tornado/api/types/tensors/TensorByte.java | 53 +++++++--------
.../tornado/api/types/tensors/TensorFP16.java | 65 +++++++++---------
.../tornado/api/types/tensors/TensorFP32.java | 67 +++++++++----------
.../tornado/api/types/tensors/TensorFP64.java | 67 +++++++++----------
.../api/types/tensors/TensorInt16.java | 67 +++++++++----------
.../api/types/tensors/TensorInt32.java | 67 +++++++++----------
.../api/types/tensors/TensorInt64.java | 67 +++++++++----------
7 files changed, 223 insertions(+), 230 deletions(-)
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorByte.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorByte.java
index 7cc5c48bf9..6041ea5950 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorByte.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorByte.java
@@ -17,14 +17,13 @@
*/
package uk.ac.manchester.tornado.api.types.tensors;
-import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
-import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
-import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
-
import java.lang.foreign.MemorySegment;
import java.util.Arrays;
-import static java.lang.foreign.ValueLayout.JAVA_BYTE;
+import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
+import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
+import uk.ac.manchester.tornado.api.types.tensors.Shape;
@SegmentElementSize(size = 1)
public final class TensorByte extends Tensor {
@@ -55,14 +54,33 @@ public TensorByte(Shape shape) {
this.tensorStorage = new ByteArray(numberOfElements);
}
+ /**
+ * Concatenates multiple {@link TensorByte} instances into a single {@link TensorByte}.
+ *
+ * @param arrays
+ * Variable number of {@link TensorByte} objects to be concatenated.
+ * @return A new {@link TensorByte} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static TensorByte concat(TensorByte... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(TensorByte::getSize).sum();
+ TensorByte concatArray = new TensorByte(new Shape(newSize));
+ long currentPositionBytes = 0;
+ for (TensorByte array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
public void init(byte value) {
for (int i = 0; i < getSize(); i++) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_BYTE, getBaseIndex() + i, value);
+ tensorStorage.set(i, value);
}
}
public void set(int index, byte value) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_BYTE, getBaseIndex() + index, value);
+ tensorStorage.set(index, value);
}
private long getBaseIndex() {
@@ -77,7 +95,7 @@ private long getBaseIndex() {
* @return
*/
public byte get(int index) {
- return tensorStorage.getSegmentWithHeader().getAtIndex(JAVA_BYTE, getBaseIndex() + index);
+ return tensorStorage.get(index);
}
@Override
@@ -129,23 +147,4 @@ public String getDTypeAsString() {
public DType getDType() {
return dType;
}
-
- /**
- * Concatenates multiple {@link TensorByte} instances into a single {@link TensorByte}.
- *
- * @param arrays
- * Variable number of {@link TensorByte} objects to be concatenated.
- * @return A new {@link TensorByte} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static TensorByte concat(TensorByte... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(TensorByte::getSize).sum();
- TensorByte concatArray = new TensorByte(new Shape(newSize));
- long currentPositionBytes = 0;
- for (TensorByte array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP16.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP16.java
index 6cb90356c4..fbce7588e7 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP16.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP16.java
@@ -17,16 +17,15 @@
*/
package uk.ac.manchester.tornado.api.types.tensors;
+import java.lang.foreign.MemorySegment;
+import java.util.Arrays;
+
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
-
-import java.lang.foreign.MemorySegment;
-import java.util.Arrays;
-
-import static java.lang.foreign.ValueLayout.JAVA_SHORT;
+import uk.ac.manchester.tornado.api.types.tensors.Shape;
@SegmentElementSize(size = 2)
public final class TensorFP16 extends Tensor {
@@ -56,14 +55,39 @@ public TensorFP16(Shape shape) {
this.tensorStorage = new HalfFloatArray(numberOfElements);
}
+ public static void initialize(TensorFP16 tensor, HalfFloat value) {
+ for (@Parallel int i = 0; i < tensor.getSize(); i++) {
+ tensor.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link TensorFP16} instances into a single {@link TensorFP16}.
+ *
+ * @param arrays
+ * Variable number of {@link TensorFP16} objects to be concatenated.
+ * @return A new {@link TensorFP16} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static TensorFP16 concat(TensorFP16... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(TensorFP16::getSize).sum();
+ TensorFP16 concatArray = new TensorFP16(new Shape(newSize));
+ long currentPositionBytes = 0;
+ for (TensorFP16 array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
public void init(HalfFloat value) {
for (int i = 0; i < getSize(); i++) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_SHORT, getBaseIndex() + i, value.getHalfFloatValue());
+ tensorStorage.set(i, value);
}
}
public void set(int index, HalfFloat value) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_SHORT, getBaseIndex() + index, value.getHalfFloatValue());
+ tensorStorage.set(index, value);
}
private long getBaseIndex() {
@@ -78,7 +102,7 @@ private long getBaseIndex() {
* @return
*/
public HalfFloat get(int index) {
- short halfFloatValue = tensorStorage.getSegmentWithHeader().getAtIndex(JAVA_SHORT, getBaseIndex() + index);
+ short halfFloatValue = tensorStorage.get(index).getHalfFloatValue();
return new HalfFloat(halfFloatValue);
}
@@ -131,29 +155,4 @@ public String getDTypeAsString() {
public DType getDType() {
return dType;
}
-
- public static void initialize(TensorFP16 tensor, HalfFloat value) {
- for (@Parallel int i = 0; i < tensor.getSize(); i++) {
- tensor.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link TensorFP16} instances into a single {@link TensorFP16}.
- *
- * @param arrays
- * Variable number of {@link TensorFP16} objects to be concatenated.
- * @return A new {@link TensorFP16} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static TensorFP16 concat(TensorFP16... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(TensorFP16::getSize).sum();
- TensorFP16 concatArray = new TensorFP16(new Shape(newSize));
- long currentPositionBytes = 0;
- for (TensorFP16 array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP32.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP32.java
index ca5616500b..23f4d17fed 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP32.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP32.java
@@ -17,16 +17,15 @@
*/
package uk.ac.manchester.tornado.api.types.tensors;
-import uk.ac.manchester.tornado.api.annotations.Parallel;
-import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
-import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
-
import java.lang.foreign.MemorySegment;
import java.nio.FloatBuffer;
import java.util.Arrays;
-import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
+import uk.ac.manchester.tornado.api.annotations.Parallel;
+import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
+import uk.ac.manchester.tornado.api.types.tensors.Shape;
@SegmentElementSize(size = 4)
public final class TensorFP32 extends Tensor {
@@ -57,14 +56,39 @@ public TensorFP32(Shape shape) {
this.tensorStorage = new FloatArray(numberOfElements);
}
+ public static void initialize(TensorFP32 tensor, short value) {
+ for (@Parallel int i = 0; i < tensor.getSize(); i++) {
+ tensor.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link TensorFP32} instances into a single {@link TensorFP32}.
+ *
+ * @param arrays
+ * Variable number of {@link TensorFP32} objects to be concatenated.
+ * @return A new {@link TensorFP32} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static TensorFP32 concat(TensorFP32... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(TensorFP32::getSize).sum();
+ TensorFP32 concatArray = new TensorFP32(new Shape(newSize));
+ long currentPositionBytes = 0;
+ for (TensorFP32 array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
public void init(float value) {
for (int i = 0; i < getSize(); i++) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_FLOAT, getBaseIndex() + i, value);
+ tensorStorage.set(i, value);
}
}
public void set(int index, float value) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_FLOAT, getBaseIndex() + index, value);
+ tensorStorage.set(index, value);
}
private long getBaseIndex() {
@@ -79,7 +103,7 @@ private long getBaseIndex() {
* @return
*/
public float get(int index) {
- return tensorStorage.getSegmentWithHeader().getAtIndex(JAVA_FLOAT, getBaseIndex() + index);
+ return tensorStorage.get(index);
}
@Override
@@ -143,29 +167,4 @@ public float[] toHeapArray() {
public FloatBuffer getFloatBuffer() {
return getSegment().asByteBuffer().asFloatBuffer();
}
-
- public static void initialize(TensorFP32 tensor, short value) {
- for (@Parallel int i = 0; i < tensor.getSize(); i++) {
- tensor.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link TensorFP32} instances into a single {@link TensorFP32}.
- *
- * @param arrays
- * Variable number of {@link TensorFP32} objects to be concatenated.
- * @return A new {@link TensorFP32} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static TensorFP32 concat(TensorFP32... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(TensorFP32::getSize).sum();
- TensorFP32 concatArray = new TensorFP32(new Shape(newSize));
- long currentPositionBytes = 0;
- for (TensorFP32 array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP64.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP64.java
index ef2022ae10..f75a74b95a 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP64.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorFP64.java
@@ -17,16 +17,15 @@
*/
package uk.ac.manchester.tornado.api.types.tensors;
-import uk.ac.manchester.tornado.api.annotations.Parallel;
-import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
-import uk.ac.manchester.tornado.api.types.arrays.DoubleArray;
-import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
-
import java.lang.foreign.MemorySegment;
import java.nio.DoubleBuffer;
import java.util.Arrays;
-import static java.lang.foreign.ValueLayout.JAVA_DOUBLE;
+import uk.ac.manchester.tornado.api.annotations.Parallel;
+import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
+import uk.ac.manchester.tornado.api.types.arrays.DoubleArray;
+import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
+import uk.ac.manchester.tornado.api.types.tensors.Shape;
@SegmentElementSize(size = 8)
public final class TensorFP64 extends Tensor {
@@ -57,14 +56,39 @@ public TensorFP64(Shape shape) {
this.tensorStorage = new DoubleArray(numberOfElements);
}
+ public static void initialize(TensorFP64 tensor, short value) {
+ for (@Parallel int i = 0; i < tensor.getSize(); i++) {
+ tensor.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link TensorFP64} instances into a single {@link TensorFP64}.
+ *
+ * @param arrays
+ * Variable number of {@link TensorFP64} objects to be concatenated.
+ * @return A new {@link TensorFP64} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static TensorFP64 concat(TensorFP64... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(TensorFP64::getSize).sum();
+ TensorFP64 concatArray = new TensorFP64(new Shape(newSize));
+ long currentPositionBytes = 0;
+ for (TensorFP64 array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
public void init(double value) {
for (int i = 0; i < getSize(); i++) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_DOUBLE, getBaseIndex() + i, value);
+ tensorStorage.set(i, value);
}
}
public void set(int index, double value) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_DOUBLE, getBaseIndex() + index, value);
+ tensorStorage.set(index, value);
}
private long getBaseIndex() {
@@ -79,7 +103,7 @@ private long getBaseIndex() {
* @return
*/
public double get(int index) {
- return tensorStorage.getSegmentWithHeader().getAtIndex(JAVA_DOUBLE, getBaseIndex() + index);
+ return tensorStorage.get(index);
}
@Override
@@ -135,29 +159,4 @@ public DType getDType() {
public DoubleBuffer getDoubleBuffer() {
return getSegment().asByteBuffer().asDoubleBuffer();
}
-
- public static void initialize(TensorFP64 tensor, short value) {
- for (@Parallel int i = 0; i < tensor.getSize(); i++) {
- tensor.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link TensorFP64} instances into a single {@link TensorFP64}.
- *
- * @param arrays
- * Variable number of {@link TensorFP64} objects to be concatenated.
- * @return A new {@link TensorFP64} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static TensorFP64 concat(TensorFP64... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(TensorFP64::getSize).sum();
- TensorFP64 concatArray = new TensorFP64(new Shape(newSize));
- long currentPositionBytes = 0;
- for (TensorFP64 array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt16.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt16.java
index 7d135e5c6c..3206cc90ac 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt16.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt16.java
@@ -17,17 +17,16 @@
*/
package uk.ac.manchester.tornado.api.types.tensors;
+import java.lang.foreign.MemorySegment;
+import java.nio.ShortBuffer;
+import java.util.Arrays;
+
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.ShortArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
-
-import java.lang.foreign.MemorySegment;
-import java.nio.ShortBuffer;
-import java.util.Arrays;
-
-import static java.lang.foreign.ValueLayout.JAVA_SHORT;
+import uk.ac.manchester.tornado.api.types.tensors.Shape;
@SegmentElementSize(size = 2)
public final class TensorInt16 extends Tensor {
@@ -58,14 +57,39 @@ public TensorInt16(Shape shape) {
this.tensorStorage = new ShortArray(numberOfElements);
}
+ public static void initialize(TensorInt16 tensor, short value) {
+ for (@Parallel int i = 0; i < tensor.getSize(); i++) {
+ tensor.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link TensorInt16} instances into a single {@link TensorInt16}.
+ *
+ * @param arrays
+ * Variable number of {@link TensorInt16} objects to be concatenated.
+ * @return A new {@link TensorInt16} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static TensorInt16 concat(TensorInt16... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(TensorInt16::getSize).sum();
+ TensorInt16 concatArray = new TensorInt16(new Shape(newSize));
+ long currentPositionBytes = 0;
+ for (TensorInt16 array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
public void init(short value) {
for (int i = 0; i < getSize(); i++) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_SHORT, getBaseIndex() + i, value);
+ tensorStorage.set(i, value);
}
}
public void set(int index, short value) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_SHORT, getBaseIndex() + index, value);
+ tensorStorage.set(index, value);
}
private long getBaseIndex() {
@@ -80,7 +104,7 @@ private long getBaseIndex() {
* @return
*/
public short get(int index) {
- return tensorStorage.getSegmentWithHeader().getAtIndex(JAVA_SHORT, getBaseIndex() + index);
+ return tensorStorage.get(index);
}
@Override
@@ -136,29 +160,4 @@ public DType getDType() {
public ShortBuffer getShortBuffer() {
return getSegment().asByteBuffer().asShortBuffer();
}
-
- public static void initialize(TensorInt16 tensor, short value) {
- for (@Parallel int i = 0; i < tensor.getSize(); i++) {
- tensor.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link TensorInt16} instances into a single {@link TensorInt16}.
- *
- * @param arrays
- * Variable number of {@link TensorInt16} objects to be concatenated.
- * @return A new {@link TensorInt16} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static TensorInt16 concat(TensorInt16... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(TensorInt16::getSize).sum();
- TensorInt16 concatArray = new TensorInt16(new Shape(newSize));
- long currentPositionBytes = 0;
- for (TensorInt16 array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt32.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt32.java
index 9f67dd0f0b..10dae979af 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt32.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt32.java
@@ -17,17 +17,16 @@
*/
package uk.ac.manchester.tornado.api.types.tensors;
+import java.lang.foreign.MemorySegment;
+import java.nio.IntBuffer;
+import java.util.Arrays;
+
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
-
-import java.lang.foreign.MemorySegment;
-import java.nio.IntBuffer;
-import java.util.Arrays;
-
-import static java.lang.foreign.ValueLayout.JAVA_INT;
+import uk.ac.manchester.tornado.api.types.tensors.Shape;
@SegmentElementSize(size = 4)
@@ -58,14 +57,39 @@ public TensorInt32(Shape shape) {
this.tensorStorage = new IntArray(numberOfElements);
}
+ public static void initialize(TensorInt32 tensor, int value) {
+ for (@Parallel int i = 0; i < tensor.getSize(); i++) {
+ tensor.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link TensorInt32} instances into a single {@link TensorInt32}.
+ *
+ * @param arrays
+ * Variable number of {@link TensorInt32} objects to be concatenated.
+ * @return A new {@link TensorInt32} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static TensorInt32 concat(TensorInt32... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(TensorInt32::getSize).sum();
+ TensorInt32 concatArray = new TensorInt32(new Shape(newSize));
+ long currentPositionBytes = 0;
+ for (TensorInt32 array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
public void init(int value) {
for (int i = 0; i < getSize(); i++) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_INT, getBaseIndex() + i, value);
+ tensorStorage.set(i, value);
}
}
public void set(int index, int value) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_INT, getBaseIndex() + index, value);
+ tensorStorage.set(index, value);
}
private long getBaseIndex() {
@@ -80,7 +104,7 @@ private long getBaseIndex() {
* @return
*/
public int get(int index) {
- return tensorStorage.getSegmentWithHeader().getAtIndex(JAVA_INT, getBaseIndex() + index);
+ return tensorStorage.get(index);
}
@Override
@@ -136,29 +160,4 @@ public DType getDType() {
public IntBuffer getIntBuffer() {
return getSegment().asByteBuffer().asIntBuffer();
}
-
- public static void initialize(TensorInt32 tensor, int value) {
- for (@Parallel int i = 0; i < tensor.getSize(); i++) {
- tensor.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link TensorInt32} instances into a single {@link TensorInt32}.
- *
- * @param arrays
- * Variable number of {@link TensorInt32} objects to be concatenated.
- * @return A new {@link TensorInt32} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static TensorInt32 concat(TensorInt32... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(TensorInt32::getSize).sum();
- TensorInt32 concatArray = new TensorInt32(new Shape(newSize));
- long currentPositionBytes = 0;
- for (TensorInt32 array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
}
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt64.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt64.java
index 63c61adf57..0b615ca127 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt64.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/tensors/TensorInt64.java
@@ -17,17 +17,16 @@
*/
package uk.ac.manchester.tornado.api.types.tensors;
+import java.lang.foreign.MemorySegment;
+import java.nio.LongBuffer;
+import java.util.Arrays;
+
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.LongArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
-
-import java.lang.foreign.MemorySegment;
-import java.nio.LongBuffer;
-import java.util.Arrays;
-
-import static java.lang.foreign.ValueLayout.JAVA_LONG;
+import uk.ac.manchester.tornado.api.types.tensors.Shape;
@SegmentElementSize(size = 8)
public final class TensorInt64 extends Tensor {
@@ -57,14 +56,39 @@ public TensorInt64(Shape shape) {
this.tensorStorage = new LongArray(numberOfElements);
}
+ public static void initialize(TensorInt64 tensor, long value) {
+ for (@Parallel int i = 0; i < tensor.getSize(); i++) {
+ tensor.set(i, value);
+ }
+ }
+
+ /**
+ * Concatenates multiple {@link TensorInt64} instances into a single {@link TensorInt64}.
+ *
+ * @param arrays
+ * Variable number of {@link TensorInt64} objects to be concatenated.
+ * @return A new {@link TensorInt64} instance containing all the elements of the input arrays,
+ * concatenated in the order they were provided.
+ */
+ public static TensorInt64 concat(TensorInt64... arrays) {
+ int newSize = Arrays.stream(arrays).mapToInt(TensorInt64::getSize).sum();
+ TensorInt64 concatArray = new TensorInt64(new Shape(newSize));
+ long currentPositionBytes = 0;
+ for (TensorInt64 array : arrays) {
+ MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
+ currentPositionBytes += array.getNumBytesOfSegment();
+ }
+ return concatArray;
+ }
+
public void init(long value) {
for (int i = 0; i < getSize(); i++) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_LONG, getBaseIndex() + i, value);
+ tensorStorage.set(i, value);
}
}
public void set(int index, long value) {
- tensorStorage.getSegmentWithHeader().setAtIndex(JAVA_LONG, getBaseIndex() + index, value);
+ tensorStorage.set(index, value);
}
private long getBaseIndex() {
@@ -79,7 +103,7 @@ private long getBaseIndex() {
* @return
*/
public long get(int index) {
- return tensorStorage.getSegmentWithHeader().getAtIndex(JAVA_LONG, getBaseIndex() + index);
+ return tensorStorage.get(index);
}
@Override
@@ -135,29 +159,4 @@ public DType getDType() {
public LongBuffer getLongBuffer() {
return getSegment().asByteBuffer().asLongBuffer();
}
-
- public static void initialize(TensorInt64 tensor, long value) {
- for (@Parallel int i = 0; i < tensor.getSize(); i++) {
- tensor.set(i, value);
- }
- }
-
- /**
- * Concatenates multiple {@link TensorInt64} instances into a single {@link TensorInt64}.
- *
- * @param arrays
- * Variable number of {@link TensorInt64} objects to be concatenated.
- * @return A new {@link TensorInt64} instance containing all the elements of the input arrays,
- * concatenated in the order they were provided.
- */
- public static TensorInt64 concat(TensorInt64... arrays) {
- int newSize = Arrays.stream(arrays).mapToInt(TensorInt64::getSize).sum();
- TensorInt64 concatArray = new TensorInt64(new Shape(newSize));
- long currentPositionBytes = 0;
- for (TensorInt64 array : arrays) {
- MemorySegment.copy(array.getSegment(), 0, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
- currentPositionBytes += array.getNumBytesOfSegment();
- }
- return concatArray;
- }
}
From 98f79bbb0be57db953ab2f3571fce25d0223ec90 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 3 Jul 2024 16:10:16 +0300
Subject: [PATCH 13/54] Fix short array index in initialization
---
.../uk/ac/manchester/tornado/api/types/arrays/ShortArray.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ShortArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ShortArray.java
index 21f65acc4c..aa1cfeb2ad 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ShortArray.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/ShortArray.java
@@ -223,7 +223,7 @@ public int getElementSize() {
*/
public void init(short value) {
for (int i = 0; i < getSize(); i++) {
- segment.setAtIndex(baseIndex + i, value, baseIndex);
+ segment.setAtIndex(i, value, baseIndex);
}
}
From 285a6f60afea01d669280fcce4c07552eadb72aa Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 4 Jul 2024 13:42:02 +0300
Subject: [PATCH 14/54] Add initial value for atomic integers in Tornado
---
.../plugins/OCLAtomicIntegerPlugin.java | 7 +++-
.../graal/nodes/TornadoAtomicIntegerNode.java | 38 ++++++++-----------
.../phases/TornadoAtomicsParametersPhase.java | 7 +++-
3 files changed, 25 insertions(+), 27 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLAtomicIntegerPlugin.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLAtomicIntegerPlugin.java
index bd86b5d874..d52c08de65 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLAtomicIntegerPlugin.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLAtomicIntegerPlugin.java
@@ -21,10 +21,12 @@
*/
package uk.ac.manchester.tornado.drivers.opencl.graal.compiler.plugins;
+import jdk.graal.compiler.core.common.type.StampFactory;
+import jdk.graal.compiler.nodes.ConstantNode;
import jdk.graal.compiler.nodes.graphbuilderconf.GraphBuilderContext;
import jdk.graal.compiler.nodes.graphbuilderconf.NodePlugin;
-
import jdk.vm.ci.hotspot.HotSpotResolvedJavaType;
+import jdk.vm.ci.meta.JavaConstant;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaType;
import uk.ac.manchester.tornado.api.internal.annotations.Vector;
@@ -45,7 +47,8 @@ private boolean createAtomicIntegerInstance(GraphBuilderContext b, ResolvedJavaT
OCLKind kind = resolveOCLKind(type);
if (kind != OCLKind.ILLEGAL) {
if (kind == OCLKind.INTEGER_ATOMIC_JAVA) {
- b.push(JavaKind.Object, b.append(new TornadoAtomicIntegerNode(kind)));
+ ConstantNode initialValue = new ConstantNode(JavaConstant.forInt(0), StampFactory.forConstant(JavaConstant.forInt(0)));
+ b.push(JavaKind.Object, b.append(new TornadoAtomicIntegerNode(kind, initialValue)));
return true;
}
}
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/TornadoAtomicIntegerNode.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/TornadoAtomicIntegerNode.java
index 287423a408..efa075c8d7 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/TornadoAtomicIntegerNode.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/TornadoAtomicIntegerNode.java
@@ -10,7 +10,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -34,7 +34,6 @@
import jdk.graal.compiler.nodes.ValueNode;
import jdk.graal.compiler.nodes.spi.LIRLowerable;
import jdk.graal.compiler.nodes.spi.NodeLIRBuilderTool;
-
import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLStampFactory;
@@ -47,35 +46,24 @@
public class TornadoAtomicIntegerNode extends FixedWithNextNode implements LIRLowerable {
public static final NodeClass TYPE = NodeClass.create(TornadoAtomicIntegerNode.class);
-
- private final OCLKind kind;
-
- private boolean ATOMIC_2_0 = false;
-
+ private static final int DEFAULT_VALUE = -1;
// How many atomics integers per graph
public static HashMap> globalAtomics = new HashMap<>();
-
// Mapping between:
// Java Method: -> { ParamIndex -> Position in the Atomic Buffer }
public static HashMap> globalAtomicsParameters = new HashMap<>();
-
- private static final int DEFAULT_VALUE = -1;
-
+ private final OCLKind kind;
@Input
- ValueNode initialValue;
-
+ protected ConstantNode initialValue;
+ private boolean ATOMIC_2_0 = false;
private int indexFromGlobalMemory;
private boolean atomicsByParameter = false;
- public TornadoAtomicIntegerNode(OCLKind kind) {
+ public TornadoAtomicIntegerNode(OCLKind kind, ConstantNode initialValue) {
super(TYPE, OCLStampFactory.getStampFor(kind));
this.kind = kind;
- this.initialValue = ConstantNode.forInt(0);
- }
-
- public void setInitialValue(ValueNode valueNode) {
- initialValue = valueNode;
+ this.initialValue = initialValue;
}
public void setInitialValueAtUsages(ValueNode valueNode) {
@@ -86,6 +74,10 @@ public ValueNode getInitialValue() {
return this.initialValue;
}
+ public void setInitialValue(ConstantNode valueNode) {
+ initialValue = valueNode;
+ }
+
private void generateExpressionForOpenCL2_0(NodeLIRBuilderTool gen) {
LIRGeneratorTool tool = gen.getLIRGeneratorTool();
Variable result = tool.newVariable(tool.getLIRKind(StampFactory.intValue()));
@@ -104,8 +96,8 @@ public int getIndexFromGlobalMemory() {
}
private int getIntFromValueNode() {
- if (initialValue instanceof ConstantNode) {
- ConstantNode c = (ConstantNode) initialValue;
+ if (initialValue instanceof ConstantNode initFromValueNode) {
+ ConstantNode c = initFromValueNode;
return Integer.parseInt(c.getValue().toValueString());
} else {
throw new TornadoRuntimeException("Value node not implemented for Atomics");
@@ -124,8 +116,8 @@ private void updateGlobalAtomicTable(HashMap positions, int paramIndex, int size
* buffer.
*
* @param paramIndex
- * Object parameter index taken from
- * {@link org.graalvm.compiler.nodes.ParameterNode}.
+ * Object parameter index taken from
+ * {@link org.graalvm.compiler.nodes.ParameterNode}.
*/
public synchronized void assignIndexFromParameter(int paramIndex) {
if (!globalAtomics.containsKey(this.graph().method())) {
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
index 38d1e98013..d3841434f5 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
@@ -23,6 +23,7 @@
import java.util.Optional;
+import jdk.graal.compiler.core.common.type.StampFactory;
import jdk.graal.compiler.graph.iterators.NodeIterable;
import jdk.graal.compiler.nodes.ConstantNode;
import jdk.graal.compiler.nodes.FixedNode;
@@ -31,6 +32,7 @@
import jdk.graal.compiler.nodes.StartNode;
import jdk.graal.compiler.nodes.StructuredGraph;
import jdk.graal.compiler.phases.Phase;
+import jdk.vm.ci.meta.JavaConstant;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.IncAtomicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.NodeAtomic;
@@ -63,8 +65,9 @@ protected void run(StructuredGraph graph) {
ParameterNode atomicArgument = parameterNodeAsAtomic;
int indexNode = atomicArgument.index();
- System.out.println("Index: " + indexNode);
- TornadoAtomicIntegerNode newNode = new TornadoAtomicIntegerNode(OCLKind.INTEGER_ATOMIC_JAVA);
+ ConstantNode initialValue = new ConstantNode(JavaConstant.forInt(-1), StampFactory.forConstant(JavaConstant.forInt(-1)));
+ graph.addOrUnique(initialValue);
+ TornadoAtomicIntegerNode newNode = new TornadoAtomicIntegerNode(OCLKind.INTEGER_ATOMIC_JAVA, initialValue);
graph.addOrUnique(newNode);
newNode.assignIndexFromParameter(indexNode);
From e2af57075ad8cf6168d01c5ddfc2b4a65616ce94 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 4 Jul 2024 17:28:32 +0300
Subject: [PATCH 15/54] Refactor and optimize OCLGraphBuilder and
TornadoAtomicsParameters
---
.../plugins/OCLGraphBuilderPlugins.java | 27 ++++++++++++-------
.../phases/TornadoAtomicsParametersPhase.java | 6 ++---
2 files changed, 21 insertions(+), 12 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
index 4e3bcee532..c7878b1a99 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java
@@ -147,8 +147,9 @@ private static void registerAtomicCall(Registration r, JavaKind returnedJavaKind
r.register(new InvocationPlugin("incrementAndGet", Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
- receiver.get(true);
- b.addPush(returnedJavaKind, b.append(new IncAtomicNode(receiver.get())));
+ IncAtomicNode atomicNode = new IncAtomicNode(receiver.get(true));
+ b.getGraph().addOrUnique(atomicNode);
+ b.addPush(returnedJavaKind, atomicNode);
return true;
}
});
@@ -189,7 +190,7 @@ public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, Va
if (initialValue instanceof ConstantNode constantNode) {
int value = Integer.parseInt(constantNode.getValue().toValueString());
if (value == 0) {
- atomic.setInitialValue(initialValue);
+ atomic.setInitialValue(constantNode);
} else {
atomic.setInitialValueAtUsages(initialValue);
}
@@ -219,7 +220,7 @@ private static TornadoAtomicIntegerNode resolveReceiverAtomic(ValueNode thisObje
}
private static void registerLocalBarrier(Registration r) {
- r.register(new InvocationPlugin("localBarrier") {
+ r.register(new InvocationPlugin("localBarrier", Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
receiver.get(true);
@@ -231,7 +232,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerGlobalBarrier(Registration r) {
- r.register(new InvocationPlugin("globalBarrier") {
+ r.register(new InvocationPlugin("globalBarrier", Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
receiver.get(true);
@@ -243,12 +244,14 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerIntLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
- r.register(new InvocationPlugin("allocateIntLocalArray", int.class) {
+ r.register(new InvocationPlugin("allocateIntLocalArray", Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
+ b.getGraph().addOrUnique(constantNode);
LocalArrayNode localArrayNode = new LocalArrayNode(OCLArchitecture.localSpace, elementType, constantNode);
+ b.getGraph().addOrUnique(localArrayNode);
b.push(returnedJavaKind, localArrayNode);
return true;
}
@@ -256,12 +259,14 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerLongLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
- r.register(new InvocationPlugin("allocateLongLocalArray", int.class) {
+ r.register(new InvocationPlugin("allocateLongLocalArray", Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
+ b.getGraph().addOrUnique(constantNode);
LocalArrayNode localArrayNode = new LocalArrayNode(OCLArchitecture.localSpace, elementType, constantNode);
+ b.getGraph().addOrUnique(localArrayNode);
b.push(returnedJavaKind, localArrayNode);
return true;
}
@@ -269,12 +274,14 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerFloatLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
- r.register(new InvocationPlugin("allocateFloatLocalArray", int.class) {
+ r.register(new InvocationPlugin("allocateFloatLocalArray", Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
+ b.getGraph().addOrUnique(constantNode);
LocalArrayNode localArrayNode = new LocalArrayNode(OCLArchitecture.localSpace, elementType, constantNode);
+ b.getGraph().addOrUnique(localArrayNode);
b.push(returnedJavaKind, localArrayNode);
return true;
}
@@ -282,12 +289,14 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
private static void registerDoubleLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
- r.register(new InvocationPlugin("allocateDoubleLocalArray", int.class) {
+ r.register(new InvocationPlugin("allocateDoubleLocalArray", Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
+ b.getGraph().addOrUnique(constantNode);
LocalArrayNode localArrayNode = new LocalArrayNode(OCLArchitecture.localSpace, elementType, constantNode);
+ b.getGraph().addOrUnique(localArrayNode);
b.push(returnedJavaKind, localArrayNode);
return true;
}
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
index d3841434f5..fc76d55365 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
@@ -23,7 +23,6 @@
import java.util.Optional;
-import jdk.graal.compiler.core.common.type.StampFactory;
import jdk.graal.compiler.graph.iterators.NodeIterable;
import jdk.graal.compiler.nodes.ConstantNode;
import jdk.graal.compiler.nodes.FixedNode;
@@ -32,7 +31,6 @@
import jdk.graal.compiler.nodes.StartNode;
import jdk.graal.compiler.nodes.StructuredGraph;
import jdk.graal.compiler.phases.Phase;
-import jdk.vm.ci.meta.JavaConstant;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.IncAtomicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.NodeAtomic;
@@ -65,7 +63,9 @@ protected void run(StructuredGraph graph) {
ParameterNode atomicArgument = parameterNodeAsAtomic;
int indexNode = atomicArgument.index();
- ConstantNode initialValue = new ConstantNode(JavaConstant.forInt(-1), StampFactory.forConstant(JavaConstant.forInt(-1)));
+ // ConstantNode initialValue = new ConstantNode(JavaConstant.forInt(-1), StampFactory.forConstant(JavaConstant.forInt(-1)));
+ final ConstantNode initialValue = graph.addOrUnique(ConstantNode.forInt(0));
+
graph.addOrUnique(initialValue);
TornadoAtomicIntegerNode newNode = new TornadoAtomicIntegerNode(OCLKind.INTEGER_ATOMIC_JAVA, initialValue);
graph.addOrUnique(newNode);
From ea984436cba9c613d969feb4a0b9cb2e1e953c74 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 4 Jul 2024 17:46:25 +0300
Subject: [PATCH 16/54] Refactor atomic parameter initialization in
TornadoAtomicsParametersPhase
---
.../opencl/graal/phases/TornadoAtomicsParametersPhase.java | 2 --
1 file changed, 2 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
index fc76d55365..a923669274 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoAtomicsParametersPhase.java
@@ -63,9 +63,7 @@ protected void run(StructuredGraph graph) {
ParameterNode atomicArgument = parameterNodeAsAtomic;
int indexNode = atomicArgument.index();
- // ConstantNode initialValue = new ConstantNode(JavaConstant.forInt(-1), StampFactory.forConstant(JavaConstant.forInt(-1)));
final ConstantNode initialValue = graph.addOrUnique(ConstantNode.forInt(0));
-
graph.addOrUnique(initialValue);
TornadoAtomicIntegerNode newNode = new TornadoAtomicIntegerNode(OCLKind.INTEGER_ATOMIC_JAVA, initialValue);
graph.addOrUnique(newNode);
From 7f92f13d6c4ab11c159a96d5e1f41b6009a95cfd Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 4 Jul 2024 18:48:10 +0300
Subject: [PATCH 17/54] Refactor and clean up imports in PTX compiler code
---
pom.xml | 23 ++++--------
.../ptx/src/main/java/module-info.java | 37 ++++++-------------
.../compiler/PTXCompilationResultBuilder.java | 16 ++++----
.../ptx/graal/compiler/PTXHighTier.java | 11 +++---
.../ptx/graal/compiler/PTXLIRGenerator.java | 35 +++++++++++++-----
.../ptx/graal/compiler/PTXLowTier.java | 5 +--
.../ptx/graal/compiler/PTXMidTier.java | 7 ++--
.../drivers/ptx/graal/lir/PTXBinary.java | 9 ++---
.../drivers/ptx/graal/lir/PTXTernary.java | 12 +++---
.../drivers/ptx/graal/lir/PTXUnary.java | 7 ++--
10 files changed, 73 insertions(+), 89 deletions(-)
diff --git a/pom.xml b/pom.xml
index 23b0de5d41..1266b58ba8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -700,8 +700,7 @@
--add-exportsjdk.graal.compiler/jdk.graal.compiler.lir=tornado.drivers.ptx--add-exports
- jdk.graal.compiler/jdk.graal.compiler.lir.framemap=tornado.drivers.ptx
-
+ jdk.graal.compiler/jdk.graal.compiler.lir.framemap=tornado.drivers.ptx--add-exportsjdk.internal.vm.ci/jdk.vm.ci.meta=tornado.drivers.ptx--add-exports
@@ -1542,28 +1541,20 @@
--add-exportsjdk.internal.vm.ci/jdk.vm.ci.amd64=tornado.drivers.ptx--add-exports
- jdk.graal.compiler/jdk.graal.compiler.hotspot.meta=tornado.drivers.ptx
-
+ jdk.graal.compiler/jdk.graal.compiler.hotspot.meta=tornado.drivers.ptx--add-exports
-
- jdk.graal.compiler/jdk.graal.compiler.replacements.classfile=tornado.drivers.ptx
+ jdk.graal.compiler/jdk.graal.compiler.replacements.classfile=tornado.drivers.ptx
--add-exports
-
- jdk.graal.compiler/jdk.graal.compiler.core.common.alloc=tornado.drivers.ptx
-
+ jdk.graal.compiler/jdk.graal.compiler.core.common.alloc=tornado.drivers.ptx--add-exports
-
- jdk.graal.compiler/jdk.graal.compiler.core.common.util=tornado.drivers.ptx
-
+ jdk.graal.compiler/jdk.graal.compiler.core.common.util=tornado.drivers.ptx--add-exports
- jdk.graal.compiler/jdk.graal.compiler.core.common.cfg=tornado.drivers.ptx
-
+ jdk.graal.compiler/jdk.graal.compiler.core.common.cfg=tornado.drivers.ptx--add-exportsjdk.graal.compiler/jdk.graal.compiler.lir=tornado.drivers.ptx--add-exports
- jdk.graal.compiler/jdk.graal.compiler.lir.framemap=tornado.drivers.ptx
-
+ jdk.graal.compiler/jdk.graal.compiler.lir.framemap=tornado.drivers.ptx--add-exportsjdk.internal.vm.ci/jdk.vm.ci.meta=tornado.drivers.ptx--add-exports
diff --git a/tornado-drivers/ptx/src/main/java/module-info.java b/tornado-drivers/ptx/src/main/java/module-info.java
index 0ae1ee54f1..27ba377720 100644
--- a/tornado-drivers/ptx/src/main/java/module-info.java
+++ b/tornado-drivers/ptx/src/main/java/module-info.java
@@ -1,30 +1,15 @@
import uk.ac.manchester.tornado.runtime.TornadoBackendProvider;
-module tornado.drivers.ptx {
- requires transitive jdk.internal.vm.ci;
- requires transitive jdk.internal.vm.compiler;
- requires transitive org.graalvm.collections;
- requires transitive org.graalvm.word;
- requires transitive tornado.api;
- requires transitive tornado.runtime;
- requires tornado.drivers.common;
+module tornado.drivers.ptx{
+// requires transitive jdk.internal.vm.ci;
+// requires transitive jdk.internal.vm.compiler;
+// requires transitive org.graalvm.collections;
+// requires transitive org.graalvm.word;
+// requires transitive tornado.api;
+// requires transitive tornado.runtime;
+// requires tornado.drivers.common;
+requires java.base;requires transitive jdk.internal.vm.ci;requires transitive jdk.graal.compiler;requires transitive org.graalvm.collections;requires transitive org.graalvm.word;requires transitive tornado.api;requires transitive tornado.runtime;requires tornado.drivers.common;
- exports uk.ac.manchester.tornado.drivers.ptx;
- exports uk.ac.manchester.tornado.drivers.ptx.enums;
- exports uk.ac.manchester.tornado.drivers.ptx.graal;
- exports uk.ac.manchester.tornado.drivers.ptx.graal.asm;
- exports uk.ac.manchester.tornado.drivers.ptx.graal.backend;
- exports uk.ac.manchester.tornado.drivers.ptx.graal.compiler;
- exports uk.ac.manchester.tornado.drivers.ptx.graal.lir;
- exports uk.ac.manchester.tornado.drivers.ptx.graal.meta;
- exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes;
- exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.calc;
- exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.vector;
- exports uk.ac.manchester.tornado.drivers.ptx.graal.phases;
- exports uk.ac.manchester.tornado.drivers.ptx.mm;
- exports uk.ac.manchester.tornado.drivers.ptx.runtime;
- exports uk.ac.manchester.tornado.drivers.ptx.power;
+exports uk.ac.manchester.tornado.drivers.ptx;exports uk.ac.manchester.tornado.drivers.ptx.enums;exports uk.ac.manchester.tornado.drivers.ptx.graal;exports uk.ac.manchester.tornado.drivers.ptx.graal.asm;exports uk.ac.manchester.tornado.drivers.ptx.graal.backend;exports uk.ac.manchester.tornado.drivers.ptx.graal.compiler;exports uk.ac.manchester.tornado.drivers.ptx.graal.lir;exports uk.ac.manchester.tornado.drivers.ptx.graal.meta;exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes;exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.calc;exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.vector;exports uk.ac.manchester.tornado.drivers.ptx.graal.phases;exports uk.ac.manchester.tornado.drivers.ptx.mm;exports uk.ac.manchester.tornado.drivers.ptx.runtime;exports uk.ac.manchester.tornado.drivers.ptx.power;
- provides TornadoBackendProvider with
- uk.ac.manchester.tornado.drivers.ptx.PTXTornadoDriverProvider;
-}
+provides TornadoBackendProvider with uk.ac.manchester.tornado.drivers.ptx.PTXTornadoDriverProvider;}
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXCompilationResultBuilder.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXCompilationResultBuilder.java
index a946bb7447..af2d81970f 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXCompilationResultBuilder.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXCompilationResultBuilder.java
@@ -10,7 +10,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -33,9 +33,9 @@
import org.graalvm.collections.EconomicMap;
import org.graalvm.collections.Equivalence;
+
import jdk.graal.compiler.asm.Assembler;
import jdk.graal.compiler.code.CompilationResult;
-import jdk.graal.compiler.core.common.spi.CodeGenProviders;
import jdk.graal.compiler.debug.DebugContext;
import jdk.graal.compiler.lir.LIR;
import jdk.graal.compiler.lir.LIRInstruction;
@@ -54,8 +54,8 @@
import jdk.graal.compiler.nodes.MergeNode;
import jdk.graal.compiler.nodes.cfg.ControlFlowGraph;
import jdk.graal.compiler.nodes.cfg.HIRBlock;
+import jdk.graal.compiler.nodes.spi.CoreProviders;
import jdk.graal.compiler.options.OptionValues;
-
import jdk.vm.ci.code.Register;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError;
@@ -76,7 +76,7 @@ public class PTXCompilationResultBuilder extends CompilationResultBuilder {
private PTXLIRGenerationResult lirGenRes;
private TaskMetaData meta;
- public PTXCompilationResultBuilder(CodeGenProviders providers, FrameMap frameMap, Assembler asm, DataBuilder dataBuilder, FrameContext frameContext, OptionValues options, DebugContext debug,
+ public PTXCompilationResultBuilder(CoreProviders providers, FrameMap frameMap, Assembler asm, DataBuilder dataBuilder, FrameContext frameContext, OptionValues options, DebugContext debug,
CompilationResult compilationResult, LIR lir) {
super(providers, frameMap, asm, dataBuilder, frameContext, options, debug, compilationResult, Register.None, EconomicMap.create(Equivalence.DEFAULT), NO_VERIFIERS, lir);
@@ -294,13 +294,13 @@ private boolean isTrueBranchWithEndNodeOrNotControlSplit(HIRBlock blockTrueBranc
* control Split (due to nested control-flow).
*
* @param HIRBlock
- * {@link HIRBlock}
+ * {@link HIRBlock}
* @param visitor
- * {@link PTXBlockVisitor}
+ * {@link PTXBlockVisitor}
* @param visited
- * {@link HashSet}
+ * {@link HashSet}
* @param pending
- * {@link HashMap}
+ * {@link HashMap}
*/
private void rescheduleTrueBranchConditionsIfNeeded(HIRBlock basicBlock, PTXBlockVisitor visitor, HashSet visited, HashMap pending) {
if (!basicBlock.isLoopHeader() && basicBlock.getDominator() != null && basicBlock.getDominator().getEndNode() instanceof IfNode) {
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXHighTier.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXHighTier.java
index 7bd5ebf23f..7b2721f374 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXHighTier.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXHighTier.java
@@ -24,11 +24,11 @@
package uk.ac.manchester.tornado.drivers.ptx.graal.compiler;
-import static org.graalvm.compiler.core.common.GraalOptions.ConditionalElimination;
-import static org.graalvm.compiler.core.common.GraalOptions.OptConvertDeoptsToGuards;
-import static org.graalvm.compiler.core.common.GraalOptions.PartialEscapeAnalysis;
-import static org.graalvm.compiler.core.phases.HighTier.Options.Inline;
-import static org.graalvm.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Optional;
+import static jdk.graal.compiler.core.common.GraalOptions.ConditionalElimination;
+import static jdk.graal.compiler.core.common.GraalOptions.OptConvertDeoptsToGuards;
+import static jdk.graal.compiler.core.common.GraalOptions.PartialEscapeAnalysis;
+import static jdk.graal.compiler.core.phases.HighTier.Options.Inline;
+import static jdk.graal.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Optional;
import jdk.graal.compiler.loop.phases.ConvertDeoptimizeToGuardPhase;
import jdk.graal.compiler.loop.phases.LoopFullUnrollPhase;
@@ -43,7 +43,6 @@
import jdk.graal.compiler.phases.common.inlining.InliningPhase;
import jdk.graal.compiler.phases.schedule.SchedulePhase;
import jdk.graal.compiler.virtual.phases.ea.PartialEscapePhase;
-
import jdk.vm.ci.meta.MetaAccessProvider;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.analysis.TornadoShapeAnalysis;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.ExceptionSuppression;
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXLIRGenerator.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXLIRGenerator.java
index 39a1352248..d3c7e361bd 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXLIRGenerator.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXLIRGenerator.java
@@ -10,7 +10,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -45,7 +45,6 @@
import jdk.graal.compiler.lir.gen.LIRGenerationResult;
import jdk.graal.compiler.lir.gen.LIRGenerator;
import jdk.graal.compiler.phases.util.Providers;
-
import jdk.vm.ci.code.Register;
import jdk.vm.ci.code.StackSlot;
import jdk.vm.ci.meta.AllocatableValue;
@@ -98,14 +97,11 @@ public static PTXBinaryOp getConditionalOp(Condition condition) {
case AT:
case GT:
return PTXBinaryOp.SETP_GT;
-
case EQ:
return PTXBinaryOp.SETP_EQ;
-
case BE:
case LE:
return PTXBinaryOp.SETP_LE;
-
case BT:
case LT:
return PTXBinaryOp.SETP_LT;
@@ -114,7 +110,6 @@ public static PTXBinaryOp getConditionalOp(Condition condition) {
default:
shouldNotReachHere();
break;
-
}
return null;
}
@@ -298,6 +293,16 @@ public void emitIntegerTestBranch(Value left, Value right, LabelRef trueDestinat
unimplemented();
}
+ @Override
+ public void emitOpMaskTestBranch(Value left, boolean negateLeft, Value right, LabelRef trueDestination, LabelRef falseDestination, double trueSuccessorProbability) {
+
+ }
+
+ @Override
+ public void emitOpMaskOrTestBranch(Value left, Value right, boolean allZeros, LabelRef trueDestination, LabelRef falseDestination, double trueSuccessorProbability) {
+
+ }
+
@Override
public Variable emitConditionalMove(PlatformKind cmpKind, Value left, Value right, Condition cond, boolean unorderedIsTrue, Value trueValue, Value falseValue) {
Logger.traceBuildLIR(Logger.BACKEND.PTX, "emitConditionalMove");
@@ -317,13 +322,13 @@ public Variable emitConditionalMove(PlatformKind cmpKind, Value left, Value righ
* based on a bitwise and operation between two values.
*
* @param leftVal
- * the left value of a condition
+ * the left value of a condition
* @param right
- * the right value of a condition
+ * the right value of a condition
* @param trueValue
- * the true value to move in the result
+ * the true value to move in the result
* @param falseValue
- * the false value to move in the result
+ * the false value to move in the result
* @return Variable: reference to the variable that contains the result
*/
@Override
@@ -346,6 +351,16 @@ public Variable emitIntegerTestMove(Value leftVal, Value right, Value trueValue,
return result;
}
+ @Override
+ public Variable emitOpMaskTestMove(Value leftVal, boolean negateLeft, Value right, Value trueValue, Value falseValue) {
+ return null;
+ }
+
+ @Override
+ public Variable emitOpMaskOrTestMove(Value leftVal, Value right, boolean allZeros, Value trueValue, Value falseValue) {
+ return null;
+ }
+
@Override
public Variable emitReverseBytes(Value operand) {
return null;
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXLowTier.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXLowTier.java
index 4d8893a0e4..7622f079a1 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXLowTier.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXLowTier.java
@@ -24,8 +24,8 @@
package uk.ac.manchester.tornado.drivers.ptx.graal.compiler;
-import static org.graalvm.compiler.core.common.GraalOptions.ConditionalElimination;
-import static org.graalvm.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Required;
+import static jdk.graal.compiler.core.common.GraalOptions.ConditionalElimination;
+import static jdk.graal.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Required;
import jdk.graal.compiler.options.OptionValues;
import jdk.graal.compiler.phases.common.AddressLoweringByNodePhase;
@@ -35,7 +35,6 @@
import jdk.graal.compiler.phases.common.IterativeConditionalEliminationPhase;
import jdk.graal.compiler.phases.common.LowTierLoweringPhase;
import jdk.graal.compiler.phases.schedule.SchedulePhase;
-
import uk.ac.manchester.tornado.api.TornadoDeviceContext;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.analysis.TornadoFeatureExtraction;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.loops.TornadoLoopCanonicalization;
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXMidTier.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXMidTier.java
index 68c30a5243..a26a0e51db 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXMidTier.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXMidTier.java
@@ -24,9 +24,9 @@
package uk.ac.manchester.tornado.drivers.ptx.graal.compiler;
-import static org.graalvm.compiler.core.common.GraalOptions.ConditionalElimination;
-import static org.graalvm.compiler.core.common.GraalOptions.OptFloatingReads;
-import static org.graalvm.compiler.core.common.GraalOptions.ReassociateExpressions;
+import static jdk.graal.compiler.core.common.GraalOptions.ConditionalElimination;
+import static jdk.graal.compiler.core.common.GraalOptions.OptFloatingReads;
+import static jdk.graal.compiler.core.common.GraalOptions.ReassociateExpressions;
import jdk.graal.compiler.options.OptionValues;
import jdk.graal.compiler.phases.common.CanonicalizerPhase;
@@ -35,7 +35,6 @@
import jdk.graal.compiler.phases.common.IterativeConditionalEliminationPhase;
import jdk.graal.compiler.phases.common.MidTierLoweringPhase;
import jdk.graal.compiler.phases.common.ReassociationPhase;
-
import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.BoundCheckEliminationPhase;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.ExceptionCheckingElimination;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoPanamaSegmentsHeaderPhase;
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXBinary.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXBinary.java
index 2c8e433ed9..5b8d777245 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXBinary.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXBinary.java
@@ -12,7 +12,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -24,14 +24,13 @@
package uk.ac.manchester.tornado.drivers.ptx.graal.lir;
-import static org.graalvm.compiler.lir.LIRInstruction.Use;
import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler.PTXBinaryIntrinsic;
import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssemblerConstants.TAB;
import jdk.graal.compiler.core.common.LIRKind;
+import jdk.graal.compiler.lir.LIRInstruction;
import jdk.graal.compiler.lir.Opcode;
import jdk.graal.compiler.lir.Variable;
-
import jdk.vm.ci.meta.Value;
import uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler;
import uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler.PTXBinaryOp;
@@ -47,9 +46,9 @@ protected static class BinaryConsumer extends PTXLIROp {
@Opcode
protected final PTXBinaryOp opcode;
- @Use
+ @LIRInstruction.Use
protected Value x;
- @Use
+ @LIRInstruction.Use
protected Value y;
protected BinaryConsumer(PTXBinaryOp opcode, LIRKind lirKind, Value x, Value y) {
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXTernary.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXTernary.java
index d4e102b959..de651c9f7d 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXTernary.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXTernary.java
@@ -12,7 +12,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -24,12 +24,10 @@
package uk.ac.manchester.tornado.drivers.ptx.graal.lir;
-import static org.graalvm.compiler.lir.LIRInstruction.Use;
-
import jdk.graal.compiler.core.common.LIRKind;
+import jdk.graal.compiler.lir.LIRInstruction;
import jdk.graal.compiler.lir.Opcode;
import jdk.graal.compiler.lir.Variable;
-
import jdk.vm.ci.meta.Value;
import uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler;
import uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler.PTXTernaryOp;
@@ -44,11 +42,11 @@ protected static class TernaryConsumer extends PTXLIROp {
@Opcode
protected final PTXTernaryOp opcode;
- @Use
+ @LIRInstruction.Use
protected Value x;
- @Use
+ @LIRInstruction.Use
protected Value y;
- @Use
+ @LIRInstruction.Use
protected Value z;
protected TernaryConsumer(PTXTernaryOp opcode, LIRKind lirKind, Value x, Value y, Value z) {
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXUnary.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXUnary.java
index 9bbfc5ac73..2fa5bca27b 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXUnary.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXUnary.java
@@ -12,7 +12,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -24,7 +24,6 @@
package uk.ac.manchester.tornado.drivers.ptx.graal.lir;
-import static org.graalvm.compiler.lir.LIRInstruction.Use;
import static uk.ac.manchester.tornado.drivers.ptx.graal.PTXArchitecture.paramSpace;
import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler.PTXUnaryOp;
import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssemblerConstants.COMMA;
@@ -35,9 +34,9 @@
import jdk.graal.compiler.core.common.LIRKind;
import jdk.graal.compiler.lir.ConstantValue;
+import jdk.graal.compiler.lir.LIRInstruction;
import jdk.graal.compiler.lir.Opcode;
import jdk.graal.compiler.lir.Variable;
-
import jdk.vm.ci.meta.Value;
import uk.ac.manchester.tornado.drivers.ptx.graal.PTXArchitecture.PTXMemoryBase;
import uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler;
@@ -50,7 +49,7 @@ public class PTXUnary {
* Abstract operation which consumes one input
*/
protected static class UnaryConsumer extends PTXLIROp {
- @Use
+ @LIRInstruction.Use
protected Value value;
@Opcode
From d572a4dfdc9471ed147f71bea210ba6053aa5251 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 4 Jul 2024 18:52:59 +0300
Subject: [PATCH 18/54] Refactor code for improved readability and remove
redundancy
---
pom.xml | 3 +-
.../TornadoFloatingReadReplacement.java | 15 ++++----
.../ptx/src/main/java/module-info.java | 10 +-----
.../ptx/graal/nodes/FixedArrayCopyNode.java | 29 ++++++++-------
.../TornadoFloatingReadReplacement.java | 35 +++++++++----------
5 files changed, 40 insertions(+), 52 deletions(-)
diff --git a/pom.xml b/pom.xml
index 1266b58ba8..6f2a16f787 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1730,8 +1730,7 @@
--add-exportsjdk.graal.compiler/jdk.graal.compiler.lir=tornado.drivers.spirv--add-exports
- jdk.graal.compiler/jdk.graal.compiler.lir.framemap=tornado.drivers.spirv
-
+ jdk.graal.compiler/jdk.graal.compiler.lir.framemap=tornado.drivers.spirv--add-exportsjdk.internal.vm.ci/jdk.vm.ci.meta=tornado.drivers.spirv--add-exports
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoFloatingReadReplacement.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoFloatingReadReplacement.java
index e957db92b3..0209122f01 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoFloatingReadReplacement.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoFloatingReadReplacement.java
@@ -21,6 +21,11 @@
*/
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;
+import static jdk.graal.compiler.graph.Graph.NodeEvent.NODE_ADDED;
+import static jdk.graal.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
+import static org.graalvm.word.LocationIdentity.any;
+import static org.graalvm.word.LocationIdentity.init;
+
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
@@ -29,6 +34,8 @@
import org.graalvm.collections.EconomicSet;
import org.graalvm.collections.Equivalence;
import org.graalvm.collections.UnmodifiableMapCursor;
+import org.graalvm.word.LocationIdentity;
+
import jdk.graal.compiler.core.common.cfg.Loop;
import jdk.graal.compiler.debug.DebugCloseable;
import jdk.graal.compiler.debug.GraalError;
@@ -73,17 +80,10 @@
import jdk.graal.compiler.phases.common.PostRunCanonicalizationPhase;
import jdk.graal.compiler.phases.common.util.EconomicSetNodeEventListener;
import jdk.graal.compiler.phases.graph.ReentrantNodeIterator;
-import org.graalvm.word.LocationIdentity;
-
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.FixedArrayNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLBarrierNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.vector.VectorLoadElementNode;
-import static jdk.graal.compiler.graph.Graph.NodeEvent.NODE_ADDED;
-import static jdk.graal.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
-import static org.graalvm.word.LocationIdentity.any;
-import static org.graalvm.word.LocationIdentity.init;
-
/**
* This phase modifies the functionality of the originally FloatingRead Phase
* from Graal. The original phase reschedules and replaces Read-Nodes with
@@ -246,7 +246,6 @@ protected void run(StructuredGraph graph, CoreProviders context) {
EconomicMap> modifiedInLoops = null;
if (graph.hasLoops()) {
modifiedInLoops = EconomicMap.create(Equivalence.IDENTITY);
- // ControlFlowGraph cfg = ControlFlowGraph.compute(graph, true, true, false, false);
ControlFlowGraph cfg = ControlFlowGraph.newBuilder(graph).connectBlocks(true).computeLoops(true).computeFrequency(true).build();
for (Loop> l : cfg.getLoops()) {
diff --git a/tornado-drivers/ptx/src/main/java/module-info.java b/tornado-drivers/ptx/src/main/java/module-info.java
index 27ba377720..19fb6b1b37 100644
--- a/tornado-drivers/ptx/src/main/java/module-info.java
+++ b/tornado-drivers/ptx/src/main/java/module-info.java
@@ -1,14 +1,6 @@
import uk.ac.manchester.tornado.runtime.TornadoBackendProvider;
-module tornado.drivers.ptx{
-// requires transitive jdk.internal.vm.ci;
-// requires transitive jdk.internal.vm.compiler;
-// requires transitive org.graalvm.collections;
-// requires transitive org.graalvm.word;
-// requires transitive tornado.api;
-// requires transitive tornado.runtime;
-// requires tornado.drivers.common;
-requires java.base;requires transitive jdk.internal.vm.ci;requires transitive jdk.graal.compiler;requires transitive org.graalvm.collections;requires transitive org.graalvm.word;requires transitive tornado.api;requires transitive tornado.runtime;requires tornado.drivers.common;
+module tornado.drivers.ptx{requires java.base;requires transitive jdk.internal.vm.ci;requires transitive jdk.graal.compiler;requires transitive org.graalvm.collections;requires transitive org.graalvm.word;requires transitive tornado.api;requires transitive tornado.runtime;requires tornado.drivers.common;
exports uk.ac.manchester.tornado.drivers.ptx;exports uk.ac.manchester.tornado.drivers.ptx.enums;exports uk.ac.manchester.tornado.drivers.ptx.graal;exports uk.ac.manchester.tornado.drivers.ptx.graal.asm;exports uk.ac.manchester.tornado.drivers.ptx.graal.backend;exports uk.ac.manchester.tornado.drivers.ptx.graal.compiler;exports uk.ac.manchester.tornado.drivers.ptx.graal.lir;exports uk.ac.manchester.tornado.drivers.ptx.graal.meta;exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes;exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.calc;exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.vector;exports uk.ac.manchester.tornado.drivers.ptx.graal.phases;exports uk.ac.manchester.tornado.drivers.ptx.mm;exports uk.ac.manchester.tornado.drivers.ptx.runtime;exports uk.ac.manchester.tornado.drivers.ptx.power;
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/FixedArrayCopyNode.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/FixedArrayCopyNode.java
index 54780d0e1e..8e7c91e030 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/FixedArrayCopyNode.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/FixedArrayCopyNode.java
@@ -10,7 +10,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -21,27 +21,26 @@
*/
package uk.ac.manchester.tornado.drivers.ptx.graal.nodes;
+import jdk.graal.compiler.core.common.LIRKind;
+import jdk.graal.compiler.core.common.type.StampFactory;
+import jdk.graal.compiler.core.common.type.TypeReference;
+import jdk.graal.compiler.graph.NodeClass;
+import jdk.graal.compiler.lir.gen.LIRGeneratorTool;
+import jdk.graal.compiler.nodeinfo.NodeInfo;
+import jdk.graal.compiler.nodes.ValueNode;
+import jdk.graal.compiler.nodes.ValuePhiNode;
+import jdk.graal.compiler.nodes.calc.FloatingNode;
+import jdk.graal.compiler.nodes.spi.LIRLowerable;
+import jdk.graal.compiler.nodes.spi.NodeLIRBuilderTool;
import jdk.vm.ci.meta.RawConstant;
import jdk.vm.ci.meta.ResolvedJavaType;
import jdk.vm.ci.meta.Value;
-import org.graalvm.compiler.core.common.LIRKind;
-import org.graalvm.compiler.core.common.type.StampFactory;
-import org.graalvm.compiler.core.common.type.TypeReference;
-import org.graalvm.compiler.graph.NodeClass;
-import org.graalvm.compiler.lir.gen.LIRGeneratorTool;
-import org.graalvm.compiler.nodeinfo.NodeInfo;
-import org.graalvm.compiler.nodes.ValueNode;
-import org.graalvm.compiler.nodes.ValuePhiNode;
-import org.graalvm.compiler.nodes.calc.FloatingNode;
-import org.graalvm.compiler.nodes.spi.LIRLowerable;
-import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
-
import uk.ac.manchester.tornado.drivers.ptx.graal.PTXArchitecture;
import uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler;
import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXBinary;
-import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXUnary;
-import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXLIRStmt;
import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXKind;
+import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXLIRStmt;
+import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXUnary;
/**
* This node generates a pointer copy between two arrays in private memory.
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoFloatingReadReplacement.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoFloatingReadReplacement.java
index 2327980ae6..319f1b3fa7 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoFloatingReadReplacement.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoFloatingReadReplacement.java
@@ -10,7 +10,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -21,8 +21,6 @@
*/
package uk.ac.manchester.tornado.drivers.ptx.graal.phases;
-import static org.graalvm.compiler.graph.Graph.NodeEvent.NODE_ADDED;
-import static org.graalvm.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
import static org.graalvm.word.LocationIdentity.any;
import static org.graalvm.word.LocationIdentity.init;
@@ -34,6 +32,8 @@
import org.graalvm.collections.EconomicSet;
import org.graalvm.collections.Equivalence;
import org.graalvm.collections.UnmodifiableMapCursor;
+import org.graalvm.word.LocationIdentity;
+
import jdk.graal.compiler.core.common.cfg.Loop;
import jdk.graal.compiler.debug.DebugCloseable;
import jdk.graal.compiler.debug.GraalError;
@@ -78,8 +78,6 @@
import jdk.graal.compiler.phases.common.PostRunCanonicalizationPhase;
import jdk.graal.compiler.phases.common.util.EconomicSetNodeEventListener;
import jdk.graal.compiler.phases.graph.ReentrantNodeIterator;
-import org.graalvm.word.LocationIdentity;
-
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.FixedArrayNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXBarrierNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.vector.VectorLoadElementNode;
@@ -107,7 +105,7 @@ public TornadoFloatingReadReplacement(CanonicalizerPhase canonicalizer) {
/**
* @param createMemoryMapNodes
- * a {@link MemoryMapNode} will be created for each return if this
+ * a {@link MemoryMapNode} will be created for each return if this
* @param canonicalizer
*/
public TornadoFloatingReadReplacement(boolean createMemoryMapNodes, CanonicalizerPhase canonicalizer) {
@@ -247,17 +245,17 @@ protected void run(StructuredGraph graph, CoreProviders context) {
EconomicMap> modifiedInLoops = null;
if (graph.hasLoops()) {
modifiedInLoops = EconomicMap.create(Equivalence.IDENTITY);
- ControlFlowGraph cfg = ControlFlowGraph.compute(graph, true, true, false, false);
+ ControlFlowGraph cfg = ControlFlowGraph.newBuilder(graph).connectBlocks(true).computeLoops(true).computeFrequency(true).build();
for (Loop> l : cfg.getLoops()) {
HIRLoop loop = (HIRLoop) l;
processLoop(loop, modifiedInLoops);
}
}
- EconomicSetNodeEventListener listener = new EconomicSetNodeEventListener(EnumSet.of(NODE_ADDED, ZERO_USAGES));
+ EconomicSetNodeEventListener listener = new EconomicSetNodeEventListener(EnumSet.of(Graph.NodeEvent.NODE_ADDED, Graph.NodeEvent.ZERO_USAGES));
try (Graph.NodeEventScope nes = graph.trackNodeEvents(listener)) {
- ReentrantNodeIterator.apply(new FloatingReadPhase.FloatingReadClosure(modifiedInLoops, true, createMemoryMapNodes, initMemory), graph.start(),
- new FloatingReadPhase.MemoryMapImpl(graph.start()));
+ ReentrantNodeIterator.apply(new FloatingReadPhase.FloatingReadClosure(modifiedInLoops, true, createMemoryMapNodes, initMemory), graph.start(), new FloatingReadPhase.MemoryMapImpl(graph
+ .start()));
}
for (Node n : removeExternallyUsedNodes(listener.getNodes())) {
@@ -362,9 +360,9 @@ private static void processAccess(MemoryAccess access, TornadoFloatingReadReplac
/**
* @param accessNode
- * is a {@link FixedNode} that will be replaced by a
- * {@link FloatingNode}. This method checks if the node that is going
- * to be replaced has an {@link PTXBarrierNode} as next.
+ * is a {@link FixedNode} that will be replaced by a
+ * {@link FloatingNode}. This method checks if the node that is going
+ * to be replaced has an {@link PTXBarrierNode} as next.
*/
private static boolean isNextNodeBarrierNode(FloatableAccessNode accessNode) {
return (accessNode.next() instanceof PTXBarrierNode);
@@ -372,9 +370,9 @@ private static boolean isNextNodeBarrierNode(FloatableAccessNode accessNode) {
/**
* @param nextNode
- * is a {@link FixedNode} that will be replaced by a
- * {@link FloatingNode}. This method removes the redundant
- * {@link PTXBarrierNode}.
+ * is a {@link FixedNode} that will be replaced by a
+ * {@link FloatingNode}. This method removes the redundant
+ * {@link PTXBarrierNode}.
*/
private static void replaceRedundantBarrierNode(Node nextNode) {
nextNode.replaceAtUsages(nextNode.successors().first());
@@ -416,8 +414,9 @@ protected TornadoFloatingReadReplacement.MemoryMapImpl processNode(FixedNode nod
final LoopExitNode loopExitNode = (LoopExitNode) node;
final EconomicSet modifiedInLoop = modifiedInLoops.get(loopExitNode.loopBegin());
final boolean anyModified = modifiedInLoop.contains(LocationIdentity.any());
- state.getMap().replaceAll(
- (locationIdentity, memoryNode) -> (anyModified || modifiedInLoop.contains(locationIdentity)) ? ProxyNode.forMemory(memoryNode, loopExitNode, locationIdentity) : memoryNode);
+ state.getMap().replaceAll((locationIdentity, memoryNode) -> (anyModified || modifiedInLoop.contains(locationIdentity))
+ ? ProxyNode.forMemory(memoryNode, loopExitNode, locationIdentity)
+ : memoryNode);
}
if (node instanceof MemoryAnchorNode) {
From 9c2665f0b87e061bc80022f566275865de5aefb8 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 4 Jul 2024 19:04:14 +0300
Subject: [PATCH 19/54] Refactor code and update exports in ptx drivers
---
.../src/etc/exportLists/ptx-exports | 120 +++++++++---------
.../ptx/graal/PTXLoweringProvider.java | 6 +-
.../plugins/PTXGraphBuilderPlugins.java | 68 +++++-----
3 files changed, 99 insertions(+), 95 deletions(-)
diff --git a/tornado-assembly/src/etc/exportLists/ptx-exports b/tornado-assembly/src/etc/exportLists/ptx-exports
index 5d787aac8c..27fccce481 100644
--- a/tornado-assembly/src/etc/exportLists/ptx-exports
+++ b/tornado-assembly/src/etc/exportLists/ptx-exports
@@ -24,68 +24,68 @@
--add-opens java.base/java.lang=tornado.drivers.ptx
--add-exports jdk.internal.vm.ci/jdk.vm.ci.common=tornado.drivers.ptx
--add-exports jdk.internal.vm.ci/jdk.vm.ci.amd64=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.hotspot.meta=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.replacements.classfile=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.alloc=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.util=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.cfg=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.framemap=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.hotspot.meta=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.replacements.classfile=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.alloc=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.util=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.cfg=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.framemap=tornado.drivers.ptx
--add-exports jdk.internal.vm.ci/jdk.vm.ci.meta=tornado.drivers.ptx
--add-exports jdk.internal.vm.ci/jdk.vm.ci.code=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.graph=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.graph.spi=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.gen=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodeinfo=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.calc=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.spi=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.code=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.debug=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.hotspot=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.java=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.asm=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.phases=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.graphbuilderconf=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.options=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.tiers=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.util=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.printer=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.graph=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.graph.spi=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.gen=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodeinfo=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.calc=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.spi=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.code=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.debug=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.hotspot=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.java=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.asm=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.phases=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.graphbuilderconf=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.options=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.tiers=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.util=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.printer=tornado.drivers.ptx
--add-exports jdk.internal.vm.ci/jdk.vm.ci.runtime=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.graph.iterators=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.java=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.bytecode=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.common=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.spi=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.api.replacements=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.replacements=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.common.inlining=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.phases=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.type=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.extended=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.loop=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.loop.phases=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.debug=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.memory=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.util=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.graph.iterators=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.java=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.bytecode=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.common=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.spi=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.api.replacements=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.replacements=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.common.inlining=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.phases=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.type=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.extended=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.loop=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.loop.phases=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.debug=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.memory=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.util=tornado.drivers.ptx
--add-opens jdk.internal.vm.ci/jdk.vm.ci.hotspot=tornado.drivers.ptx
--add-exports jdk.internal.vm.ci/jdk.vm.ci.hotspot=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.asm=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.cfg=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.schedule=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.virtual.phases.ea=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.ssa=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.calc=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.gen=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.match=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.memory.address=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.type=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.graph=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.common.util=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.common.util=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.graph=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.word=tornado.drivers.ptx
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.memory=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.asm=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.cfg=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.schedule=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.virtual.phases.ea=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.ssa=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.calc=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.gen=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.match=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.memory.address=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.type=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.graph=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.common.util=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.common.util=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.graph=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.word=tornado.drivers.ptx
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.memory=tornado.drivers.ptx
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXLoweringProvider.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXLoweringProvider.java
index f4fb9b9c67..d6b7091e26 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXLoweringProvider.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXLoweringProvider.java
@@ -21,11 +21,13 @@
*/
package uk.ac.manchester.tornado.drivers.ptx.graal;
-import static org.graalvm.compiler.nodes.NamedLocationIdentity.ARRAY_LENGTH_LOCATION;
+import static jdk.graal.compiler.nodes.NamedLocationIdentity.ARRAY_LENGTH_LOCATION;
import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.shouldNotReachHere;
import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.unimplemented;
import static uk.ac.manchester.tornado.drivers.providers.TornadoMemoryOrder.GPU_MEMORY_MODE;
+import org.graalvm.word.LocationIdentity;
+
import jdk.graal.compiler.core.common.memory.BarrierType;
import jdk.graal.compiler.core.common.memory.MemoryExtendKind;
import jdk.graal.compiler.core.common.spi.ForeignCallsProvider;
@@ -75,8 +77,6 @@
import jdk.graal.compiler.phases.util.Providers;
import jdk.graal.compiler.replacements.DefaultJavaLoweringProvider;
import jdk.graal.compiler.replacements.SnippetCounter;
-import org.graalvm.word.LocationIdentity;
-
import jdk.vm.ci.code.TargetDescription;
import jdk.vm.ci.hotspot.HotSpotCallingConventionType;
import jdk.vm.ci.meta.ConstantReflectionProvider;
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
index f004e41c94..9eac2421be 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
@@ -39,8 +39,7 @@
import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntBinaryIntrinsicNode.Operation.MIN;
import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntUnaryIntrinsicNode.Operation.POPCOUNT;
-import java.lang.foreign.MemorySegment;
-import java.lang.foreign.ValueLayout;
+import org.graalvm.word.LocationIdentity;
import jdk.graal.compiler.core.common.memory.BarrierType;
import jdk.graal.compiler.core.common.memory.MemoryOrderMode;
@@ -64,14 +63,12 @@
import jdk.graal.compiler.nodes.memory.address.OffsetAddressNode;
import jdk.graal.compiler.nodes.util.GraphUtil;
import jdk.graal.compiler.replacements.InlineDuringParsingPlugin;
-import org.graalvm.word.LocationIdentity;
-
import jdk.vm.ci.meta.JavaConstant;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.exceptions.Debug;
-import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
+import uk.ac.manchester.tornado.api.types.arrays.TornadoMemorySegment;
import uk.ac.manchester.tornado.drivers.ptx.graal.PTXArchitecture;
import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXKind;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.LocalArrayNode;
@@ -399,36 +396,14 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
- public static Class getValueLayoutClass(Class k) {
- if (k == int.class) {
- return ValueLayout.OfInt.class;
- } else if (k == double.class) {
- return ValueLayout.OfDouble.class;
- } else if (k == float.class) {
- return ValueLayout.OfFloat.class;
- } else if (k == long.class) {
- return ValueLayout.OfLong.class;
- } else if (k == boolean.class) {
- return ValueLayout.OfBoolean.class;
- } else if (k == byte.class) {
- return ValueLayout.OfByte.class;
- } else if (k == char.class) {
- return ValueLayout.OfChar.class;
- } else if (k == short.class) {
- return ValueLayout.OfShort.class;
- } else {
- throw new TornadoRuntimeException("Class type " + k + " not supported.");
- }
- }
-
private static void registerMemoryAccessPlugins(InvocationPlugins plugins) {
- Registration r = new Registration(plugins, MemorySegment.class);
+ Registration r = new Registration(plugins, TornadoMemorySegment.class);
for (JavaKind kind : JavaKind.values()) {
if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- r.register(new InvocationPlugin("getAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class) {
+ r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", InvocationPlugin.Receiver.class, int.class, int.class) {
@Override
- public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index) {
+ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
@@ -436,9 +411,9 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
return true;
}
});
- r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class, kind.toJavaClass()) {
+ r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, int.class, kind.toJavaClass(), int.class) {
@Override
- public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index, ValueNode value) {
+ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
@@ -450,6 +425,35 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
}
+ // private static void registerMemoryAccessPlugins(InvocationPlugins plugins) {
+ // Registration r = new Registration(plugins, MemorySegment.class);
+ //
+ // for (JavaKind kind : JavaKind.values()) {
+ // if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
+ // r.register(new InvocationPlugin("getAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class) {
+ // @Override
+ // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index) {
+ // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
+ // JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
+ // b.addPush(kind, readNode);
+ // return true;
+ // }
+ // });
+ // r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class, kind.toJavaClass()) {
+ // @Override
+ // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index, ValueNode value) {
+ // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
+ // JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
+ // b.add(writeNode);
+ // return true;
+ // }
+ // });
+ // }
+ // }
+ // }
+
public static void registerNewInstancePlugins(Plugins plugins) {
plugins.appendNodePlugin(new PTXVectorNodePlugin());
}
From f7596edea5637a794be27945db90f2544b519f33 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 4 Jul 2024 19:04:29 +0300
Subject: [PATCH 20/54] Refactor TornadoFixedArrayCopyPhase class
---
.../phases/TornadoFixedArrayCopyPhase.java | 87 +++++++++----------
1 file changed, 43 insertions(+), 44 deletions(-)
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoFixedArrayCopyPhase.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoFixedArrayCopyPhase.java
index fa47e0e736..63ccf523c0 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoFixedArrayCopyPhase.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoFixedArrayCopyPhase.java
@@ -21,24 +21,23 @@
*/
package uk.ac.manchester.tornado.drivers.ptx.graal.phases;
-import jdk.vm.ci.meta.ResolvedJavaType;
-import org.graalvm.compiler.graph.Node;
-import org.graalvm.compiler.nodes.GraphState;
-import org.graalvm.compiler.nodes.StructuredGraph;
-import org.graalvm.compiler.nodes.ValuePhiNode;
-import org.graalvm.compiler.nodes.memory.ReadNode;
-import org.graalvm.compiler.nodes.memory.address.OffsetAddressNode;
-import org.graalvm.compiler.phases.Phase;
+import java.util.ArrayList;
+import java.util.Optional;
+import jdk.graal.compiler.graph.Node;
+import jdk.graal.compiler.nodes.GraphState;
+import jdk.graal.compiler.nodes.StructuredGraph;
+import jdk.graal.compiler.nodes.ValuePhiNode;
+import jdk.graal.compiler.nodes.memory.ReadNode;
+import jdk.graal.compiler.nodes.memory.address.OffsetAddressNode;
+import jdk.graal.compiler.phases.Phase;
+import jdk.vm.ci.meta.ResolvedJavaType;
import uk.ac.manchester.tornado.api.exceptions.TornadoCompilationException;
import uk.ac.manchester.tornado.drivers.ptx.graal.PTXArchitecture;
+import uk.ac.manchester.tornado.drivers.ptx.graal.PTXStampFactory;
import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXKind;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.FixedArrayCopyNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.FixedArrayNode;
-import uk.ac.manchester.tornado.drivers.ptx.graal.PTXStampFactory;
-
-import java.util.ArrayList;
-import java.util.Optional;
/**
* This phase examines if a copy takes place between two arrays in private memory based on
@@ -46,6 +45,38 @@
*/
public class TornadoFixedArrayCopyPhase extends Phase {
+ private static boolean isFixedArrayCopied(ValuePhiNode phiNode) {
+ return phiNode.usages().filter(OffsetAddressNode.class).isNotEmpty() && phiNode.values().filter(FixedArrayNode.class).isNotEmpty();
+ }
+
+ private static void deleteFixed(Node n) {
+ Node pred = n.predecessor();
+ Node suc = n.successors().first();
+
+ n.replaceFirstSuccessor(suc, null);
+ n.replaceAtPredecessor(suc);
+ pred.replaceFirstSuccessor(n, suc);
+
+ for (Node us : n.usages()) {
+ n.removeUsage(us);
+ }
+ n.clearInputs();
+
+ n.safeDelete();
+ }
+
+ private static ValuePhiNode getPrivateArrayIndex(Node node) {
+ // identify the index
+ for (Node input : node.inputs()) {
+ if (input instanceof ValuePhiNode phiNode) {
+ return phiNode;
+ } else {
+ return getPrivateArrayIndex(input);
+ }
+ }
+ return null;
+ }
+
@Override
public Optional notApplicableTo(GraphState graphState) {
return ALWAYS_APPLICABLE;
@@ -84,36 +115,4 @@ protected void run(StructuredGraph graph) {
}
}
- private static boolean isFixedArrayCopied(ValuePhiNode phiNode) {
- return phiNode.usages().filter(OffsetAddressNode.class).isNotEmpty() && phiNode.values().filter(FixedArrayNode.class).isNotEmpty();
- }
-
- private static void deleteFixed(Node n) {
- Node pred = n.predecessor();
- Node suc = n.successors().first();
-
- n.replaceFirstSuccessor(suc, null);
- n.replaceAtPredecessor(suc);
- pred.replaceFirstSuccessor(n, suc);
-
- for (Node us : n.usages()) {
- n.removeUsage(us);
- }
- n.clearInputs();
-
- n.safeDelete();
- }
-
- private static ValuePhiNode getPrivateArrayIndex(Node node) {
- // identify the index
- for (Node input : node.inputs()) {
- if (input instanceof ValuePhiNode phiNode) {
- return phiNode;
- } else {
- return getPrivateArrayIndex(input);
- }
- }
- return null;
- }
-
}
From 8ef58b23834d9a791e290da08aac3e26da20444e Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Fri, 5 Jul 2024 14:40:44 +0300
Subject: [PATCH 21/54] Reorder code and fix HalfFloatPlaceholder receiver in
OCLHalfFloatPlugins
---
.../compiler/plugins/OCLHalfFloatPlugins.java | 7 +-
.../phases/TornadoHalfFloatReplacement.java | 279 +++++++++---------
...TornadoHalfFloatFixedGuardElimination.java | 5 +-
3 files changed, 154 insertions(+), 137 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java
index 40f82dc2af..4f743255e9 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java
@@ -27,16 +27,15 @@
import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugin;
import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugins;
import jdk.graal.compiler.nodes.graphbuilderconf.NodePlugin;
-
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.runtime.graal.nodes.AddHalfFloatNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.DivHalfFloatNode;
-import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode;
-import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder;
+import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.NewHalfFloatInstance;
+import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode;
public class OCLHalfFloatPlugins {
@@ -99,7 +98,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("getHalfFloatValue", InvocationPlugin.Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
- b.push(JavaKind.Short, b.append(new HalfFloatPlaceholder(receiver.get())));
+ b.push(JavaKind.Short, b.append(new HalfFloatPlaceholder(receiver.get(true))));
return true;
}
});
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java
index 116b129956..ce334d165d 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java
@@ -24,9 +24,6 @@
import java.util.ArrayList;
import java.util.Optional;
-import jdk.vm.ci.meta.Constant;
-import jdk.vm.ci.meta.JavaKind;
-import jdk.vm.ci.meta.RawConstant;
import jdk.graal.compiler.core.common.type.StampFactory;
import jdk.graal.compiler.graph.Node;
import jdk.graal.compiler.nodes.ConstantNode;
@@ -44,7 +41,9 @@
import jdk.graal.compiler.nodes.java.NewInstanceNode;
import jdk.graal.compiler.nodes.memory.address.AddressNode;
import jdk.graal.compiler.phases.BasePhase;
-
+import jdk.vm.ci.meta.Constant;
+import jdk.vm.ci.meta.JavaKind;
+import jdk.vm.ci.meta.RawConstant;
import uk.ac.manchester.tornado.api.internal.annotations.HalfType;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.ReadHalfFloatNode;
@@ -66,134 +65,6 @@
public class TornadoHalfFloatReplacement extends BasePhase {
- @Override
- public Optional notApplicableTo(GraphState graphState) {
- return ALWAYS_APPLICABLE;
- }
-
- protected void run(StructuredGraph graph, TornadoHighTierContext context) {
-
- for (ValueAnchorNode valueAnchorNode : graph.getNodes().filter(ValueAnchorNode.class)) {
- ArrayList deletePi = new ArrayList();
- for (Node valueAnchorNodeUsage : valueAnchorNode.usages()) {
- if (valueAnchorNodeUsage instanceof PiNode) {
- PiNode piNode = (PiNode) valueAnchorNodeUsage;
- piNode.replaceAtUsages(piNode.object());
- deletePi.add(piNode);
- }
- }
- for (PiNode p : deletePi) {
- p.safeDelete();
- }
- deleteFixed(valueAnchorNode);
- }
-
- // replace reads with halfFloat reads
- for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
- if (javaRead.successors().first() instanceof NewInstanceNode) {
- NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first();
- if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
- if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
- NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
- deleteFixed(newHalfFloatInstance);
- }
- AddressNode readingAddress = javaRead.getAddress();
- ReadHalfFloatNode readHalfFloatNode = new ReadHalfFloatNode(readingAddress);
- graph.addWithoutUnique(readHalfFloatNode);
- replaceFixed(javaRead, readHalfFloatNode);
- newInstanceNode.replaceAtUsages(readHalfFloatNode);
- deleteFixed(newInstanceNode);
- }
- }
- }
-
- for (NewInstanceNode newInstanceNode : graph.getNodes().filter(NewInstanceNode.class)) {
- if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
- if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
- NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
- ValueNode valueInput = newHalfFloatInstance.getValue();
- newInstanceNode.replaceAtUsages(valueInput);
- deleteFixed(newInstanceNode);
- deleteFixed(newHalfFloatInstance);
- }
- }
- }
-
- // replace writes with halfFloat writes
- for (JavaWriteNode javaWrite : graph.getNodes().filter(JavaWriteNode.class)) {
- if (isWriteHalfFloat(javaWrite)) {
- // This casting is safe to do as it is already checked by the isWriteHalfFloat function
- HalfFloatPlaceholder placeholder = (HalfFloatPlaceholder) javaWrite.value();
- ValueNode writingValue;
- if (javaWrite.predecessor() instanceof NewHalfFloatInstance) {
- // if a new HalfFloat instance is written
- NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) javaWrite.predecessor();
- writingValue = newHalfFloatInstance.getValue();
- if (newHalfFloatInstance.predecessor() instanceof NewInstanceNode) {
- NewInstanceNode newInstanceNode = (NewInstanceNode) newHalfFloatInstance.predecessor();
- if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) {
- deleteFixed(newInstanceNode);
- deleteFixed(newHalfFloatInstance);
- }
- }
- } else {
- // if the result of an operation or a stored value is written
- writingValue = placeholder.getInput();
- }
- placeholder.replaceAtUsages(writingValue);
- placeholder.safeDelete();
- AddressNode writingAddress = javaWrite.getAddress();
- WriteHalfFloatNode writeHalfFloatNode = new WriteHalfFloatNode(writingAddress, writingValue);
- graph.addWithoutUnique(writeHalfFloatNode);
- replaceFixed(javaWrite, writeHalfFloatNode);
- deleteFixed(javaWrite);
- }
- }
-
- // replace the half float operator nodes with the corresponding regular operators
- replaceAddHalfFloatNodes(graph);
- replaceSubHalfFloatNodes(graph);
- replaceMultHalfFloatNodes(graph);
- replaceDivHalfFloatNodes(graph);
-
- // add after the loadindexedvector nodes the marker node to fix the offset of its read
-
- for (LoadIndexedVectorNode loadIndexedVectorNode : graph.getNodes().filter(LoadIndexedVectorNode.class)) {
- if (loadIndexedVectorNode.getOCLKind().isHalf()) {
- VectorHalfRead vectorHalfRead;
- if (loadIndexedVectorNode.index() instanceof ConstantNode) {
- ConstantNode offset = (ConstantNode) loadIndexedVectorNode.index();
- int offsetValue = Integer.valueOf(offset.getValue().toValueString());
- vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead(offsetValue));
- } else {
- vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead());
- }
- graph.addAfterFixed(loadIndexedVectorNode, vectorHalfRead);
- }
- }
-
- for (VectorValueNode vectorValueNode : graph.getNodes().filter(VectorValueNode.class)) {
- if (vectorValueNode.getOCLKind().isHalf()) {
- for (Node vectorElement : vectorValueNode.inputs()) {
- if (vectorElement instanceof VectorLoadElementNode) {
- VectorLoadElementNode vectorLoad = (VectorLoadElementNode) vectorElement;
- VectorLoadElementNode vectorLoadShort = new VectorLoadElementNode(OCLKind.HALF, vectorLoad.getVector(), vectorLoad.getLaneId());
- graph.addWithoutUnique(vectorLoadShort);
- vectorLoad.replaceAtUsages(vectorLoadShort);
- vectorLoad.safeDelete();
- } else if (vectorElement instanceof ConstantNode constantNode && constantNode.getValue().toValueString().contains("null")) {
- Constant zeroValue = new RawConstant(0);
- ConstantNode zero = new ConstantNode(zeroValue, StampFactory.forKind(JavaKind.Short));
- graph.addWithoutUnique(zero);
- constantNode.replaceAtUsages(zero);
- constantNode.safeDelete();
- }
- }
- }
- }
-
- }
-
private static ValueNode replaceAdd(AddHalfFloatNode addHalfFloatNode, StructuredGraph graph) {
ValueNode addNode;
ValueNode addX = getHalfOperand(addHalfFloatNode.getX(), graph);
@@ -267,6 +138,21 @@ private static ValueNode replaceMult(MultHalfFloatNode multHalfFloatNode, Struct
graph.addWithoutUnique(multNode);
}
+ if (multHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ PiNode piNode = multHalfFloatNode.usages().filter(PiNode.class).first();
+ if (piNode.inputs().filter(ValueAnchorNode.class).isNotEmpty()) {
+ ValueAnchorNode anchorNode = piNode.inputs().filter(ValueAnchorNode.class).first();
+ deleteFixed(anchorNode);
+ piNode.replaceAtUsages(multNode);
+ piNode.safeDelete();
+ } else {
+ piNode.replaceAtUsages(multNode);
+ piNode.safeDelete();
+ }
+ } else {
+ multHalfFloatNode.replaceAtUsages(multNode);
+ }
+
multHalfFloatNode.replaceAtUsages(multNode);
multHalfFloatNode.safeDelete();
return multNode;
@@ -361,4 +247,133 @@ private static void deleteFixed(Node node) {
}
}
+ @Override
+ public Optional notApplicableTo(GraphState graphState) {
+ return ALWAYS_APPLICABLE;
+ }
+
+ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
+
+ for (ValueAnchorNode valueAnchorNode : graph.getNodes().filter(ValueAnchorNode.class)) {
+ ArrayList deletePi = new ArrayList();
+ for (Node valueAnchorNodeUsage : valueAnchorNode.usages()) {
+ if (valueAnchorNodeUsage instanceof PiNode) {
+ PiNode piNode = (PiNode) valueAnchorNodeUsage;
+ piNode.replaceAtUsages(piNode.object());
+ deletePi.add(piNode);
+ }
+ }
+ for (PiNode p : deletePi) {
+ p.safeDelete();
+ }
+ deleteFixed(valueAnchorNode);
+ }
+
+ // replace reads with halfFloat reads
+ for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
+ if (javaRead.successors().first() instanceof NewInstanceNode) {
+ NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first();
+ if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
+ if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
+ NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
+ deleteFixed(newHalfFloatInstance);
+ }
+ AddressNode readingAddress = javaRead.getAddress();
+ ReadHalfFloatNode readHalfFloatNode = new ReadHalfFloatNode(readingAddress);
+ graph.addWithoutUnique(readHalfFloatNode);
+ replaceFixed(javaRead, readHalfFloatNode);
+ newInstanceNode.replaceAtUsages(readHalfFloatNode);
+ deleteFixed(newInstanceNode);
+ }
+ }
+ }
+
+ for (NewInstanceNode newInstanceNode : graph.getNodes().filter(NewInstanceNode.class)) {
+ if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
+ if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
+ NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
+ ValueNode valueInput = newHalfFloatInstance.getValue();
+ newInstanceNode.replaceAtUsages(valueInput);
+ deleteFixed(newInstanceNode);
+ deleteFixed(newHalfFloatInstance);
+ }
+ }
+ }
+
+ // replace writes with halfFloat writes
+ for (JavaWriteNode javaWrite : graph.getNodes().filter(JavaWriteNode.class)) {
+ if (isWriteHalfFloat(javaWrite)) {
+ // This casting is safe to do as it is already checked by the isWriteHalfFloat function
+ HalfFloatPlaceholder placeholder = (HalfFloatPlaceholder) javaWrite.value();
+ ValueNode writingValue;
+ if (javaWrite.predecessor() instanceof NewHalfFloatInstance) {
+ // if a new HalfFloat instance is written
+ NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) javaWrite.predecessor();
+ writingValue = newHalfFloatInstance.getValue();
+ if (newHalfFloatInstance.predecessor() instanceof NewInstanceNode) {
+ NewInstanceNode newInstanceNode = (NewInstanceNode) newHalfFloatInstance.predecessor();
+ if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) {
+ deleteFixed(newInstanceNode);
+ deleteFixed(newHalfFloatInstance);
+ }
+ }
+ } else {
+ // if the result of an operation or a stored value is written
+ writingValue = placeholder.getInput();
+ }
+ System.out.println("Ewring " + placeholder.toString() + " " + placeholder.inputs().first().toString());
+ placeholder.replaceAtUsages(writingValue);
+ placeholder.safeDelete();
+ AddressNode writingAddress = javaWrite.getAddress();
+ WriteHalfFloatNode writeHalfFloatNode = new WriteHalfFloatNode(writingAddress, writingValue);
+ graph.addWithoutUnique(writeHalfFloatNode);
+ replaceFixed(javaWrite, writeHalfFloatNode);
+ deleteFixed(javaWrite);
+ }
+ }
+
+ // replace the half float operator nodes with the corresponding regular operators
+ replaceAddHalfFloatNodes(graph);
+ replaceSubHalfFloatNodes(graph);
+ replaceMultHalfFloatNodes(graph);
+ replaceDivHalfFloatNodes(graph);
+
+ // add after the loadindexedvector nodes the marker node to fix the offset of its read
+
+ for (LoadIndexedVectorNode loadIndexedVectorNode : graph.getNodes().filter(LoadIndexedVectorNode.class)) {
+ if (loadIndexedVectorNode.getOCLKind().isHalf()) {
+ VectorHalfRead vectorHalfRead;
+ if (loadIndexedVectorNode.index() instanceof ConstantNode) {
+ ConstantNode offset = (ConstantNode) loadIndexedVectorNode.index();
+ int offsetValue = Integer.valueOf(offset.getValue().toValueString());
+ vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead(offsetValue));
+ } else {
+ vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead());
+ }
+ graph.addAfterFixed(loadIndexedVectorNode, vectorHalfRead);
+ }
+ }
+
+ for (VectorValueNode vectorValueNode : graph.getNodes().filter(VectorValueNode.class)) {
+ if (vectorValueNode.getOCLKind().isHalf()) {
+ for (Node vectorElement : vectorValueNode.inputs()) {
+ if (vectorElement instanceof VectorLoadElementNode) {
+ VectorLoadElementNode vectorLoad = (VectorLoadElementNode) vectorElement;
+ VectorLoadElementNode vectorLoadShort = new VectorLoadElementNode(OCLKind.HALF, vectorLoad.getVector(), vectorLoad.getLaneId());
+ graph.addWithoutUnique(vectorLoadShort);
+ vectorLoad.replaceAtUsages(vectorLoadShort);
+ vectorLoad.safeDelete();
+ } else if (vectorElement instanceof ConstantNode constantNode && constantNode.getValue().toValueString().contains("null")) {
+ Constant zeroValue = new RawConstant(0);
+ ConstantNode zero = new ConstantNode(zeroValue, StampFactory.forKind(JavaKind.Short));
+ graph.addWithoutUnique(zero);
+ constantNode.replaceAtUsages(zero);
+ constantNode.safeDelete();
+ }
+ }
+ }
+ }
+
+ }
+
}
diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java
index bfa96b8f3b..c098d7c21d 100644
--- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java
+++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java
@@ -32,7 +32,6 @@
import jdk.graal.compiler.nodes.ValueNode;
import jdk.graal.compiler.nodes.calc.IsNullNode;
import jdk.graal.compiler.phases.BasePhase;
-
import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder;
public class TornadoHalfFloatFixedGuardElimination extends BasePhase {
@@ -64,6 +63,10 @@ protected void run(StructuredGraph graph, TornadoSketchTierContext context) {
for (HalfFloatPlaceholder placeholderNode : graph.getNodes().filter(HalfFloatPlaceholder.class)) {
if (placeholderNode.getInput() instanceof PiNode placeholderInput) {
ValueNode halfFloatValue = placeholderInput.object();
+ if (halfFloatValue instanceof PiNode) {
+ nodesToBeDeleted.add(halfFloatValue);
+ halfFloatValue = (ValueNode) halfFloatValue.inputs().first();
+ }
FixedGuardNode placeholderGuard = (FixedGuardNode) placeholderInput.getGuard();
if (placeholderGuard.inputs().filter(IsNullNode.class).isNotEmpty()) {
IsNullNode isNullNode = placeholderGuard.inputs().filter(IsNullNode.class).first();
From 1d87e7048079cff93f14dd13381568625ea2edce Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Fri, 5 Jul 2024 15:24:46 +0300
Subject: [PATCH 22/54] Update method used in testFloat task
---
.../tornado/unittests/foundation/TestFloats.java | 2 +-
.../unittests/foundation/TestKernels.java | 16 ----------------
2 files changed, 1 insertion(+), 17 deletions(-)
diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestFloats.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestFloats.java
index fc63a68339..bb78cc6cb6 100644
--- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestFloats.java
+++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestFloats.java
@@ -45,7 +45,7 @@ public void testFloatsCopy() throws TornadoExecutionPlanException {
FloatArray a = new FloatArray(numElements);
TaskGraph taskGraph = new TaskGraph("s0") //
- .task("t0", TestKernels::testFloatCopy222, a) //
+ .task("t0", TestKernels::testFloatCopy, a) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, a);
ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestKernels.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestKernels.java
index be0852e11e..2230c08925 100644
--- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestKernels.java
+++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/foundation/TestKernels.java
@@ -156,22 +156,6 @@ public static void testFloatCopy(FloatArray a) {
}
}
- public static void testFloatCopy222(FloatArray a) {
- for (@Parallel int i = 0; i < a.getSize(); i++) {
- float x = a.get(i);
- x = x + i;
- }
- }
-
- public static FloatArray testFloatCopy2(FloatArray a) {
- FloatArray temp = new FloatArray(a.getSize());
- for (int i = 0; i < a.getSize(); i++) {
- temp.set(i, 50.0f + a.get(i));
- }
-
- return temp;
- }
-
public static void testDoublesCopy(DoubleArray a) {
for (@Parallel int i = 0; i < a.getSize(); i++) {
a.set(i, 50);
From d0314f928bc38e1203b971c222435c5a63a853df Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Fri, 5 Jul 2024 15:26:16 +0300
Subject: [PATCH 23/54] Remove unused code in ExceptionSuppression
---
.../compiler/phases/guards/ExceptionSuppression.java | 10 ----------
1 file changed, 10 deletions(-)
diff --git a/tornado-drivers/drivers-common/src/main/java/uk/ac/manchester/tornado/drivers/common/compiler/phases/guards/ExceptionSuppression.java b/tornado-drivers/drivers-common/src/main/java/uk/ac/manchester/tornado/drivers/common/compiler/phases/guards/ExceptionSuppression.java
index c50b100f22..ed5a031d83 100644
--- a/tornado-drivers/drivers-common/src/main/java/uk/ac/manchester/tornado/drivers/common/compiler/phases/guards/ExceptionSuppression.java
+++ b/tornado-drivers/drivers-common/src/main/java/uk/ac/manchester/tornado/drivers/common/compiler/phases/guards/ExceptionSuppression.java
@@ -30,9 +30,7 @@
import jdk.graal.compiler.nodes.LogicNode;
import jdk.graal.compiler.nodes.StructuredGraph;
import jdk.graal.compiler.nodes.extended.GuardedNode;
-import jdk.graal.compiler.nodes.extended.ValueAnchorNode;
import jdk.graal.compiler.phases.BasePhase;
-
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext;
public class ExceptionSuppression extends BasePhase {
@@ -59,14 +57,6 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
}
});
- // graph.getNodes().filter(ValueAnchorNode.class).forEach(anchor -> {
- //// if (anchor. instanceof GuardNode guard) {
- //// guards.add(guard);
- //// conditions.add(guard.getCondition());
- //// // anchor.removeAnchoredNode();
- //// }
- // });
-
guards.forEach(guard -> {
guard.clearInputs();
guard.safeDelete();
From 2583246555f624637c68c6d14fcf0f75525cfd55 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 18 Jul 2024 18:56:07 +0300
Subject: [PATCH 24/54] Refactor code for plugins to be compliant with Graal
24.0.1
---
.../plugins/PTXGraphBuilderPlugins.java | 128 ++++++++++++++----
.../compiler/plugins/PTXMathPlugins.java | 27 ++--
.../compiler/plugins/PTXVectorPlugins.java | 11 +-
3 files changed, 120 insertions(+), 46 deletions(-)
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
index 9eac2421be..43d4fc621b 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
@@ -23,24 +23,6 @@
*/
package uk.ac.manchester.tornado.drivers.ptx.graal.compiler.plugins;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode.Operation.FMAX;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode.Operation.FMIN;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode.Operation.POW;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.ATAN;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.COS;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.EXP;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.FABS;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.LOG;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.SIGN;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.SIN;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.TAN;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.TANH;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntBinaryIntrinsicNode.Operation.MAX;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntBinaryIntrinsicNode.Operation.MIN;
-import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntUnaryIntrinsicNode.Operation.POPCOUNT;
-
-import org.graalvm.word.LocationIdentity;
-
import jdk.graal.compiler.core.common.memory.BarrierType;
import jdk.graal.compiler.core.common.memory.MemoryOrderMode;
import jdk.graal.compiler.core.common.type.StampFactory;
@@ -48,7 +30,9 @@
import jdk.graal.compiler.nodes.ConstantNode;
import jdk.graal.compiler.nodes.FixedWithNextNode;
import jdk.graal.compiler.nodes.ValueNode;
+import jdk.graal.compiler.nodes.calc.AddNode;
import jdk.graal.compiler.nodes.calc.MulNode;
+import jdk.graal.compiler.nodes.calc.SignExtendNode;
import jdk.graal.compiler.nodes.extended.BoxNode;
import jdk.graal.compiler.nodes.extended.JavaReadNode;
import jdk.graal.compiler.nodes.extended.JavaWriteNode;
@@ -66,6 +50,7 @@
import jdk.vm.ci.meta.JavaConstant;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaMethod;
+import org.graalvm.word.LocationIdentity;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.exceptions.Debug;
import uk.ac.manchester.tornado.api.types.arrays.TornadoMemorySegment;
@@ -80,6 +65,22 @@
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PrintfNode;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode.Operation.FMAX;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode.Operation.FMIN;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode.Operation.POW;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.ATAN;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.COS;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.EXP;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.FABS;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.LOG;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.SIGN;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.SIN;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.TAN;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode.Operation.TANH;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntBinaryIntrinsicNode.Operation.MAX;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntBinaryIntrinsicNode.Operation.MIN;
+import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntUnaryIntrinsicNode.Operation.POPCOUNT;
+
public class PTXGraphBuilderPlugins {
public static void registerInvocationPlugins(final Plugins ps, final InvocationPlugins plugins) {
@@ -198,6 +199,7 @@ private static void registerGlobalBarrier(Registration r) {
r.register(new InvocationPlugin("globalBarrier", InvocationPlugin.Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
PTXBarrierNode localBarrierNode = new PTXBarrierNode(1, -1);
b.add(localBarrierNode);
return true;
@@ -209,8 +211,11 @@ private static void registerIntLocalArray(Registration r, JavaKind returnedJavaK
r.register(new InvocationPlugin("allocateIntLocalArray", InvocationPlugin.Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
+ receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
+ b.getGraph().addOrUnique(constantNode);
LocalArrayNode localArrayNode = new LocalArrayNode(PTXArchitecture.sharedSpace, elementType, constantNode);
+ b.getGraph().addOrUnique(localArrayNode);
b.push(returnedJavaKind, localArrayNode);
return true;
}
@@ -221,8 +226,11 @@ private static void registerLongLocalArray(Registration r, JavaKind returnedJava
r.register(new InvocationPlugin("allocateLongLocalArray", InvocationPlugin.Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
+ receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
+ b.getGraph().addOrUnique(constantNode);
LocalArrayNode localArrayNode = new LocalArrayNode(PTXArchitecture.sharedSpace, elementType, constantNode);
+ b.getGraph().addOrUnique(localArrayNode);
b.push(returnedJavaKind, localArrayNode);
return true;
}
@@ -233,8 +241,11 @@ private static void registerFloatLocalArray(Registration r, JavaKind returnedJav
r.register(new InvocationPlugin("allocateFloatLocalArray", InvocationPlugin.Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
+ receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
+ b.getGraph().addOrUnique(constantNode);
LocalArrayNode localArrayNode = new LocalArrayNode(PTXArchitecture.sharedSpace, elementType, constantNode);
+ b.getGraph().addOrUnique(localArrayNode);
b.push(returnedJavaKind, localArrayNode);
return true;
}
@@ -245,8 +256,11 @@ private static void registerDoubleLocalArray(Registration r, JavaKind returnedJa
r.register(new InvocationPlugin("allocateDoubleLocalArray", InvocationPlugin.Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
+ receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
+ b.getGraph().addOrUnique(constantNode);
LocalArrayNode localArrayNode = new LocalArrayNode(PTXArchitecture.sharedSpace, elementType, constantNode);
+ b.getGraph().addOrUnique(localArrayNode);
b.push(returnedJavaKind, localArrayNode);
return true;
}
@@ -404,7 +418,9 @@ private static void registerMemoryAccessPlugins(InvocationPlugins plugins) {
r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", InvocationPlugin.Receiver.class, int.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
- MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ System.out.println("APPLY -> " + kind.name());
+ AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
+ MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
b.addPush(kind, readNode);
@@ -414,7 +430,12 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, int.class, kind.toJavaClass(), int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
- MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ System.out.println("APPLY set-> " + kind.name());
+
+ AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
+ SignExtendNode signExtend = new SignExtendNode(absoluteIndexNode.asNode(), 64);
+ b.getGraph().addOrUnique(signExtend);
+ MulNode mulNode = b.append(new MulNode(signExtend, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
b.add(writeNode);
@@ -426,34 +447,87 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
// private static void registerMemoryAccessPlugins(InvocationPlugins plugins) {
- // Registration r = new Registration(plugins, MemorySegment.class);
+ // Registration r = new Registration(plugins, TornadoMemorySegment.class);
//
// for (JavaKind kind : JavaKind.values()) {
// if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- // r.register(new InvocationPlugin("getAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class) {
+ // r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", InvocationPlugin.Receiver.class, int.class, int.class) {
// @Override
- // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index) {
- // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
+ // // Constant kindBytes = new RawConstant(kind.getByteCount());
+ // // ConstantNode constantNode = new ConstantNode(kindBytes, StampFactory.forKind(JavaKind.Int));
+ // // b.getGraph().addOrUnique(constantNode);
+ // // AddNode addNode = new AddNode(index, baseIndex);
+ // // b.getGraph().addOrUnique(addNode);
+ // // MulNode mulNode = new MulNode(addNode, constantNode);
+ // // b.getGraph().addOrUnique(mulNode);
+ // // AddressNode addressNode = new OffsetAddressNode(receiver.get(), mulNode);
+ // // b.getGraph().addOrUnique(addressNode);
+ // // JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
+ // // b.addPush(kind, readNode);
+ // // return true;
+ // AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
+ // MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
// AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
// JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
// b.addPush(kind, readNode);
// return true;
// }
// });
- // r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class, kind.toJavaClass()) {
+ // r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, int.class, kind.toJavaClass(), int.class) {
// @Override
- // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index, ValueNode value) {
- // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
+ // AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
+ // MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
// AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
// JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
// b.add(writeNode);
// return true;
+ // // Constant kindBytes = new RawConstant(kind.getByteCount());
+ // // ConstantNode constantNode = new ConstantNode(kindBytes, StampFactory.forKind(JavaKind.Int));
+ // // b.getGraph().addOrUnique(constantNode);
+ // // MulNode mulNode = new MulNode(index, constantNode);
+ // // b.getGraph().addOrUnique(mulNode);
+ // // AddressNode addressNode = new OffsetAddressNode(receiver.get(), mulNode);
+ // // b.getGraph().addOrUnique(addressNode);
+ // // JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
+ // // b.add(writeNode);
+ // // return true;
// }
// });
// }
// }
// }
+ // private static void registerMemoryAccessPlugins(InvocationPlugins plugins) {
+ // Registration r = new Registration(plugins, MemorySegment.class);
+ //
+ // for (JavaKind kind : JavaKind.values()) {
+ // if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
+ // r.register(new InvocationPlugin("getAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class) {
+ // @Override
+ // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index) {
+ // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
+ // JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
+ // b.addPush(kind, readNode);
+ // return true;
+ // }
+ // });
+ // r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class, kind.toJavaClass()) {
+ // @Override
+ // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index, ValueNode value) {
+ // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
+ // JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
+ // b.add(writeNode);
+ // return true;
+ // }
+ // });
+ // }
+ // }
+ // }
+
public static void registerNewInstancePlugins(Plugins plugins) {
plugins.appendNodePlugin(new PTXVectorNodePlugin());
}
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXMathPlugins.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXMathPlugins.java
index 6169fe0bee..7b21607a78 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXMathPlugins.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXMathPlugins.java
@@ -23,6 +23,19 @@
*/
package uk.ac.manchester.tornado.drivers.ptx.graal.compiler.plugins;
+import jdk.graal.compiler.nodes.ValueNode;
+import jdk.graal.compiler.nodes.graphbuilderconf.GraphBuilderContext;
+import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugin;
+import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugins;
+import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugins.Registration;
+import jdk.vm.ci.meta.JavaKind;
+import jdk.vm.ci.meta.ResolvedJavaMethod;
+import uk.ac.manchester.tornado.api.math.TornadoMath;
+import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode;
+import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode;
+import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntBinaryIntrinsicNode;
+import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntUnaryIntrinsicNode;
+
import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode.Operation.FMAX;
import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode.Operation.FMIN;
import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode.Operation.POW;
@@ -44,20 +57,6 @@
import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntBinaryIntrinsicNode.Operation.MIN;
import static uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntUnaryIntrinsicNode.Operation.ABS;
-import jdk.graal.compiler.nodes.ValueNode;
-import jdk.graal.compiler.nodes.graphbuilderconf.GraphBuilderContext;
-import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugin;
-import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugins;
-import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugins.Registration;
-
-import jdk.vm.ci.meta.JavaKind;
-import jdk.vm.ci.meta.ResolvedJavaMethod;
-import uk.ac.manchester.tornado.api.math.TornadoMath;
-import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPBinaryIntrinsicNode;
-import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXFPUnaryIntrinsicNode;
-import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntBinaryIntrinsicNode;
-import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXIntUnaryIntrinsicNode;
-
public class PTXMathPlugins {
public static void registerTornadoMathPlugins(final InvocationPlugins plugins) {
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java
index aa9f6453eb..ef1be960c6 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java
@@ -23,8 +23,6 @@
*/
package uk.ac.manchester.tornado.drivers.ptx.graal.compiler.plugins;
-import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.guarantee;
-
import jdk.graal.compiler.core.common.type.ObjectStamp;
import jdk.graal.compiler.core.common.type.StampPair;
import jdk.graal.compiler.nodes.ParameterNode;
@@ -41,7 +39,6 @@
import jdk.graal.compiler.nodes.java.StoreIndexedNode;
import jdk.graal.compiler.nodes.memory.address.AddressNode;
import jdk.graal.compiler.nodes.memory.address.OffsetAddressNode;
-
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import jdk.vm.ci.meta.ResolvedJavaType;
@@ -91,6 +88,8 @@
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.vector.VectorValueNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.PanamaPrivateMemoryNode;
+import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.guarantee;
+
public final class PTXVectorPlugins {
public static void registerPlugins(final Plugins ps, final InvocationPlugins plugins) {
@@ -231,6 +230,7 @@ private static void registerVectorCollectionsPlugins(final InvocationPlugins plu
final Registration r = new Registration(plugins, declaringClass);
r.register(new InvocationPlugin("loadFromArray", Receiver.class, storageType, int.class) {
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode array, ValueNode index) {
+ receiver.get(true);
final ResolvedJavaType resolvedType = b.getMetaAccess().lookupJavaType(vectorClass);
PTXKind kind = PTXKind.fromResolvedJavaType(resolvedType);
JavaKind elementKind = kind.getElementKind().asJavaKind();
@@ -243,6 +243,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("storeToArray", Receiver.class, vectorClass, storageType, int.class) {
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode value, ValueNode array, ValueNode index) {
+ receiver.get(true);
final ResolvedJavaType resolvedType = b.getMetaAccess().lookupJavaType(vectorClass);
PTXKind kind = PTXKind.fromResolvedJavaType(resolvedType);
JavaKind elementKind = kind.getElementKind().asJavaKind();
@@ -264,8 +265,8 @@ private static void registerVectorPlugins(final Plugins ps, final InvocationPlug
ps.appendNodePlugin(new NodePlugin() {
@Override
public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, ValueNode[] args) {
- if (method.getName().equals("") && (method.toString().contains("FloatArray.(int)") || method.toString().contains("DoubleArray.(int)") || method.toString().contains(
- "IntArray.(int)") || method.toString().contains("HalfFloatArray.(int)"))) {
+ if (method.getName().equals("") && (method.toString().contains("FloatArray.(int)") || method.toString().contains("DoubleArray.(int)") || method.toString()
+ .contains("IntArray.(int)") || method.toString().contains("HalfFloatArray.(int)"))) {
Class> javaType = resolveJavaClass(method.toString());
b.append(new PanamaPrivateMemoryNode(b.getMetaAccess().lookupJavaType(javaType), args[1]));
return true;
From ce6dad27de9950fa661a32e44eba464ef341856e Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 18 Jul 2024 19:16:57 +0300
Subject: [PATCH 25/54] Refactor PTXGraphBuilderPlugins in manually add
SingExtendNode
---
.../plugins/PTXGraphBuilderPlugins.java | 91 +------------------
1 file changed, 4 insertions(+), 87 deletions(-)
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
index 43d4fc621b..36cbcee1d2 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
@@ -418,9 +418,10 @@ private static void registerMemoryAccessPlugins(InvocationPlugins plugins) {
r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", InvocationPlugin.Receiver.class, int.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
- System.out.println("APPLY -> " + kind.name());
AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
- MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
+ SignExtendNode signExtend = new SignExtendNode(absoluteIndexNode.asNode(), 64);
+ b.getGraph().addOrUnique(signExtend);
+ MulNode mulNode = b.append(new MulNode(signExtend, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
b.addPush(kind, readNode);
@@ -430,8 +431,6 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, int.class, kind.toJavaClass(), int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
- System.out.println("APPLY set-> " + kind.name());
-
AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
SignExtendNode signExtend = new SignExtendNode(absoluteIndexNode.asNode(), 64);
b.getGraph().addOrUnique(signExtend);
@@ -445,89 +444,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
}
}
-
- // private static void registerMemoryAccessPlugins(InvocationPlugins plugins) {
- // Registration r = new Registration(plugins, TornadoMemorySegment.class);
- //
- // for (JavaKind kind : JavaKind.values()) {
- // if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- // r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", InvocationPlugin.Receiver.class, int.class, int.class) {
- // @Override
- // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
- // // Constant kindBytes = new RawConstant(kind.getByteCount());
- // // ConstantNode constantNode = new ConstantNode(kindBytes, StampFactory.forKind(JavaKind.Int));
- // // b.getGraph().addOrUnique(constantNode);
- // // AddNode addNode = new AddNode(index, baseIndex);
- // // b.getGraph().addOrUnique(addNode);
- // // MulNode mulNode = new MulNode(addNode, constantNode);
- // // b.getGraph().addOrUnique(mulNode);
- // // AddressNode addressNode = new OffsetAddressNode(receiver.get(), mulNode);
- // // b.getGraph().addOrUnique(addressNode);
- // // JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
- // // b.addPush(kind, readNode);
- // // return true;
- // AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
- // MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
- // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- // JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
- // b.addPush(kind, readNode);
- // return true;
- // }
- // });
- // r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, int.class, kind.toJavaClass(), int.class) {
- // @Override
- // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
- // AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
- // MulNode mulNode = b.append(new MulNode(absoluteIndexNode, ConstantNode.forInt(kind.getByteCount())));
- // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- // JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
- // b.add(writeNode);
- // return true;
- // // Constant kindBytes = new RawConstant(kind.getByteCount());
- // // ConstantNode constantNode = new ConstantNode(kindBytes, StampFactory.forKind(JavaKind.Int));
- // // b.getGraph().addOrUnique(constantNode);
- // // MulNode mulNode = new MulNode(index, constantNode);
- // // b.getGraph().addOrUnique(mulNode);
- // // AddressNode addressNode = new OffsetAddressNode(receiver.get(), mulNode);
- // // b.getGraph().addOrUnique(addressNode);
- // // JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
- // // b.add(writeNode);
- // // return true;
- // }
- // });
- // }
- // }
- // }
-
- // private static void registerMemoryAccessPlugins(InvocationPlugins plugins) {
- // Registration r = new Registration(plugins, MemorySegment.class);
- //
- // for (JavaKind kind : JavaKind.values()) {
- // if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- // r.register(new InvocationPlugin("getAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class) {
- // @Override
- // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index) {
- // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
- // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- // JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
- // b.addPush(kind, readNode);
- // return true;
- // }
- // });
- // r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class, kind.toJavaClass()) {
- // @Override
- // public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index, ValueNode value) {
- // MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
- // AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
- // JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
- // b.add(writeNode);
- // return true;
- // }
- // });
- // }
- // }
- // }
-
+
public static void registerNewInstancePlugins(Plugins plugins) {
plugins.appendNodePlugin(new PTXVectorNodePlugin());
}
From b298c767c689e8453f50b170bb7c3209c6582b8c Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 18 Jul 2024 19:30:22 +0300
Subject: [PATCH 26/54] Update receiver handling in various plugins
---
.../ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java | 3 ++-
.../ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java | 4 ++--
.../drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java | 4 ++++
3 files changed, 8 insertions(+), 3 deletions(-)
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
index 36cbcee1d2..886cd8de42 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java
@@ -188,6 +188,7 @@ private static void registerLocalBarrier(Registration r) {
r.register(new InvocationPlugin("localBarrier", InvocationPlugin.Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
PTXBarrierNode localBarrierNode = new PTXBarrierNode(0, -1);
b.add(localBarrierNode);
return true;
@@ -444,7 +445,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
}
}
-
+
public static void registerNewInstancePlugins(Plugins plugins) {
plugins.appendNodePlugin(new PTXVectorNodePlugin());
}
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java
index 369dfb22a1..db3ed46a0e 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java
@@ -29,7 +29,6 @@
import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugin;
import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugins;
import jdk.graal.compiler.nodes.graphbuilderconf.NodePlugin;
-
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.types.HalfFloat;
@@ -101,7 +100,8 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("getHalfFloatValue", InvocationPlugin.Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
- b.push(JavaKind.Short, b.append(new HalfFloatPlaceholder(receiver.get())));
+ receiver.get(true);
+ b.push(JavaKind.Short, b.append(new HalfFloatPlaceholder(receiver.get(true))));
return true;
}
});
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java
index ef1be960c6..ab7c0312e0 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java
@@ -277,6 +277,7 @@ public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, Va
r.register(new InvocationPlugin("get", Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode laneId) {
+ receiver.get(true);
final VectorLoadElementNode loadElement = new VectorLoadElementNode(vectorKind.getElementKind(), receiver.get(), laneId);
b.push(javaElementKind, b.append(loadElement));
return true;
@@ -286,6 +287,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("set", Receiver.class, vectorKind.getJavaClass()) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode value) {
+ receiver.get(true);
if (receiver.get() instanceof ParameterNode) {
final AddressNode address = new OffsetAddressNode(receiver.get(), null);
final VectorStoreGlobalMemory store = new VectorStoreGlobalMemory(vectorKind, address, value);
@@ -299,6 +301,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("set", Receiver.class, int.class, elementType) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode laneId, ValueNode value) {
+ receiver.get(true);
final VectorStoreElementProxyNode store = new VectorStoreElementProxyNode(vectorKind.getElementKind(), receiver.get(), laneId, value);
b.add(b.append(store));
return true;
@@ -308,6 +311,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("set", Receiver.class, int.class, storageType) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode laneId, ValueNode value) {
+ receiver.get(true);
final VectorStoreElementProxyNode store = new VectorStoreElementProxyNode(vectorKind.getElementKind(), receiver.get(), laneId, value);
b.add(b.append(store));
return true;
From aa602ff3b92a159a541618d1ed2f7d7f5e01cb96 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 18 Jul 2024 19:31:44 +0300
Subject: [PATCH 27/54] Update method parameters in OCLHalfFloatPlugins
---
.../graal/compiler/plugins/OCLHalfFloatPlugins.java | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java
index 4f743255e9..53567fbef5 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java
@@ -59,7 +59,7 @@ public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, Va
}
});
- r.register(new InvocationPlugin("add", HalfFloat.class, HalfFloat.class) {
+ r.register(new InvocationPlugin("add", InvocationPlugin.Receiver.class, HalfFloat.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) {
AddHalfFloatNode addNode = b.append(new AddHalfFloatNode(halfFloat1, halfFloat2));
@@ -68,7 +68,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
});
- r.register(new InvocationPlugin("sub", HalfFloat.class, HalfFloat.class) {
+ r.register(new InvocationPlugin("sub", InvocationPlugin.Receiver.class, HalfFloat.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) {
SubHalfFloatNode subNode = b.append(new SubHalfFloatNode(halfFloat1, halfFloat2));
@@ -77,7 +77,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
});
- r.register(new InvocationPlugin("mult", HalfFloat.class, HalfFloat.class) {
+ r.register(new InvocationPlugin("mult", InvocationPlugin.Receiver.class, HalfFloat.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) {
MultHalfFloatNode multNode = b.append(new MultHalfFloatNode(halfFloat1, halfFloat2));
@@ -86,7 +86,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
});
- r.register(new InvocationPlugin("div", HalfFloat.class, HalfFloat.class) {
+ r.register(new InvocationPlugin("div", InvocationPlugin.Receiver.class, HalfFloat.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) {
DivHalfFloatNode divNode = b.append(new DivHalfFloatNode(halfFloat1, halfFloat2));
From 4d92f4ae672c045683040e91a7d5a137e7c24414 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Fri, 19 Jul 2024 12:46:06 +0300
Subject: [PATCH 28/54] Refactor PTXStamp class and update PTXVectorPlugins
---
.../tornado/drivers/ptx/graal/PTXStamp.java | 42 ++++++-------------
.../compiler/plugins/PTXVectorPlugins.java | 2 +-
2 files changed, 14 insertions(+), 30 deletions(-)
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXStamp.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXStamp.java
index da054d7f13..e3951b0599 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXStamp.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXStamp.java
@@ -21,14 +21,10 @@
*/
package uk.ac.manchester.tornado.drivers.ptx.graal;
-import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.shouldNotReachHere;
-import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.unimplemented;
-
import jdk.graal.compiler.core.common.LIRKind;
import jdk.graal.compiler.core.common.spi.LIRKindTool;
import jdk.graal.compiler.core.common.type.ObjectStamp;
import jdk.graal.compiler.core.common.type.Stamp;
-
import jdk.vm.ci.meta.Constant;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.MemoryAccessProvider;
@@ -36,6 +32,9 @@
import jdk.vm.ci.meta.ResolvedJavaType;
import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXKind;
+import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.shouldNotReachHere;
+import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.unimplemented;
+
public class PTXStamp extends ObjectStamp {
private static final ResolvedJavaType STAMP_TYPE = null;
@@ -71,33 +70,18 @@ public PTXKind getPTXKind() {
return kind;
}
- @Override
public JavaKind getStackKind() {
if (kind.isPrimitive()) {
- switch (kind) {
- case PRED:
- return JavaKind.Boolean;
- case S8:
- case U8:
- return JavaKind.Byte;
- case S16:
- case U16:
- case F16:
- case B16:
- return JavaKind.Short;
- case S32:
- case U32:
- return JavaKind.Int;
- case S64:
- case U64:
- return JavaKind.Long;
- case F32:
- return JavaKind.Float;
- case F64:
- return JavaKind.Double;
- default:
- return JavaKind.Illegal;
- }
+ return switch (kind) {
+ case PRED -> JavaKind.Boolean;
+ case S8, U8 -> JavaKind.Byte;
+ case S16, U16, F16, B16 -> JavaKind.Short;
+ case S32, U32 -> JavaKind.Int;
+ case S64, U64 -> JavaKind.Long;
+ case F32 -> JavaKind.Float;
+ case F64 -> JavaKind.Double;
+ default -> JavaKind.Illegal;
+ };
} else if (kind.isVector()) {
return JavaKind.Object;
}
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java
index ab7c0312e0..7ca152f8a2 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXVectorPlugins.java
@@ -227,7 +227,7 @@ private static void registerVectorCollectionsPlugins(final InvocationPlugins plu
final Class> declaringClass = vectorKind.getJavaClass();
- final Registration r = new Registration(plugins, declaringClass);
+ final Registration r = new Registration(plugins, declaringClass).setAllowOverwrite(true);
r.register(new InvocationPlugin("loadFromArray", Receiver.class, storageType, int.class) {
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode array, ValueNode index) {
receiver.get(true);
From 0833d06b372e8edacbc31994654eb3b6c3d52f88 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Fri, 19 Jul 2024 12:32:53 +0100
Subject: [PATCH 29/54] Update module dependencies in SPIR-V and PTX drivers
---
.../ptx/src/main/java/module-info.java | 29 +++++++++++++++++--
.../spirv/src/main/java/module-info.java | 5 +++-
2 files changed, 30 insertions(+), 4 deletions(-)
diff --git a/tornado-drivers/ptx/src/main/java/module-info.java b/tornado-drivers/ptx/src/main/java/module-info.java
index 19fb6b1b37..6a90f2e19a 100644
--- a/tornado-drivers/ptx/src/main/java/module-info.java
+++ b/tornado-drivers/ptx/src/main/java/module-info.java
@@ -1,7 +1,30 @@
import uk.ac.manchester.tornado.runtime.TornadoBackendProvider;
-module tornado.drivers.ptx{requires java.base;requires transitive jdk.internal.vm.ci;requires transitive jdk.graal.compiler;requires transitive org.graalvm.collections;requires transitive org.graalvm.word;requires transitive tornado.api;requires transitive tornado.runtime;requires tornado.drivers.common;
+module tornado.drivers.ptx {
+ requires java.base;
+ requires transitive jdk.internal.vm.ci;
+ requires transitive jdk.graal.compiler;
+ requires transitive org.graalvm.collections;
+ requires transitive org.graalvm.word;
+ requires transitive tornado.api;
+ requires transitive tornado.runtime;
+ requires tornado.drivers.common;
-exports uk.ac.manchester.tornado.drivers.ptx;exports uk.ac.manchester.tornado.drivers.ptx.enums;exports uk.ac.manchester.tornado.drivers.ptx.graal;exports uk.ac.manchester.tornado.drivers.ptx.graal.asm;exports uk.ac.manchester.tornado.drivers.ptx.graal.backend;exports uk.ac.manchester.tornado.drivers.ptx.graal.compiler;exports uk.ac.manchester.tornado.drivers.ptx.graal.lir;exports uk.ac.manchester.tornado.drivers.ptx.graal.meta;exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes;exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.calc;exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.vector;exports uk.ac.manchester.tornado.drivers.ptx.graal.phases;exports uk.ac.manchester.tornado.drivers.ptx.mm;exports uk.ac.manchester.tornado.drivers.ptx.runtime;exports uk.ac.manchester.tornado.drivers.ptx.power;
+ exports uk.ac.manchester.tornado.drivers.ptx;
+ exports uk.ac.manchester.tornado.drivers.ptx.enums;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal.asm;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal.backend;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal.compiler;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal.lir;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal.meta;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.calc;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal.nodes.vector;
+ exports uk.ac.manchester.tornado.drivers.ptx.graal.phases;
+ exports uk.ac.manchester.tornado.drivers.ptx.mm;
+ exports uk.ac.manchester.tornado.drivers.ptx.runtime;
+ exports uk.ac.manchester.tornado.drivers.ptx.power;
-provides TornadoBackendProvider with uk.ac.manchester.tornado.drivers.ptx.PTXTornadoDriverProvider;}
+ provides TornadoBackendProvider with uk.ac.manchester.tornado.drivers.ptx.PTXTornadoDriverProvider;
+}
diff --git a/tornado-drivers/spirv/src/main/java/module-info.java b/tornado-drivers/spirv/src/main/java/module-info.java
index 6ad624dc83..7b6649a716 100644
--- a/tornado-drivers/spirv/src/main/java/module-info.java
+++ b/tornado-drivers/spirv/src/main/java/module-info.java
@@ -1,8 +1,11 @@
import uk.ac.manchester.tornado.runtime.TornadoBackendProvider;
module tornado.drivers.spirv {
+ requires java.base;
requires transitive jdk.internal.vm.ci;
- requires transitive jdk.internal.vm.compiler;
+ requires transitive jdk.graal.compiler;
+ requires transitive org.graalvm.collections;
+ requires transitive org.graalvm.word;
requires transitive tornado.api;
requires transitive tornado.runtime;
requires tornado.drivers.common;
From 6ca43342b6c4a7d09f3486020628e5bfb1d415cd Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Fri, 19 Jul 2024 12:57:58 +0100
Subject: [PATCH 30/54] Update package names in SPIRV exports list
---
.../src/etc/exportLists/spirv-exports | 120 +++++++++---------
1 file changed, 60 insertions(+), 60 deletions(-)
diff --git a/tornado-assembly/src/etc/exportLists/spirv-exports b/tornado-assembly/src/etc/exportLists/spirv-exports
index f965d1a3ca..da32aa95f8 100644
--- a/tornado-assembly/src/etc/exportLists/spirv-exports
+++ b/tornado-assembly/src/etc/exportLists/spirv-exports
@@ -24,68 +24,68 @@
--add-opens java.base/java.lang=tornado.drivers.spirv
--add-exports jdk.internal.vm.ci/jdk.vm.ci.common=tornado.drivers.spirv
--add-exports jdk.internal.vm.ci/jdk.vm.ci.amd64=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.hotspot.meta=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.replacements.classfile=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.alloc=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.util=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.cfg=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.framemap=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.hotspot.meta=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.replacements.classfile=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.alloc=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.util=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.cfg=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.framemap=tornado.drivers.spirv
--add-exports jdk.internal.vm.ci/jdk.vm.ci.meta=tornado.drivers.spirv
--add-exports jdk.internal.vm.ci/jdk.vm.ci.code=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.graph=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.graph.spi=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.gen=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodeinfo=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.calc=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.spi=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.code=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.debug=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.hotspot=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.java=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.asm=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.phases=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.graphbuilderconf=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.options=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.tiers=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.util=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.printer=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.graph=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.graph.spi=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.gen=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodeinfo=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.calc=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.spi=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.code=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.debug=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.hotspot=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.java=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.asm=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.phases=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.graphbuilderconf=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.options=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.tiers=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.util=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.printer=tornado.drivers.spirv
--add-exports jdk.internal.vm.ci/jdk.vm.ci.runtime=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.graph.iterators=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.java=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.bytecode=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.common=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.spi=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.api.replacements=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.replacements=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.common.inlining=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.phases=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.type=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.extended=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.loop=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.loop.phases=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.debug=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.memory=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.util=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.graph.iterators=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.java=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.bytecode=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.common=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.spi=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.api.replacements=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.replacements=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.common.inlining=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.phases=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.type=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.extended=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.loop=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.loop.phases=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.debug=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.memory=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.util=tornado.drivers.spirv
--add-opens jdk.internal.vm.ci/jdk.vm.ci.hotspot=tornado.drivers.spirv
--add-exports jdk.internal.vm.ci/jdk.vm.ci.hotspot=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.asm=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.cfg=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.schedule=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.virtual.phases.ea=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.lir.ssa=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.calc=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.gen=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.match=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.memory.address=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.type=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.graph=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.common.util=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.common.util=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.graph=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.word=tornado.drivers.spirv
---add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.memory=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.asm=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.cfg=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.schedule=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.virtual.phases.ea=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.lir.ssa=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.calc=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.gen=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.match=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.memory.address=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.nodes.type=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.graph=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.common.util=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.common.util=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.phases.graph=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.word=tornado.drivers.spirv
+--add-exports jdk.graal.compiler/jdk.graal.compiler.core.common.memory=tornado.drivers.spirv
From 31ddc3897e20f01418a28964f58147f5bb4e5b8d Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Fri, 19 Jul 2024 12:58:21 +0100
Subject: [PATCH 31/54] Update import statements to use jdk.graal namespace
instead of org.graalvm
---
.../graal/SPIRVHotSpotBackendFactory.java | 24 +++++++-------
.../spirv/graal/SPIRVLoweringProvider.java | 12 +++----
.../SPIRVCompilationResultBuilder.java | 4 +--
.../spirv/graal/compiler/SPIRVCompiler.java | 32 +++++++++----------
.../spirv/graal/compiler/SPIRVHighTier.java | 13 ++++----
.../graal/compiler/SPIRVLIRGenerator.java | 28 +++++++++++++---
.../spirv/graal/compiler/SPIRVLowTier.java | 6 ++--
.../spirv/graal/compiler/SPIRVMidTier.java | 10 +++---
.../phases/TornadoFixedArrayCopyPhase.java | 13 ++++----
.../TornadoFloatingReadReplacement.java | 12 +++----
10 files changed, 87 insertions(+), 67 deletions(-)
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVHotSpotBackendFactory.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVHotSpotBackendFactory.java
index b457ed4439..d51ca3671b 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVHotSpotBackendFactory.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVHotSpotBackendFactory.java
@@ -24,11 +24,11 @@
package uk.ac.manchester.tornado.drivers.spirv.graal;
import static jdk.vm.ci.common.InitTimer.timer;
-import static org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration.Plugins;
import jdk.graal.compiler.api.replacements.SnippetReflectionProvider;
import jdk.graal.compiler.core.common.spi.MetaAccessExtensionProvider;
import jdk.graal.compiler.hotspot.meta.HotSpotStampProvider;
+import jdk.graal.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration;
import jdk.graal.compiler.nodes.graphbuilderconf.InvocationPlugins;
import jdk.graal.compiler.nodes.loop.LoopsDataProviderImpl;
import jdk.graal.compiler.nodes.spi.LoopsDataProvider;
@@ -85,23 +85,23 @@ public static SPIRVBackend createJITCompiler(OptionValues options, HotSpotJVMCIR
HotSpotConstantReflectionProvider constantReflection = (HotSpotConstantReflectionProvider) jvmci.getConstantReflection();
// We specify an architecture of 64 bits
- SPIRVArchitecture architecture = new SPIRVArchitecture(SPIRVKind.OP_TYPE_INT_64, device.getByteOrder(), spirvRuntime);
+ uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVArchitecture architecture = new uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVArchitecture(SPIRVKind.OP_TYPE_INT_64, device.getByteOrder(), spirvRuntime);
SPIRVTargetDescription targetDescription = new SPIRVTargetDescription(architecture, false, SPIRV_STACK_ALIGNMENT, SPIRV_IMPLICIT_NULL_CHECK_LIMIT, SPIRV_INLINE_OBJECT, device
.isDeviceDoubleFPSupported(), device.getDeviceExtensions());
SPIRVDeviceContext deviceContext = context.getDeviceContext(device.getDeviceIndex());
- SPIRVCodeProvider codeProvider = new SPIRVCodeProvider(targetDescription);
+ uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVCodeProvider codeProvider = new uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVCodeProvider(targetDescription);
- SPIRVProviders providers;
- SPIRVSuitesProvider suites;
- SPIRVLoweringProvider lowerer;
- Plugins plugins;
+ uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVProviders providers;
+ uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVSuitesProvider suites;
+ uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVLoweringProvider lowerer;
+ GraphBuilderConfiguration.Plugins plugins;
try (InitTimer t = timer("create providers")) {
TornadoPlatformConfigurationProvider platformConfigurationProvider = new TornadoPlatformConfigurationProvider();
MetaAccessExtensionProvider metaAccessExtensionProvider = new TornadoMetaAccessExtensionProvider();
- lowerer = new SPIRVLoweringProvider(metaAccess, foreignCalls, platformConfigurationProvider, metaAccessExtensionProvider, constantReflection, vmConfig, targetDescription, false);
+ lowerer = new uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVLoweringProvider(metaAccess, foreignCalls, platformConfigurationProvider, metaAccessExtensionProvider, constantReflection, vmConfig, targetDescription, false);
WordTypes wordTypes = new TornadoWordTypes(metaAccess, SPIRVKind.OP_TYPE_FLOAT_32.asJavaKind());
LoopsDataProvider lpd = new LoopsDataProviderImpl();
@@ -115,9 +115,9 @@ public static SPIRVBackend createJITCompiler(OptionValues options, HotSpotJVMCIR
replacements.setGraphBuilderPlugins(plugins);
- suites = new SPIRVSuitesProvider(options, deviceContext, plugins, metaAccess, compilerConfiguration, addressLowering);
+ suites = new uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVSuitesProvider(options, deviceContext, plugins, metaAccess, compilerConfiguration, addressLowering);
- providers = new SPIRVProviders(metaAccess, codeProvider, constantReflection, constantFieldProvider, foreignCalls, lowerer, replacements, stampProvider, platformConfigurationProvider,
+ providers = new uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVProviders(metaAccess, codeProvider, constantReflection, constantFieldProvider, foreignCalls, lowerer, replacements, stampProvider, platformConfigurationProvider,
metaAccessExtensionProvider, snippetReflection, wordTypes, p.getLoopsDataProvider(), suites);
lowerer.initialize(options, new DummySnippetFactory(), providers);
@@ -137,10 +137,10 @@ public static SPIRVBackend createJITCompiler(OptionValues options, HotSpotJVMCIR
* {@link TornadoReplacements}
* @return Plugins for SPIRV
*/
- private static Plugins createGraphPlugins(HotSpotMetaAccessProvider metaAccess, TornadoReplacements replacements, SnippetReflectionProvider snippetReflectionProvider,
+ private static GraphBuilderConfiguration.Plugins createGraphPlugins(HotSpotMetaAccessProvider metaAccess, TornadoReplacements replacements, SnippetReflectionProvider snippetReflectionProvider,
LoweringProvider loweringProvider) {
InvocationPlugins invocationPlugins = new InvocationPlugins();
- Plugins plugins = new Plugins(invocationPlugins);
+ GraphBuilderConfiguration.Plugins plugins = new GraphBuilderConfiguration.Plugins(invocationPlugins);
SPIRVGraphBuilderPlugins.registerParametersPlugins(plugins);
SPIRVGraphBuilderPlugins.registerNewInstancePlugins(plugins);
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVLoweringProvider.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVLoweringProvider.java
index 94ab3d4c30..a87384a021 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVLoweringProvider.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVLoweringProvider.java
@@ -24,7 +24,7 @@
*/
package uk.ac.manchester.tornado.drivers.spirv.graal;
-import static org.graalvm.compiler.nodes.NamedLocationIdentity.ARRAY_LENGTH_LOCATION;
+import static jdk.graal.compiler.nodes.NamedLocationIdentity.ARRAY_LENGTH_LOCATION;
import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.shouldNotReachHere;
import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.unimplemented;
import static uk.ac.manchester.tornado.drivers.providers.TornadoMemoryOrder.GPU_MEMORY_MODE;
@@ -273,14 +273,14 @@ private void lowerReduceSnippets(Node node, LoweringTool tool) {
private void lowerLocalNewArray(StructuredGraph graph, int length, NewArrayNonVirtualizableNode newArray) {
LocalArrayNode localArrayNode;
ConstantNode newLengthNode = ConstantNode.forInt(length, graph);
- localArrayNode = graph.addWithoutUnique(new LocalArrayNode(SPIRVArchitecture.localSpace, newArray.elementType(), newLengthNode));
+ localArrayNode = graph.addWithoutUnique(new LocalArrayNode(uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVArchitecture.localSpace, newArray.elementType(), newLengthNode));
newArray.replaceAtUsages(localArrayNode);
}
private void lowerPrivateNewArray(StructuredGraph graph, int size, NewArrayNonVirtualizableNode newArray) {
FixedArrayNode fixedArrayNode;
final ConstantNode newLengthNode = ConstantNode.forInt(size, graph);
- fixedArrayNode = graph.addWithoutUnique(new FixedArrayNode(SPIRVArchitecture.privateSpace, newArray.elementType(), newLengthNode));
+ fixedArrayNode = graph.addWithoutUnique(new FixedArrayNode(uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVArchitecture.privateSpace, newArray.elementType(), newLengthNode));
newArray.replaceAtUsages(fixedArrayNode);
}
@@ -359,7 +359,7 @@ private void lowerInvoke(Invoke invoke, LoweringTool tool, StructuredGraph graph
ResolvedJavaType type = os.javaType(tool.getMetaAccess());
SPIRVKind kind = SPIRVKind.fromResolvedJavaTypeToVectorKind(type);
if (kind != SPIRVKind.ILLEGAL) {
- returnStampPair = StampPair.createSingle(SPIRVStampFactory.getStampFor(kind));
+ returnStampPair = StampPair.createSingle(uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVStampFactory.getStampFor(kind));
}
}
@@ -493,7 +493,7 @@ public void lowerLoadIndexedNode(LoadIndexedNode loadIndexed, LoweringTool tool)
AddressNode address;
Stamp loadStamp = loadIndexed.stamp(NodeView.DEFAULT);
- if (!(loadIndexed.stamp(NodeView.DEFAULT) instanceof SPIRVStamp)) {
+ if (!(loadIndexed.stamp(NodeView.DEFAULT) instanceof uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVStamp)) {
loadStamp = loadStamp(loadIndexed.stamp(NodeView.DEFAULT), elementKind, false);
}
address = createArrayAccess(graph, loadIndexed, elementKind);
@@ -532,7 +532,7 @@ private AbstractWriteNode createMemWriteNode(JavaKind elementKind, ValueNode val
}
ValueNode storeConvertValue = value;
Stamp valueStamp = value.stamp(NodeView.DEFAULT);
- if (!(valueStamp instanceof SPIRVStamp) || !((SPIRVStamp) valueStamp).getSPIRVKind().isVector()) {
+ if (!(valueStamp instanceof uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVStamp) || !((uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVStamp) valueStamp).getSPIRVKind().isVector()) {
storeConvertValue = implicitStoreConvert(graph, elementKind, value);
}
memoryWrite = graph.add(new WriteNode(address, NamedLocationIdentity.getArrayLocation(elementKind), storeConvertValue, BarrierType.NONE, TornadoMemoryOrder.GPU_MEMORY_MODE));
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompilationResultBuilder.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompilationResultBuilder.java
index 54d77ec1ff..1462ee591b 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompilationResultBuilder.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompilationResultBuilder.java
@@ -35,7 +35,7 @@
import org.graalvm.collections.Equivalence;
import jdk.graal.compiler.asm.Assembler;
import jdk.graal.compiler.code.CompilationResult;
-import jdk.graal.compiler.core.common.spi.CodeGenProviders;
+import jdk.graal.compiler.nodes.spi.CoreProviders;
import jdk.graal.compiler.debug.DebugContext;
import jdk.graal.compiler.lir.LIR;
import jdk.graal.compiler.lir.LIRInstruction;
@@ -75,7 +75,7 @@ public class SPIRVCompilationResultBuilder extends CompilationResultBuilder {
private boolean isParallel;
private SPIRVDeviceContext deviceContext;
- public SPIRVCompilationResultBuilder(CodeGenProviders providers, FrameMap frameMap, Assembler asm, DataBuilder dataBuilder, FrameContext frameContext, OptionValues options, DebugContext debug,
+ public SPIRVCompilationResultBuilder(CoreProviders providers, FrameMap frameMap, Assembler asm, DataBuilder dataBuilder, FrameContext frameContext, OptionValues options, DebugContext debug,
CompilationResult compilationResult, LIR lir) {
super(providers, frameMap, asm, dataBuilder, frameContext, options, debug, compilationResult, Register.None, EconomicMap.create(Equivalence.DEFAULT), NO_VERIFIERS, lir);
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompiler.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompiler.java
index 6377d72dce..b0b0458a78 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompiler.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompiler.java
@@ -23,7 +23,7 @@
*/
package uk.ac.manchester.tornado.drivers.spirv.graal.compiler;
-import static org.graalvm.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Optional;
+import static jdk.graal.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Optional;
import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.guarantee;
import static uk.ac.manchester.tornado.runtime.TornadoCoreRuntime.getDebugContext;
import static uk.ac.manchester.tornado.runtime.common.TornadoOptions.DUMP_COMPILED_METHODS;
@@ -113,9 +113,9 @@ public class SPIRVCompiler {
private static final TimerKey EmitLIR = DebugContext.timer("SPIRVEmitLIR");
private static final TimerKey EmitCode = DebugContext.timer("SPIRVEmitCode");
- private static final SPIRVIRGenerationPhase LIR_GENERATION_PHASE = new SPIRVIRGenerationPhase();
+ private static final uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVIRGenerationPhase LIR_GENERATION_PHASE = new uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVIRGenerationPhase();
- private synchronized static SPIRVCompilationResult compile(SPIRVCompilationRequest r) {
+ private synchronized static uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult compile(SPIRVCompilationRequest r) {
assert !r.graph.isFrozen();
try (DebugContext.Scope s0 = getDebugContext().scope("GraalCompiler", r.graph, r.providers.getCodeCache()); DebugCloseable a = CompilerTimer.start(getDebugContext())) {
emitFrontEnd(r.providers, r.backend, r.installedCodeOwner, r.args, r.meta, r.graph, r.graphBuilderSuite, r.optimisticOpts, r.profilingInfo, r.suites, r.isKernel, r.buildGraph,
@@ -149,7 +149,7 @@ private static void emitFrontEnd(Providers providers, SPIRVBackend backend, Reso
/*
* Register metadata with all tornado phases
*/
- ((SPIRVCanonicalizer) suites.getHighTier().getCustomCanonicalizer()).setContext(providers.getMetaAccess(), installedCodeOwner, args, meta);
+ ((uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCanonicalizer) suites.getHighTier().getCustomCanonicalizer()).setContext(providers.getMetaAccess(), installedCodeOwner, args, meta);
final TornadoHighTierContext highTierContext = new TornadoHighTierContext(providers, graphBuilderSuite, optimisticOpts, installedCodeOwner, args, meta, isKernel, batchCompilationConfig);
if (buildGraph) {
@@ -178,7 +178,7 @@ private static void emitFrontEnd(Providers providers, SPIRVBackend backend, Reso
}
}
- private static void emitBackEnd(StructuredGraph graph, Object stub, ResolvedJavaMethod installedCodeOwner, SPIRVBackend backend, SPIRVCompilationResult compilationResult,
+ private static void emitBackEnd(StructuredGraph graph, Object stub, ResolvedJavaMethod installedCodeOwner, SPIRVBackend backend, uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult compilationResult,
RegisterConfig registerConfig, TornadoLIRSuites lirSuites, boolean isKernel, boolean isParallel, TornadoProfiler profiler) {
try (DebugContext.Scope s = getDebugContext().scope("SPIRVBackend", graph.getLastSchedule()); DebugCloseable a = BackEnd.start(getDebugContext())) {
LIRGenerationResult lirGen = null;
@@ -196,7 +196,7 @@ private static void emitBackEnd(StructuredGraph graph, Object stub, ResolvedJava
}
private static LIRGenerationResult emitLIR(SPIRVBackend backend, StructuredGraph graph, Object stub, RegisterConfig registerConfig, TornadoLIRSuites lirSuites,
- SPIRVCompilationResult compilationResult, boolean isKernel) {
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult compilationResult, boolean isKernel) {
try {
return emitLIR0(backend, graph, stub, registerConfig, lirSuites, compilationResult, isKernel);
} catch (Throwable e) {
@@ -224,13 +224,13 @@ private static LIRGenerationResult emitLIR0(SPIRVB
}
RegisterAllocationConfig registerAllocationConfig = backend.newRegisterAllocationConfig(registerConfig, new String[] {});
FrameMapBuilder frameMapBuilder = backend.newFrameMapBuilder(registerConfig);
- SPIRVLIRGenerationResult lirGenRes = (SPIRVLIRGenerationResult) backend.newLIRGenerationResult(graph.compilationId(), lir, frameMapBuilder, registerAllocationConfig);
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVLIRGenerationResult lirGenRes = (uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVLIRGenerationResult) backend.newLIRGenerationResult(graph.compilationId(), lir, frameMapBuilder, registerAllocationConfig);
lirGenRes.setMethodIndex(backend.getMethodIndex());
LIRGeneratorTool lirGen = backend.newLIRGenerator(lirGenRes);
NodeLIRBuilderTool nodeLirGen = backend.newNodeLIRBuilder(graph, lirGen);
// LIR generation
- SPIRVIRGenerationPhase.LIRGenerationContext context = new SPIRVIRGenerationPhase.LIRGenerationContext(lirGen, nodeLirGen, graph, schedule, isKernel);
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVIRGenerationPhase.LIRGenerationContext context = new uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVIRGenerationPhase.LIRGenerationContext(lirGen, nodeLirGen, graph, schedule, isKernel);
LIR_GENERATION_PHASE.apply(backend.getTarget(), lirGenRes, context);
try (DebugContext.Scope s = getDebugContext().scope("LIRStages", nodeLirGen, lir)) {
@@ -256,10 +256,10 @@ private static LIRGenerationResult emitLowLevel(TargetDescription target, LIRGen
}
private static void emitCode(SPIRVBackend backend, Assumptions assumptions, ResolvedJavaMethod rootMethod, List methods, int bytecodeSize, LIRGenerationResult lirGen,
- SPIRVCompilationResult compilationResult, ResolvedJavaMethod installedCodeOwner, boolean isKernel, boolean isParallel, TornadoProfiler profiler) {
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult compilationResult, ResolvedJavaMethod installedCodeOwner, boolean isKernel, boolean isParallel, TornadoProfiler profiler) {
try (DebugCloseable a = EmitCode.start(getDebugContext())) {
FrameMap frameMap = lirGen.getFrameMap();
- final SPIRVCompilationResultBuilder crb = backend.newCompilationResultBuilder(frameMap, compilationResult, isKernel, isParallel, lirGen.getLIR());
+ final uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResultBuilder crb = backend.newCompilationResultBuilder(frameMap, compilationResult, isKernel, isParallel, lirGen.getLIR());
backend.emitCode(crb, lirGen.getLIR(), installedCodeOwner, profiler);
if (assumptions != null && !assumptions.isEmpty()) {
@@ -317,7 +317,7 @@ public static String buildKernelName(String methodName) {
return sb.toString();
}
- public static SPIRVCompilationResult compileSketchForDevice(Sketch sketch, CompilableTask task, SPIRVProviders providers, SPIRVBackend backend, TornadoProfiler profiler) {
+ public static uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult compileSketchForDevice(Sketch sketch, CompilableTask task, SPIRVProviders providers, SPIRVBackend backend, TornadoProfiler profiler) {
final StructuredGraph kernelGraph = (StructuredGraph) sketch.getGraph().copy(getDebugContext());
ResolvedJavaMethod resolvedJavaMethod = kernelGraph.method();
@@ -333,7 +333,7 @@ public static SPIRVCompilationResult compileSketchForDevice(Sketch sketch, Compi
OptimisticOptimizations optimisticOptimizations = OptimisticOptimizations.ALL;
ProfilingInfo profilingInfo = resolvedJavaMethod.getProfilingInfo();
- SPIRVCompilationResult kernelCompilationResult = new SPIRVCompilationResult(task.getId(), buildKernelName(resolvedJavaMethod.getName()), taskMeta);
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult kernelCompilationResult = new uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult(task.getId(), buildKernelName(resolvedJavaMethod.getName()), taskMeta);
CompilationResultBuilderFactory factory = CompilationResultBuilderFactory.Default;
Set methods = new HashSet<>();
@@ -375,7 +375,7 @@ public static SPIRVCompilationResult compileSketchForDevice(Sketch sketch, Compi
Sketch currentSketch = TornadoSketcher.lookup(currentMethod, task.meta().getBackendIndex(), taskMeta.getDeviceIndex());
final StructuredGraph graph = (StructuredGraph) currentSketch.getGraph().copy(getDebugContext());
- final SPIRVCompilationResult compilationResult = new SPIRVCompilationResult(task.getId(), currentMethod.getName(), taskMeta);
+ final uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult compilationResult = new uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult(task.getId(), currentMethod.getName(), taskMeta);
// Share assembler across compilation results
compilationResult.setAssembler(kernelCompilationRequest.compilationResult.getAssembler());
@@ -474,7 +474,7 @@ public static class SPIRVCompilationRequest {
public final ProfilingInfo profilingInfo;
public final TornadoSuites suites;
public final TornadoLIRSuites lirSuites;
- public final SPIRVCompilationResult compilationResult;
+ public final uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult compilationResult;
public final CompilationResultBuilderFactory factory;
public final boolean isKernel;
public final boolean buildGraph;
@@ -483,7 +483,7 @@ public static class SPIRVCompilationRequest {
public SPIRVCompilationRequest(StructuredGraph graph, ResolvedJavaMethod installedCodeOwner, Object[] args, TaskMetaData meta, Providers providers, SPIRVBackend backend,
PhaseSuite graphBuilderSuite, OptimisticOptimizations optimisticOpts, ProfilingInfo profilingInfo, TornadoSuites suites, TornadoLIRSuites lirSuites,
- SPIRVCompilationResult compilationResult, CompilationResultBuilderFactory factory, boolean isKernel, boolean buildGraph, BatchCompilationConfig batchCompilationConfig,
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult compilationResult, CompilationResultBuilderFactory factory, boolean isKernel, boolean buildGraph, BatchCompilationConfig batchCompilationConfig,
TornadoProfiler profiler) {
this.graph = graph;
this.installedCodeOwner = installedCodeOwner;
@@ -504,7 +504,7 @@ public SPIRVCompilationRequest(StructuredGraph graph, ResolvedJavaMethod install
this.profiler = profiler;
}
- public SPIRVCompilationResult execute() {
+ public uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult execute() {
return SPIRVCompiler.compile(this);
}
}
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java
index 07004a75f3..665510aa6a 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java
@@ -23,12 +23,11 @@
*
*/
package uk.ac.manchester.tornado.drivers.spirv.graal.compiler;
-
-import static org.graalvm.compiler.core.common.GraalOptions.ConditionalElimination;
-import static org.graalvm.compiler.core.common.GraalOptions.OptConvertDeoptsToGuards;
-import static org.graalvm.compiler.core.common.GraalOptions.PartialEscapeAnalysis;
-import static org.graalvm.compiler.core.phases.HighTier.Options.Inline;
-import static org.graalvm.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Optional;
+import static jdk.graal.compiler.core.common.GraalOptions.ConditionalElimination;
+import static jdk.graal.compiler.core.common.GraalOptions.OptConvertDeoptsToGuards;
+import static jdk.graal.compiler.core.common.GraalOptions.PartialEscapeAnalysis;
+import static jdk.graal.compiler.core.phases.HighTier.Options.Inline;
+import static jdk.graal.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Optional;
import jdk.graal.compiler.loop.phases.ConvertDeoptimizeToGuardPhase;
import jdk.graal.compiler.loop.phases.LoopFullUnrollPhase;
@@ -39,10 +38,10 @@
import jdk.graal.compiler.phases.common.DeadCodeEliminationPhase;
import jdk.graal.compiler.phases.common.HighTierLoweringPhase;
import jdk.graal.compiler.phases.common.IterativeConditionalEliminationPhase;
+import jdk.graal.compiler.phases.common.RemoveValueProxyPhase;
import jdk.graal.compiler.phases.common.inlining.InliningPhase;
import jdk.graal.compiler.phases.schedule.SchedulePhase;
import jdk.graal.compiler.virtual.phases.ea.PartialEscapePhase;
-
import jdk.vm.ci.meta.MetaAccessProvider;
import uk.ac.manchester.tornado.api.TornadoDeviceContext;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.analysis.TornadoShapeAnalysis;
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLIRGenerator.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLIRGenerator.java
index 44aa6193ec..62e0b61105 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLIRGenerator.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLIRGenerator.java
@@ -33,7 +33,7 @@
import jdk.graal.compiler.core.common.cfg.BasicBlock;
import jdk.graal.compiler.core.common.memory.BarrierType;
import jdk.graal.compiler.core.common.memory.MemoryOrderMode;
-import jdk.graal.compiler.core.common.spi.CodeGenProviders;
+import jdk.graal.compiler.nodes.spi.CoreProviders;
import jdk.graal.compiler.core.common.spi.ForeignCallLinkage;
import jdk.graal.compiler.core.common.type.Stamp;
import jdk.graal.compiler.lir.LIRFrameState;
@@ -77,8 +77,8 @@ public class SPIRVLIRGenerator extends LIRGenerator {
private SPIRVGenTool spirvGenTool;
private SPIRVBuiltinTool spirvBuiltinTool;
- public SPIRVLIRGenerator(CodeGenProviders providers, LIRGenerationResult lirGenRes, final int methodIndex) {
- super(new SPIRVLIRKindTool((SPIRVTargetDescription) providers.getCodeCache().getTarget()), new SPIRVArithmeticTool(), new SPIRVBarrierSetLIRGenerator(), new SPIRVMoveFactory(), providers,
+ public SPIRVLIRGenerator(CoreProviders providers, LIRGenerationResult lirGenRes, final int methodIndex) {
+ super(new SPIRVLIRKindTool((SPIRVTargetDescription) providers.getCodeCache().getTarget()), new SPIRVArithmeticTool(), new uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVBarrierSetLIRGenerator(), new uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVMoveFactory(), providers,
lirGenRes);
spirvGenTool = new SPIRVGenTool(this);
spirvBuiltinTool = new SPIRVBuiltinTool();
@@ -215,6 +215,16 @@ public void emitIntegerTestBranch(Value left, Value right, LabelRef trueDestinat
throw new RuntimeException("Not implemented yet");
}
+ @Override
+ public void emitOpMaskTestBranch(Value left, boolean negateLeft, Value right, LabelRef trueDestination, LabelRef falseDestination, double trueSuccessorProbability) {
+
+ }
+
+ @Override
+ public void emitOpMaskOrTestBranch(Value left, Value right, boolean allZeros, LabelRef trueDestination, LabelRef falseDestination, double trueSuccessorProbability) {
+
+ }
+
@Override
public Variable emitConditionalMove(PlatformKind cmpKind, Value leftVal, Value right, Condition cond, boolean unorderedIsTrue, Value trueValue, Value falseValue) {
Logger.traceBuildLIR(Logger.BACKEND.SPIRV, "emit TernaryBranch: " + leftVal + " " + cond + right + " ? " + trueValue + " : " + falseValue);
@@ -255,6 +265,16 @@ public Variable emitIntegerTestMove(Value leftVal, Value right, Value trueValue,
return result;
}
+ @Override
+ public Variable emitOpMaskTestMove(Value leftVal, boolean negateLeft, Value right, Value trueValue, Value falseValue) {
+ return null;
+ }
+
+ @Override
+ public Variable emitOpMaskOrTestMove(Value leftVal, Value right, boolean allZeros, Value trueValue, Value falseValue) {
+ return null;
+ }
+
@Override
public Variable emitReverseBytes(Value operand) {
return null;
@@ -355,7 +375,7 @@ public Variable newVariable(ValueKind> valueKind) {
// Format of the variable "_"
// variable.setName("spirv_" + spirvKind.getTypePrefix() + "_" + variable.index
// + "F" + methodIndex);
- SPIRVLIRGenerationResult res = (SPIRVLIRGenerationResult) getResult();
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVLIRGenerationResult res = (uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVLIRGenerationResult) getResult();
res.insertVariable(variable);
return variable;
}
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLowTier.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLowTier.java
index ced8d7a638..bd1669bf5f 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLowTier.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLowTier.java
@@ -24,17 +24,17 @@
*/
package uk.ac.manchester.tornado.drivers.spirv.graal.compiler;
-import static org.graalvm.compiler.core.common.GraalOptions.ConditionalElimination;
-import static org.graalvm.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Required;
+import static jdk.graal.compiler.core.common.GraalOptions.ConditionalElimination;
+import static jdk.graal.compiler.phases.common.DeadCodeEliminationPhase.Optionality.Required;
import jdk.graal.compiler.options.OptionValues;
import jdk.graal.compiler.phases.common.AddressLoweringByNodePhase;
import jdk.graal.compiler.phases.common.CanonicalizerPhase;
import jdk.graal.compiler.phases.common.DeadCodeEliminationPhase;
import jdk.graal.compiler.phases.common.FixReadsPhase;
+import jdk.graal.compiler.phases.common.UseTrappingNullChecksPhase;
import jdk.graal.compiler.phases.common.IterativeConditionalEliminationPhase;
import jdk.graal.compiler.phases.common.LowTierLoweringPhase;
-import jdk.graal.compiler.phases.common.UseTrappingNullChecksPhase;
import jdk.graal.compiler.phases.schedule.SchedulePhase;
import uk.ac.manchester.tornado.api.TornadoDeviceContext;
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVMidTier.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVMidTier.java
index 611da3278f..bc38840192 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVMidTier.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVMidTier.java
@@ -24,9 +24,9 @@
*/
package uk.ac.manchester.tornado.drivers.spirv.graal.compiler;
-import static org.graalvm.compiler.core.common.GraalOptions.ConditionalElimination;
-import static org.graalvm.compiler.core.common.GraalOptions.OptFloatingReads;
-import static org.graalvm.compiler.core.common.GraalOptions.ReassociateExpressions;
+import static jdk.graal.compiler.core.common.GraalOptions.ConditionalElimination;
+import static jdk.graal.compiler.core.common.GraalOptions.OptFloatingReads;
+import static jdk.graal.compiler.core.common.GraalOptions.ReassociateExpressions;
import jdk.graal.compiler.options.OptionValues;
import jdk.graal.compiler.phases.common.CanonicalizerPhase;
@@ -35,8 +35,8 @@
import jdk.graal.compiler.phases.common.IterativeConditionalEliminationPhase;
import jdk.graal.compiler.phases.common.MidTierLoweringPhase;
import jdk.graal.compiler.phases.common.ReassociationPhase;
-import jdk.graal.compiler.phases.common.RemoveValueProxyPhase;
+import jdk.graal.compiler.phases.common.RemoveValueProxyPhase;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.loops.TornadoPartialLoopUnroll;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.BoundCheckEliminationPhase;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.ExceptionCheckingElimination;
@@ -81,7 +81,7 @@ public SPIRVMidTier(OptionValues options) {
appendPhase(new IterativeConditionalEliminationPhase(canonicalizer, true));
}
- appendPhase(new RemoveValueProxyPhase(canonicalizer));
+// appendPhase(new RemoveValueProxyPhase(canonicalizer));
appendPhase(new GuardLoweringPhase());
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoFixedArrayCopyPhase.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoFixedArrayCopyPhase.java
index a2bef8360c..dcb37d4a0d 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoFixedArrayCopyPhase.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoFixedArrayCopyPhase.java
@@ -24,18 +24,19 @@
*/
package uk.ac.manchester.tornado.drivers.spirv.graal.phases;
-import org.graalvm.compiler.nodes.GraphState;
-import org.graalvm.compiler.nodes.StructuredGraph;
-import org.graalvm.compiler.nodes.ValuePhiNode;
-import org.graalvm.compiler.nodes.memory.address.OffsetAddressNode;
-import org.graalvm.compiler.phases.BasePhase;
-
+import jdk.graal.compiler.nodes.memory.address.OffsetAddressNode;
+import jdk.graal.compiler.nodes.GraphState;
+import jdk.graal.compiler.nodes.StructuredGraph;
+import jdk.graal.compiler.nodes.ValuePhiNode;
+import jdk.graal.compiler.phases.BasePhase;
import uk.ac.manchester.tornado.api.exceptions.TornadoCompilationException;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoLowTierContext;
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.FixedArrayNode;
import java.util.Optional;
+import static jdk.graal.compiler.phases.BasePhase.ALWAYS_APPLICABLE;
+
public class TornadoFixedArrayCopyPhase extends BasePhase {
@Override
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoFloatingReadReplacement.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoFloatingReadReplacement.java
index 2b0abab62b..1c8fa3faac 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoFloatingReadReplacement.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoFloatingReadReplacement.java
@@ -21,8 +21,8 @@
*/
package uk.ac.manchester.tornado.drivers.spirv.graal.phases;
-import static org.graalvm.compiler.graph.Graph.NodeEvent.NODE_ADDED;
-import static org.graalvm.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
+import static jdk.graal.compiler.graph.Graph.NodeEvent.NODE_ADDED;
+import static jdk.graal.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
import static org.graalvm.word.LocationIdentity.any;
import static org.graalvm.word.LocationIdentity.init;
@@ -252,7 +252,7 @@ protected void run(StructuredGraph graph, CoreProviders context) {
EconomicMap> modifiedInLoops = null;
if (graph.hasLoops()) {
modifiedInLoops = EconomicMap.create(Equivalence.IDENTITY);
- ControlFlowGraph cfg = ControlFlowGraph.compute(graph, true, true, false, false);
+ ControlFlowGraph cfg = ControlFlowGraph.newBuilder(graph).connectBlocks(true).computeLoops(true).computeFrequency(true).build();
for (Loop> l : cfg.getLoops()) {
HIRLoop loop = (HIRLoop) l;
processLoop(loop, modifiedInLoops);
@@ -359,7 +359,7 @@ private static void processAnchor(MemoryAnchorNode anchor, TornadoFloatingReadRe
private static void processAccess(MemoryAccess access, TornadoFloatingReadReplacement.MemoryMapImpl state) {
LocationIdentity locationIdentity = access.getLocationIdentity();
- if (!locationIdentity.equals(LocationIdentity.any()) && locationIdentity.isMutable()) {
+ if (!locationIdentity.equals(any()) && locationIdentity.isMutable()) {
MemoryKill lastLocationAccess = state.getLastLocationAccess(locationIdentity);
access.setLastLocationAccess(lastLocationAccess);
}
@@ -420,7 +420,7 @@ protected TornadoFloatingReadReplacement.MemoryMapImpl processNode(FixedNode nod
if (node instanceof LoopExitNode) {
final LoopExitNode loopExitNode = (LoopExitNode) node;
final EconomicSet modifiedInLoop = modifiedInLoops.get(loopExitNode.loopBegin());
- final boolean anyModified = modifiedInLoop.contains(LocationIdentity.any());
+ final boolean anyModified = modifiedInLoop.contains(any());
state.getMap().replaceAll(
(locationIdentity, memoryNode) -> (anyModified || modifiedInLoop.contains(locationIdentity)) ? ProxyNode.forMemory(memoryNode, loopExitNode, locationIdentity) : memoryNode);
}
@@ -519,7 +519,7 @@ protected TornadoFloatingReadReplacement.MemoryMapImpl afterSplit(AbstractBeginN
protected EconomicMap processLoop(LoopBeginNode loop, TornadoFloatingReadReplacement.MemoryMapImpl initialState) {
EconomicSet modifiedLocations = modifiedInLoops.get(loop);
EconomicMap phis = EconomicMap.create(Equivalence.DEFAULT);
- if (modifiedLocations.contains(LocationIdentity.any())) {
+ if (modifiedLocations.contains(any())) {
// create phis for all locations if ANY is modified in the loop
modifiedLocations = EconomicSet.create(Equivalence.DEFAULT, modifiedLocations);
modifiedLocations.addAll(initialState.getMap().getKeys());
From 431c455be1093298e3c5b40d13a831bc94f880c4 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Fri, 19 Jul 2024 13:21:00 +0100
Subject: [PATCH 32/54] Update SPIRV vector and graph builder plugins
---
.../plugins/SPIRVGraphBuilderPlugins.java | 59 ++++++++-----------
.../compiler/plugins/SPIRVVectorPlugins.java | 9 ++-
2 files changed, 34 insertions(+), 34 deletions(-)
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java
index 22a54313c4..58bf5c1dc9 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java
@@ -51,7 +51,9 @@
import jdk.graal.compiler.core.common.type.StampFactory;
import jdk.graal.compiler.nodes.ConstantNode;
import jdk.graal.compiler.nodes.ValueNode;
+import jdk.graal.compiler.nodes.calc.AddNode;
import jdk.graal.compiler.nodes.calc.MulNode;
+import jdk.graal.compiler.nodes.calc.SignExtendNode;
import jdk.graal.compiler.nodes.extended.JavaReadNode;
import jdk.graal.compiler.nodes.extended.JavaWriteNode;
import jdk.graal.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration.Plugins;
@@ -68,6 +70,7 @@
import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
+import uk.ac.manchester.tornado.api.types.arrays.TornadoMemorySegment;
import uk.ac.manchester.tornado.drivers.common.logging.Logger;
import uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVArchitecture;
import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVKind;
@@ -84,11 +87,11 @@
public class SPIRVGraphBuilderPlugins {
public static void registerParametersPlugins(Plugins plugins) {
- SPIRVVectorPlugins.registerParameterPlugins(plugins);
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.plugins.SPIRVVectorPlugins.registerParameterPlugins(plugins);
}
public static void registerNewInstancePlugins(Plugins plugins) {
- plugins.appendNodePlugin(new SPIRVVectorNodePlugin());
+ plugins.appendNodePlugin(new uk.ac.manchester.tornado.drivers.spirv.graal.compiler.plugins.SPIRVVectorNodePlugin());
// FIXME: Atomics for SPIRV Backend not implemented.
}
@@ -105,10 +108,10 @@ public static void registerInvocationPlugins(Plugins plugins, final InvocationPl
// Register plugins for the new API
registerKernelContextPlugins(invocationPlugins);
- SPIRVMathPlugins.registerTornadoMathPlugins(invocationPlugins);
- SPIRVVectorPlugins.registerPlugins(plugins, invocationPlugins);
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.plugins.SPIRVMathPlugins.registerTornadoMathPlugins(invocationPlugins);
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.plugins.SPIRVVectorPlugins.registerPlugins(plugins, invocationPlugins);
- SPIRVHalfFloatPlugins.registerPlugins(plugins, invocationPlugins);
+ uk.ac.manchester.tornado.drivers.spirv.graal.compiler.plugins.SPIRVHalfFloatPlugins.registerPlugins(plugins, invocationPlugins);
// Register plugins for Off-Heap Arrays with Panama
registerMemoryAccessPlugins(invocationPlugins);
}
@@ -154,6 +157,7 @@ private static void registerLocalBarrier(Registration r) {
r.register(new InvocationPlugin("localBarrier", InvocationPlugin.Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
SPIRVBarrierNode localBarrierNode = new SPIRVBarrierNode(SPIRVBarrierNode.SPIRVMemFenceFlags.LOCAL);
b.append(localBarrierNode);
return true;
@@ -165,6 +169,7 @@ private static void registerGlobalBarrier(Registration r) {
r.register(new InvocationPlugin("globalBarrier", InvocationPlugin.Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
SPIRVBarrierNode barrierNode = new SPIRVBarrierNode(SPIRVBarrierNode.SPIRVMemFenceFlags.GLOBAL);
b.append(barrierNode);
return true;
@@ -185,8 +190,11 @@ private static void registerLocalArray(Registration r, final String method, Java
r.register(new InvocationPlugin(method, InvocationPlugin.Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
+ receiver.get(true);
ConstantNode constantNode = new ConstantNode(size.asConstant(), StampFactory.forKind(JavaKind.Int));
+ b.getGraph().addOrUnique(constantNode);
LocalArrayNode localArrayNode = new LocalArrayNode(SPIRVArchitecture.localSpace, elementType, constantNode);
+ b.getGraph().addOrUnique(localArrayNode);
b.push(returnedJavaKind, localArrayNode);
return true;
}
@@ -367,47 +375,32 @@ private static void registerTornadoVMIntrinsicsPlugins(Plugins plugins) {
}
}
- public static Class> getValueLayoutClass(Class k) {
- if (k == int.class) {
- return ValueLayout.OfInt.class;
- } else if (k == double.class) {
- return ValueLayout.OfDouble.class;
- } else if (k == float.class) {
- return ValueLayout.OfFloat.class;
- } else if (k == long.class) {
- return ValueLayout.OfLong.class;
- } else if (k == boolean.class) {
- return ValueLayout.OfBoolean.class;
- } else if (k == byte.class) {
- return ValueLayout.OfByte.class;
- } else if (k == char.class) {
- return ValueLayout.OfChar.class;
- } else if (k == short.class) {
- return ValueLayout.OfShort.class;
- } else {
- throw new TornadoRuntimeException("Class type " + k + " not supported.");
- }
- }
private static void registerMemoryAccessPlugins(InvocationPlugins plugins) {
- Registration r = new Registration(plugins, MemorySegment.class);
+ Registration r = new Registration(plugins, TornadoMemorySegment.class);
for (JavaKind kind : JavaKind.values()) {
if (kind != JavaKind.Object && kind != JavaKind.Void && kind != JavaKind.Illegal) {
- r.register(new InvocationPlugin("getAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class) {
+ r.register(new InvocationPlugin("get" + kind.name() + "AtIndex", InvocationPlugin.Receiver.class, int.class, int.class) {
@Override
- public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index) {
- MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode baseIndex) {
+ AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
+ SignExtendNode signExtend = new SignExtendNode(absoluteIndexNode.asNode(), 64);
+ b.getGraph().addOrUnique(signExtend);
+ MulNode mulNode = b.append(new MulNode(signExtend, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
JavaReadNode readNode = new JavaReadNode(kind, addressNode, LocationIdentity.any(), BarrierType.NONE, MemoryOrderMode.PLAIN, false);
b.addPush(kind, readNode);
return true;
}
});
- r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, getValueLayoutClass(kind.toJavaClass()), long.class, kind.toJavaClass()) {
+ r.register(new InvocationPlugin("setAtIndex", InvocationPlugin.Receiver.class, int.class, kind.toJavaClass(), int.class) {
@Override
- public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode layout, ValueNode index, ValueNode value) {
- MulNode mulNode = b.append(new MulNode(index, ConstantNode.forInt(kind.getByteCount())));
+ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode index, ValueNode value, ValueNode baseIndex) {
+ AddNode absoluteIndexNode = b.append(new AddNode(index, baseIndex));
+ SignExtendNode signExtend = new SignExtendNode(absoluteIndexNode.asNode(), 64);
+ b.getGraph().addOrUnique(signExtend);
+ MulNode mulNode = b.append(new MulNode(signExtend, ConstantNode.forInt(kind.getByteCount())));
AddressNode addressNode = b.append(new OffsetAddressNode(receiver.get(), mulNode));
JavaWriteNode writeNode = new JavaWriteNode(kind, addressNode, LocationIdentity.any(), value, BarrierType.NONE, false);
b.add(writeNode);
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVVectorPlugins.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVVectorPlugins.java
index 5ecbc08109..b66e730a20 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVVectorPlugins.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVVectorPlugins.java
@@ -221,9 +221,10 @@ private static void registerVectorCollectionsPlugins(final InvocationPlugins plu
final Class> declaringClass = vectorKind.getJavaClass();
- final Registration r = new Registration(plugins, declaringClass);
+ final Registration r = new Registration(plugins, declaringClass).setAllowOverwrite(true);
r.register(new InvocationPlugin("loadFromArray", Receiver.class, storageType, int.class) {
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode array, ValueNode index) {
+ receiver.get(true);
final ResolvedJavaType resolvedType = b.getMetaAccess().lookupJavaType(vectorClass);
SPIRVKind kind = SPIRVKind.fromResolvedJavaTypeToVectorKind(resolvedType);
JavaKind elementKind = kind.getElementKind().asJavaKind();
@@ -236,6 +237,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("storeToArray", Receiver.class, vectorClass, storageType, int.class) {
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode value, ValueNode array, ValueNode index) {
+ receiver.get(true);
final ResolvedJavaType resolvedType = b.getMetaAccess().lookupJavaType(vectorClass);
SPIRVKind kind = SPIRVKind.fromResolvedJavaTypeToVectorKind(resolvedType);
JavaKind elementKind = kind.getElementKind().asJavaKind();
@@ -271,6 +273,7 @@ public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, Va
r.register(new InvocationPlugin("get", Receiver.class, int.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode laneId) {
+ receiver.get(true);
final VectorLoadElementNode loadElement = new VectorLoadElementNode(spirvVectorKind.getElementKind(), receiver.get(), laneId);
b.push(javaElementKind, b.append(loadElement));
return true;
@@ -280,6 +283,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("set", Receiver.class, int.class, storageType) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode laneId, ValueNode value) {
+ receiver.get(true);
final VectorStoreElementProxyNode store = new VectorStoreElementProxyNode(spirvVectorKind.getElementKind(), receiver.get(), laneId, value);
b.add(b.append(store));
return true;
@@ -289,6 +293,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("set", Receiver.class, spirvVectorKind.getJavaClass()) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode value) {
+ receiver.get(true);
if (receiver.get() instanceof ParameterNode) {
final AddressNode address = new OffsetAddressNode(receiver.get(), null);
final VectorStoreGlobalMemory store = new VectorStoreGlobalMemory(spirvVectorKind, address, value);
@@ -302,6 +307,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("set", Receiver.class, int.class, elementType) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode laneId, ValueNode value) {
+ receiver.get(true);
final VectorStoreElementProxyNode store = new VectorStoreElementProxyNode(spirvVectorKind.getElementKind(), receiver.get(), laneId, value);
b.add(b.append(store));
return true;
@@ -351,6 +357,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
r.register(new InvocationPlugin("getArray", Receiver.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
+ receiver.get(true);
final ResolvedJavaType resolvedType = b.getMetaAccess().lookupJavaType(declaringClass);
SPIRVKind kind = SPIRVKind.fromResolvedJavaTypeToVectorKind(resolvedType);
JavaKind elementKind = kind.getElementKind().asJavaKind();
From 429d8ad9ce0df4cad76a8501b72db256c8ec005d Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Fri, 19 Jul 2024 13:36:18 +0100
Subject: [PATCH 33/54] Refactor SPIRVCompilationResultBuilder class by
adjusting package references
---
.../SPIRVCompilationResultBuilder.java | 19 ++++++++++---------
1 file changed, 10 insertions(+), 9 deletions(-)
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompilationResultBuilder.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompilationResultBuilder.java
index 1462ee591b..5b18de6a52 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompilationResultBuilder.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVCompilationResultBuilder.java
@@ -33,9 +33,9 @@
import org.graalvm.collections.EconomicMap;
import org.graalvm.collections.Equivalence;
+
import jdk.graal.compiler.asm.Assembler;
import jdk.graal.compiler.code.CompilationResult;
-import jdk.graal.compiler.nodes.spi.CoreProviders;
import jdk.graal.compiler.debug.DebugContext;
import jdk.graal.compiler.lir.LIR;
import jdk.graal.compiler.lir.LIRInstruction;
@@ -54,10 +54,11 @@
import jdk.graal.compiler.nodes.MergeNode;
import jdk.graal.compiler.nodes.cfg.ControlFlowGraph;
import jdk.graal.compiler.nodes.cfg.HIRBlock;
+import jdk.graal.compiler.nodes.spi.CoreProviders;
import jdk.graal.compiler.options.OptionValues;
-
import jdk.vm.ci.code.Register;
import jdk.vm.ci.meta.ResolvedJavaMethod;
+
import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError;
import uk.ac.manchester.tornado.drivers.common.logging.Logger;
import uk.ac.manchester.tornado.drivers.opencl.graal.compiler.OCLBlockVisitor;
@@ -77,7 +78,7 @@ public class SPIRVCompilationResultBuilder extends CompilationResultBuilder {
public SPIRVCompilationResultBuilder(CoreProviders providers, FrameMap frameMap, Assembler asm, DataBuilder dataBuilder, FrameContext frameContext, OptionValues options, DebugContext debug,
CompilationResult compilationResult, LIR lir) {
- super(providers, frameMap, asm, dataBuilder, frameContext, options, debug, compilationResult, Register.None, EconomicMap.create(Equivalence.DEFAULT), NO_VERIFIERS, lir);
+ super(providers, frameMap, asm, dataBuilder, frameContext, options, debug, compilationResult, Register.None, EconomicMap.create(Equivalence.DEFAULT), NO_VERIFIERS, lir);
nonInlinedMethods = new HashSet<>();
}
@@ -148,7 +149,7 @@ public void emit(LIR lir) {
final ControlFlowGraph cfg = (ControlFlowGraph) lir.getControlFlowGraph();
Logger.traceCodeGen(Logger.BACKEND.SPIRV, "Traversing CFG: ", cfg.graph.name);
cfg.computePostdominators();
- traverseControlFlowGraph(cfg, new SPIRVBlockVisitor(this));
+ traverseControlFlowGraph(cfg, new uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVBlockVisitor(this));
Logger.traceCodeGen(Logger.BACKEND.SPIRV, "Finished traversing CFG");
this.lir = null;
@@ -156,14 +157,14 @@ public void emit(LIR lir) {
}
- private void traverseControlFlowGraph(ControlFlowGraph cfg, SPIRVBlockVisitor visitor) {
+ private void traverseControlFlowGraph(ControlFlowGraph cfg, uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVBlockVisitor visitor) {
traverseControlFlowGraph(cfg.getStartBlock(), visitor, new HashSet<>(), new HashMap<>());
if (rescheduledBasicBlocks != null) {
rescheduledBasicBlocks.clear();
}
}
- private void rescheduleBasicBlock(HIRBlock basicBlock, SPIRVBlockVisitor visitor, HashSet visited, HashMap pending) {
+ private void rescheduleBasicBlock(HIRBlock basicBlock, uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVBlockVisitor visitor, HashSet visited, HashMap pending) {
HIRBlock block = pending.get(basicBlock);
visitor.enter(block);
visitor.exit(block);
@@ -217,7 +218,7 @@ private boolean isTrueBranchWithEndNodeOrNotControlSplit(HIRBlock blockTrueBranc
* @param pending
* {@link HashMap}
*/
- private void rescheduleTrueBranchConditionsIfNeeded(HIRBlock basicBlock, SPIRVBlockVisitor visitor, HashSet visited, HashMap pending) {
+ private void rescheduleTrueBranchConditionsIfNeeded(HIRBlock basicBlock, uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVBlockVisitor visitor, HashSet visited, HashMap pending) {
if (!basicBlock.isLoopHeader() && basicBlock.getDominator() != null && basicBlock.getDominator().getEndNode() instanceof IfNode) {
IfNode ifNode = (IfNode) basicBlock.getDominator().getEndNode();
HIRBlock blockTrueBranch = getBlockTrueBranch(basicBlock);
@@ -236,7 +237,7 @@ && isTrueBranchWithEndNodeOrNotControlSplit(blockTrueBranch))) {
}
}
- private void traverseControlFlowGraph(HIRBlock basicBlock, SPIRVBlockVisitor visitor, HashSet visited, HashMap pending) {
+ private void traverseControlFlowGraph(HIRBlock basicBlock, uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVBlockVisitor visitor, HashSet visited, HashMap pending) {
if (pending.containsKey(basicBlock) && !visited.contains(pending.get(basicBlock))) {
rescheduleBasicBlock(basicBlock, visitor, visited, pending);
@@ -357,7 +358,7 @@ public void addNonInlinedMethod(ResolvedJavaMethod targetMethod) {
}
public TaskMetaData getTaskMetaData() {
- return ((SPIRVCompilationResult) compilationResult).getMeta();
+ return ((uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResult) compilationResult).getMeta();
}
}
From 859597b78994ef06f656afcf754c9f2d7e575ba3 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Tue, 23 Jul 2024 15:14:05 +0300
Subject: [PATCH 34/54] Update TornadoHalfFloatReplacement class to handle
newly introduced Pi Nodes
---
.../phases/TornadoHalfFloatReplacement.java | 279 +++++++++---------
1 file changed, 144 insertions(+), 135 deletions(-)
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java
index 3d273afe16..3d14a0bba8 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java
@@ -21,12 +21,6 @@
*/
package uk.ac.manchester.tornado.drivers.ptx.graal.phases;
-import java.util.ArrayList;
-import java.util.Optional;
-
-import jdk.vm.ci.meta.Constant;
-import jdk.vm.ci.meta.JavaKind;
-import jdk.vm.ci.meta.RawConstant;
import jdk.graal.compiler.core.common.type.StampFactory;
import jdk.graal.compiler.graph.Node;
import jdk.graal.compiler.nodes.ConstantNode;
@@ -43,7 +37,9 @@
import jdk.graal.compiler.nodes.java.NewInstanceNode;
import jdk.graal.compiler.nodes.memory.address.AddressNode;
import jdk.graal.compiler.phases.BasePhase;
-
+import jdk.vm.ci.meta.Constant;
+import jdk.vm.ci.meta.JavaKind;
+import jdk.vm.ci.meta.RawConstant;
import uk.ac.manchester.tornado.api.internal.annotations.HalfType;
import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXKind;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXHalfFloatDivisionNode;
@@ -64,135 +60,10 @@
import uk.ac.manchester.tornado.runtime.graal.nodes.VectorHalfRead;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext;
-public class TornadoHalfFloatReplacement extends BasePhase {
-
- @Override
- public Optional notApplicableTo(GraphState graphState) {
- return ALWAYS_APPLICABLE;
- }
-
- protected void run(StructuredGraph graph, TornadoHighTierContext context) {
-
- for (ValueAnchorNode valueAnchorNode : graph.getNodes().filter(ValueAnchorNode.class)) {
- ArrayList deletePi = new ArrayList();
- for (Node valueAnchorNodeUsage : valueAnchorNode.usages()) {
- if (valueAnchorNodeUsage instanceof PiNode) {
- PiNode piNode = (PiNode) valueAnchorNodeUsage;
- piNode.replaceAtUsages(piNode.object());
- deletePi.add(piNode);
- }
- }
- for (PiNode p : deletePi) {
- p.safeDelete();
- }
- deleteFixed(valueAnchorNode);
- }
-
- // replace reads with halfFloat reads
- for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
- if (javaRead.successors().first() instanceof NewInstanceNode) {
- NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first();
- if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
- if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
- NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
- deleteFixed(newHalfFloatInstance);
- }
- AddressNode readingAddress = javaRead.getAddress();
- ReadHalfFloatNode readHalfFloatNode = new ReadHalfFloatNode(readingAddress);
- graph.addWithoutUnique(readHalfFloatNode);
- replaceFixed(javaRead, readHalfFloatNode);
- newInstanceNode.replaceAtUsages(readHalfFloatNode);
- deleteFixed(newInstanceNode);
- }
- }
- }
-
- for (NewInstanceNode newInstanceNode : graph.getNodes().filter(NewInstanceNode.class)) {
- if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
- if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
- NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
- ValueNode valueInput = newHalfFloatInstance.getValue();
- newInstanceNode.replaceAtUsages(valueInput);
- deleteFixed(newInstanceNode);
- deleteFixed(newHalfFloatInstance);
- }
- }
- }
-
- // replace writes with halfFloat writes
- for (JavaWriteNode javaWrite : graph.getNodes().filter(JavaWriteNode.class)) {
- if (isWriteHalfFloat(javaWrite)) {
- // This casting is safe to do as it is already checked by the isWriteHalfFloat function
- HalfFloatPlaceholder placeholder = (HalfFloatPlaceholder) javaWrite.value();
- ValueNode writingValue;
- if (javaWrite.predecessor() instanceof NewHalfFloatInstance) {
- // if a new HalfFloat instance is written
- NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) javaWrite.predecessor();
- writingValue = newHalfFloatInstance.getValue();
- if (newHalfFloatInstance.predecessor() instanceof NewInstanceNode) {
- NewInstanceNode newInstanceNode = (NewInstanceNode) newHalfFloatInstance.predecessor();
- if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) {
- deleteFixed(newInstanceNode);
- deleteFixed(newHalfFloatInstance);
- }
- }
- } else {
- // if the result of an operation or a stored value is written
- writingValue = placeholder.getInput();
- }
- placeholder.replaceAtUsages(writingValue);
- placeholder.safeDelete();
- AddressNode writingAddress = javaWrite.getAddress();
- WriteHalfFloatNode writeHalfFloatNode = new WriteHalfFloatNode(writingAddress, writingValue);
- graph.addWithoutUnique(writeHalfFloatNode);
- replaceFixed(javaWrite, writeHalfFloatNode);
- deleteFixed(javaWrite);
- }
- }
-
- // replace the half float operator nodes with the corresponding regular operators
- replaceAddHalfFloatNodes(graph);
- replaceSubHalfFloatNodes(graph);
- replaceMultHalfFloatNodes(graph);
- replaceDivHalfFloatNodes(graph);
-
- // add after the loadindexedvector nodes the marker node to fix the offset of its read
-
- for (LoadIndexedVectorNode loadIndexedVectorNode : graph.getNodes().filter(LoadIndexedVectorNode.class)) {
- if (loadIndexedVectorNode.getPtxKind().isHalf()) {
- VectorHalfRead vectorHalfRead;
- if (loadIndexedVectorNode.index() instanceof ConstantNode) {
- ConstantNode offset = (ConstantNode) loadIndexedVectorNode.index();
- int offsetValue = Integer.valueOf(offset.getValue().toValueString());
- vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead(offsetValue));
- } else {
- vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead());
- }
- graph.addAfterFixed(loadIndexedVectorNode, vectorHalfRead);
- }
- }
-
- for (VectorValueNode vectorValueNode : graph.getNodes().filter(VectorValueNode.class)) {
- if (vectorValueNode.getPTXKind().isHalf()) {
- for (Node vectorElement : vectorValueNode.inputs()) {
- if (vectorElement instanceof VectorLoadElementNode) {
- VectorLoadElementNode vectorLoad = (VectorLoadElementNode) vectorElement;
- VectorLoadElementNode vectorLoadShort = new VectorLoadElementNode(PTXKind.S16, vectorLoad.getVector(), vectorLoad.getLaneId());
- graph.addWithoutUnique(vectorLoadShort);
- vectorLoad.replaceAtUsages(vectorLoadShort);
- vectorLoad.safeDelete();
- } else if (vectorElement instanceof ConstantNode constantNode && constantNode.getValue().toValueString().contains("null")) {
- Constant zeroValue = new RawConstant(0);
- ConstantNode zero = new ConstantNode(zeroValue, StampFactory.forKind(JavaKind.Short));
- graph.addWithoutUnique(zero);
- constantNode.replaceAtUsages(zero);
- constantNode.safeDelete();
- }
- }
- }
- }
+import java.util.ArrayList;
+import java.util.Optional;
- }
+public class TornadoHalfFloatReplacement extends BasePhase {
private static ValueNode replaceAdd(AddHalfFloatNode addHalfFloatNode, StructuredGraph graph) {
ValueNode addNode;
@@ -358,4 +229,142 @@ private static void deleteFixed(Node node) {
node.safeDelete();
}
}
+
+ @Override
+ public Optional notApplicableTo(GraphState graphState) {
+ return ALWAYS_APPLICABLE;
+ }
+
+ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
+
+ for (ValueAnchorNode valueAnchorNode : graph.getNodes().filter(ValueAnchorNode.class)) {
+ ArrayList deletePi = new ArrayList();
+ for (Node valueAnchorNodeUsage : valueAnchorNode.usages()) {
+ if (valueAnchorNodeUsage instanceof PiNode) {
+ PiNode piNode = (PiNode) valueAnchorNodeUsage;
+ piNode.replaceAtUsages(piNode.object());
+ deletePi.add(piNode);
+ }
+ }
+ for (PiNode p : deletePi) {
+ p.safeDelete();
+ }
+ deleteFixed(valueAnchorNode);
+ }
+
+ // cleaup the reminder Pi nodes introduced since Graal 24.0.1
+ for (PiNode piNode : graph.getNodes().filter(PiNode.class)) {
+ for (Node piNodeUsages : piNode.usages()) {
+ if (piNodeUsages instanceof VectorValueNode) {
+ piNode.replaceAtUsages(piNode.object());
+ piNode.safeDelete();
+ }
+ }
+ }
+
+ // replace reads with halfFloat reads
+ for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
+ if (javaRead.successors().first() instanceof NewInstanceNode) {
+ NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first();
+ if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
+ if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
+ NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
+ deleteFixed(newHalfFloatInstance);
+ }
+ AddressNode readingAddress = javaRead.getAddress();
+ ReadHalfFloatNode readHalfFloatNode = new ReadHalfFloatNode(readingAddress);
+ graph.addWithoutUnique(readHalfFloatNode);
+ replaceFixed(javaRead, readHalfFloatNode);
+ newInstanceNode.replaceAtUsages(readHalfFloatNode);
+ deleteFixed(newInstanceNode);
+ }
+ }
+ }
+
+ for (NewInstanceNode newInstanceNode : graph.getNodes().filter(NewInstanceNode.class)) {
+ if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
+ if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
+ NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
+ ValueNode valueInput = newHalfFloatInstance.getValue();
+ newInstanceNode.replaceAtUsages(valueInput);
+ deleteFixed(newInstanceNode);
+ deleteFixed(newHalfFloatInstance);
+ }
+ }
+ }
+
+ // replace writes with halfFloat writes
+ for (JavaWriteNode javaWrite : graph.getNodes().filter(JavaWriteNode.class)) {
+ if (isWriteHalfFloat(javaWrite)) {
+ // This casting is safe to do as it is already checked by the isWriteHalfFloat function
+ HalfFloatPlaceholder placeholder = (HalfFloatPlaceholder) javaWrite.value();
+ ValueNode writingValue;
+ if (javaWrite.predecessor() instanceof NewHalfFloatInstance) {
+ // if a new HalfFloat instance is written
+ NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) javaWrite.predecessor();
+ writingValue = newHalfFloatInstance.getValue();
+ if (newHalfFloatInstance.predecessor() instanceof NewInstanceNode) {
+ NewInstanceNode newInstanceNode = (NewInstanceNode) newHalfFloatInstance.predecessor();
+ if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) {
+ deleteFixed(newInstanceNode);
+ deleteFixed(newHalfFloatInstance);
+ }
+ }
+ } else {
+ // if the result of an operation or a stored value is written
+ writingValue = placeholder.getInput();
+ }
+ placeholder.replaceAtUsages(writingValue);
+ placeholder.safeDelete();
+ AddressNode writingAddress = javaWrite.getAddress();
+ WriteHalfFloatNode writeHalfFloatNode = new WriteHalfFloatNode(writingAddress, writingValue);
+ graph.addWithoutUnique(writeHalfFloatNode);
+ replaceFixed(javaWrite, writeHalfFloatNode);
+ deleteFixed(javaWrite);
+ }
+ }
+
+ // replace the half float operator nodes with the corresponding regular operators
+ replaceAddHalfFloatNodes(graph);
+ replaceSubHalfFloatNodes(graph);
+ replaceMultHalfFloatNodes(graph);
+ replaceDivHalfFloatNodes(graph);
+
+ // add after the loadindexedvector nodes the marker node to fix the offset of its read
+
+ for (LoadIndexedVectorNode loadIndexedVectorNode : graph.getNodes().filter(LoadIndexedVectorNode.class)) {
+ if (loadIndexedVectorNode.getPtxKind().isHalf()) {
+ VectorHalfRead vectorHalfRead;
+ if (loadIndexedVectorNode.index() instanceof ConstantNode) {
+ ConstantNode offset = (ConstantNode) loadIndexedVectorNode.index();
+ int offsetValue = Integer.valueOf(offset.getValue().toValueString());
+ vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead(offsetValue));
+ } else {
+ vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead());
+ }
+ graph.addAfterFixed(loadIndexedVectorNode, vectorHalfRead);
+ }
+ }
+
+ for (VectorValueNode vectorValueNode : graph.getNodes().filter(VectorValueNode.class)) {
+ if (vectorValueNode.getPTXKind().isHalf()) {
+ for (Node vectorElement : vectorValueNode.inputs()) {
+ if (vectorElement instanceof VectorLoadElementNode) {
+ VectorLoadElementNode vectorLoad = (VectorLoadElementNode) vectorElement;
+ VectorLoadElementNode vectorLoadShort = new VectorLoadElementNode(PTXKind.S16, vectorLoad.getVector(), vectorLoad.getLaneId());
+ graph.addWithoutUnique(vectorLoadShort);
+ vectorLoad.replaceAtUsages(vectorLoadShort);
+ vectorLoad.safeDelete();
+ } else if (vectorElement instanceof ConstantNode constantNode && constantNode.getValue().toValueString().contains("null")) {
+ Constant zeroValue = new RawConstant(0);
+ ConstantNode zero = new ConstantNode(zeroValue, StampFactory.forKind(JavaKind.Short));
+ graph.addWithoutUnique(zero);
+ constantNode.replaceAtUsages(zero);
+ constantNode.safeDelete();
+ }
+ }
+ }
+ }
+
+ }
}
From d9f5b497a7e72b21ec4af262b41031986e861fb8 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Tue, 23 Jul 2024 16:26:05 +0300
Subject: [PATCH 35/54] Refactor InvocationPlugin registration in
OCLHalfFloatPlugins
---
.../graal/compiler/plugins/OCLHalfFloatPlugins.java | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java
index 53567fbef5..4f743255e9 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java
@@ -59,7 +59,7 @@ public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, Va
}
});
- r.register(new InvocationPlugin("add", InvocationPlugin.Receiver.class, HalfFloat.class, HalfFloat.class) {
+ r.register(new InvocationPlugin("add", HalfFloat.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) {
AddHalfFloatNode addNode = b.append(new AddHalfFloatNode(halfFloat1, halfFloat2));
@@ -68,7 +68,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
});
- r.register(new InvocationPlugin("sub", InvocationPlugin.Receiver.class, HalfFloat.class, HalfFloat.class) {
+ r.register(new InvocationPlugin("sub", HalfFloat.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) {
SubHalfFloatNode subNode = b.append(new SubHalfFloatNode(halfFloat1, halfFloat2));
@@ -77,7 +77,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
});
- r.register(new InvocationPlugin("mult", InvocationPlugin.Receiver.class, HalfFloat.class, HalfFloat.class) {
+ r.register(new InvocationPlugin("mult", HalfFloat.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) {
MultHalfFloatNode multNode = b.append(new MultHalfFloatNode(halfFloat1, halfFloat2));
@@ -86,7 +86,7 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
}
});
- r.register(new InvocationPlugin("div", InvocationPlugin.Receiver.class, HalfFloat.class, HalfFloat.class) {
+ r.register(new InvocationPlugin("div", HalfFloat.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) {
DivHalfFloatNode divNode = b.append(new DivHalfFloatNode(halfFloat1, halfFloat2));
From 2ea09261e54315632d52bfcd4e96a78920bb900f Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 24 Jul 2024 12:16:56 +0300
Subject: [PATCH 36/54] [WIP] Refactor half float replacement and guard
elimination methods for new PiNodes
---
.../compiler/plugins/PTXHalfFloatPlugin.java | 2 +
.../phases/TornadoHalfFloatReplacement.java | 50 +++++++++++++++++--
...TornadoHalfFloatFixedGuardElimination.java | 36 +++++++------
3 files changed, 69 insertions(+), 19 deletions(-)
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java
index db3ed46a0e..31ccd50259 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java
@@ -101,6 +101,8 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) {
receiver.get(true);
+ // PiNode piNode = new PiNode(receiver.get(), StampFactory.forKind(JavaKind.Short));
+ // b.getGraph().addOrUnique(piNode);
b.push(JavaKind.Short, b.append(new HalfFloatPlaceholder(receiver.get(true))));
return true;
}
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java
index 3d14a0bba8..3bd22b8af4 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java
@@ -113,11 +113,34 @@ private static ValueNode replaceSub(SubHalfFloatNode subHalfFloatNode, Structure
subNode = new SubNode(subX, subY);
graph.addWithoutUnique(subNode);
}
- subHalfFloatNode.replaceAtUsages(subNode);
+
+ PiNode piNode = null;
+ if (subHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = subHalfFloatNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ piNode.replaceAtUsages(subNode);
+ piNode.safeDelete();
+ } else {
+ subHalfFloatNode.replaceAtUsages(subNode);
+ }
subHalfFloatNode.safeDelete();
+
return subNode;
}
+ private static ValueNode nodeToBeReplaced(ValueNode valueNode) {
+ PiNode piNode = null;
+ if (valueNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = valueNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ return piNode;
+ } else {
+ return valueNode;
+ }
+ }
+
private static void replaceSubHalfFloatNodes(StructuredGraph graph) {
for (SubHalfFloatNode subHalfFloatNode : graph.getNodes().filter(SubHalfFloatNode.class)) {
replaceSub(subHalfFloatNode, graph);
@@ -136,7 +159,18 @@ private static ValueNode replaceMult(MultHalfFloatNode multHalfFloatNode, Struct
multNode = new MulNode(multX, multY);
graph.addWithoutUnique(multNode);
}
- multHalfFloatNode.replaceAtUsages(multNode);
+
+ PiNode piNode = null;
+ if (multHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = multHalfFloatNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ piNode.replaceAtUsages(multNode);
+ piNode.safeDelete();
+ } else {
+ multHalfFloatNode.replaceAtUsages(multNode);
+ }
+
multHalfFloatNode.safeDelete();
return multNode;
}
@@ -154,7 +188,17 @@ private static ValueNode replaceDiv(DivHalfFloatNode divHalfFloatNode, Structure
PTXHalfFloatDivisionNode divNode = new PTXHalfFloatDivisionNode(divX, divY);
graph.addWithoutUnique(divNode);
- divHalfFloatNode.replaceAtUsages(divNode);
+ PiNode piNode = null;
+ if (divHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = divHalfFloatNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ piNode.replaceAtUsages(divNode);
+ piNode.safeDelete();
+ } else {
+ divHalfFloatNode.replaceAtUsages(divNode);
+ }
+
divHalfFloatNode.safeDelete();
return divNode;
}
diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java
index c098d7c21d..617c9279f9 100644
--- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java
+++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java
@@ -21,9 +21,6 @@
*/
package uk.ac.manchester.tornado.runtime.graal.phases;
-import java.util.ArrayList;
-import java.util.Optional;
-
import jdk.graal.compiler.graph.Node;
import jdk.graal.compiler.nodes.FixedGuardNode;
import jdk.graal.compiler.nodes.GraphState;
@@ -34,6 +31,9 @@
import jdk.graal.compiler.phases.BasePhase;
import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder;
+import java.util.ArrayList;
+import java.util.Optional;
+
public class TornadoHalfFloatFixedGuardElimination extends BasePhase {
private static void deleteFixed(Node node) {
@@ -61,20 +61,24 @@ public Optional notApplicableTo(GraphState graphState) {
protected void run(StructuredGraph graph, TornadoSketchTierContext context) {
ArrayList nodesToBeDeleted = new ArrayList();
for (HalfFloatPlaceholder placeholderNode : graph.getNodes().filter(HalfFloatPlaceholder.class)) {
- if (placeholderNode.getInput() instanceof PiNode placeholderInput) {
- ValueNode halfFloatValue = placeholderInput.object();
- if (halfFloatValue instanceof PiNode) {
- nodesToBeDeleted.add(halfFloatValue);
- halfFloatValue = (ValueNode) halfFloatValue.inputs().first();
- }
- FixedGuardNode placeholderGuard = (FixedGuardNode) placeholderInput.getGuard();
- if (placeholderGuard.inputs().filter(IsNullNode.class).isNotEmpty()) {
- IsNullNode isNullNode = placeholderGuard.inputs().filter(IsNullNode.class).first();
- nodesToBeDeleted.add(isNullNode);
+ if (placeholderNode.getInput() instanceof PiNode placeholderPi) {
+ if (placeholderPi.object() instanceof ValueNode valueNodeToKeep) {
+ if (valueNodeToKeep.usages().filter(IsNullNode.class).isNotEmpty()) {
+ IsNullNode isNullNode = valueNodeToKeep.usages().filter(IsNullNode.class).first();
+
+ if (isNullNode.usages().first() instanceof FixedGuardNode) {
+ nodesToBeDeleted.add(placeholderPi);
+ }
+ }
+
+ FixedGuardNode placeholderGuard = (FixedGuardNode) placeholderPi.getGuard();
+ if (placeholderGuard.inputs().filter(IsNullNode.class).isNotEmpty()) {
+ IsNullNode isNullNode = placeholderGuard.inputs().filter(IsNullNode.class).first();
+ nodesToBeDeleted.add(isNullNode);
+ }
+ deleteFixed(placeholderGuard);
+ placeholderNode.setInput(valueNodeToKeep);
}
- deleteFixed(placeholderGuard);
- placeholderNode.setInput(halfFloatValue);
- nodesToBeDeleted.add(placeholderInput);
}
}
From d97f909d7961532bac6fac9133fd02c50c9c7b33 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 24 Jul 2024 13:15:23 +0300
Subject: [PATCH 37/54] Unify TornadoHalfFloatReplacement phase functionality
amonmg all backends
---
.../phases/TornadoHalfFloatReplacement.java | 49 ++-
.../phases/TornadoHalfFloatReplacement.java | 17 +-
.../phases/TornadoHalfFloatReplacement.java | 287 +++++++++---------
3 files changed, 198 insertions(+), 155 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java
index ce334d165d..959d72c51e 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java
@@ -21,9 +21,6 @@
*/
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;
-import java.util.ArrayList;
-import java.util.Optional;
-
import jdk.graal.compiler.core.common.type.StampFactory;
import jdk.graal.compiler.graph.Node;
import jdk.graal.compiler.nodes.ConstantNode;
@@ -63,6 +60,9 @@
import uk.ac.manchester.tornado.runtime.graal.nodes.VectorHalfRead;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext;
+import java.util.ArrayList;
+import java.util.Optional;
+
public class TornadoHalfFloatReplacement extends BasePhase {
private static ValueNode replaceAdd(AddHalfFloatNode addHalfFloatNode, StructuredGraph graph) {
@@ -114,8 +114,18 @@ public static ValueNode replaceSub(SubHalfFloatNode subHalfFloatNode, Structured
graph.addWithoutUnique(subNode);
}
- subHalfFloatNode.replaceAtUsages(subNode);
+ PiNode piNode = null;
+ if (subHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = subHalfFloatNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ piNode.replaceAtUsages(subNode);
+ piNode.safeDelete();
+ } else {
+ subHalfFloatNode.replaceAtUsages(subNode);
+ }
subHalfFloatNode.safeDelete();
+
return subNode;
}
@@ -153,7 +163,17 @@ private static ValueNode replaceMult(MultHalfFloatNode multHalfFloatNode, Struct
multHalfFloatNode.replaceAtUsages(multNode);
}
- multHalfFloatNode.replaceAtUsages(multNode);
+ PiNode piNode = null;
+ if (multHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = multHalfFloatNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ piNode.replaceAtUsages(multNode);
+ piNode.safeDelete();
+ } else {
+ multHalfFloatNode.replaceAtUsages(multNode);
+ }
+
multHalfFloatNode.safeDelete();
return multNode;
}
@@ -171,7 +191,17 @@ private static ValueNode replaceDiv(DivHalfFloatNode divHalfFloatNode, Structure
FloatDivNode divNode = new FloatDivNode(divX, divY);
graph.addWithoutUnique(divNode);
- divHalfFloatNode.replaceAtUsages(divNode);
+ PiNode piNode = null;
+ if (divHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = divHalfFloatNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ piNode.replaceAtUsages(divNode);
+ piNode.safeDelete();
+ } else {
+ divHalfFloatNode.replaceAtUsages(divNode);
+ }
+
divHalfFloatNode.safeDelete();
return divNode;
}
@@ -321,7 +351,6 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
// if the result of an operation or a stored value is written
writingValue = placeholder.getInput();
}
- System.out.println("Ewring " + placeholder.toString() + " " + placeholder.inputs().first().toString());
placeholder.replaceAtUsages(writingValue);
placeholder.safeDelete();
AddressNode writingAddress = javaWrite.getAddress();
@@ -374,6 +403,12 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
}
}
+ for (HalfFloatPlaceholder placeholder : graph.getNodes().filter(HalfFloatPlaceholder.class)) {
+ ValueNode input = placeholder.getInput();
+ placeholder.replaceAtUsages(input);
+ placeholder.safeDelete();
+ }
+
}
}
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java
index 3bd22b8af4..47acd25f69 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java
@@ -129,18 +129,6 @@ private static ValueNode replaceSub(SubHalfFloatNode subHalfFloatNode, Structure
return subNode;
}
- private static ValueNode nodeToBeReplaced(ValueNode valueNode) {
- PiNode piNode = null;
- if (valueNode.usages().filter(PiNode.class).isNotEmpty()) {
- piNode = valueNode.usages().filter(PiNode.class).first();
- }
- if (piNode != null) {
- return piNode;
- } else {
- return valueNode;
- }
- }
-
private static void replaceSubHalfFloatNodes(StructuredGraph graph) {
for (SubHalfFloatNode subHalfFloatNode : graph.getNodes().filter(SubHalfFloatNode.class)) {
replaceSub(subHalfFloatNode, graph);
@@ -410,5 +398,10 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
}
}
+ for (HalfFloatPlaceholder placeholder : graph.getNodes().filter(HalfFloatPlaceholder.class)) {
+ ValueNode input = placeholder.getInput();
+ placeholder.replaceAtUsages(input);
+ placeholder.safeDelete();
+ }
}
}
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java
index 8fb07eac01..6f432d4b05 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java
@@ -24,12 +24,6 @@
*/
package uk.ac.manchester.tornado.drivers.spirv.graal.phases;
-import java.util.ArrayList;
-import java.util.Optional;
-
-import jdk.vm.ci.meta.Constant;
-import jdk.vm.ci.meta.JavaKind;
-import jdk.vm.ci.meta.RawConstant;
import jdk.graal.compiler.core.common.type.StampFactory;
import jdk.graal.compiler.graph.Node;
import jdk.graal.compiler.nodes.ConstantNode;
@@ -47,7 +41,9 @@
import jdk.graal.compiler.nodes.java.NewInstanceNode;
import jdk.graal.compiler.nodes.memory.address.AddressNode;
import jdk.graal.compiler.phases.BasePhase;
-
+import jdk.vm.ci.meta.Constant;
+import jdk.vm.ci.meta.JavaKind;
+import jdk.vm.ci.meta.RawConstant;
import uk.ac.manchester.tornado.api.internal.annotations.HalfType;
import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVKind;
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.ReadHalfFloatNode;
@@ -67,135 +63,10 @@
import uk.ac.manchester.tornado.runtime.graal.nodes.VectorHalfRead;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext;
-public class TornadoHalfFloatReplacement extends BasePhase {
-
- @Override
- public Optional notApplicableTo(GraphState graphState) {
- return ALWAYS_APPLICABLE;
- }
-
- protected void run(StructuredGraph graph, TornadoHighTierContext context) {
-
- for (ValueAnchorNode valueAnchorNode : graph.getNodes().filter(ValueAnchorNode.class)) {
- ArrayList deletePi = new ArrayList();
- for (Node valueAnchorNodeUsage : valueAnchorNode.usages()) {
- if (valueAnchorNodeUsage instanceof PiNode) {
- PiNode piNode = (PiNode) valueAnchorNodeUsage;
- piNode.replaceAtUsages(piNode.object());
- deletePi.add(piNode);
- }
- }
- for (PiNode p : deletePi) {
- p.safeDelete();
- }
- deleteFixed(valueAnchorNode);
- }
-
- // replace reads with halfFloat reads
- for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
- if (javaRead.successors().first() instanceof NewInstanceNode) {
- NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first();
- if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
- if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
- NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
- deleteFixed(newHalfFloatInstance);
- }
- AddressNode readingAddress = javaRead.getAddress();
- ReadHalfFloatNode readHalfFloatNode = new ReadHalfFloatNode(readingAddress);
- graph.addWithoutUnique(readHalfFloatNode);
- replaceFixed(javaRead, readHalfFloatNode);
- newInstanceNode.replaceAtUsages(readHalfFloatNode);
- deleteFixed(newInstanceNode);
- }
- }
- }
-
- for (NewInstanceNode newInstanceNode : graph.getNodes().filter(NewInstanceNode.class)) {
- if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
- if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
- NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
- ValueNode valueInput = newHalfFloatInstance.getValue();
- newInstanceNode.replaceAtUsages(valueInput);
- deleteFixed(newInstanceNode);
- deleteFixed(newHalfFloatInstance);
- }
- }
- }
-
- // replace writes with halfFloat writes
- for (JavaWriteNode javaWrite : graph.getNodes().filter(JavaWriteNode.class)) {
- if (isWriteHalfFloat(javaWrite)) {
- // This casting is safe to do as it is already checked by the isWriteHalfFloat function
- HalfFloatPlaceholder placeholder = (HalfFloatPlaceholder) javaWrite.value();
- ValueNode writingValue;
- if (javaWrite.predecessor() instanceof NewHalfFloatInstance) {
- // if a new HalfFloat instance is written
- NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) javaWrite.predecessor();
- writingValue = newHalfFloatInstance.getValue();
- if (newHalfFloatInstance.predecessor() instanceof NewInstanceNode) {
- NewInstanceNode newInstanceNode = (NewInstanceNode) newHalfFloatInstance.predecessor();
- if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) {
- deleteFixed(newInstanceNode);
- deleteFixed(newHalfFloatInstance);
- }
- }
- } else {
- // if the result of an operation or a stored value is written
- writingValue = placeholder.getInput();
- }
- placeholder.replaceAtUsages(writingValue);
- placeholder.safeDelete();
- AddressNode writingAddress = javaWrite.getAddress();
- WriteHalfFloatNode writeHalfFloatNode = new WriteHalfFloatNode(writingAddress, writingValue);
- graph.addWithoutUnique(writeHalfFloatNode);
- replaceFixed(javaWrite, writeHalfFloatNode);
- deleteFixed(javaWrite);
- }
- }
-
- // replace the half float operator nodes with the corresponding regular operators
- replaceAddHalfFloatNodes(graph);
- replaceSubHalfFloatNodes(graph);
- replaceMultHalfFloatNodes(graph);
- replaceDivHalfFloatNodes(graph);
-
- // add after the loadindexedvector nodes the marker node to fix the offset of its read
-
- for (LoadIndexedVectorNode loadIndexedVectorNode : graph.getNodes().filter(LoadIndexedVectorNode.class)) {
- if (loadIndexedVectorNode.getSPIRVKind().isHalf()) {
- VectorHalfRead vectorHalfRead;
- if (loadIndexedVectorNode.index() instanceof ConstantNode) {
- ConstantNode offset = (ConstantNode) loadIndexedVectorNode.index();
- int offsetValue = Integer.valueOf(offset.getValue().toValueString());
- vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead(offsetValue));
- } else {
- vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead());
- }
- graph.addAfterFixed(loadIndexedVectorNode, vectorHalfRead);
- }
- }
-
- for (SPIRVVectorValueNode vectorValueNode : graph.getNodes().filter(SPIRVVectorValueNode.class)) {
- if (vectorValueNode.getSPIRVKind().isHalf()) {
- for (Node vectorElement : vectorValueNode.inputs()) {
- if (vectorElement instanceof VectorLoadElementNode) {
- VectorLoadElementNode vectorLoad = (VectorLoadElementNode) vectorElement;
- VectorLoadElementNode vectorLoadShort = new VectorLoadElementNode(SPIRVKind.OP_TYPE_FLOAT_16, vectorLoad.getVector(), vectorLoad.getLaneId());
- graph.addWithoutUnique(vectorLoadShort);
- vectorLoad.replaceAtUsages(vectorLoadShort);
- vectorLoad.safeDelete();
- } else if (vectorElement instanceof ConstantNode constantNode && constantNode.getValue().toValueString().contains("null")) {
- Constant zeroValue = new RawConstant(0);
- ConstantNode zero = new ConstantNode(zeroValue, StampFactory.forKind(JavaKind.Short));
- graph.addWithoutUnique(zero);
- constantNode.replaceAtUsages(zero);
- constantNode.safeDelete();
- }
- }
- }
- }
+import java.util.ArrayList;
+import java.util.Optional;
- }
+public class TornadoHalfFloatReplacement extends BasePhase {
private static ValueNode replaceAdd(AddHalfFloatNode addHalfFloatNode, StructuredGraph graph) {
boolean isVectorOperation = isVectorOp(addHalfFloatNode.getX(), addHalfFloatNode.getY());
@@ -298,7 +169,17 @@ private static ValueNode replaceDiv(DivHalfFloatNode divHalfFloatNode, Structure
FloatDivNode divNode = new FloatDivNode(divX, divY);
graph.addWithoutUnique(divNode);
- divHalfFloatNode.replaceAtUsages(divNode);
+ PiNode piNode = null;
+ if (divHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = divHalfFloatNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ piNode.replaceAtUsages(divNode);
+ piNode.safeDelete();
+ } else {
+ divHalfFloatNode.replaceAtUsages(divNode);
+ }
+
divHalfFloatNode.safeDelete();
return divNode;
}
@@ -378,4 +259,138 @@ private static void deleteFixed(Node node) {
}
}
+ @Override
+ public Optional notApplicableTo(GraphState graphState) {
+ return ALWAYS_APPLICABLE;
+ }
+
+ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
+
+ for (ValueAnchorNode valueAnchorNode : graph.getNodes().filter(ValueAnchorNode.class)) {
+ ArrayList deletePi = new ArrayList();
+ for (Node valueAnchorNodeUsage : valueAnchorNode.usages()) {
+ if (valueAnchorNodeUsage instanceof PiNode) {
+ PiNode piNode = (PiNode) valueAnchorNodeUsage;
+ piNode.replaceAtUsages(piNode.object());
+ deletePi.add(piNode);
+ }
+ }
+ for (PiNode p : deletePi) {
+ p.safeDelete();
+ }
+ deleteFixed(valueAnchorNode);
+ }
+
+ // replace reads with halfFloat reads
+ for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
+ if (javaRead.successors().first() instanceof NewInstanceNode) {
+ NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first();
+ if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
+ if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
+ NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
+ deleteFixed(newHalfFloatInstance);
+ }
+ AddressNode readingAddress = javaRead.getAddress();
+ ReadHalfFloatNode readHalfFloatNode = new ReadHalfFloatNode(readingAddress);
+ graph.addWithoutUnique(readHalfFloatNode);
+ replaceFixed(javaRead, readHalfFloatNode);
+ newInstanceNode.replaceAtUsages(readHalfFloatNode);
+ deleteFixed(newInstanceNode);
+ }
+ }
+ }
+
+ for (NewInstanceNode newInstanceNode : graph.getNodes().filter(NewInstanceNode.class)) {
+ if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
+ if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
+ NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first();
+ ValueNode valueInput = newHalfFloatInstance.getValue();
+ newInstanceNode.replaceAtUsages(valueInput);
+ deleteFixed(newInstanceNode);
+ deleteFixed(newHalfFloatInstance);
+ }
+ }
+ }
+
+ // replace writes with halfFloat writes
+ for (JavaWriteNode javaWrite : graph.getNodes().filter(JavaWriteNode.class)) {
+ if (isWriteHalfFloat(javaWrite)) {
+ // This casting is safe to do as it is already checked by the isWriteHalfFloat function
+ HalfFloatPlaceholder placeholder = (HalfFloatPlaceholder) javaWrite.value();
+ ValueNode writingValue;
+ if (javaWrite.predecessor() instanceof NewHalfFloatInstance) {
+ // if a new HalfFloat instance is written
+ NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) javaWrite.predecessor();
+ writingValue = newHalfFloatInstance.getValue();
+ if (newHalfFloatInstance.predecessor() instanceof NewInstanceNode) {
+ NewInstanceNode newInstanceNode = (NewInstanceNode) newHalfFloatInstance.predecessor();
+ if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) {
+ deleteFixed(newInstanceNode);
+ deleteFixed(newHalfFloatInstance);
+ }
+ }
+ } else {
+ // if the result of an operation or a stored value is written
+ writingValue = placeholder.getInput();
+ }
+ placeholder.replaceAtUsages(writingValue);
+ placeholder.safeDelete();
+ AddressNode writingAddress = javaWrite.getAddress();
+ WriteHalfFloatNode writeHalfFloatNode = new WriteHalfFloatNode(writingAddress, writingValue);
+ graph.addWithoutUnique(writeHalfFloatNode);
+ replaceFixed(javaWrite, writeHalfFloatNode);
+ deleteFixed(javaWrite);
+ }
+ }
+
+ // replace the half float operator nodes with the corresponding regular operators
+ replaceAddHalfFloatNodes(graph);
+ replaceSubHalfFloatNodes(graph);
+ replaceMultHalfFloatNodes(graph);
+ replaceDivHalfFloatNodes(graph);
+
+ // add after the loadindexedvector nodes the marker node to fix the offset of its read
+
+ for (LoadIndexedVectorNode loadIndexedVectorNode : graph.getNodes().filter(LoadIndexedVectorNode.class)) {
+ if (loadIndexedVectorNode.getSPIRVKind().isHalf()) {
+ VectorHalfRead vectorHalfRead;
+ if (loadIndexedVectorNode.index() instanceof ConstantNode) {
+ ConstantNode offset = (ConstantNode) loadIndexedVectorNode.index();
+ int offsetValue = Integer.valueOf(offset.getValue().toValueString());
+ vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead(offsetValue));
+ } else {
+ vectorHalfRead = graph.addWithoutUnique(new VectorHalfRead());
+ }
+ graph.addAfterFixed(loadIndexedVectorNode, vectorHalfRead);
+ }
+ }
+
+ for (SPIRVVectorValueNode vectorValueNode : graph.getNodes().filter(SPIRVVectorValueNode.class)) {
+ if (vectorValueNode.getSPIRVKind().isHalf()) {
+ for (Node vectorElement : vectorValueNode.inputs()) {
+ if (vectorElement instanceof VectorLoadElementNode) {
+ VectorLoadElementNode vectorLoad = (VectorLoadElementNode) vectorElement;
+ VectorLoadElementNode vectorLoadShort = new VectorLoadElementNode(SPIRVKind.OP_TYPE_FLOAT_16, vectorLoad.getVector(), vectorLoad.getLaneId());
+ graph.addWithoutUnique(vectorLoadShort);
+ vectorLoad.replaceAtUsages(vectorLoadShort);
+ vectorLoad.safeDelete();
+ } else if (vectorElement instanceof ConstantNode constantNode && constantNode.getValue().toValueString().contains("null")) {
+ Constant zeroValue = new RawConstant(0);
+ ConstantNode zero = new ConstantNode(zeroValue, StampFactory.forKind(JavaKind.Short));
+ graph.addWithoutUnique(zero);
+ constantNode.replaceAtUsages(zero);
+ constantNode.safeDelete();
+ }
+ }
+ }
+ }
+
+ for (HalfFloatPlaceholder placeholder : graph.getNodes().filter(HalfFloatPlaceholder.class)) {
+ ValueNode input = placeholder.getInput();
+ placeholder.replaceAtUsages(input);
+ placeholder.safeDelete();
+ }
+
+ }
+
}
From 302ff196a12a014da2c4ef0870f07bc1c53d55fe Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 24 Jul 2024 11:58:14 +0100
Subject: [PATCH 38/54] Implement cleanup of Pi nodes in
TornadoHalfFloatReplacement
---
.../spirv/graal/phases/TornadoHalfFloatReplacement.java | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java
index 6f432d4b05..7294189b2a 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java
@@ -281,6 +281,15 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
deleteFixed(valueAnchorNode);
}
+ // cleaup the reminder Pi nodes introduced since Graal 24.0.1
+ for (PiNode piNode : graph.getNodes().filter(PiNode.class)) {
+ for (Node piNodeUsages : piNode.usages()) {
+ if (piNodeUsages instanceof SPIRVVectorValueNode) {
+ piNode.replaceAtUsages(piNode.object());
+ piNode.safeDelete();
+ }
+ }
+ }
// replace reads with halfFloat reads
for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
if (javaRead.successors().first() instanceof NewInstanceNode) {
From faad244182a80fb793d5037be8a828ac8c77ab39 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Wed, 24 Jul 2024 19:17:34 +0300
Subject: [PATCH 39/54] Refactor `isCompatible` method in PTXStamp.java
---
.../ac/manchester/tornado/drivers/ptx/graal/PTXStamp.java | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXStamp.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXStamp.java
index e3951b0599..4fdb8110ff 100644
--- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXStamp.java
+++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXStamp.java
@@ -108,8 +108,13 @@ public boolean isCompatible(Constant constant) {
@Override
public boolean isCompatible(Stamp stamp) {
- if (stamp instanceof PTXStamp && ((PTXStamp) stamp).kind == kind) {
+ if (stamp instanceof PTXStamp) {
return true;
+ } else if (stamp instanceof ObjectStamp) {
+ PTXKind stampKind = PTXKind.fromResolvedJavaType(((ObjectStamp) stamp).type());
+ if (stampKind == kind) {
+ return true;
+ }
}
unimplemented("stamp is compat: %s + %s", this, stamp);
From f425242bf3080d7f5b61e66cc0720250506c6f46 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 25 Jul 2024 09:39:42 +0100
Subject: [PATCH 40/54] Add RemoveValueProxyPhase in SPIRVLowTier.java
---
.../tornado/drivers/spirv/graal/compiler/SPIRVLowTier.java | 2 ++
1 file changed, 2 insertions(+)
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLowTier.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLowTier.java
index bd1669bf5f..b2113ced2b 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLowTier.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVLowTier.java
@@ -32,6 +32,7 @@
import jdk.graal.compiler.phases.common.CanonicalizerPhase;
import jdk.graal.compiler.phases.common.DeadCodeEliminationPhase;
import jdk.graal.compiler.phases.common.FixReadsPhase;
+import jdk.graal.compiler.phases.common.RemoveValueProxyPhase;
import jdk.graal.compiler.phases.common.UseTrappingNullChecksPhase;
import jdk.graal.compiler.phases.common.IterativeConditionalEliminationPhase;
import jdk.graal.compiler.phases.common.LowTierLoweringPhase;
@@ -87,6 +88,7 @@ public SPIRVLowTier(OptionValues options, TornadoDeviceContext deviceContext, Ad
appendPhase(new InverseSquareRootPhase());
}
+
// TODO Atomics Phase for SPIRV (this is the last thing to support)
appendPhase(new SchedulePhase(SchedulePhase.SchedulingStrategy.LATEST_OUT_OF_LOOPS));
From 22f1c696f30ce7b50f0d3142fc94f33839763bcb Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 25 Jul 2024 09:39:56 +0100
Subject: [PATCH 41/54] Improve PiNode replacement handling in SPIRV
transformations
---
.../phases/TornadoHalfFloatReplacement.java | 24 +++++++++++++++++--
1 file changed, 22 insertions(+), 2 deletions(-)
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java
index 7294189b2a..24463b934d 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java
@@ -118,8 +118,18 @@ private static ValueNode replaceSub(SubHalfFloatNode subHalfFloatNode, Structure
subNode = new SubNode(subX, subY);
graph.addWithoutUnique(subNode);
}
- subHalfFloatNode.replaceAtUsages(subNode);
+ PiNode piNode = null;
+ if (subHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = subHalfFloatNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ piNode.replaceAtUsages(subNode);
+ piNode.safeDelete();
+ } else {
+ subHalfFloatNode.replaceAtUsages(subNode);
+ }
subHalfFloatNode.safeDelete();
+
return subNode;
}
@@ -150,7 +160,17 @@ private static ValueNode replaceMult(MultHalfFloatNode multHalfFloatNode, Struct
multNode = new MulNode(multX, multY);
graph.addWithoutUnique(multNode);
}
- multHalfFloatNode.replaceAtUsages(multNode);
+ PiNode piNode = null;
+ if (multHalfFloatNode.usages().filter(PiNode.class).isNotEmpty()) {
+ piNode = multHalfFloatNode.usages().filter(PiNode.class).first();
+ }
+ if (piNode != null) {
+ piNode.replaceAtUsages(multNode);
+ piNode.safeDelete();
+ } else {
+ multHalfFloatNode.replaceAtUsages(multNode);
+ }
+
multHalfFloatNode.safeDelete();
return multNode;
}
From 090f9d17fe418cefa00cd46d9b271de384933811 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Thu, 25 Jul 2024 09:40:10 +0100
Subject: [PATCH 42/54] Add RemoveValueProxyPhase to SPIRVMidTier pipeline
---
.../tornado/drivers/spirv/graal/compiler/SPIRVMidTier.java | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVMidTier.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVMidTier.java
index bc38840192..506101c3e7 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVMidTier.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVMidTier.java
@@ -75,14 +75,14 @@ public SPIRVMidTier(OptionValues options) {
appendPhase(new TornadoFloatingReadReplacement(canonicalizer));
}
+ appendPhase(new RemoveValueProxyPhase(canonicalizer));
+
appendPhase(canonicalizer);
if (ConditionalElimination.getValue(options)) {
appendPhase(new IterativeConditionalEliminationPhase(canonicalizer, true));
}
-// appendPhase(new RemoveValueProxyPhase(canonicalizer));
-
appendPhase(new GuardLoweringPhase());
appendPhase(canonicalizer);
From 0c28c94668197d43033ab97dce341d95d9982a3a Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Sat, 27 Jul 2024 13:07:47 +0100
Subject: [PATCH 43/54] Update Graal JAR version to 24.0.2
---
bin/pull_graal_jars.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/bin/pull_graal_jars.py b/bin/pull_graal_jars.py
index eb57791724..829b74b3db 100755
--- a/bin/pull_graal_jars.py
+++ b/bin/pull_graal_jars.py
@@ -26,7 +26,7 @@
# Constants
TARGET_DIR = "graalJars"
-VERSION = "24.0.1"
+VERSION = "24.0.2"
BASE_URL = "https://repo1.maven.org/maven2/org/graalvm"
GRAAL_JARS = [
f"compiler/compiler/{VERSION}/compiler-{VERSION}.jar",
From be1fdfa3a5691d1f47791dc4a39c94aff84eccb4 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Sat, 27 Jul 2024 13:38:30 +0100
Subject: [PATCH 44/54] Remove extra `--ea` flags from test commands
---
Makefile | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/Makefile b/Makefile
index cb57af98f6..ae75e3a9d0 100644
--- a/Makefile
+++ b/Makefile
@@ -43,14 +43,14 @@ tests:
rm -f tornado_unittests.log
tornado --devices
tornado-test --verbose
- tornado-test --ea -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
+ tornado-test -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
test-native.sh
fast-tests:
rm -f tornado_unittests.log
tornado --devices
- tornado-test --ea --verbose --quickPass
- tornado-test --ea -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
+ tornado-test --verbose --quickPass
+ tornado-test -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
test-native.sh
tests-spirv-levelzero:
From 2ffe8ac5aa2bda2b3549360fed3aa67d5bd5a1ab Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Mon, 29 Jul 2024 14:52:22 +0100
Subject: [PATCH 45/54] Update `isCompatible` to handle `ObjectStamp` in
`OCLStamp`
---
.../tornado/drivers/opencl/graal/OCLStamp.java | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLStamp.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLStamp.java
index ea78fd10af..c211b217a6 100644
--- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLStamp.java
+++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLStamp.java
@@ -123,9 +123,14 @@ public boolean isCompatible(Constant constant) {
@Override
public boolean isCompatible(Stamp stamp) {
- if (stamp instanceof OCLStamp && ((OCLStamp) stamp).oclKind == oclKind) {
- return true;
- }
+ if (stamp instanceof OCLStamp) {
+ return true;
+ } else if (stamp instanceof ObjectStamp) {
+ OCLKind stampKind = OCLKind.fromResolvedJavaType(((ObjectStamp) stamp).type());
+ if (stampKind == oclKind) {
+ return true;
+ }
+ }
unimplemented("stamp iscompat: %s + %s", this, stamp);
return false;
From cdd178479819e07acf1f756c27b88a222733a0a9 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Mon, 29 Jul 2024 15:46:35 +0100
Subject: [PATCH 46/54] Simplify SPIRVStamp compatibility check
---
.../manchester/tornado/drivers/spirv/graal/SPIRVStamp.java | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVStamp.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVStamp.java
index a242eafbfe..ddfbd3efcd 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVStamp.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVStamp.java
@@ -132,8 +132,13 @@ public boolean isCompatible(Constant constant) {
@Override
public boolean isCompatible(Stamp stamp) {
- if (stamp instanceof SPIRVStamp && ((SPIRVStamp) stamp).spirvKind == spirvKind) {
+ if (stamp instanceof SPIRVStamp) {
return true;
+ } else if (stamp instanceof ObjectStamp) {
+ SPIRVKind stampKind = SPIRVKind.fromResolvedJavaType(((ObjectStamp) stamp).type());
+ if (stampKind == spirvKind) {
+ return true;
+ }
}
unimplemented("stamp is compat: %s + %s", this, stamp);
return false;
From f5009224e3e57853db5f1de39c4c4c0155036307 Mon Sep 17 00:00:00 2001
From: Thanos Stratikopoulos
Date: Mon, 29 Jul 2024 17:57:48 +0300
Subject: [PATCH 47/54] [feat] Update graalvm version in pom file
---
pom.xml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pom.xml b/pom.xml
index 424b581a57..766fa3573b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -18,7 +18,7 @@
${project.version}
- 24.0.1
+ 24.0.2UTF-8${platform}false
From 6c58f20a3b650c449ffbb0502a8e3ef7ea5fc0d8 Mon Sep 17 00:00:00 2001
From: mikepapadim
Date: Tue, 30 Jul 2024 10:15:16 +0100
Subject: [PATCH 48/54] Refactor method to use fromJavaKind over
fromResolvedJavaType
---
.../ac/manchester/tornado/drivers/spirv/graal/SPIRVStamp.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVStamp.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVStamp.java
index ddfbd3efcd..0de25ce19c 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVStamp.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/SPIRVStamp.java
@@ -135,7 +135,7 @@ public boolean isCompatible(Stamp stamp) {
if (stamp instanceof SPIRVStamp) {
return true;
} else if (stamp instanceof ObjectStamp) {
- SPIRVKind stampKind = SPIRVKind.fromResolvedJavaType(((ObjectStamp) stamp).type());
+ SPIRVKind stampKind = SPIRVKind.fromJavaKind(((ObjectStamp) stamp).type());
if (stampKind == spirvKind) {
return true;
}
From c0f6e63ec3f085181b9b5df157787c567546e924 Mon Sep 17 00:00:00 2001
From: Thanos Stratikopoulos <34061419+stratika@users.noreply.github.com>
Date: Wed, 31 Jul 2024 12:02:19 +0300
Subject: [PATCH 49/54] Update installer_config.py for JDK 22
---
bin/installer_config.py | 110 ++++++++++++++++++++--------------------
1 file changed, 55 insertions(+), 55 deletions(-)
diff --git a/bin/installer_config.py b/bin/installer_config.py
index 8d9c129abc..18693c852a 100644
--- a/bin/installer_config.py
+++ b/bin/installer_config.py
@@ -26,14 +26,14 @@
__APPLE__ = "darwin"
__WINDOWS__ = "windows"
-__JDK21__ = "jdk21"
-__GRAALVM21__ = "graalvm-jdk-21"
-__MANDREL21__ = "mandrel-jdk-21"
-__CORRETTO21__ = "corretto-jdk-21"
-__MICROSOFT21__ = "microsoft-jdk-21"
-__ZULU21__ = "zulu-jdk-21"
-__TEMURIN21__ = "temurin-jdk-21"
-__SAPMACHINE21__ = "sapmachine-jdk-21"
+__JDK21__ = "jdk22"
+__GRAALVM21__ = "graal-jdk-22"
+__MANDREL21__ = "mandrel-jdk-22"
+__CORRETTO21__ = "corretto-jdk-22"
+__MICROSOFT21__ = "microsoft-jdk-22"
+__ZULU21__ = "zulu-jdk-22"
+__TEMURIN21__ = "temurin-jdk-22"
+__SAPMACHINE21__ = "sapmachine-jdk-22"
## cmake
CMAKE = {
@@ -69,115 +69,115 @@
## JDK
JDK = {
- __JDK21__: {
+ __JDK22__: {
__LINUX__: {
- __X86_64__: "https://download.oracle.com/java/21/latest/jdk-21_linux-x64_bin.tar.gz",
- __ARM__: "https://download.oracle.com/java/21/latest/jdk-21_linux-aarch64_bin.tar.gz",
+ __X86_64__: "https://download.oracle.com/java/22/latest/jdk-22_linux-x64_bin.tar.gz",
+ __ARM__: "https://download.oracle.com/java/22/latest/jdk-22_linux-aarch64_bin.tar.gz",
},
__APPLE__: {
- __X86_64__: "https://download.oracle.com/java/21/latest/jdk-21_macos-x64_bin.tar.gz",
- __ARM__: "https://download.oracle.com/java/21/latest/jdk-21_macos-aarch64_bin.tar.gz",
+ __X86_64__: "https://download.oracle.com/java/22/latest/jdk-22_macos-x64_bin.tar.gz",
+ __ARM__: "https://download.oracle.com/java/22/latest/jdk-22_macos-aarch64_bin.tar.gz",
},
__WINDOWS__: {
- __X86_64__: "https://download.oracle.com/java/21/archive/jdk-21.0.1_windows-x64_bin.zip",
+ __X86_64__: "https://download.oracle.com/java/22/archive/jdk-22.0.2_windows-x64_bin.zip",
__ARM__: None,
},
},
- __GRAALVM21__: {
+ __GRAALVM22__: {
__LINUX__: {
- __X86_64__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.1/graalvm-community-jdk-21.0.1_linux-x64_bin.tar.gz",
- __ARM__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.1/graalvm-community-jdk-21.0.1_linux-aarch64_bin.tar.gz",
+ __X86_64__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-22.0.2/graalvm-community-jdk-22.0.2_linux-x64_bin.tar.gz",
+ __ARM__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-22.0.2/graalvm-community-jdk-22.0.2_linux-aarch64_bin.tar.gz",
},
__APPLE__: {
- __X86_64__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.1/graalvm-community-jdk-21.0.1_macos-x64_bin.tar.gz",
- __ARM__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.1/graalvm-community-jdk-21.0.1_macos-aarch64_bin.tar.gz",
+ __X86_64__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-22.0.2/graalvm-community-jdk-22.0.2_macos-x64_bin.tar.gz",
+ __ARM__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-22.0.2/graalvm-community-jdk-22.0.2_macos-aarch64_bin.tar.gz",
},
__WINDOWS__: {
- __X86_64__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.1/graalvm-community-jdk-21.0.1_windows-x64_bin.zip",
+ __X86_64__: "https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-22.0.2/graalvm-community-jdk-22.0.2_windows-x64_bin.zip",
__ARM__: None,
},
},
- __CORRETTO21__: {
+ __CORRETTO22__: {
__LINUX__: {
- __X86_64__: "https://corretto.aws/downloads/latest/amazon-corretto-21-x64-linux-jdk.tar.gz",
- __ARM__: "https://corretto.aws/downloads/latest/amazon-corretto-21-aarch64-linux-jdk.tar.gz",
+ __X86_64__: "https://corretto.aws/downloads/latest/amazon-corretto-22-x64-linux-jdk.tar.gz",
+ __ARM__: "https://corretto.aws/downloads/latest/amazon-corretto-22-aarch64-linux-jdk.tar.gz",
},
__APPLE__: {
- __X86_64__: "https://corretto.aws/downloads/latest/amazon-corretto-21-x64-macos-jdk.tar.gz",
- __ARM__: "https://corretto.aws/downloads/latest/amazon-corretto-21-aarch64-macos-jdk.tar.gz",
+ __X86_64__: "https://corretto.aws/downloads/latest/amazon-corretto-22-x64-macos-jdk.tar.gz",
+ __ARM__: "https://corretto.aws/downloads/latest/amazon-corretto-22-aarch64-macos-jdk.tar.gz",
},
__WINDOWS__: {
- __X86_64__: None,
+ __X86_64__: "https://corretto.aws/downloads/latest/amazon-corretto-22-x64-windows-jdk.zip",
__ARM__: None,
},
},
- __MANDREL21__: {
+ __MANDREL22__: {
__LINUX__: {
- __X86_64__: "https://github.com/graalvm/mandrel/releases/download/mandrel-23.1.0.0-Final/mandrel-java21-linux-amd64-23.1.0.0-Final.tar.gz",
- __ARM__: "https://github.com/graalvm/mandrel/releases/download/mandrel-23.1.0.0-Final/mandrel-java21-linux-aarch64-23.1.0.0-Final.tar.gz",
+ __X86_64__: "https://github.com/graalvm/mandrel/releases/download/mandrel-24.0.2.0-Final/mandrel-java22-linux-amd64-24.0.2.0-Final.tar.gz",
+ __ARM__: "https://github.com/graalvm/mandrel/releases/download/mandrel-24.0.2.0-Final/mandrel-java22-linux-aarch64-24.0.2.0-Final.tar.gz",
},
__APPLE__: {
__X86_64__: None,
- __ARM__: None,
+ __ARM__: "https://github.com/graalvm/mandrel/releases/download/mandrel-24.0.2.0-Final/mandrel-java22-macos-aarch64-24.0.2.0-Final.tar.gz",
},
__WINDOWS__: {
- __X86_64__: None,
+ __X86_64__: "https://github.com/graalvm/mandrel/releases/download/mandrel-24.0.2.0-Final/mandrel-java22-windows-amd64-24.0.2.0-Final.zip",
__ARM__: None,
},
},
- __MICROSOFT21__: {
+ __MICROSOFT22__: {
__LINUX__: {
- __X86_64__: "https://aka.ms/download-jdk/microsoft-jdk-21.0.3-linux-x64.tar.gz",
- __ARM__: "https://aka.ms/download-jdk/microsoft-jdk-21.0.3-linux-aarch64.tar.gz",
+ __X86_64__: None,
+ __ARM__: None,
},
__APPLE__: {
- __X86_64__: "https://aka.ms/download-jdk/microsoft-jdk-21.0.3-macos-x64.tar.gz",
- __ARM__: "https://aka.ms/download-jdk/microsoft-jdk-21.0.3-macos-aarch64.tar.gz",
+ __X86_64__: None,
+ __ARM__: None,
},
__WINDOWS__: {
- __X86_64__: "https://aka.ms/download-jdk/microsoft-jdk-21.0.3-windows-x64.zip",
- __ARM__: "https://aka.ms/download-jdk/microsoft-jdk-21.0.3-windows-aarch64.zip",
+ __X86_64__: None,
+ __ARM__: None,
},
},
- __ZULU21__: {
+ __ZULU22__: {
__LINUX__: {
- __X86_64__: "https://cdn.azul.com/zulu/bin/zulu21.28.85-ca-jdk21.0.0-linux_x64.tar.gz",
- __ARM__: "https://cdn.azul.com/zulu/bin/zulu21.28.85-ca-jdk21.0.0-linux_aarch64.tar.gz",
+ __X86_64__: "https://cdn.azul.com/zulu/bin/zulu22.32.15-ca-jdk22.0.2-linux_x64.tar.gz",
+ __ARM__: "https://cdn.azul.com/zulu/bin/zulu22.32.15-ca-jdk22.0.2-linux_aarch64.tar.gz",
},
__APPLE__: {
- __X86_64__: "https://cdn.azul.com/zulu/bin/zulu21.28.85-ca-jdk21.0.0-macosx_x64.tar.gz",
- __ARM__: "https://cdn.azul.com/zulu/bin/zulu21.28.85-ca-jdk21.0.0-macosx_aarch64.tar.gz",
+ __X86_64__: "https://cdn.azul.com/zulu/bin/zulu22.32.15-ca-jdk22.0.2-macosx_x64.tar.gz",
+ __ARM__: "https://cdn.azul.com/zulu/bin/zulu22.32.15-ca-jdk22.0.2-macosx_aarch64.tar.gz",
},
__WINDOWS__: {
__X86_64__: None,
__ARM__: None,
},
},
- __TEMURIN21__: {
+ __TEMURIN22__: {
__LINUX__: {
- __X86_64__: "https://github.com/adoptium/temurin21-binaries/releases/download/jdk-21.0.1%2B12/OpenJDK21U-jdk_x64_linux_hotspot_21.0.1_12.tar.gz",
- __ARM__: "https://github.com/adoptium/temurin21-binaries/releases/download/jdk-21.0.1%2B12/OpenJDK21U-jdk_aarch64_linux_hotspot_21.0.1_12.tar.gz",
+ __X86_64__: "https://github.com/adoptium/temurin22-binaries/releases/download/jdk-22.0.1%2B8/OpenJDK22U-jdk_x64_linux_hotspot_22.0.1_8.tar.gz",
+ __ARM__: "https://github.com/adoptium/temurin22-binaries/releases/download/jdk-22.0.1%2B8/OpenJDK22U-jdk_aarch64_linux_hotspot_22.0.1_8.tar.gz",
},
__APPLE__: {
- __X86_64__: "https://github.com/adoptium/temurin21-binaries/releases/download/jdk-21.0.1%2B12/OpenJDK21U-jdk_x64_mac_hotspot_21.0.1_12.tar.gz",
- __ARM__: "https://github.com/adoptium/temurin21-binaries/releases/download/jdk-21.0.1%2B12/OpenJDK21U-jdk_aarch64_mac_hotspot_21.0.1_12.tar.gz",
+ __X86_64__: "https://github.com/adoptium/temurin22-binaries/releases/download/jdk-22.0.1%2B8/OpenJDK22U-jdk_x64_mac_hotspot_22.0.1_8.tar.gz",
+ __ARM__: "https://github.com/adoptium/temurin22-binaries/releases/download/jdk-22.0.1%2B8/OpenJDK22U-jdk_aarch64_mac_hotspot_22.0.1_8.tar.gz",
},
__WINDOWS__: {
- __X86_64__: "https://github.com/adoptium/temurin21-binaries/releases/download/jdk-21.0.3%2B9/OpenJDK21U-jdk_x64_windows_hotspot_21.0.3_9.zip",
+ __X86_64__: "https://github.com/adoptium/temurin22-binaries/releases/download/jdk-22.0.1%2B8/OpenJDK22U-jdk_x64_windows_hotspot_22.0.1_8.zip",
__ARM__: None,
},
},
- __SAPMACHINE21__: {
+ __SAPMACHINE22__: {
__LINUX__: {
- __X86_64__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-21.0.3/sapmachine-jdk-21.0.3_linux-x64_bin.tar.gz",
- __ARM__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-21.0.3/sapmachine-jdk-21.0.3_linux-aarch64_bin.tar.gz",
+ __X86_64__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-22.0.2/sapmachine-jdk-22.0.2_linux-x64_bin.tar.gz",
+ __ARM__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-22.0.2/sapmachine-jdk-22.0.2_linux-aarch64_bin.tar.gz",
},
__APPLE__: {
- __X86_64__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-21.0.3/sapmachine-jdk-21.0.3_macos-x64_bin.tar.gz",
- __ARM__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-21.0.3/sapmachine-jdk-21.0.3_macos-aarch64_bin.tar.gz",
+ __X86_64__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-22.0.2/sapmachine-jdk-22.0.2_macos-x64_bin.tar.gz",
+ __ARM__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-22.0.2/sapmachine-jdk-22.0.2_macos-aarch64_bin.tar.gz",
},
__WINDOWS__: {
- __X86_64__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-21.0.3/sapmachine-jdk-21.0.3_windows-x64_bin.zip",
+ __X86_64__: "https://github.com/SAP/SapMachine/releases/download/sapmachine-22.0.2/sapmachine-jdk-22.0.2_windows-x64_bin.zip",
__ARM__: None,
},
},
From e63d1bd1be96b0fe1ef418980f08684e1ab0d477 Mon Sep 17 00:00:00 2001
From: Thanos Stratikopoulos <34061419+stratika@users.noreply.github.com>
Date: Wed, 31 Jul 2024 12:12:54 +0300
Subject: [PATCH 50/54] Update tornadovm-installer
---
bin/tornadovm-installer | 38 +++++++++++++++++++-------------------
1 file changed, 19 insertions(+), 19 deletions(-)
diff --git a/bin/tornadovm-installer b/bin/tornadovm-installer
index 972667fb35..76a55bc1ed 100755
--- a/bin/tornadovm-installer
+++ b/bin/tornadovm-installer
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
#
-# Copyright (c) 2013-2023, APT Group, Department of Computer Science,
+# Copyright (c) 2013-2024, APT Group, Department of Computer Science,
# The University of Manchester.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -36,14 +36,14 @@ __DIRECTORY_DEPENDENCIES__ = os.path.join("etc", "dependencies")
__VERSION__ = "v1.0.5"
__SUPPORTED_JDKS__ = [
- config.__JDK21__,
- config.__GRAALVM21__,
- config.__CORRETTO21__,
- config.__MICROSOFT21__,
- config.__MANDREL21__,
- config.__ZULU21__,
- config.__TEMURIN21__,
- config.__SAPMACHINE21__,
+ config.__JDK22__,
+ config.__GRAALVM22__,
+ config.__CORRETTO22__,
+ config.__MICROSOFT22__,
+ config.__MANDREL22__,
+ config.__ZULU22__,
+ config.__TEMURIN22__,
+ config.__SAPMACHINE22__,
]
__SUPPORTED_BACKENDS__ = ["opencl", "spirv", "ptx"]
@@ -326,10 +326,10 @@ class TornadoInstaller:
backend = self.composeBackendOption(args)
- makeJDK = "jdk21"
+ makeJDK = "jdk22"
polyglotOption = ""
if (args.javaHome != None and "graal" in args.javaHome) or (args.jdk != None and "graal" in args.jdk):
- makeJDK = "graal-jdk-21"
+ makeJDK = "graal-jdk-22"
polyglotOption = self.composePolyglotOption(args)
if args.javaHome != None:
@@ -358,14 +358,14 @@ def listSupportedJDKs():
"""
TornadoVM Installer - Select a JDK implementation to install with TornadoVM:
- jdk21 : Install TornadoVM with OpenJDK 21 (Oracle OpenJDK)
- graal-jdk-21 : Install TornadoVM with GraalVM and JDK 21 (GraalVM 23.1.0)
- mandrel-jdk-21 : Install TornadoVM with Mandrel and JDK 21 (GraalVM 23.1.0)
- corretto-jdk-21 : Install TornadoVM with Corretto JDK 21
- microsoft-jdk-21 : Install TornadoVM with Microsoft JDK 21
- zulu-jdk-21 : Install TornadoVM with Azul Zulu JDK 21
- temurin-jdk-21 : Install TornadoVM with Eclipse Temurin JDK 21
- sapmachine-jdk-21 : Install TornadoVM with SapMachine OpenJDK 21
+ jdk22 : Install TornadoVM with OpenJDK 22 (Oracle OpenJDK)
+ graal-jdk-22 : Install TornadoVM with GraalVM and JDK 22 (GraalVM 24.0.2)
+ mandrel-jdk-22 : Install TornadoVM with Mandrel and JDK 22 (GraalVM 24.0.2)
+ corretto-jdk-22 : Install TornadoVM with Corretto JDK 22
+ microsoft-jdk-22 : Install TornadoVM with Microsoft JDK 22
+ zulu-jdk-22 : Install TornadoVM with Azul Zulu JDK 22
+ temurin-jdk-22 : Install TornadoVM with Eclipse Temurin JDK 22
+ sapmachine-jdk-22 : Install TornadoVM with SapMachine OpenJDK 22
Usage:
$ ./bin/tornadovm-installer --jdk --backend
From a6c083b9ebd7094bef5584cd0e1c2bf39cb68155 Mon Sep 17 00:00:00 2001
From: Thanos Stratikopoulos <34061419+stratika@users.noreply.github.com>
Date: Wed, 31 Jul 2024 12:14:18 +0300
Subject: [PATCH 51/54] Update installer_config.py
---
bin/installer_config.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/bin/installer_config.py b/bin/installer_config.py
index 18693c852a..8665438ed2 100644
--- a/bin/installer_config.py
+++ b/bin/installer_config.py
@@ -26,13 +26,13 @@
__APPLE__ = "darwin"
__WINDOWS__ = "windows"
-__JDK21__ = "jdk22"
-__GRAALVM21__ = "graal-jdk-22"
-__MANDREL21__ = "mandrel-jdk-22"
-__CORRETTO21__ = "corretto-jdk-22"
-__MICROSOFT21__ = "microsoft-jdk-22"
-__ZULU21__ = "zulu-jdk-22"
-__TEMURIN21__ = "temurin-jdk-22"
+__JDK22__ = "jdk22"
+__GRAALVM22__ = "graal-jdk-22"
+__MANDREL22__ = "mandrel-jdk-22"
+__CORRETTO22__ = "corretto-jdk-22"
+__MICROSOFT22__ = "microsoft-jdk-22"
+__ZULU22__ = "zulu-jdk-22"
+__TEMURIN22__ = "temurin-jdk-22"
__SAPMACHINE21__ = "sapmachine-jdk-22"
## cmake
From dec5a8756d69aaff3dd9d0d89bdb10febf8d703a Mon Sep 17 00:00:00 2001
From: Thanos Stratikopoulos <34061419+stratika@users.noreply.github.com>
Date: Wed, 31 Jul 2024 12:14:47 +0300
Subject: [PATCH 52/54] Update installer_config.py
---
bin/installer_config.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/bin/installer_config.py b/bin/installer_config.py
index 8665438ed2..cb170994bc 100644
--- a/bin/installer_config.py
+++ b/bin/installer_config.py
@@ -33,7 +33,7 @@
__MICROSOFT22__ = "microsoft-jdk-22"
__ZULU22__ = "zulu-jdk-22"
__TEMURIN22__ = "temurin-jdk-22"
-__SAPMACHINE21__ = "sapmachine-jdk-22"
+__SAPMACHINE22__ = "sapmachine-jdk-22"
## cmake
CMAKE = {
From ac2795378a1fc3154f7d56f67c7c306f41b75ded Mon Sep 17 00:00:00 2001
From: Thanos Stratikopoulos <34061419+stratika@users.noreply.github.com>
Date: Thu, 1 Aug 2024 11:02:52 +0300
Subject: [PATCH 53/54] Update Makefile.mak
---
Makefile.mak | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/Makefile.mak b/Makefile.mak
index dbe5130f1c..f42a8876de 100644
--- a/Makefile.mak
+++ b/Makefile.mak
@@ -4,23 +4,23 @@ all: build
# nmake BACKENDS=""
BACKEND = opencl
-build jdk21:
- python bin\compile --jdk jdk21 --backend $(BACKEND)
+build jdk22:
+ python bin\compile --jdk jdk22 --backend $(BACKEND)
rebuild-deps:
- python bin\compile --jdk graal-jdk-21 --rebuild --backend $(BACKEND)
+ python bin\compile --jdk graal-jdk-22 --rebuild --backend $(BACKEND)
-graal-jdk-21:
- python bin\compile --jdk graal-jdk-21 --backend $(BACKEND)
+graal-jdk-22:
+ python bin\compile --jdk graal-jdk-22 --backend $(BACKEND)
polyglot:
- python bin\compile --jdk graal-jdk-21 --backend $(BACKEND) --polyglot
+ python bin\compile --jdk graal-jdk-22 --backend $(BACKEND) --polyglot
ptx:
- python bin\compile --jdk graal-jdk-21 --backend ptx,opencl
+ python bin\compile --jdk graal-jdk-22 --backend ptx,opencl
spirv:
- python bin\compile --jdk graal-jdk-21 --backend spirv,ptx,opencl
+ python bin\compile --jdk graal-jdk-22 --backend spirv,ptx,opencl
# Variable passed for the preparation of the Xilinx FPGA emulated target device. The default device is `xilinx_u50_gen3x16_xdma_201920_3`.
# make xilinx_emulation FPGA_PLATFORM= NUM_OF_FPGA_DEVICES=
From e6b5dcabc86dc2cea4210c7eadd5f6752cd06995 Mon Sep 17 00:00:00 2001
From: Thanos Stratikopoulos <34061419+stratika@users.noreply.github.com>
Date: Thu, 1 Aug 2024 11:35:46 +0300
Subject: [PATCH 54/54] Update Makefile.mak with new rules same as in Makefile
---
Makefile.mak | 25 +++++++++++++++++++++++--
1 file changed, 23 insertions(+), 2 deletions(-)
diff --git a/Makefile.mak b/Makefile.mak
index f42a8876de..4e0d50f28f 100644
--- a/Makefile.mak
+++ b/Makefile.mak
@@ -42,10 +42,31 @@ example:
tests:
del /f tornado_unittests.log
python %TORNADO_SDK%\bin\tornado --devices
- python %TORNADO_SDK%\bin\tornado-test --ea --verbose
- python %TORNADO_SDK%\bin\tornado-test --ea -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
+ python %TORNADO_SDK%\bin\tornado-test --verbose
+ python %TORNADO_SDK%\bin\tornado-test -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
%TORNADO_SDK%\bin\test-native.cmd
+fast-tests:
+ del /f tornado_unittests.log
+ python %TORNADO_SDK%\bin\tornado --devices
+ python %TORNADO_SDK%\bin\tornado-test --verbose --quickPass
+ python %TORNADO_SDK%\bin\tornado-test -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
+ test-native.sh
+
+tests-spirv-levelzero:
+ del /f tornado_unittests.log
+ python %TORNADO_SDK%\bin\tornado --jvm="-Dtornado.spirv.dispatcher=levelzero" uk.ac.manchester.tornado.drivers.TornadoDeviceQuery --params="verbose"
+ python %TORNADO_SDK%\bin\tornado-test --jvm="-Dtornado.spirv.dispatcher=levelzero" --ea --verbose
+ python %TORNADO_SDK%\bin\tornado-test --jvm="-Dtornado.spirv.dispatcher=levelzero"--ea -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
+ test-native.sh
+
+tests-spirv-opencl:
+ del /f tornado_unittests.log
+ python %TORNADO_SDK%\bin\tornado --jvm="-Dtornado.spirv.dispatcher=opencl" uk.ac.manchester.tornado.drivers.TornadoDeviceQuery --params="verbose"
+ python %TORNADO_SDK%\bin\tornado-test --jvm="-Dtornado.spirv.dispatcher=opencl" --ea --verbose
+ python %TORNADO_SDK%\bin\tornado-test --jvm="-Dtornado.spirv.dispatcher=opencl"--ea -V -J"-Dtornado.device.memory=1MB" uk.ac.manchester.tornado.unittests.fails.HeapFail#test03
+ test-native.sh
+
tests-opt:
python %TORNADO_SDK%\bin\tornado --devices
python %TORNADO_SDK%\bin\tornado-test -V --fast --ea --verbose -J"-Dtornado.spirv.loadstore=True" --printKernel