Skip to content

Commit

Permalink
Merge pull request #103 from imotai/main
Browse files Browse the repository at this point in the history
add test cases for kernel and og_up
  • Loading branch information
imotai authored Oct 3, 2023
2 parents e265439 + 8faff75 commit 3c2916c
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 32 deletions.
2 changes: 1 addition & 1 deletion chat/src/og_terminal/terminal_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def show_help(console):


def gen_a_random_emoji():
index = random.randint(0, len(EMOJI_KEYS))
index = random.randint(0, len(EMOJI_KEYS) - 1)
return EMOJI[EMOJI_KEYS[index]]


Expand Down
2 changes: 2 additions & 0 deletions chat/tests/test_chat_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ def test_parse_number():
numbers = parse_numbers(test_text)
assert numbers
assert numbers[0] == "0"


17 changes: 13 additions & 4 deletions kernel/src/og_kernel/server/kernel_rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,13 @@ def __init__(self):
message="api key is required",
details=[],
)

os.makedirs(config["config_root_path"], exist_ok=True)
os.makedirs(config["workspace"], exist_ok=True)
config_root_path = config["config_root_path"]
workspace = config["workspace"]
logger.info(
f"start kernel rpc with config root path {config_root_path} and workspace {workspace}"
)

async def stop(
self, request: kernel_server_pb2.StopKernelRequest, context: ServicerContext
Expand Down Expand Up @@ -142,9 +147,11 @@ async def start(
"""
kernel_name = request.kernel_name if request.kernel_name else "python3"
if kernel_name in self.kms and self.kms[kernel_name]:
logger.warning("the kernel has been started")
logger.warning(
"the request will be ignored for that the kernel has been started"
)
return kernel_server_pb2.StartKernelResponse(
code=1, msg="the kernel has been started"
code=0, msg="the kernel has been started"
)
logging.info("create a new kernel with kernel_name %s" % kernel_name)
connection_file = "%s/kernel-%s.json" % (
Expand All @@ -165,7 +172,9 @@ async def download(
"""
download file
"""
target_filename = "%s/%s" % (config["workspace"], request.filename)
filename = request.filename
workspace = config["workspace"]
target_filename = f"{workspace}{os.sep}{filename}"
if not await aio_os.path.exists(target_filename):
await context.abort(10, "%s filename do not exist" % request.filename)
async with aiofiles.open(target_filename, "rb") as afp:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ click==8.1.7
discord.py==2.3.2
openai==0.28.1
build==1.0.3
python-dotenv==1.0.0

29 changes: 3 additions & 26 deletions sdk/src/og_sdk/agent_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from og_proto.agent_server_pb2_grpc import AgentServerStub
import aiofiles
from typing import AsyncIterable
from .utils import generate_chunk, generate_async_chunk

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,19 +85,9 @@ def upload_file(self, filepath, filename):
"""

# TODO limit the file size
def generate_trunk(filepath, filename) -> common_pb2.FileChunk:
try:
with open(filepath, "rb") as fp:
while True:
chunk = fp.read(1024 * 128)
if not chunk:
break
yield common_pb2.FileChunk(buffer=chunk, filename=filename)
except Exception as ex:
logger.error("fail to read file %s" % ex)

return self.stub.upload(
generate_trunk(filepath, filename), metadata=self.metadata
generate_chunk(filepath, filename), metadata=self.metadata
)

def prompt(self, prompt, files=[]):
Expand Down Expand Up @@ -189,22 +180,8 @@ async def upload_file(self, filepath, filename):
"""
upload file to agent
"""

# TODO limit the file size
async def generate_trunk(
filepath, filename
) -> AsyncIterable[common_pb2.FileChunk]:
try:
async with aiofiles.open(filepath, "rb") as afp:
while True:
chunk = await afp.read(1024 * 128)
if not chunk:
break
yield common_pb2.FileChunk(buffer=chunk, filename=filename)
except Exception as ex:
logger.error("fail to read file %s", ex)

return await self.upload_binary(generate_trunk(filepath, filename))
return await self.upload_binary(generate_async_chunk(filepath, filename))

def close(self):
if self.channel:
Expand Down
32 changes: 32 additions & 0 deletions sdk/src/og_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,38 @@
import re
import string
import random
import aiofiles
import logging
from og_proto import agent_server_pb2, common_pb2
from typing import AsyncIterable

logger = logging.getLogger(__name__)


def generate_chunk(filepath, filename) -> common_pb2.FileChunk:
try:
with open(filepath, "rb") as fp:
while True:
chunk = fp.read(1024 * 128)
if not chunk:
break
yield common_pb2.FileChunk(buffer=chunk, filename=filename)
except Exception as ex:
logger.error("fail to read file %s" % ex)


async def generate_async_chunk(
filepath, filename
) -> AsyncIterable[common_pb2.FileChunk]:
try:
async with aiofiles.open(filepath, "rb") as afp:
while True:
chunk = await afp.read(1024 * 128)
if not chunk:
break
yield common_pb2.FileChunk(buffer=chunk, filename=filename)
except Exception as ex:
logger.error("fail to read file %s", ex)


def process_char_stream(stream):
Expand Down
1 change: 1 addition & 0 deletions sdk/tests/agent_sdk_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
import json
import random
import logging
from tempfile import gettempdir
from pathlib import Path
from og_sdk.agent_sdk import AgentSDK
Expand Down
32 changes: 32 additions & 0 deletions sdk/tests/kernel_sdk_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
# limitations under the License.

""" """
import os
import asyncio
import pytest
import logging
import json
from og_sdk.kernel_sdk import KernelSDK
from og_sdk.utils import generate_async_chunk
from og_proto.kernel_server_pb2 import ExecuteResponse
import aiofiles
from typing import AsyncIterable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -52,6 +56,34 @@ async def test_bad_sdk(bad_kernel_sdk):
assert True


@pytest.mark.asyncio
async def test_upload_and_download_smoke_test(kernel_sdk):
kernel_sdk.connect()
path = os.path.abspath(__file__)
response = await kernel_sdk.upload_binary(
generate_async_chunk(path, "kernel_sdk_tests.py")
)
assert response
file_stats = os.stat(path)
assert response.length == file_stats.st_size, "bad upload file size"
length = 0
async for chunk in kernel_sdk.download_file("kernel_sdk_tests.py"):
length += len(chunk.buffer)
assert length == file_stats.st_size, "bad upload file size"


@pytest.mark.asyncio
async def test_stop_kernel(kernel_sdk):
kernel_sdk.connect()
assert kernel_sdk.stub is not None # Check that stub is initialized
if not await kernel_sdk.is_alive():
await kernel_sdk.start()
assert await kernel_sdk.is_alive()
response = await kernel_sdk.stop()
assert response.code == 0
assert not await kernel_sdk.is_alive()


@pytest.mark.asyncio
async def test_sdk_smoke_test(kernel_sdk):
kernel_sdk.connect()
Expand Down
1 change: 1 addition & 0 deletions start_sandbox.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ cat <<EOF> ~/.octogen/config
endpoint=127.0.0.1:9528
api_key=${KERNEL_RPC_KEY}
EOF
og_ping
2 changes: 1 addition & 1 deletion up/src/og_up/up.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_latest_release_version(repo_name, live, segments):
refresh(live, segments)
r = requests.get(f"https://api.github.com/repos/{repo_name}/releases/latest")
old_segment = segments.pop()
version = r.json().get("name", "")
version = r.json().get("name", "").strip()
if not version:
segments.append(("❌", "Get octogen latest version failed", version))
else:
Expand Down
73 changes: 73 additions & 0 deletions up/tests/up_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sys
import pytest
import tempfile
import logging
from rich.live import Live
from rich.console import Console
from og_up.up import run_with_realtime_print
Expand All @@ -19,7 +20,79 @@
from og_up.up import get_latest_release_version
from og_up.up import start_octogen_for_codellama
from og_up.up import random_str
from og_up.up import generate_agent_common, generate_agent_azure_openai, generate_agent_openai, generate_agent_codellama
from og_up.up import generate_kernel_env
from rich.console import Group
from dotenv import dotenv_values

logger = logging.getLogger(__name__)


def test_generate_kernel_env():
console = Console()
segments = []
with Live(Group(*segments), console=console) as live:
temp_dir = tempfile.mkdtemp(prefix="octogen")
rpc_key = "rpc_key"
generate_kernel_env(live, segments, temp_dir, rpc_key)
fullpath = f"{temp_dir}/kernel/.env"
config = dotenv_values(fullpath)
assert config["rpc_key"] == rpc_key, "bad rpc key"
assert config["rpc_port"] == "9527", "bad rpc port"


def test_generate_agent_codellama():
console = Console()
segments = []
with Live(Group(*segments), console=console) as live:
temp_dir = tempfile.mkdtemp(prefix="octogen")
admin_key = "admin_key"
generate_agent_codellama(live, segments, temp_dir, admin_key)
fullpath = f"{temp_dir}/agent/.env"
config = dotenv_values(fullpath)
assert config["llm_key"] == "codellama", "bad llm key"
assert (
config["llama_api_base"] == "http://127.0.0.1:8080"
), "bad codellama server endpoint"
assert config["admin_key"] == admin_key, "bad admin key"


def test_generate_agent_env_openai():
console = Console()
segments = []
with Live(Group(*segments), console=console) as live:
temp_dir = tempfile.mkdtemp(prefix="octogen")
admin_key = "admin_key"
openai_key = "openai_key"
model = "gpt-4-0613"
generate_agent_openai(live, segments, temp_dir, admin_key, openai_key, model)
fullpath = f"{temp_dir}/agent/.env"
config = dotenv_values(fullpath)
assert config["llm_key"] == "openai", "bad llm key"
assert config["openai_api_key"] == openai_key, "bad api key"
assert config["openai_api_model"] == model, "bad model"
assert config["admin_key"] == admin_key, "bad admin key"


def test_generate_agent_env_azure_openai():
console = Console()
segments = []
with Live(Group(*segments), console=console) as live:
temp_dir = tempfile.mkdtemp(prefix="octogen")
admin_key = "admin_key"
openai_key = "openai_key"
deployment = "octogen"
api_base = "azure"
generate_agent_azure_openai(
live, segments, temp_dir, admin_key, openai_key, deployment, api_base
)
fullpath = f"{temp_dir}/agent/.env"
config = dotenv_values(fullpath)
assert config["llm_key"] == "azure_openai", "bad llm key"
assert config["openai_api_base"] == api_base, "bad api base"
assert config["openai_api_key"] == openai_key, "bad api key"
assert config["openai_api_deployment"] == deployment, "bad deployment"
assert config["admin_key"] == admin_key, "bad admin key"


def test_run_print():
Expand Down

0 comments on commit 3c2916c

Please sign in to comment.