Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian authored Jul 11, 2024
1 parent a5c1869 commit 78e0c3b
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions lmms_eval/tasks/vcr_wiki/utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,46 @@
import datetime
import json

import os
from difflib import SequenceMatcher as SM
from functools import partial

import evaluate
import numpy as np
import spacy
from nltk.util import ngrams
from spacy.cli import download

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

# Download the English and Chinese models
try:
nlp_en = spacy.load("en_core_web_sm")
except Exception as e:
download("en_core_web_sm")
nlp_en = spacy.load("en_core_web_sm")
from difflib import SequenceMatcher as SM
from functools import partial

try:
nlp_zh = spacy.load("zh_core_web_sm")
except Exception as e:
download("zh_core_web_sm")
nlp_zh = spacy.load("zh_core_web_sm")
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
from loguru import logger as eval_logger

nlp = {"en": nlp_en, "zh": nlp_zh}
rouge = evaluate.load("rouge")
with open(Path(__file__).parent / "_default_template_vcr_yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)

from loguru import logger as eval_logger
config = yaml.safe_load("".join(safe_data))

dir_name = os.path.dirname(os.path.abspath(__file__))
# Download the English and Chinese models
if config["load_package"]:
try:
nlp_en = spacy.load("en_core_web_sm")
nlp_zh = spacy.load("zh_core_web_sm")
nlp = {"en": nlp_en, "zh": nlp_zh}
rouge = evaluate.load("rouge")
except Exception as e:
eval_logger.debug(f"Failed to load spacy models: {e}")
download("en_core_web_sm")
nlp_en = spacy.load("en_core_web_sm")
download("zh_core_web_sm")
nlp_zh = spacy.load("zh_core_web_sm")
else:
nlp = {"en": None, "zh": None}
rouge = None
eval_logger.debug("Spacy models not loaded due to load_package is False. Please set load_package to True in the config file to load them.")

aggregate_results_template = {
"max_sim_val": 0,
Expand Down

0 comments on commit 78e0c3b

Please sign in to comment.