Skip to content

Commit

Permalink
impl
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Jan 3, 2024
1 parent daa4138 commit 2880bf4
Showing 1 changed file with 52 additions and 39 deletions.
91 changes: 52 additions & 39 deletions autotm/preprocessing/cooc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tempfile
import time
from dataclasses import dataclass
from typing import Dict, Tuple
from typing import Dict, Tuple, Optional, List, Set

import artm
from six import iteritems
Expand All @@ -34,54 +34,56 @@ def __create_batch_dictionary(batch):
return batch_dictionary


def __save_dictionary(cooc_dictionary, num_tokens):
with open('cooc_data.txt', 'w') as fout:
for index in range(num_tokens):
if index in cooc_dictionary:
for key, value in iteritems(cooc_dictionary[index]):
fout.write(u'{0} {1} {2}\n'.format(index, key, value))
def __process_batch(
global_cooc_df_dictionary,
global_cooc_tf_dictionary,
global_cooc_df_term_dictionary,
global_cooc_tf_term_dictionary,
batch,
window_size,
dictionary):
batch_dictionary = __create_batch_dictionary(batch)

def __process_window_df(global_cooc_dict, global_word_dict, token_ids,
doc_seen_pairs: Set[Tuple[int, int]], doc_seen_words: Set[int]):
for j in range(1, len(token_ids)):
token_index_1 = dictionary[batch_dictionary[token_ids[0]]]
token_index_2 = dictionary[batch_dictionary[token_ids[j]]]

def __process_batch(global_cooc_dictionary, batch, window_size, dictionary):
batch_dictionary = __create_batch_dictionary(batch)
token_pair = (min(token_index_1, token_index_2), max(token_index_1, token_index_2))

def __process_window(token_ids, token_weights):
if token_pair not in doc_seen_pairs:
global_cooc_dict[token_pair] = global_cooc_dict.get(token_pair, 0.0) + 1.0
doc_seen_pairs.add(token_pair)

if token_index_1 not in doc_seen_words:
global_word_dict[token_index_1] = global_word_dict.get(token_index_1, 0.0) + 1.0

if token_index_2 not in doc_seen_words:
global_word_dict[token_index_2] = global_word_dict.get(token_index_2, 0.0) + 1.0

def __process_window_tf(global_cooc_dict, global_word_dict, token_ids, token_weights: List[float]):
for j in range(1, len(token_ids)):
value = min(token_weights[0], token_weights[j])
token_index_1 = dictionary[batch_dictionary[token_ids[0]]]
token_index_2 = dictionary[batch_dictionary[token_ids[j]]]

if token_index_1 in global_cooc_dictionary:
if token_index_2 in global_cooc_dictionary:
if token_index_2 in global_cooc_dictionary[token_index_1]:
global_cooc_dictionary[token_index_1][token_index_2] += value
else:
if token_index_1 in global_cooc_dictionary[token_index_2]:
global_cooc_dictionary[token_index_2][token_index_1] += value
else:
global_cooc_dictionary[token_index_1][token_index_2] = value
else:
if token_index_2 in global_cooc_dictionary[token_index_1]:
global_cooc_dictionary[token_index_1][token_index_2] += value
else:
global_cooc_dictionary[token_index_1][token_index_2] = value
else:
if token_index_2 in global_cooc_dictionary:
if token_index_1 in global_cooc_dictionary[token_index_2]:
global_cooc_dictionary[token_index_2][token_index_1] += value
else:
global_cooc_dictionary[token_index_2][token_index_1] = value
else:
global_cooc_dictionary[token_index_1] = {}
global_cooc_dictionary[token_index_1][token_index_2] = value
token_pair = (min(token_index_1, token_index_2), max(token_index_1, token_index_2))
global_cooc_dict[token_pair] = global_cooc_dict.get(token_pair, 0.0) + value
global_word_dict[token_index_1] = global_word_dict.get(token_index_1, 0.0) + 1.0
global_word_dict[token_index_2] = global_word_dict.get(token_index_2, 0.0) + 1.0

for item in batch.item:
doc_seen_pairs = set()
doc_seen_words = set()
real_window_size = window_size if window_size > 0 else len(item.token_id)
for window_start_id in range(len(item.token_id)):
end_index = window_start_id + real_window_size
token_ids = item.token_id[window_start_id: end_index if end_index < len(item.token_id) else len(item.token_id)]
token_weights = item.token_weight[window_start_id: end_index if end_index < len(item.token_id) else len(item.token_id)]
__process_window(token_ids, token_weights)
__process_window_df(global_cooc_df_dictionary, global_cooc_df_term_dictionary, token_ids, doc_seen_pairs, doc_seen_words)
__process_window_tf(global_cooc_tf_dictionary, global_cooc_tf_term_dictionary, token_ids, token_weights)


def __size(global_cooc_dictionary):
result = sys.getsizeof(global_cooc_dictionary)
Expand Down Expand Up @@ -118,21 +120,32 @@ def calculate_cooc(batches_path: str, vocab_path: str, window_size: int=10) -> C
for index, line in enumerate(fin):
dictionary[line.split(' ')[0][0: -1]] = index

# tf dict
global_cooc_dictionary = {}
global_cooc_df_dictionary = dict()
global_cooc_tf_dictionary = dict()
global_cooc_df_term_dictionary = dict()
global_cooc_tf_term_dictionary = dict()
for index, filename in enumerate(batches_list):
local_time_start = time.time()
logger.debug('Processing batch: %s' % index)
current_batch = artm.messages.Batch()
with open(filename, 'rb') as fin:
current_batch.ParseFromString(fin.read())
__process_batch(global_cooc_dictionary, current_batch, window_size, dictionary)
__process_batch(
global_cooc_df_dictionary, global_cooc_tf_dictionary,
global_cooc_df_term_dictionary, global_cooc_tf_term_dictionary,
current_batch, window_size, dictionary
)

logger.debug('Finished batch, elapsed time: %s' % (time.time() - local_time_start))

logger.info(
'Finished cooc dict collection, elapsed time: %s, size: %s Gb'
% (time.time() - global_time_start, __size(global_cooc_dictionary) / 1000000000.0)
% (time.time() - global_time_start, __size(global_cooc_tf_dictionary) / 1000000000.0)
)

return cooc_df_dict, global_cooc_dictionary
return CoocDictionaries(
cooc_df=global_cooc_df_dictionary,
cooc_tf=global_cooc_tf_dictionary,
cooc_df_term=global_cooc_df_term_dictionary,
cooc_tf_term=global_cooc_df_term_dictionary
)

0 comments on commit 2880bf4

Please sign in to comment.