From b7bdaea78eef75d7e9e60f2483506c162dc7d2bc Mon Sep 17 00:00:00 2001 From: lordlabuckdas <55460753+lordlabuckdas@users.noreply.github.com> Date: Mon, 28 Jun 2021 20:09:10 +0530 Subject: [PATCH] allow retrying of urls --- snare/cloner.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/snare/cloner.py b/snare/cloner.py index b910696b..2b9f80a5 100644 --- a/snare/cloner.py +++ b/snare/cloner.py @@ -95,7 +95,7 @@ async def process_link(self, url, level, check_host=False): ): return None if url.human_repr() not in self.visited_urls and (level + 1) <= self.max_depth: - await self.new_urls.put((url, level + 1)) + await self.new_urls.put({"url": url, "level": level + 1, "try_count": 0}) res = None try: @@ -151,20 +151,20 @@ def _make_filename(self, url): hash_name = m.hexdigest() return file_name, hash_name - async def fetch_data(self, driver, current_url): + async def fetch_data(self, driver, current_url, level, try_count): raise NotImplementedError async def get_body(self, driver): while not self.new_urls.empty(): print(animation[self.itr], end="\r") self.itr = (self.itr + 1) % len(animation) - current_url, level = await self.new_urls.get() + current_url, level, try_count = (await self.new_urls.get()).values() if current_url.human_repr() in self.visited_urls: continue self.visited_urls.append(current_url.human_repr()) file_name, hash_name = self._make_filename(current_url) self.logger.debug("Cloned file: %s", file_name) - data, headers, content_type = await self.fetch_data(driver, current_url) + data, headers, content_type = await self.fetch_data(driver, current_url, level, try_count) if data is not None: self.meta[file_name]["hash"] = hash_name @@ -183,13 +183,13 @@ async def get_body(self, driver): if not carved_url.is_absolute(): carved_url = self.root.join(carved_url) if carved_url.human_repr() not in self.visited_urls: - await self.new_urls.put((carved_url, level + 1)) + await self.new_urls.put({"url": carved_url, "level": level + 1, "try_count": 0}) try: with open(os.path.join(self.target_path, hash_name), "wb") as index_fh: index_fh.write(data) except TypeError: - await self.new_urls.put((current_url, level)) + await self.new_urls.put({"url": current_url, "level": level, "try_count": try_count + 1}) async def get_root_host(self): try: @@ -204,7 +204,7 @@ async def get_root_host(self): class SimpleCloner(BaseCloner): - async def fetch_data(self, session, current_url): + async def fetch_data(self, session, current_url, level, try_count): data = None headers = [] content_type = None @@ -215,6 +215,7 @@ async def fetch_data(self, session, current_url): data = await response.read() except (aiohttp.ClientError, asyncio.TimeoutError) as client_error: self.logger.error(client_error) + await self.new_urls.put({"url": current_url, "level": level, "try_count": try_count + 1}) else: await response.release() return [data, headers, content_type] @@ -229,7 +230,7 @@ def get_content_type(headers): return val.split(";")[0] return None - async def fetch_data(self, browser, current_url): + async def fetch_data(self, browser, current_url, level, try_count): data = None headers = [] content_type = None @@ -242,6 +243,7 @@ async def fetch_data(self, browser, current_url): data = await response.buffer() except Exception as err: self.logger.error(err) + await self.new_urls.put({"url": current_url, "level": level, "try_count": try_count + 1}) finally: if page: await page.close() @@ -268,8 +270,8 @@ async def run(self): else: driver = await launch() try: - await self.runner.new_urls.put((self.runner.root, 0)) - await self.runner.new_urls.put((self.runner.error_page, 0)) + await self.runner.new_urls.put({"url": self.runner.root, "level": 0, "try_count": 0}) + await self.runner.new_urls.put({"url": self.runner.error_page, "level": 0, "try_count": 0}) await self.runner.get_body(driver) except KeyboardInterrupt: raise