From 9278dc0f674d31447df15f0f4b4600e45e9462cc Mon Sep 17 00:00:00 2001 From: singjc Date: Wed, 30 Oct 2024 00:03:47 -0400 Subject: [PATCH 1/5] change: BertEncoder struct to public --- candle-transformers/src/models/bert.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index bdc0385deb..fcd528dc87 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -411,7 +411,7 @@ impl BertLayer { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 -struct BertEncoder { +pub struct BertEncoder { layers: Vec, span: tracing::Span, } From 7572e6161f10d2592e44f751273e6357857e0c5b Mon Sep 17 00:00:00 2001 From: singjc Date: Wed, 30 Oct 2024 00:51:45 -0400 Subject: [PATCH 2/5] change: make certain fields in Config struct public --- candle-transformers/src/models/bert.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index fcd528dc87..80440237cc 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -46,12 +46,12 @@ enum PositionEmbeddingType { #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { vocab_size: usize, - hidden_size: usize, - num_hidden_layers: usize, - num_attention_heads: usize, - intermediate_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, pub hidden_act: HiddenAct, - hidden_dropout_prob: f64, + pub hidden_dropout_prob: f64, max_position_embeddings: usize, type_vocab_size: usize, initializer_range: f64, @@ -376,7 +376,7 @@ impl BertOutput { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 -struct BertLayer { +pub struct BertLayer { attention: BertAttention, intermediate: BertIntermediate, output: BertOutput, @@ -417,7 +417,7 @@ pub struct BertEncoder { } impl BertEncoder { - fn load(vb: VarBuilder, config: &Config) -> Result { + pub fn load(vb: VarBuilder, config: &Config) -> Result { let layers = (0..config.num_hidden_layers) .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config)) .collect::>>()?; @@ -425,7 +425,7 @@ impl BertEncoder { Ok(BertEncoder { layers, span }) } - fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + pub fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { let _enter = self.span.enter(); let mut hidden_states = hidden_states.clone(); // Use a loop rather than a fold as it's easier to modify when adding debug/... From cfa7e17cffc70c700ba914eb6305720f70783266 Mon Sep 17 00:00:00 2001 From: singjc Date: Wed, 30 Oct 2024 10:47:57 -0400 Subject: [PATCH 3/5] change: all fields in bert config struct to be public --- candle-transformers/src/models/bert.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 80440237cc..4de0891222 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -45,24 +45,24 @@ enum PositionEmbeddingType { // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { - vocab_size: usize, + pub vocab_size: usize, pub hidden_size: usize, pub num_hidden_layers: usize, pub num_attention_heads: usize, pub intermediate_size: usize, pub hidden_act: HiddenAct, pub hidden_dropout_prob: f64, - max_position_embeddings: usize, - type_vocab_size: usize, - initializer_range: f64, - layer_norm_eps: f64, - pad_token_id: usize, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub initializer_range: f64, + pub layer_norm_eps: f64, + pub pad_token_id: usize, #[serde(default)] - position_embedding_type: PositionEmbeddingType, + pub position_embedding_type: PositionEmbeddingType, #[serde(default)] - use_cache: bool, - classifier_dropout: Option, - model_type: Option, + pub use_cache: bool, + pub classifier_dropout: Option, + pub model_type: Option, } impl Default for Config { From fe0dd76ec1865018533b617fdfaa7a4231ee8408 Mon Sep 17 00:00:00 2001 From: singjc Date: Wed, 30 Oct 2024 14:10:28 -0400 Subject: [PATCH 4/5] change: add clone to bert encoder and others --- candle-transformers/src/models/bert.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 4de0891222..f9e6648075 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -13,6 +13,7 @@ pub enum HiddenAct { Relu, } +#[derive(Clone)] struct HiddenActLayer { act: HiddenAct, span: tracing::Span, @@ -112,6 +113,7 @@ impl Config { } } +#[derive(Clone)] struct Dropout { #[allow(dead_code)] pr: f64, @@ -190,6 +192,7 @@ impl BertEmbeddings { } } +#[derive(Clone)] struct BertSelfAttention { query: Linear, key: Linear, @@ -257,6 +260,7 @@ impl BertSelfAttention { } } +#[derive(Clone)] struct BertSelfOutput { dense: Linear, layer_norm: LayerNorm, @@ -290,6 +294,7 @@ impl BertSelfOutput { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 +#[derive(Clone)] struct BertAttention { self_attention: BertSelfAttention, self_output: BertSelfOutput, @@ -316,6 +321,7 @@ impl BertAttention { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 +#[derive(Clone)] struct BertIntermediate { dense: Linear, intermediate_act: HiddenActLayer, @@ -343,6 +349,7 @@ impl Module for BertIntermediate { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 +#[derive(Clone)] struct BertOutput { dense: Linear, layer_norm: LayerNorm, @@ -376,6 +383,7 @@ impl BertOutput { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 +#[derive(Clone)] pub struct BertLayer { attention: BertAttention, intermediate: BertIntermediate, @@ -411,8 +419,9 @@ impl BertLayer { } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 +#[derive(Clone)] pub struct BertEncoder { - layers: Vec, + pub layers: Vec, span: tracing::Span, } From cae6768d5916833585a494d01e2a245bf15fd1ea Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 4 Dec 2024 21:15:25 +0100 Subject: [PATCH 5/5] Clippy fix. --- candle-transformers/src/models/bert.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 06e65feade..0ff62c4f3e 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -47,7 +47,7 @@ impl HiddenActLayer { #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[serde(rename_all = "lowercase")] -enum PositionEmbeddingType { +pub enum PositionEmbeddingType { #[default] Absolute, }