diff --git a/Cargo.lock b/Cargo.lock index 2ed9cff0d..96083ba76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1990,6 +1990,7 @@ dependencies = [ "image", "init-tracing-opentelemetry", "itertools 0.12.1", + "lazy_static", "lorax-client", "metrics", "metrics-exporter-prometheus", diff --git a/router/Cargo.toml b/router/Cargo.toml index 2b38a0f62..93f69cf34 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -23,6 +23,7 @@ futures = "0.3.26" home = "=0.5.9" hf-hub = { version = "0.3.0", features = ["tokio"] } h2 = "0.3.26" +lazy_static = "1.4.0" lorax-client = { path = "client" } metrics = "0.21.0" metrics-exporter-prometheus = { version = "0.12.1", features = [] } diff --git a/router/src/infer.rs b/router/src/infer.rs index 5b39ec37d..3e8aa9db6 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -15,13 +15,15 @@ use futures::stream::StreamExt; /// Batching and inference logic use itertools::izip; use itertools::multizip; +use lazy_static::lazy_static; use lorax_client::{ - Batch, CachedBatch, ClassifyPredictionList, ClientError, Embedding, GeneratedText, Generation, - NextTokens, PreloadedAdapter, ShardedClient, + input_chunk, Batch, CachedBatch, ClassifyPredictionList, ClientError, Embedding, GeneratedText, + Generation, NextTokens, PreloadedAdapter, ShardedClient, }; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; use nohash_hasher::IntMap; +use regex::Regex; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::sync::{ @@ -36,6 +38,261 @@ use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Span}; +lazy_static! { + static ref MIN_LENGTH_PER_ENTITY_RUST: HashMap<&'static str, usize> = { + let mut m = HashMap::new(); + m.insert("merchant", 2); + m.insert("url", 5); + m.insert("product", 5); + m.insert("person", 5); + m.insert("payment_platform", 2); + m.insert("store number", 5); // Note: Rust variable names typically use snake_case + m.insert("location", 2); + m.insert("account_number", 5); + m.insert("phone", 7); + m.insert("date", 5); + m.insert("tx_acronym", 3); + m + }; + + static ref POSSIBLE_ENTITY_TYPES_RUST: HashSet = + MIN_LENGTH_PER_ENTITY_RUST.keys().map(|s| s.to_string()).collect(); + + // URL_PATTERNS + static ref URL_NON_ALPHANUMERIC: Regex = Regex::new(r"[^a-z0-9\-_\./]").unwrap(); + static ref URL_WHITESPACE: Regex = Regex::new(r"\s+").unwrap(); + static ref URL_HTTP_WWW: Regex = Regex::new(r"https?www\.").unwrap(); + static ref URL_HTTP_WWW_NO_DOT: Regex = Regex::new(r"https?www").unwrap(); + static ref URL_HTTP_PREFIX: Regex = Regex::new(r"^https?").unwrap(); + static ref URL_WWW_PREFIX: Regex = Regex::new(r"^www").unwrap(); + static ref URL_LEADING_DOT: Regex = Regex::new(r"^\.").unwrap(); + static ref URL_TRAILING_DOT: Regex = Regex::new(r"\.$").unwrap(); + static ref URL_TRAILING_SLASH: Regex = Regex::new(r"\/$").unwrap(); + + // MERCHANT_PRODUCT_PATTERNS + static ref MP_HTTP_WWW: Regex = Regex::new(r"https?www\.").unwrap(); + static ref MP_HTTP_WWW_NO_DOT: Regex = Regex::new(r"https?www").unwrap(); + static ref MP_HTTP_PREFIX: Regex = Regex::new(r"^https?").unwrap(); + static ref MP_WWW_PREFIX: Regex = Regex::new(r"^www").unwrap(); + static ref MP_POSSESSIVE: Regex = Regex::new(r"'s").unwrap(); + static ref MP_NON_ALPHANUMERIC: Regex = Regex::new(r"[^a-z0-9 \+]").unwrap(); + static ref MP_WHITESPACE: Regex = Regex::new(r"\s+").unwrap(); + + // MERCHANT_HACK_PATTERN + static ref MERCHANT_HACK: Regex = Regex::new(r"peacock ([a-z0-9]{5})").unwrap(); + + // --- Precompiled Patterns for tx_annotation_via_regex --- + static ref EXCLUDE_PATTERNS_RUST: HashMap<&'static str, Regex> = { + let mut m = HashMap::new(); + m.insert("late_payment_fee", Regex::new(r"loan.?advance").unwrap()); + m.insert("annual_fee", Regex::new(r"l.?a.?fit|laf|\*").unwrap()); + m.insert("other_fee", Regex::new(r"planet.?fit|club4fit").unwrap()); + m.insert("round_up", Regex::new(r"saloon|night|rodeo").unwrap()); + m.insert("interest_earned", Regex::new(r"payment.*principal|credit.*card.*payment").unwrap()); + m.insert("investment_transfer", Regex::new(r"robinhood(\b|\W)*(card|road|rd)").unwrap()); + m + }; + + static ref DEBIT_PATTERNS_RUST: Vec<(&'static str, Regex)> = vec![ + ("overdraft_fee", Regex::new(r"(^|\b|\W)(ove?rdra?ft|courtesy\spay|overdrawn)($|\b|\W)").unwrap()), + ("insufficient_funds_fee", Regex::new(r"(^|\b|\W)(sufficient\sfund.?|insufficient\s(fund.?|ch.?.?g.?)|nsf)($|\b|\W)").unwrap()), + ("foreign_transaction_fee", Regex::new(r"(^|\b|\W)(foreign|international|int.?l|intrntl).*\W(fee|cha?r?ge?|surcharge|se?r?vi?ce?.?cha?r?ge?)($|\b|\W)|(^|\b|\W)(fee|cha?r?ge?|surcharge|se?r?vi?ce?.?cha?r?ge?).*(foreign|international|int.?l|intrntl)($|\b|\W)").unwrap()), + ("late_payment_fee", Regex::new(r"(^|\b|\W)(late\scharge|past\sdue)($|\b|\W)|(^|\b|\W)late.*(\b|\W)fee").unwrap()), + ("annual_fee", Regex::new(r"(^|\b|\W)annual.*fee($|\b|\W)").unwrap()), + ("other_fee", Regex::new(r"(^|\b|\W)(fee|fees|feex)($|\b|\W)").unwrap()), + ("round_up", Regex::new(r"(^|\b|\W)(round.?up|save.?your.?cha|your.?change|pocket.?change|make.?cents|roll.?up.?change|edge\sup|(change.?up|good.?cents).?tra?n.?sfe?r)($|\b|\W)").unwrap()), + ("interest_charge", Regex::new(r"(^|\b|\W)interest($|\b|\W)").unwrap()), + ("service_charge", Regex::new(r"(^|\b|\W)(surcharge|se?r?vi?ce?.?cha?r?ge?|atm.?charge|service\s(charge|chg|chrg|charges))($|\b|\W)|(^|\b|\W)charge.?$").unwrap()), + ("atm_withdrawal", Regex::new(r"atm.*(withdrawal|withdrwl|w\/d|withdraw|(\b|\W)w\sd)(\b|\W)").unwrap()), + ("other_withdrawal", Regex::new(r"^(.?mobile.?|payments?|recurring|recur|rec|point.?of.?sale|p.?o.?s|by|received|e?check|checking|chk|dda|cash|card|credit|debit|remote|internet|online|virtual|image|overdraft|uncollected|protection|descriptive|cash|share|electronic|visa|stripe|counter|regular|external|customer|savings|changeup|branch|banking|e|od|at|scan|made.*branch.*|view|(\b|\d|\W)|(withdrawal|withdrwl|w\/d|withdraw|w\sd)){0,}.?(withdrawal|withdrwl|w\/d|withdraw|w\sd).?(.?mobile.?|payments?|recurring|recur|rec|point.?of.?sale|p.?o.?s|by|received|e?check|checking|chk|dda|cash|card|credit|debit|remote|internet|online|virtual|image|overdraft|uncollected|protection|descriptive|cash|share|electronic|visa|stripe|counter|regular|external|customer|savings|changeup|branch|banking|e|od|at|scan|made.*branch.*|view|(\b|\d|\W)|(0[1-9]|1[0-2])/([0-2][0-9]|3[01])|(made|at).*branch.*|(withdrawal|withdrwl|w\/d|withdraw|w\sd)){0,}$").unwrap()), + ("credit_card_payment", Regex::new(r"(card|cc).?(bill)?.?payment|(chase|to).?credit.?ca?rd|payment.*thank|online.*thank|credit.?c.?rd.?auto.?pay|crcardpmt|pa?y?me?n?t.*(\b|\W)cardmember.?serv|cardmember.?serv(\b|\W)").unwrap()), + ("paypal_inst_xfer_or_venmo", Regex::new(r"(^|\b|\W)(fro?m?|to)(\b|\W).*venmo($|\b|\W)|(^|\b|\W)venmo.*(\b|\W)(fro?m?|to)($|\b|\W)|paypal.*xfer|xfer.*paypal|paypal instant").unwrap()), + ("investment_transfer", Regex::new(r"(^|\b|\W)(acorns|credit builder|coinbasei?n?c?(.com)?|wealth.?front|betterment|qapital|m1(\b|\W)+(finance|payments?)|schwab.?brokerage|(withdrawal|investment).*edward.?jones|edward.?jones.*investment|raymond.?james.*(brokerage|deposit)|crypto.?com.*888.?824.?8817|webull|vanguard(\b|\W)+(buy|sell|investments?|mkt)|merrill.?lynch|pershing.?brokerage|(brokerage|withdrawal).?pershing|robinhood|fidelity.?(brokera?g?e?|fiis|investm?e?n?t?s?))($|\b|\W)|(^|\b|\W)(from|to)\sbrokerage(\b|\W)").unwrap()), + ("savings_transfer", Regex::new(r"(^|\b|\W)(fro?m?|to|tra?n?.?sfe?r?|xfer|tr?s?fr|transf|transferencia|p2p|a2a|v2v|acco?u?n?t.?to.?acco?u?n?t)(\b|\W).*(\b|\W)(savings?|sa?v|autosave)($|\b|\W)|(^|\b|\W)(saving|savings|sa?v|autosave)(\b|\W).*(fro?m?|to|tra?n.?sfe?r|xfer|tr?s?fr|transf|transferencia|p2p|a2a|v2v|acco?u?n?t.?to.?acco?u?n?t)($|\b|\W)|(move|save).?your.?pay|autosave|save as you go").unwrap()), + ("internal_transfer", Regex::new(r"(^|\b|\W)((fro?m?|to|tra?n?.?sfe?r?|xfer|tr?s?fr|transf|transferencia|p2p|a2a|v2v|acco?u?n?t.?to.?acco?u?n?t)(\b|\W).*(\b|\W)(checking?|ch?k|shares?|internal|invest|varo|rewards?|between|vault)|manual.?(db|cr)-bkrg|(fro?m?)\s.*(\b|\W)to|personal.*account.*personal.*account)($|\b|\W)|(^|\b|\W)(checking|vault|ch?k)\sto($|\b|\W)|internal.*tra?n.?sfe?r|inter_account").unwrap()), + ("other_transfer", Regex::new(r"(^|\b|\W)(tra?n?.?sfe?r?|xfer|tr?s?fr|transf|transferencia|p2p|a2a|v2v|acco?u?n?t.?to.?acco?u?n?t|phone.?trans)($|\b|\W)").unwrap()), + ("check", Regex::new(r"(^|deposited|checked|cashed|returned|substitute|counter|dda|priority|inclearing|paid|teller)\s?check(\d|\b|\W){0,}$|^check.?$|^check no category|^check\spaid(\d|\b|\W){0,}$").unwrap()), + ("daily_cash_adjustment", Regex::new(r"daily cash adjustment").unwrap()), + ("monthly_installment_payment", Regex::new(r"monthly installments.*(\b|\W)of(\b|\W)").unwrap()), + ]; + + static ref CREDIT_PATTERNS_RUST: Vec<(&'static str, Regex)> = vec![ + ("fee_reversal", Regex::new(r"(^|\b|\W)(fee|ove?rdra?ft|nsf|(in|non|un)?.?sufficient|charge|courtesy\spay|surcharge|se?r?vi?ce?.?cha?r?ge?|past\sdue|atm).*(\b|\W)(rev\W|revers|returne?d?|re?fu?nd|waive.?|rebate|reimb|adj|credit|correct)|(^|\b|\W)(rev\W|revers|returne?d?|refund|waive.?|rebate|reimb|adj|credit).*(\b|\W)(fee|ove?rdra?ft|nsf|sufficient|charge|courtesy\spay|surcharge|svc.?chg|past\sdue)").unwrap()), + ("round_up", Regex::new(r"(^|\b|\W)(round.?up|save.?your.?cha|your.?change|pocket.?change|make.?cents|roll.?up.?change)($|\b|\W)").unwrap()), + ("cash_back_redemption", Regex::new(r"cash.?(back|redemption)|reward").unwrap()), + ("interest_earned", Regex::new(r"interest|(^|\b|\W)(int|apy).?earned").unwrap()), + ("dividend", Regex::new(r"(^|\b|\W)dividends?($|\b|\W)").unwrap()), + ("cash_advance", Regex::new(r"(^|\b|\W)(pay|cash)(\b|\W).*(\b|\W)advance|instacash").unwrap()), + ("paypal_inst_xfer_or_venmo", Regex::new(r"(^|\b|\W)(fro?m?|to)(\b|\W).*venmo($|\b|\W)|(^|\b|\W)venmo.*(\b|\W)(fro?m?|to)($|\b|\W)|paypal.*xfer|xfer.*paypal|paypal instant").unwrap()), + ("investment_transfer", Regex::new(r"(^|\b|\W)(acorns|credit builder|coinbasei?n?c?(.com)?|wealth.?front|betterment|qapital|m1(\b|\W)+(finance|payments?)|schwab.?brokerage|(withdrawal|investment).*edward.?jones|edward.?jones.*investment|raymond.?james.*(brokerage|deposit)|crypto.?com.*888.?824.?8817|webull|vanguard(\b|\W)+(buy|sell|investments?|mkt)|merrill.?lynch|pershing.?brokerage|(brokerage|withdrawal).?pershing|robinhood|fidelity.?(brokera?g?e?|fiis|investm?e?n?t?s?))($|\b|\W)|(^|\b|\W)(from|to)\sbrokerage(\b|\W)").unwrap()), + ("savings_transfer", Regex::new(r"(^|\b|\W)(fro?m?|to|tra?n?.?sfe?r?|xfer|tr?s?fr|transf|transferencia|p2p|a2a|v2v|acco?u?n?t.?to.?acco?u?n?t)(\b|\W).*(\b|\W)(savings?|sa?v|autosave)($|\b|\W)|(^|\b|\W)(saving|savings|sa?v|autosave)(\b|\W).*(fro?m?|to|tra?n.?sfe?r|xfer|tr?s?fr|transf|transferencia|p2p|a2a|v2v|acco?u?n?t.?to.?acco?u?n?t)($|\b|\W)|(move|save).?your.?pay|autosave|save as you go").unwrap()), + ("internal_transfer", Regex::new(r"(^|\b|\W)((fro?m?|to|tra?n?.?sfe?r?|xfer|tr?s?fr|transf|transferencia|p2p|a2a|v2v|acco?u?n?t.?to.?acco?u?n?t)(\b|\W).*(\b|\W)(checking?|ch?k|shares?|internal|invest|varo|rewards?|between|vault)|manual.?(db|cr)-bkrg|(fro?m?)\s.*(\b|\W)to|personal.*account.*personal.*account)($|\b|\W)|(^|\b|\W)(checking|vault|ch?k)\sto($|\b|\W)|internal.*tra?n.?sfe?r|inter_account").unwrap()), + ("other_transfer", Regex::new(r"(^|\b|\W)(tra?n?.?sfe?r?|xfer|tr?s?fr|transf|transferencia|p2p|a2a|v2v|acco?u?n?t.?to.?acco?u?n?t|phone.?trans)($|\b|\W)|real.?time.*credit|tran.*ppd").unwrap()), + ("atm_deposit", Regex::new(r"(^|\b|\W)atm.*(dep|deposit)($|\b|\W)").unwrap()), + ("check_return", Regex::new(r"return.*(check|checks|chk)|(check|checks|chk).*return").unwrap()), + ("credit_card_payment", Regex::new(r"directpay|(card|cc).?.?payment|payment.*thank|online.*thank|credit.?c.?rd.?auto.?pay|crcardpmt|(chase|bank.?of.?america)\scredit\sca?rd").unwrap()), + ("statement_credit", Regex::new(r"(^|\b|\W)(statement|redemption|tra?ve?l).?credit($|\b|\W)|points.*tra?ve?").unwrap()), + ("other_deposit", Regex::new(r"^(.?mobile.?|payments?|recurring|recur|rec|point.?of.?sale|p.?o.?s|by|received|e?check|checking|chk|dda|cash|card|credit|debit|remote|internet|online|virtual|image|overdraft|uncollected|protection|descriptive|cash|share|electronic|visa|stripe|counter|regular|external|customer|savings|changeup|branch|banking|e|od|at|scan|made.*branch.*|view|(\d|\b|\W)|deposit){0,}.?deposit.?(.?mobile.?|payments?|recurring|recur|rec|point.?of.?sale|p.?o.?s|by|received|e?check|checking|chk|dda|cash|card|credit|debit|remote|internet|online|virtual|image|overdraft|uncollected|protection|descriptive|cash|share|electronic|visa|stripe|counter|regular|external|customer|savings|changeup|branch|banking|e|od|at|scan|made.*branch.*|view|(\d|\b|\W)|deposit){0,}$|counter credit").unwrap()), + ("gov_payment", Regex::new(r"(^|\b|\W)(tax|tpg products|taxrefunds?)($|\b|\W)|(taxrfd|oag).*ppd|treas 310|state.?of.*(pa?y?me?n?t|refu?n?d?|depo?s?i?t?)|taxrfd|dept\s?of\s?rev|h.?r.?block").unwrap()), + ("payroll", Regex::new(r"deposit\sdfas|dfas.?in|payroll|salary|paycheck|(dir|direct)\s?(dep|deposit)|social.?security|(child|edi|ihss|dd).*(\b|\W)ppd\sid(\b|\W)|one@work").unwrap()), + ("courtesy_credit", Regex::new(r"(^|\b|\W)courtesy.?credit($|\b|\W)").unwrap()), + ]; +} + +fn postprocess_entity_rust(entity_type: &str, value: &str) -> Option { + if entity_type == "0" { + return None; + } + let mut processed_value = value.to_lowercase().trim().to_string(); + if processed_value.is_empty() { + return None; + } + + match entity_type { + "url" => { + processed_value = URL_NON_ALPHANUMERIC + .replace_all(&processed_value, "") + .to_string(); + processed_value = URL_WHITESPACE.replace_all(&processed_value, "").to_string(); + processed_value = URL_HTTP_WWW.replace_all(&processed_value, "").to_string(); + processed_value = URL_HTTP_WWW_NO_DOT + .replace_all(&processed_value, "") + .to_string(); + processed_value = URL_HTTP_PREFIX + .replace_all(&processed_value, "") + .to_string(); + processed_value = URL_WWW_PREFIX.replace_all(&processed_value, "").to_string(); + processed_value = URL_LEADING_DOT + .replace_all(&processed_value, "") + .to_string(); + processed_value = URL_TRAILING_DOT + .replace_all(&processed_value, "") + .to_string(); + processed_value = URL_TRAILING_SLASH + .replace_all(&processed_value, "") + .to_string(); + } + "merchant" | "product" => { + processed_value = MP_HTTP_WWW.replace_all(&processed_value, "").to_string(); + processed_value = MP_HTTP_WWW_NO_DOT + .replace_all(&processed_value, "") + .to_string(); + processed_value = MP_HTTP_PREFIX.replace_all(&processed_value, "").to_string(); + processed_value = MP_WWW_PREFIX.replace_all(&processed_value, "").to_string(); + processed_value = MP_POSSESSIVE.replace_all(&processed_value, "s").to_string(); + processed_value = MP_NON_ALPHANUMERIC + .replace_all(&processed_value, "") + .to_string(); + processed_value = MP_WHITESPACE + .replace_all(&processed_value, " ") + .trim() + .to_string(); + + if entity_type == "merchant" { + processed_value = MERCHANT_HACK + .replace_all(&processed_value, "peacock") + .to_string(); + } + } + "phone" => { + processed_value = processed_value + .chars() + .filter(|c| c.is_ascii_digit()) + .collect(); + } + _ => { + // Default processing or no specific processing for other types + } + } + + let min_length = MIN_LENGTH_PER_ENTITY_RUST + .get(entity_type) + .copied() + .unwrap_or(0); + if processed_value.len() < min_length { + return None; + } + + Some(processed_value) +} + +fn preprocess_description_rust(s: &str) -> String { + s.to_lowercase() +} + +fn debits_regex_rust(description: &str) -> Option { + let s = preprocess_description_rust(description); + for (name, pattern) in DEBIT_PATTERNS_RUST.iter() { + if pattern.is_match(&s) { + if let Some(exclude_pattern) = EXCLUDE_PATTERNS_RUST.get(name) { + if exclude_pattern.is_match(&s) { + continue; + } + } + return Some(name.to_string()); + } + } + None +} + +fn credit_regex_rust(description: &str) -> Option { + let s = preprocess_description_rust(description); + for (name, pattern) in CREDIT_PATTERNS_RUST.iter() { + if pattern.is_match(&s) { + if let Some(exclude_pattern) = EXCLUDE_PATTERNS_RUST.get(name) { + if exclude_pattern.is_match(&s) { + continue; + } + } + return Some(name.to_string()); + } + } + None +} + +fn tx_annotation_via_regex(description: String, amount: f32) -> Option { + if amount > 0.0 { + debits_regex_rust(&description) + } else { + credit_regex_rust(&description) + } +} + +fn create_linkable_fields_rust( + raw_ner: &[Entity], + possible_entities: &HashSet, + threshold: f32, +) -> HashMap> { + let mut accumulated_entities: HashMap> = HashMap::new(); + + for entity in raw_ner { + if entity.entity_group != "0" && entity.score >= threshold { + if let Some(processed_word) = + postprocess_entity_rust(&entity.entity_group, &entity.word) + { + let key = format!("{}_entity", entity.entity_group); + accumulated_entities + .entry(key) + .or_default() + .push(processed_word); + } + } + } + + let mut final_linkable_fields: HashMap> = HashMap::new(); + for entity_base_type in possible_entities { + let key = format!("{}_entity", entity_base_type); + if let Some(words_vec) = accumulated_entities.get_mut(&key) { + words_vec.sort_unstable(); + words_vec.dedup(); + final_linkable_fields.insert(key, Some(words_vec.join(" "))); + } else { + final_linkable_fields.insert(key, None); + } + } + final_linkable_fields +} + #[derive(Clone, Serialize, Deserialize, Default)] pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, @@ -644,14 +901,20 @@ impl Infer { None, ); - let inputs = request.inputs.clone(); + let inputs_clone_for_ner = request.inputs.original_description.clone(); // Clone for aggregate_ner_output_simple + let inputs_clone_for_response = request.inputs.original_description.clone(); // Clone for the final response if needed, or just use the one above. + let description_for_annotation = request.inputs.original_description.clone(); // Clone for tx_annotation_via_regex + + let input_descriptions_val = request.inputs.original_description; // Keep original for validation + let amount = request.inputs.amount; + let (tokenized_inputs, input_length) = self .validation - .validate_input(request.inputs, true, None, Some(1)) + .validate_input(input_descriptions_val, true, None, Some(1)) .await?; let valid_request = ValidClassifyRequest { - inputs: inputs.clone(), + inputs: inputs_clone_for_response, // Pass the clone intended for this tokenized_inputs, input_length: input_length as u32, adapter: adapter.clone(), @@ -676,7 +939,7 @@ impl Infer { ); // Return values - let mut return_entities = None; + let mut raw_ner_entities: Option> = None; let mut result_start = None; let mut result_queued = None; @@ -706,20 +969,38 @@ impl Infer { id: _, } => { let entities = aggregate_ner_output_simple( - inputs.clone(), + inputs_clone_for_ner.clone(), // Use the clone for NER aggregation predictions, self.tokenizer.clone().unwrap(), ); - return_entities = Some(entities); + raw_ner_entities = Some(entities); result_start = Some(start); result_queued = Some(queued); } } } - if let Some(return_entities) = return_entities { + if let Some(entities) = raw_ner_entities { + // TODO: Make threshold configurable, e.g., from request.parameters + const NER_POSTPROCESSING_THRESHOLD: f32 = 0.8; + let linkable_fields = create_linkable_fields_rust( + &entities, + &POSSIBLE_ENTITY_TYPES_RUST, + NER_POSTPROCESSING_THRESHOLD, + ); + + let annotation = tx_annotation_via_regex( + description_for_annotation, // Use the cloned description + amount, + ); + + // Add annotation to linkable_fields + let mut updated_linkable_fields = linkable_fields; + updated_linkable_fields.insert("annotation".to_string(), annotation); + Ok(InferClassifyResponse { - predictions: return_entities, + raw_ner: entities, + linkable_fields: updated_linkable_fields, queued: result_queued.unwrap(), start: result_start.unwrap(), }) @@ -760,31 +1041,39 @@ impl Infer { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); - let request_id_map: HashMap = request + // Store ClassifyInput to have access to amount later + let request_id_to_input_map: HashMap = request .inputs .iter() .enumerate() - .map(|(id, input)| (id as u64, input.clone())) + .map(|(id, input_item)| (id as u64, input_item.clone())) // input_item is &crate::ClassifyInput .collect(); // Call validate_input on every input in the request and await the results let futures: Vec<_> = request .inputs .iter() - .map(|input| { - self.validation - .validate_input(input.clone(), true, None, Some(1)) + .map(|input_item| { + // input_item is &crate::ClassifyInput + self.validation.validate_input( + input_item.original_description.clone(), + true, + None, + Some(1), + ) }) .collect(); let all_tokenized_inputs = try_join_all(futures).await?; - for ((id, r_inputs), (tokenized_inputs, input_length)) in + for ((id, _r_inputs), (tokenized_inputs, input_length)) in + // _r_inputs is &crate::ClassifyInput, not directly used here request.inputs.iter().enumerate().zip(all_tokenized_inputs) { - let inputs = r_inputs.to_string().clone(); + let input_item_for_valid_request = request_id_to_input_map.get(&(id as u64)).unwrap(); // This is &crate::ClassifyInput + let valid_request = ValidClassifyRequest { - inputs, + inputs: input_item_for_valid_request.original_description.clone(), // Pass the description string tokenized_inputs, input_length: input_length as u32, adapter: adapter.clone(), @@ -809,28 +1098,49 @@ impl Infer { drop(response_tx); // Close the sending end // Return values - - let mut all_entities = HashMap::new(); + let mut all_responses_map = HashMap::new(); let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { - // Add prefill tokens InferStreamResponse::Classify { - predictions, + predictions, // This is ClassifyPredictionList start, queued, id, } => { - let request_inputs = request_id_map.get(&id.unwrap()).unwrap().clone(); - let entities = aggregate_ner_output_simple( - request_inputs, - predictions.clone(), + let request_id = + id.expect("Classify response in batch missing ID. This is a bug."); + + let full_input_item = request_id_to_input_map + .get(&request_id) + .expect("Request ID not found in map. This is a bug."); + + let description_for_processing = full_input_item.original_description.clone(); + let amount_for_processing = full_input_item.amount; + + let raw_ner_entities = aggregate_ner_output_simple( + description_for_processing.clone(), + predictions.clone(), // ClassifyPredictionList self.tokenizer.clone().unwrap(), ); - all_entities.insert( - id.unwrap(), + + const NER_POSTPROCESSING_THRESHOLD: f32 = 0.8; + let mut linkable_fields = create_linkable_fields_rust( + &raw_ner_entities, + &POSSIBLE_ENTITY_TYPES_RUST, + NER_POSTPROCESSING_THRESHOLD, + ); + + let annotation = + tx_annotation_via_regex(description_for_processing, amount_for_processing); + + linkable_fields.insert("annotation".to_string(), annotation); + + all_responses_map.insert( + request_id, InferClassifyResponse { - predictions: entities, + raw_ner: raw_ner_entities, + linkable_fields, queued, start, }, @@ -843,21 +1153,21 @@ impl Infer { } } } - if all_entities.is_empty() { + if all_responses_map.is_empty() { let err = InferError::ClassificationFailure; metrics::increment_counter!("lorax_request_failure", "err" => "classification_failure"); tracing::error!("{err}"); Err(err) } else { - let mut sorted_responses: Vec<_> = all_entities.into_iter().collect(); - sorted_responses.sort_by_key(|&(id, _)| id); + let mut sorted_responses_vec: Vec<_> = all_responses_map.into_iter().collect(); + sorted_responses_vec.sort_by_key(|&(id, _)| id); - let sorted_responses: Vec = sorted_responses + let final_sorted_responses: Vec = sorted_responses_vec .into_iter() .map(|(_, response)| response) .collect(); - Ok(sorted_responses) + Ok(final_sorted_responses) } } @@ -1747,7 +2057,8 @@ impl InferError { #[derive(Debug)] pub(crate) struct InferClassifyResponse { - pub(crate) predictions: Vec, + pub(crate) raw_ner: Vec, // Renamed from predictions + pub(crate) linkable_fields: HashMap>, pub(crate) queued: Instant, pub(crate) start: Instant, } diff --git a/router/src/lib.rs b/router/src/lib.rs index c3cf2cedc..c59b29fca 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -22,6 +22,7 @@ use loader::AdapterLoader; use serde::{Deserialize, Serialize}; use serde_json::json; use server::prepare_chat_input; +use std::collections::HashMap; use utoipa::ToSchema; use validation::Validation; @@ -1210,14 +1211,27 @@ struct CompatEmbedding { object: String, } +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct ClassifyInput { + original_description: String, + amount: f32, +} + +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct ClassifyParameters { + aggregation_strategy: String, +} + #[derive(Clone, Debug, Deserialize, ToSchema)] struct ClassifyRequest { - inputs: String, + inputs: ClassifyInput, + parameters: Option, } #[derive(Clone, Debug, Deserialize, ToSchema)] struct BatchClassifyRequest { - inputs: Vec, + inputs: Vec, + parameters: Option, } #[derive(Clone, Debug, Deserialize, ToSchema)] @@ -1236,6 +1250,17 @@ struct Entity { end: usize, } +#[derive(Debug, Serialize)] +struct ClassifyResponse { + raw_ner: Vec, + linkable_fields: HashMap>, +} + +#[derive(Debug, Serialize)] +struct BatchClassifyResponse { + responses: Vec, +} + impl From for Entity { fn from(entity: EntityMessage) -> Self { Entity { diff --git a/router/src/server.rs b/router/src/server.rs index 8f9ac505e..bb87ac341 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2,16 +2,17 @@ use crate::adapter::{extract_adapter_params, BASE_MODEL_ADAPTER_ID}; use crate::config::Config; use crate::health::Health; -use crate::infer::{InferError, InferResponse, InferStreamResponse}; +use crate::infer::{InferClassifyResponse, InferError, InferResponse, InferStreamResponse}; use crate::tool_grammar::ToolGrammar; use crate::validation::ValidationError; use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig}; use crate::{ - AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchEmbedRequest, BestOfSequence, - ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, - ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest, - CompatEmbedRequest, CompatEmbedResponse, CompatEmbedding, CompatGenerateRequest, - CompletionFinishReason, CompletionRequest, CompletionResponse, CompletionResponseChoice, + AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchClassifyResponse, + BatchEmbedRequest, BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, + ChatMessage, ClassifyInput, ClassifyRequest, ClassifyResponse, CompatEmbedRequest, + CompatEmbedResponse, CompatEmbedding, CompatGenerateRequest, CompletionFinishReason, + CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, Details, EmbedParameters, EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message, @@ -491,7 +492,11 @@ async fn is_startup_ready( ) -> Result<(), (StatusCode, Json)> { if health.shard_info().supports_classification { let classify_request = ClassifyRequest { - inputs: "San Francisco".to_string(), + inputs: ClassifyInput { + original_description: "San Francisco".to_string(), + amount: 100.0, + }, + parameters: None, }; match infer.classify(classify_request).await { Ok(_) => {} @@ -1868,11 +1873,11 @@ async fn compat_embed( async fn classify( infer: Extension, Json(req): Json, -) -> Result<(HeaderMap, Json>), (StatusCode, Json)> { +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); let start_time = Instant::now(); metrics::increment_counter!("lorax_request_count"); - tracing::debug!("Input: {}", req.inputs); + tracing::debug!("Input: {:?}", req.inputs); let response = infer.classify(req).await?; // Timings @@ -1919,15 +1924,23 @@ async fn classify( "lorax_request_inference_duration", inference_time.as_secs_f64() ); + + // The `responses` Vec is already in the correct structure for serialization. + // Record metrics for each response. metrics::histogram!( "lorax_request_classify_output_count", - response.predictions.len() as f64 + response.raw_ner.len() as f64 ); - tracing::debug!("Output: {:?}", response.predictions); + tracing::debug!("Output: {:?}", response.raw_ner); tracing::info!("Success"); - - Ok((headers, Json(response.predictions))) + Ok(( + headers, + Json(ClassifyResponse { + raw_ner: response.raw_ner, + linkable_fields: response.linkable_fields, + }), + )) } #[utoipa::path( @@ -1936,7 +1949,7 @@ async fn classify( path = "/classify_batch", request_body = BatchClassifyRequest, responses( - (status = 200, description = "Classifications", body = BatchClassifyResponse), + (status = 200, description = "Classifications", body = Vec), (status = 500, description = "Incomplete classification", body = ErrorResponse), ) )] @@ -1944,7 +1957,7 @@ async fn classify( async fn classify_batch( infer: Extension, Json(req): Json, -) -> Result<(HeaderMap, Json>>), (StatusCode, Json)> { +) -> Result<(HeaderMap, Json>), (StatusCode, Json)> { let span = tracing::Span::current(); let start_time = Instant::now(); metrics::increment_counter!("lorax_request_count"); @@ -1955,6 +1968,9 @@ async fn classify_batch( // Timings let now = Instant::now(); let total_time = start_time.elapsed(); + + // These calculations should still work as `queued` and `start` are accessible in Rust, + // even if skipped during serialization. let mut validation_times = Vec::with_capacity(responses.len()); let mut queue_times = Vec::with_capacity(responses.len()); let mut inference_times = Vec::with_capacity(responses.len()); @@ -2010,20 +2026,28 @@ async fn classify_batch( inference_time.as_secs_f64() ); - let batch_entity_vec: Vec> = responses - .into_iter() - .map(|r| { - let entity_vec = r.predictions; - metrics::histogram!( - "lorax_request_classify_output_count", - entity_vec.len() as f64 - ); - entity_vec - }) - .collect(); - tracing::debug!("Output: {:?}", batch_entity_vec); + // The `responses` Vec is already in the correct structure for serialization. + // Record metrics for each response. + for r in &responses { + metrics::histogram!( + "lorax_request_classify_output_count", + r.raw_ner.len() as f64 + ); + } + tracing::debug!("Output: {:?}", responses); tracing::info!("Success"); - Ok((headers, Json(batch_entity_vec))) + Ok(( + headers, + Json( + responses + .into_iter() + .map(|r| ClassifyResponse { + raw_ner: r.raw_ner, + linkable_fields: r.linkable_fields, + }) + .collect(), + ), + )) } /// Tokenize inputs