Skip to content

Commit

Permalink
Merge branch 'main' into li
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed Jan 28, 2025
2 parents 231cca5 + a975e85 commit 615a019
Show file tree
Hide file tree
Showing 13 changed files with 1,347 additions and 860 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v3 # Updated to the latest version
- uses: actions/checkout@v4 # Updated to the latest version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4 # Updated to the latest version
with:
Expand All @@ -37,7 +37,7 @@ jobs:
poetry run pytest
- name: Upload pytest results as an artifact (optional)
uses: actions/upload-artifact@v3 # Updated to the latest version
uses: actions/upload-artifact@v4 # Updated to the latest version
if: always() # Always run this step to ensure test results are saved even if previous steps fail
with:
name: pytest-results
Expand Down
7 changes: 4 additions & 3 deletions adalflow/adalflow/components/data_process/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
DEFAULT_CHUNK_SIZE = 800
DEFAULT_CHUNK_OVERLAP = 200

tokenizer = Tokenizer()


class TextSplitter(Component):
Expand Down Expand Up @@ -156,6 +155,8 @@ class TextSplitter(Component):
# Document(id=e7b617b2-3927-4248-afce-ec0fc247ac8b, text='to illustrate.', meta_data=None, vector=[], parent_doc_id=doc1, order=2, score=None)
"""

tokenizer = Tokenizer()

def __init__(
self,
split_by: Literal["word", "sentence", "page", "passage", "token"] = "word",
Expand Down Expand Up @@ -301,7 +302,7 @@ def call(self, documents: DocumentSplitterInputType) -> DocumentSplitterOutputTy
def _split_text_into_units(self, text: str, separator: str) -> List[str]:
"""Split text based on the specified separator."""
if self.split_by == "token":
splits = tokenizer.encode(text)
splits = TextSplitter.tokenizer.encode(text)
else:
splits = text.split(separator)
log.info(f"Text split by '{separator}' into {len(splits)} parts.")
Expand Down Expand Up @@ -344,7 +345,7 @@ def _merge_units_to_chunks(

if self.split_by == "token":
# decode each chunk here
chunks = [tokenizer.decode(chunk) for chunk in chunks]
chunks = [TextSplitter.tokenizer.decode(chunk) for chunk in chunks]

log.info(f"Merged into {len(chunks)} chunks.")
return chunks
Expand Down
74 changes: 62 additions & 12 deletions adalflow/adalflow/components/model_client/bedrock_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""AWS Bedrock ModelClient integration."""

import json
import os
from typing import Dict, Optional, Any, Callable
from typing import Dict, Optional, Any, Callable, Generator as GeneratorType
import backoff
import logging

Expand All @@ -26,7 +27,6 @@ def get_first_message_content(completion: Dict) -> str:
r"""When we only need the content of the first message.
It is the default parser for chat completion."""
return completion["output"]["message"]["content"][0]["text"]
return completion["output"]["message"]["content"][0]["text"]


