diff --git a/generator/generation/emoji_generator.py b/generator/generation/emoji_generator.py index 6895b5d8..1b1c6d12 100644 --- a/generator/generation/emoji_generator.py +++ b/generator/generation/emoji_generator.py @@ -5,11 +5,18 @@ from omegaconf import DictConfig +from .combination_limiter import CombinationLimiter from .name_generator import NameGenerator logger = logging.getLogger('generator') +def order_product(*args): + return [tuple(i[1] for i in p) for p in + sorted(product(*map(enumerate, args)), + key=lambda x: (sum(y[0] for y in x), x))] + + class EmojiGenerator(NameGenerator): """ Replaces words with their corresponding emojis @@ -21,7 +28,12 @@ def __init__(self, config: DictConfig): with open(config.generation.name2emoji_path, 'r', encoding='utf-8') as f: self.name2emoji = json.load(f) + self.combination_limiter = CombinationLimiter(self.limit) + def generate(self, tokens: Tuple[str, ...], params: dict[str, Any]) -> List[Tuple[str, ...]]: all_possibilities = [[token] + self.name2emoji.get(token, []) for token in tokens] + + all_possibilities = self.combination_limiter.limit(all_possibilities) + # skipping the item in which all the original tokens are preserved - return list(islice(product(*all_possibilities), 1, self.limit)) + return list(islice(order_product(*all_possibilities), 1, self.limit)) diff --git a/tests/test_name_generators.py b/tests/test_name_generators.py index 40586a85..82b061b7 100644 --- a/tests/test_name_generators.py +++ b/tests/test_name_generators.py @@ -248,13 +248,11 @@ def test_emoji_generator(): all_tokenized = [gn.tokens for gn in generated_names] - print(all_tokenized) - assert ('🥰', 'your', '🤩') in all_tokenized assert ('🥰', 'your', '👀') in all_tokenized - assert ('🥰', 'your', '🥽') in all_tokenized - assert ('🥰', 'your', 'eyes') in all_tokenized - assert ('adore', 'your', '👀') in all_tokenized + assert ('🥰', 'your', '😵‍💫') in all_tokenized + assert ('🥰', 'your', 'eyes') in all_tokenized[:2] + assert ('adore', 'your', '👀') in all_tokenized[:2] assert ('adore', 'your', 'eyes') not in all_tokenized