Skip to content

Commit

Permalink
make DepthAnythingV2 more reusable (#2675)
Browse files Browse the repository at this point in the history
* make DepthAnythingV2 more reusable

* Fix clippy lints.

---------

Co-authored-by: laurent <[email protected]>
  • Loading branch information
edgarriba and LaurentMazare authored Dec 21, 2024
1 parent 67cab7d commit 5c2f893
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
6 changes: 2 additions & 4 deletions candle-examples/examples/depth_anything_v2/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use std::ffi::OsString;
use std::path::PathBuf;

use clap::Parser;
use std::{ffi::OsString, path::PathBuf, sync::Arc};

use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
Expand Down Expand Up @@ -82,7 +80,7 @@ pub fn main() -> anyhow::Result<()> {
};

let config = DepthAnythingV2Config::vit_small();
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
let depth_anything = DepthAnythingV2::new(Arc::new(dinov2), config, vb)?;

let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;

Expand Down
44 changes: 25 additions & 19 deletions candle-transformers/src/models/depth_anything_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
//! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything)
//!
use std::sync::Arc;

use candle::D::Minus1;
use candle::{Module, Result, Tensor};
use candle_nn::ops::Identity;
Expand Down Expand Up @@ -365,16 +367,18 @@ impl Scratch {

const NUM_CHANNELS: usize = 4;

pub struct DPTHead<'a> {
conf: &'a DepthAnythingV2Config,
pub struct DPTHead {
projections: Vec<Conv2d>,
resize_layers: Vec<Box<dyn Module>>,
readout_projections: Vec<Sequential>,
scratch: Scratch,
use_class_token: bool,
input_image_size: usize,
target_patch_size: usize,
}

impl<'a> DPTHead<'a> {
pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
impl DPTHead {
pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
projections.push(conv2d(
Expand Down Expand Up @@ -445,20 +449,22 @@ impl<'a> DPTHead<'a> {
let scratch = Scratch::new(conf, vb.pp("scratch"))?;

Ok(Self {
conf,
projections,
resize_layers,
readout_projections,
scratch,
use_class_token: conf.use_class_token,
input_image_size: conf.input_image_size,
target_patch_size: conf.target_patch_size,
})
}
}

impl Module for DPTHead<'_> {
impl Module for DPTHead {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
for i in 0..NUM_CHANNELS {
let x = if self.conf.use_class_token {
let x = if self.use_class_token {
let x = xs.get(i)?.get(0)?;
let class_token = xs.get(i)?.get(1)?;
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
Expand All @@ -473,8 +479,8 @@ impl Module for DPTHead<'_> {
let x = x.permute((0, 2, 1))?.reshape((
x_dims[0],
x_dims[x_dims.len() - 1],
self.conf.target_patch_size,
self.conf.target_patch_size,
self.target_patch_size,
self.target_patch_size,
))?;
let x = self.projections[i].forward(&x)?;

Expand Down Expand Up @@ -515,25 +521,25 @@ impl Module for DPTHead<'_> {

let out = self.scratch.output_conv1.forward(&path1)?;

let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
let out = out.interpolate2d(self.input_image_size, self.input_image_size)?;

self.scratch.output_conv2.forward(&out)
}
}

pub struct DepthAnythingV2<'a> {
pretrained: &'a DinoVisionTransformer,
depth_head: DPTHead<'a>,
conf: &'a DepthAnythingV2Config,
pub struct DepthAnythingV2 {
pretrained: Arc<DinoVisionTransformer>,
depth_head: DPTHead,
conf: DepthAnythingV2Config,
}

impl<'a> DepthAnythingV2<'a> {
impl DepthAnythingV2 {
pub fn new(
pretrained: &'a DinoVisionTransformer,
conf: &'a DepthAnythingV2Config,
pretrained: Arc<DinoVisionTransformer>,
conf: DepthAnythingV2Config,
vb: VarBuilder,
) -> Result<Self> {
let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
let depth_head = DPTHead::new(&conf, vb.pp("depth_head"))?;

Ok(Self {
pretrained,
Expand All @@ -543,7 +549,7 @@ impl<'a> DepthAnythingV2<'a> {
}
}

impl Module for DepthAnythingV2<'_> {
impl Module for DepthAnythingV2 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let features = self.pretrained.get_intermediate_layers(
xs,
Expand Down

0 comments on commit 5c2f893

Please sign in to comment.