Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for reduction with sum of squares (RSS), compute of RNSNorm, RNSNorm fused with Matmul, FindMaxAttention and Softmax #593

Open
wants to merge 29 commits into
base: develop
Choose a base branch
from

Conversation

mikepapadim
Copy link
Member

@mikepapadim mikepapadim commented Nov 24, 2024

Description

This PR adds tests to test the code gen and and vailidity of the results for some common operation required in the LLM architecture.

  1. Reduction with Sum of Squares (RSS): Ensures the correct functionality of reduction operations that compute the sum of squares, which is vital for various numerical algorithms.​

  2. Compute of RNSNorm: Validates the implementation of the RNSNorm function, a normalization technique that can enhance the stability and performance of neural networks.​

  3. RNSNorm Fused with Matmul: Tests the integration of RNSNorm directly within matrix multiplication operations, aiming to optimize performance by reducing the number of separate computational steps.​

  4. FindMaxAttention: Assesses the functionality of the FindMaxAttention operation, which is essential in attention mechanisms commonly used in transformer models for tasks like natural language processing.​

  5. Softmax: Verifies the correctness of the Softmax function implementation, a critical component in various machine learning models, particularly in classification tasks.

  6. MultiHeadedAttention: Verifies the correctness of the Mutiheaded attention, it breaks down individual tasks/kernels and uses a refernce Java implmentation very close to LLama3.java.

  7. FFN: Verifies the corectness layer.

  8. Fused with RoPE: verifies corectness .

Backend/s tested

Mark the backends affected by this PR.

  • OpenCL
  • PTX
  • SPIRV

OS tested

Mark the OS where this PR is tested.

  • Linux
  • OSx
  • Windows

Did you check on FPGAs?

If it is applicable, check your changes on FPGAs.

  • Yes
  • No

How to test the new patch?

tornado-test -V uk.ac.manchester.tornado.unittests.reductions.TestReductionsFloats#testReduceSumSquares

tornado-test -V uk.ac.manchester.tornado.unittests.llm.TestRNSNormLayer

tornado-test -V uk.ac.manchester.tornado.unittests.llm.TestSoftMaxLayer

tornado-test -V uk.ac.manchester.tornado.unittests.llm.TestFindMaxAttention

tornado-test -V uk.ac.manchester.tornado.unittests.llm.TestParallelFFNLayer

tornado-test -V uk.ac.manchester.tornado.unittests.llm.TestMultiHeadAttention

tornado-test   -V uk.ac.manchester.tornado.unittests.llm.TestFusedLayer

@jjfumero
Copy link
Member

Include the new test in the test-suite:
tornado-assembly/src/bin/tornado-test

@jjfumero
Copy link
Member

The new test enters in an infinite loop when running with the SPIR-V backend:

tornado-test -V uk.ac.manchester.tornado.unittests.compute.LLMFusedKernelsTest
/home/juan/tornadovm/TornadoVM/bin/sdk/bin/tornado --jvm "-Xmx6g -Dtornado.recover.bailout=False -Dtornado.unittests.verbose=True "  -m  tornado.unittests/uk.ac.manchester.tornado.unittests.tools.TornadoTestRunner  --params "uk.ac.manchester.tornado.unittests.compute.LLMFusedKernelsTest"

The PTX and OpenCL backends run fine.

Copy link
Collaborator

@stratika stratika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add the new LLMFusedKernelsTest class in the tornado-test in order to be run when Jenkins runs the unit-tests. In my setup, the tests pass for PTX. But, the tests in the LLMFusedKernelsTest class are not finishing when running with SPIR-V. I guess, that they are not supported for SPIR-V?


