diff --git a/ailab/db/crawler/__init__.py b/ailab/db/crawler/__init__.py index 89964cc..0c452dc 100644 --- a/ailab/db/crawler/__init__.py +++ b/ailab/db/crawler/__init__.py @@ -210,15 +210,13 @@ def fetch_crawl_row(cursor, url): assert 'html_content' in row.keys() return row -def fetch_chunk_token_row(cursor, url): +def fetch_chunk_token_row(cursor, id): """Fetch the most recent chunk token for a given chunk id.""" - data = db.parse_postgresql_url(url) + data = {'id': id} cursor.execute( - "SELECT chunk.id as chunk_id, token.id as token_id, tokens FROM chunk" - " JOIN token ON chunk.id = token.chunk_id" - " WHERE chunk.id = %(id)s LIMIT 1", + """SELECT chunk.id as chunk_id, token.id as token_id, tokens FROM chunk + JOIN token ON chunk.id = token.chunk_id + WHERE chunk.id = %(id)s LIMIT 1""", data ) - # psycopg.extras.DictRow is not a real dict and will convert - # to string as a list so we force convert to dict return cursor.fetchone() diff --git a/tests/test_db_crawler.py b/tests/test_db_crawler.py index f3fcd0a..f3d98f5 100644 --- a/tests/test_db_crawler.py +++ b/tests/test_db_crawler.py @@ -53,20 +53,6 @@ def test_fetch_crawl_row_by_postgresql_url(self): row['title'], "Sampling procedures - Canadian Food Inspection Agency") - def test_fetch_chunk_row(self): - """sample test to check if fetch_chunk_row works""" - url = db.create_postgresql_url( - "DBNAME", - "chunk", - "469812c5-190c-4e56-9f88-c8621592bcb5") - with db.cursor(self.connection) as cursor: - row = crawler.fetch_chunk_token_row(cursor, url) - self.connection.rollback() - self.assertTrue(isinstance(row, dict)) - self.assertEqual(len(row['tokens']), 76) - self.assertEqual(str(row['chunk_id']), "469812c5-190c-4e56-9f88-c8621592bcb5") - self.assertEqual(str(row['token_id']), 'dbb7b498-2cbf-4ae9-aa10-3169cc72f285') - def test_fetch_chunk_id_without_embedding(self): """sample test to check if fetch_chunk_id_without_embedding works""" with db.cursor(self.connection) as cursor: @@ -115,24 +101,24 @@ def test_store_embedding_item(self): self.connection.rollback() self.assertEqual(item["token_id"], "be612259-9b52-42fd-8d0b-d72120efa3b6") - # def test_fetch_crawl_ids_without_chunk(self): - # """Test fetching crawl IDs without a chunk.""" - # with db.cursor(self.connection) as cursor: - # id = crawler.fetch_crawl_ids_without_chunk(cursor) - # self.connection.rollback() - # self.assertTrue() + def test_fetch_crawl_ids_without_chunk(self): + """Test fetching crawl IDs without a chunk.""" + with db.cursor(self.connection) as cursor: + id = crawler.fetch_crawl_ids_without_chunk(cursor) + self.connection.rollback() + self.assertEqual(id, []) - # def test_fetch_crawl_row(self): - # """Test fetching a crawl row.""" - # with db.cursor(self.connection) as cursor: - # id = crawler.fetch_crawl_row(cursor, "https://example.com") - # self.connection.rollback() - # self.assertFalse() + def test_fetch_crawl_row(self): + """Test fetching a crawl row.""" + with db.cursor(self.connection) as cursor: + row = crawler.fetch_crawl_row(cursor, "https://inspection.canada.ca/a-propos-de-l-acia/structure-organisationnelle/mandat/fra/1299780188624/1319164463699") + self.connection.rollback() + self.assertEqual(row['title'], "Mandat - Agence canadienne d'inspection des aliments") - # def test_fetch_chunk_token_row(self): - # """Test fetching a chunk token row.""" - # with db.cursor(self.connection) as cursor: - # crawler.fetch_chunk_token_row(cursor, "https://example.com") - # self.connection.rollback() - # self.assertFalse() + def test_fetch_chunk_token_row(self): + """Test fetching a chunk token row.""" + with db.cursor(self.connection) as cursor: + row = crawler.fetch_chunk_token_row(cursor, "469812c5-190c-4e56-9f88-c8621592bcb5") + self.connection.rollback() + self.assertEqual(str(row['chunk_id']), "469812c5-190c-4e56-9f88-c8621592bcb5")