Skip to content

Commit

Permalink
Update github_secret_scanner.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vmwclabot2 authored Feb 14, 2025
1 parent c04a49f commit f3d7eb7
Showing 1 changed file with 124 additions and 85 deletions.
209 changes: 124 additions & 85 deletions scripts/github_secret_scanner.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,176 @@
--- START OF FILE github_secret_scanner.py ---

import argparse
import requests
import logging
import csv
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import lru_cache
from datetime import datetime
from datetime import datetime, timedelta
from contextlib import contextmanager
import sys
import time
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry


# Logger class to manage logging levels and messages
class Logger:
@staticmethod
def setup(log_level='INFO'):
numeric_level = getattr(logging, log_level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {log_level}')
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
_instance = None

def __new__(cls, log_level='INFO'):
if cls._instance is None:
cls._instance = super(Logger, cls).__new__(cls)
numeric_level = getattr(logging, log_level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {log_level}')
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
return cls._instance


# GitHubClient class to handle GitHub API interactions (repos, secret scanning)
class GitHubClient:
def __init__(self, token):
def __init__(self, token, max_retries=3):
self.token = token
self.base_url = "https://api.github.com"
self.headers = {"Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json"}
self.max_retries = max_retries
self.session = self._create_session() # Create session once
self.logger = Logger() # Use the Logger instance
self.rate_limit_remaining = None
self.rate_limit_reset = None

@contextmanager
def github_session(self):
"""Context manager to handle GitHub API session."""
def _create_session(self):
"""Create a requests session with retry logic."""
session = requests.Session()
session.headers.update(self.headers)
retry_strategy = Retry(
total=self.max_retries,
backoff_factor=2, # Exponential backoff
status_forcelist=[429, 500, 502, 503, 504], # Retry on these status codes
allowed_methods=["GET"] # Only retry GET requests
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("https://", adapter)
session.mount("http://", adapter) # Usually not needed for GitHub API
return session

def _handle_rate_limit(self):
"""Checks remaining rate limit and waits if necessary."""
if self.rate_limit_remaining is None or self.rate_limit_remaining < 50: # Threshold.
self.validate_token() # Updates the rate limits

if self.rate_limit_remaining < 10:
wait_time = (self.rate_limit_reset - datetime.now()).total_seconds() + 5 # Add a buffer
if wait_time > 0:
self.logger.info(f"Rate limit approaching. Waiting for {wait_time:.0f} seconds.")
time.sleep(wait_time)
self.validate_token() # Refresh after waiting.

def _request(self, method, url, **kwargs):
"""Centralized request handling with rate limit checks and error handling."""
self._handle_rate_limit()
try:
yield session
finally:
session.close()
response = self.session.request(method, url, **kwargs)
response.raise_for_status()

if 'X-RateLimit-Remaining' in response.headers:
self.rate_limit_remaining = int(response.headers['X-RateLimit-Remaining'])
self.rate_limit_reset = datetime.fromtimestamp(int(response.headers['X-RateLimit-Reset']))

return response
except requests.exceptions.RequestException as e:
self.logger.exception(f"Request failed: {e}") # Use exception for full traceback
raise

def validate_token(self):
url = f"{self.base_url}/rate_limit"
try:
with self.github_session() as session:
response = session.get(url)
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
rate_limit = response.json()['resources']['core']
if rate_limit['remaining'] > 0:
logging.info(f"Rate Limit: {rate_limit['remaining']} remaining.")
else:
logging.error(f"Rate limit exceeded. Reset at {datetime.utcfromtimestamp(rate_limit['reset'])}")
raise Exception("Rate limit exceeded.")
response = self._request("GET", url)
rate_limit = response.json()['resources']['core']
self.rate_limit_remaining = rate_limit['remaining']
self.rate_limit_reset = datetime.fromtimestamp(rate_limit['reset'])

if self.rate_limit_remaining > 0:
self.logger.info(f"Rate Limit: {self.rate_limit_remaining} remaining. Reset at {self.rate_limit_reset}")
else:
self.logger.error(f"Rate limit exceeded. Reset at {self.rate_limit_reset}")
raise Exception("Rate limit exceeded.")
except requests.exceptions.RequestException as e:
logging.error(f"Token validation failed: {e}")
raise # Re-raise the exception to be caught by the main function
self.logger.exception(f"Token validation failed: {e}")
raise

@lru_cache(maxsize=100)
def fetch_default_branch(self, org, repo):
"""Fetch default branch of a repo (cached to avoid redundant calls)."""
"""Fetch default branch of a repo (cached)."""
url = f"{self.base_url}/repos/{org}/{repo}"
try:
with self.github_session() as session:
response = session.get(url)
response.raise_for_status()
repo_data = response.json()
return repo_data['default_branch']
response = self._request("GET", url)
return response.json()['default_branch']
except requests.exceptions.RequestException as e:
logging.error(f"Failed to fetch default branch for {repo}: {e}")
self.logger.exception(f"Failed to fetch default branch for {repo}: {e}")
raise

def fetch_repositories(self, org):
repos = []
url = f"{self.base_url}/orgs/{org}/repos?per_page=100"
try:
with self.github_session() as session:
while url:
response = session.get(url)
response.raise_for_status()
repos.extend(response.json())
url = response.links.get('next', {}).get('url')
while url:
response = self._request("GET", url)
repos.extend(response.json())
url = response.links.get('next', {}).get('url')
except requests.exceptions.RequestException as e:
logging.error(f"Failed to fetch repositories for {org}: {e}")
raise
self.logger.exception(f"Failed to fetch repositories for {org}: {e}")
raise
return repos

def fetch_secret_alerts(self, org, repo):
alerts = []
url = f"{self.base_url}/repos/{org}/{repo}/secret-scanning/alerts?per_page=100&state=open" # Consider only Open Alerts.
url = f"{self.base_url}/repos/{org}/{repo}/secret-scanning/alerts?per_page=100&state=open"
try:
with self.github_session() as session:
while url:
response = session.get(url)
response.raise_for_status()
alerts.extend(response.json())
url = response.links.get('next', {}).get('url')
while url:
response = self._request("GET", url)
alerts.extend(response.json())
url = response.links.get('next', {}).get('url')
except requests.exceptions.RequestException as e:
logging.error(f"Failed to fetch alerts for {repo}: {e}")
self.logger.exception(f"Failed to fetch alerts for {repo}: {e}")
raise
return alerts


# SecretScanner class to handle the logic related to scanning and reporting
class SecretScanner:
def __init__(self, org, token, output_file, include_inactive=False, log_level='INFO', max_workers=10):
def __init__(self, org, token, output_file, include_inactive=False, log_level='INFO', max_workers=10, max_retries=3):
self.org = org
self.token = token
self.output_file = output_file
self.include_inactive = include_inactive
self.max_workers = max_workers
self.client = GitHubClient(self.token)
Logger.setup(log_level)
self.client = GitHubClient(self.token, max_retries) #Pass max_retries
self.logger = Logger(log_level) # Initialize and use the Logger

def is_alert_active(self, org, repo, alert):
"""Checks if an alert is active (at HEAD of default branch)."""
try:
default_branch = self.client.fetch_default_branch(org, repo)
if 'locations' in alert:
if 'locations' in alert and alert['locations']:
for location in alert['locations']:
if location['type'] == 'commit':
with self.client.github_session() as session:
location_response = session.get(location['details_url'])
location_response.raise_for_status()
location_details = location_response.json()

if 'path' in location_details: # sometimes, details do not have the path element.
commits_url = f"{self.client.base_url}/repos/{org}/{repo}/commits?path={location_details['path']}&sha={default_branch}"
commits_response = session.get(commits_url)
commits_response.raise_for_status()
commits = commits_response.json()

if commits and commits[0]['sha'] == location_details['commit_sha']:
return True # Found the commit at head
return False # Not found, or no locations.
if location.get('type') == 'commit':
details_url = location.get('details_url')
if details_url: # Check if details_url exists
location_details = self.client._request("GET", details_url).json()
path = location_details.get('path')
commit_sha = location_details.get('commit_sha')
if path and commit_sha:
commits_url = f"{self.client.base_url}/repos/{org}/{repo}/commits?path={path}&sha={default_branch}"
commits = self.client._request("GET", commits_url).json()

if commits and commits[0]['sha'] == commit_sha:
return True
return False
except requests.exceptions.RequestException as e:
logging.error(f"Error checking if alert is active for {repo}: {e}")
return False # Consider not active if errors occur
self.logger.exception(f"Error checking if alert is active for {repo}: {e}")
return False

def generate_report(self):
try:
Expand All @@ -145,21 +185,19 @@ def generate_report(self):
writer.writerow(["Repository", "Alert ID", "Secret Type", "Status", "Alert URL", "Last Updated"])

with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# Submit all fetch_secret_alerts tasks
future_to_repo = {
executor.submit(self.client.fetch_secret_alerts, self.org, repo['name']): repo
for repo in repos
}

# Process results as they become available
for future in as_completed(future_to_repo):
repo = future_to_repo[future]
try:
alerts = future.result() # Get the result (list of alerts)
alerts = future.result()
for alert in alerts:
is_active = self.is_alert_active(self.org, repo['name'], alert)
if self.include_inactive or is_active:
status = "Active" if is_active else "Inactive" #Explicit state.
status = "Active" if is_active else "Inactive"
writer.writerow([
repo['name'],
alert['number'],
Expand All @@ -169,13 +207,13 @@ def generate_report(self):
alert['updated_at']
])
except Exception as e:
logging.error(f"Error processing alerts for {repo['name']}: {e}") # Log individual repo errors.
self.logger.exception(f"Error processing alerts for {repo['name']}: {e}")

logging.info(f"Report generated: {self.output_file}")
self.logger.info(f"Report generated: {self.output_file}")

except Exception as e:
logging.error(f"Failed to generate report: {e}")
sys.exit(1) # Exit with an error code to signal failure to the workflow
self.logger.exception(f"Failed to generate report: {e}")
sys.exit(1)


# ReportGenerator class to process the report (kept as-is, since no issues found)
Expand All @@ -186,7 +224,7 @@ def count_alerts(input_file):
with open(input_file, mode='r', encoding='utf-8') as file:
reader = csv.reader(file)
next(reader) # Skip header
for _ in reader: # Use _ for unused loop variable
for _ in reader:
total += 1
return total

Expand All @@ -197,9 +235,11 @@ def count_active_alerts(input_file):
reader = csv.reader(file)
next(reader) # Skip header
for row in reader:
if row[3] == "Active": # Index 3 now corresponds to "Status"
if row[3] == "Active":
active += 1
return active


# Main function to handle arguments and initiate the scanning process
def main():
parser = argparse.ArgumentParser(description="GitHub Secret Scanner")
Expand All @@ -209,24 +249,23 @@ def main():
parser.add_argument("--include-inactive", action='store_true', help="Include inactive alerts in the report")
parser.add_argument("--log-level", default="INFO", help="Logging level")
parser.add_argument("--max-workers", type=int, default=10, help="Maximum concurrent workers")
parser.add_argument("--max-retries", type=int, default=3, help="Maximum retries for API requests") # Added argument

args = parser.parse_args()

try:
# Instantiate SecretScanner and generate the report
scanner = SecretScanner(args.org, args.token, args.output, args.include_inactive, args.log_level, args.max_workers)
scanner = SecretScanner(args.org, args.token, args.output, args.include_inactive, args.log_level, args.max_workers, args.max_retries) # Pass max_retries
scanner.generate_report()

# Process report statistics
total_alerts = ReportGenerator.count_alerts(args.output)
active_alerts = ReportGenerator.count_active_alerts(args.output)

logging.info(f"Total alerts found: {total_alerts}")
logging.info(f"Total alerts found: {total_alerts}") # Use logging.info
logging.info(f"Active alerts: {active_alerts}")

except Exception as e:
logging.error(f"An error occurred: {e}")
sys.exit(1) # Exit with error code to be captured on GH Actions.
logging.exception(f"An error occurred: {e}") # Use logging.exception.
sys.exit(1)


if __name__ == "__main__":
Expand Down

0 comments on commit f3d7eb7

Please sign in to comment.