25
25
from api .pcm_globals import set_auth_cookies_in_context , logger , auth_cookies
26
26
from api .security .csrf .constants import CSRF_COOKIE_NAME
27
27
from api .security .csrf .csrf import csrf_needed
28
- from api .utils import disable_auth
28
+ from api .utils import disable_auth , read_and_delete_ssm_output_from_cloudwatch
29
29
from api .validation import validated
30
30
from api .validation .schemas import PCProxyArgs , PCProxyBody
31
31
32
32
USER_POOL_ID = os .getenv ("USER_POOL_ID" )
33
33
AUTH_PATH = os .getenv ("AUTH_PATH" )
34
34
API_BASE_URL = os .getenv ("API_BASE_URL" )
35
- API_VERSION = os .getenv ("API_VERSION" , "3.1.0" )
35
+ API_VERSION = sorted (set (os .getenv ("API_VERSION" , "3.1.0" ).strip ().split ("," )), key = lambda x : [- int (n ) for n in x .split ('.' )])
36
+ # Default version must be highest version so that it can be used for read operations due to backwards compatibility
37
+ DEFAULT_API_VERSION = API_VERSION [0 ]
36
38
API_USER_ROLE = os .getenv ("API_USER_ROLE" )
37
39
OIDC_PROVIDER = os .getenv ("OIDC_PROVIDER" )
38
40
CLIENT_ID = os .getenv ("CLIENT_ID" )
39
41
CLIENT_SECRET = os .getenv ("CLIENT_SECRET" )
40
42
SECRET_ID = os .getenv ("SECRET_ID" )
41
- SITE_URL = os .getenv ("SITE_URL" , API_BASE_URL )
42
43
SCOPES_LIST = os .getenv ("SCOPES_LIST" )
43
44
REGION = os .getenv ("AWS_DEFAULT_REGION" )
44
45
TOKEN_URL = os .getenv ("TOKEN_URL" , f"{ AUTH_PATH } /oauth2/token" )
47
48
JWKS_URL = os .getenv ("JWKS_URL" )
48
49
AUDIENCE = os .getenv ("AUDIENCE" )
49
50
USER_ROLES_CLAIM = os .getenv ("USER_ROLES_CLAIM" , "cognito:groups" )
51
+ SSM_LOG_GROUP_NAME = os .getenv ("SSM_LOG_GROUP_NAME" )
52
+ ARG_VERSION = "version"
50
53
51
54
try :
52
55
if (not USER_POOL_ID or USER_POOL_ID == "" ) and SECRET_ID :
62
65
JWKS_URL = os .getenv ("JWKS_URL" ,
63
66
f"https://cognito-idp.{ REGION } .amazonaws.com/{ USER_POOL_ID } /" ".well-known/jwks.json" )
64
67
68
+ def create_url_map (url_list ):
69
+ url_map = {}
70
+ if url_list :
71
+ for url in url_list .split ("," ):
72
+ if url :
73
+ pair = url .split ("=" )
74
+ url_map [pair [0 ]] = pair [1 ]
75
+ return url_map
76
+
77
+ API_BASE_URL_MAPPING = create_url_map (API_BASE_URL )
78
+ SITE_URL = os .getenv ("SITE_URL" , API_BASE_URL_MAPPING .get (DEFAULT_API_VERSION ))
79
+
80
+
65
81
66
82
def jwt_decode (token , audience = None , access_token = None ):
67
83
return jwt .decode (
@@ -164,7 +180,7 @@ def authenticate(groups):
164
180
165
181
if (not groups ):
166
182
return abort (403 )
167
-
183
+
168
184
jwt_roles = set (decoded .get (USER_ROLES_CLAIM , []))
169
185
groups_granted = groups .intersection (jwt_roles )
170
186
if len (groups_granted ) == 0 :
@@ -190,7 +206,7 @@ def get_scopes_list():
190
206
191
207
def get_redirect_uri ():
192
208
return f"{ SITE_URL } /login"
193
-
209
+
194
210
# Local Endpoints
195
211
196
212
@@ -232,9 +248,9 @@ def ec2_action():
232
248
def get_cluster_config_text (cluster_name , region = None ):
233
249
url = f"/v3/clusters/{ cluster_name } "
234
250
if region :
235
- info_resp = sigv4_request ("GET" , API_BASE_URL , url , params = {"region" : region })
251
+ info_resp = sigv4_request ("GET" , get_base_url ( request ) , url , params = {"region" : region })
236
252
else :
237
- info_resp = sigv4_request ("GET" , API_BASE_URL , url )
253
+ info_resp = sigv4_request ("GET" , get_base_url ( request ) , url )
238
254
if info_resp .status_code != 200 :
239
255
abort (info_resp .status_code )
240
256
@@ -264,10 +280,16 @@ def ssm_command(region, instance_id, user, run_command):
264
280
DocumentName = "AWS-RunShellScript" ,
265
281
Comment = f"Run ssm command." ,
266
282
Parameters = {"commands" : [command ]},
283
+ CloudWatchOutputConfig = {
284
+ 'CloudWatchLogGroupName' : SSM_LOG_GROUP_NAME ,
285
+ 'CloudWatchOutputEnabled' : True
286
+ },
267
287
)
268
288
269
289
command_id = ssm_resp ["Command" ]["CommandId" ]
270
290
291
+ logger .info (f"Submitted SSM command { command_id } " )
292
+
271
293
# Wait for command to complete
272
294
time .sleep (0.75 )
273
295
while time .time () - start < 60 :
@@ -282,7 +304,13 @@ def ssm_command(region, instance_id, user, run_command):
282
304
if status ["Status" ] != "Success" :
283
305
raise Exception (status ["StandardErrorContent" ])
284
306
285
- output = status ["StandardOutputContent" ]
307
+ output = read_and_delete_ssm_output_from_cloudwatch (
308
+ region = region ,
309
+ log_group_name = SSM_LOG_GROUP_NAME ,
310
+ command_id = command_id ,
311
+ instance_id = instance_id ,
312
+ )
313
+
286
314
return output
287
315
288
316
@@ -352,7 +380,7 @@ def sacct():
352
380
user ,
353
381
f"sacct { sacct_args } --json "
354
382
+ "| jq -c .jobs[0:120]\\ |\\ map\\ ({name,user,partition,state,job_id,exit_code\\ }\\ )" ,
355
- )
383
+ )
356
384
if type (accounting ) is tuple :
357
385
return accounting
358
386
else :
@@ -471,7 +499,7 @@ def get_dcv_session():
471
499
472
500
473
501
def get_custom_image_config ():
474
- image_info = sigv4_request ("GET" , API_BASE_URL , f"/v3/images/custom/{ request .args .get ('image_id' )} " ).json ()
502
+ image_info = sigv4_request ("GET" , get_base_url ( request ) , f"/v3/images/custom/{ request .args .get ('image_id' )} " ).json ()
475
503
configuration = requests .get (image_info ["imageConfiguration" ]["url" ])
476
504
return configuration .text
477
505
@@ -553,13 +581,7 @@ def get_instance_types():
553
581
ec2 = boto3 .client ("ec2" , config = config )
554
582
else :
555
583
ec2 = boto3 .client ("ec2" )
556
- filters = [
557
- {"Name" : "current-generation" , "Values" : ["true" ]},
558
- {"Name" : "instance-type" ,
559
- "Values" : [
560
- "c5*" , "c6*" , "c7*" , "g4*" , "g5*" , "g6*" , "hpc*" , "p3*" , "p4*" , "p5*" , "t2*" , "t3*" , "m6*" , "m7*" , "r*"
561
- ]},
562
- ]
584
+ filters = [{"Name" : "current-generation" , "Values" : ["true" ]}]
563
585
instance_paginator = ec2 .get_paginator ("describe_instance_types" )
564
586
instances_paginator = instance_paginator .paginate (Filters = filters )
565
587
instance_types = []
@@ -583,9 +605,9 @@ def _get_identity_from_token(decoded, claims):
583
605
identity ["username" ] = decoded ["username" ]
584
606
585
607
for claim in claims :
586
- if claim in decoded :
587
- identity ["attributes" ][claim ] = decoded [claim ]
588
-
608
+ if claim in decoded :
609
+ identity ["attributes" ][claim ] = decoded [claim ]
610
+
589
611
return identity
590
612
591
613
def get_identity ():
@@ -722,14 +744,20 @@ def _get_params(_request):
722
744
params .pop ("path" )
723
745
return params
724
746
747
+ def get_base_url (request ):
748
+ version = request .args .get (ARG_VERSION )
749
+ if version and str (version ) in API_VERSION :
750
+ return API_BASE_URL_MAPPING [str (version )]
751
+ return API_BASE_URL_MAPPING [DEFAULT_API_VERSION ]
752
+
725
753
726
754
pc = Blueprint ('pc' , __name__ )
727
755
728
756
@pc .get ('/' , strict_slashes = False )
729
757
@authenticated ({'admin' })
730
758
@validated (params = PCProxyArgs )
731
759
def pc_proxy_get ():
732
- response = sigv4_request (request .method , API_BASE_URL , request .args .get ("path" ), _get_params (request ))
760
+ response = sigv4_request (request .method , get_base_url ( request ) , request .args .get ("path" ), _get_params (request ))
733
761
return response .json (), response .status_code
734
762
735
763
@pc .route ('/' , methods = ['POST' ,'PUT' ,'PATCH' ,'DELETE' ], strict_slashes = False )
@@ -743,5 +771,5 @@ def pc_proxy():
743
771
except :
744
772
pass
745
773
746
- response = sigv4_request (request .method , API_BASE_URL , request .args .get ("path" ), _get_params (request ), body = body )
774
+ response = sigv4_request (request .method , get_base_url ( request ) , request .args .get ("path" ), _get_params (request ), body = body )
747
775
return response .json (), response .status_code
0 commit comments