Skip to content

Commit

Permalink
Fix to normalization. Global normaliztion was being applied, when ele…
Browse files Browse the repository at this point in the history
…ment-wise normalization (per vector normalization) should be applied instead
  • Loading branch information
Unobtainiumrock committed Jan 27, 2025
1 parent f2abc34 commit 9ba8760
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 39 deletions.
18 changes: 11 additions & 7 deletions adalflow/adalflow/components/retriever/faiss_retriever.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Semantic search/embedding-based retriever using FAISS."""

import faiss
from typing import (
List,
Optional,
Expand All @@ -24,22 +25,23 @@
RetrieverStrQueryType,
EmbedderOutputType,
)
from adalflow.core.functional import normalize_np_array, is_normalized
from adalflow.core.functional import normalize_embeddings, is_normalized

from adalflow.utils.lazy_import import safe_import, OptionalPackages

safe_import(OptionalPackages.FAISS.value[0], OptionalPackages.FAISS.value[1])
import faiss

log = logging.getLogger(__name__)

FAISSRetrieverDocumentEmbeddingType = Union[List[float], np.ndarray] # single embedding
# single embedding
FAISSRetrieverDocumentEmbeddingType = Union[List[float], np.ndarray]
FAISSRetrieverDocumentsType = Sequence[FAISSRetrieverDocumentEmbeddingType]

FAISSRetrieverEmbeddingQueryType = Union[
List[float], List[List[float]], np.ndarray
] # single embedding or list of embeddings
FAISSRetrieverQueryType = Union[RetrieverStrQueryType, FAISSRetrieverEmbeddingQueryType]
FAISSRetrieverQueryType = Union[RetrieverStrQueryType,
FAISSRetrieverEmbeddingQueryType]
FAISSRetrieverQueriesType = Sequence[FAISSRetrieverQueryType]
FAISSRetrieverQueriesStrType = Sequence[RetrieverStrQueryType]
FAISSRetrieverQueriesEmbeddingType = Sequence[FAISSRetrieverEmbeddingQueryType]
Expand Down Expand Up @@ -161,7 +163,8 @@ def build_index_from_documents(
If you are using Document format, pass them as [doc.vector for doc in documents]
"""
if document_map_func:
assert callable(document_map_func), "document_map_func should be callable"
assert callable(
document_map_func), "document_map_func should be callable"
documents = [document_map_func(doc) for doc in documents]
try:
self.documents = documents
Expand All @@ -183,7 +186,7 @@ def build_index_from_documents(
log.warning(
"Embeddings are not normalized, normalizing the embeddings"
)
self.xb = normalize_np_array(self.xb)
self.xb = normalize_embeddings(self.xb)

self._preprare_faiss_index_from_np_array(self.xb)
log.info(f"Index built with {self.total_documents} chunks")
Expand Down Expand Up @@ -295,7 +298,8 @@ def retrieve_string_queries(
output: RetrieverOutputType = [
RetrieverOutput(doc_indices=[], query=query) for query in queries
]
retrieved_output: RetrieverOutputType = self._to_retriever_output(Ind, D)
retrieved_output: RetrieverOutputType = self._to_retriever_output(
Ind, D)

# fill in the doc_indices and score for valid queries
for i, per_query_output in enumerate(retrieved_output):
Expand Down
104 changes: 72 additions & 32 deletions adalflow/adalflow/core/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def custom_asdict(
tuples, lists, and dicts.
"""
if not is_dataclass_instance(obj):
raise TypeError("custom_asdict() should be called on dataclass instances")
raise TypeError(
"custom_asdict() should be called on dataclass instances")
return _asdict_inner(obj, dict_factory, exclude or {})


Expand Down Expand Up @@ -254,15 +255,18 @@ class TrecDataList:
): # Optional[Address] will be false, and true for each check

log.debug(
f"{is_dataclass(cls)} of {cls}, {is_potential_dataclass(cls)} of {cls}"
f"{is_dataclass(cls)} of {cls}, {
is_potential_dataclass(cls)} of {cls}"
)
# Ensure the data is a dictionary
if not isinstance(data, dict):
raise ValueError(
f"Expected data of type dict for {cls}, but got {type(data).__name__}"
f"Expected data of type dict for {
cls}, but got {type(data).__name__}"
)
cls_type = extract_dataclass_type(cls)
fieldtypes = {f.name: f.type for f in cls_type.__dataclass_fields__.values()}
fieldtypes = {
f.name: f.type for f in cls_type.__dataclass_fields__.values()}

restored_data = cls_type(
**{
Expand All @@ -277,11 +281,13 @@ class TrecDataList:
for item in data:
if check_data_class_field_args_zero(cls):
# restore the value to its dataclass type
restored_data.append(dataclass_obj_from_dict(cls.__args__[0], item))
restored_data.append(
dataclass_obj_from_dict(cls.__args__[0], item))

elif check_if_class_field_args_zero_exists(cls):
# Use the original data [Any]
restored_data.append(dataclass_obj_from_dict(cls.__args__[0], item))
restored_data.append(
dataclass_obj_from_dict(cls.__args__[0], item))

else:
restored_data.append(item)
Expand All @@ -293,10 +299,12 @@ class TrecDataList:
for item in data:
if check_data_class_field_args_zero(cls):
# restore the value to its dataclass type
restored_data.add(dataclass_obj_from_dict(cls.__args__[0], item))
restored_data.add(
dataclass_obj_from_dict(cls.__args__[0], item))
elif check_if_class_field_args_zero_exists(cls):
# Use the original data [Any]
restored_data.add(dataclass_obj_from_dict(cls.__args__[0], item))
restored_data.add(
dataclass_obj_from_dict(cls.__args__[0], item))

else:
# Use the original data [Any]
Expand All @@ -319,7 +327,8 @@ class TrecDataList:
return data
# else normal data like int, str, float, etc.
else:
log.debug(f"Not datclass, or list, or dict: {cls}, use the original data.")
log.debug(f"Not datclass, or list, or dict: {
cls}, use the original data.")
return data


Expand Down Expand Up @@ -393,7 +402,8 @@ def get_type_schema(
if arg is not type(None)
]
return (
f"Optional[{types[0]}]" if len(types) == 1 else f"Union[{', '.join(types)}]"
f"Optional[{types[0]}]" if len(
types) == 1 else f"Union[{', '.join(types)}]"
)
elif origin in {List, list}:
args = get_args(type_obj)
Expand All @@ -414,21 +424,22 @@ def get_type_schema(
elif origin in {Set, set}:
args = get_args(type_obj)
return (
f"Set[{get_type_schema(args[0],exclude, type_var_map)}]" if args else "Set"
f"Set[{get_type_schema(
args[0], exclude, type_var_map)}]" if args else "Set"
)

elif origin is Sequence:
args = get_args(type_obj)
return (
f"Sequence[{get_type_schema(args[0], exclude,type_var_map)}]"
f"Sequence[{get_type_schema(args[0], exclude, type_var_map)}]"
if args
else "Sequence"
)

elif origin in {Tuple, tuple}:
args = get_args(type_obj)
if args:
return f"Tuple[{', '.join(get_type_schema(arg,exclude,type_var_map) for arg in args)}]"
return f"Tuple[{', '.join(get_type_schema(arg, exclude, type_var_map) for arg in args)}]"
return "Tuple"

elif is_dataclass(type_obj):
Expand Down Expand Up @@ -496,7 +507,8 @@ def get_dataclass_schema(
# prepare field schema, it weill be done recursively for nested dataclasses

field_type = type_var_map.get(f.type, f.type)
field_schema = {"type": get_type_schema(field_type, exclude, type_var_map)}
field_schema = {"type": get_type_schema(
field_type, exclude, type_var_map)}

# check required field
is_required = _is_required_field(f)
Expand Down Expand Up @@ -588,7 +600,8 @@ def example_function(x: int, y: str = "default") -> int:
param_type = type_hints.get(param_name, "Any")
if parameter.default == Parameter.empty:
schema["required"].append(param_name)
schema["properties"][param_name] = {"type": get_type_schema(param_type)}
schema["properties"][param_name] = {
"type": get_type_schema(param_type)}
else:
schema["properties"][param_name] = {
"type": get_type_schema(param_type),
Expand Down Expand Up @@ -659,7 +672,8 @@ def evaluate_ast_node(node: ast.AST, context_map: Dict[str, Any] = None):
return output_fun
# TODO: raise the error back to the caller so that the llm can get the error message
except KeyError as e:
log.error(f"Error: {e}, {node.id} does not exist in the context_map.")
log.error(f"Error: {e}, {
node.id} does not exist in the context_map.")
raise ValueError(
f"Error: {e}, {node.id} does not exist in the context_map."
)
Expand All @@ -669,7 +683,8 @@ def evaluate_ast_node(node: ast.AST, context_map: Dict[str, Any] = None):

elif isinstance(
node, ast.Call
): # another fun or class as argument and value, e.g. add( multiply(4,5), 3)
# another fun or class as argument and value, e.g. add( multiply(4,5), 3)
):
func = evaluate_ast_node(node.func, context_map)
args = [evaluate_ast_node(arg, context_map) for arg in node.args]
kwargs = {
Expand Down Expand Up @@ -712,11 +727,13 @@ def parse_function_call_expr(
if isinstance(tree.body, ast.Call):
# Extract the function name
func_name = (
tree.body.func.id if isinstance(tree.body.func, ast.Name) else None
tree.body.func.id if isinstance(
tree.body.func, ast.Name) else None
)

# Prepare the list of arguments and keyword arguments
args = [evaluate_ast_node(arg, context_map) for arg in tree.body.args]
args = [evaluate_ast_node(arg, context_map)
for arg in tree.body.args]
keywords = {
kw.arg: evaluate_ast_node(kw.value, context_map)
for kw in tree.body.keywords
Expand Down Expand Up @@ -889,13 +906,32 @@ def is_normalized(v: VECTOR_TYPE, tol=1e-4) -> bool:
return np.abs(norm - 1) < tol


def normalize_np_array(v: np.ndarray) -> np.ndarray:
# Compute the norm of the vector (assuming v is 1D)
norm = np.linalg.norm(v)
# Normalize the vector
normalized_v = v / norm
# Return the normalized vector
return normalized_v
def normalize_embeddings(v: np.ndarray) -> np.ndarray:
"""
Normalize embeddings to have L2 norm = 1.
Handles both:
- 1D arrays: a single embedding (shape = (d,))
- 2D arrays: multiple embeddings (shape = (N, d))
"""
if v.ndim == 1:
# Single embedding vector
norm = np.linalg.norm(v)
if norm == 0:
norm = 1e-12 # Avoid division by zero
return (v / norm).astype(np.float32)
elif v.ndim == 2:
# Multiple embeddings: row-wise normalization
# norms: shape = (N,1)
norms = np.linalg.norm(v, axis=1, keepdims=True)
# Avoid division by zero for rows that might be zero
norms[norms < 1e-12] = 1e-12
return (v / norms).astype(np.float32)
else:
raise ValueError(
f"normalize_np_array expects 1D or 2D input. Got shape {v.shape}"
)


def normalize_vector(v: VECTOR_TYPE) -> List[float]:
Expand Down Expand Up @@ -1086,7 +1122,7 @@ def extract_json_str(text: str, add_missing_right_brace: bool = True) -> str:
"Incomplete JSON object found and add_missing_right_brace is False."
)

return text[start : end + 1]
return text[start: end + 1]


def extract_list_str(text: str, add_missing_right_bracket: bool = True) -> str:
Expand Down Expand Up @@ -1137,7 +1173,7 @@ def extract_list_str(text: str, add_missing_right_bracket: bool = True) -> str:
"Incomplete list found and add_missing_right_bracket is False."
)

return text[start : end + 1]
return text[start: end + 1]


def extract_yaml_str(text: str) -> str:
Expand Down Expand Up @@ -1222,7 +1258,8 @@ def parse_json_str_to_obj(json_str: str) -> Union[Dict[str, Any], List[Any]]:
return json_obj
except json.JSONDecodeError as e:
log.info(
f"Got invalid JSON object with json.loads. Error: {e}. Got JSON string: {json_str}"
f"Got invalid JSON object with json.loads. Error: {
e}. Got JSON string: {json_str}"
)
# 2nd attemp after fixing the json string
try:
Expand All @@ -1246,7 +1283,8 @@ def parse_json_str_to_obj(json_str: str) -> Union[Dict[str, Any], List[Any]]:
return json_obj
except yaml.YAMLError as e:
raise ValueError(
f"Got invalid JSON object with yaml.safe_load. Error: {e}. Got JSON string: {json_str}"
f"Got invalid JSON object with yaml.safe_load. Error: {
e}. Got JSON string: {json_str}"
)


Expand All @@ -1269,7 +1307,8 @@ def random_sample(

if not replace and num_shots > dataset_size:
log.debug(
f"num_shots {num_shots} is larger than the dataset size {dataset_size}"
f"num_shots {num_shots} is larger than the dataset size {
dataset_size}"
)
num_shots = dataset_size

Expand All @@ -1282,6 +1321,7 @@ def random_sample(
# Normalize weights to sum to 1
weights = weights / weights.sum()

indices = np.random.choice(len(dataset), size=num_shots, replace=replace, p=weights)
indices = np.random.choice(
len(dataset), size=num_shots, replace=replace, p=weights)

return [dataset[i] for i in indices]

0 comments on commit 9ba8760

Please sign in to comment.