__all__ = [
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
self._aws_connection_timeout = aws_connection_timeout
self._aws_read_timeout = aws_read_timeout

self._client = None
self.session = None
self.sync_client = self.init_sync_client()
self.chat_completion_parser = (
Expand Down Expand Up @@ -158,16 +159,51 @@ def init_sync_client(self):
def init_async_client(self):
raise NotImplementedError("Async call not implemented yet.")

def parse_chat_completion(self, completion):
log.debug(f"completion: {completion}")
def handle_stream_response(self, stream: dict) -> GeneratorType:
r"""Handle the stream response from bedrock. Yield the chunks.
Args:
stream (dict): The stream response generator from bedrock.
Returns:
GeneratorType: A generator that yields the chunks from bedrock stream.
"""
try:
stream: GeneratorType = stream["stream"]
for chunk in stream:
log.debug(f"Raw chunk: {chunk}")
yield chunk
except Exception as e:
log.debug(f"Error in handle_stream_response: {e}") # Debug print
raise

def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
r"""Parse the completion, and assign it into the raw_response attribute.
If the completion is a stream, it will be handled by the handle_stream_response
method that returns a Generator. Otherwise, the completion will be parsed using
the get_first_message_content method.
Args:
completion (dict): The completion response from bedrock API call.
Returns:
GeneratorOutput: A generator output object with the parsed completion. May
return a generator if the completion is a stream.
"""
try:
data = completion["output"]["message"]["content"][0]["text"]
usage = self.track_completion_usage(completion)
return GeneratorOutput(data=None, usage=usage, raw_response=data)
usage = None
data = self.chat_completion_parser(completion)
if not isinstance(data, GeneratorType):
# Streaming completion usage tracking is not implemented.
usage = self.track_completion_usage(completion)
return GeneratorOutput(
data=None, error=None, raw_response=data, usage=usage
)
except Exception as e:
log.error(f"Error parsing completion: {e}")
log.error(f"Error parsing the completion: {e}")
return GeneratorOutput(
data=None, error=str(e), raw_response=str(completion)
data=None, error=str(e), raw_response=json.dumps(completion)
)

def track_completion_usage(self, completion: Dict) -> CompletionUsage:
Expand All @@ -191,6 +227,7 @@ def list_models(self):
print(f" Description: {model['description']}")
print(f" Provider: {model['provider']}")
print("")

except Exception as e:
print(f"Error listing models: {e}")

Expand Down Expand Up @@ -222,14 +259,27 @@ def convert_inputs_to_api_kwargs(
bedrock_runtime_exceptions.ModelErrorException,
bedrock_runtime_exceptions.ValidationException,
),
max_time=5,
max_time=2,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
def call(
self,
api_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED,
) -> dict:
"""
kwargs is the combined input and model_kwargs
"""
if model_type == ModelType.LLM:
return self.sync_client.converse(**api_kwargs)
if "stream" in api_kwargs and api_kwargs.get("stream", False):
log.debug("Streaming call")
api_kwargs.pop(
"stream", None
) # stream is not a valid parameter for bedrock
self.chat_completion_parser = self.handle_stream_response
return self.sync_client.converse_stream(**api_kwargs)
else:
api_kwargs.pop("stream", None)
return self.sync_client.converse(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")

Expand Down
15 changes: 10 additions & 5 deletions adalflow/adalflow/components/retriever/faiss_retriever.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Semantic search/embedding-based retriever using FAISS."""

import faiss
from typing import (
List,
Optional,
Expand Down Expand Up @@ -29,17 +30,18 @@
from adalflow.utils.lazy_import import safe_import, OptionalPackages

safe_import(OptionalPackages.FAISS.value[0], OptionalPackages.FAISS.value[1])
import faiss

log = logging.getLogger(__name__)

FAISSRetrieverDocumentEmbeddingType = Union[List[float], np.ndarray] # single embedding
# single embedding
FAISSRetrieverDocumentEmbeddingType = Union[List[float], np.ndarray]
FAISSRetrieverDocumentsType = Sequence[FAISSRetrieverDocumentEmbeddingType]

FAISSRetrieverEmbeddingQueryType = Union[
List[float], List[List[float]], np.ndarray
] # single embedding or list of embeddings
FAISSRetrieverQueryType = Union[RetrieverStrQueryType, FAISSRetrieverEmbeddingQueryType]
FAISSRetrieverQueryType = Union[RetrieverStrQueryType,
FAISSRetrieverEmbeddingQueryType]
FAISSRetrieverQueriesType = Sequence[FAISSRetrieverQueryType]
FAISSRetrieverQueriesStrType = Sequence[RetrieverStrQueryType]
FAISSRetrieverQueriesEmbeddingType = Sequence[FAISSRetrieverEmbeddingQueryType]
Expand Down Expand Up @@ -161,7 +163,8 @@ def build_index_from_documents(
If you are using Document format, pass them as [doc.vector for doc in documents]
"""
if document_map_func:
assert callable(document_map_func), "document_map_func should be callable"
assert callable(
document_map_func), "document_map_func should be callable"
documents = [document_map_func(doc) for doc in documents]
try:
self.documents = documents
Expand Down Expand Up @@ -194,6 +197,7 @@ def build_index_from_documents(
raise e

def _convert_cosine_similarity_to_probability(self, D: np.ndarray) -> np.ndarray:
D = np.clip(D, -1, 1)
D = (D + 1) / 2
D = np.round(D, 3)
return D
Expand Down Expand Up @@ -295,7 +299,8 @@ def retrieve_string_queries(
output: RetrieverOutputType = [
RetrieverOutput(doc_indices=[], query=query) for query in queries
]
retrieved_output: RetrieverOutputType = self._to_retriever_output(Ind, D)
retrieved_output: RetrieverOutputType = self._to_retriever_output(
Ind, D)

# fill in the doc_indices and score for valid queries
for i, per_query_output in enumerate(retrieved_output):
Expand Down
Loading

0 comments on commit 615a019

Please sign in to comment.