Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: manual linting corrections to pass flake8 #85

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/frontend_service/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ async def get_connector():

async def handle_error_response(response):
if response.status != 200:
return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}"
return (f"Error sending {response.method} request "
f" to {str(response.url)}): {await response.text()}")


async def create_client_session() -> aiohttp.ClientSession:
Expand Down Expand Up @@ -96,14 +97,18 @@ async def init_agent(history: list[messages.BaseMessage]) -> UserAgent:
return_intermediate_steps=True,
)
# Create new prompt template
tool_strings = "\n".join([f"> {tool.name}: {tool.description}" for tool in tools])
tool_strings = "\n".join(
[f"> {tool.name}: {tool.description}" for tool in tools]
)
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = FORMAT_INSTRUCTIONS.format(
tool_names=tool_names,
)
today_date = date.today().strftime("%Y-%m-%d")
today = f"Today is {today_date}."
template = "\n\n".join([PREFIX, tool_strings, format_instructions, SUFFIX, today])
template = "\n\n".join(
[PREFIX, tool_strings, format_instructions, SUFFIX, today]
)
human_message_template = "{input}\n\n{agent_scratchpad}"
prompt = ChatPromptTemplate.from_messages(
[("system", template), ("human", human_message_template)]
Expand All @@ -113,23 +118,27 @@ async def init_agent(history: list[messages.BaseMessage]) -> UserAgent:
return UserAgent(client, agent)


PREFIX = """SFO Airport Assistant helps travelers find their way at the airport.
PREFIX = """SFO Airport Assistant helps travelers find their way at the
airport.

Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to
complex multi-query questions that require passing results from one query to another. As a language model, Assistant is
able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding
conversations and provide responses that are coherent and relevant to the topic at hand.
Assistant is designed to be able to assist with a wide range of tasks, from
answering simple questions to complex multi-query questions that require
passing results from one query to another. As a language model, Assistant is
able to generate human-like text based on the input it receives, allowing it to
engage in natural-sounding conversations and provide responses that are
coherent and relevant to the topic at hand.

Overall, Assistant is a powerful tool that can help answer a wide range of questions pertaining to the San
Francisco Airport. SFO Airport Assistant is here to assist. It currently does not have access to user info.
Overall, Assistant is a powerful tool that can help answer a wide range of
questions pertaining to the San Francisco Airport. SFO Airport Assistant is
here to assist. It currently does not have access to user info.

TOOLS:
------

Assistant has access to the following tools:"""

FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name)
and an action_input key (tool input).
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an
action key (tool name) and an action_input key (tool input).

Valid "action" values: "Final Answer" or {tool_names}

Expand Down Expand Up @@ -162,7 +171,8 @@ async def init_agent(history: list[messages.BaseMessage]) -> UserAgent:
```"""

SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate.
If using a tool, reminder to ALWAYS respond with a valid json blob of a single action.
If using a tool, reminder to ALWAYS respond with a valid json blob of a
single action.
Format is Action:```$JSON_BLOB```then Observation:.
Thought:

Expand Down
25 changes: 17 additions & 8 deletions src/frontend_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def lifespan(app: FastAPI):
async def index(request: Request):
"""Render the default template."""
# Agent setup
agent = await get_agent(request.session, user_id_token=None)
await get_agent(request.session, user_id_token=None)
return templates.TemplateResponse(
"index.html",
{
Expand All @@ -79,7 +79,9 @@ async def login_google(
form_data = await request.form()
user_id_token = form_data.get("credential")
if user_id_token is None:
raise HTTPException(status_code=401, detail="No user credentials found")
raise HTTPException(
status_code=401, detail="No user credentials found"
)
# create new request session
_ = await get_agent(request.session, str(user_id_token))
print("Logged in to Google.")
Expand All @@ -97,11 +99,14 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):
raise HTTPException(status_code=400, detail="Error: No user query")
if "uuid" not in request.session:
raise HTTPException(
status_code=400, detail="Error: Invoke index handler before start chatting"
status_code=400,
detail="Error: Invoke index handler before start chatting",
)

# Add user message to chat history
request.session["history"].append(message_to_dict(HumanMessage(content=prompt)))
request.session["history"].append(
message_to_dict(HumanMessage(content=prompt))
)
user_agent = await get_agent(request.session, user_id_token=None)
try:
print(prompt)
Expand All @@ -113,7 +118,9 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):
)
return markdown(response["output"])
except Exception as err:
raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}")
raise HTTPException(
status_code=500, detail=f"Error invoking agent: {err}"
)


async def get_agent(session: dict[str, Any], user_id_token: Optional[str]):
Expand All @@ -124,7 +131,9 @@ async def get_agent(session: dict[str, Any], user_id_token: Optional[str]):
if "history" not in session:
session["history"] = messages_to_dict(BASE_HISTORY)
if id not in user_agents:
user_agents[id] = await init_agent(messages_from_dict(session["history"]))
user_agents[id] = await init_agent(
messages_from_dict(session["history"])
)
user_agent = user_agents[id]
if user_id_token is not None:
user_agent.client.headers["User-Id-Token"] = f"Bearer {user_id_token}"
Expand All @@ -136,12 +145,12 @@ async def reset(request: Request):
"""Reset agent"""

