Skip to content

Commit

Permalink
feat: normalize cls embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
myeolinmalchi committed Jan 29, 2025
1 parent dfb31ad commit ff42721
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions utils/embed/llama_cpp/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,21 @@ def compute_score(self, lw1: Dict[int, float], lw2: Dict[int, float]):
scores += weight * lw2[token]
return scores

def _normalize_hidden_state(self, hidden_state: List[List[float]]):
tensors = torch.tensor(hidden_state)
norm = torch.nn.LayerNorm(tensors.size())
tensors = norm(tensors)

return tensors

def _sparse_embedding(
self, hidden_state: List[List[float]], input_ids: List[int]
self, hidden_state: torch.Tensor, input_ids: List[int]
):
sparse_tensors = torch.tensor(hidden_state)
norm = torch.nn.LayerNorm(sparse_tensors.size())
sparse_tensors = norm(sparse_tensors)
#sparse_tensors = torch.tensor(hidden_state)
#norm = torch.nn.LayerNorm(sparse_tensors.size())
#sparse_tensors = norm(sparse_tensors)

sparse_tensors = torch.nan_to_num(sparse_tensors, 0)
sparse_tensors = torch.nan_to_num(hidden_state, 0)

token_weights_tensor = torch.relu(self.sparse_linear(sparse_tensors))
input_ids_tensor = torch.tensor(input_ids)
Expand Down Expand Up @@ -146,7 +153,6 @@ def embed_(
# decode and fetch embeddings
input_ids: List[List[int]] = []
seq_embeddings: List[List[List[float]]] = []
cls_embeddings: List[List[float]] = []

def decode_batch(seq_sizes: List[int]):
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
Expand All @@ -164,7 +170,6 @@ def decode_batch(seq_sizes: List[int]):
]

seq_embeddings.append(embeddings)
cls_embeddings.append(embeddings[0])

pos += size

Expand Down Expand Up @@ -213,18 +218,17 @@ def decode_batch(seq_sizes: List[int]):
llama_cpp.llama_perf_context_print(self._ctx.ctx)

lexical_weights: List[Dict[int, float]] = []
token_weights = []
cls_embeddings: List[List[float]] = []
for idx, input_id in enumerate(input_ids):
_, token_weight = self._sparse_embedding(
seq_embeddings[idx], input_id
)
normalized = self._normalize_hidden_state(seq_embeddings[idx])
_, token_weight = self._sparse_embedding(normalized, input_id)

lexical_weights.append(
self._process_token_weights(
token_weight.detach().numpy(), input_id
)
)
token_weights.append(token_weight)
cls_embeddings.append(normalized[0].detach().numpy())

outputs = [
EmbedResult(
Expand Down

0 comments on commit ff42721

Please sign in to comment.