@Test
public void testRNSNormFusedWithMatMul() throws TornadoExecutionPlanException {
final int size = 2048;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add the following, unless it is supported:

assertNotBackend(TornadoVMBackendType.SPIRV);

@jjfumero
Copy link
Member

@mikepapadim , is this ready?

@mikepapadim mikepapadim changed the title Add tests for reduction with sum of squares (RSS), compute of RNSNorm, and RNSNorm fused with Matmul Add tests for reduction with sum of squares (RSS), compute of RNSNorm, RNSNorm fused with Matmul, and Softmax Mar 12, 2025
@mikepapadim
Copy link
Member Author

@jjfumero @stratika @mairooni this is ready and now also includes softmax implementation

@jjfumero
Copy link
Member

Thanks @mikepapadim. Did you check with SPIR-V. Last time it got stuck in an infinite loop.

@mikepapadim
Copy link
Member Author

Thanks @mikepapadim. Did you check with SPIR-V. Last time it got stuck in an infinite loop.

No, I need to switch machines for that

@stratika
Copy link
Collaborator

I confirm it seems to be something odd with SPIR-V. Let's wait and if it is not supported for any reason we can mark it as unsupported. Currently it seems like it is stuck. Note that we should test with both the OpenCL and Level Zero runtimes.

make BACKEND=spirv
tornado-test --quickPass --verbose --jvm="-Dtornado.unittests.device=0:0"
tornado-test --quickPass --verbose --jvm="-Dtornado.unittests.device=0:1"

@mikepapadim mikepapadim changed the title Add tests for reduction with sum of squares (RSS), compute of RNSNorm, RNSNorm fused with Matmul, and Softmax Add tests for reduction with sum of squares (RSS), compute of RNSNorm, RNSNorm fused with Matmul, FindMaxAttention and Softmax Mar 12, 2025
@stratika
Copy link
Collaborator

I can help with the issue in SPIR-V. I will have a look next week.

@mikepapadim
Copy link
Member Author

TODO: add all to tornado-test script

Copy link
Collaborator

@stratika stratika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's a great addition to the unit-tests. I have added come comments. Basically, the only think is that currently the tests are reporting many logged messages. I am not sure if those messages are necessary when we run the tests. I think the most important is to see how many are failing, so I would focus on using the assertEquals. Let me know of any thought, I will continue with testing the new tests in my machine.

assertEquals("Mismatch at head " + h, expected, actual, 1e-5f);
}

