diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 18de03e1df8016..211193979db64a 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -689,6 +689,8 @@
title: NAT
- local: model_doc/poolformer
title: PoolFormer
+ - local: model_doc/prompt_depth_anything
+ title: Prompt Depth Anything
- local: model_doc/pvt
title: Pyramid Vision Transformer (PVT)
- local: model_doc/pvt_v2
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 967049d89cbe12..a96d7eee4f2fae 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -275,6 +275,7 @@ Flax), PyTorch, and/or TensorFlow.
| [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ |
| [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ |
| [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ |
+| [PromptDepthAnything](model_doc/prompt_depth_anything) | ✅ | ❌ | ❌ |
| [ProphetNet](model_doc/prophetnet) | ✅ | ❌ | ❌ |
| [PVT](model_doc/pvt) | ✅ | ❌ | ❌ |
| [PVTv2](model_doc/pvt_v2) | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/prompt_depth_anything.md b/docs/source/en/model_doc/prompt_depth_anything.md
new file mode 100644
index 00000000000000..d1fc0200904464
--- /dev/null
+++ b/docs/source/en/model_doc/prompt_depth_anything.md
@@ -0,0 +1,94 @@
+
+
+# Prompt Depth Anything
+
+## Overview
+
+The Prompt Depth Anything model was introduced in [Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation](https://arxiv.org/abs/2412.14015) by Haotong Lin, Sida Peng, Jingxiao Chen, Songyou Peng, Jiaming Sun, Minghuan Liu, Hujun Bao, Jiashi Feng, Xiaowei Zhou, Bingyi Kang.
+
+
+The abstract from the paper is as follows:
+
+*Prompts play a critical role in unleashing the power of language and vision foundation models for specific tasks. For the first time, we introduce prompting into depth foundation models, creating a new paradigm for metric depth estimation termed Prompt Depth Anything. Specifically, we use a low-cost LiDAR as the prompt to guide the Depth Anything model for accurate metric depth output, achieving up to 4K resolution. Our approach centers on a concise prompt fusion design that integrates the LiDAR at multiple scales within the depth decoder. To address training challenges posed by limited datasets containing both LiDAR depth and precise GT depth, we propose a scalable data pipeline that includes synthetic data LiDAR simulation and real data pseudo GT depth generation. Our approach sets new state-of-the-arts on the ARKitScenes and ScanNet++ datasets and benefits downstream applications, including 3D reconstruction and generalized robotic grasping.*
+
+
+
+ Prompt Depth Anything overview. Taken from the original paper.
+
+## Usage example
+
+The Transformers library allows you to use the model with just a few lines of code:
+
+```python
+>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
+>>> import torch
+>>> import numpy as np
+>>> from PIL import Image
+>>> import requests
+
+>>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
+>>> image = Image.open(requests.get(url, stream=True).raw)
+
+>>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
+>>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
+
+>>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
+>>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
+>>> # the prompt depth can be None, and the model will output a monocular relative depth.
+
+>>> # prepare image for the model
+>>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth)
+
+>>> with torch.no_grad():
+... outputs = model(**inputs)
+
+>>> # interpolate to original size
+>>> post_processed_output = image_processor.post_process_depth_estimation(
+... outputs,
+... target_sizes=[(image.height, image.width)],
+... )
+
+>>> # visualize the prediction
+>>> predicted_depth = post_processed_output[0]["predicted_depth"]
+>>> depth = predicted_depth * 1000
+>>> depth = depth.detach().cpu().numpy()
+>>> depth = Image.fromarray(depth.astype("uint16")) # mm
+```
+
+## Resources
+
+A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Prompt Depth Anything.
+
+- [Prompt Depth Anything Demo](https://huggingface.co/spaces/depth-anything/PromptDA)
+- [Prompt Depth Anything Interactive Results](https://promptda.github.io/interactive.html)
+
+If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
+
+## PromptDepthAnythingConfig
+
+[[autodoc]] PromptDepthAnythingConfig
+
+## PromptDepthAnythingForDepthEstimation
+
+[[autodoc]] PromptDepthAnythingForDepthEstimation
+ - forward
+
+## PromptDepthAnythingImageProcessor
+
+[[autodoc]] PromptDepthAnythingImageProcessor
+ - preprocess
\ No newline at end of file
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 5510ac6c8ad512..e808276bce6a09 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -689,6 +689,7 @@
"models.plbart": ["PLBartConfig"],
"models.poolformer": ["PoolFormerConfig"],
"models.pop2piano": ["Pop2PianoConfig"],
+ "models.prompt_depth_anything": ["PromptDepthAnythingConfig"],
"models.prophetnet": [
"ProphetNetConfig",
"ProphetNetTokenizer",
@@ -1246,6 +1247,7 @@
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
_import_structure["models.pixtral"].append("PixtralImageProcessor")
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
+ _import_structure["models.prompt_depth_anything"].extend(["PromptDepthAnythingImageProcessor"])
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
_import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"])
@@ -3181,6 +3183,12 @@
"Pop2PianoPreTrainedModel",
]
)
+ _import_structure["models.prompt_depth_anything"].extend(
+ [
+ "PromptDepthAnythingForDepthEstimation",
+ "PromptDepthAnythingPreTrainedModel",
+ ]
+ )
_import_structure["models.prophetnet"].extend(
[
"ProphetNetDecoder",
@@ -5682,6 +5690,7 @@
from .models.pop2piano import (
Pop2PianoConfig,
)
+ from .models.prompt_depth_anything import PromptDepthAnythingConfig
from .models.prophetnet import (
ProphetNetConfig,
ProphetNetTokenizer,
@@ -6260,6 +6269,7 @@
PoolFormerFeatureExtractor,
PoolFormerImageProcessor,
)
+ from .models.prompt_depth_anything import PromptDepthAnythingImageProcessor
from .models.pvt import PvtImageProcessor
from .models.qwen2_vl import Qwen2VLImageProcessor
from .models.rt_detr import RTDetrImageProcessor
@@ -7819,6 +7829,10 @@
Pop2PianoForConditionalGeneration,
Pop2PianoPreTrainedModel,
)
+ from .models.prompt_depth_anything import (
+ PromptDepthAnythingForDepthEstimation,
+ PromptDepthAnythingPreTrainedModel,
+ )
from .models.prophetnet import (
ProphetNetDecoder,
ProphetNetEncoder,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 7fcaddde704cf7..fc59e75971effd 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -207,6 +207,7 @@
plbart,
poolformer,
pop2piano,
+ prompt_depth_anything,
prophetnet,
pvt,
pvt_v2,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 69ce8efa10c76c..e46fd0d18227ce 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -227,6 +227,7 @@
("plbart", "PLBartConfig"),
("poolformer", "PoolFormerConfig"),
("pop2piano", "Pop2PianoConfig"),
+ ("prompt_depth_anything", "PromptDepthAnythingConfig"),
("prophetnet", "ProphetNetConfig"),
("pvt", "PvtConfig"),
("pvt_v2", "PvtV2Config"),
@@ -554,6 +555,7 @@
("plbart", "PLBart"),
("poolformer", "PoolFormer"),
("pop2piano", "Pop2Piano"),
+ ("prompt_depth_anything", "PromptDepthAnything"),
("prophetnet", "ProphetNet"),
("pvt", "PVT"),
("pvt_v2", "PVTv2"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index db25591eaa3544..d08b22721ada3a 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -123,6 +123,7 @@
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("poolformer", ("PoolFormerImageProcessor",)),
+ ("prompt_depth_anything", ("PromptDepthAnythingImageProcessor",)),
("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)),
("qwen2_vl", ("Qwen2VLImageProcessor",)),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index e8a2dece432476..52cb7223923b91 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -893,6 +893,7 @@
("depth_anything", "DepthAnythingForDepthEstimation"),
("dpt", "DPTForDepthEstimation"),
("glpn", "GLPNForDepthEstimation"),
+ ("prompt_depth_anything", "PromptDepthAnythingForDepthEstimation"),
("zoedepth", "ZoeDepthForDepthEstimation"),
]
)
diff --git a/src/transformers/models/prompt_depth_anything/__init__.py b/src/transformers/models/prompt_depth_anything/__init__.py
new file mode 100644
index 00000000000000..3cb05f8e378874
--- /dev/null
+++ b/src/transformers/models/prompt_depth_anything/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_prompt_depth_anything import PromptDepthAnythingConfig
+ from .image_processing_prompt_depth_anything import PromptDepthAnythingImageProcessor
+ from .modeling_prompt_depth_anything import (
+ PromptDepthAnythingForDepthEstimation,
+ PromptDepthAnythingPreTrainedModel,
+ )
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py
new file mode 100644
index 00000000000000..4852afb9c84f51
--- /dev/null
+++ b/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py
@@ -0,0 +1,159 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_prompt_depth_anything.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+
+import copy
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import verify_backbone_config_arguments
+from ..auto.configuration_auto import CONFIG_MAPPING
+
+
+logger = logging.get_logger(__name__)
+
+
+class PromptDepthAnythingConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PromptDepthAnythingModel`]. It is used to instantiate a PromptDepthAnything
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the PromptDepthAnything
+ [LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
+ The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
+ leverage the [`AutoBackbone`] API.
+ backbone (`str`, *optional*):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to use pretrained weights for the backbone.
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
+ API.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size of the patches to extract from the backbone features.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ reassemble_hidden_size (`int`, *optional*, defaults to 384):
+ The number of input channels of the reassemble layers.
+ reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
+ The up/downsampling factors of the reassemble layers.
+ neck_hidden_sizes (`List[str]`, *optional*, defaults to `[48, 96, 192, 384]`):
+ The hidden sizes to project to for the feature maps of the backbone.
+ fusion_hidden_size (`int`, *optional*, defaults to 64):
+ The number of channels before fusion.
+ head_in_index (`int`, *optional*, defaults to -1):
+ The index of the features to use in the depth estimation head.
+ head_hidden_size (`int`, *optional*, defaults to 32):
+ The number of output channels in the second convolution of the depth estimation head.
+ depth_estimation_type (`str`, *optional*, defaults to `"relative"`):
+ The type of depth estimation to use. Can be one of `["relative", "metric"]`.
+ max_depth (`float`, *optional*):
+ The maximum depth to use for the "metric" depth estimation head. 20 should be used for indoor models
+ and 80 for outdoor models. For "relative" depth estimation, this value is ignored.
+
+ Example:
+
+ ```python
+ >>> from transformers import PromptDepthAnythingConfig, PromptDepthAnythingForDepthEstimation
+
+ >>> # Initializing a PromptDepthAnything small style configuration
+ >>> configuration = PromptDepthAnythingConfig()
+
+ >>> # Initializing a model from the PromptDepthAnything small style configuration
+ >>> model = PromptDepthAnythingForDepthEstimation(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "prompt_depth_anything"
+
+ def __init__(
+ self,
+ backbone_config=None,
+ backbone=None,
+ use_pretrained_backbone=False,
+ use_timm_backbone=False,
+ backbone_kwargs=None,
+ patch_size=14,
+ initializer_range=0.02,
+ reassemble_hidden_size=384,
+ reassemble_factors=[4, 2, 1, 0.5],
+ neck_hidden_sizes=[48, 96, 192, 384],
+ fusion_hidden_size=64,
+ head_in_index=-1,
+ head_hidden_size=32,
+ depth_estimation_type="relative",
+ max_depth=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if backbone_config is None and backbone is None:
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.")
+ backbone_config = CONFIG_MAPPING["dinov2"](
+ image_size=518,
+ hidden_size=384,
+ num_attention_heads=6,
+ out_indices=[9, 10, 11, 12],
+ apply_layernorm=True,
+ reshape_hidden_states=False,
+ )
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.get("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ verify_backbone_config_arguments(
+ use_timm_backbone=use_timm_backbone,
+ use_pretrained_backbone=use_pretrained_backbone,
+ backbone=backbone,
+ backbone_config=backbone_config,
+ backbone_kwargs=backbone_kwargs,
+ )
+
+ self.backbone_config = backbone_config
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.use_timm_backbone = use_timm_backbone
+ self.backbone_kwargs = backbone_kwargs
+ self.reassemble_hidden_size = reassemble_hidden_size
+ self.patch_size = patch_size
+ self.initializer_range = initializer_range
+ self.reassemble_factors = reassemble_factors
+ self.neck_hidden_sizes = neck_hidden_sizes
+ self.fusion_hidden_size = fusion_hidden_size
+ self.head_in_index = head_in_index
+ self.head_hidden_size = head_hidden_size
+ if depth_estimation_type not in ["relative", "metric"]:
+ raise ValueError("depth_estimation_type must be one of ['relative', 'metric']")
+ self.depth_estimation_type = depth_estimation_type
+ self.max_depth = max_depth if max_depth else 1
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+
+ if output["backbone_config"] is not None:
+ output["backbone_config"] = self.backbone_config.to_dict()
+
+ output["model_type"] = self.__class__.model_type
+ return output
+
+
+__all__ = ["PromptDepthAnythingConfig"]
diff --git a/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py
new file mode 100644
index 00000000000000..8dfeff03ad2706
--- /dev/null
+++ b/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py
@@ -0,0 +1,292 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Prompt Depth Anything checkpoints from the original repository. URL:
+https://github.com/DepthAnything/PromptDA"""
+
+import argparse
+import re
+from pathlib import Path
+
+import requests
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from transformers import (
+ Dinov2Config,
+ PromptDepthAnythingConfig,
+ PromptDepthAnythingForDepthEstimation,
+ PromptDepthAnythingImageProcessor,
+)
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_dpt_config(model_name):
+ if "small" in model_name or "vits" in model_name:
+ out_indices = [3, 6, 9, 12]
+ backbone_config = Dinov2Config.from_pretrained(
+ "facebook/dinov2-small", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
+ )
+ fusion_hidden_size = 64
+ neck_hidden_sizes = [48, 96, 192, 384]
+ elif "base" in model_name or "vitb" in model_name:
+ out_indices = [3, 6, 9, 12]
+ backbone_config = Dinov2Config.from_pretrained(
+ "facebook/dinov2-base", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
+ )
+ fusion_hidden_size = 128
+ neck_hidden_sizes = [96, 192, 384, 768]
+ elif "large" in model_name or "vitl" in model_name:
+ out_indices = [5, 12, 18, 24]
+ backbone_config = Dinov2Config.from_pretrained(
+ "facebook/dinov2-large", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
+ )
+ fusion_hidden_size = 256
+ neck_hidden_sizes = [256, 512, 1024, 1024]
+ else:
+ raise NotImplementedError(f"Model not supported: {model_name}")
+
+ depth_estimation_type = "metric"
+ max_depth = None
+
+ config = PromptDepthAnythingConfig(
+ reassemble_hidden_size=backbone_config.hidden_size,
+ patch_size=backbone_config.patch_size,
+ backbone_config=backbone_config,
+ fusion_hidden_size=fusion_hidden_size,
+ neck_hidden_sizes=neck_hidden_sizes,
+ depth_estimation_type=depth_estimation_type,
+ max_depth=max_depth,
+ )
+
+ return config
+
+
+def transform_qkv_weights(key, value, config):
+ if not key.startswith("qkv_transform"):
+ return value
+
+ layer_idx = int(key.split("_")[-1])
+ hidden_size = config.backbone_config.hidden_size
+
+ suffix = "bias" if "bias" in key else "weight"
+ return {
+ f"backbone.encoder.layer.{layer_idx}.attention.attention.query.{suffix}": value[:hidden_size],
+ f"backbone.encoder.layer.{layer_idx}.attention.attention.key.{suffix}": value[hidden_size : hidden_size * 2],
+ f"backbone.encoder.layer.{layer_idx}.attention.attention.value.{suffix}": value[-hidden_size:],
+ }
+
+
+ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+ # Stem
+ r"pretrained.cls_token": r"backbone.embeddings.cls_token",
+ r"pretrained.mask_token": r"backbone.embeddings.mask_token",
+ r"pretrained.pos_embed": r"backbone.embeddings.position_embeddings",
+ r"pretrained.patch_embed.proj.(weight|bias)": r"backbone.embeddings.patch_embeddings.projection.\1",
+ # Backbone
+ r"pretrained.norm.(weight|bias)": r"backbone.layernorm.\1",
+ # Transformer layers
+ r"pretrained.blocks.(\d+).ls1.gamma": r"backbone.encoder.layer.\1.layer_scale1.lambda1",
+ r"pretrained.blocks.(\d+).ls2.gamma": r"backbone.encoder.layer.\1.layer_scale2.lambda1",
+ r"pretrained.blocks.(\d+).norm1.(weight|bias)": r"backbone.encoder.layer.\1.norm1.\2",
+ r"pretrained.blocks.(\d+).norm2.(weight|bias)": r"backbone.encoder.layer.\1.norm2.\2",
+ r"pretrained.blocks.(\d+).mlp.fc1.(weight|bias)": r"backbone.encoder.layer.\1.mlp.fc1.\2",
+ r"pretrained.blocks.(\d+).mlp.fc2.(weight|bias)": r"backbone.encoder.layer.\1.mlp.fc2.\2",
+ r"pretrained.blocks.(\d+).attn.proj.(weight|bias)": r"backbone.encoder.layer.\1.attention.output.dense.\2",
+ r"pretrained.blocks.(\d+).attn.qkv.(weight|bias)": r"qkv_transform_\2_\1",
+ # Neck
+ r"depth_head.projects.(\d+).(weight|bias)": r"neck.reassemble_stage.layers.\1.projection.\2",
+ r"depth_head.scratch.layer(\d+)_rn.weight": lambda m: f"neck.convs.{int(m.group(1))-1}.weight",
+ r"depth_head.resize_layers.(\d+).(weight|bias)": r"neck.reassemble_stage.layers.\1.resize.\2",
+ # Refinenet (with reversed indices)
+ r"depth_head.scratch.refinenet(\d+).out_conv.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.projection.{m.group(2)}",
+ r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv1.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer1.convolution1.{m.group(2)}",
+ r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer1.convolution2.{m.group(2)}",
+ r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv1.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer2.convolution1.{m.group(2)}",
+ r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer2.convolution2.{m.group(2)}",
+ r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.0.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.prompt_depth_layer.convolution1.{m.group(2)}",
+ r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.prompt_depth_layer.convolution2.{m.group(2)}",
+ r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.4.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.prompt_depth_layer.convolution3.{m.group(2)}",
+ # Head
+ r"depth_head.scratch.output_conv1.(weight|bias)": r"head.conv1.\1",
+ r"depth_head.scratch.output_conv2.0.(weight|bias)": r"head.conv2.\1",
+ r"depth_head.scratch.output_conv2.2.(weight|bias)": r"head.conv3.\1",
+}
+
+
+def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
+ """
+ Convert old state dict keys to new keys using regex patterns.
+ """
+ output_dict = {}
+ if state_dict_keys is not None:
+ for old_key in state_dict_keys:
+ new_key = old_key
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
+ match = re.match(pattern, old_key)
+ if match:
+ if callable(replacement):
+ new_key = replacement(match)
+ else:
+ new_key = re.sub(pattern, replacement, old_key)
+ break
+ output_dict[old_key] = new_key
+ return output_dict
+
+
+@torch.no_grad()
+def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits):
+ """
+ Copy/paste/tweak model's weights to our DPT structure.
+ """
+
+ # define DPT configuration
+ config = get_dpt_config(model_name)
+
+ model_name_to_repo = {
+ "prompt-depth-anything-vits": "depth-anything/prompt-depth-anything-vits",
+ "prompt-depth-anything-vits-transparent": "depth-anything/prompt-depth-anything-vits-transparent",
+ "prompt-depth-anything-vitl": "depth-anything/prompt-depth-anything-vitl",
+ }
+
+ # load original state_dict
+ repo_id = model_name_to_repo[model_name]
+ filename = name_to_checkpoint[model_name]
+ filepath = hf_hub_download(
+ repo_id=repo_id,
+ filename=f"{filename}",
+ )
+
+ state_dict = torch.load(filepath, map_location="cpu")["state_dict"]
+ state_dict = {key[9:]: state_dict[key] for key in state_dict}
+
+ # Convert state dict using mappings
+ key_mapping = convert_old_keys_to_new_keys(state_dict.keys())
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ new_key = key_mapping[key]
+ transformed_value = transform_qkv_weights(new_key, value, config)
+ if isinstance(transformed_value, dict):
+ new_state_dict.update(transformed_value)
+ else:
+ new_state_dict[new_key] = transformed_value
+
+ # load HuggingFace model
+ model = PromptDepthAnythingForDepthEstimation(config)
+ model.load_state_dict(new_state_dict, strict=False)
+ model.eval()
+
+ processor = PromptDepthAnythingImageProcessor(
+ do_resize=True,
+ size=756,
+ ensure_multiple_of=14,
+ keep_aspect_ratio=True,
+ do_rescale=True,
+ do_normalize=True,
+ image_mean=[0.485, 0.456, 0.406],
+ image_std=[0.229, 0.224, 0.225],
+ )
+ url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
+ image = Image.open(requests.get(url, stream=True).raw)
+
+ prompt_depth_url = (
+ "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
+ )
+ prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
+
+ inputs = processor(image, return_tensors="pt", prompt_depth=prompt_depth)
+
+ # Verify forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+ predicted_depth = outputs.predicted_depth
+
+ print("Shape of predicted depth:", predicted_depth.shape)
+ print("First values:", predicted_depth[0, :3, :3])
+
+ # assert logits
+ if verify_logits:
+ expected_shape = torch.Size([1, 756, 1008])
+ if model_name == "prompt-depth-anything-vits":
+ expected_slice = torch.tensor(
+ [[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]]
+ )
+ elif model_name == "prompt-depth-anything-vits-transparent":
+ expected_slice = torch.tensor(
+ [[3.0058, 3.0397, 3.0460], [3.0314, 3.0393, 3.0504], [3.0326, 3.0465, 3.0545]]
+ )
+ elif model_name == "prompt-depth-anything-vitl":
+ expected_slice = torch.tensor(
+ [[3.1336, 3.1358, 3.1363], [3.1368, 3.1267, 3.1414], [3.1397, 3.1385, 3.1448]]
+ )
+ else:
+ raise ValueError("Not supported")
+ assert predicted_depth.shape == torch.Size(expected_shape)
+ assert torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=5e-3) # 5mm tolerance
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print("Pushing model and processor to hub...")
+ model.push_to_hub(repo_id=f"{model_name.title()}-hf")
+ processor.push_to_hub(repo_id=f"{model_name.title()}-hf")
+
+
+name_to_checkpoint = {
+ "prompt-depth-anything-vits": "model.ckpt",
+ "prompt-depth-anything-vits-transparent": "model.ckpt",
+ "prompt-depth-anything-vitl": "model.ckpt",
+}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="prompt_depth_anything_vits",
+ type=str,
+ choices=name_to_checkpoint.keys(),
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether to push the model to the hub after conversion.",
+ )
+ parser.add_argument(
+ "--verify_logits",
+ action="store_false",
+ required=False,
+ help="Whether to verify the logits after conversion.",
+ )
+
+ args = parser.parse_args()
+ convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits)
diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py
new file mode 100644
index 00000000000000..75de664cc46af2
--- /dev/null
+++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py
@@ -0,0 +1,512 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for PromptDepthAnything."""
+
+import math
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+
+
+if TYPE_CHECKING:
+ from ...modeling_outputs import DepthEstimatorOutput
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import pad, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_torch_available,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ is_vision_available,
+ logging,
+ requires_backends,
+)
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ pass
+
+
+logger = logging.get_logger(__name__)
+
+
+def _constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
+ x = round(val / multiple) * multiple
+
+ if max_val is not None and x > max_val:
+ x = math.floor(val / multiple) * multiple
+
+ if x < min_val:
+ x = math.ceil(val / multiple) * multiple
+
+ return x
+
+
+def _get_resize_output_image_size(
+ input_image: np.ndarray,
+ output_size: Union[int, Iterable[int]],
+ keep_aspect_ratio: bool,
+ multiple: int,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> Tuple[int, int]:
+ output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
+
+ input_height, input_width = get_image_size(input_image, input_data_format)
+ output_height, output_width = output_size
+
+ # determine new height and width
+ scale_height = output_height / input_height
+ scale_width = output_width / input_width
+
+ if keep_aspect_ratio:
+ # scale as little as possible
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+
+ new_height = _constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
+ new_width = _constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
+
+ return (new_height, new_width)
+
+
+class PromptDepthAnythingImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a PromptDepthAnything image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the image after resizing. Can be overidden by `size` in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
+ be overidden by `keep_aspect_ratio` in `preprocess`.
+ ensure_multiple_of (`int`, *optional*, defaults to 1):
+ If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
+ by `ensure_multiple_of` in `preprocess`.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
+ `preprocess`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `False`):
+ Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in
+ combination with DPT.
+ size_divisor (`int`, *optional*):
+ If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
+ DINOv2 paper, which uses the model in combination with DPT.
+ prompt_scale_to_meter (`float`, *optional*, defaults to 0.001):
+ Scale factor to convert the prompt depth to meters.
+ """
+
+ model_input_names = ["pixel_values", "prompt_depth"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ keep_aspect_ratio: bool = False,
+ ensure_multiple_of: int = 1,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: bool = False,
+ size_divisor: int = None,
+ prompt_scale_to_meter: float = 0.001, # default unit is mm
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.size = size
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.ensure_multiple_of = ensure_multiple_of
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.do_pad = do_pad
+ self.size_divisor = size_divisor
+ self.prompt_scale_to_meter = prompt_scale_to_meter
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ keep_aspect_ratio: bool = False,
+ ensure_multiple_of: int = 1,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
+ is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
+ set, the image is resized to a size that is a multiple of this value.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Target size of the output image.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
+ ensure_multiple_of (`int`, *optional*, defaults to 1):
+ The image is resized to a size that is a multiple of this value.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
+
+ output_size = _get_resize_output_image_size(
+ image,
+ output_size=(size["height"], size["width"]),
+ keep_aspect_ratio=keep_aspect_ratio,
+ multiple=ensure_multiple_of,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def pad_image(
+ self,
+ image: np.ndarray,
+ size_divisor: int,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Center pad an image to be a multiple of `multiple`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ size_divisor (`int`):
+ The width and height of the image will be padded to a multiple of this number.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+
+ def _get_pad(size, size_divisor):
+ new_size = math.ceil(size / size_divisor) * size_divisor
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ height, width = get_image_size(image, input_data_format)
+
+ pad_size_left, pad_size_right = _get_pad(height, size_divisor)
+ pad_size_top, pad_size_bottom = _get_pad(width, size_divisor)
+
+ padded_image = pad(
+ image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format
+ )
+ return padded_image
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ prompt_depth: Optional[ImageInput] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[int] = None,
+ keep_aspect_ratio: Optional[bool] = None,
+ ensure_multiple_of: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ size_divisor: Optional[int] = None,
+ prompt_scale_to_meter: Optional[float] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ prompt_depth (`ImageInput`, *optional*):
+ Prompt depth to preprocess, which can be sparse depth obtained from multi-view geometry or
+ low-resolution depth from a depth sensor. Generally has shape (height, width), where height
+ and width can be smaller than those of the images. It's optional and can be None, which means no prompt depth
+ is used. If it is None, the output depth will be a monocular relative depth.
+ It is recommended to provide a prompt_scale_to_meter value, which is the scale factor to convert the prompt depth
+ to meters. This is useful when the prompt depth is not in meters.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest
+ possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is
+ resized to a size that is a multiple of this value.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
+ Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If
+ True, the image will be resized to keep the aspect ratio and the size will be the maximum possible.
+ ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
+ Ensure that the image size is a multiple of this value.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ prompt_scale_to_meter (`float`, *optional*, defaults to `self.prompt_scale_to_meter`):
+ Scale factor to convert the prompt depth to meters.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
+ ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ size_divisor = size_divisor if size_divisor is not None else self.size_divisor
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_pad=do_pad,
+ size_divisibility=size_divisor,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ preprocessed_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(
+ image=image,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ input_data_format=input_data_format,
+ )
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ if do_pad:
+ image = self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format)
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ preprocessed_images.append(image)
+
+ images = preprocessed_images
+
+ data = {"pixel_values": images}
+ if prompt_depth is not None:
+ # prompt_depth is a list of images with shape (height, width)
+ # we need to convert it to a list of images with shape (1, height, width)
+ prompt_depths = make_list_of_images(prompt_depth, expected_ndims=2)
+ assert len(prompt_depths) == len(images)
+
+ # Validate prompt_depths has same length as images
+ if len(prompt_depths) != len(images):
+ raise ValueError(
+ f"Number of prompt depth images ({len(prompt_depths)}) does not match number of input images ({len(images)})"
+ )
+
+ if prompt_scale_to_meter is None:
+ prompt_scale_to_meter = self.prompt_scale_to_meter
+
+ processed_prompt_depths = []
+ for depth in prompt_depths:
+ depth = to_numpy_array(depth)
+ depth = depth * prompt_scale_to_meter
+ if depth.min() == depth.max():
+ # Prompt depth is invalid, min and max are the same.
+ # We can simply randomly select one pixel and set it to a small value.
+ EPS = 1e-6
+ random_x = np.random.randint(0, depth.shape[0])
+ random_y = np.random.randint(0, depth.shape[1])
+ depth[random_x, random_y] = depth[0, 0] + EPS
+ depth = depth[..., None].astype(np.float32)
+ depth = to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format)
+
+ processed_prompt_depths.append(depth)
+ prompt_depths = processed_prompt_depths
+ data["prompt_depth"] = prompt_depths
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ # Copied from transformers.models.dpt.image_processing_dpt.DPTImageProcessor.post_process_depth_estimation with DPT->PromptDepthAnything
+ def post_process_depth_estimation(
+ self,
+ outputs: "DepthEstimatorOutput",
+ target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
+ ) -> List[Dict[str, TensorType]]:
+ """
+ Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
+ Only supports PyTorch.
+
+ Args:
+ outputs ([`DepthEstimatorOutput`]):
+ Raw outputs of the model.
+ target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
+
+ Returns:
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
+ predictions.
+ """
+ requires_backends(self, "torch")
+
+ predicted_depth = outputs.predicted_depth
+
+ if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
+ )
+
+ results = []
+ target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
+ for depth, target_size in zip(predicted_depth, target_sizes):
+ if target_size is not None:
+ depth = torch.nn.functional.interpolate(
+ depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
+ ).squeeze()
+
+ results.append({"predicted_depth": depth})
+
+ return results
+
+
+__all__ = ["PromptDepthAnythingImageProcessor"]
diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py
new file mode 100644
index 00000000000000..2b4a6ea0ef7f99
--- /dev/null
+++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py
@@ -0,0 +1,530 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_prompt_depth_anything.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from transformers.utils.generic import torch_int
+
+from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
+from ...modeling_outputs import DepthEstimatorOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils.backbone_utils import load_backbone
+from .configuration_prompt_depth_anything import PromptDepthAnythingConfig
+
+
+# General docstring
+_CONFIG_FOR_DOC = "PromptDepthAnythingConfig"
+
+
+class PromptDepthAnythingLayer(nn.Module):
+ def __init__(self, config: PromptDepthAnythingConfig):
+ super().__init__()
+ self.convolution1 = nn.Conv2d(
+ 1,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+ self.activation1 = nn.ReLU()
+
+ self.convolution2 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+ self.activation2 = nn.ReLU()
+
+ self.convolution3 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+
+ def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.convolution1(prompt_depth)
+ hidden_state = self.activation1(hidden_state)
+ hidden_state = self.convolution2(hidden_state)
+ hidden_state = self.activation2(hidden_state)
+ hidden_state = self.convolution3(hidden_state)
+ return hidden_state
+
+
+class PromptDepthAnythingPreActResidualLayer(nn.Module):
+ """
+ ResidualConvUnit, pre-activate residual unit.
+
+ Args:
+ config (`[PromptDepthAnythingConfig]`):
+ Model configuration class defining the model architecture.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.activation1 = nn.ReLU()
+ self.convolution1 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+
+ self.activation2 = nn.ReLU()
+ self.convolution2 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ residual = hidden_state
+ hidden_state = self.activation1(hidden_state)
+ hidden_state = self.convolution1(hidden_state)
+ hidden_state = self.activation2(hidden_state)
+ hidden_state = self.convolution2(hidden_state)
+
+ return hidden_state + residual
+
+
+class PromptDepthAnythingFeatureFusionLayer(nn.Module):
+ """Feature fusion layer, merges feature maps from different stages.
+
+ Args:
+ config (`[PromptDepthAnythingConfig]`):
+ Model configuration class defining the model architecture.
+ """
+
+ def __init__(self, config: PromptDepthAnythingConfig):
+ super().__init__()
+
+ self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
+
+ self.residual_layer1 = PromptDepthAnythingPreActResidualLayer(config)
+ self.residual_layer2 = PromptDepthAnythingPreActResidualLayer(config)
+ self.prompt_depth_layer = PromptDepthAnythingLayer(config)
+
+ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None):
+ if residual is not None:
+ if hidden_state.shape != residual.shape:
+ residual = nn.functional.interpolate(
+ residual, size=hidden_state.shape[2:], mode="bilinear", align_corners=False
+ )
+ hidden_state = hidden_state + self.residual_layer1(residual)
+
+ hidden_state = self.residual_layer2(hidden_state)
+
+ if prompt_depth is not None:
+ prompt_depth = nn.functional.interpolate(
+ prompt_depth, size=hidden_state.shape[2:], mode="bilinear", align_corners=False
+ )
+ res = self.prompt_depth_layer(prompt_depth)
+ hidden_state = hidden_state + res
+
+ modifier = {"scale_factor": 2} if size is None else {"size": size}
+
+ hidden_state = nn.functional.interpolate(
+ hidden_state,
+ **modifier,
+ mode="bilinear",
+ align_corners=True,
+ )
+ hidden_state = self.projection(hidden_state)
+
+ return hidden_state
+
+
+class PromptDepthAnythingFeatureFusionStage(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layers = nn.ModuleList()
+ for _ in range(len(config.neck_hidden_sizes)):
+ self.layers.append(PromptDepthAnythingFeatureFusionLayer(config))
+
+ def forward(self, hidden_states, size=None, prompt_depth=None):
+ # reversing the hidden_states, we start from the last
+ hidden_states = hidden_states[::-1]
+
+ fused_hidden_states = []
+ fused_hidden_state = None
+
+ for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)):
+ size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None
+
+ if fused_hidden_state is None:
+ # first layer only uses the last hidden_state
+ fused_hidden_state = layer(hidden_state, size=size, prompt_depth=prompt_depth)
+ else:
+ fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size, prompt_depth=prompt_depth)
+
+ fused_hidden_states.append(fused_hidden_state)
+
+ return fused_hidden_states
+
+
+class PromptDepthAnythingDepthEstimationHead(nn.Module):
+ """
+ Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
+ the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's
+ supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation
+ type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.head_in_index = config.head_in_index
+ self.patch_size = config.patch_size
+
+ features = config.fusion_hidden_size
+ self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1)
+ self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1)
+ self.activation1 = nn.ReLU()
+ self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0)
+ if config.depth_estimation_type == "relative":
+ self.activation2 = nn.ReLU()
+ elif config.depth_estimation_type == "metric":
+ self.activation2 = nn.Sigmoid()
+ else:
+ raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}")
+ self.max_depth = config.max_depth
+
+ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor:
+ hidden_states = hidden_states[-1]
+
+ predicted_depth = self.conv1(hidden_states)
+ target_height = torch_int(patch_height * self.patch_size)
+ target_width = torch_int(patch_width * self.patch_size)
+ predicted_depth = nn.functional.interpolate(
+ predicted_depth,
+ (target_height, target_width),
+ mode="bilinear",
+ align_corners=True,
+ )
+ predicted_depth = self.conv2(predicted_depth)
+ predicted_depth = self.activation1(predicted_depth)
+ predicted_depth = self.conv3(predicted_depth)
+ predicted_depth = self.activation2(predicted_depth)
+ # (batch_size, 1, height, width) -> (batch_size, height, width), which
+ # keeps the same behavior as Depth Anything v1 & v2
+ predicted_depth = predicted_depth.squeeze(dim=1)
+
+ return predicted_depth
+
+
+class PromptDepthAnythingPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = PromptDepthAnythingConfig
+ base_model_prefix = "prompt_depth_anything"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+class PromptDepthAnythingReassembleLayer(nn.Module):
+ def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int):
+ super().__init__()
+ self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1)
+
+ # up/down sampling depending on factor
+ if factor > 1:
+ self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
+ elif factor == 1:
+ self.resize = nn.Identity()
+ elif factor < 1:
+ # so should downsample
+ stride = torch_int(1 / factor)
+ self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1)
+
+ def forward(self, hidden_state):
+ hidden_state = self.projection(hidden_state)
+ hidden_state = self.resize(hidden_state)
+
+ return hidden_state
+
+
+class PromptDepthAnythingReassembleStage(nn.Module):
+ """
+ This class reassembles the hidden states of the backbone into image-like feature representations at various
+ resolutions.
+
+ This happens in 3 stages:
+ 1. Take the patch embeddings and reshape them to image-like feature representations.
+ 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
+ 3. Resizing the spatial dimensions (height, width).
+
+ Args:
+ config (`[PromptDepthAnythingConfig]`):
+ Model configuration class defining the model architecture.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.layers = nn.ModuleList()
+ for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors):
+ self.layers.append(PromptDepthAnythingReassembleLayer(config, channels=channels, factor=factor))
+
+ def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
+ """
+ Args:
+ hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
+ List of hidden states from the backbone.
+ """
+ out = []
+
+ for i, hidden_state in enumerate(hidden_states):
+ # reshape to (batch_size, num_channels, height, width)
+ hidden_state = hidden_state[:, 1:]
+ batch_size, _, num_channels = hidden_state.shape
+ hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ hidden_state = self.layers[i](hidden_state)
+ out.append(hidden_state)
+
+ return out
+
+
+class PromptDepthAnythingNeck(nn.Module):
+ """
+ PromptDepthAnythingNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
+ input and produces another list of tensors as output. For PromptDepthAnything, it includes 2 stages:
+
+ * PromptDepthAnythingReassembleStage
+ * PromptDepthAnythingFeatureFusionStage.
+
+ Args:
+ config (dict): config dict.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.reassemble_stage = PromptDepthAnythingReassembleStage(config)
+
+ self.convs = nn.ModuleList()
+ for channel in config.neck_hidden_sizes:
+ self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
+
+ # fusion
+ self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config)
+
+ def forward(
+ self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None
+ ) -> List[torch.Tensor]:
+ """
+ Args:
+ hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
+ List of hidden states from the backbone.
+ """
+ if not isinstance(hidden_states, (tuple, list)):
+ raise TypeError("hidden_states should be a tuple or list of tensors")
+
+ if len(hidden_states) != len(self.config.neck_hidden_sizes):
+ raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
+
+ # postprocess hidden states
+ hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
+
+ features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
+
+ # fusion blocks
+ output = self.fusion_stage(features, prompt_depth=prompt_depth)
+
+ return output
+
+
+PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
+ for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ """
+ Prompt Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
+ """,
+ PROMPT_DEPTH_ANYTHING_START_DOCSTRING,
+)
+class PromptDepthAnythingForDepthEstimation(PromptDepthAnythingPreTrainedModel):
+ _no_split_modules = ["DPTViTEmbeddings"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.backbone = load_backbone(config)
+ self.neck = PromptDepthAnythingNeck(config)
+ self.head = PromptDepthAnythingDepthEstimationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ prompt_depth: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth depth estimation maps for computing the loss.
+
+ Returns:
+
+ Examples:
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth depth estimation maps for computing the loss.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
+ >>> import torch
+ >>> import numpy as np
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
+ >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
+
+ >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
+ >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
+
+ >>> # prepare image for the model
+ >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth)
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> # interpolate to original size
+ >>> post_processed_output = image_processor.post_process_depth_estimation(
+ ... outputs,
+ ... target_sizes=[(image.height, image.width)],
+ ... )
+
+ >>> # visualize the prediction
+ >>> predicted_depth = post_processed_output[0]["predicted_depth"]
+ >>> depth = predicted_depth * 1000.
+ >>> depth = depth.detach().cpu().numpy()
+ >>> depth = Image.fromarray(depth.astype("uint16")) # mm
+ ```"""
+ loss = None
+ if labels is not None:
+ raise NotImplementedError("Training is not implemented yet")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ outputs = self.backbone.forward_with_filtered_kwargs(
+ pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
+ )
+ hidden_states = outputs.feature_maps
+
+ _, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ patch_height = height // patch_size
+ patch_width = width // patch_size
+
+ if prompt_depth is not None:
+ # normalize prompt depth
+ batch_size = prompt_depth.shape[0]
+ depth_min = torch.min(prompt_depth.reshape(batch_size, -1), dim=1).values
+ depth_max = torch.max(prompt_depth.reshape(batch_size, -1), dim=1).values
+ depth_min, depth_max = depth_min.view(batch_size, 1, 1, 1), depth_max.view(batch_size, 1, 1, 1)
+ prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min)
+ # normalize done
+
+ hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth)
+
+ predicted_depth = self.head(hidden_states, patch_height, patch_width)
+ if prompt_depth is not None:
+ # denormalize predicted depth
+ depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1)
+ predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min
+ # denormalize done
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (predicted_depth,) + outputs[1:]
+ else:
+ output = (predicted_depth,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return DepthEstimatorOutput(
+ loss=loss,
+ predicted_depth=predicted_depth,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["PromptDepthAnythingForDepthEstimation", "PromptDepthAnythingPreTrainedModel"]
diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py
new file mode 100644
index 00000000000000..9d14b5fc7e1fdb
--- /dev/null
+++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py
@@ -0,0 +1,380 @@
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from transformers.models.depth_anything.configuration_depth_anything import DepthAnythingConfig
+from transformers.models.depth_anything.modeling_depth_anything import (
+ DepthAnythingDepthEstimationHead,
+ DepthAnythingFeatureFusionLayer,
+ DepthAnythingFeatureFusionStage,
+ DepthAnythingForDepthEstimation,
+ DepthAnythingNeck,
+ DepthAnythingReassembleStage,
+)
+from transformers.utils.generic import torch_int
+
+from ...file_utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+from ...modeling_outputs import DepthEstimatorOutput
+from ...modeling_utils import PreTrainedModel
+
+
+_CONFIG_FOR_DOC = "PromptDepthAnythingConfig"
+
+
+class PromptDepthAnythingConfig(DepthAnythingConfig):
+ model_type = "prompt_depth_anything"
+
+
+class PromptDepthAnythingLayer(nn.Module):
+ def __init__(self, config: PromptDepthAnythingConfig):
+ super().__init__()
+ self.convolution1 = nn.Conv2d(
+ 1,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+ self.activation1 = nn.ReLU()
+
+ self.convolution2 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+ self.activation2 = nn.ReLU()
+
+ self.convolution3 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+
+ def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.convolution1(prompt_depth)
+ hidden_state = self.activation1(hidden_state)
+ hidden_state = self.convolution2(hidden_state)
+ hidden_state = self.activation2(hidden_state)
+ hidden_state = self.convolution3(hidden_state)
+ return hidden_state
+
+
+class PromptDepthAnythingFeatureFusionLayer(DepthAnythingFeatureFusionLayer):
+ def __init__(self, config: PromptDepthAnythingConfig):
+ super().__init__(config)
+ self.prompt_depth_layer = PromptDepthAnythingLayer(config)
+
+ def forward(self, hidden_state, residual=None, size=None, prompt_depth=None):
+ if residual is not None:
+ if hidden_state.shape != residual.shape:
+ residual = nn.functional.interpolate(
+ residual, size=hidden_state.shape[2:], mode="bilinear", align_corners=False
+ )
+ hidden_state = hidden_state + self.residual_layer1(residual)
+
+ hidden_state = self.residual_layer2(hidden_state)
+
+ if prompt_depth is not None:
+ prompt_depth = nn.functional.interpolate(
+ prompt_depth, size=hidden_state.shape[2:], mode="bilinear", align_corners=False
+ )
+ res = self.prompt_depth_layer(prompt_depth)
+ hidden_state = hidden_state + res
+
+ modifier = {"scale_factor": 2} if size is None else {"size": size}
+
+ hidden_state = nn.functional.interpolate(
+ hidden_state,
+ **modifier,
+ mode="bilinear",
+ align_corners=True,
+ )
+ hidden_state = self.projection(hidden_state)
+
+ return hidden_state
+
+
+class PromptDepthAnythingFeatureFusionStage(DepthAnythingFeatureFusionStage):
+ def forward(self, hidden_states, size=None, prompt_depth=None):
+ # reversing the hidden_states, we start from the last
+ hidden_states = hidden_states[::-1]
+
+ fused_hidden_states = []
+ fused_hidden_state = None
+
+ for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)):
+ size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None
+
+ if fused_hidden_state is None:
+ # first layer only uses the last hidden_state
+ fused_hidden_state = layer(hidden_state, size=size, prompt_depth=prompt_depth)
+ else:
+ fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size, prompt_depth=prompt_depth)
+
+ fused_hidden_states.append(fused_hidden_state)
+
+ return fused_hidden_states
+
+
+class PromptDepthAnythingDepthEstimationHead(DepthAnythingDepthEstimationHead):
+ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor:
+ hidden_states = hidden_states[-1]
+
+ predicted_depth = self.conv1(hidden_states)
+ target_height = torch_int(patch_height * self.patch_size)
+ target_width = torch_int(patch_width * self.patch_size)
+ predicted_depth = nn.functional.interpolate(
+ predicted_depth,
+ (target_height, target_width),
+ mode="bilinear",
+ align_corners=True,
+ )
+ predicted_depth = self.conv2(predicted_depth)
+ predicted_depth = self.activation1(predicted_depth)
+ predicted_depth = self.conv3(predicted_depth)
+ predicted_depth = self.activation2(predicted_depth)
+ # (batch_size, 1, height, width) -> (batch_size, height, width), which
+ # keeps the same behavior as Depth Anything v1 & v2
+ predicted_depth = predicted_depth.squeeze(dim=1)
+
+ return predicted_depth
+
+
+PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
+ for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ prompt_depth (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
+ Prompt depth is the sparse or low-resolution depth obtained from multi-view geometry or a
+ low-resolution depth sensor. It generally has shape (height, width), where height
+ and width can be smaller than those of the images. It is optional and can be None, which means no prompt depth
+ will be used. If it is None, the output will be a monocular relative depth.
+ The values are recommended to be in meters, but this is not necessary.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class PromptDepthAnythingPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = PromptDepthAnythingConfig
+ base_model_prefix = "prompt_depth_anything"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+class PromptDepthAnythingReassembleLayer(nn.Module):
+ def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int):
+ super().__init__()
+ self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1)
+
+ # up/down sampling depending on factor
+ if factor > 1:
+ self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
+ elif factor == 1:
+ self.resize = nn.Identity()
+ elif factor < 1:
+ # so should downsample
+ stride = torch_int(1 / factor)
+ self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1)
+
+ def forward(self, hidden_state):
+ hidden_state = self.projection(hidden_state)
+ hidden_state = self.resize(hidden_state)
+
+ return hidden_state
+
+
+class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage):
+ pass
+
+
+class PromptDepthAnythingNeck(DepthAnythingNeck):
+ def forward(
+ self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None, prompt_depth=None
+ ) -> List[torch.Tensor]:
+ """
+ Args:
+ hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
+ List of hidden states from the backbone.
+ """
+ if not isinstance(hidden_states, (tuple, list)):
+ raise TypeError("hidden_states should be a tuple or list of tensors")
+
+ if len(hidden_states) != len(self.config.neck_hidden_sizes):
+ raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
+
+ # postprocess hidden states
+ hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
+
+ features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
+
+ # fusion blocks
+ output = self.fusion_stage(features, prompt_depth=prompt_depth)
+
+ return output
+
+
+@add_start_docstrings(
+ """
+ Prompt Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
+ """,
+ PROMPT_DEPTH_ANYTHING_START_DOCSTRING,
+)
+class PromptDepthAnythingForDepthEstimation(DepthAnythingForDepthEstimation):
+ @add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ prompt_depth: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth depth estimation maps for computing the loss.
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
+ >>> import torch
+ >>> import numpy as np
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
+ >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
+
+ >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
+ >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
+
+ >>> # prepare image for the model
+ >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth)
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> # interpolate to original size
+ >>> post_processed_output = image_processor.post_process_depth_estimation(
+ ... outputs,
+ ... target_sizes=[(image.height, image.width)],
+ ... )
+
+ >>> # visualize the prediction
+ >>> predicted_depth = post_processed_output[0]["predicted_depth"]
+ >>> depth = predicted_depth * 1000.
+ >>> depth = depth.detach().cpu().numpy()
+ >>> depth = Image.fromarray(depth.astype("uint16")) # mm
+ ```"""
+ loss = None
+ if labels is not None:
+ raise NotImplementedError("Training is not implemented yet")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ outputs = self.backbone.forward_with_filtered_kwargs(
+ pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
+ )
+ hidden_states = outputs.feature_maps
+
+ _, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ patch_height = height // patch_size
+ patch_width = width // patch_size
+
+ if prompt_depth is not None:
+ # normalize prompt depth
+ batch_size = prompt_depth.shape[0]
+ depth_min = torch.min(prompt_depth.reshape(batch_size, -1), dim=1).values
+ depth_max = torch.max(prompt_depth.reshape(batch_size, -1), dim=1).values
+ depth_min, depth_max = depth_min.view(batch_size, 1, 1, 1), depth_max.view(batch_size, 1, 1, 1)
+ prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min)
+ # normalize done
+
+ hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth)
+
+ predicted_depth = self.head(hidden_states, patch_height, patch_width)
+ if prompt_depth is not None:
+ # denormalize predicted depth
+ depth_min, depth_max = depth_min.squeeze(1), depth_max.squeeze(1)
+ predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min
+ # denormalize done
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (predicted_depth,) + outputs[1:]
+ else:
+ output = (predicted_depth,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return DepthEstimatorOutput(
+ loss=loss,
+ predicted_depth=predicted_depth,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "PromptDepthAnythingConfig",
+ "PromptDepthAnythingForDepthEstimation",
+ "PromptDepthAnythingPreTrainedModel",
+]
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index e3463461ea07e5..11f5aa1c86a2fa 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -7605,6 +7605,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class PromptDepthAnythingForDepthEstimation(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class PromptDepthAnythingPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class ProphetNetDecoder(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index 3ebda4404aae9c..8d5dee958bd196 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -548,6 +548,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class PromptDepthAnythingImageProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class PvtImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
diff --git a/tests/models/prompt_depth_anything/__init__.py b/tests/models/prompt_depth_anything/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/prompt_depth_anything/test_image_processing_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_image_processing_prompt_depth_anything.py
new file mode 100644
index 00000000000000..7becbe5dfa5091
--- /dev/null
+++ b/tests/models/prompt_depth_anything/test_image_processing_prompt_depth_anything.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.file_utils import is_vision_available
+from transformers.testing_utils import require_torch, require_vision
+
+from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
+
+
+if is_vision_available():
+ from transformers import PromptDepthAnythingImageProcessor
+
+
+class PromptDepthAnythingImageProcessingTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=None,
+ do_normalize=True,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ ):
+ super().__init__()
+ size = size if size is not None else {"height": 18, "width": 18}
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def prepare_image_processor_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "size": self.size,
+ }
+
+ def expected_output_image_shape(self, images):
+ return self.num_channels, self.size["height"], self.size["width"]
+
+ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
+ return prepare_image_inputs(
+ batch_size=self.batch_size,
+ num_channels=self.num_channels,
+ min_resolution=self.min_resolution,
+ max_resolution=self.max_resolution,
+ equal_resolution=equal_resolution,
+ numpify=numpify,
+ torchify=torchify,
+ )
+
+
+@require_torch
+@require_vision
+class PromptDepthAnythingImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
+ image_processing_class = PromptDepthAnythingImageProcessor if is_vision_available() else None
+
+ def setUp(self):
+ super().setUp()
+ self.image_processor_tester = PromptDepthAnythingImageProcessingTester(self)
+
+ @property
+ def image_processor_dict(self):
+ return self.image_processor_tester.prepare_image_processor_dict()
+
+ def test_image_processor_properties(self):
+ image_processing = self.image_processing_class(**self.image_processor_dict)
+ self.assertTrue(hasattr(image_processing, "image_mean"))
+ self.assertTrue(hasattr(image_processing, "image_std"))
+ self.assertTrue(hasattr(image_processing, "do_normalize"))
+ self.assertTrue(hasattr(image_processing, "do_resize"))
+ self.assertTrue(hasattr(image_processing, "size"))
+ self.assertTrue(hasattr(image_processing, "do_rescale"))
+ self.assertTrue(hasattr(image_processing, "rescale_factor"))
+ self.assertTrue(hasattr(image_processing, "do_pad"))
+ self.assertTrue(hasattr(image_processing, "size_divisor"))
+ self.assertTrue(hasattr(image_processing, "prompt_scale_to_meter"))
+
+ def test_image_processor_from_dict_with_kwargs(self):
+ image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
+ self.assertEqual(image_processor.size, {"height": 18, "width": 18})
+
+ image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
+ self.assertEqual(image_processor.size, {"height": 42, "width": 42})
+
+ def test_keep_aspect_ratio(self):
+ size = {"height": 512, "width": 512}
+ image_processor = PromptDepthAnythingImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32)
+
+ image = np.zeros((489, 640, 3))
+
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
+
+ self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
+
+ def test_prompt_depth_processing(self):
+ size = {"height": 756, "width": 756}
+ image_processor = PromptDepthAnythingImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32)
+
+ image = np.zeros((756, 1008, 3))
+ prompt_depth = np.random.random((192, 256))
+
+ outputs = image_processor(image, prompt_depth=prompt_depth, return_tensors="pt")
+ pixel_values = outputs.pixel_values
+ prompt_depth_values = outputs.prompt_depth
+
+ self.assertEqual(list(pixel_values.shape), [1, 3, 768, 1024])
+ self.assertEqual(list(prompt_depth_values.shape), [1, 1, 192, 256])
diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py
new file mode 100644
index 00000000000000..3e95670fc46034
--- /dev/null
+++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py
@@ -0,0 +1,323 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Prompt Depth Anything model."""
+
+import unittest
+
+import requests
+
+from transformers import Dinov2Config, PromptDepthAnythingConfig
+from transformers.file_utils import is_torch_available, is_vision_available
+from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import PromptDepthAnythingForDepthEstimation
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoImageProcessor
+
+
+class PromptDepthAnythingModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ num_channels=3,
+ image_size=32,
+ patch_size=16,
+ use_labels=True,
+ num_labels=3,
+ is_training=True,
+ hidden_size=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ intermediate_size=8,
+ out_features=["stage1", "stage2"],
+ apply_layernorm=False,
+ reshape_hidden_states=False,
+ neck_hidden_sizes=[2, 2],
+ fusion_hidden_size=6,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.out_features = out_features
+ self.apply_layernorm = apply_layernorm
+ self.reshape_hidden_states = reshape_hidden_states
+ self.use_labels = use_labels
+ self.num_labels = num_labels
+ self.is_training = is_training
+ self.neck_hidden_sizes = neck_hidden_sizes
+ self.fusion_hidden_size = fusion_hidden_size
+ self.seq_length = (self.image_size // self.patch_size) ** 2 + 1
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
+
+ prompt_depth = floats_tensor([self.batch_size, 1, self.image_size // 4, self.image_size // 4])
+
+ config = self.get_config()
+
+ return config, pixel_values, labels, prompt_depth
+
+ def get_config(self):
+ return PromptDepthAnythingConfig(
+ backbone_config=self.get_backbone_config(),
+ reassemble_hidden_size=self.hidden_size,
+ patch_size=self.patch_size,
+ neck_hidden_sizes=self.neck_hidden_sizes,
+ fusion_hidden_size=self.fusion_hidden_size,
+ )
+
+ def get_backbone_config(self):
+ return Dinov2Config(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ is_training=self.is_training,
+ out_features=self.out_features,
+ reshape_hidden_states=self.reshape_hidden_states,
+ )
+
+ def create_and_check_for_depth_estimation(self, config, pixel_values, labels, prompt_depth):
+ config.num_labels = self.num_labels
+ model = PromptDepthAnythingForDepthEstimation(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, prompt_depth=prompt_depth)
+ self.parent.assertEqual(result.predicted_depth.shape, (self.batch_size, self.image_size, self.image_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels, prompt_depth = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values, "prompt_depth": prompt_depth}
+ return config, inputs_dict
+
+
+@require_torch
+class PromptDepthAnythingModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as Prompt Depth Anything does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (PromptDepthAnythingForDepthEstimation,) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {"depth-estimation": PromptDepthAnythingForDepthEstimation} if is_torch_available() else {}
+ )
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = PromptDepthAnythingModelTester(self)
+ self.config_tester = ConfigTester(
+ self,
+ config_class=PromptDepthAnythingConfig,
+ has_text_modality=False,
+ hidden_size=37,
+ common_properties=["patch_size"],
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(
+ reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings"
+ )
+ def test_inputs_embeds(self):
+ pass
+
+ def test_for_depth_estimation(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs)
+
+ @unittest.skip(reason="Prompt Depth Anything does not support training yet")
+ def test_training(self):
+ pass
+
+ @unittest.skip(reason="Prompt Depth Anything does not support training yet")
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ @unittest.skip(
+ reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings"
+ )
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model")
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model")
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @unittest.skip(
+ reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant(self):
+ pass
+
+ @unittest.skip(
+ reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant_false(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ model_name = "depth-anything/prompt-depth-anything-vits-hf"
+ model = PromptDepthAnythingForDepthEstimation.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ def test_backbone_selection(self):
+ def _validate_backbone_init():
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ self.assertEqual(len(model.backbone.out_indices), 2)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ config.backbone = "facebook/dinov2-small"
+ config.use_pretrained_backbone = True
+ config.use_timm_backbone = False
+ config.backbone_config = None
+ config.backbone_kwargs = {"out_indices": [-2, -1]}
+ _validate_backbone_init()
+
+
+def prepare_img():
+ url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
+ image = Image.open(requests.get(url, stream=True).raw)
+ return image
+
+
+def prepare_prompt_depth():
+ prompt_depth_url = (
+ "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
+ )
+ prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
+ return prompt_depth
+
+
+@require_torch
+@require_vision
+@slow
+class PromptDepthAnythingModelIntegrationTest(unittest.TestCase):
+ def test_inference_wo_prompt_depth(self):
+ image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
+ model = PromptDepthAnythingForDepthEstimation.from_pretrained(
+ "depth-anything/prompt-depth-anything-vits-hf"
+ ).to(torch_device)
+
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
+
+ with torch.no_grad():
+ outputs = model(**inputs)
+ predicted_depth = outputs.predicted_depth
+
+ expected_shape = torch.Size([1, 756, 1008])
+ self.assertEqual(predicted_depth.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[0.5029, 0.5120, 0.5176], [0.4998, 0.5147, 0.5197], [0.4973, 0.5201, 0.5241]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-3))
+
+ def test_inference(self):
+ image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
+ model = PromptDepthAnythingForDepthEstimation.from_pretrained(
+ "depth-anything/prompt-depth-anything-vits-hf"
+ ).to(torch_device)
+
+ image = prepare_img()
+ prompt_depth = prepare_prompt_depth()
+ inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth).to(torch_device)
+
+ with torch.no_grad():
+ outputs = model(**inputs)
+ predicted_depth = outputs.predicted_depth
+
+ expected_shape = torch.Size([1, 756, 1008])
+ self.assertEqual(predicted_depth.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-3))
+
+ def test_export(self):
+ for strict in [True, False]:
+ with self.subTest(strict=strict):
+ if not is_torch_greater_or_equal_than_2_4:
+ self.skipTest(reason="This test requires torch >= 2.4 to run.")
+ model = (
+ PromptDepthAnythingForDepthEstimation.from_pretrained(
+ "depth-anything/prompt-depth-anything-vits-hf"
+ )
+ .to(torch_device)
+ .eval()
+ )
+ image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
+ image = prepare_img()
+ prompt_depth = prepare_prompt_depth()
+ inputs = image_processor(images=image, prompt_depth=prompt_depth, return_tensors="pt").to(torch_device)
+
+ exported_program = torch.export.export(
+ model,
+ args=(inputs["pixel_values"], inputs["prompt_depth"]),
+ strict=strict,
+ )
+ with torch.no_grad():
+ eager_outputs = model(**inputs)
+ exported_outputs = exported_program.module().forward(inputs["pixel_values"])
+ self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape)
+ self.assertTrue(
+ torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4)
+ )