Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for reduction with sum of squares (RSS), compute of RNSNorm, RNSNorm fused with Matmul, FindMaxAttention and Softmax #593

Open
wants to merge 29 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
57faae1
Add test to ensure sum of squares reduction code gen works
mikepapadim Nov 23, 2024
02186f9
Add test to mimic the computation in llama3 forward method
mikepapadim Nov 23, 2024
d2c14c8
Merge branch 'develop' into test/test_reduce_square_sum
mikepapadim Mar 12, 2025
7b6eae6
[Refactor] Rename LLMFusedKernelsTest to TestLLMFusedKernels and upda…
mikepapadim Mar 12, 2025
82a6137
[refactor] Rename MMwithBytes and LLMFusedKernelsTest classes to Test…
mikepapadim Mar 12, 2025
0114e1c
[add] Implement TestSoftMaxLayer with sequential softmax and parallel…
mikepapadim Mar 12, 2025
1999cc8
[add] Implement TestSoftMaxLayer with sequential softmax and parallel…
mikepapadim Mar 12, 2025
3d18a14
Add javadoc in SoftMax
mikepapadim Mar 12, 2025
5390a0f
Add SoftMax test in tornado-test
mikepapadim Mar 12, 2025
72aa06c
Implement TestFindMaxAttention with sequential and parallel max atten…
mikepapadim Mar 12, 2025
0d9c172
Make localmemory allocation to be dynamic
mikepapadim Mar 13, 2025
3016923
Format code
mikepapadim Mar 13, 2025
0ea8ab4
Add multihead attention test
mikepapadim Mar 14, 2025
368448d
Refactor multi-head attention tests for improved clarity and performance
mikepapadim Mar 14, 2025
1c5c179
Add TestWeightedSum for weighted sum calculation and debugging
mikepapadim Mar 14, 2025
ecbf76f
Merge branch 'develop' of github.com:beehive-lab/TornadoVM into test/…
mikepapadim Mar 14, 2025
7f34de6
Add test for FFN layer
mikepapadim Mar 14, 2025
02e528c
Add unit tests for fused layer with RoPE and RMSNorm
mikepapadim Mar 14, 2025
0f4e471
Add copyright headers to test files
mikepapadim Mar 14, 2025
7976b89
Add additional unit tests for attention and layer functionalities
mikepapadim Mar 17, 2025
f09620f
[fix] Add sqrt plugin for float type in SPIRV
stratika Mar 18, 2025
37334ce
[refactor][fix] Refactored class name and resolved the issue of seque…
stratika Mar 18, 2025
ae91eae
[test] Added the TestParallelFFNLayer to be run with assertions disab…
stratika Mar 18, 2025
12fc5d9
[fix] Replaced the Math functions with TornadoMath functions for the …
stratika Mar 18, 2025
b35351d
[revert] Removed sqrt plugin for float type in SPIRV, since this is f…
stratika Mar 18, 2025
910eaeb
[test] Removed print messages and added a flag with False as default …
stratika Mar 18, 2025
8eba390
[test] Added the enable assertions check when the script is invoked f…
stratika Mar 18, 2025
fd4a4f9
[test] Fix indentation error
stratika Mar 18, 2025
ed0d374
[test] Removed unnecessary parameter
stratika Mar 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions tornado-assembly/src/bin/tornado-test
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ MONITOR_REGISTRY = {
# ################################################################################################################

class TestEntry:
def __init__(self, testName, testMethods=None, testParameters=None, monitorClass=None):
def __init__(self, testName, testMethods=None, testParameters=None, monitorClass=None, enableAssertions=None):
self.testName = testName
self.testMethods = testMethods
self.testParameters = testParameters
self.monitorClass = monitorClass
self.enableAssertions = enableAssertions


## List of classes to be tested. Include new unittest classes here
Expand Down Expand Up @@ -198,7 +199,18 @@ __TEST_THE_WORLD__ = [
TestEntry("uk.ac.manchester.tornado.unittests.codegen.CodeGenTest"),
TestEntry("uk.ac.manchester.tornado.unittests.atomics.TestAtomics"),
TestEntry("uk.ac.manchester.tornado.unittests.compute.ComputeTests"),
TestEntry("uk.ac.manchester.tornado.unittests.compute.MMwithBytes"),
TestEntry("uk.ac.manchester.tornado.unittests.llm.TestMatMulWithByteArrays"),
TestEntry("uk.ac.manchester.tornado.unittests.llm.TestRNSNormLayer"),
TestEntry("uk.ac.manchester.tornado.unittests.llm.TestSoftMaxLayer"),
TestEntry("uk.ac.manchester.tornado.unittests.llm.TestFindMaxAttention"),
TestEntry("uk.ac.manchester.tornado.unittests.llm.TestFusedLayer"),

## The TestParallelFFNLayer test is explicitly marked to run with the assertions disabled as it results in long compilation time
TestEntry("uk.ac.manchester.tornado.unittests.llm.TestParallelFFNLayer",
enableAssertions=False),

TestEntry("uk.ac.manchester.tornado.unittests.llm.TestWeightedSum"),
TestEntry("uk.ac.manchester.tornado.unittests.llm.TestMultiHeadAttention"),
TestEntry("uk.ac.manchester.tornado.unittests.dynamic.TestDynamic"),
TestEntry("uk.ac.manchester.tornado.unittests.vector.api.TestVectorAPI"),
TestEntry("uk.ac.manchester.tornado.unittests.api.TestConcat"),
Expand Down Expand Up @@ -284,7 +296,7 @@ __TORNADO_TESTS_WHITE_LIST__ = [

## Precision errors
"uk.ac.manchester.tornado.unittests.compute.ComputeTests#testNBodyBigNoWorker",
"uk.ac.manchester.tornado.unittests.compute.MMwithBytes#testMatrixMultiplicationWithBytes",
"uk.ac.manchester.tornado.unittests.llm.TestMatMulWithByteArrays#testMatrixMultiplicationWithBytes",
"uk.ac.manchester.tornado.unittests.compute.ComputeTests#testEuler",
"uk.ac.manchester.tornado.unittests.codegen.CodeGenTest#test02",
"uk.ac.manchester.tornado.unittests.reductions.TestReductionsFloats#testComputePi",
Expand Down Expand Up @@ -596,7 +608,10 @@ def runTests(args):

if (args.testClass != None):
options = "--jvm \"" + options + "\" "
cmd = TORNADO_CMD + options
ASSERTIONS_CMD = ""
if args.enable_assertions:
ASSERTIONS_CMD = ENABLE_ASSERTIONS
cmd = TORNADO_CMD + ASSERTIONS_CMD + options
command = appendTestRunnerClassToCmd(cmd, args)
command = command + " --params \"" + args.testClass + "\""
print(command)
Expand Down Expand Up @@ -659,7 +674,11 @@ def runTestTheWorld(options, args):
for testParam in t.testParameters:
command += " " + testParam

command = TORNADO_CMD + " --jvm \"" + command + "\" "
ASSERTIONS_CMD = ""
if args.enable_assertions and t.enableAssertions is not False:
ASSERTIONS_CMD = ENABLE_ASSERTIONS

command = TORNADO_CMD + ASSERTIONS_CMD + " --jvm \"" + command + "\" "

command = appendTestRunnerClassToCmd(command, args)
command = command + " --params \"" + t.testName
Expand Down Expand Up @@ -780,10 +799,6 @@ def main():
global javaVersion
javaVersion = getJavaVersion()

if (args.enable_assertions):
global TORNADO_CMD
TORNADO_CMD += ENABLE_ASSERTIONS

if (args.junit):
runWithJUnit(args)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*
* Copyright (c) 2024, APT Group, Department of Computer Science,
* The University of Manchester. All rights reserved.
* Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
*/
package uk.ac.manchester.tornado.unittests.llm;

import static org.junit.Assert.assertEquals;

import org.junit.Test;

import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid2D;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

/**
* <p>
* How to run the tests?
* </p>
*
* <code>
* tornado-test -V uk.ac.manchester.tornado.unittests.llm.TestFindMaxAttention
* </code>
*/
public class TestFindMaxAttention extends TornadoTestBase {

/**
* Sequential implementation for finding max attention scores used for reference and validation.
*
* @param pos
* The current position in the sequence (inclusive upper bound)
* @param seqLen
* The sequence length
* @param numHeads
* The number of attention heads
* @param attScores
* The attention scores array
* @return
* A new FloatArray containing the maximum values for each head
*/
public static FloatArray findMaxAttentionScoresSequential(int pos, int seqLen, int numHeads, FloatArray attScores) {
FloatArray maxValues = new FloatArray(numHeads);

// For each head, find the maximum score
for (int h = 0; h < numHeads; h++) {
float maxVal = Float.NEGATIVE_INFINITY;
int attOffset = h * seqLen;

// Find max in the range [0, pos]
for (int t = 0; t <= pos; t++) {
maxVal = Math.max(maxVal, attScores.get(attOffset + t));
}

maxValues.set(h, maxVal);
}

return maxValues;
}

/**
* Parallel implementation for finding maximum attention scores.
* Each head is processed by a separate work group, with threads collaborating
* through parallel reduction to find the maximum value.
*
* @param context
* The kernel execution context
* @param pos
* The current position in the sequence (inclusive upper bound)
* @param seqLen
* The sequence length
* @param attScores
* The attention scores array
* @param maxValues
* The output array to store maximum values for each head
*/
public static void findMaxAttentionScores(KernelContext context, int pos, int seqLen, FloatArray attScores, FloatArray maxValues, int localWorkgroupSize) {
int globalId = context.globalIdx; // Global thread ID
int localId = context.localIdx; // Thread ID within work group
int workGroupSize = context.localGroupSizeX; // Work group size
int numWorkGroups = context.localGroupSizeX; // Number of work groups

// Calculate which head this thread is working on
int h = globalId / workGroupSize;

// Check if this thread should process a head (don't exceed numHeads)
if (h < maxValues.getSize()) {
float[] maxReduction = context.allocateFloatLocalArray(localWorkgroupSize);

// Attention scores offset for this head
int attOffset = h * seqLen;

// Find the maximum value for this thread's assigned elements
float maxVal = Float.NEGATIVE_INFINITY;

// Each thread processes a stride of elements
for (int t = localId; t <= pos; t += workGroupSize) {
maxVal = Math.max(maxVal, attScores.get(attOffset + t));
}

// Store in local memory for reduction
maxReduction[localId] = maxVal;

// Parallel reduction to find global maximum
for (int stride = workGroupSize / 2; stride > 0; stride /= 2) {
context.localBarrier();
if (localId < stride) {
maxReduction[localId] = Math.max(maxReduction[localId], maxReduction[localId + stride]);
}
}

// Only the first thread in each work group writes the result
if (localId == 0) {
maxValues.set(h, maxReduction[0]);
}
}
}

/**
* Test the parallel implementation against the sequential implementation.
*
* @throws TornadoExecutionPlanException
* If there's an error in the Tornado execution plan
*/
@Test
public void testFindMaxAttentionScores() throws TornadoExecutionPlanException {
// Define the problem configuration
final int numHeads = 16; // Number of attention heads
final int seqLen = 2048; // Sequence length
final int pos = 1024; // Current position in sequence
final int threadsPerHead = 256; // Set to a power of 2 (common for GPUs)

// Create input and output arrays
FloatArray attScores = new FloatArray(numHeads * seqLen);
FloatArray maxValues = new FloatArray(numHeads);
FloatArray expectedMaxValues;

// Initialize attention scores with random values between -5 and 5
for (int i = 0; i < attScores.getSize(); i++) {
attScores.set(i, (float) (Math.random() * 10 - 5));
}

// Compute expected max values using sequential implementation
expectedMaxValues = findMaxAttentionScoresSequential(pos, seqLen, numHeads, attScores);

// Set up worker grid for parallel execution
// IMPORTANT: Make sure global work size is >= local work size
int localSize = 256; // This is your "threadsPerHead" or work group size
int globalSize = numHeads * localSize; // This ensures global size is a multiple of local size

WorkerGrid worker = new WorkerGrid2D(globalSize, 1);
worker.setLocalWork(localSize, 1, 1);
worker.setGlobalWork(globalSize, 1, 1);

// Create grid scheduler
GridScheduler gridScheduler = new GridScheduler("s0.findMaxAttentionScores", worker);

// Create kernel context
KernelContext context = new KernelContext();

// Create task graph
TaskGraph taskGraph = new TaskGraph("s0")
//@formatter:off
.transferToDevice(DataTransferMode.FIRST_EXECUTION, attScores)
.task("findMaxAttentionScores", TestFindMaxAttention::findMaxAttentionScores,
context, pos,
seqLen, attScores,
maxValues, localSize)
.transferToHost(DataTransferMode.EVERY_EXECUTION, maxValues);
//@formatter:on

// Execute the task graph
ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.withGridScheduler(gridScheduler).execute();
}

// Validate results
for (int h = 0; h < numHeads; h++) {
float expected = expectedMaxValues.get(h);
float actual = maxValues.get(h);
assertEquals("Mismatch at head " + h, expected, actual, 1e-5f);
}
}
}
Loading