-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tests for reduction with sum of squares (RSS), compute of RNSNorm, RNSNorm fused with Matmul, FindMaxAttention and Softmax #593
base: develop
Are you sure you want to change the base?
Conversation
...-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compute/LLMFusedKernelsTest.java
Outdated
Show resolved
Hide resolved
...-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compute/LLMFusedKernelsTest.java
Outdated
Show resolved
Hide resolved
Include the new test in the test-suite: |
The new test enters in an infinite loop when running with the SPIR-V backend: tornado-test -V uk.ac.manchester.tornado.unittests.compute.LLMFusedKernelsTest
/home/juan/tornadovm/TornadoVM/bin/sdk/bin/tornado --jvm "-Xmx6g -Dtornado.recover.bailout=False -Dtornado.unittests.verbose=True " -m tornado.unittests/uk.ac.manchester.tornado.unittests.tools.TornadoTestRunner --params "uk.ac.manchester.tornado.unittests.compute.LLMFusedKernelsTest" The PTX and OpenCL backends run fine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add the new LLMFusedKernelsTest
class in the tornado-test
in order to be run when Jenkins runs the unit-tests. In my setup, the tests pass for PTX. But, the tests in the LLMFusedKernelsTest
class are not finishing when running with SPIR-V. I guess, that they are not supported for SPIR-V?
...-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compute/LLMFusedKernelsTest.java
Outdated
Show resolved
Hide resolved
...-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compute/LLMFusedKernelsTest.java
Outdated
Show resolved
Hide resolved
...-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compute/LLMFusedKernelsTest.java
Outdated
Show resolved
Hide resolved
|
||
@Test | ||
public void testRNSNormFusedWithMatMul() throws TornadoExecutionPlanException { | ||
final int size = 2048; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add the following, unless it is supported:
assertNotBackend(TornadoVMBackendType.SPIRV);
...-unittests/src/main/java/uk/ac/manchester/tornado/unittests/compute/LLMFusedKernelsTest.java
Outdated
Show resolved
Hide resolved
@mikepapadim , is this ready? |
…te copyright year
…MatMulWithByteArrays and TestRMSNormLayer, respectively, and update package structure
… reduction methods
… reduction methods
Thanks @mikepapadim. Did you check with SPIR-V. Last time it got stuck in an infinite loop. |
No, I need to switch machines for that |
I confirm it seems to be something odd with SPIR-V. Let's wait and if it is not supported for any reason we can mark it as unsupported. Currently it seems like it is stuck. Note that we should test with both the OpenCL and Level Zero runtimes. make BACKEND=spirv
tornado-test --quickPass --verbose --jvm="-Dtornado.unittests.device=0:0"
tornado-test --quickPass --verbose --jvm="-Dtornado.unittests.device=0:1" |
…tion score calculations
I can help with the issue in SPIR-V. I will have a look next week. |
TODO: add all to tornado-test script |
…test_reduce_square_sum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that's a great addition to the unit-tests. I have added come comments. Basically, the only think is that currently the tests are reporting many logged messages. I am not sure if those messages are necessary when we run the tests. I think the most important is to see how many are failing, so I would focus on using the assertEquals
. Let me know of any thought, I will continue with testing the new tests in my machine.
assertEquals("Mismatch at head " + h, expected, actual, 1e-5f); | ||
} | ||
|
||
System.out.println("All results match! Test passed."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to remove the print of the message. If they pass validation they will be passing.
tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/llm/TestFusedLayer.java
Show resolved
Hide resolved
tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/llm/TestFusedLayer.java
Show resolved
Hide resolved
System.out.println("Executing fused layer task graph..."); | ||
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { | ||
executionPlan.withGridScheduler(gridScheduler).execute(); | ||
System.out.println("Execution completed successfully"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we assert the validation of the results at this stage in order to trigger PASS or FAIL?
|
||
for (int i = 0; i < outputSeqLogits.getSize(); i++) { | ||
float expected = outputSeqLogits.get(i); // Expected value from the sequential output | ||
float actual = outputLogits.get(i); // Actual value from the RNS output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are missing the assertEquals
method for the expected and actual value.
...ado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/llm/TestParallelFFNLayer.java
Show resolved
Hide resolved
Random random = new Random(42); | ||
|
||
// Print head 0, dimension 0 values for clarity | ||
System.out.println("Initializing test data:"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to remove the printing statement.
System.out.println("Attention weights for head 0:"); | ||
for (int t = 0; t <= pos; t++) { | ||
System.out.printf("Position %d: %.6f%n", t, attScores.get(0 * seqLen + t)); | ||
} | ||
|
||
// Print the first few value vectors for head 0 | ||
System.out.println("First few values for head 0, dimension 0:"); | ||
for (int t = 0; t <= Math.min(pos, 3); t++) { | ||
int valueOffset = loff + t * kvDim + (0 / kvMul) * headSize + 0; | ||
System.out.printf("Position %d: %.6f%n", t, valueCache.get(valueOffset)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are this print statements valuable for when we run the unit-tests? or are they added for documentation purposes? If it is for documentation, perhaps we can add one flag that is FALSE by default, in order to reduce the logged output when we run all unit-tests.
In my opinion, when we run the tests we should focus on PASS/FAIL results. What do you think?
Also, one interesting note that I discussed with @mikepapadim is that when we run the unit-tests with tornado-test --ea -V uk.ac.manchester.tornado.unittests.llm.TestParallelFFNLayer |
…ntial kernel in SPIRV. The problem was triggered when we had an if condition block in a sequential kernel and we had some code out of the if block.
…led due to long compilation time
…or the Math class
…value in order to be able to debug
…or a particular test
I think with @stratika changes this is good to go now. |
For the SPIR-V I get the following failures:
And:
Both OpenCL and PTX are correct. |
My setup for the SPIR-V: $ tornado --devices
Number of Tornado drivers: 2
Driver: SPIR-V
Total number of SPIR-V devices : 2
Tornado device=0:0 (DEFAULT)
SPIRV -- SPIRV OCL - Intel(R) UHD Graphics 770
Global Memory Size: 28.8 GB
Local Memory Size: 64.0 KB
Workgroup Dimensions: 3
Total Number of Block Threads: [512]
Max WorkGroup Configuration: [512, 512, 512]
Device OpenCL C version: OpenCL C 1.2
Tornado device=0:1
SPIRV -- SPIRV LevelZero - Intel(R) UHD Graphics 770
Global Memory Size: 28.8 GB
Local Memory Size: 64.0 KB
Workgroup Dimensions: 3
Total Number of Block Threads: [512]
Max WorkGroup Configuration: [512, 512, 512]
Device OpenCL C version: (LEVEL ZERO) 1.5 |
Description
This PR adds tests to test the code gen and and vailidity of the results for some common operation required in the LLM architecture.
Reduction with Sum of Squares (RSS)
: Ensures the correct functionality of reduction operations that compute the sum of squares, which is vital for various numerical algorithms.Compute of RNSNorm
: Validates the implementation of the RNSNorm function, a normalization technique that can enhance the stability and performance of neural networks.RNSNorm Fused with Matmul
: Tests the integration of RNSNorm directly within matrix multiplication operations, aiming to optimize performance by reducing the number of separate computational steps.FindMaxAttention
: Assesses the functionality of the FindMaxAttention operation, which is essential in attention mechanisms commonly used in transformer models for tasks like natural language processing.Softmax
: Verifies the correctness of the Softmax function implementation, a critical component in various machine learning models, particularly in classification tasks.MultiHeadedAttention
: Verifies the correctness of the Mutiheaded attention, it breaks down individual tasks/kernels and uses a refernce Java implmentation very close to LLama3.java.FFN
: Verifies the corectness layer.Fused with RoPE
: verifies corectness .Backend/s tested
Mark the backends affected by this PR.
OS tested
Mark the OS where this PR is tested.
Did you check on FPGAs?
If it is applicable, check your changes on FPGAs.
How to test the new patch?