Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Feb 6, 2024
1 parent e8e32d1 commit 594c8b0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 16 deletions.
23 changes: 8 additions & 15 deletions rllm-cuda/src/llamacpp/tmodel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,23 +208,16 @@ impl TModel {
}

fn sample_argmax(&self, logits: &Tensor) -> u32 {
#[cfg(feature = "tch")]
{
logits.argmax(0, false).int64_value(&[]) as u32
}
#[cfg(not(feature = "tch"))]
{
let data = logits.as_slice();
let mut top = data[0];
let mut top_idx = 0;
for (i, x) in data.iter().enumerate() {
if *x > top {
top = *x;
top_idx = i;
}
let data = logits.as_slice();
let mut top = data[0];
let mut top_idx = 0;
for (i, x) in data.iter().enumerate() {
if *x > top {
top = *x;
top_idx = i;
}
top_idx as u32
}
top_idx as u32
}

fn sample_multinomial(&self, state: &mut LogitsProcessor, prs: &Vec<f32>) -> Result<u32> {
Expand Down
1 change: 0 additions & 1 deletion rllm-cuda/src/llm/tmodel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ impl ModelExec for TModel {
num_seqs: usize,
vocab_size: usize,
) -> Self::AiciBias {
#[cfg(feature = "tch")]
let tensor = Tensor::from_slice(slice)
.to(self.config.model.device)
.reshape(&[num_seqs as i64, vocab_size as i64]);
Expand Down

0 comments on commit 594c8b0

Please sign in to comment.