if "uuid" not in request.session:
raise HTTPException(status_code=400, detail=f"No session to reset.")
raise HTTPException(status_code=400, detail="No session to reset.")

uuid = request.session["uuid"]
global user_agents
if uuid not in user_agents.keys():
raise HTTPException(status_code=500, detail=f"Current agent not found")
raise HTTPException(status_code=500, detail="Current agent not found")

await user_agents[uuid].client.close()
del user_agents[uuid]
Expand Down
4 changes: 3 additions & 1 deletion src/frontend_service/run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ async def main():
app = init_app(client_id=CLIENT_ID, secret_key=SECRET_KEY)
if app is None:
raise TypeError("app not instantiated")
server = uvicorn.Server(uvicorn.Config(app, host=HOST, port=PORT, log_level="info"))
server = uvicorn.Server(
uvicorn.Config(app, host=HOST, port=PORT, log_level="info")
)
await server.serve()


Expand Down
75 changes: 46 additions & 29 deletions src/frontend_service/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import google.oauth2.id_token # type: ignore
from google.auth import compute_engine # type: ignore
from google.auth.transport.requests import Request # type: ignore
from langchain.agents.agent import ExceptionTool # type: ignore
from langchain.tools import StructuredTool
from pydantic.v1 import BaseModel, Field

# URL to connect to the backend services
BASE_URL = os.getenv("SERVICE_URL", default="127.0.0.1") # f"{SERVICE_URL}:8080"
BASE_URL = os.getenv(
"SERVICE_URL", default="127.0.0.1"
) # f"{SERVICE_URL}:8080"
SERVICE_ACCOUNT_EMAIL = os.getenv("SERVICE_ACCOUNT_EMAIL", default=None)
CREDENTIALS = None

Expand Down Expand Up @@ -64,7 +65,7 @@ def get_id_token():
def get_headers(client: aiohttp.ClientSession):
"""Helper method to generate ID tokens for authenticated requests"""
headers = client.headers
if not "http://" in BASE_URL:
if "http://" not in BASE_URL:
# Append ID Token to make authenticated requests to Cloud Run services
headers["Authorization"] = f"Bearer {get_id_token()}"
return headers
Expand Down Expand Up @@ -93,10 +94,12 @@ async def search_airports(country: str, city: str, name: str):
num = 2
response_json = await response.json()
if len(response_json) < 1:
return "There are no airports matching that query. Let the user know there are no results."
return ("There are no airports matching that query. Let the user "
"know there are no results.")
elif len(response_json) > num:
return (
f"There are {len(response_json)} airports matching that query. Here are the first {num} results:\n"
f"There are {len(response_json)} airports matching that "
f"query. Here are the first {num} results:\n"
+ " ".join([f"{response_json[i]}" for i in range(num)])
)
else:
Expand Down Expand Up @@ -127,7 +130,9 @@ class ListFlights(BaseModel):
departure_airport: Optional[str] = Field(
description="Departure airport 3-letter code",
)
arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code")
arrival_airport: Optional[str] = Field(
description="Arrival airport 3-letter code"
)
date: Optional[str] = Field(description="Date of flight departure")


