Skip to content

Commit da074b2

Browse files
authored
Merge pull request #831 from admiralakber/chore/update-llama-cpp-b6482
Update llama.cpp to b6482 (3d4053f)
2 parents d3a31a9 + 192ab01 commit da074b2

File tree

5 files changed

+29
-60
lines changed

5 files changed

+29
-60
lines changed

llama-cpp-2/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repository = "https://github.com/utilityai/llama-cpp-rs"
1010

1111
[dependencies]
1212
enumflags2 = "0.7.12"
13-
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.113" }
13+
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.122" }
1414
thiserror = { workspace = true }
1515
tracing = { workspace = true }
1616
tracing-core = { workspace = true }
@@ -35,7 +35,7 @@ mtmd = ["llama-cpp-sys-2/mtmd"]
3535

3636

3737
[target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies]
38-
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.113", features = [
38+
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.122", features = [
3939
"metal",
4040
] }
4141

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ impl LlamaContext<'_> {
2828
/// * `dest` - The sequence id to copy the cache to.
2929
/// * `size` - The size of the cache to copy.
3030
pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
31-
unsafe { llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, 0, size) }
31+
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
32+
unsafe { llama_cpp_sys_2::llama_memory_seq_cp(mem, src, dest, 0, size) }
3233
}
3334

3435
/// Copy the cache from one sequence to another.
@@ -57,9 +58,8 @@ impl LlamaContext<'_> {
5758
let p1 = p1
5859
.map_or(Ok(-1), i32::try_from)
5960
.map_err(KvCacheConversionError::P1TooLarge)?;
60-
unsafe {
61-
llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
62-
}
61+
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
62+
unsafe { llama_cpp_sys_2::llama_memory_seq_cp(mem, src, dest, p0, p1) };
6363
Ok(())
6464
}
6565

@@ -92,18 +92,15 @@ impl LlamaContext<'_> {
9292
let p1 = p1
9393
.map_or(Ok(-1), i32::try_from)
9494
.map_err(KvCacheConversionError::P1TooLarge)?;
95-
Ok(unsafe { llama_cpp_sys_2::llama_kv_self_seq_rm(self.context.as_ptr(), src, p0, p1) })
96-
}
97-
98-
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
99-
#[must_use]
100-
pub fn get_kv_cache_used_cells(&self) -> i32 {
101-
unsafe { llama_cpp_sys_2::llama_kv_self_used_cells(self.context.as_ptr()) }
95+
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
96+
Ok(unsafe { llama_cpp_sys_2::llama_memory_seq_rm(mem, src, p0, p1) })
10297
}
10398

10499
/// Clear the KV cache
105100
pub fn clear_kv_cache(&mut self) {
106-
unsafe { llama_cpp_sys_2::llama_kv_self_clear(self.context.as_ptr()) }
101+
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
102+
// clear both metadata and data buffers to match previous semantics
103+
unsafe { llama_cpp_sys_2::llama_memory_clear(mem, true) }
107104
}
108105

109106
/// Removes all tokens that do not belong to the specified sequence
@@ -112,7 +109,8 @@ impl LlamaContext<'_> {
112109
///
113110
/// * `seq_id` - The sequence id to keep
114111
pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
115-
unsafe { llama_cpp_sys_2::llama_kv_self_seq_keep(self.context.as_ptr(), seq_id) }
112+
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
113+
unsafe { llama_cpp_sys_2::llama_memory_seq_keep(mem, seq_id) }
116114
}
117115

118116
#[allow(clippy::doc_markdown)]
@@ -146,9 +144,8 @@ impl LlamaContext<'_> {
146144
let p1 = p1
147145
.map_or(Ok(-1), i32::try_from)
148146
.map_err(KvCacheConversionError::P1TooLarge)?;
149-
unsafe {
150-
llama_cpp_sys_2::llama_kv_self_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
151-
}
147+
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
148+
unsafe { llama_cpp_sys_2::llama_memory_seq_add(mem, seq_id, p0, p1, delta) };
152149
Ok(())
153150
}
154151

