Skip to content

Commit

Permalink
Remove either dep and fix settings.json
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jul 26, 2024
1 parent 558b1a9 commit 5db24b1
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.testing.pytestEnabled": true
}
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ tracing-subscriber = "0.3.7"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
either = { version = "1.13.0", features = ["serde"] }

[profile.release-with-debug]
inherits = "release"
Expand Down
1 change: 0 additions & 1 deletion candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
either = { workspace = true }

[dev-dependencies]
anyhow = { workspace = true }
Expand Down
17 changes: 6 additions & 11 deletions candle-examples/examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use either::Either;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;

Expand Down Expand Up @@ -171,7 +170,7 @@ fn main() -> Result<()> {
let eos_token_id = config.eos_token_id.or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
.map(|x| model::LlamaEosToks(Either::Left(x)))
.map(|x| model::LlamaEosToks::Single(x))

Check failure on line 173 in candle-examples/examples/llama/main.rs

View workflow job for this annotation

GitHub Actions / Clippy

redundant closure
});
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
Expand Down Expand Up @@ -231,17 +230,13 @@ fn main() -> Result<()> {
tokens.push(next_token);

match eos_token_id {
Some(model::LlamaEosToks(Either::Left(eos_tok_id))) => {
if next_token == eos_tok_id {
break;
}
Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
break;
}
Some(model::LlamaEosToks(Either::Right(ref eos_ids))) => {
if eos_ids.contains(&next_token) {
break;
}
Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
break;
}
None => (),
_ => (),
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
Expand Down
1 change: 0 additions & 1 deletion candle-transformers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ byteorder = { workspace = true }
candle = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
candle-nn = { workspace = true }
either = { workspace = true }
fancy-regex = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
Expand Down
7 changes: 5 additions & 2 deletions candle-transformers/src/models/llama.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use either::Either;
use std::{collections::HashMap, f32::consts::PI};

pub const MAX_SEQ_LEN: usize = 4096;
Expand All @@ -24,7 +23,11 @@ pub struct Llama3RopeConfig {
pub rope_type: Llama3RopeType,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaEosToks(#[serde(with = "either::serde_untagged")] pub Either<u32, Vec<u32>>);
#[serde(untagged)]
pub enum LlamaEosToks {
Single(u32),
Multiple(Vec<u32>),
}

#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
Expand Down
3 changes: 1 addition & 2 deletions candle-transformers/src/models/llava/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::models::{
clip::{text_model::Activation, vision_model::ClipVisionConfig},
llama::{Config, LlamaEosToks},
};
use either::Either;
use serde::{Deserialize, Serialize};

// original config from liuhaotian/llava
Expand Down Expand Up @@ -74,7 +73,7 @@ impl LLaVAConfig {
rms_norm_eps: self.rms_norm_eps as f64,
rope_theta: self.rope_theta,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(LlamaEosToks(Either::Left(self.eos_token_id as u32))),
eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)),
use_flash_attn: false,
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
}
Expand Down

0 comments on commit 5db24b1

Please sign in to comment.