Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

Commit

Permalink
cleanup function + library to have a clean name / interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jordan-wu-97 committed Jan 5, 2024
1 parent 6d47661 commit ef24a81
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 63 deletions.
6 changes: 3 additions & 3 deletions examples/fast-api-server/fast_api_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain.embeddings import OpenAIEmbeddings
from langchain.storage import RedisStore
from openassistants.core.assistant import Assistant
from openassistants.functions.crud import OpenAPICRUD, PythonCRUD
from openassistants.functions.crud import OpenAPILibrary, PythonLibrary
from openassistants.utils.langchain_util import LangChainCachedEmbeddings
from openassistants_fastapi import RouteAssistants, create_router

Expand All @@ -15,9 +15,9 @@


# create a library with the custom function
custom_python_lib = PythonCRUD(functions=[find_email_by_name_function])
custom_python_lib = PythonLibrary(functions=[find_email_by_name_function])

openapi_lib = OpenAPICRUD(
openapi_lib = OpenAPILibrary(
spec="https://petstore3.swagger.io/api/v3/openapi.json",
base_url="https://petstore3.swagger.io/api/v3",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from openassistants.functions.base import (
BaseFunction,
FunctionExecutionDependency,
IBaseFunction,
IFunction,
)
from openassistants.functions.utils import AsyncStreamVersion


class IndexFunction(BaseFunction):
type: Literal["IndexFunction"] = "IndexFunction"
functions: Callable[[], Awaitable[List[IBaseFunction]]]
functions: Callable[[], Awaitable[List[IFunction]]]

async def execute(
self, deps: FunctionExecutionDependency
Expand Down
39 changes: 21 additions & 18 deletions packages/openassistants/openassistants/core/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from openassistants.data_models.function_input import FunctionCall, FunctionInputRequest
from openassistants.functions.base import (
FunctionExecutionDependency,
IBaseFunction,
IEntity,
IFunction,
IFunctionLibrary,
)
from openassistants.functions.crud import FunctionCRUD, LocalCRUD, PythonCRUD
from openassistants.functions.crud import LocalFunctionLibrary, PythonLibrary
from openassistants.llm_function_calling.entity_resolution import resolve_entities
from openassistants.llm_function_calling.fallback import perform_general_qa
from openassistants.llm_function_calling.infilling import (
Expand All @@ -36,14 +37,14 @@ class Assistant:
function_summarization: BaseChatModel
function_fallback: BaseChatModel
entity_embedding_model: Embeddings
function_libraries: List[FunctionCRUD]
function_libraries: List[IFunctionLibrary]
scope_description: str

_cached_all_functions: List[IBaseFunction]
_cached_all_functions: List[IFunction]

def __init__(
self,
libraries: List[str | FunctionCRUD],
libraries: List[str | IFunctionLibrary],
function_identification: Optional[BaseChatModel] = None,
function_infilling: Optional[BaseChatModel] = None,
function_summarization: Optional[BaseChatModel] = None,
Expand Down Expand Up @@ -71,12 +72,14 @@ def __init__(
entity_embedding_model or LangChainCachedEmbeddings(OpenAIEmbeddings())
)
self.function_libraries = [
library if isinstance(library, FunctionCRUD) else LocalCRUD(library)
library
if isinstance(library, IFunctionLibrary)
else LocalFunctionLibrary(library)
for library in libraries
]

if add_index:
index_func: IBaseFunction = IndexFunction(
index_func: IFunction = IndexFunction(
id="index",
display_name="List functions",
description=(
Expand All @@ -91,19 +94,19 @@ def __init__(
functions=self.get_all_functions,
)

self.function_libraries.append(PythonCRUD(functions=[index_func]))
self.function_libraries.append(PythonLibrary(functions=[index_func]))

self._cached_all_functions = []

async def get_all_functions(self) -> List[IBaseFunction]:
async def get_all_functions(self) -> List[IFunction]:
if not self._cached_all_functions:
functions = []
functions: List[IFunction] = []
for library in self.function_libraries:
functions.extend(await library.aread_all())
functions.extend(await library.get_all_functions())
self._cached_all_functions = functions
return self._cached_all_functions

async def get_function_by_id(self, function_id: str) -> Optional[IBaseFunction]:
async def get_function_by_id(self, function_id: str) -> Optional[IFunction]:
functions = await self.get_all_functions()
for function in functions:
if function.get_id() == function_id:
Expand All @@ -112,7 +115,7 @@ async def get_function_by_id(self, function_id: str) -> Optional[IBaseFunction]:

async def execute_function(
self,
function: IBaseFunction,
function: IFunction,
func_args: Dict[str, Any],
dependencies: Dict[str, Any],
):
Expand All @@ -138,7 +141,7 @@ async def do_infilling(
self,
dependencies: dict,
message: OpasUserMessage,
selected_function: IBaseFunction,
selected_function: IFunction,
args_json_schema: dict,
entities_info: Dict[str, List[IEntity]],
) -> Tuple[bool, dict]:
Expand Down Expand Up @@ -194,12 +197,12 @@ async def do_infilling(
async def handle_user_plaintext(
self,
message: OpasUserMessage,
all_functions: List[IBaseFunction],
all_functions: List[IFunction],
dependencies: Dict[str, Any],
autorun: bool,
force_select_function: Optional[str],
) -> AsyncStreamVersion[List[OpasMessage]]:
selected_function: Optional[IBaseFunction] = None
selected_function: Optional[IFunction] = None
# perform entity resolution
chat_history: List[OpasMessage] = dependencies.get("chat_history") # type: ignore

Expand Down Expand Up @@ -299,13 +302,13 @@ async def handle_user_plaintext(
async def handle_user_input(
self,
message: OpasUserMessage,
all_functions: List[IBaseFunction],
all_functions: List[IFunction],
dependencies: Dict[str, Any],
) -> AsyncStreamVersion[List[OpasMessage]]:
if message.input_response is None:
raise ValueError("message must have input_response")

selected_function: Optional[IBaseFunction] = None
selected_function: Optional[IFunction] = None

for f in all_functions:
if f.get_id() == message.input_response.name:
Expand Down
6 changes: 3 additions & 3 deletions packages/openassistants/openassistants/eval/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from openassistants.data_models.function_input import FunctionCall
from openassistants.data_models.function_output import DataFrameOutput, TextOutput
from openassistants.functions.base import IBaseFunction
from openassistants.functions.base import IFunction
from openassistants.utils.async_utils import last_value
from pydantic import BaseModel, ConfigDict

Expand Down Expand Up @@ -88,7 +88,7 @@ async def run_function_invocation(
async def get_function(
self,
assistant: Assistant,
) -> IBaseFunction:
) -> IFunction:
base_function = await assistant.get_function_by_id(self.function)
if base_function is None:
raise ValueError("Function not found")
Expand Down Expand Up @@ -159,7 +159,7 @@ class FunctionInteractionResponseNode(BaseModel):
user_input_response: OpasUserMessage
assistant_function_invocation: OpasAssistantMessage
function_response: OpasFunctionMessage
function_spec: IBaseFunction
function_spec: IFunction


class FunctionInteractionResponse(FunctionInteractionResponseNode):
Expand Down
10 changes: 8 additions & 2 deletions packages/openassistants/openassistants/functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_entities(self) -> Sequence[IEntity]:
pass


class IBaseFunction(abc.ABC):
class IFunction(abc.ABC):
@abc.abstractmethod
def get_id(self) -> str:
pass
Expand Down Expand Up @@ -125,7 +125,7 @@ class BaseFunctionParameters(BaseModel):
json_schema: JSONSchema = EMPTY_JSON_SCHEMA


class BaseFunction(IBaseFunction, BaseModel, abc.ABC):
class BaseFunction(IFunction, BaseModel, abc.ABC):
id: str
type: str
display_name: Optional[str] = None
Expand Down Expand Up @@ -157,3 +157,9 @@ def get_parameters_json_schema(self) -> JSONSchema:

async def get_entity_configs(self) -> Mapping[str, IEntityConfig]:
return {}


class IFunctionLibrary(abc.ABC):
@abc.abstractmethod
async def get_all_functions(self) -> Sequence[IFunction]:
pass
43 changes: 19 additions & 24 deletions packages/openassistants/openassistants/functions/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from openassistants.functions.base import (
BaseFunction,
BaseFunctionParameters,
IBaseFunction,
IFunction,
IFunctionLibrary,
)
from openassistants.utils import yaml as yaml_utils
from pydantic import Field, TypeAdapter
Expand All @@ -46,27 +47,32 @@
]


class FunctionCRUD(abc.ABC):
class BaseFileLibrary(IFunctionLibrary, abc.ABC):
@abc.abstractmethod
def read(self, slug: str) -> Optional[IBaseFunction]:
def read(self, slug: str) -> Optional[IFunction]:
pass

@abc.abstractmethod
def list_ids(self) -> List[str]:
pass

async def aread(self, function_id: str) -> Optional[IBaseFunction]:
async def aread(self, function_id: str) -> Optional[IFunction]:
return await run_in_threadpool(self.read, function_id)

async def alist_ids(self) -> List[str]:
return await run_in_threadpool(self.list_ids)

async def aread_all(self) -> List[IBaseFunction]:
async def get_all_functions(self) -> Sequence[IFunction]:
ids = await self.alist_ids()
return await asyncio.gather(*[self.aread(f_id) for f_id in ids]) # type: ignore
funcs: List[IFunction | None] = await asyncio.gather( # type: ignore
*[self.aread(f_id) for f_id in ids]
)
if None in funcs:
raise RuntimeError("Failed to load all functions")
return funcs # type: ignore


class LocalCRUD(FunctionCRUD):
class LocalFunctionLibrary(BaseFileLibrary):
def __init__(self, library_id: str, directory: str = "library"):
self.library_id = library_id
self.directory = Path(directory) / library_id
Expand All @@ -84,32 +90,21 @@ def read(self, function_id: str) -> Optional[BaseFunction]:
except Exception as e:
raise RuntimeError(f"Failed to load: {function_id}") from e

async def aread_all(self) -> List[IBaseFunction]:
ids = self.list_ids()
return [self.read(f_id) for f_id in ids] # type: ignore

def list_ids(self) -> List[str]:
return [
file.stem for file in self.directory.iterdir() if file.suffix == ".yaml"
]


class PythonCRUD(FunctionCRUD):
def __init__(self, functions: Sequence[IBaseFunction]):
class PythonLibrary(IFunctionLibrary):
def __init__(self, functions: Sequence[IFunction]):
self.functions = functions

def read(self, slug: str) -> Optional[IBaseFunction]:
for function in self.functions:
if function.get_id() == slug:
return function

return None

def list_ids(self) -> List[str]:
return [function.get_id() for function in self.functions]
async def get_all_functions(self) -> Sequence[IFunction]:
return self.functions


class OpenAPICRUD(PythonCRUD):
class OpenAPILibrary(PythonLibrary):
openapi: OpenAPISpec

@staticmethod
Expand Down Expand Up @@ -169,6 +164,6 @@ def __init__(self, spec: Union[OpenAPISpec, str], base_url: Optional[str]):
self.openapi.servers[0].url = base_url

openai_functions = openapi_spec_to_openai_fn(self.openapi)
functions = OpenAPICRUD.openai_fns_to_openapi_function(openai_functions)
functions = OpenAPILibrary.openai_fns_to_openapi_function(openai_functions)

super().__init__(functions)
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from langchain.vectorstores.usearch import USearch
from openassistants.data_models.chat_messages import OpasMessage
from openassistants.functions.base import (
IBaseFunction,
IEntity,
IEntityConfig,
IFunction,
)
from openassistants.llm_function_calling.infilling import generate_arguments

Expand Down Expand Up @@ -64,7 +64,7 @@ async def _get_entities(


async def resolve_entities(
function: IBaseFunction,
function: IFunction,
function_infilling_llm: BaseChatModel,
embeddings: Embeddings,
user_query: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from langchain.chat_models.base import BaseChatModel
from langchain.schema.messages import HumanMessage
from openassistants.data_models.chat_messages import OpasMessage
from openassistants.functions.base import IBaseFunction, IEntity
from openassistants.functions.base import IEntity, IFunction
from openassistants.llm_function_calling.utils import (
build_chat_history_prompt,
generate_to_json,
)


async def generate_argument_decisions_schema(function: IBaseFunction):
async def generate_argument_decisions_schema(function: IFunction):
# Start with the base schema
json_schema = function.get_parameters_json_schema()

Expand Down Expand Up @@ -50,7 +50,7 @@ class NestedObject(TypedDict):


async def generate_argument_decisions(
function: IBaseFunction,
function: IFunction,
chat: BaseChatModel,
user_query: str,
chat_history: List[OpasMessage],
Expand Down Expand Up @@ -93,7 +93,7 @@ def entity_to_json_schema_obj(entity: IEntity):


async def generate_arguments(
function: IBaseFunction,
function: IFunction,
chat: BaseChatModel,
user_query: str,
chat_history: List[OpasMessage],
Expand Down
Loading

0 comments on commit ef24a81

Please sign in to comment.