Skip to content

Commit 1b016c1

Browse files
committed
Add network validation script executed in the sagemaker_ui_post_startup script
**Description** This change introduces the network validation script which tests if certain AWS services are reachable by making read only API calls with a set timeout. If the call exceeds the timeout, the script infers that it was caused by a bad network setup such as not having access to the internet/ VPC endpoint to make the call. API calls that resolve (succeed or fail) within the timeout are inferred as having the proper network setup. AWS services for Compute Connections and Git are checked in this script. More specifically, the script lists the datazone connections to see which services need to be checked. The unreachable services are aggregated and are displayed by writing to the post-startup-status.json, which displays the error notification in the IDE. **Testing** Tested in a SMUS portal containing internet, no internet, and no internet with VPC Endpoints to Datazone and s3.
1 parent 0fc2d44 commit 1b016c1

File tree

4 files changed

+400
-0
lines changed

4 files changed

+400
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/bin/bash
2+
set -eux
3+
4+
# Input parameters with defaults:
5+
# Default to 1 (Git storage) if no parameter is passed.
6+
is_s3_storage=${1:-"1"}
7+
# Output file path for unreachable services JSON
8+
network_validation_file=${2:-"/tmp/.network_validation.json"}
9+
10+
# Function to write unreachable services to a JSON file
11+
write_unreachable_services_to_file() {
12+
local value="$1"
13+
local file="$network_validation_file"
14+
15+
# Create the file if it doesn't exist
16+
if [ ! -f "$file" ]; then
17+
touch "$file" || {
18+
echo "Failed to create $file" >&2
19+
return 0
20+
}
21+
fi
22+
23+
# Check file is writable
24+
if [ ! -w "$file" ]; then
25+
echo "Error: $file is not writable" >&2
26+
return 0
27+
fi
28+
29+
# Write JSON object with UnreachableServices key and the comma-separated list value
30+
jq -n --arg value "$value" '{"UnreachableServices": $value}' > "$file"
31+
}
32+
33+
# Configure AWS CLI region using environment variable REGION_NAME
34+
aws configure set region "${REGION_NAME}"
35+
echo "Successfully configured region to ${REGION_NAME}"
36+
37+
# Metadata file location containing DataZone info
38+
sourceMetaData=/opt/ml/metadata/resource-metadata.json
39+
40+
# Extract necessary DataZone metadata fields via jq
41+
dataZoneDomainId=$(jq -r '.AdditionalMetadata.DataZoneDomainId' < "$sourceMetaData")
42+
dataZoneProjectId=$(jq -r '.AdditionalMetadata.DataZoneProjectId' < "$sourceMetaData")
43+
dataZoneEndPoint=$(jq -r '.AdditionalMetadata.DataZoneEndpoint' < "$sourceMetaData")
44+
dataZoneDomainRegion=$(jq -r '.AdditionalMetadata.DataZoneDomainRegion' < "$sourceMetaData")
45+
s3Path=$(jq -r '.AdditionalMetadata.ProjectS3Path' < "$sourceMetaData")
46+
47+
# Extract bucket name, fallback to empty string if not found
48+
s3ValidationBucket=$(echo "${s3Path:-}" | sed -E 's#s3://([^/]+).*#\1#')
49+
50+
# Call AWS CLI list-connections, including endpoint if specified
51+
if [ -n "$dataZoneEndPoint" ]; then
52+
response=$(aws datazone list-connections \
53+
--endpoint-url "$dataZoneEndPoint" \
54+
--domain-identifier "$dataZoneDomainId" \
55+
--project-identifier "$dataZoneProjectId" \
56+
--region "$dataZoneDomainRegion")
57+
else
58+
response=$(aws datazone list-connections \
59+
--domain-identifier "$dataZoneDomainId" \
60+
--project-identifier "$dataZoneProjectId" \
61+
--region "$dataZoneDomainRegion")
62+
fi
63+
64+
# Extract each connection item as a compact JSON string
65+
connection_items=$(echo "$response" | jq -c '.items[]')
66+
67+
# Required AWS Services for Compute connections and Git
68+
# Initialize SERVICE_COMMANDS with always-needed STS and S3 checks
69+
declare -A SERVICE_COMMANDS=(
70+
["STS"]="aws sts get-caller-identity"
71+
["S3"]="aws s3api list-objects --bucket \"$s3ValidationBucket\" --max-items 1"
72+
)
73+
74+
# Track connection types found for conditional checks
75+
declare -A seen_types=()
76+
77+
# Iterate over each connection to populate service commands conditionally
78+
while IFS= read -r item; do
79+
# Extract connection type
80+
type=$(echo "$item" | jq -r '.type')
81+
seen_types["$type"]=1
82+
83+
# For SPARK connections, check for Glue and EMR properties
84+
if [[ "$type" == "SPARK" ]]; then
85+
# If sparkGlueProperties present, add Glue check
86+
if echo "$item" | jq -e '.props.sparkGlueProperties' > /dev/null; then
87+
SERVICE_COMMANDS["Glue"]="aws glue get-databases --max-items 1"
88+
fi
89+
90+
# Check for emr-serverless in sparkEmrProperties.computeArn for EMR Serverless check
91+
emr_arn=$(echo "$item" | jq -r '.props.sparkEmrProperties.computeArn // empty')
92+
if [[ "$emr_arn" == *"emr-serverless"* && "$emr_arn" == *"/applications/"* ]]; then
93+
# Extract the application ID from the ARN
94+
emr_app_id=$(echo "$emr_arn" | sed -E 's#.*/applications/([^/]+)#\1#')
95+
96+
# Only set the service command if the application ID is valid
97+
if [[ -n "$emr_app_id" ]]; then
98+
SERVICE_COMMANDS["EMR Serverless"]="aws emr-serverless get-application --application-id \"$emr_app_id\""
99+
fi
100+
fi
101+
fi
102+
done <<< "$connection_items"
103+
104+
# Add Athena if ATHENA connection found
105+
[[ -n "${seen_types["ATHENA"]}" ]] && SERVICE_COMMANDS["Athena"]="aws athena list-data-catalogs --max-items 1"
106+
107+
# Add Redshift checks if REDSHIFT connection found
108+
if [[ -n "${seen_types["REDSHIFT"]}" ]]; then
109+
SERVICE_COMMANDS["Redshift Clusters"]="aws redshift describe-clusters --max-records 20"
110+
SERVICE_COMMANDS["Redshift Serverless"]="aws redshift-serverless list-namespaces --max-results 1"
111+
fi
112+
113+
# If using Git Storage (S3 storage flag == 1), check CodeConnections connectivity
114+
# Domain Execution role contains permissions for CodeConnections
115+
if [[ "$is_s3_storage" == "1" ]]; then
116+
SERVICE_COMMANDS["CodeConnections"]="aws codeconnections list-connections --max-results 1 --profile DomainExecutionRoleCreds"
117+
fi
118+
119+
# Timeout (seconds) for each API call
120+
api_time_out_limit=10
121+
# Array to accumulate unreachable services
122+
unreachable_services=()
123+
# Create a temporary directory to store individual service results
124+
temp_dir=$(mktemp -d)
125+
126+
# Launch all service API checks in parallel background jobs
127+
for service in "${!SERVICE_COMMANDS[@]}"; do
128+
{
129+
# Run command with timeout, discard stdout/stderr
130+
if timeout "${api_time_out_limit}s" bash -c "${SERVICE_COMMANDS[$service]}" > /dev/null 2>&1; then
131+
# Success: write OK to temp file
132+
echo "OK" > "$temp_dir/$service"
133+
else
134+
# Get exit code to differentiate timeout or other errors
135+
exit_code=$?
136+
if [ "$exit_code" -eq 124 ]; then
137+
# Timeout exit code
138+
echo "TIMEOUT" > "$temp_dir/$service"
139+
else
140+
# Other errors (e.g., permission denied)
141+
echo "ERROR" > "$temp_dir/$service"
142+
fi
143+
fi
144+
} &
145+
done
146+
147+
# Wait for all background jobs to complete before continuing
148+
wait
149+
150+
# Process each service's result file to identify unreachable ones
151+
for service in "${!SERVICE_COMMANDS[@]}"; do
152+
result_file="$temp_dir/$service"
153+
if [ -f "$result_file" ]; then
154+
result=$(<"$result_file")
155+
if [[ "$result" == "TIMEOUT" ]]; then
156+
echo "$service API did NOT resolve within ${api_time_out_limit}s. Marking as unreachable."
157+
unreachable_services+=("$service")
158+
elif [[ "$result" == "OK" ]]; then
159+
echo "$service API is reachable."
160+
else
161+
echo "$service API returned an error (but not a timeout). Ignored for network check."
162+
fi
163+
else
164+
echo "$service check did not produce a result file. Skipping."
165+
fi
166+
done
167+
168+
# Cleanup temporary directory
169+
rm -rf "$temp_dir"
170+
171+
# Write unreachable services to file if any, else write empty string
172+
if (( ${#unreachable_services[@]} > 0 )); then
173+
joined_services=$(IFS=','; echo "${unreachable_services[*]}")
174+
# Add spaces after commas for readability
175+
joined_services_with_spaces=${joined_services//,/,\ }
176+
write_unreachable_services_to_file "$joined_services_with_spaces"
177+
echo "Unreachable AWS Services: ${joined_services_with_spaces}"
178+
else
179+
write_unreachable_services_to_file ""
180+
echo "All required AWS services reachable within ${api_time_out_limit}s"
181+
fi

template/v2/dirs/etc/sagemaker-ui/sagemaker_ui_post_startup.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,4 +204,23 @@ if [ "${SAGEMAKER_APP_TYPE_LOWERCASE}" = "jupyterlab" ]; then
204204
bash /etc/sagemaker-ui/workflows/sm-spark-cli-install.sh
205205
fi
206206

207+
# Execute network validation script, to check if any required AWS Services are unreachable
208+
echo "Starting network validation script..."
209+
210+
network_validation_file="/tmp/.network_validation.json"
211+
212+
# Run the validation script; only if it succeeds, check unreachable services
213+
if bash /etc/sagemaker-ui/network_validation.sh "$is_s3_storage_flag" "$network_validation_file"; then
214+
# Read unreachable services from JSON file
215+
failed_services=$(jq -r '.UnreachableServices // empty' "$network_validation_file" || echo "")
216+
if [[ -n "$failed_services" ]]; then
217+
error_message="$failed_services are unreachable. Please contact your admin."
218+
# Example error message: Redshift Clusters, Athena, STS, Glue are unreachable. Please contact your admin.
219+
write_status_to_file "error" "$error_message"
220+
echo "$error_message"
221+
fi
222+
else
223+
echo "Warning: network_validation.sh failed, skipping unreachable services check."
224+
fi
225+
207226
write_status_to_file_on_script_complete
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/bin/bash
2+
set -eux
3+
4+
# Input parameters with defaults:
5+
# Default to 1 (Git storage) if no parameter is passed.
6+
is_s3_storage=${1:-"1"}
7+
# Output file path for unreachable services JSON
8+
network_validation_file=${2:-"/tmp/.network_validation.json"}
9+
10+
# Function to write unreachable services to a JSON file
11+
write_unreachable_services_to_file() {
12+
local value="$1"
13+
local file="$network_validation_file"
14+
15+
# Create the file if it doesn't exist
16+
if [ ! -f "$file" ]; then
17+
touch "$file" || {
18+
echo "Failed to create $file" >&2
19+
return 0
20+
}
21+
fi
22+
23+
# Check file is writable
24+
if [ ! -w "$file" ]; then
25+
echo "Error: $file is not writable" >&2
26+
return 0
27+
fi
28+
29+
# Write JSON object with UnreachableServices key and the comma-separated list value
30+
jq -n --arg value "$value" '{"UnreachableServices": $value}' > "$file"
31+
}
32+
33+
# Configure AWS CLI region using environment variable REGION_NAME
34+
aws configure set region "${REGION_NAME}"
35+
echo "Successfully configured region to ${REGION_NAME}"
36+
37+
# Metadata file location containing DataZone info
38+
sourceMetaData=/opt/ml/metadata/resource-metadata.json
39+
40+
# Extract necessary DataZone metadata fields via jq
41+
dataZoneDomainId=$(jq -r '.AdditionalMetadata.DataZoneDomainId' < "$sourceMetaData")
42+
dataZoneProjectId=$(jq -r '.AdditionalMetadata.DataZoneProjectId' < "$sourceMetaData")
43+
dataZoneEndPoint=$(jq -r '.AdditionalMetadata.DataZoneEndpoint' < "$sourceMetaData")
44+
dataZoneDomainRegion=$(jq -r '.AdditionalMetadata.DataZoneDomainRegion' < "$sourceMetaData")
45+
s3Path=$(jq -r '.AdditionalMetadata.ProjectS3Path' < "$sourceMetaData")
46+
47+
# Extract bucket name, fallback to empty string if not found
48+
s3ValidationBucket=$(echo "${s3Path:-}" | sed -E 's#s3://([^/]+).*#\1#')
49+
50+
# Call AWS CLI list-connections, including endpoint if specified
51+
if [ -n "$dataZoneEndPoint" ]; then
52+
response=$(aws datazone list-connections \
53+
--endpoint-url "$dataZoneEndPoint" \
54+
--domain-identifier "$dataZoneDomainId" \
55+
--project-identifier "$dataZoneProjectId" \
56+
--region "$dataZoneDomainRegion")
57+
else
58+
response=$(aws datazone list-connections \
59+
--domain-identifier "$dataZoneDomainId" \
60+
--project-identifier "$dataZoneProjectId" \
61+
--region "$dataZoneDomainRegion")
62+
fi
63+
64+
# Extract each connection item as a compact JSON string
65+
connection_items=$(echo "$response" | jq -c '.items[]')
66+
67+
# Required AWS Services for Compute connections and Git
68+
# Initialize SERVICE_COMMANDS with always-needed STS and S3 checks
69+
declare -A SERVICE_COMMANDS=(
70+
["STS"]="aws sts get-caller-identity"
71+
["S3"]="aws s3api list-objects --bucket \"$s3ValidationBucket\" --max-items 1"
72+
)
73+
74+
# Track connection types found for conditional checks
75+
declare -A seen_types=()
76+
77+
# Iterate over each connection to populate service commands conditionally
78+
while IFS= read -r item; do
79+
# Extract connection type
80+
type=$(echo "$item" | jq -r '.type')
81+
seen_types["$type"]=1
82+
83+
# For SPARK connections, check for Glue and EMR properties
84+
if [[ "$type" == "SPARK" ]]; then
85+
# If sparkGlueProperties present, add Glue check
86+
if echo "$item" | jq -e '.props.sparkGlueProperties' > /dev/null; then
87+
SERVICE_COMMANDS["Glue"]="aws glue get-databases --max-items 1"
88+
fi
89+
90+
# Check for emr-serverless in sparkEmrProperties.computeArn for EMR Serverless check
91+
emr_arn=$(echo "$item" | jq -r '.props.sparkEmrProperties.computeArn // empty')
92+
if [[ "$emr_arn" == *"emr-serverless"* && "$emr_arn" == *"/applications/"* ]]; then
93+
# Extract the application ID from the ARN
94+
emr_app_id=$(echo "$emr_arn" | sed -E 's#.*/applications/([^/]+)#\1#')
95+
96+
# Only set the service command if the application ID is valid
97+
if [[ -n "$emr_app_id" ]]; then
98+
SERVICE_COMMANDS["EMR Serverless"]="aws emr-serverless get-application --application-id \"$emr_app_id\""
99+
fi
100+
fi
101+
fi
102+
done <<< "$connection_items"
103+
104+
# Add Athena if ATHENA connection found
105+
[[ -n "${seen_types["ATHENA"]}" ]] && SERVICE_COMMANDS["Athena"]="aws athena list-data-catalogs --max-items 1"
106+
107+
# Add Redshift checks if REDSHIFT connection found
108+
if [[ -n "${seen_types["REDSHIFT"]}" ]]; then
109+
SERVICE_COMMANDS["Redshift Clusters"]="aws redshift describe-clusters --max-records 20"
110+
SERVICE_COMMANDS["Redshift Serverless"]="aws redshift-serverless list-namespaces --max-results 1"
111+
fi
112+
113+
# If using Git Storage (S3 storage flag == 1), check CodeConnections connectivity
114+
# Domain Execution role contains permissions for CodeConnections
115+
if [[ "$is_s3_storage" == "1" ]]; then
116+
SERVICE_COMMANDS["CodeConnections"]="aws codeconnections list-connections --max-results 1 --profile DomainExecutionRoleCreds"
117+
fi
118+
119+
# Timeout (seconds) for each API call
120+
api_time_out_limit=10
121+
# Array to accumulate unreachable services
122+
unreachable_services=()
123+
# Create a temporary directory to store individual service results
124+
temp_dir=$(mktemp -d)
125+
126+
# Launch all service API checks in parallel background jobs
127+
for service in "${!SERVICE_COMMANDS[@]}"; do
128+
{
129+
# Run command with timeout, discard stdout/stderr
130+
if timeout "${api_time_out_limit}s" bash -c "${SERVICE_COMMANDS[$service]}" > /dev/null 2>&1; then
131+
# Success: write OK to temp file
132+
echo "OK" > "$temp_dir/$service"
133+
else
134+
# Get exit code to differentiate timeout or other errors
135+
exit_code=$?
136+
if [ "$exit_code" -eq 124 ]; then
137+
# Timeout exit code
138+
echo "TIMEOUT" > "$temp_dir/$service"
139+
else
140+
# Other errors (e.g., permission denied)
141+
echo "ERROR" > "$temp_dir/$service"
142+
fi
143+
fi
144+
} &
145+
done
146+
147+
# Wait for all background jobs to complete before continuing
148+
wait
149+
150+
# Process each service's result file to identify unreachable ones
151+
for service in "${!SERVICE_COMMANDS[@]}"; do
152+
result_file="$temp_dir/$service"
153+
if [ -f "$result_file" ]; then
154+
result=$(<"$result_file")
155+
if [[ "$result" == "TIMEOUT" ]]; then
156+
echo "$service API did NOT resolve within ${api_time_out_limit}s. Marking as unreachable."
157+
unreachable_services+=("$service")
158+
elif [[ "$result" == "OK" ]]; then
159+
echo "$service API is reachable."
160+
else
161+
echo "$service API returned an error (but not a timeout). Ignored for network check."
162+
fi
163+
else
164+
echo "$service check did not produce a result file. Skipping."
165+
fi
166+
done
167+
168+
# Cleanup temporary directory
169+
rm -rf "$temp_dir"
170+
171+
# Write unreachable services to file if any, else write empty string
172+
if (( ${#unreachable_services[@]} > 0 )); then
173+
joined_services=$(IFS=','; echo "${unreachable_services[*]}")
174+
# Add spaces after commas for readability
175+
joined_services_with_spaces=${joined_services//,/,\ }
176+
write_unreachable_services_to_file "$joined_services_with_spaces"
177+
echo "Unreachable AWS Services: ${joined_services_with_spaces}"
178+
else
179+
write_unreachable_services_to_file ""
180+
echo "All required AWS services reachable within ${api_time_out_limit}s"
181+
fi

0 commit comments

Comments
 (0)