Skip to content

Commit

Permalink
order emoji replacements
Browse files Browse the repository at this point in the history
  • Loading branch information
djstrong committed Dec 5, 2022
1 parent f31ee3d commit 49da038
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
14 changes: 13 additions & 1 deletion generator/generation/emoji_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@

from omegaconf import DictConfig

from .combination_limiter import CombinationLimiter
from .name_generator import NameGenerator

logger = logging.getLogger('generator')


def order_product(*args):
return [tuple(i[1] for i in p) for p in
sorted(product(*map(enumerate, args)),
key=lambda x: (sum(y[0] for y in x), x))]


class EmojiGenerator(NameGenerator):
"""
Replaces words with their corresponding emojis
Expand All @@ -21,7 +28,12 @@ def __init__(self, config: DictConfig):
with open(config.generation.name2emoji_path, 'r', encoding='utf-8') as f:
self.name2emoji = json.load(f)

self.combination_limiter = CombinationLimiter(self.limit)

def generate(self, tokens: Tuple[str, ...], params: dict[str, Any]) -> List[Tuple[str, ...]]:
all_possibilities = [[token] + self.name2emoji.get(token, []) for token in tokens]

all_possibilities = self.combination_limiter.limit(all_possibilities)

# skipping the item in which all the original tokens are preserved
return list(islice(product(*all_possibilities), 1, self.limit))
return list(islice(order_product(*all_possibilities), 1, self.limit))
8 changes: 3 additions & 5 deletions tests/test_name_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,11 @@ def test_emoji_generator():

all_tokenized = [gn.tokens for gn in generated_names]

print(all_tokenized)

assert ('πŸ₯°', 'your', '🀩') in all_tokenized
assert ('πŸ₯°', 'your', 'πŸ‘€') in all_tokenized
assert ('πŸ₯°', 'your', 'πŸ₯½') in all_tokenized
assert ('πŸ₯°', 'your', 'eyes') in all_tokenized
assert ('adore', 'your', 'πŸ‘€') in all_tokenized
assert ('πŸ₯°', 'your', 'πŸ˜΅β€πŸ’«') in all_tokenized
assert ('πŸ₯°', 'your', 'eyes') in all_tokenized[:2]
assert ('adore', 'your', 'πŸ‘€') in all_tokenized[:2]

assert ('adore', 'your', 'eyes') not in all_tokenized

Expand Down

0 comments on commit 49da038

Please sign in to comment.