Skip to content

Commit

Permalink
cleaning + linting
Browse files Browse the repository at this point in the history
  • Loading branch information
arthurprevot committed Nov 27, 2023
1 parent a8e1c2d commit 3e05fb3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
11 changes: 5 additions & 6 deletions yaetos/airflow_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
from textwrap import dedent, indent


def get_template(params, param_extras):

params['KeepJobFlowAliveWhenNoSteps'] = params['deploy_args'].get('leave_on', False)
Expand All @@ -19,17 +20,15 @@ def get_template(params, param_extras):
""".format(**params)

instance_groups_extra = instance_groups_extra if params['emr_core_instances'] != 0 else ''
params['instance_groups_extra'] = indent(instance_groups_extra, ' '*12)
params['instance_groups_extra'] = indent(instance_groups_extra, ' ' * 12)

# Set extra params, params not available in template but overloadable
# import ipdb; ipdb.set_trace() # will drop to python terminal here to inspect # noqa: E702
lines = ''
for item in param_extras.keys():
entries = item.replace('airflow.', '').split('.')
entries = '"]["'.join(entries)
line = f'DAG_ARGS["{entries}"] = {param_extras[item]}\n' + ' ' * 4
lines += line

params['extras'] = lines

template = """
Expand Down Expand Up @@ -129,21 +128,21 @@ def get_template(params, param_extras):
]
with DAG(**DAG_ARGS) as dag:
cluster_creator = EmrCreateJobFlowOperator(
task_id='start_emr_cluster',
aws_conn_id='aws_default',
emr_conn_id='emr_default',
job_flow_overrides=CLUSTER_JOB_FLOW_OVERRIDES
)
step_adder = EmrAddStepsOperator(
task_id='add_steps',
job_flow_id="{{{{ task_instance.xcom_pull(task_ids='start_emr_cluster', key='return_value') }}}}",
aws_conn_id='aws_default',
steps=EMR_STEPS,
)
step_checker = EmrStepSensor(
task_id='watch_step',
job_flow_id="{{{{ task_instance.xcom_pull('start_emr_cluster', key='return_value') }}}}",
Expand Down
8 changes: 4 additions & 4 deletions yaetos/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, deploy_args, app_args):
# Paths
self.s3_logs = CPt(app_args.get('s3_logs', 's3://').replace('{root_path}', self.app_args.get('root_path', '')))
self.s3_bucket_logs = self.s3_logs.bucket
self.metadata_folder = 'pipelines_metadata' # TODO remove
self.metadata_folder = 'pipelines_metadata' # TODO remove
self.pipeline_name = self.generate_pipeline_name(self.deploy_args['mode'], self.app_args['job_name'], self.user) # format: some_job.some_user.20181204.153429
self.job_log_path = self.get_job_log_path() # format: yaetos/logs/some_job.some_user.20181204.153429
self.job_log_path_with_bucket = '{}/{}'.format(self.s3_bucket_logs, self.job_log_path) # format: bucket-tempo/yaetos/logs/some_job.some_user.20181204.153429
Expand Down Expand Up @@ -695,7 +695,7 @@ def create_dags(self):
start_date = 'None'
else:
start_date = f'dateutil.parser.parse("{start_date}")'

# Set schedule, should be string evaluable in python, or string compatible with airflow
freq_input = self.deploy_args.get('frequency', '@once')
if freq_input.startswith('{') and freq_input.endswith('}'):
Expand All @@ -705,7 +705,7 @@ def create_dags(self):
else:
schedule = f"'{freq_input}'"

params={
params = {
'ec2_instance_slaves': self.ec2_instance_slaves,
'emr_core_instances': self.emr_core_instances,
'package_path_with_bucket': self.package_path_with_bucket,
Expand All @@ -730,7 +730,7 @@ def create_dags(self):
content = get_template(params, param_extras)
if not os.path.isdir(self.DAGS):
os.mkdir(self.DAGS)

job_dag_name = self.set_job_dag_name(self.app_args['job_name'])
fname = self.DAGS / Pt(job_dag_name)

Expand Down
2 changes: 1 addition & 1 deletion yaetos/pandas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def save_pandas_local(df, path, save_method='to_csv', save_kwargs={}):
# --- other ----

def query_pandas(query_str, dfs):
assert DUCKDB_SETUP == True
assert DUCKDB_SETUP is True
con = duckdb.connect(database=':memory:')
for key, value in dfs.items():
con.register(key, value)
Expand Down

0 comments on commit 3e05fb3

Please sign in to comment.