Skip to content

Commit

Permalink
run linter on repo
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnicita committed Jul 31, 2024
1 parent f116f8c commit 99930ca
Show file tree
Hide file tree
Showing 16 changed files with 242 additions and 166 deletions.
5 changes: 0 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,3 @@ repos:
hooks:
- id: black
language_version: python3.9
- repo: https://github.com/PyCQA/flake8
rev: 7.1.0
hooks:
- id: flake8
language_version: python3.9
8 changes: 4 additions & 4 deletions ai/llm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from api.polymarket.gamma import GammaMarketClient
from polymarket.agents.ai.llm.prompts import Prompter

class Executor:

class Executor:
def __init__(self):
load_dotenv()
self.prompter = Prompter()
Expand All @@ -26,15 +26,15 @@ def get_llm_response(self, user_input: str) -> str:
result = self.llm.invoke(messages)
return result.content


def get_superforecast(self, event_title: str, market_question: str, outcome: str) -> str:
def get_superforecast(
self, event_title: str, market_question: str, outcome: str
) -> str:
messages = prompts.superforecaster(
event_title=event_title, market_question=market_question, outcome=outcome
)
result = self.llm.invoke(messages)
return result.content


def get_polymarket_llm(self, user_input: str) -> str:
client = GammaMarketClient()
data1 = client.get_current_events()
Expand Down
19 changes: 7 additions & 12 deletions ai/llm/prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
class Prompter:

def generate_simple_ai_trader(market_description: str, relevant_info: str) -> str:
return f"""
Expand All @@ -12,14 +11,12 @@ def generate_simple_ai_trader(market_description: str, relevant_info: str) -> st
Do you buy or sell? How much?
"""


def market_analyst(self) -> str:
return f"""
You are a market analyst that takes a description of an event and produces a market forecast.
Assign a probability estimate to the event occurring described by the user
"""


def sentiment_analyzer(question: str, outcome: str) -> float:
return f"""
You are a political scientist trained in media analysis.
Expand All @@ -31,7 +28,6 @@ def sentiment_analyzer(question: str, outcome: str) -> float:
"""


def superforecaster(event_title: str, market_question: str, outcome: str) -> str:
return f"""
You are a Superforecaster tasked with correctly predicting the likelihood of events.
Expand Down Expand Up @@ -66,8 +62,9 @@ def superforecaster(event_title: str, market_question: str, outcome: str) -> str
I believe {market_question} has a likelihood {float} for outcome of {outcome}.
"""


def prompts_polymarket(data1: str, data2: str, market_question: str, outcome: str) -> str:
def prompts_polymarket(
data1: str, data2: str, market_question: str, outcome: str
) -> str:
current_market_data = str(data1)
current_event_data = str(data2)
return f"""
Expand All @@ -84,7 +81,6 @@ def prompts_polymarket(data1: str, data2: str, market_question: str, outcome: st
I believe {market_question} has a likelihood {float} for outcome of {outcome}.
"""


def prompts_polymarket(data1: str, data2: str, user_input: str) -> str:
current_market_data = str(data1)
current_event_data = str(data2)
Expand All @@ -99,11 +95,9 @@ def prompts_polymarket(data1: str, data2: str, user_input: str) -> str:
"""


def routing(system_message: str) -> str:
return f"""You are an expert at routing a user question to the appropriate data source. """


def multiquery(question: str) -> str:
return f"""
You're an AI assistant. Your task is to generate five different versions
Expand All @@ -114,22 +108,23 @@ def multiquery(question: str) -> str:
"""


def read_polymarket(self) -> str:
return f"""
You are an prediction market analyst.
"""


def polymarket_analyst_api(self) -> str:
return f"""You are an AI assistant for analyzing prediction markets.
You will be provided with json output for api data from Polymarket.
Polymarket is an online prediction market that lets users Bet on the outcome of future events in a wide range of topics, like sports, politics, and pop culture.
Get accurate real-time probabilities of the events that matter most to you. """

def filter_events(self) -> str:
return self.polymarket_analyst_api(self) + """
return (
self.polymarket_analyst_api(self)
+ """
Filter events for the ones you will be best at trading on profitably.
"""
)
35 changes: 22 additions & 13 deletions ai/rag/polymarket_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,52 @@
from langchain_community.document_loaders import JSONLoader
from api.polymarket.gamma import GammaMarketClient


