From 3cfd9759fd30bba04040209ff7ad1ffad7a53d0c Mon Sep 17 00:00:00 2001 From: Selene Baez Santamaria Date: Wed, 8 Mar 2023 14:58:57 +0100 Subject: [PATCH] modify DatasetWalker to be subscriptable --- scripts/dataset_walker.py | 61 ++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/scripts/dataset_walker.py b/scripts/dataset_walker.py index f2bceb1..b07d0a1 100644 --- a/scripts/dataset_walker.py +++ b/scripts/dataset_walker.py @@ -1,12 +1,13 @@ -import os import json +import os from scripts.knowledge_reader import KnowledgeReader + class DatasetWalker(object): def __init__(self, dataset, dataroot, labels=False, labels_file=None, incl_knowledge=False): path = os.path.join(os.path.abspath(dataroot)) - + if dataset not in ['train', 'val']: raise ValueError('Wrong dataset name: %s' % (dataset)) @@ -31,29 +32,43 @@ def __iter__(self): if self.labels is not None: for log, label in zip(self.logs, self.labels): if self._incl_knowledge is True and label['target'] is True: - for idx, snippet in enumerate(label['knowledge']): - domain = snippet['domain'] - entity_id = snippet['entity_id'] - doc_type = snippet['doc_type'] - doc_id = snippet['doc_id'] - - if doc_type == 'review': - sent_id = snippet['sent_id'] - sent = self._knowledge.get_review_sent(domain, entity_id, doc_id, sent_id) - label['knowledge'][idx]['sent'] = sent - - elif doc_type == 'faq': - doc = self._knowledge.get_faq_doc(domain, entity_id, doc_id) - question = doc['question'] - answer = doc['answer'] - - label['knowledge'][idx]['question'] = question - label['knowledge'][idx]['answer'] = answer - - yield(log, label) + label = self.resolve_knowledge(label) + yield (log, label) else: for log in self.logs: - yield(log, None) + yield (log, None) def __len__(self, ): return len(self.logs) + + def __getitem__(self, item): + if self.labels is not None: + log = self.logs[item] + label = self.labels[item] + if self._incl_knowledge is True and label['target'] is True: + label = self.resolve_knowledge(label) + return (log, label) + else: + return (self.logs[item], None) + + def resolve_knowledge(self, label): + for idx, snippet in enumerate(label['knowledge']): + domain = snippet['domain'] + entity_id = snippet['entity_id'] + doc_type = snippet['doc_type'] + doc_id = snippet['doc_id'] + + if doc_type == 'review': + sent_id = snippet['sent_id'] + sent = self._knowledge.get_review_sent(domain, entity_id, doc_id, sent_id) + label['knowledge'][idx]['sent'] = sent + + elif doc_type == 'faq': + doc = self._knowledge.get_faq_doc(domain, entity_id, doc_id) + question = doc['question'] + answer = doc['answer'] + + label['knowledge'][idx]['question'] = question + label['knowledge'][idx]['answer'] = answer + + return label