Skip to content

Commit

Permalink
chore: Add experiment field to PredictRequest and handle its absence …
Browse files Browse the repository at this point in the history
…in predict function
  • Loading branch information
amirnd51 committed Aug 19, 2024
1 parent 7ba1ade commit 3d5c030
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
29 changes: 21 additions & 8 deletions python_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def get_experiment(experiment_id: str):

# Fetch all results
rows = cur.fetchall()

print(rows)
if not rows:
raise Exception(f"No experiment found with ID {experiment_id}")

Expand Down Expand Up @@ -307,6 +307,7 @@ class PredictRequest(BaseModel):
model: int
traceLevel: str
config : Optional[dict] = Field(default={})
experiment : Optional[str] = Field(default=None)

@app.post("/predict")
async def predict(request: PredictRequest):
Expand All @@ -332,9 +333,12 @@ async def predict(request: PredictRequest):
has_multi_input=True
config = request.config
# print(inputs[0])
first_input=inputs[0]



experiment_id=request.experiment
# print(experiment_id)


# experiment_id=create_expriement( cur, conn)

trail= get_trial_by_model_and_input( model_id, inputs)
# print(trail)
Expand All @@ -343,14 +347,23 @@ async def predict(request: PredictRequest):
if trail and trail[2] is not None:
# print(trail[2])
experiment_id = trail[0]
cur,conn=get_db_cur_con()
trial_id = trail[1]
return {"experimentId": experiment_id, "trialId": trial_id, "model_id": model_id, "input_url": inputs[0]}

model=get_model_by_id(model_id,cur,conn)
if not experiment_id:
experiment_id=create_expriement(cur, conn)

return {"experimentId": experiment_id, "trialId": trial_id, "model_id": model["name"], "input_url": inputs}
else:
#create a new trial and generate a new uuid experiment
cur,conn=get_db_cur_con()
if not experiment_id:
experiment_id=create_expriement(cur, conn)
# create a new trial and generate a new uuid experiment
# cur,conn=get_db_cur_con()


experiment_id=create_expriement( cur, conn)


trial_id=create_trial( model_id, experiment_id, cur, conn)
create_trial_inputs(trial_id, inputs, cur, conn)
Expand All @@ -369,7 +382,7 @@ async def predict(request: PredictRequest):
message= makePredictMessage(architecture, batch_size, desired_result_modality, gpu, inputs,has_multi_input,context,config, model["name"], trace_level, 0, "localhost:6831")

sendPredictMessage(message,queue_name,trial_id)
return {"experimentId": experiment_id, "trialId": trial_id, "model_id": model["name"]}
return {"experimentId": experiment_id, "trialId": trial_id, "model_id": model["name"],"input_url": inputs}



Expand Down
4 changes: 3 additions & 1 deletion python_api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def create_trial( model_id, experiment_id, cur, conn):
return trial_id

def create_trial_inputs(trial_id, inputs, cur, conn):
cur.execute("INSERT INTO trial_inputs (created_at,updated_at,trial_id, url) VALUES (%s,%s,%s, %s)", ( datetime.now(), datetime.now(),trial_id, json.dumps(inputs)))
cur.execute("SELECT MAX(id) as id FROM trial_inputs")
max_id = int(cur.fetchone()["id"])
cur.execute("INSERT INTO trial_inputs (id,created_at,updated_at,trial_id, url) VALUES (%s,%s,%s,%s, %s)", (max_id+1, datetime.now(), datetime.now(),trial_id, json.dumps(inputs)))
conn.commit()

def create_expriement( cur, conn):
Expand Down

0 comments on commit 3d5c030

Please sign in to comment.