diff --git a/autotm/preprocessing/cooc.py b/autotm/preprocessing/cooc.py index 24abd7b..b673e4d 100644 --- a/autotm/preprocessing/cooc.py +++ b/autotm/preprocessing/cooc.py @@ -8,7 +8,7 @@ import tempfile import time from dataclasses import dataclass -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional, List, Set import artm from six import iteritems @@ -34,54 +34,56 @@ def __create_batch_dictionary(batch): return batch_dictionary -def __save_dictionary(cooc_dictionary, num_tokens): - with open('cooc_data.txt', 'w') as fout: - for index in range(num_tokens): - if index in cooc_dictionary: - for key, value in iteritems(cooc_dictionary[index]): - fout.write(u'{0} {1} {2}\n'.format(index, key, value)) +def __process_batch( + global_cooc_df_dictionary, + global_cooc_tf_dictionary, + global_cooc_df_term_dictionary, + global_cooc_tf_term_dictionary, + batch, + window_size, + dictionary): + batch_dictionary = __create_batch_dictionary(batch) + def __process_window_df(global_cooc_dict, global_word_dict, token_ids, + doc_seen_pairs: Set[Tuple[int, int]], doc_seen_words: Set[int]): + for j in range(1, len(token_ids)): + token_index_1 = dictionary[batch_dictionary[token_ids[0]]] + token_index_2 = dictionary[batch_dictionary[token_ids[j]]] -def __process_batch(global_cooc_dictionary, batch, window_size, dictionary): - batch_dictionary = __create_batch_dictionary(batch) + token_pair = (min(token_index_1, token_index_2), max(token_index_1, token_index_2)) - def __process_window(token_ids, token_weights): + if token_pair not in doc_seen_pairs: + global_cooc_dict[token_pair] = global_cooc_dict.get(token_pair, 0.0) + 1.0 + doc_seen_pairs.add(token_pair) + + if token_index_1 not in doc_seen_words: + global_word_dict[token_index_1] = global_word_dict.get(token_index_1, 0.0) + 1.0 + + if token_index_2 not in doc_seen_words: + global_word_dict[token_index_2] = global_word_dict.get(token_index_2, 0.0) + 1.0 + + def __process_window_tf(global_cooc_dict, global_word_dict, token_ids, token_weights: List[float]): for j in range(1, len(token_ids)): value = min(token_weights[0], token_weights[j]) token_index_1 = dictionary[batch_dictionary[token_ids[0]]] token_index_2 = dictionary[batch_dictionary[token_ids[j]]] - if token_index_1 in global_cooc_dictionary: - if token_index_2 in global_cooc_dictionary: - if token_index_2 in global_cooc_dictionary[token_index_1]: - global_cooc_dictionary[token_index_1][token_index_2] += value - else: - if token_index_1 in global_cooc_dictionary[token_index_2]: - global_cooc_dictionary[token_index_2][token_index_1] += value - else: - global_cooc_dictionary[token_index_1][token_index_2] = value - else: - if token_index_2 in global_cooc_dictionary[token_index_1]: - global_cooc_dictionary[token_index_1][token_index_2] += value - else: - global_cooc_dictionary[token_index_1][token_index_2] = value - else: - if token_index_2 in global_cooc_dictionary: - if token_index_1 in global_cooc_dictionary[token_index_2]: - global_cooc_dictionary[token_index_2][token_index_1] += value - else: - global_cooc_dictionary[token_index_2][token_index_1] = value - else: - global_cooc_dictionary[token_index_1] = {} - global_cooc_dictionary[token_index_1][token_index_2] = value + token_pair = (min(token_index_1, token_index_2), max(token_index_1, token_index_2)) + global_cooc_dict[token_pair] = global_cooc_dict.get(token_pair, 0.0) + value + global_word_dict[token_index_1] = global_word_dict.get(token_index_1, 0.0) + 1.0 + global_word_dict[token_index_2] = global_word_dict.get(token_index_2, 0.0) + 1.0 for item in batch.item: + doc_seen_pairs = set() + doc_seen_words = set() real_window_size = window_size if window_size > 0 else len(item.token_id) for window_start_id in range(len(item.token_id)): end_index = window_start_id + real_window_size token_ids = item.token_id[window_start_id: end_index if end_index < len(item.token_id) else len(item.token_id)] token_weights = item.token_weight[window_start_id: end_index if end_index < len(item.token_id) else len(item.token_id)] - __process_window(token_ids, token_weights) + __process_window_df(global_cooc_df_dictionary, global_cooc_df_term_dictionary, token_ids, doc_seen_pairs, doc_seen_words) + __process_window_tf(global_cooc_tf_dictionary, global_cooc_tf_term_dictionary, token_ids, token_weights) + def __size(global_cooc_dictionary): result = sys.getsizeof(global_cooc_dictionary) @@ -118,21 +120,32 @@ def calculate_cooc(batches_path: str, vocab_path: str, window_size: int=10) -> C for index, line in enumerate(fin): dictionary[line.split(' ')[0][0: -1]] = index - # tf dict - global_cooc_dictionary = {} + global_cooc_df_dictionary = dict() + global_cooc_tf_dictionary = dict() + global_cooc_df_term_dictionary = dict() + global_cooc_tf_term_dictionary = dict() for index, filename in enumerate(batches_list): local_time_start = time.time() logger.debug('Processing batch: %s' % index) current_batch = artm.messages.Batch() with open(filename, 'rb') as fin: current_batch.ParseFromString(fin.read()) - __process_batch(global_cooc_dictionary, current_batch, window_size, dictionary) + __process_batch( + global_cooc_df_dictionary, global_cooc_tf_dictionary, + global_cooc_df_term_dictionary, global_cooc_tf_term_dictionary, + current_batch, window_size, dictionary + ) logger.debug('Finished batch, elapsed time: %s' % (time.time() - local_time_start)) logger.info( 'Finished cooc dict collection, elapsed time: %s, size: %s Gb' - % (time.time() - global_time_start, __size(global_cooc_dictionary) / 1000000000.0) + % (time.time() - global_time_start, __size(global_cooc_tf_dictionary) / 1000000000.0) ) - return cooc_df_dict, global_cooc_dictionary + return CoocDictionaries( + cooc_df=global_cooc_df_dictionary, + cooc_tf=global_cooc_tf_dictionary, + cooc_df_term=global_cooc_df_term_dictionary, + cooc_tf_term=global_cooc_df_term_dictionary + )