Skip to content

Commit

Permalink
allow retrying of urls (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
lordlabuckdas authored Jun 29, 2021
1 parent 67aff95 commit a2b8e3b
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions snare/cloner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def process_link(self, url, level, check_host=False):
):
return None
if url.human_repr() not in self.visited_urls and (level + 1) <= self.max_depth:
await self.new_urls.put((url, level + 1))
await self.new_urls.put({"url": url, "level": level + 1, "try_count": 0})

res = None
try:
Expand Down Expand Up @@ -151,20 +151,22 @@ def _make_filename(self, url):
hash_name = m.hexdigest()
return file_name, hash_name

async def fetch_data(self, driver, current_url):
async def fetch_data(self, driver, current_url, level, try_count):
raise NotImplementedError

async def get_body(self, driver):
while not self.new_urls.empty():
print(animation[self.itr], end="\r")
self.itr = (self.itr + 1) % len(animation)
current_url, level = await self.new_urls.get()
current_url, level, try_count = (await self.new_urls.get()).values()
if try_count > 2:
continue
if current_url.human_repr() in self.visited_urls:
continue
self.visited_urls.append(current_url.human_repr())
file_name, hash_name = self._make_filename(current_url)
self.logger.debug("Cloned file: %s", file_name)
data, headers, content_type = await self.fetch_data(driver, current_url)
data, headers, content_type = await self.fetch_data(driver, current_url, level, try_count)

if data is not None:
self.meta[file_name]["hash"] = hash_name
Expand All @@ -183,13 +185,13 @@ async def get_body(self, driver):
if not carved_url.is_absolute():
carved_url = self.root.join(carved_url)
if carved_url.human_repr() not in self.visited_urls:
await self.new_urls.put((carved_url, level + 1))
await self.new_urls.put({"url": carved_url, "level": level + 1, "try_count": 0})

try:
with open(os.path.join(self.target_path, hash_name), "wb") as index_fh:
index_fh.write(data)
except TypeError:
await self.new_urls.put((current_url, level))
await self.new_urls.put({"url": current_url, "level": level, "try_count": try_count + 1})

async def get_root_host(self):
try:
Expand All @@ -204,7 +206,7 @@ async def get_root_host(self):


class SimpleCloner(BaseCloner):
async def fetch_data(self, session, current_url):
async def fetch_data(self, session, current_url, level, try_count):
data = None
headers = []
content_type = None
Expand All @@ -215,6 +217,7 @@ async def fetch_data(self, session, current_url):
data = await response.read()
except (aiohttp.ClientError, asyncio.TimeoutError) as client_error:
self.logger.error(client_error)
await self.new_urls.put({"url": current_url, "level": level, "try_count": try_count + 1})
else:
await response.release()
return [data, headers, content_type]
Expand All @@ -229,7 +232,7 @@ def get_content_type(headers):
return val.split(";")[0]
return None

async def fetch_data(self, browser, current_url):
async def fetch_data(self, browser, current_url, level, try_count):
data = None
headers = []
content_type = None
Expand All @@ -242,6 +245,7 @@ async def fetch_data(self, browser, current_url):
data = await response.buffer()
except Exception as err:
self.logger.error(err)
await self.new_urls.put({"url": current_url, "level": level, "try_count": try_count + 1})
finally:
if page:
await page.close()
Expand All @@ -268,8 +272,8 @@ async def run(self):
else:
driver = await launch()
try:
await self.runner.new_urls.put((self.runner.root, 0))
await self.runner.new_urls.put((self.runner.error_page, 0))
await self.runner.new_urls.put({"url": self.runner.root, "level": 0, "try_count": 0})
await self.runner.new_urls.put({"url": self.runner.error_page, "level": 0, "try_count": 0})
await self.runner.get_body(driver)
except KeyboardInterrupt:
raise
Expand Down

0 comments on commit a2b8e3b

Please sign in to comment.