System.out.println("All results match! Test passed.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to remove the print of the message. If they pass validation they will be passing.

System.out.println("Executing fused layer task graph...");
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.withGridScheduler(gridScheduler).execute();
System.out.println("Execution completed successfully");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we assert the validation of the results at this stage in order to trigger PASS or FAIL?


for (int i = 0; i < outputSeqLogits.getSize(); i++) {
float expected = outputSeqLogits.get(i); // Expected value from the sequential output
float actual = outputLogits.get(i); // Actual value from the RNS output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are missing the assertEquals method for the expected and actual value.

Random random = new Random(42);

// Print head 0, dimension 0 values for clarity
System.out.println("Initializing test data:");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to remove the printing statement.

Comment on lines +158 to +168
System.out.println("Attention weights for head 0:");
for (int t = 0; t <= pos; t++) {
System.out.printf("Position %d: %.6f%n", t, attScores.get(0 * seqLen + t));
}

// Print the first few value vectors for head 0
System.out.println("First few values for head 0, dimension 0:");
for (int t = 0; t <= Math.min(pos, 3); t++) {
int valueOffset = loff + t * kvDim + (0 / kvMul) * headSize + 0;
System.out.printf("Position %d: %.6f%n", t, valueCache.get(valueOffset));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are this print statements valuable for when we run the unit-tests? or are they added for documentation purposes? If it is for documentation, perhaps we can add one flag that is FALSE by default, in order to reduce the logged output when we run all unit-tests.

In my opinion, when we run the tests we should focus on PASS/FAIL results. What do you think?

@stratika
Copy link
Collaborator

Also, one interesting note that I discussed with @mikepapadim is that when we run the unit-tests with make fast-tests, this uses the tornado-test --ea --verbose --quickPass command. and the flag --ea which is used to enable assertions at the compiler level adds a delay when we run the following test:

tornado-test  --ea -V uk.ac.manchester.tornado.unittests.llm.TestParallelFFNLayer

@mikepapadim
Copy link
Member Author

I think with @stratika changes this is good to go now.

@jjfumero
Copy link
Member

For the SPIR-V I get the following failures:

/home/juan/repos/tornadovm/bin/sdk/bin/tornado --jvm "-Xmx6g -Dtornado.recover.bailout=False -Dtornado.unittests.verbose=True "  -m  tornado.unittests/uk.ac.manchester.tornado.unittests.tools.TornadoTestRunner  --params "uk.ac.manchester.tornado.unittests.llm.TestFindMaxAttention"
WARNING: Using incubator modules: jdk.incubator.vector

Test: class uk.ac.manchester.tornado.unittests.llm.TestFindMaxAttention
	Running test: testFindMaxAttentionScores ................  [FAILED] 
		\_[REASON] Mismatch at head 0 expected:<4.9773192> but was:<0.0>
Test ran: 1, Failed: 1, Unsupported: 0

And:

/home/juan/repos/tornadovm/bin/sdk/bin/tornado --jvm "-Xmx6g -Dtornado.recover.bailout=False -Dtornado.unittests.verbose=True "  -m  tornado.unittests/uk.ac.manchester.tornado.unittests.tools.TornadoTestRunner  --params "uk.ac.manchester.tornado.unittests.llm.TestMultiHeadAttention"
WARNING: Using incubator modules: jdk.incubator.vector

Test: class uk.ac.manchester.tornado.unittests.llm.TestMultiHeadAttention
	Running test: testMultiHeadAttentionLlamaEquivalence ................  [FAILED] 
		\_[REASON] expected:<-0.052507114> but was:<NaN>
	Running test: testAttentionScoresCalculation ................  [PASS] 
	Running test: testFindMaxAttentionScores ................  [FAILED] 
		\_[REASON] Max value mismatch expected:<4.981786> but was:<0.0>
	Running test: testCalculateExpAndSum     ................  [FAILED] 
		\_[REASON] Sum value mismatch expected:<78781.76> but was:<0.0>
	Running test: testNormalizeSoftmax       ................  [PASS] 
	Running test: testComputeWeightedSum     ................  [PASS] 
	Running test: testSingleHeadFullPipeline ................  [FAILED] 
		\_[REASON] Output mismatch at dimension 0 expected:<0.010636929422616959> but was:<NaN>
	Running test: testMultiHeadAttentionFixed ................  [FAILED] 
		\_[REASON] Output mismatch at head 0, dimension 0 expected:<-0.052507114> but was:<NaN>
Test ran: 8, Failed: 5, Unsupported: 0

Both OpenCL and PTX are correct.

@jjfumero
Copy link
Member

My setup for the SPIR-V:

$ tornado --devices

Number of Tornado drivers: 2
Driver: SPIR-V
  Total number of SPIR-V devices  : 2
  Tornado device=0:0  (DEFAULT)
	SPIRV -- SPIRV OCL - Intel(R) UHD Graphics 770
		Global Memory Size: 28.8 GB
		Local Memory Size: 64.0 KB
		Workgroup Dimensions: 3
		Total Number of Block Threads: [512]
		Max WorkGroup Configuration: [512, 512, 512]
		Device OpenCL C version: OpenCL C 1.2

  Tornado device=0:1
	SPIRV -- SPIRV LevelZero - Intel(R) UHD Graphics 770
		Global Memory Size: 28.8 GB
		Local Memory Size: 64.0 KB
		Workgroup Dimensions: 3
		Total Number of Block Threads: [512]
		Max WorkGroup Configuration: [512, 512, 512]
		Device OpenCL C version:  (LEVEL ZERO) 1.5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

3 participants