Skip to content

Commit

Permalink
Also support core17 & core18
Browse files Browse the repository at this point in the history
  • Loading branch information
chriskamphuis committed Jan 21, 2021
1 parent 050c161 commit c244fbd
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
36 changes: 30 additions & 6 deletions index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@

class Index:

_col_name_col_generator_map = {
'robust04': 'DefaultLuceneDocumentGenerator',
'core17': 'DefaultLuceneDocumentGenerator',
'core18': 'WashingtonPostGenerator'
}

_col_name_col_type_map = {
'robust04': 'TrecCollection',
'core17': 'NewYorkTimesCollection',
'core18': 'WashingtonPostCollection'
}

def __init__(self, **kwargs):
self.arguments = self.get_arguments(kwargs)
self.connection = duckdb.connect(self.arguments['database'])
Expand All @@ -18,30 +30,36 @@ def __init__(self, **kwargs):
def get_arguments(kwargs):
arguments = {
'database': None,
'collection': None
'collection_location': None,
'collection_name': None
}
for key, item in arguments.items():
if kwargs.get(key) is not None:
arguments[key] = kwargs.get(key)
if arguments['database'] is None:
raise IOError('Database path needs to be provided.')
if arguments['collection'] is None:
if arguments['collection_location'] is None:
raise IOError('Collection path needs to be provided.')
if arguments['collection_name'] is None:
raise IOError('Collection name needs to be provided.')
if arguments['collection_name'] not in ['robust04', 'core17', 'core18']:
raise IOError('Collection name needs to be one of: robust04, core17, core18')
return arguments

def create_input_table(self):
self.cursor.execute(f"CREATE TABLE documents(id VARCHAR, body VARCHAR)")
self.connection.begin()
c = collection.Collection('TrecCollection', self.arguments['collection'])
generator = index.Generator('DefaultLuceneDocumentGenerator')
c = collection.Collection(self._col_name_col_type_map[self.arguments['collection_name']],
self.arguments['collection_location'])
generator = index.Generator(self._col_name_col_generator_map[self.arguments['collection_name']])
for fs in c:
for i, doc in enumerate(fs):
if i % 10000 == 0:
self.connection.commit()
self.connection.begin()
try:
parsed = generator.create_document(doc)
except:
except: # The document is empty
pass
doc_id = parsed.get("id")
contents = parsed.get("contents")
Expand All @@ -59,8 +77,14 @@ def create_input_table(self):
metavar='[file]',
help='Location of the database.')
parser.add_argument('-c',
'--collection',
'--collection_location',
required=True,
metavar='[directory]',
help='Location of the collection.')
parser.add_argument('-n',
'--collection_name',
required=True,
metavar=['string'],
choices=['robust04', 'core17', 'core18'],
help='Name of the collection')
Index(**vars(parser.parse_args()))
3 changes: 3 additions & 0 deletions search.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def get_arguments(kwargs):
raise IOError('Database path needs to be provided.')
if arguments['collection_name'] is None:
raise IOError('Collection name needs to be provided.')
if arguments['collection_name'] not in ['robust04', 'core17', 'core18']:
raise IOError('Collection name needs to be one of: robust04, core17, core18')
if arguments['outfile'] is None:
raise IOError('Output file needs to be provided.')
return arguments
Expand Down Expand Up @@ -66,6 +68,7 @@ def run_topics(self):
'--collection_name',
required=True,
metavar='[string]',
choices=['robust04', 'core17', 'core18'],
help='Name of the collection.')
parser.add_argument('-o',
'--outfile',
Expand Down

0 comments on commit c244fbd

Please sign in to comment.