Skip to content

Commit

Permalink
Merge pull request #87 from ai-cfia/84-Save-pipeline-with-the-inference
Browse files Browse the repository at this point in the history
new pipeline_id column
  • Loading branch information
sylvanie85 authored Jul 29, 2024
2 parents d397464 + 1831051 commit 875212a
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 53 deletions.
19 changes: 13 additions & 6 deletions datastore/Nachet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,18 @@ async def register_inference_result(
"""
try:
trimmed_inference = inference_metadata.build_inference_import(inference_dict)

model_name = inference_dict["models"][0]["name"]
pipeline_id = machine_learning.get_pipeline_id_from_model_name(cursor, model_name)
inference_dict["pipeline_id"] = str(pipeline_id)

inference_id = inference.new_inference(
cursor, trimmed_inference, user_id, picture_id, type
cursor, trimmed_inference, user_id, picture_id, type, pipeline_id
)
nb_object = int(inference_dict["totalBoxes"])
inference_dict["inference_id"] = str(inference_id)


# loop through the boxes
for box_index in range(nb_object):
# TODO: adapt for multiple types of objects
Expand Down Expand Up @@ -343,14 +350,14 @@ async def new_correction_inference_feedback(cursor, inference_dict, type: int =
TODO: doc
"""
try:
if "inference_id" in inference_dict.keys():
inference_id = inference_dict["inference_id"]
if "inferenceId" in inference_dict.keys():
inference_id = inference_dict["inferenceId"]
else:
raise InferenceFeedbackError(
"Error: inference_id not found in the given infence_dict"
)
if "user_id" in inference_dict.keys():
user_id = inference_dict["user_id"]
if "userId" in inference_dict.keys():
user_id = inference_dict["userId"]
if not (user.is_a_user_id(cursor, user_id)):
raise InferenceFeedbackError(
f"Error: user_id {user_id} not found in the database"
Expand All @@ -371,7 +378,7 @@ async def new_correction_inference_feedback(cursor, inference_dict, type: int =
f"Error: Inference {inference_id} is already verified"
)
for object in inference_dict["boxes"]:
box_id = object["box_id"]
box_id = object["boxId"]
seed_name = object["label"]
seed_id = object["classId"]
# flag_seed = False
Expand Down
39 changes: 3 additions & 36 deletions datastore/db/bytebase/schema_nachet_0.0.11.sql
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ IF (EXISTS (SELECT 1 FROM information_schema.schemata WHERE schema_name = 'nache
"is_default" boolean NOT NULL default false,
"data" json NOT NULL
);


Alter table "nachet_0.0.11"."inference" ADD "pipeline_id" uuid REFERENCES "nachet_0.0.11"."pipeline"(id);

CREATE TRIGGER "pipeline_default_trigger" BEFORE insert OR UPDATE ON "nachet_0.0.11"."pipeline"
FOR EACH ROW EXECUTE FUNCTION "nachet_0.0.11".pipeline_default_trigger();

Expand Down Expand Up @@ -166,41 +168,6 @@ FOR EACH ROW
WHEN (NEW.verified = true)
EXECUTE FUNCTION verified_inference();


-- Trigger function for the `inference` table
CREATE OR REPLACE FUNCTION "nachet_0.0.11".update_inference_timestamp()
RETURNS TRIGGER AS $$
BEGIN
NEW.update_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

-- Trigger for the `inference` table
CREATE TRIGGER inference_update_before
BEFORE UPDATE ON "nachet_0.0.11".inference
FOR EACH ROW
EXECUTE FUNCTION "nachet_0.0.11".update_inference_timestamp();

-- Trigger function for the `object` table
CREATE OR REPLACE FUNCTION "nachet_0.0.11".update_object_timestamp()
RETURNS TRIGGER AS $$
BEGIN
NEW.update_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

updated_at

-- Trigger for the `inference` table
CREATE TRIGGER object_update_before
BEFORE UPDATE ON "nachet_0.0.11".object
FOR EACH ROW
EXECUTE FUNCTION "nachet_0.0.11".update_object_timestamp();



INSERT INTO "nachet_0.0.11".seed(name) VALUES
('Brassica napus'),
('Brassica juncea'),
Expand Down
8 changes: 5 additions & 3 deletions datastore/db/queries/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class InferenceAlreadyVerifiedError(Exception):
"""

def new_inference(cursor, inference, user_id: str, picture_id:str,type):
def new_inference(cursor, inference, user_id: str, picture_id:str,type, pipeline_id:str):
"""
This function uploads a new inference to the database.
Expand All @@ -42,10 +42,11 @@ def new_inference(cursor, inference, user_id: str, picture_id:str,type):
inference(
inference,
picture_id,
user_id
user_id,
pipeline_id
)
VALUES
(%s,%s,%s)
(%s,%s,%s,%s)
RETURNING id
"""
cursor.execute(
Expand All @@ -54,6 +55,7 @@ def new_inference(cursor, inference, user_id: str, picture_id:str,type):
inference,
picture_id,
user_id,
pipeline_id,
),
)
inference_id=cursor.fetchone()[0]
Expand Down
41 changes: 37 additions & 4 deletions datastore/db/queries/machine_learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class NonExistingTaskEWarning(UserWarning):
pass
class PipelineCreationError(Exception):
pass
class PipelineNotFoundError(Exception):
pass

def new_pipeline(cursor, pipeline,pipeline_name, model_ids, active:bool=False):
"""
Expand Down Expand Up @@ -240,7 +242,40 @@ def new_pipeline_model(cursor,pipeline_id,model_id):
return pipeline_model_id
except(Exception):
raise PipelineCreationError("Error: pipeline model not uploaded")


def get_pipeline_id_from_model_name(cursor,model_name:str):
"""
This function gets the pipeline id from the model name.
Parameters:
- cursor (cursor): The cursor of the database.
- model_name (str): The name of the model.
Returns:
- The UUID of the model.
"""
try:
query = """
SELECT
pm.pipeline_id
FROM
pipeline_model as pm
LEFT JOIN
model as m on m.id=pm.model_id
WHERE
m.name= %s
"""
cursor.execute(
query,
(
model_name,
),
)
model_id=cursor.fetchone()[0]
return model_id
except(Exception):
raise PipelineNotFoundError(f"Error: error finding pipelin for model {model_name}")

def new_model(cursor, model,name,endpoint_name,task_id:int):
"""
This function creates a new model in the database.
Expand Down Expand Up @@ -369,10 +404,8 @@ def get_model_id_from_name(cursor,model_name:str):
)
model_id=cursor.fetchone()[0]
return model_id
except(ValueError):
raise NonExistingTaskEWarning(f"Warning: the given model '{model_name}' was not found")
except(Exception):
raise PipelineCreationError("Error: model not found")
raise PipelineNotFoundError(f"Error: model not found for model name : {model_name}")

def get_model_id_from_endpoint(cursor,endpoint_name:str):
"""
Expand Down
16 changes: 12 additions & 4 deletions tests/Nachet/test_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,14 +712,19 @@ def setUp(self):
self.cursor, self.user_id, self.inference, picture_id, model_id
)
)
self.registered_inference["user_id"] = self.user_id
# to match frontend field names :
self.registered_inference["inferenceId"] = self.registered_inference["inference_id"]
self.registered_inference.pop("inference_id")
self.registered_inference["userId"] = self.user_id
self.mock_box = {"topX": 123, "topY": 456, "bottomX": 789, "bottomY": 123}
self.inference_id = self.registered_inference.get("inference_id")
self.inference_id = self.registered_inference.get("inferenceId")
self.boxes_id = []
self.top_id = []
self.unreal_seed_id = Nachet.seed.new_seed(self.cursor, "unreal_seed")
for box in self.registered_inference["boxes"]:
self.boxes_id.append(box["box_id"])
box["boxId"] = box["box_id"]
box.pop("box_id")
self.boxes_id.append(box["boxId"])
self.top_id.append(box["top_id"])
box["classId"] = Nachet.seed.get_seed_id(self.cursor, box["label"])

Expand Down Expand Up @@ -872,7 +877,7 @@ def test_new_correction_inference_feedback_box_edited(self):
)
for box in self.registered_inference["boxes"]:
object_db = Nachet.inference.get_inference_object(
self.cursor, box["box_id"]
self.cursor, box["boxId"]
)
# The new box metadata must be updated
self.assertDictEqual(object_db[1], self.mock_box)
Expand Down Expand Up @@ -950,3 +955,6 @@ def test_new_correction_inference_feedback_unknown_seed(self):
self.assertTrue(validator.is_valid_uuid(str(object_db[4])))
# valid column must be true
self.assertTrue(object_db[6])

if __name__ == "__main__":
unittest.main()

0 comments on commit 875212a

Please sign in to comment.