Skip to content

Commit

Permalink
Fix type hints and imports in agent_bean module
Browse files Browse the repository at this point in the history
  • Loading branch information
kolergy committed Mar 17, 2024
1 parent b5ea1c9 commit ae3d913
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion agent_bean/agent_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_available_functions(self) -> List[str]:
return self.functions_str_list


def get_special_tokens(self, model_name:str) -> [dict]:
def get_special_tokens(self, model_name:str) -> Dict:
"""get the special tokens used by the model"""
keys = ["model_sys_delim", "model_usr_delim"]
out = {k:self.setup['models_list'][model_name][k] for k in keys }
Expand Down
7 changes: 4 additions & 3 deletions agent_bean/models_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import torch
import tiktoken

from typing import List
from langchain_community.chat_models import ChatOpenAI
from agent_bean.system_info import SystemInfo
from agent_bean.transformers_model import TfModel , TransformersEmbeddings
from agent_bean.ollama_model import OllamaModel #, OllamaEmbeddings
from agent_bean.ollama_model import OllamaModel , OllamaEmbeddings
from agent_bean.mistral_model import MistralModel, MistralEmbeddings
from transformers import GenerationConfig
#from agent_bean.google_vertexai_model import VertexAIModel, VertexAIEmbeddings
Expand Down Expand Up @@ -67,7 +68,7 @@ def setup_update(self, setup: dict) -> None:
for m in self.active_models.keys():
self.deinstantiate_model(m)

def get_available_models(self) -> [str]:
def get_available_models(self) -> List[str]:
"""Return a list of available model names."""
return list(self.setup['models_list'].keys())

Expand Down Expand Up @@ -177,7 +178,7 @@ def predict(self, model_name:str, prompt:str ) -> str:
return None


def decode(self, model_name:str, tokens:[float]) -> str:
def decode(self, model_name:str, tokens:List[float]) -> str:
"""Decode a sequence of tokens using a specified model's embeddings.
Args:
Expand Down
16 changes: 8 additions & 8 deletions agent_bean/ollama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from typing import List
from agent_bean.system_info import SystemInfo

#class OllamaEmbeddings:
# """This class wraps the Ollama tokenizer to be used like embeddings"""
# def __init__(self, tokenizer: ollama.OllamaTokenizer) -> None:
# self.tokenizer = tokenizer

# def encode(self, text: str) -> List[int]:
# """Return the token IDs for the text"""
# return self.tokenizer.encode(text)
class OllamaEmbeddings:
"""This class wraps the Ollama tokenizer to be used like embeddings"""
def __init__(self, tokenizer: ollama.embeddings) -> None:
self.tokenizer = tokenizer

def encode(self, text: str) -> List[int]:
"""Return the token IDs for the text"""
return self.tokenizer.encode(text)

# def decode(self, tokens: List[int]) -> str:
# """Return the text for the token IDs"""
Expand Down
12 changes: 6 additions & 6 deletions agent_bean_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def write(self, message):
def flush(self):
self.terminal.flush()
self.log.flush()

def isatty(self):
return False
return False

sys.stdout = Logger("output.log")
sys.stderr = Logger("output.log")
Expand Down Expand Up @@ -109,8 +109,8 @@ def update_ram():
v_ram_used_Gb.append(agent.si.get_v_ram_used() )
time_s.append( time.time() - start_time )

df = pd.DataFrame({'time_s': time_s, 'ram_used_Gb': ram_used_Gb, 'v_ram_used_Gb': v_ram_used_Gb})
df = df.tail(4998) # Limit the number of points gradio crashes at 5000
df = pd.DataFrame({'time_s': time_s, 'ram_used_Gb': ram_used_Gb})
df = df.tail(4999) # Limit the number of points gradio crashes at 5000
update_ram = gr.LinePlot(
value = df,
title = ram_label,
Expand All @@ -125,8 +125,8 @@ def update_ram():
def update_v_ram():
""" Update the v_ram plot """
#print(f"update_v_ram() called, elapsed: {time.time() - start_time:6.2f} s. v_ram_used_Gb: {agent.si.get_v_ram_used():6.2f} Gb. v_ram_total_Gb: {agent.si.get_v_ram_total():6.2f} Gb. v_ram_free_Gb: {agent.si.get_v_ram_free():6.2f} Gb.")
df = pd.DataFrame({'time_s': time_s, 'ram_used_Gb': ram_used_Gb, 'v_ram_used_Gb': v_ram_used_Gb})
df = df.tail(4998) # Limit the number of points gradio crashes at 5000
df = pd.DataFrame({'time_s': time_s, 'v_ram_used_Gb': v_ram_used_Gb})
df = df.tail(4999) # Limit the number of points gradio crashes at 5000
update_v_ram = gr.LinePlot(
value = df,
title = v_ram_label,
Expand Down

0 comments on commit ae3d913

Please sign in to comment.