Skip to content

Commit

Permalink
fix: estimate
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed Aug 20, 2024
1 parent 74e97d2 commit 71623de
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions file_estimate.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,18 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) (
// IO,
// see https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L4930-L5002.
e.Devices[0].Weight.Input = GGUFBytesScalar(ipLs.Bytes())
var op GGUFBytesScalar
if _, ok := opLs.Get("output.weight"); ok {
e.Devices[0].Weight.Output = GGUFBytesScalar(opLs.Bytes())
op = GGUFBytesScalar(opLs.Bytes())
} else if a.AttentionCausal {
e.Devices[0].Weight.Output = GGUFBytesScalar(opLs.Bytes()) + e.Devices[0].Weight.Input /* duplicate the input layer */
op = GGUFBytesScalar(opLs.Bytes()) + e.Devices[0].Weight.Input /* duplicate the input layer */
}
if fullOffload {
for i := range e.Devices[1:] {
e.Devices[i+1].Weight.Output = op
}
} else {
e.Devices[0].Weight.Output = op
}
}

Expand Down Expand Up @@ -431,11 +439,12 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) (
rs := GGMLTypeF32.RowSizeOf([]uint64{uint64(a.SSMInnerSize)*nTokens + uint64(a.SSMStateSize)*uint64(a.SSMInnerSize)*nKV})
ssmInc += rs
}
cp := GGUFBytesScalar(convInc + ssmInc)
for i, d := range e.Devices[1:] {
if d.LastLayer < 0 {
if d.LastLayer < 0 && i != 0 {
continue
}
e.Devices[i+1].Computation.Compute = GGUFBytesScalar(convInc + ssmInc)
e.Devices[i+1].Computation.Compute = cp
}
default:
loadAttnInc, offloadAttnInc := uint64(0), uint64(0)
Expand Down Expand Up @@ -468,19 +477,15 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) (
case strings.HasSuffix(l.Name, ".attn_q.weight"):
rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[0], nTokens})
offloadAttnInc += rs * 2 // Qcur, Qcur + RoPE.
if !isOffloadOutputLayer {
loadAttnInc = rs // Vcur.
}
loadAttnInc = rs // Vcur.
rs = GGMLTypeF32.RowSizeOf([]uint64{nKV, nTokens, a.AttentionHeadCount})
offloadAttnInc += rs // kq.
rs = o.CacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV})
offloadAttnInc += rs * 2 // k-?, v-?.
case strings.HasSuffix(l.Name, ".attn_qkv.weight"):
rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[0], nTokens})
offloadAttnInc += rs * 2 // Qcur, Qcur + RoPE.
if !isOffloadOutputLayer {
loadAttnInc = rs // Vcur.
}
loadAttnInc = rs // Vcur.
rs = GGMLTypeF32.RowSizeOf([]uint64{nKV, nTokens, a.AttentionHeadCount})
offloadAttnInc += rs // kq.
rs = o.CacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV})
Expand All @@ -493,12 +498,20 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) (
rs := GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[l.NDimensions-1], nTokens})
ffnInc += rs
}
e.Devices[0].Computation.Compute = GGUFBytesScalar(loadAttnInc)
switch {
case fullOffload:
e.Devices[0].Computation.Compute = GGUFBytesScalar(loadAttnInc*uint64(len(e.Devices)) + ffnInc)
case partialOffload:
e.Devices[0].Computation.Compute = GGUFBytesScalar(loadAttnInc + ffnInc)
case zeroOffload:
e.Devices[0].Computation.Compute = GGUFBytesScalar(loadAttnInc)
}
cp := GGUFBytesScalar(max(offloadAttnInc, ffnInc))
for i, d := range e.Devices[1:] {
if d.LastLayer < 0 {
if d.LastLayer < 0 && i != 0 {
continue
}
e.Devices[i+1].Computation.Compute = GGUFBytesScalar(max(offloadAttnInc, ffnInc))
e.Devices[i+1].Computation.Compute = cp
}
// Special case: we cannot use mmap for splitting expert weights in MoE.
if a.ExpertCount > 0 {
Expand Down Expand Up @@ -689,5 +702,5 @@ func (u LLaMACppKVCacheUsage) Sum() GGUFBytesScalar {
}

func (u LLaMACppComputationUsage) Sum() GGUFBytesScalar {
return u.Footprint + max(u.Input+u.Compute, u.Output)
return u.Footprint + u.Input + max(u.Compute, u.Output)
}

0 comments on commit 71623de

Please sign in to comment.