Skip to content

Commit

Permalink
attempt to fix dataprocessing for transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
SecroLoL committed Jan 12, 2024
1 parent 023dd95 commit 02215ce
Showing 1 changed file with 56 additions and 21 deletions.
77 changes: 56 additions & 21 deletions stanza/models/lemma_classifier/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,35 +39,70 @@ def load_dataset(data_path: str, batch_size=DEFAULT_BATCH_SIZE, get_counts: bool
# but we can still use those labels in a confusion matrix
label_decoder = dict(label_decoder)

print(label_decoder, f"Should be strings to ints")

with open(data_path, "r+", encoding="utf-8") as f:
sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), defaultdict(str)
data_processor = prepare_dataset.DataProcessor("", [], "")
sentences_data = data_processor.read_processed_data(data_path)

for idx, sentence in enumerate(sentences_data):
# TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons
words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma")
if None in [words, target_idx, upos_tags, label]:
raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}")
for i, line in enumerate(f.readlines()):
split = line.split(" ")
num_tokens = split[-1]
sentence = split[: num_tokens]
upos_tags = split[num_tokens: num_tokens * 2]
tgt_idx = int(split[-3])
lemma = split[-2]

if i < 3:
print(sentences)
print(upos_tags)
print(tgt_idx)
print(lemma)

label_id = label_decoder.get(label, None)
label_id = label_decoder.get(lemma, None)
if label_id is None:
label_decoder[label] = len(label_decoder) # create a new ID for the unknown label

converted_upos_tags = [] # convert upos tags to upos IDs
label_decoder[lemma] = len(label_decoder)

print(upos_to_id)
converted_upos_tags = []
for upos_tag in upos_tags:
upos_id = upos_to_id.get(upos_tag, None)
if upos_id is None:
upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag
upos_idx = upos_to_id.get(upos_tag, None)
if upos_idx is None:
upos_to_id[upos_tag] = len(upos_to_id)
converted_upos_tags.append(upos_to_id[upos_tag])

sentences.append(words)
indices.append(target_idx)
sentences.append(sentence)
indices.append(tgt_idx)
upos_ids.append(converted_upos_tags)
labels.append(label_decoder[label])
labels.append(label_decoder[lemma])

if get_counts:
counts[label_decoder[label]] += 1
if get_counts:
counts[label_decoder[lemma]] += 1

# data_processor = prepare_dataset.DataProcessor("", [], "")
# sentences_data = data_processor.read_processed_data(data_path)

# for idx, sentence in enumerate(sentences_data):
# # TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons
# words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma")
# if None in [words, target_idx, upos_tags, label]:
# raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}")

# label_id = label_decoder.get(label, None)
# if label_id is None:
# label_decoder[label] = len(label_decoder) # create a new ID for the unknown label

# converted_upos_tags = [] # convert upos tags to upos IDs
# for upos_tag in upos_tags:
# upos_id = upos_to_id.get(upos_tag, None)
# if upos_id is None:
# upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag
# converted_upos_tags.append(upos_to_id[upos_tag])

# sentences.append(words)
# indices.append(target_idx)
# upos_ids.append(converted_upos_tags)
# labels.append(label_decoder[label])

# if get_counts:
# counts[label_decoder[label]] += 1

sentence_batches = [sentences[i: i + batch_size] for i in range(0, len(sentences), batch_size)]
indices_batches = [torch.tensor(indices[i: i + batch_size]) for i in range(0, len(indices), batch_size)]
Expand Down

0 comments on commit 02215ce

Please sign in to comment.