Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for AWS Launch Template Configuration #2668

Merged
merged 19 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
c0b0c87
Add launch_template config options to aws cluster
viniciusdc Aug 28, 2024
245db3b
add ami_type options to pydantic schema to reduce HCL conditionals
viniciusdc Aug 28, 2024
38b7d0d
add dynamic launch_template to eks_node_group
viniciusdc Aug 28, 2024
1578953
Merge branch 'develop' into 2603-aws-node-launch-template
viniciusdc Aug 30, 2024
49f69ab
Merge branch 'develop' into 2603-aws-node-launch-template
viniciusdc Sep 2, 2024
b93361b
small cleanup refactoring of launch_template model
viniciusdc Sep 6, 2024
c43d8db
Merge branch 'develop' into 2603-aws-node-launch-template
viniciusdc Sep 6, 2024
4afb503
use exclude for ami_type instead of private method
viniciusdc Sep 6, 2024
4c2aee7
Merge branch 'develop' into 2603-aws-node-launch-template
viniciusdc Sep 11, 2024
7981ba9
fix missing var name & fix deployment bug & rm validation restrictions
viniciusdc Sep 11, 2024
6aafcdc
fixes
viniciusdc Sep 16, 2024
c211fa6
fixes on ami_id
viniciusdc Sep 17, 2024
1f392e8
add try to assert block to inspect error
viniciusdc Sep 17, 2024
4952589
Merge branch 'develop' into 2603-aws-node-launch-template
viniciusdc Sep 17, 2024
7f909d1
Merge branch 'develop' into 2603-aws-node-launch-template
viniciusdc Sep 17, 2024
50c6a5f
fix user_data and CUSTOM ami_type logic
viniciusdc Sep 18, 2024
749ae38
[pre-commit.ci] Apply automatic pre-commit fixes
pre-commit-ci[bot] Sep 18, 2024
7df8ca2
Merge branch 'develop' into 2603-aws-node-launch-template
viniciusdc Sep 18, 2024
5a6bda3
rm aux aws.amis method
viniciusdc Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ class AzureInputVars(schema.Base):
workload_identity_enabled: bool = False


class AWSAmiTypes(enum.Enum):
AL2_x86_64 = "AL2_x86_64"
AL2_x86_64_GPU = "AL2_x86_64_GPU"
CUSTOM = "CUSTOM"


class AWSNodeLaunchTemplate(schema.Base):
pre_bootstrap_command: Optional[str] = None
ami_id: Optional[str] = None


class AWSNodeGroupInputVars(schema.Base):
name: str
instance_type: str
Expand All @@ -137,6 +148,28 @@ class AWSNodeGroupInputVars(schema.Base):
max_size: int
single_subnet: bool
permissions_boundary: Optional[str] = None
ami_type: Optional[AWSAmiTypes] = None
launch_template: Optional[AWSNodeLaunchTemplate] = None

@field_validator("ami_type", mode="before")
@classmethod
def _infer_and_validate_ami_type(cls, value, values) -> str:
gpu_enabled = values.get("gpu", False)

# Auto-set ami_type if not provided
if not value:
if values.get("launch_template") and values["launch_template"].ami_id:
return "CUSTOM"
if gpu_enabled:
return "AL2_x86_64_GPU"
return "AL2_x86_64"

# Explicit validation
if value == "AL2_x86_64" and gpu_enabled:
raise ValueError(
"ami_type 'AL2_x86_64' cannot be used with GPU enabled (gpu=True)."
)
return value


class AWSInputVars(schema.Base):
Expand Down Expand Up @@ -449,6 +482,7 @@ class AWSNodeGroup(schema.Base):
gpu: bool = False
single_subnet: bool = False
permissions_boundary: Optional[str] = None
launch_template: Optional[AWSNodeLaunchTemplate] = None


DEFAULT_AWS_NODE_GROUPS = {
Expand Down Expand Up @@ -525,6 +559,7 @@ def _check_input(cls, data: Any) -> Any:
raise ValueError(
f"Amazon Web Services instance {node_group.instance} not one of available instance types={available_instances}"
)

return data


Expand Down Expand Up @@ -828,6 +863,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
max_size=node_group.max_nodes,
single_subnet=node_group.single_subnet,
permissions_boundary=node_group.permissions_boundary,
launch_template=node_group.launch_template,
)
for name, node_group in self.config.amazon_web_services.node_groups.items()
],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
MIME-Version: 1.0
Content-Type: multipart/mixed; boundary="//"

%{ if node_pre_bootstrap_command != null }
--//
Content-Type: text/x-shellscript; charset="us-ascii"

${node_pre_bootstrap_command}
%{ endif }

%{ if include_bootstrap_cmd }
--//
Content-Type: text/x-shellscript; charset="us-ascii"
#!/bin/bash
set -ex

/etc/eks/bootstrap.sh ${cluster_name} --b64-cluster-ca ${cluster_cert_authority} --apiserver-endpoint ${cluster_endpoint}
%{ endif }

