Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename AI to RAGClient and add compat names #578

Merged
merged 5 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions edgedb/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,30 @@
TYPE_CHECKING = False
if TYPE_CHECKING:
from gel.ai import * # noqa
create_ai = create_rag_client # noqa
EdgeDBAI = RAGClient # noqa
create_async_ai = create_async_rag_client # noqa
AsyncEdgeDBAI = AsyncRAGClient # noqa
AIOptions = RAGOptions # noqa
import gel.ai as _mod
import sys as _sys
_cur = _sys.modules['edgedb.ai']
for _k in vars(_mod):
if not _k.startswith('__') or _k in ('__all__', '__doc__'):
setattr(_cur, _k, getattr(_mod, _k))
_cur.create_ai = _mod.create_rag_client
_cur.EdgeDBAI = _mod.RAGClient
_cur.create_async_ai = _mod.create_async_rag_client
_cur.AsyncEdgeDBAI = _mod.AsyncRAGClient
_cur.AIOptions = _mod.RAGOptions
if hasattr(_cur, '__all__'):
_cur.__all__ = _cur.__all__ + [
'create_ai',
'EdgeDBAI',
'create_async_ai',
'AsyncEdgeDBAI',
'AIOptions',
]
del _cur
del _sys
del _mod
Expand Down
16 changes: 8 additions & 8 deletions gel/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
# limitations under the License.
#

from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext
from .core import create_ai, EdgeDBAI
from .core import create_async_ai, AsyncEdgeDBAI
from .types import RAGOptions, ChatParticipantRole, Prompt, QueryContext
from .core import create_rag_client, RAGClient
from .core import create_async_rag_client, AsyncRAGClient

__all__ = [
"AIOptions",
"RAGOptions",
"ChatParticipantRole",
"Prompt",
"QueryContext",
"create_ai",
"EdgeDBAI",
"create_async_ai",
"AsyncEdgeDBAI",
"create_rag_client",
"RAGClient",
"create_async_rag_client",
"AsyncRAGClient",
]
20 changes: 10 additions & 10 deletions gel/ai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,27 @@
from . import types


def create_ai(client: gel.Client, **kwargs) -> EdgeDBAI:
def create_rag_client(client: gel.Client, **kwargs) -> RAGClient:
client.ensure_connected()
return EdgeDBAI(client, types.AIOptions(**kwargs))
return RAGClient(client, types.RAGOptions(**kwargs))


async def create_async_ai(
async def create_async_rag_client(
client: gel.AsyncIOClient, **kwargs
) -> AsyncEdgeDBAI:
) -> AsyncRAGClient:
await client.ensure_connected()
return AsyncEdgeDBAI(client, types.AIOptions(**kwargs))
return AsyncRAGClient(client, types.RAGOptions(**kwargs))


class BaseEdgeDBAI:
options: types.AIOptions
class BaseRAGClient:
options: types.RAGOptions
context: types.QueryContext
client_cls = NotImplemented

def __init__(
self,
client: typing.Union[gel.Client, gel.AsyncIOClient],
options: types.AIOptions,
options: types.RAGOptions,
**kwargs,
):
pool = client._impl
Expand Down Expand Up @@ -103,7 +103,7 @@ def _make_rag_request(
)


class EdgeDBAI(BaseEdgeDBAI):
class RAGClient(BaseRAGClient):
client: httpx.Client

def _init_client(self, **kwargs):
Expand Down Expand Up @@ -146,7 +146,7 @@ def generate_embeddings(self, *inputs: str, model: str) -> list[float]:
return resp.json()["data"][0]["embedding"]


class AsyncEdgeDBAI(BaseEdgeDBAI):
class AsyncRAGClient(BaseRAGClient):
client: httpx.AsyncClient

def _init_client(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions gel/ai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ class Prompt:


@dc.dataclass
class AIOptions:
class RAGOptions:
model: str
prompt: typing.Optional[Prompt] = None

def derive(self, kwargs):
return AIOptions(**{**dc.asdict(self), **kwargs})
return RAGOptions(**{**dc.asdict(self), **kwargs})


@dc.dataclass
Expand Down
4 changes: 2 additions & 2 deletions tools/gen_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
if __name__ == '__main__':
this = pathlib.Path(__file__)

errors_fn = this.parent.parent / 'edgedb' / 'errors' / '__init__.py'
init_fn = this.parent.parent / 'edgedb' / '__init__.py'
errors_fn = this.parent.parent / 'gel' / 'errors' / '__init__.py'
init_fn = this.parent.parent / 'gel' / '__init__.py'

with open(errors_fn, 'rt') as f:
errors_txt = f.read()
Expand Down
32 changes: 29 additions & 3 deletions tools/make_import_shims.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import os
import sys

MODS = sorted(['gel', 'gel._taskgroup', 'gel._version', 'gel.abstract', 'gel.ai', 'gel.ai.core', 'gel.ai.types', 'gel.asyncio_client', 'gel.base_client', 'gel.blocking_client', 'gel.codegen', 'gel.color', 'gel.con_utils', 'gel.credentials', 'gel.datatypes', 'gel.datatypes.datatypes', 'gel.datatypes.range', 'gel.describe', 'gel.enums', 'gel.errors', 'gel.errors._base', 'gel.errors.tags', 'gel.introspect', 'gel.options', 'gel.pgproto', 'gel.pgproto.pgproto', 'gel.pgproto.types', 'gel.platform', 'gel.protocol', 'gel.protocol.asyncio_proto', 'gel.protocol.blocking_proto', 'gel.protocol.protocol', 'gel.scram', 'gel.scram.saslprep', 'gel.transaction'])

COMPAT = {
'gel.ai': {
'create_ai': 'create_rag_client',
'EdgeDBAI': 'RAGClient',
'create_async_ai': 'create_async_rag_client',
'AsyncEdgeDBAI': 'AsyncRAGClient',
'AIOptions': 'RAGOptions',
},
}


def main():
Expand All @@ -12,7 +19,10 @@ def main():
nmod = 'edgedb' + mod[len('gel'):]
slash_name = nmod.replace('.', '/')
if is_package:
os.mkdir(slash_name)
try:
os.mkdir(slash_name)
except FileExistsError:
pass
fname = slash_name + '/__init__.py'
else:
fname = slash_name + '.py'
Expand All @@ -25,12 +35,28 @@ def main():
TYPE_CHECKING = False
if TYPE_CHECKING:
from {mod} import * # noqa
''')
if mod in COMPAT:
for k, v in COMPAT[mod].items():
f.write(f' {k} = {v} # noqa\n')
f.write(f'''\
import {mod} as _mod
import sys as _sys
_cur = _sys.modules['{nmod}']
for _k in vars(_mod):
if not _k.startswith('__') or _k in ('__all__', '__doc__'):
setattr(_cur, _k, getattr(_mod, _k))
''')
if mod in COMPAT:
for k, v in COMPAT[mod].items():
f.write(f"_cur.{k} = _mod.{v}\n")
f.write(f'''\
if hasattr(_cur, '__all__'):
_cur.__all__ = _cur.__all__ + [
{',\n '.join(repr(k) for k in COMPAT[mod])},
]
''')
f.write(f'''\
del _cur
del _sys
del _mod
Expand Down
Loading