Skip to content

Commit

Permalink
Merge pull request #4 from duo-labs/andresriancho-master
Browse files Browse the repository at this point in the history
Combines Andre's changes with some edits to ensure no profile needs to be set
  • Loading branch information
jordan-wright authored Apr 17, 2020
2 parents 8430dd5 + d44d783 commit 515f29e
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions resources/partitioner/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,30 @@
import datetime
import re
import os
import argparse
import sys
from pathlib import Path

__version__ = "1.0.0"

def get_session():
parser = argparse.ArgumentParser()
parser.add_argument(
'--profile',
help='AWS profile from ~/.aws/config',
required=False,
default=None
)
args = parser.parse_args()

try:
session = boto3.Session(profile_name=args.profile)
except Exception as e:
print('%s' % e)
sys.exit(1)

return session


class athena_querier:
database = ""
Expand All @@ -19,7 +41,7 @@ class athena_querier:
def __init__(self, database, output):
self.database = database
self.output_bucket = output
self.athena = boto3.client("athena")
self.athena = get_session().client("athena")

self.query(
"CREATE DATABASE IF NOT EXISTS {db} {comment}".format(
Expand Down Expand Up @@ -98,15 +120,16 @@ def wait_for_query_to_complete(self, queryExecutionId):


def main():

print("Starting cloudtrail_partitioner {}".format(__version__))

# Read config
config = {}
try:
config_file = Path("../../config/config.yaml")
with open(config_file, "r") as stream:
config = yaml.safe_load(stream)
except Exception as e:
logging.info("Unable to open config file, will try getting config from environment variables")
print("Unable to open config file, will try getting config from environment variables")

# Override the config file with the environment variables
if 'S3_BUCKET_CONTAINING_LOGS' in os.environ:
Expand All @@ -126,11 +149,11 @@ def main():
raise Exception("No configuration info found")

# Check the credentials and get the current region and account id
sts = boto3.client("sts")
sts = get_session().client("sts")
identity = sts.get_caller_identity()
logging.info("Using AWS identity: {}".format(identity["Arn"]))
print("Using AWS identity: {}".format(identity["Arn"]))
current_account_id = identity["Account"]
current_region = boto3.session.Session().region_name
current_region = get_session().region_name

# Get the default output bucket if one is not given
if config['output_s3_bucket'] == 'default':
Expand All @@ -143,16 +166,20 @@ def main():
athena = athena_querier(config['database'], "s3://"+config['output_s3_bucket'])

# Get all regions (needed for creating partitions)
ec2 = boto3.client("ec2")
ec2 = get_session().client("ec2")
region_response = ec2.describe_regions(AllRegions=True)["Regions"]
regions = []
for region in region_response:
regions.append(region["RegionName"])

# Ensure the CloudTrail log folder has the expected contents
s3 = boto3.client("s3")
s3 = get_session().client("s3")
log_path_prefix = config["cloudtrail_prefix"]

# Users will most likely forget this in the config, so we add it here
if not log_path_prefix.endswith('/'):
log_path_prefix += '/'

# Ensure we're running in the same region as the bucket
bucket_location = s3.get_bucket_location(Bucket=config["s3_bucket_containing_logs"])["LocationConstraint"]
if bucket_location is None:
Expand All @@ -167,15 +194,15 @@ def main():
Delimiter="/",
MaxKeys=1,
)

if "CommonPrefixes" not in resp or len(resp["CommonPrefixes"]) == 0:
exit(
"ERROR: S3 bucket has no contents. Ensure you have logs at s3://{bucket}/{path}".format(
bucket=config["s3_bucket_containing_logs"], path=log_path_prefix
)
)

if resp["CommonPrefixes"][0]["Prefix"] != "AWSLogs/":
if resp["CommonPrefixes"][0]["Prefix"] != log_path_prefix + "AWSLogs/":
exit(
"ERROR: S3 bucket path is incorrect. Ensure you have logs at s3://{bucket}/{path}/AWSLogs".format(
bucket=config["s3_bucket_containing_logs"], path=log_path_prefix
Expand Down Expand Up @@ -208,7 +235,7 @@ def main():
elif re.match("^[0-d]{12}$", directory_name):
accounts.append({"account_id": directory_name, "path_prefix": prefix})
else:
logger.info("Unexpected folder: {}".format(directory_name))
print("Unexpected folder: {}".format(directory_name))

# String to hold the SQL query that creates a view to allow searching all the tables.
view_query = ""
Expand Down

0 comments on commit 515f29e

Please sign in to comment.