class PolymarketRAG:
def __init__(self, local_db_directory=None, embedding_function=None) -> None:
self.gamma_client = GammaMarketClient()
self.local_db_directory = local_db_directory
self.embedding_function = embedding_function

def load_json_from_local(self, json_file_path=None, vector_db_directory='./local_db'):

def load_json_from_local(
self, json_file_path=None, vector_db_directory="./local_db"
):
loader = JSONLoader(
file_path=json_file_path,
jq_schema='.[].description',
text_content=False)
file_path=json_file_path, jq_schema=".[].description", text_content=False
)
loaded_docs = loader.load()

embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
db2 = Chroma.from_documents(loaded_docs, embedding_function, persist_directory=vector_db_directory)
db2 = Chroma.from_documents(
loaded_docs, embedding_function, persist_directory=vector_db_directory
)

return db2

def create_local_markets_rag(self, local_directory='./local_db'):
def create_local_markets_rag(self, local_directory="./local_db"):
all_markets = self.gamma_client.get_all_current_markets()

if not os.path.isdir(local_directory):
os.mkdir(local_directory)

local_file_path = f'{local_directory}/all-current-markets_{time.time()}.json'
local_file_path = f"{local_directory}/all-current-markets_{time.time()}.json"

with open(local_file_path, 'w+') as output_file:
with open(local_file_path, "w+") as output_file:
json.dump(all_markets, output_file)

self.load_json_from_local(json_file_path=local_file_path, vector_db_directory=local_directory)


self.load_json_from_local(
json_file_path=local_file_path, vector_db_directory=local_directory
)

def query_local_markets_rag(self, local_directory=None, query=None):
embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
local_db = Chroma(persist_directory=local_directory, embedding_function=embedding_function)
local_db = Chroma(
persist_directory=local_directory, embedding_function=embedding_function
)
response_docs = local_db.similarity_search_with_score(query=query)
return response_docs


# TODO:
# 1. Pull available markets
# 2. Prompt to find a market to trade on
Expand Down
46 changes: 26 additions & 20 deletions ai/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,64 @@
import json


# TODO: this could be done using regex for sleeker but less readable code.
def parse_camel_case(key):
output = ''
output = ""
for char in key:
if char.isupper():
output += ' '
output += " "
output += char.lower()
else:
output += char
return output


# Experimenting with converting json fields into natural language to improve retrieval from text embed mode
def preprocess_market_object(market_object):
description = market_object['description']
description = market_object["description"]

for k, v in market_object.items():
if k == 'description':
if k == "description":
continue
if isinstance(v, bool):
description += f' This market is{" not" if not v else ""} {parse_camel_case(k)}.'

if k in ['volume', 'liquidity']:
description += (
f' This market is{" not" if not v else ""} {parse_camel_case(k)}.'
)

if k in ["volume", "liquidity"]:
description += f" This market has a current {k} of {v}."
print('\n\ndescription:', description)
print("\n\ndescription:", description)

market_object['description'] = description
market_object["description"] = description

return market_object


def preprocess_local_json(file_path, preprocessor_function):
with open(file_path, 'r+') as open_file:
with open(file_path, "r+") as open_file:
data = json.load(open_file)

output = []
for obj in data:
preprocessed_json = preprocessor_function(obj)
output.append(preprocessed_json)
split_path = file_path.split('.')