--//
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,52 @@ resource "aws_eks_cluster" "main" {
tags = merge({ Name = var.name }, var.tags)
}

## aws_launch_template user_data invocation
## If using a Custom AMI, then the /etc/eks/bootstrap cmds and args must be included/modified,
## otherwise, on default AWS EKS Node AMI, the bootstrap cmd is appended automatically
resource "aws_launch_template" "main" {
for_each = {
for node_group in var.node_groups :
node_group.name => node_group
if node_group.launch_template != null
}

name_prefix = "eks-${var.name}-${each.value.name}-"
image_id = each.value.launch_template.ami_id

vpc_security_group_ids = var.cluster_security_groups


metadata_options {
http_tokens = "required"
http_endpoint = "enabled"
instance_metadata_tags = "enabled"
}

block_device_mappings {
device_name = "/dev/xvda"
ebs {
volume_size = 50
volume_type = "gp2"
}
}

# https://docs.aws.amazon.com/eks/latest/userguide/launch-templates.html#launch-template-basics
user_data = base64encode(
templatefile(
"${path.module}/files/user_data.tftpl",
{
node_pre_bootstrap_command = each.value.launch_template.pre_bootstrap_command
# This will ensure the bootstrap user data is used to join the node
include_bootstrap_cmd = each.value.launch_template.ami_id != null ? true : false
cluster_name = aws_eks_cluster.main.name
cluster_cert_authority = aws_eks_cluster.main.certificate_authority[0].data
cluster_endpoint = aws_eks_cluster.main.endpoint
}
)
)
}


resource "aws_eks_node_group" "main" {
count = length(var.node_groups)
Expand All @@ -31,15 +77,24 @@ resource "aws_eks_node_group" "main" {
subnet_ids = var.node_groups[count.index].single_subnet ? [element(var.cluster_subnets, 0)] : var.cluster_subnets

instance_types = [var.node_groups[count.index].instance_type]
ami_type = var.node_groups[count.index].gpu == true ? "AL2_x86_64_GPU" : "AL2_x86_64"
disk_size = 50
ami_type = var.node_groups[count.index].ami_type
disk_size = var.node_groups[count.index].launch_template == null ? 50 : null

scaling_config {
min_size = var.node_groups[count.index].min_size
desired_size = var.node_groups[count.index].desired_size
max_size = var.node_groups[count.index].max_size
}

# Only set launch_template if its node_group counterpart parameter is not null
dynamic "launch_template" {
for_each = var.node_groups[count.index].launch_template != null ? [0] : []
content {
id = aws_launch_template.main[var.node_groups[count.index].name].id
version = aws_launch_template.main[var.node_groups[count.index].name].latest_version
}
}

labels = {
"dedicated" = var.node_groups[count.index].name
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ variable "node_group_additional_policies" {
variable "node_groups" {
description = "Node groups to add to EKS Cluster"
type = list(object({
name = string
instance_type = string
gpu = bool
min_size = number
desired_size = number
max_size = number
single_subnet = bool
name = string
instance_type = string
gpu = bool
min_size = number
desired_size = number
max_size = number
single_subnet = bool
launch_template = map(any)
ami_type = string
}))
}

Expand Down
16 changes: 9 additions & 7 deletions src/_nebari/stages/infrastructure/template/aws/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ variable "kubernetes_version" {
variable "node_groups" {
description = "AWS node groups"
type = list(object({
name = string
instance_type = string
gpu = bool
min_size = number
desired_size = number
max_size = number
single_subnet = bool
name = string
instance_type = string
gpu = bool
min_size = number
desired_size = number
max_size = number
single_subnet = bool
launch_template = map(any)
ami_type = string
}))
}

Expand Down
1 change: 0 additions & 1 deletion src/_nebari/stages/terraform_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def check_immutable_fields(self):
nebari_config_diff = utils.JsonDiff(
nebari_config_state.model_dump(), self.config.model_dump()
)

# check if any changed fields are immutable
for keys, old, new in nebari_config_diff.modified():
bottom_level_schema = self.config
Expand Down
22 changes: 14 additions & 8 deletions tests/tests_unit/test_cli_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,26 @@ def test_cli_validate_from_env():
["validate", "--config", tmp_file.resolve()],
env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.20"},
)

assert 0 == valid_result.exit_code
assert not valid_result.exception
assert "Successfully validated configuration" in valid_result.stdout
try:
assert 0 == valid_result.exit_code
assert not valid_result.exception
assert "Successfully validated configuration" in valid_result.stdout
except AssertionError:
print(valid_result.stdout)
raise

invalid_result = runner.invoke(
app,
["validate", "--config", tmp_file.resolve()],
env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.0"},
)

assert 1 == invalid_result.exit_code
assert invalid_result.exception
assert "Invalid `kubernetes-version`" in invalid_result.stdout
try:
assert 1 == invalid_result.exit_code
assert invalid_result.exception
assert "Invalid `kubernetes-version`" in invalid_result.stdout
except AssertionError:
print(invalid_result.stdout)
raise


@pytest.mark.parametrize(
Expand Down
Loading