[Bugfix][Kernel]Fix incorrect output tokens when running ChatGLM-9b inferencing on MI250 #312
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[Issue Description]
We encountered random and incorrect output tokens when running ChatGLM-9b inferencing with latest rocm/vllm release on MI250.
[Root Cause]
The vectorized rms_norm_kernel uses the same data type as input tensors (like FP16) to store intermediate results, which causes precision lost and leads to incorrect output tokens.
[Solution]
Use float type to store intermediate results and get correct output tokens.
We compared the outputs from different kernels w/ PyTorch's standard op and verified the final output token correctness, below are the test results:
We can see after the fix is applied, without impacting on performance, the "mismatched elements" has reduced significantly, and both the absolute & relative differences are close or even better than non-vectorized kernel, final output tokens are correct as well.