split_path = file_path.split(".")
new_file_path = split_path[0] + "_preprocessed." + split_path[1]
with open(new_file_path, 'w+') as output_file:
with open(new_file_path, "w+") as output_file:
json.dump(output, output_file)


# Options for improving search:
# 1. Translate JSON params into natural language
# 2. Metadata function with post-filtering on metadata kv pairs
def metadata_func(record: dict, metadata: dict) -> dict:
print('record:', record)
print('meta:', metadata)
print("record:", record)
print("meta:", metadata)
for k, v in record.items():
metadata[k] = v

del metadata['description']
del metadata['events']

return metadata
del metadata["description"]
del metadata["events"]

return metadata
45 changes: 29 additions & 16 deletions api/polymarket/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from api.polymarket.types import ClobReward
from api.polymarket.types import Tag


class GammaMarketClient:
def __init__(self):
self.gamma_url = "https://gamma-api.polymarket.com"
Expand All @@ -29,9 +30,13 @@ def parse_pydantic_market(self, market_object):

# These two fields below are returned as stringified lists from the api
if "outcomePrices" in market_object:
market_object["outcomePrices"] = json.loads(market_object["outcomePrices"])
market_object["outcomePrices"] = json.loads(
market_object["outcomePrices"]
)
if "clobTokenIds" in market_object:
market_object["clobTokenIds"] = json.loads(market_object["clobTokenIds"])
market_object["clobTokenIds"] = json.loads(
market_object["clobTokenIds"]
)

return Market(**market_object)
except Exception as err:
Expand Down Expand Up @@ -66,15 +71,19 @@ def parse_pydantic_event(self, event_object):
except Exception as err:
print(f"[parse_event] Caught exception: {err}")

def get_markets(self, querystring_params={}, parse_pydantic=False, local_file_path=None):
def get_markets(
self, querystring_params={}, parse_pydantic=False, local_file_path=None
):
if parse_pydantic and local_file_path is not None:
raise Exception('Cannot use "parse_pydantic" and "local_file" params simultaneously.')

raise Exception(
'Cannot use "parse_pydantic" and "local_file" params simultaneously.'
)

response = httpx.get(self.gamma_markets_endpoint, params=querystring_params)
if response.status_code == 200:
data = response.json()
if local_file_path is not None:
with open(local_file_path, 'w+') as out_file:
with open(local_file_path, "w+") as out_file:
json.dump(data, out_file)
elif not parse_pydantic:
return data
Expand All @@ -87,15 +96,19 @@ def get_markets(self, querystring_params={}, parse_pydantic=False, local_file_pa
print(f"Error response returned from api: HTTP {response.status_code}")
raise Exception()

def get_events(self, querystring_params={}, parse_pydantic=False, local_file_path=None):
def get_events(
self, querystring_params={}, parse_pydantic=False, local_file_path=None
):
if parse_pydantic and local_file_path is not None:
raise Exception('Cannot use "parse_pydantic" and "local_file" params simultaneously.')

raise Exception(
'Cannot use "parse_pydantic" and "local_file" params simultaneously.'
)

response = httpx.get(self.gamma_events_endpoint, params=querystring_params)
if response.status_code == 200:
data = response.json()
if local_file_path is not None:
with open(local_file_path, 'w+') as out_file:
with open(local_file_path, "w+") as out_file:
json.dump(data, out_file)
elif not parse_pydantic:
return data
Expand Down Expand Up @@ -128,19 +141,19 @@ def get_all_current_markets(self, limit=100):
all_markets = []
while True:
params = {
'active': True,
'closed': False,
'archived': False,
'limit': limit,
'offset': offset
"active": True,
"closed": False,
"archived": False,
"limit": limit,
"offset": offset,
}
market_batch = self.get_markets(querystring_params=params)
all_markets.extend(market_batch)

if len(market_batch) < limit:
break
offset += limit

return all_markets

def get_current_events(self, limit=4):
Expand Down
Loading

0 comments on commit 99930ca

Please sign in to comment.