Skip to content

Commit

Permalink
Updated nlp.py script to account for the repackaging of the riva.clie…
Browse files Browse the repository at this point in the history
…nt Python module

Signed-off-by: Sven Chilton <[email protected]>
  • Loading branch information
svenchilton committed Dec 13, 2022
1 parent ed7d1c9 commit 63a22f9
Showing 1 changed file with 54 additions and 33 deletions.
87 changes: 54 additions & 33 deletions virtual-assistant-rasa/rasa-weatherbot/riva_local/nlp/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@
# README.md file.
# ==============================================================================

import riva_api.riva_nlp_pb2 as rnlp
import riva_api.riva_nlp_pb2_grpc as rnlp_srv
import riva.client
from riva.client.proto.riva_nlp_pb2 import (
AnalyzeIntentResponse,
NaturalQueryResponse,
TokenClassResponse
)

import grpc
from config import riva_config, rivanlp_config
import requests
import json

channel = grpc.insecure_channel(riva_config["RIVA_SPEECH_API_URL"])
riva_nlp = rnlp_srv.RivaLanguageUnderstandingStub(channel)
auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"])
riva_nlp = riva.client.NLPService(auth)

# The default intent object for specifying the city
SPECIFY_CITY_INTENT = { 'value': "specify_city", 'confidence': 1 }
Expand All @@ -28,7 +32,7 @@
nlu_fallback_threshold = rivanlp_config["NLU_FALLBACK_THRESHOLD"] if "NLU_FALLBACK_THRESHOLD" in rivanlp_config else NLU_FALLBACK_THRESHOLD


def get_intent_rnlp(resp, result):
def get_intent_old(resp, result):
if hasattr(resp, 'intent') and (hasattr(resp, 'domain') and resp.domain.class_name != "nomatch.none"):
result['intent'] = { 'value': resp.intent.class_name, 'confidence': resp.intent.score }
parentclassintent_index = result['intent']['value'].find(".")
Expand All @@ -43,7 +47,7 @@ def set_nlu_fallback_intent(result):
print("[Riva NLU] Intent set to ", NLU_FALLBACK_INTENT)


def get_entities_rnlp(resp, result):
def get_entities(resp, result):
location_found_flag = False
all_entities_class = {}
all_entities = []
Expand Down Expand Up @@ -72,7 +76,7 @@ def get_entities_rnlp(resp, result):
result['entities'].append(entity)
break
return location_found_flag


def get_entities_rcnlp(resp, result):
location_found_flag = False
Expand All @@ -89,54 +93,71 @@ def get_entities_rcnlp(resp, result):


def get_riva_output(text):
# Submit an AnalyzeIntentRequest. We do not provide a domain with the query, so a domain
# Submit an AnalyzeIntent request. We do not provide a domain with the query, so a domain
# classifier is run first, and based on the inferred value from the domain classifier,
# the query is run through the appropriate intent/slot classifier
# Note: the detected domain is also returned in the response.
result = {'intent': None, 'entities': []}
try:
req = rnlp.AnalyzeIntentRequest()
req.query = str(text)
# The <domain_name> is appended to "riva_intent_" to look for a model "riva_intent_<domain_name>"
# So the model "riva_intent_<domain_name>" needs to be preloaded in riva server.
# In this case the domain is weather and the model being used is "riva_intent_weather-misc".
req.options.domain = "weather"
resp = riva_nlp.AnalyzeIntent(req)
options = riva.client.AnalyzeIntentOptions(lang='en-US', domain='weather')

resp: AnalyzeIntentResponse = riva_nlp.analyze_intent(text, options)

except Exception as inst:
# An exception occurred
print("[Riva NLU] ERROR: Error during NLU request: " + str(inst))
return {"error": 'Riva_NLU_Error', "error_message": str(inst)}
get_intent_rnlp(resp, result)
location_found_flag = get_entities_rnlp(resp, result)
if not location_found_flag:
print("[Riva NLU] Error during NLU request")
return {'riva_error': 'riva_error'}
entities = {}
get_intent(resp, entities)
get_slots(resp, entities)
if 'location' not in entities:
if verbose:
print(f"[Riva NLU] Did not find any location in the string: {text}\n"
"[Riva NLU] Checking again using NER model")
try:
req = rnlp.TokenClassRequest()
req.model.model_name = "riva_ner"
req.text.append(text)
resp_ner = riva_nlp.ClassifyTokens(req)
model_name = "riva_ner"
resp_ner: TokenClassResponse = riva_nlp.classify_tokens(text, model_name)
except Exception as inst:
# An exception occurred
print("[Riva NLU] ERROR: Error during NLU request (riva_ner): " + str(inst))
return {"error": 'Riva_NLU_Error', "error_message": str(inst)}
print("[Riva NLU] Error during NLU request (riva_ner)")
return {'riva_error': 'riva_error'}

if verbose:
print(f"[Riva NLU] NER response results: \n {resp_ner.results[0].results}\n")
print("[Riva NLU] Location Entities:")
location_found_flag = get_entities_rcnlp(resp_ner, result)
if location_found_flag:
if not result['intent']:
result['intent'] = SPECIFY_CITY_INTENT
loc_count = 0
for result in resp_ner.results[0].results:
if result.label[0].class_name == "LOC":
if verbose:
print("[Riva NLU] Intent set to ", SPECIFY_CITY_INTENT)
else:
print(f"[Riva NLU] Location found: {result.token}") # Flow unhandled for multiple location input
loc_count += 1
entities['location'] = result.token
if loc_count == 0:
if verbose:
print("[Riva NLU] No location found in string using NER LOC")
set_nlu_fallback_intent(result)
print("[Riva NLU] Checking response domain")
if resp.domain.class_name == "nomatch.none":
# as a final resort try QA API
if enable_qa == "true":
if verbose:
print("[Riva NLU] Checking using QA API")
riva_misty_profile = requests.get(nlp_config["RIVA_MISTY_PROFILE"]).text # Live pull from Cloud
qa_resp = get_qa_answer(riva_misty_profile, text, p_threshold)
if not qa_resp['result'] == '':
if verbose:
print("[Riva NLU] received qa result")
entities['intent'] = 'qa_answer'
entities['answer_span'] = qa_resp['result']
entities['query'] = text
else:
entities['intent'] = 'riva_error'
else:
entities['intent'] = 'riva_error'
if verbose:
print("[Riva NLU] Final Riva NLU Output: ", result)
return result
print("[Riva NLU] This is what entities contain: ", entities)
return entities


def get_intent_and_entities(text):
Expand Down

0 comments on commit 63a22f9

Please sign in to comment.