From 97cf028f177a18590bea6d7a922636e60f93bd25 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 26 Jun 2024 19:07:16 +0800 Subject: [PATCH] fix unchoose file error --- wren-ai-service/eval/data_curation/app.py | 41 +++++++++++++---------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/wren-ai-service/eval/data_curation/app.py b/wren-ai-service/eval/data_curation/app.py index 666b753970..f2bf8a7630 100644 --- a/wren-ai-service/eval/data_curation/app.py +++ b/wren-ai-service/eval/data_curation/app.py @@ -61,27 +61,32 @@ def on_click_generate_question_sql_pairs(llm_client: AsyncClient): def on_click_setup_uploaded_file(): uploaded_file = st.session_state.get("uploaded_mdl_file") - match = re.match( - r".+_(" + "|".join(DATA_SOURCES) + r")_mdl\.json$", - uploaded_file.name, - ) - if not match: - st.error( - f"the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}" + if uploaded_file: + match = re.match( + r".+_(" + "|".join(DATA_SOURCES) + r")_mdl\.json$", + uploaded_file.name, ) - st.stop() + if not match: + st.error( + f"the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}" + ) + st.stop() - data_source = match.group(1) - st.session_state["data_source"] = data_source - st.session_state["mdl_json"] = orjson.loads( - uploaded_file.getvalue().decode("utf-8") - ) + data_source = match.group(1) + st.session_state["data_source"] = data_source + st.session_state["mdl_json"] = orjson.loads( + uploaded_file.getvalue().decode("utf-8") + ) - st.session_state["connection_info"] = { - "project_id": os.getenv("bigquery.project-id"), - "dataset_id": os.getenv("bigquery.dataset-id"), - "credentials": os.getenv("bigquery.credentials-key"), - } + st.session_state["connection_info"] = { + "project_id": os.getenv("bigquery.project-id"), + "dataset_id": os.getenv("bigquery.dataset-id"), + "credentials": os.getenv("bigquery.credentials-key"), + } + else: + st.session_state["data_source"] = None + st.session_state["mdl_json"] = None + st.session_state["connection_info"] = None def on_change_sql(i: int, key: str):