Skip to content

Commit

Permalink
Merge branch 'main' of github.com:AnubhabB/candle into ones-impl
Browse files Browse the repository at this point in the history
  • Loading branch information
AnubhabB committed Sep 28, 2024
2 parents 2f37d25 + 62525e8 commit 453ae03
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 27 deletions.
22 changes: 0 additions & 22 deletions candle-examples/examples/mobileclip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];

for path in paths {
let tensor = candle_examples::imagenet::load_image_with_std_mean(
path,
Expand All @@ -70,67 +69,49 @@ fn load_images<T: AsRef<std::path::Path>>(
)?;
images.push(tensor);
}

let images = Tensor::stack(&images, 0)?;

Ok(images)
}

pub fn main() -> anyhow::Result<()> {
let args = Args::parse();

let model_name = args.which.model_name();

let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);

let model_file = if args.use_pth {
api.get("open_clip_pytorch_model.bin")?
} else {
api.get("open_clip_model.safetensors")?
};

let tokenizer = api.get("tokenizer.json")?;

let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;

let config = &args.which.config();

let device = candle_examples::device(args.cpu)?;

let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};

let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;

let vb = if args.use_pth {
VarBuilder::from_pth(&model_file, DType::F32, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }
};

let model = mobileclip::MobileClipModel::new(vb, config)?;

let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;

let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;

let softmax_image = softmax(&logits_per_image, 1)?;

let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;

println!("softmax_image_vec: {:?}", softmax_image_vec);

let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();

let probability_per_image = probability_vec.len() / vec_imgs.len();

for (i, img) in vec_imgs.iter().enumerate() {
Expand Down Expand Up @@ -171,7 +152,6 @@ pub fn tokenize_sequences(
};

let mut tokens = vec![];

for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
Expand All @@ -185,8 +165,6 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]);
}
}

let input_ids = Tensor::new(tokens, device)?;

Ok((input_ids, vec_seq))
}
1 change: 0 additions & 1 deletion candle-transformers/src/models/fastvit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,6 @@ fn fastvit_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Resul
.apply(&stage3)?
.apply(&stage4)?
.apply(&final_conv)?;

match &cls {
None => Ok(xs),
Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls),
Expand Down
4 changes: 0 additions & 4 deletions candle-transformers/src/models/mobileclip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ impl MobileClipConfig {
pub fn s1() -> Self {
let text_config = text_model::Config::vit_base_patch32();
let vision_config = fastvit::Config::mci1();

Self {
text_config,
vision_config,
Expand All @@ -32,7 +31,6 @@ impl MobileClipConfig {
pub fn s2() -> Self {
let text_config = text_model::Config::vit_base_patch32();
let vision_config = fastvit::Config::mci2();

Self {
text_config,
vision_config,
Expand All @@ -45,12 +43,10 @@ impl MobileClipModel {
pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> {
let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?;
let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?;

let text_projection = vs.get(
(c.text_config.embed_dim, c.text_config.projection_dim),
"text.text_projection",
)?;

let logit_scale = vs.get(&[], "logit_scale")?;
Ok(Self {
text_model,
Expand Down

0 comments on commit 453ae03

Please sign in to comment.