Expand All @@ -151,10 +156,12 @@ async def list_flights(
num = 2
response_json = await response.json()
if len(response_json) < 1:
return "There are no flights matching that query. Let the user know there are no results."
return ("There are no flights matching that query. Let the user "
"know there are no results.")
elif len(response_json) > num:
return (
f"There are {len(response_json)} flights matching that query. Here are the first {num} results:\n"
f"There are {len(response_json)} flights matching that query. "
f"Here are the first {num} results:\n"
+ " ".join([f"{response_json[i]}" for i in range(num)])
)
else:
Expand Down Expand Up @@ -240,10 +247,12 @@ async def initialize_tools(client: aiohttp.ClientSession):
coroutine=generate_search_airports(client),
name="Search Airport",
description="""
Use this tool to list all airports matching search criteria.
Takes at least one of country, city, name, or all and returns all matching airports.
The agent can decide to return the results directly to the user.
Input of this tool must be in JSON format and include all three inputs - country, city, name.
Use this tool to list all airports matching search
criteria. Takes at least one of country, city, name, or
all and returns all matching airports. The agent can
decide to return the results directly to the user.
Input of this tool must be in JSON format and include
all three inputs - country, city, name.
Example:
{{
"country": "United States",
Expand All @@ -269,23 +278,28 @@ async def initialize_tools(client: aiohttp.ClientSession):
coroutine=generate_search_flights_by_number(client),
name="Search Flights By Flight Number",
description="""
Use this tool to get info for a specific flight. Do NOT use this tool with a flight id.
Takes an airline and flight number and returns info on the flight.
Do NOT guess an airline or flight number.
A flight number is a code for an airline service consisting of two-character
airline designator and a 1 to 4 digit number ex. OO123, DL 1234, BA 405, AS 3452.
If the tool returns more than one option choose the date closes to today.
Use this tool to get info for a specific flight. Do NOT
use this tool with a flight id. Takes an airline and
flight number and returns info on the flight. Do NOT
guess an airline or flight number. A flight number is a
code for an airline service consisting of two-character
airline designator and a 1 to 4 digit number ex. OO123,
DL 1234, BA 405, AS 3452. If the tool returns more than
one option choose the date closes to today.
""",
args_schema=FlightNumberInput,
),
StructuredTool.from_function(
coroutine=generate_list_flights(client),
name="List Flights",
description="""
Use this tool to list all flights matching search criteria.
Takes an arrival airport, a departure airport, or both, filters by date and returns all matching flights.
The agent can decide to return the results directly to the user.
Input of this tool must be in JSON format and include all three inputs - arrival_airport, departure_airport, and date.
Use this tool to list all flights matching search
criteria. Takes an arrival airport, a departure
airport, or both, filters by date and returns all
matching flights. The agent can decide to return the
results directly to the user. Input of this tool must
be in JSON format and include all three inputs -
arrival_airport, departure_airport, and date.
Example:
{{
"departure_airport": "SFO",
Expand All @@ -311,12 +325,14 @@ async def initialize_tools(client: aiohttp.ClientSession):
coroutine=generate_search_amenities(client),
name="Search Amenities",
description="""
Use this tool to search amenities by name or to recommended airport amenities at SFO.
If user provides flight info, use 'Get Flight' and 'Get Flights by Number'
first to get gate info and location.
Only recommend amenities that are returned by this query.
Find amenities close to the user by matching the terminal and then comparing
the gate numbers. Gate number iterate by letter and number, example A1 A2 A3
Use this tool to search amenities by name or to
recommended airport amenities at SFO. If user provides
flight info, use 'Get Flight' and 'Get Flights by
Number' first to get gate info and location. Only
recommend amenities that are returned by this query.
Find amenities close to the user by matching the
terminal and then comparing the gate numbers. Gate
number iterate by letter and number, example A1 A2 A3
B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1.
""",
args_schema=QueryInput,
Expand Down Expand Up @@ -361,7 +377,8 @@ async def initialize_tools(client: aiohttp.ClientSession):
name="List Tickets",
description="""
Use this tool to list a user's flight tickets.
Takes no input and returns a list of current user's flight tickets.
Takes no input and returns a list of current user's
flight tickets.
""",
),
]
2 changes: 0 additions & 2 deletions src/retrieval_service/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Copy link
Contributor Author

@glasnt glasnt Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on previous, at a guess this is what was meant? (Based on a similar pattern in src/retrieval_service/datastore/init.py, datastore/providers/init.py)

Suggested change
# limitations under the License.
# limitations under the License.
from .app import EMBEDDING_MODEL_NAME, init_app, parse_config
__ALL__ = [EMBEDDING_MODEL_NAME, init_app, parse_config]


from .app import EMBEDDING_MODEL_NAME, init_app, parse_config
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed because linting, but this breaks the process because of current imports in run_app

17 changes: 12 additions & 5 deletions src/retrieval_service/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ipaddress import IPv4Address, IPv6Address
from typing import Optional

import yaml
import os
from fastapi import FastAPI
from langchain_google_vertexai import VertexAIEmbeddings
Expand All @@ -41,11 +40,17 @@ def parse_config() -> AppConfig:
config["host"] = os.environ.get("APP_HOST", "127.0.0.1")
config["port"] = os.environ.get("APP_PORT", 8080)
config["datastore"] = {}
config["datastore"]["kind"] = os.environ.get("DB_KIND", "cloudsql-postgres")
config["datastore"]["kind"] = os.environ.get(
"DB_KIND", "cloudsql-postgres"
)
config["datastore"]["project"] = os.environ.get("DB_PROJECT", "my-project")
config["datastore"]["region"] = os.environ.get("DB_REGION", "us-central1")
config["datastore"]["instance"] = os.environ.get("DB_INSTANCE", "my-instance")
config["datastore"]["database"] = os.environ.get("DB_NAME", "assistantdemo")
config["datastore"]["instance"] = os.environ.get(
"DB_INSTANCE", "my-instance"
)
config["datastore"]["database"] = os.environ.get(
"DB_NAME", "assistantdemo"
)
config["datastore"]["user"] = os.environ.get("DB_USER", "postgres")
config["datastore"]["password"] = os.environ.get("DB_PASSWORD", "password")
return AppConfig(**config)
Expand All @@ -55,7 +60,9 @@ def parse_config() -> AppConfig:
def gen_init(cfg: AppConfig):
async def initialize_datastore(app: FastAPI):
app.state.datastore = await datastore.create(cfg.datastore)
app.state.embed_service = VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME)
app.state.embed_service = VertexAIEmbeddings(
model_name=EMBEDDING_MODEL_NAME
)
yield
await app.state.datastore.close()

Expand Down
Loading