@@ -183,7 +180,8 @@ impl LlamaContext<'_> {
183180
.map_or(Ok(-1), i32::try_from)
184181
.map_err(KvCacheConversionError::P1TooLarge)?;
185182
let d = c_int::from(d.get());
186-
unsafe { llama_cpp_sys_2::llama_kv_self_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
183+
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
184+
unsafe { llama_cpp_sys_2::llama_memory_seq_div(mem, seq_id, p0, p1, d) }
187185
Ok(())
188186
}
189187

@@ -194,19 +192,7 @@ impl LlamaContext<'_> {
194192
/// * `seq_id` - The sequence id to get the max position for
195193
#[must_use]
196194
pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
197-
unsafe { llama_cpp_sys_2::llama_kv_self_seq_pos_max(self.context.as_ptr(), seq_id) }
198-
}
199-
200-
/// Defragment the KV cache
201-
/// This will be applied:
202-
/// - lazily on next [`LlamaContext::decode`]
203-
/// - explicitly with [`Self::kv_cache_update`]
204-
pub fn kv_cache_defrag(&mut self) {
205-
unsafe { llama_cpp_sys_2::llama_kv_self_defrag(self.context.as_ptr()) }
206-
}
207-
208-
/// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
209-
pub fn kv_cache_update(&mut self) {
210-
unsafe { llama_cpp_sys_2::llama_kv_self_update(self.context.as_ptr()) }
195+
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
196+
unsafe { llama_cpp_sys_2::llama_memory_seq_pos_max(mem, seq_id) }
211197
}
212198
}

llama-cpp-2/src/context/params.rs

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -335,34 +335,20 @@ impl LlamaContextParams {
335335
self.context_params.n_ubatch
336336
}
337337

338-
/// Set the `flash_attention` parameter
339-
///
340-
/// # Examples
341-
///
342-
/// ```rust
343-
/// use llama_cpp_2::context::params::LlamaContextParams;
344-
/// let params = LlamaContextParams::default()
345-
/// .with_flash_attention(true);
346-
/// assert_eq!(params.flash_attention(), true);
347-
/// ```
338+
/// Set the flash attention policy using llama.cpp enum
348339
#[must_use]
349-
pub fn with_flash_attention(mut self, enabled: bool) -> Self {
350-
self.context_params.flash_attn = enabled;
340+
pub fn with_flash_attention_policy(
341+
mut self,
342+
policy: llama_cpp_sys_2::llama_flash_attn_type,
343+
) -> Self {
344+
self.context_params.flash_attn_type = policy;
351345
self
352346
}
353347

354-
/// Get the `flash_attention` parameter
355-
///
356-
/// # Examples
357-
///
358-
/// ```rust
359-
/// use llama_cpp_2::context::params::LlamaContextParams;
360-
/// let params = LlamaContextParams::default();
361-
/// assert_eq!(params.flash_attention(), false);
362-
/// ```
348+
/// Get the flash attention policy
363349
#[must_use]
364-
pub fn flash_attention(&self) -> bool {
365-
self.context_params.flash_attn
350+
pub fn flash_attention_policy(&self) -> llama_cpp_sys_2::llama_flash_attn_type {
351+
self.context_params.flash_attn_type
366352
}
367353

368354
/// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU

llama-cpp-2/src/model/params.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,6 @@ impl LlamaModelParams {
180180
/// ```
181181
/// # use llama_cpp_2::model::params::LlamaModelParams;
182182
/// let params = LlamaModelParams::default();
183-
/// #[cfg(not(target_os = "macos"))]
184-
/// assert_eq!(params.n_gpu_layers(), 0, "n_gpu_layers should be 0");
185-
/// #[cfg(target_os = "macos")]
186183
/// assert_eq!(params.n_gpu_layers(), 999, "n_gpu_layers should be 999");
187184
/// assert_eq!(params.main_gpu(), 0, "main_gpu should be 0");
188185
/// assert_eq!(params.vocab_only(), false, "vocab_only should be false");

llama-cpp-sys-2/llama.cpp

Submodule llama.cpp updated 431 files

0 commit comments

Comments
 (0)