diff --git a/index.py b/index.py index 3a39fbe..5fe7e01 100644 --- a/index.py +++ b/index.py @@ -6,6 +6,18 @@ class Index: + _col_name_col_generator_map = { + 'robust04': 'DefaultLuceneDocumentGenerator', + 'core17': 'DefaultLuceneDocumentGenerator', + 'core18': 'WashingtonPostGenerator' + } + + _col_name_col_type_map = { + 'robust04': 'TrecCollection', + 'core17': 'NewYorkTimesCollection', + 'core18': 'WashingtonPostCollection' + } + def __init__(self, **kwargs): self.arguments = self.get_arguments(kwargs) self.connection = duckdb.connect(self.arguments['database']) @@ -18,22 +30,28 @@ def __init__(self, **kwargs): def get_arguments(kwargs): arguments = { 'database': None, - 'collection': None + 'collection_location': None, + 'collection_name': None } for key, item in arguments.items(): if kwargs.get(key) is not None: arguments[key] = kwargs.get(key) if arguments['database'] is None: raise IOError('Database path needs to be provided.') - if arguments['collection'] is None: + if arguments['collection_location'] is None: raise IOError('Collection path needs to be provided.') + if arguments['collection_name'] is None: + raise IOError('Collection name needs to be provided.') + if arguments['collection_name'] not in ['robust04', 'core17', 'core18']: + raise IOError('Collection name needs to be one of: robust04, core17, core18') return arguments def create_input_table(self): self.cursor.execute(f"CREATE TABLE documents(id VARCHAR, body VARCHAR)") self.connection.begin() - c = collection.Collection('TrecCollection', self.arguments['collection']) - generator = index.Generator('DefaultLuceneDocumentGenerator') + c = collection.Collection(self._col_name_col_type_map[self.arguments['collection_name']], + self.arguments['collection_location']) + generator = index.Generator(self._col_name_col_generator_map[self.arguments['collection_name']]) for fs in c: for i, doc in enumerate(fs): if i % 10000 == 0: @@ -41,7 +59,7 @@ def create_input_table(self): self.connection.begin() try: parsed = generator.create_document(doc) - except: + except: # The document is empty pass doc_id = parsed.get("id") contents = parsed.get("contents") @@ -59,8 +77,14 @@ def create_input_table(self): metavar='[file]', help='Location of the database.') parser.add_argument('-c', - '--collection', + '--collection_location', required=True, metavar='[directory]', help='Location of the collection.') + parser.add_argument('-n', + '--collection_name', + required=True, + metavar=['string'], + choices=['robust04', 'core17', 'core18'], + help='Name of the collection') Index(**vars(parser.parse_args())) diff --git a/search.py b/search.py index 3275ccc..f27339e 100644 --- a/search.py +++ b/search.py @@ -28,6 +28,8 @@ def get_arguments(kwargs): raise IOError('Database path needs to be provided.') if arguments['collection_name'] is None: raise IOError('Collection name needs to be provided.') + if arguments['collection_name'] not in ['robust04', 'core17', 'core18']: + raise IOError('Collection name needs to be one of: robust04, core17, core18') if arguments['outfile'] is None: raise IOError('Output file needs to be provided.') return arguments @@ -66,6 +68,7 @@ def run_topics(self): '--collection_name', required=True, metavar='[string]', + choices=['robust04', 'core17', 'core18'], help='Name of the collection.') parser.add_argument('-o', '--outfile',