diff --git a/explainaboard/analysis/feature_funcs.py b/explainaboard/analysis/feature_funcs.py index 2b5ec275..e89e3f09 100644 --- a/explainaboard/analysis/feature_funcs.py +++ b/explainaboard/analysis/feature_funcs.py @@ -11,7 +11,7 @@ from explainaboard.info import SysOutputInfo from explainaboard.utils import basic_words from explainaboard.utils.logging import progress -from explainaboard.utils.tokenizer import Tokenizer +from explainaboard.utils.tokenizer import SingleSpaceTokenizer, Tokenizer from explainaboard.utils.typing_utils import unwrap @@ -68,17 +68,10 @@ def get_basic_words(text: str) -> float: Returns: The ratio of basic words. """ - value_list = text.split(' ') - n_words = len(value_list) - n_basic_words = 0 - - for word in value_list: - - lower = word.lower() - if lower in basic_words.BASIC_WORDS: - n_basic_words = n_basic_words + 1 - - return n_basic_words * 1.0 / n_words + tokens = SingleSpaceTokenizer()(text) + assert len(tokens) > 0, f"BUG: no tokens obtained from the text: '{text}'" + n_basic_words = sum(1 for t in tokens if t.lower() in basic_words.BASIC_WORDS) + return n_basic_words / len(tokens) def get_lexical_richness(text: str) -> float: diff --git a/explainaboard/analysis/feature_funcs_test.py b/explainaboard/analysis/feature_funcs_test.py new file mode 100644 index 00000000..3e4f8827 --- /dev/null +++ b/explainaboard/analysis/feature_funcs_test.py @@ -0,0 +1,35 @@ +"""Tests for explainaboard.analysis.feature_funcs.""" + + +import unittest + +from explainaboard.analysis.feature_funcs import get_basic_words + + +class FeatureFuncsTest(unittest.TestCase): + def test_get_basic_words(self) -> None: + # All examples should exactly match. + + # zero word + self.assertEqual(get_basic_words(""), 0.0) + self.assertEqual(get_basic_words(" "), 0.0) + + # one word + self.assertEqual(get_basic_words("the"), 1.0) + self.assertEqual(get_basic_words(" the"), 0.5) + self.assertEqual(get_basic_words(" the "), 1 / 3) + self.assertEqual(get_basic_words("USA"), 0.0) + + # two words + self.assertEqual(get_basic_words("United States"), 0.0) + self.assertEqual(get_basic_words("The USA"), 0.5) + self.assertEqual(get_basic_words("The country"), 1.0) + + # check capitalization + self.assertEqual(get_basic_words("The THE the tHE"), 1.0) + + # check punctuation + self.assertEqual(get_basic_words("It is."), 0.5) + self.assertEqual(get_basic_words("It is ."), 2 / 3) + self.assertEqual(get_basic_words("It, is"), 0.5) + self.assertEqual(get_basic_words("It , is"), 2 / 3) diff --git a/explainaboard/utils/tokenizer_test.py b/explainaboard/utils/tokenizer_test.py index f07f0e0c..467350d4 100644 --- a/explainaboard/utils/tokenizer_test.py +++ b/explainaboard/utils/tokenizer_test.py @@ -44,6 +44,38 @@ def test_from_orig_and_tokens_invalid(self) -> None: class TokenizerSerializerTest(unittest.TestCase): + def test_empty(self) -> None: + tokens = SingleSpaceTokenizer()("") + self.assertEqual(len(tokens), 1) + self.assertEqual(tokens.strs, [""]) + self.assertEqual(tokens.positions, [0]) + + def test_only_0x20(self) -> None: + tokens = SingleSpaceTokenizer()(" ") + self.assertEqual(len(tokens), 4) + self.assertEqual(tokens.strs, ["", "", "", ""]) + self.assertEqual(tokens.positions, [0, 1, 2, 3]) + + def test_isspace(self) -> None: + tokens = SingleSpaceTokenizer()("\t\v \n\r\f") + self.assertEqual(len(tokens), 2) + self.assertEqual(tokens.strs, ["\t\v", "\n\r\f"]) + self.assertEqual(tokens.positions, [0, 3]) + + def test_sentence(self) -> None: + tokens = SingleSpaceTokenizer()("May the force be with you.") + self.assertEqual(len(tokens), 6) + self.assertEqual(tokens.strs, ["May", "the", "force", "be", "with", "you."]) + self.assertEqual(tokens.positions, [0, 4, 8, 14, 17, 22]) + + def test_sentence_with_extra_whitespaces(self) -> None: + tokens = SingleSpaceTokenizer()(" May the force\nbe with you. ") + self.assertEqual(len(tokens), 8) + self.assertEqual( + tokens.strs, ["", "May", "", "the", "force\nbe", "with", "you.", ""] + ) + self.assertEqual(tokens.positions, [0, 1, 5, 6, 10, 19, 24, 29]) + def test_serialize(self) -> None: serializer = PrimitiveSerializer() self.assertEqual(