diff --git a/server/data.py b/server/data.py index fdcf317..707a220 100644 --- a/server/data.py +++ b/server/data.py @@ -30,7 +30,7 @@ def sent_2_words(sent: str) -> List[str]: return words -def corpus_to_words(file_path: str, vocab_path: str = str(VOCAB_PATH)): +def corpus_to_words(file_path: str): my_counter: Counter = Counter() with open(file_path, "r", encoding="utf-8") as fl: for sent in tqdm(fl, desc="Precess file"): @@ -40,7 +40,10 @@ def corpus_to_words(file_path: str, vocab_path: str = str(VOCAB_PATH)): min_count = max([10, max_cnt / 100]) selected_words = [word for word, count in my_counter.items() if (min_count < count <= max_cnt)] + return selected_words + +def save_words(selected_words, vocab_path: str = str(VOCAB_PATH)): with open(vocab_path, "w", encoding="utf-8") as fl: for w in selected_words: fl.write(w + "\n") @@ -179,7 +182,8 @@ def add_new_text(bucket_name=BUCKET_SPLIT_TEXTS, destination_bucket_name=BUCKET_ to_copy = sorted(set(all_files) - set(copied_files))[0] print(to_copy) download_blob(bucket_name, to_copy, DATA_PATH / to_copy) - corpus_to_words(DATA_PATH / to_copy) + words = corpus_to_words(DATA_PATH / to_copy) + save_words(words) copy_blob(bucket_name, to_copy, destination_bucket_name, to_copy) diff --git a/the_hat_game/data.py b/the_hat_game/data.py new file mode 100644 index 0000000..28ae721 --- /dev/null +++ b/the_hat_game/data.py @@ -0,0 +1,32 @@ +import re +from collections import Counter +from typing import List + +from nltk.corpus import stopwords +from nltk.stem import WordNetLemmatizer +from nltk.tokenize import word_tokenize +from tqdm.auto import tqdm + +STOP_WORDS = stopwords.words("english") +LEMMATIZER = WordNetLemmatizer() + + +def sent_2_words(sent: str) -> List[str]: + sent = sent.lower() + sent = re.sub("[^a-z]+", " ", sent) + words = word_tokenize(sent) + words = [LEMMATIZER.lemmatize(word) for word in words if ((word not in STOP_WORDS) and len(word.strip()) > 3)] + return words + + +def corpus_to_words(file_path: str): + my_counter: Counter = Counter() + with open(file_path, "r", encoding="utf-8") as fl: + for sent in tqdm(fl, desc="Precess file"): + my_counter.update(sent_2_words(sent)) + + max_cnt = max(count for word, count in my_counter.items()) / 10 + min_count = max([10, max_cnt / 100]) + + selected_words = [word for word, count in my_counter.items() if (min_count < count <= max_cnt)] + return selected_words diff --git a/the_hat_game/game.py b/the_hat_game/game.py index 008ad9f..3208a58 100644 --- a/the_hat_game/game.py +++ b/the_hat_game/game.py @@ -13,6 +13,7 @@ import the_hat_game.nltk_setup # noqa: F401 from the_hat_game.loggers import c_handler, logger from the_hat_game.players import RemotePlayer +from the_hat_game.data import corpus_to_words class Game: @@ -24,8 +25,23 @@ def __init__( n_rounds, n_explain_words, n_guessing_words, + corpus_path=None, + vocab_path=None, random_state=None, ): + """Main class for Game. + params: + - players: list of AbstractPlayer - players in the game + - words: list of str - all used words for guessing + - criteria: 'hard' of 'soft' - game criteria + - n_rounds: int - number of rounds + - n_explain_words: int - number of words for explaining + - n_guessing_words: int - number of words for guessing + - corpus_path: str - path for the corpus to create vocabulary (for criteria='hard') + - vocab_path: str - path for vocabulary (for criteria='hard') + NOTE: only corpus_path or vocab_path must be defined + NOTE: if vocabulary is not defined nltk.wordnet will be used for filter not existing words + """ assert len(players) >= 2 assert criteria in ("hard", "soft") self.players = players @@ -36,6 +52,21 @@ def __init__( self.n_guessing_words = n_guessing_words self.random_state = random_state self.stemmer = SnowballStemmer("english") + if corpus_path is not None: + assert vocab_path is None, "corpus and vocabulary cannot be defined at the same time" + self.whitelist = corpus_to_words(corpus_path) + elif vocab_path is not None: + with open(vocab_path, encoding="utf-8") as f: + vocab_words = f.readlines() + self.whitelist = [word.strip() for word in vocab_words] + else: + self.whitelist = None + # add all words for guessing to non-empty whitelist + if self.whitelist is not None: + self.whitelist += [ + word for word in self.words + if word not in self.whitelist + ] def remove_repeated_words(self, words): unique_words = [] @@ -54,9 +85,11 @@ def remove_same_rooted_words(self, word, word_list): cleared_word_list = [w for w in word_list if self.stemmer.stem(w) != root] return cleared_word_list - @staticmethod - def remove_non_existing_words(words): - existing_words = [w for w in words if len(wordnet.synsets(w)) > 0] + def remove_non_existing_words(self, words): + if self.whitelist is not None: + existing_words = [w for w in words if w in self.whitelist] + else: + existing_words = [w for w in words if len(wordnet.synsets(w)) > 0] return existing_words def create_word_list(self, player, word, n_words):