Skip to content

Commit

Permalink
change nli model (#167)
Browse files Browse the repository at this point in the history
* change nli model

* Fix bug in hallucination

---------

Co-authored-by: Shuguang Chen <[email protected]>
  • Loading branch information
cotran2 and nehcgs authored Oct 10, 2024
1 parent 3b7c586 commit f9e3a05
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion arch/src/consts.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5";
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/deberta-base-nli";
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.1;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
Expand Down
2 changes: 1 addition & 1 deletion model_server/app/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_embedding_model(


def get_zero_shot_model(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli"),
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
):
print("Loading Zero-shot Model...")

Expand Down
7 changes: 4 additions & 3 deletions model_server/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,18 @@ async def hallucination(req: HallucinationRequest, res: Response):
if "arch_messages" in req.parameters:
req.parameters.pop("arch_messages")

candidate_labels = [f"{k} is {v}" for k, v in req.parameters.items()]
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}

predictions = classifier(
req.prompt,
candidate_labels=candidate_labels,
candidate_labels=list(candidate_labels.keys()),
hypothesis_template="{}",
multi_label=True,
)

params_scores = {
k[0]: s for k, s in zip(req.parameters.items(), predictions["scores"])
candidate_labels[label]: score
for label, score in zip(predictions["labels"], predictions["scores"])
}

logger.info(
Expand Down

0 comments on commit f9e3a05

Please sign in to comment.