In previous 14. MAKING PREDICTION with LLAMA MODEL - 1 chapter, this chapter, and following 16. MAKING PREDICTION with LLAMA MODEL - 3 chapter, we walk through the LlamaTransformer.Forward(...) method and its components.
In the previous chapter, we have initiated a for loop to iterate over 32 layers defined at lt.Layers
and called LlamaTransformerBlock.Forward(...)
method.
In this chapter, we will delve into details of the LlamaTransformerBlock.Forward(...)
method.
Recap:
The Llama models use Pre-RMSNorm (Root Mean Square Layer Normalization). Because of we perform Root Mean Square Layer Normalization before performing multiplication of current tensor with normalization weights tensor, we call this normalization stage as "pre-normalization".
Diagram: Forward Pass Through Attention Pre-normalization. For the complete diagram, click here.
In this chapter, we only cover how to call RMSNorm succinctly, to give more space for other details. The details of prenormalization and RMSNorm (Root Mean Square Layer Normalization) will be explained in the chapter 16.1. Performing Forward Pass Through Output Prenormalization - RMSNorm.Forward(...).
Now, we have the x
tensor argument. In our case, at the first iteration, x
is input tensor, at the other iterations, x
is output of previous LlamaTransformerBlock
. In our case, at the first iteration, the shape of this tensor is {22, 4096}
. 22 stands for sequence length, 4096 stands for the embedding layer dimension. normalizedX
which is the resulting tensor will have same shape as the input, {22, 4096}
.
from src/model/llamatransformer.go
func (ltb *LlamaTransformerBlock) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
var maskSize []int
if mask != nil {
maskSize = mask.Size
}
common.GLogger.DebugPrintf("LlamaTransformerBlock.Forward started for x: shape(%v), startPos: %d, freqsCis: shape(%v), mask: shape(%v)", x.Size, startPos, freqsCis.Size, maskSize)
common.GLogger.DebugPrintf("Calling RMSNorm for tensor x shape(%v) and LlamaTransformerBlock.attn_norm weights shape(%v) -> tensor normalizedX", x.Size, ltb.attn_norm.weights.Size)
normalizedX, err := ltb.attn_norm.Forward(infContext, x)
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling LlamaTransformerBlock.Forward for layer: 1 / 32, startPos: 0 -> tensor currentTensor ...
[DEBUG] ... LlamaTransformerBlock.Forward started for x: shape([22 4096]), startPos: 0, freqsCis: shape([22 64]), mask: shape([]) ...
[DEBUG] ... Calling RMSNorm for tensor x shape([22 4096]) and LlamaTransformerBlock.attn_norm weights shape([4096]) -> tensor normalizedX ...
from src/model/llamatransformer.go
func (ltb *LlamaTransformerBlock) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
common.GLogger.DebugPrintf("Calling LlamaAttention.Forward for tensor normalizedX shape(%v) and startPos: %d, freqsCis: shape(%v), mask: shape(%v) -> tensor h", normalizedX.Size, startPos, freqsCis.Size, maskSize)
h, err := ltb.attention.Forward(infContext, normalizedX, startPos, freqsCis, mask)
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling LlamaAttention.Forward for tensor normalizedX shape([22 4096]) and startPos: 0, freqsCis: shape([22 64]), mask: shape([22 22]) -> tensor h ...
Recap:
The most important part of the transformer models that provide accurate outputs is the attention mechanism. Each "block" of Llama consists of a self-attention and a feed-forward neural network parts. The details will be explained further, but also we call these "block"s as "layer"s.
The attention mechanism is one of the important inventions that made language models more improved. Llama models have implemented "multi-head attention", so in our model case (Llama 3.1 8B-Instruct) we have 32 attention heads with some other supportive components. In the following steps, we will walk through details of an "attention module".
Important note:
In our case model Llama 3.1 8B-Instruct has 32 layers of transformer blocks and each block contains an attention module containing 32 attention heads. Both numbers are 32, but they specify numbers of different concepts, so pay more attention to avoid any confusion.
Diagram: Forward Pass Through Attention Module. For the complete diagram, click here.
The LlamaAttention
object consists of:
attn_wq
: Attention query weights tensor with shape of{N_Heads * HeadDim, Dim} = {32 * 128, 4096} = {4096, 4096}
,attn_wk
: Attention key weights tensor with shape of{N_KVHeads * HeadDim, Dim} = {8 * 128, 4096} = {1024, 4096}
,attn_wv
: Attention value weights tensor with shape of{N_KVHeads * HeadDim, Dim} = {8 * 128, 4096} = {1024, 4096}
,attn_wo
: Attention output weights tensor with shape of{N_Heads * HeadDim, Dim} = {32 * 128, 4096} = {4096, 4096}
.
Type definition:
from src/model/llamatransformer.go
type LlamaAttention struct {
LayerIndex int
N_Heads int
N_KVHeads int
N_Rep int
HeadDim int
attn_wq *ml.Tensor // Original: "layers.0.attention.wq.weight" | ggml: "blk.0.attn_q.weight" | [out_features, in_features] -> shape: [4096 4096] -> [N_Heads * HeadDim, Dim]
attn_wk *ml.Tensor // Original: "layers.0.attention.wk.weight" | ggml: "blk.0.attn_k.weight" | [out_features, in_features] -> shape: [1024 4096] -> [N_KVHeads * HeadDim, Dim]
attn_wv *ml.Tensor // Original: "layers.0.attention.wv.weight" | ggml: "blk.0.attn_v.weight" | [out_features, in_features] -> shape: [1024 4096] -> [N_KVHeads * HeadDim, Dim]
attn_wo *ml.Tensor // Original: "layers.0.attention.wo.weight" | ggml: "blk.0.attn_output.weight" | [out_features, in_features] -> shape: [4096 4096] -> [N_Heads * HeadDim, Dim]
}
Now, we have the x
tensor argument. In our case, at the first iteration, x
is the normalized form of input tensor, at the other iterations, x
is the normalized form of output of previous LlamaTransformerBlock
. In our case, at the first iteration, the shape of this tensor is {22, 4096}
. 22 stands for sequence length, 4096 stands for the embedding layer dimension.
In our attention module, we have LlamaAttention.attn_wq
, LlamaAttention.attn_wk
, and LlamaAttention.attn_wv
. They are weight tensors of current (one of 32 layers) layer, which stand for "attention query weights", "attention key weights", and "attention value weights" respectively.
We need to perform a linear transformation with our x
tensor with each of these three weight tensors independently, then take the results into xq
, xk
, and xv
tensors respectively. These operations can be done independently, so we can run them parallelly. In this step, we provide a structure to call them parallelly as follows.
The concepts that is used here were described in the chapter 13.1. Preliminary Concepts. For now, know that, the context and WaitGroup are used to manage parallel operations. We call the ml.LinearTransformation as goroutines.
Then, these 3 goroutines are performed and finished, we take the results xq
, xk
, and xv
tensors from the parallelResults
map.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
sequenceLength := x.Size[0]
ctx, cancel := context.WithCancelCause(context.Background())
var wg sync.WaitGroup
var mu sync.Mutex
parallelResults := make(map[string]*ml.Tensor)
common.GLogger.DebugPrintf("[Scheduling goroutine] ml.LinearTransformation for x shape(%v) and LlamaAttention.attn_wq weights shape(%v) -> tensor xq", x.Size, lat.attn_wq.Size)
wg.Add(1)
go func() {
defer wg.Done()
if ctx.Err() != nil {
return
}
...
}()
common.GLogger.DebugPrintf("[Scheduling goroutine] ml.LinearTransformation for x shape(%v) and LlamaAttention.attn_wk weights shape(%v) -> tensor xk", x.Size, lat.attn_wk.Size)
wg.Add(1)
go func() {
defer wg.Done()
if ctx.Err() != nil {
return
}
...
}()
common.GLogger.DebugPrintf("[Scheduling goroutine] ml.LinearTransformation for x shape(%v) and LlamaAttention.attn_wv weights shape(%v) -> tensor xv", x.Size, lat.attn_wv.Size)
wg.Add(1)
go func() {
defer wg.Done()
if ctx.Err() != nil {
return
}
...
}()
runtime.Gosched()
select {
case <-ctx.Done():
// Cancellation signal was received
return nil, context.Cause(ctx)
case <-common.WaitGroupDone(&wg):
runtime.Gosched()
}
xq := parallelResults["xq"]
xk := parallelResults["xk"]
xv := parallelResults["xv"]
}
We perform three ml.LinearTransformation which set their results into the parallelResults
map with mutex locks. The resulting xq
tensor has shape of {22, 4096}
, xk
and xv
tensors are with same shape of {22, 1024}
.
Diagram: Forward Pass Through Attention Module - Calculating xq, xk, and xv. For the complete diagram, click here.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
go func() {
...
xq, err := ml.LinearTransformation(x, lat.attn_wq)
...
mu.Lock()
parallelResults["xq"] = xq
mu.Unlock()
}
...
go func() {
...
xk, err := ml.LinearTransformation(x, lat.attn_wk)
...
mu.Lock()
parallelResults["xk"] = xk
mu.Unlock()
}
...
go func() {
...
xv, err := ml.LinearTransformation(x, lat.attn_wv)
...
mu.Lock()
parallelResults["xv"] = xv
mu.Unlock()
}
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... [Scheduling goroutine] ml.LinearTransformation for x shape([22 4096]) and LlamaAttention.attn_wq weights shape([4096 4096]) -> tensor xq ...
[DEBUG] ... [Scheduling goroutine] ml.LinearTransformation for x shape([22 4096]) and LlamaAttention.attn_wk weights shape([1024 4096]) -> tensor xk ...
[DEBUG] ... [Scheduling goroutine] ml.LinearTransformation for x shape([22 4096]) and LlamaAttention.attn_wv weights shape([1024 4096]) -> tensor xv ...
[DEBUG] ... [Calling in goroutine] ml.LinearTransformation for x shape([22 4096]) and LlamaAttention.attn_wv weights shape([1024 4096]) -> tensor xv ...
[DEBUG] ... [Calling in goroutine] ml.LinearTransformation for x shape([22 4096]) and LlamaAttention.attn_wq weights shape([4096 4096]) -> tensor xq ...
[DEBUG] ... [Calling in goroutine] ml.LinearTransformation for x shape([22 4096]) and LlamaAttention.attn_wk weights shape([1024 4096]) -> tensor xk ...
[DEBUG] ... Parallel results, xq: shape([22 4096]), xk: shape([22 1024]), xv: shape([22 1024]) ...
Note: As you can see the logs above, the order of
[Scheduling goroutine]
and[Calling in goroutine]
lines are different, it shows they were executed parallelly.
Diagram: Forward Pass Through Attention Module - Do reshapings. For the complete diagram, click here.
The resulting xq
tensor has shape of {22, 4096}
, xk
and xv
tensors are with same shape of {22, 1024}
. In our case, we implement "multi-head attention":
- On
xq
tensor with32
attention heads (according tomodelArgs.N_Heads = 32
), - On
xk
andxv
tensors with8
attention key/value heads (according tomodelArgs.N_KVHeads = 8
).
Our resulting tensors have the values of each attention head combined into the dimension with size of 4096
. Our modelArgs.HeadDim = 128
, so the dimension of each attention head is 128.
Now, we need to reshape our tensors to differentiate each attention head.
- On
xq
tensor with shape of{22, 4096}
will be in shape of{22, 32, 128}
. The first22
stands for sequence length, the second32
stands for "attention head count"modelArgs.N_Heads
, the128
stands for "attention head dimension"modelArgs.HeadDim
, - On
xk
andxv
tensors with shape of{22, 1024}
will be in shape of{22, 8, 128}
. The first22
stands for sequence length, the second8
stands for "attention key/value head count"modelArgs.N_KVHeads
, the128
stands for "attention head dimension"modelArgs.HeadDim
,
But, wait! We have mentioned the "attention head count" with modelArgs.N_Heads
, but in the code there are two concepts: modelArgs.N_Heads
and modelArgs.N_KVHeads
. The modelArgs.N_Heads
is used for specifying the shape of query tensor xq
. The modelArgs.N_KVHeads
is used for specifying the shape of key xk
and value xv
tensors. At one of further steps, we will apply "Repeat K/V heads" operation to make our key xk
and value xv
tensor shapes equal to query tensor xq
.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
common.GLogger.DebugPrintf("Parallel results, xq: shape(%v), xk: shape(%v), xv: shape(%v)", xq.Size, xk.Size, xv.Size)
/*
Do reshapings
*/
var err error
if xq, err = xq.Reshape([]int{sequenceLength, lat.N_Heads, lat.HeadDim}); err != nil {
return nil, err
}
if xk, err = xk.Reshape([]int{sequenceLength, lat.N_KVHeads, lat.HeadDim}); err != nil {
return nil, err
}
if xv, err = xv.Reshape([]int{sequenceLength, lat.N_KVHeads, lat.HeadDim}); err != nil {
return nil, err
}
common.GLogger.DebugPrintf("Reshaping results, xq: shape(%v), xk: shape(%v), xv: shape(%v)", xq.Size, xk.Size, xv.Size)
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Parallel results, xq: shape([22 4096]), xk: shape([22 1024]), xv: shape([22 1024]) ...
[DEBUG] ... Reshaping results, xq: shape([22 32 128]), xk: shape([22 8 128]), xv: shape([22 8 128]) ...
Diagram: Forward Pass Through Attention Module - Apply Rotary Embeddings. For the complete diagram, click here.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
/*
Apply rotary embeddings
*/
if xq, xk, err = applyRotaryEmbeddings(xq, xk, freqsCis); err != nil {
return nil, err
}
common.GLogger.DebugPrintf("applyRotaryEmbeddings results, xq: shape(%v), xk: shape(%v)", xq.Size, xk.Size)
...
}
For more information about how the
freqsCis
tensor is initiated, refer to 10. ROPE ROTARY POSITIONAL EMBEDDINGS and 10.BONUS. PRECOMPUTING FREQUENCY TENSOR.
During this step, we apply the RoPE (Rotary Positional Embeddings) approach propoed by RoFormer paper over our query xq
and key xk
tensors.
Note: The shape information here is for the first iteration of our sample case. The shape samples written in code descriptions are for a different case. The first dimension of the shapes stands for sequence length which varies by the prompt tokens count.
-
Have the query tensor
xq
with shape of{22, 32, 128}
. Convert the tensor's data type fromDT_BF16
toDT_COMPLEX
, and change the shape to{22, 32, 64}
via Tensor.ViewAsComplex64WithReshape(...) method, the result is assigned intoxq_
variable,
This method:- Converts the data type of the tensor to
DT_F32
(float32), the shape remains as{22, 32, 128}
, - Reshapes the tensor with shape of
{22, 32, 64, 2}
, - Converts each pair of float32 in the last dimension into a
complex64
data type via Tensor.ViewAsComplex64(...), the new shape is{22, 32, 64}
with data type ofDT_COMPLEX
.
See torch.view_as_complex documentation for more information.
Comment from Pytorch's documentation (link above):
Torch's view_as_complex() is only supported for tensors with torch.dtype torch.float64 and torch.float32.
The input is expected to have the last dimension of size 2. In addition, the tensor must have a stride of 1 for its last dimension. The strides of all other dimensions must be even numbers. - Converts the data type of the tensor to
-
Have the key tensor
xk
with shape of{22, 8, 128}
. Convert the tensor's data type fromDT_BF16
toDT_COMPLEX
, and change the shape to{22, 8, 64}
via Tensor.ViewAsComplex64WithReshape(...) method, the result is assigned intoxk_
variable, -
Reshape the
freqs_cis
tensor with shape of{22, 64}
to the shape{22, 1, 64}
, -
Process the
xqOut
:- Perform an element-wise multiplication with
xq_
tensor with shape of{22, 32, 64}
andfreqs_cis
tensor with shape of{22, 1, 64}
via ml.MultiplyElementwise. Output shape is{22, 32, 64}
, assign the result intoxqOut
variable, - Convert the
xqOut
tensor's data type fromDT_COMPLEX
toDT_F32
(float32), and change the shape to{22, 32, 128}
via Tensor.ViewAsComplex64WithReshape(...) method (think as packing-unpacking the pairs in the last dimension), - Convert the
xqOut
tensor's data type fromDT_F32
(float32) toDT_BF16
with same shape{22, 32, 128}
,
- Perform an element-wise multiplication with
-
Process the
xkOut
:- Perform an element-wise multiplication with
xk_
tensor with shape of{22, 8, 64}
andfreqs_cis
tensor with shape of{22, 1, 64}
via ml.MultiplyElementwise. Output shape is{22, 8, 64}
, assign the result intoxkOut
variable, - Convert the
xkOut
tensor's data type fromDT_COMPLEX
toDT_F32
(float32), and change the shape to{22, 8, 128}
via Tensor.ViewAsComplex64WithReshape(...) method (think as packing-unpacking the pairs in the last dimension), - Convert the
xkOut
tensor's data type fromDT_F32
(float32) toDT_BF16
with same shape{22, 8, 128}
,
- Perform an element-wise multiplication with
-
Return the tensors
xqOut
andxkOut
tensors with shape{22, 8, 128}
together as result.
from src/model/llamatransformer.go
func applyRotaryEmbeddings(xq *ml.Tensor, xk *ml.Tensor, freqs_cis *ml.Tensor) (xqOut *ml.Tensor, xkOut *ml.Tensor, err error) {
// xq shape=[5,32,128] dtype=DT_BF16
xq_, err := xq.ViewAsComplex64WithReshape() // shape=[5,32,64] dtype=DT_COMPLEX
if err != nil {
return nil, nil, err
}
// xk shape=[5,8,128] dtype=DT_BF16
xk_, err := xk.ViewAsComplex64WithReshape() // shape=[5,8,64] dtype=DT_COMPLEX
if err != nil {
return nil, nil, err
}
// freqs_cis shape=[5, 64] dtype=DT_COMPLEX
if freqs_cis, err = freqs_cis.Reshape([]int{xq_.Size[0], 1, xq_.Size[2]}); err != nil { // shape=[5,1,64] dtype=DT_COMPLEX
return nil, nil, err
}
if xqOut, err = ml.MultiplyElementwise(xq_, freqs_cis); err != nil { // shape=[5,32,64] dtype=DT_COMPLEX
return nil, nil, err
}
if xqOut, err = xqOut.ViewAsFloat32WithReshape(); err != nil { // shape=[5,32,128] dtype=DT_F32
return nil, nil, err
}
if xqOut, err = xqOut.ToBFloat16(); err != nil { // shape=[5,32,128] dtype=DT_BF16
return nil, nil, err
}
if xkOut, err = ml.MultiplyElementwise(xk_, freqs_cis); err != nil { // shape=[5,8,64] dtype=DT_COMPLEX
return nil, nil, err
}
if xkOut, err = xkOut.ViewAsFloat32WithReshape(); err != nil { // shape=[5,8,128] dtype=DT_F32
return nil, nil, err
}
if xkOut, err = xkOut.ToBFloat16(); err != nil { // shape=[5,8,128] dtype=DT_BF16
return nil, nil, err
}
return xqOut, xkOut, nil
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... applyRotaryEmbeddings results, xq: shape([22 32 128]), xk: shape([22 8 128]) ...
Diagram: Forward Pass Through Attention Module - Update KV cache. For the complete diagram, click here.
We have initiated an "inference context" with type of model.InferenceContext which we keep the state of one inference process. In this context object, we have two cache arrays: InferenceContext.CacheK
and InferenceContext.CacheV
which stand for "cache of keys" and "cache of values" respectively. These arrays have 32 items correspond to 32 layers. Each of these items consists of tensors with shape of {200, 8, 128}
. 200 stands for the maximum sequence length inferenceArgs.SequenceLength
, 8 stands for modelArgs.N_KVHeads
, 128 stands for modelArgs.HeadDim
.
Here, in our case of the first iteration, we set the cache of the 0th
layer. We set the slices of the CacheK
and CacheV
with index range 0
(startPos) to 22
(startPos + sequenceLength) to xk
and xv
tensors respectively. The sequenceLength
is the first dimension of the shape of x
tensor argument.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
/*
Update KV cache
*/
infContext.CacheK[lat.LayerIndex].SetSlice([]int{startPos}, []int{startPos + sequenceLength}, xk)
infContext.CacheV[lat.LayerIndex].SetSlice([]int{startPos}, []int{startPos + sequenceLength}, xv)
...
}
To make easy to understand how the KV cache is updated, think of a sample:
- Prompt tokens count is 22,
- While generation of the 1st token:
startPos
is0
,- The shape of
x
argument ofLlamaAttention.Forward(...)
is{22, 4096}
, - The shapes of
xk
andxv
are{22, 8, 128}
, - We update the indices of each cache
0:22
withxk
andxv
.
- While generation of the 2nd token:
startPos
is22
,- The shape of
x
argument ofLlamaAttention.Forward(...)
is{1, 4096}
(in the iterations except first, the tokens are processed one by one, because of this, the first dimension is 1), - The shapes of
xk
andxv
are{1, 8, 128}
, - We update the indices of each cache
22:23
withxk
andxv
.
- While generation of the 3rd token:
startPos
is23
,- The shape of
x
argument ofLlamaAttention.Forward(...)
is{1, 4096}
(in the iterations except first, the tokens are processed one by one, because of this, the first dimension is 1), - The shapes of
xk
andxv
are{1, 8, 128}
, - We update the indices of each cache
23:24
withxk
andxv
.
- So on...
Now, we take the cached keys and values for the all positions so far. To make easy to understand:
- Prompt tokens count is 22,
- While generation of the 1st token:
startPos
is0
,- The shape of
x
argument ofLlamaAttention.Forward(...)
is{22, 4096}
, - We take items at indices
0:22
ofCacheK
andCacheV
intokeys
andvalues
tensors respectively, becausestartPos + sequenceLength = 22
. Thekeys
andvalues
are with the shape of{22, 8, 128}
.
- While generation of the 2nd token:
startPos
is22
,- The shape of
x
argument ofLlamaAttention.Forward(...)
is{1, 4096}
(in the iterations except first, the tokens are processed one by one, because of this, the first dimension is 1), - We take items at indices
0:23
ofCacheK
andCacheV
intokeys
andvalues
tensors respectively, becausestartPos + sequenceLength = 23
. Thekeys
andvalues
are with the shape of{23, 8, 128}
.
- While generation of the 3rd token:
startPos
is23
,- The shape of
x
argument ofLlamaAttention.Forward(...)
is{1, 4096}
(in the iterations except first, the tokens are processed one by one, because of this, the first dimension is 1), - We take items at indices
0:24
ofCacheK
andCacheV
intokeys
andvalues
tensors respectively, becausestartPos + sequenceLength = 24
. Thekeys
andvalues
are with the shape of{24, 8, 128}
.
In this documentation, we cover only the first iteration of generating the first token. So, in our case, we retrieve items at indices 0:22
, the shapes of our keys
and values
are {22, 8, 128}
.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
/*
Retrieve cached KV so far
*/
keys, err := infContext.CacheK[lat.LayerIndex].Slice([]int{0}, []int{startPos + sequenceLength})
if err != nil {
return nil, err
}
values, err := infContext.CacheV[lat.LayerIndex].Slice([]int{0}, []int{startPos + sequenceLength})
if err != nil {
return nil, err
}
...
}
Repeating K/V heads step is only required if N_KVHeads < N_Heads
, in our case (Llama 3.1 8B-Instruct model) N_Rep = N_Heads / N_KVHeads = 4
, so we need to apply this step.
This operation is applied to keys
and values
tensors. In our case, our keys
and value
tensors are in shape of {22, 8, 128}
. The first 22
stands for sequence length, the second 8
stands for "attention key/value head count" modelArgs.N_KVHeads
, the 128
stands for "attention head dimension" modelArgs.HeadDim
.
Then, we need to equalize the attention head
counts of keys
and values
to query
tensors.
We've defined function to achieve this: attentionRepeatKV. This function:
- Creates a new tensor (named
expanded
) with shape of{sequenceLength, modelArgs.N_KVHeads, N_Rep, modelArgs.HeadDim} = {22, 8, 4, 128}
, - Reshapes input tensor from shape of
{sequenceLength, modelArgs.N_KVHeads, modelArgs.HeadDim} = {22, 8, 128}
to{sequenceLength, modelArgs.N_KVHeads, 1, modelArgs.HeadDim} = {22, 8, 1, 128}
, - By looping over
sequenceLength
,modelArgs.N_KVHeads
, andN_Rep
, copies the last dimension (modelArgs.HeadDim
) parts in count ofN_Rep
into the newexpanded
tensor, - Reshapes the
expanded
tensor from shape of{sequenceLength, modelArgs.N_KVHeads, N_Rep, modelArgs.HeadDim} = {22, 8, 4, 128}
to{sequenceLength, modelArgs.N_KVHeads * N_Rep, modelArgs.HeadDim} = {22, 32, 128}
, - Now, we have a repeated (key or value) tensor with same shape of
query
tensor.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
/*
Repeat k/v heads if N_KVHeads < N_Heads
*/
// example shape=[5, 8, 128] (cacheLen + sequenceLength, N_KVHeads, HeadDim)
if keys, err = attentionRepeatKV(keys, lat.N_Rep); err != nil { // example shape=[5, 32, 128] (cacheLen + sequenceLength, N_Heads, HeadDim)
return nil, err
}
// example shape=[5, 8, 128] (cacheLen + sequenceLength, N_KVHeads, HeadDim)
if values, err = attentionRepeatKV(values, lat.N_Rep); err != nil { // example shape=[5, 32, 128] (cacheLen + sequenceLength, N_Heads, HeadDim)
return nil, err
}
...
}
from src/model/llamatransformer.go
func attentionRepeatKV(x *ml.Tensor, N_Rep int) (*ml.Tensor, error) {
// See: https://github.com/meta-llama/llama-models/blob/f45cdfd624b98b6655540f7101d8d9cb432e631c/models/llama3_1/reference_impl/model.py#L103
if N_Rep == 1 {
return x, nil
}
sequenceLength, n_KVHeads, headDim := x.Size[0], x.Size[1], x.Size[2]
expanded := ml.NewEmptyTensor([]int{sequenceLength, n_KVHeads, N_Rep, headDim}, x.DataType)
var err error
x, err = x.Reshape([]int{sequenceLength, n_KVHeads, 1, headDim})
if err != nil {
return nil, err
}
for i := 0; i < sequenceLength; i++ {
for j := 0; j < n_KVHeads; j++ {
slice, err := x.Slice([]int{i, j, 0}, []int{i, j, 1})
if err != nil {
return nil, err
}
for rep := 0; rep < N_Rep; rep++ {
if err = expanded.SetSlice([]int{i, j, rep}, []int{i, j, rep + 1}, slice); err != nil {
return nil, err
}
}
}
}
if expanded, err = expanded.Reshape([]int{sequenceLength, n_KVHeads * N_Rep, headDim}); err != nil {
return nil, err
}
return expanded, nil
}
Diagram: Forward Pass Through Attention Module - Do transposes. For the complete diagram, click here.
Note: After applying previous "Repeat K/V heads" operation, all of our attention counts will be equal to
modelArgs.N_Heads = 32
, no more shapes withmodelArgs.N_KVHeads = 8
.
In this step, we need to perform some transpose operations:
- Transpose
xq
's0th
and1st
dimensions: from{sequenceLength, N_Heads, HeadDim} = {22, 32, 128}
to{N_Heads, sequenceLength, HeadDim} = {32, 22, 128}
,
In the sample at the code comments, sequenceLength is 5, and the operation is from
{5, 32, 128}
to{32, 5, 128}
- Transpose
keys
's0th
and1st
dimensions: from{sequenceLength, N_Heads, HeadDim} = {22, 32, 128}
to{N_Heads, sequenceLength, HeadDim} = {32, 22, 128}
,
In the sample at the code comments, sequenceLength is 5, and the operation is from
{5, 32, 128}
to{32, 5, 128}
- Transpose
values
's0th
and1st
dimensions: from{sequenceLength, N_Heads, HeadDim} = {22, 32, 128}
to{N_Heads, sequenceLength, HeadDim} = {32, 22, 128}
,
In the sample at the code comments, sequenceLength is 5, and the operation is from
{5, 32, 128}
to{32, 5, 128}
- Transpose
keys
's1st
and2nd
dimensions: from{N_Heads, sequenceLength, HeadDim} = {32, 22, 128}
to{N_Heads, HeadDim, sequenceLength} = {32, 128, 22}
.
In the sample at the code comments, sequenceLength is 5, and the operation is from
{32, 5, 128}
to{32, 128, 5}
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
/*
Do transposes
*/
if xq, err = xq.Transpose(0, 1); err != nil { // from [5, 32, 128] -> example shape=[32, 5, 128] (N_Heads, sequenceLength, HeadDim)
return nil, err
}
if keys, err = keys.Transpose(0, 1); err != nil { // from [5, 32, 128] -> example shape=[32, 5, 128] (N_Heads, sequenceLength, HeadDim)
return nil, err
}
if values, err = values.Transpose(0, 1); err != nil { // from [5, 32, 128] -> example shape=[32, 5, 128] (N_Heads, sequenceLength, HeadDim)
return nil, err
}
if keys, err = keys.Transpose(1, 2); err != nil { // from [32, 5, 128] -> example shape=[32, 128, 5] (N_Heads, HeadDim, sequenceLength)
return nil, err
}
common.GLogger.DebugPrintf("Multiple transposing results, xq: shape(%v), keys: shape(%v), values: shape(%v)", xq.Size, keys.Size, values.Size)
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Multiple transposing results, xq: shape([32 22 128]), keys: shape([32 128 22]), values: shape([32 22 128]) ...
Diagram: Forward Pass Through Attention Module - Calculate scores. For the complete diagram, click here.
# Goal in Python manner:
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
We calculate the scores which we will perform Softmax operation over.
- Perform a matrix multiplication over
xq
with shape of{32, 22, 128}
andkeys
with shape of{32, 128, 22}
(transpose operation has been performed in previous step already), - Take square root of
lat.HeadDim = 128
is11.3125
in BFloat16 form, - Divide all items of the result of matrix multiplication
xqMatMulKeys
to11.3125
, assign the result intoscores
.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
common.GLogger.DebugPrintf("Calling ml.MatMul for xq shape(%v) and keys shape(%v) -> tensor xqMatMulKeys", xq.Size, keys.Size)
xqMatMulKeys, err := ml.MatMul(xq, keys) // matmul([32,5,128], [32,128,5]) -> example shape=[32,5,5] (N_Heads, sequenceLength, sequenceLength)
if err != nil {
return nil, err
}
common.GLogger.DebugPrintf("Calling ml.DivToScalar for xqMatMulKeys shape(%v) and scalar -> tensor scores", xqMatMulKeys.Size)
scores, err := ml.DivToScalar(xqMatMulKeys, dtype.BFloat16fromFloat32(float32(math.Sqrt(float64(lat.HeadDim))))) // example shape=[32,5,5]
if err != nil {
return nil, err
}
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling ml.MatMul for xq shape([32 22 128]) and keys shape([32 128 22]) -> tensor xqMatMulKeys ...
[DEBUG] ... Calling ml.DivToScalar for xqMatMulKeys shape([32 22 22]) and scalar -> tensor scores ...
If there is a given mask
argument, perform masking operation. This is because Llama is an auto-regressive model and our mask contains triangular matrix consisting of 0
s and -Inf (negative infinity)
s. For more information, refer to 14.2.3. Creating the mask tensor.
By performing ml.Add(...) operation over scores
with shape of {32, 22, 22}
and mask
tensor with shape of {22, 22}
, we take the items corresponding on 0
mask values and ignore the items corresponding on -Inf
mask values (adding -Inf
to a number makes the number -Inf
).
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
if mask != nil {
common.GLogger.DebugPrintf("Calling ml.Add to calculate scores shape(%v) + mask shape(%v) -> tensor scores", scores.Size, mask.Size)
if scores, err = ml.Add(scores, mask); err != nil { // example shape=[32,5,5]
return nil, err
}
} else {
common.GLogger.DebugPrintf("Skipping addition scores + mask")
}
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling ml.Add to calculate scores shape([32 22 22]) + mask shape([22 22]) -> tensor scores ...
# Goal in Python manner:
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
In this step, we perform Softmax operation over the scores.
For more information, refer to: torch.nn.Softmax.
To achieve this:
- Convert the
scores
tensor data type toDT_F32
(float32), - Call ml.Softmax function,
- Convert the result data type to
DT_BF16
and assign intoscores
tensor.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
common.GLogger.DebugPrintf("Converting scores tensor shape(%v) to Float32 tensor -> tensor scores", scores.Size)
scores, err = scores.ToFloat32() // example shape=[32,5,5] dtype=DT_F32
if err != nil {
return nil, err
}
common.GLogger.DebugPrintf("Calling ml.Softmax for scores shape(%v) and dim %d -> tensor scores", scores.Size, len(scores.Size)-1)
if scores, err = ml.Softmax(scores, len(scores.Size)-1); err != nil { // example shape=[32,5,5] dtype=DT_F32
return nil, err
}
common.GLogger.DebugPrintf("Converting scores tensor shape(%v) to BFloat16 tensor -> tensor scores", scores.Size)
if scores, err = scores.ToBFloat16(); err != nil { // example shape=[32,5,5] (N_Heads, sequenceLength, sequenceLength) dtype=DT_BF16
return nil, err
}
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Converting scores tensor shape([32 22 22]) to Float32 tensor -> tensor scores ...
[DEBUG] ... Calling ml.Softmax for scores shape([32 22 22]) and dim 2 -> tensor scores ...
[DEBUG] ... Converting scores tensor shape([32 22 22]) to BFloat16 tensor -> tensor scores ...
Diagram: Forward Pass Through Attention Module - Calculate output. For the complete diagram, click here.
# Goal in Python manner:
output = torch.matmul(scores, values)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
- Perform a matrix multiplication over
scores
with shape of{32, 22, 22}
andvalues
with shape of{32, 32, 128}
, assign the result with shape of{32, 22, 128}
intooutput
tensor, - Transpose the
output
's0th
and1st
dimensions to shape of{22, 32, 128}
, - Reshape the
output
to the shape of{22, 4096} = {sequenceLength, output.GetElementCount() / sequenceLength}
,
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
common.GLogger.DebugPrintf("Calling ml.MatMul for scores shape(%v) and values shape(%v) -> tensor output", scores.Size, values.Size)
output, err := ml.MatMul(scores, values)
if err != nil {
return nil, err
}
if output, err = output.Transpose(0, 1); err != nil {
return nil, err
}
outputTrailingSize := output.GetElementCount() / sequenceLength
if output, err = output.Reshape([]int{sequenceLength, outputTrailingSize}); err != nil {
return nil, err
}
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling ml.MatMul for scores shape([32 22 22]) and values shape([32 22 128]) -> tensor output ...
We have the output weights of our attention module in the lat.attn_wo
tensor. We perform a linear transformation with our output
tensor (with the shape of {32, 4096}
) with the lat.attn_wo
weights tensor (with shape of {4096, 4096}
). Then, we return this result with the shape of {32, 4096}
as output of the attention model.
from src/model/llamatransformer.go
func (lat *LlamaAttention) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
/*
Apply lat.attn_wo weights to output
*/
common.GLogger.DebugPrintf("Calling ml.LinearTransformation for output shape(%v) and LlamaAttention.attn_wo weights shape(%v) -> tensor output", output.Size, lat.attn_wo.Size)
// lat.attn_wo: [out_features, in_features] -> shape: [4096 4096] -> [N_Heads * HeadDim, Dim]
if output, err = ml.LinearTransformation(output, lat.attn_wo); err != nil {
return nil, err
}
common.GLogger.DebugPrintf("Returning tensor output: shape(%v)", output.Size)
return output, nil
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling ml.LinearTransformation for output shape([22 4096]) and LlamaAttention.attn_wo weights shape([4096 4096]) -> tensor output ...
[DEBUG] ... Returning tensor output: shape([22 4096]) ...
Diagram: Add attention module output and current tensor. For the complete diagram, click here.
Now, we returned to our latest position in the LlamaTransformerBlock.Forward(...)
method.
We had the x
tensor argument. In our case, at the first iteration, x
is input tensor, at the other iterations, x
is output of previous LlamaTransformerBlock
. In our case, at the first iteration, the shape of this tensor is {22, 4096}
. 22 stands for sequence length, 4096 stands for the embedding layer dimension. normalizedX
which is the resulting tensor will have same shape as the input, {22, 4096}
.
Also, we have the h
tensor with the shape of {22, 4096}
, which is the output of our attention module LlamaAttention
.
We add x
and h
tensors via ml.Add(...) function and assign the result into h
tensor.
from src/model/llamatransformer.go
func (ltb *LlamaTransformerBlock) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
common.GLogger.DebugPrintf("Calling ml.Add to calculate x shape(%v) + h shape(%v) -> tensor h", x.Size, h.Size)
if h, err = ml.Add(x, h); err != nil {
return nil, err
}
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling ml.Add to calculate x shape([22 4096]) + h shape([22 4096]) -> tensor h ...
Diagram: Forward Pass Through Feed-Forward Pre-normalization. For the complete diagram, click here.
Now, we have the h
tensor. We perform RMSNorm over h
tensor with normalization weights of ltb.ffn_norm
, and assign the result into normalizedH
which is the resulting tensor will have the same shape as the h
, {22, 4096}
.
from src/model/llamatransformer.go
func (ltb *LlamaTransformerBlock) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
common.GLogger.DebugPrintf("Calling RMSNorm for tensor h shape(%v) and LlamaTransformerBlock.ffn_norm weights shape(%v) -> tensor normalizedH", x.Size, ltb.ffn_norm.weights.Size)
normalizedH, err := ltb.ffn_norm.Forward(infContext, h)
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling RMSNorm for tensor h shape([22 4096]) and LlamaTransformerBlock.ffn_norm weights shape([4096]) -> tensor normalizedH ...
from src/model/llamatransformer.go
func (ltb *LlamaTransformerBlock) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
common.GLogger.DebugPrintf("Calling LlamaFeedForward.Forward for tensor normalizedH shape(%v) -> tensor ffnOutput", normalizedH.Size)
ffnOutput, err := ltb.feedForward.Forward(normalizedH)
...
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling LlamaFeedForward.Forward for tensor normalizedH shape([22 4096]) -> tensor ffnOutput ...
Diagram: Forward Pass Through Feed-Forward Module. For the complete diagram, click here.
# Goal in Python manner:
self.w2(F.silu(self.w1(x)) * self.w3(x))
# Python code with our variable names:
self.ffn_down(F.silu(self.ffn_gate(x)) * self.ffn_up(x))
In this stage, we have a feed-forward neural network module consisting multiple weight tensors for each of 32 transformer block layers, with names: w1
, w2
, and w3
in original Python repository, ffn_gate
, ffn_down
, and ffn_up
in our project, respectively.
The steps are:
- Perform a linear transformation over
x
with shape of{22, 4096}
andlff.ffn_gate
weights with shape of{14336, 4096}
, assign the resulting tensor with shape of{22, 14336}
intoh
, - Perform Sigmoid Linear Unit (SiLU) function over the
h
tensor via ml.Silu(...) function,For more information, refer to: torch.nn.SiLU.
- Perform a linear transformation over
x
with shape of{22, 4096}
andlff.ffn_up
weights with shape of{14336, 4096}
, assign the resulting tensor with shape of{22, 14336}
intoffnUpX
, - Perform an element-wise multiplication with
h
tensor with shape of{22, 14336}
andffnUpX
tensor with shape of{22, 14336}
via ml.MultiplyElementwise. Output shape is{22, 14336}
, assign the result intoh
variable, - Perform a linear transformation over
h
with shape of{22, 14336}
andlff.ffn_down
weights with shape of{4096, 14336}
, assign the resulting tensor with shape of{22, 4096}
intooutput
, - Return the
output
tensor.
from src/model/llamatransformer.go
func (lff *LlamaFeedForward) Forward(x *ml.Tensor) (*ml.Tensor, error) {
...
common.GLogger.DebugPrintf("Calling ml.LinearTransformation for x shape(%v) and LlamaFeedForward.ffn_gate weights shape(%v) -> tensor h", x.Size, lff.ffn_gate.Size)
h, err := ml.LinearTransformation(x, lff.ffn_gate)
if err != nil {
return nil, err
}
common.GLogger.DebugPrintf("Calling ml.Silu for h shape(%v) -> tensor h", h.Size)
if h, err = ml.Silu(h); err != nil {
return nil, err
}
common.GLogger.DebugPrintf("Calling ml.LinearTransformation for x shape(%v) and LlamaFeedForward.ffn_up weights shape(%v) -> tensor ffnUpX", x.Size, lff.ffn_up.Size)
ffnUpX, err := ml.LinearTransformation(x, lff.ffn_up)
if err != nil {
return nil, err
}
common.GLogger.DebugPrintf("Calling ml.MultiplyElementwise for h shape(%v) and ffnUpX weights shape(%v) -> tensor ffnUpX", h.Size, ffnUpX.Size)
if h, err = ml.MultiplyElementwise(h, ffnUpX); err != nil {
return nil, err
}
common.GLogger.DebugPrintf("Calling ml.LinearTransformation for h shape(%v) and LlamaFeedForward.ffn_down weights shape(%v) -> tensor output", h.Size, lff.ffn_down.Size)
output, err := ml.LinearTransformation(h, lff.ffn_down)
if err != nil {
return nil, err
}
return output, nil
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling ml.LinearTransformation for x shape([22 4096]) and LlamaFeedForward.ffn_gate weights shape([14336 4096]) -> tensor h ...
[DEBUG] ... Calling ml.LinearTransformation for x shape([22 4096]) and LlamaFeedForward.ffn_up weights shape([14336 4096]) -> tensor ffnUpX ...
[DEBUG] ... Calling ml.MultiplyElementwise for h shape([22 14336]) and ffnUpX weights shape([22 14336]) -> tensor ffnUpX ...
[DEBUG] ... Calling ml.LinearTransformation for h shape([22 14336]) and LlamaFeedForward.ffn_down weights shape([4096 14336]) -> tensor output ...
Diagram: Add Feed-Forward module output and current tensor. For the complete diagram, click here.
Now, we returned to our latest position in the LlamaTransformerBlock.Forward(...)
method.
- We have the
ffnOutput
tensor with the shape of{22, 4096}
, which is the output of our feed-forward neural network moduleLlamaFeedForward
. - Also, we had the
h
tensor as current tensor with the shape of{22, 4096}
, which is the output of our attention moduleLlamaAttention
. - We add
h
andffnOutput
tensors via ml.Add(...) function and assign the result with shape of{22, 4096}
intooutput
tensor, - Return it as output of
LlamaTransformerBlock
.
from src/model/llamatransformer.go
func (ltb *LlamaTransformerBlock) Forward(infContext *InferenceContext, x *ml.Tensor, startPos int, freqsCis *ml.Tensor, mask *ml.Tensor) (*ml.Tensor, error) {
...
common.GLogger.DebugPrintf("Calling ml.Add to calculate h shape(%v) + ffnOutput shape(%v) -> tensor output", h.Size, ffnOutput.Size)
output, err := ml.Add(h, ffnOutput)
if err != nil {
return nil, err
}
common.GLogger.DebugPrintf("Returning tensor output: shape(%v)", output.Size)
return output, nil
}
We can see output lines in the "debug.log" file if debugging is enabled, as follows:
[DEBUG] ... Calling ml.Add to calculate h shape([22 4096]) + ffnOutput shape([22 4096]) -> tensor output ...
[DEBUG] ... Returning tensor output: shape([22 4096]) ...
The flow will continue with next LlamaTransformerBlock
layer.