diff --git a/.golangci.yaml b/.golangci.yaml index 6e3ff8b..2fc1b71 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -97,6 +97,7 @@ linters: enable-all: true disable: - deadcode + - depguard - exhaustivestruct - exhaustruct - forbidigo diff --git a/cmd/root.go b/cmd/root.go index 66688e3..81b71f0 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -13,6 +13,7 @@ func Execute() { filePath string outputPath string dryRun bool + recursive bool ) rootCmd := &cobra.Command{ @@ -34,6 +35,11 @@ func Execute() { i := tsort.NewIngestor() + if recursive { + // Ignore the outputPath when in recursive mode + return i.ParseAll(filePath, dryRun) + } + return i.Parse(filePath, outputPath, dryRun) }, } @@ -43,12 +49,17 @@ func Execute() { "out", "o", "", - "path to the output file") + "path to the output file. Ignored if --recursive is used.") rootCmd.PersistentFlags().BoolVarP( &dryRun, "dry-run", "d", false, "preview the changes without altering the original file.") + rootCmd.PersistentFlags().BoolVarP( + &recursive, + "recursive", + "r", false, + "parse all Terraform files within the provided directory and its subdirectories recursively") if err := rootCmd.Execute(); err != nil { log.Fatalf("error: %s", err) diff --git a/tsort/testdata/recursive/valid.tf b/tsort/testdata/recursive/valid.tf new file mode 100644 index 0000000..cdaa4a5 --- /dev/null +++ b/tsort/testdata/recursive/valid.tf @@ -0,0 +1,82 @@ +variable "kubernetes_pipeline_roles" { + description = "IAM roles for pipelines required access to EKS." + type = list(object({ + rolearn = string + namespaces = list(string) + })) + default = [] +} + + + + +variable "kubernetes_pipeline_users" { + description = "IAM users for pipelines required access to EKS." + type = list(object({ + userarn = string + namespaces = list(string) + })) + default = [] +} + + +variable "external_dns_additional_managed_zones" { + description = "Additional managed zones for external-dns." + type = list(string) + default = [] +} + +variable "aws-profile" { + description = "The aws profile name, used when creating the kubeconfig file." +} + +variable "additional_userdata" { + default = "" +} + + + + +terraform { + required_version = ">= 0.12" +} + + + + +variable "eks_shared_namespaces" { + description = "Namespaces to be shared between teams." + type = map(list(string)) + default = { + dns = ["external-dns"] + infra = ["infra-shared"] + logging = ["logging"] + monitoring = ["infra-monitoring"] + ingress = ["infra-ingress"] + argo = ["argo"] + newrelic = ["newrelic"] + } +} + +output "hardened-image-id" { + description = "The AMI ID of the hardened image." + value = data.aws_ami.hardened.id +} + + + +locals { + kubernetes_pipeline_roles = [ + for role in var.kubernetes_pipeline_roles : { + rolearn = role.rolearn + namespaces = role.namespaces + } + ] +} + + + +output "private_key_name" { + description = "The name of the private key made to share." + value = module.account.private_key_name +} diff --git a/tsort/testdata/recursive/valid1.tf b/tsort/testdata/recursive/valid1.tf new file mode 100644 index 0000000..cdaa4a5 --- /dev/null +++ b/tsort/testdata/recursive/valid1.tf @@ -0,0 +1,82 @@ +variable "kubernetes_pipeline_roles" { + description = "IAM roles for pipelines required access to EKS." + type = list(object({ + rolearn = string + namespaces = list(string) + })) + default = [] +} + + + + +variable "kubernetes_pipeline_users" { + description = "IAM users for pipelines required access to EKS." + type = list(object({ + userarn = string + namespaces = list(string) + })) + default = [] +} + + +variable "external_dns_additional_managed_zones" { + description = "Additional managed zones for external-dns." + type = list(string) + default = [] +} + +variable "aws-profile" { + description = "The aws profile name, used when creating the kubeconfig file." +} + +variable "additional_userdata" { + default = "" +} + + + + +terraform { + required_version = ">= 0.12" +} + + + + +variable "eks_shared_namespaces" { + description = "Namespaces to be shared between teams." + type = map(list(string)) + default = { + dns = ["external-dns"] + infra = ["infra-shared"] + logging = ["logging"] + monitoring = ["infra-monitoring"] + ingress = ["infra-ingress"] + argo = ["argo"] + newrelic = ["newrelic"] + } +} + +output "hardened-image-id" { + description = "The AMI ID of the hardened image." + value = data.aws_ami.hardened.id +} + + + +locals { + kubernetes_pipeline_roles = [ + for role in var.kubernetes_pipeline_roles : { + rolearn = role.rolearn + namespaces = role.namespaces + } + ] +} + + + +output "private_key_name" { + description = "The name of the private key made to share." + value = module.account.private_key_name +} diff --git a/tsort/testdata/recursive/valid2.tf b/tsort/testdata/recursive/valid2.tf new file mode 100644 index 0000000..cdaa4a5 --- /dev/null +++ b/tsort/testdata/recursive/valid2.tf @@ -0,0 +1,82 @@ +variable "kubernetes_pipeline_roles" { + description = "IAM roles for pipelines required access to EKS." + type = list(object({ + rolearn = string + namespaces = list(string) + })) + default = [] +} + + + + +variable "kubernetes_pipeline_users" { + description = "IAM users for pipelines required access to EKS." + type = list(object({ + userarn = string + namespaces = list(string) + })) + default = [] +} + + +variable "external_dns_additional_managed_zones" { + description = "Additional managed zones for external-dns." + type = list(string) + default = [] +} + +variable "aws-profile" { + description = "The aws profile name, used when creating the kubeconfig file." +} + +variable "additional_userdata" { + default = "" +} + + + + +terraform { + required_version = ">= 0.12" +} + + + + +variable "eks_shared_namespaces" { + description = "Namespaces to be shared between teams." + type = map(list(string)) + default = { + dns = ["external-dns"] + infra = ["infra-shared"] + logging = ["logging"] + monitoring = ["infra-monitoring"] + ingress = ["infra-ingress"] + argo = ["argo"] + newrelic = ["newrelic"] + } +} + +output "hardened-image-id" { + description = "The AMI ID of the hardened image." + value = data.aws_ami.hardened.id +} + + + +locals { + kubernetes_pipeline_roles = [ + for role in var.kubernetes_pipeline_roles : { + rolearn = role.rolearn + namespaces = role.namespaces + } + ] +} + + + +output "private_key_name" { + description = "The name of the private key made to share." + value = module.account.private_key_name +} diff --git a/tsort/tfsort_test.go b/tsort/tfsort_test.go index ac2143a..75f0d95 100644 --- a/tsort/tfsort_test.go +++ b/tsort/tfsort_test.go @@ -1,6 +1,7 @@ package tsort_test import ( + "fmt" "os" "testing" @@ -101,18 +102,53 @@ func TestParse(t *testing.T) { } }) - t.Run("Error writing to output file", func(t *testing.T) { - os.Remove(outputFile) - if err := os.WriteFile(outputFile, []byte("data"), 0o000); err != nil { + // cleanup + os.Remove(outputFile) +} + +func TestParseAll(t *testing.T) { + ingestor := tsort.NewIngestor() + + // Save original content of the files + originalContent, err := os.ReadFile("testdata/valid.tf") + if err != nil { + t.Fatalf("Failed to read original content: %v", err) + } + + t.Run("Valid Directory", func(t *testing.T) { + if err := ingestor.ParseAll("testdata/recursive", false); err != nil { + t.Errorf("Unexpected error: %v", err) + } + for _, file := range []string{"valid.tf", "valid1.tf", "valid2.tf"} { + filePath := fmt.Sprintf("testdata/recursive/%s", file) + expectedFile, _ := os.ReadFile("testdata/expected.tf") + outFile, _ := os.ReadFile(filePath) + + if string(outFile) != string(expectedFile) { + t.Errorf("Output file content in '%s' is not as expected", filePath) + } + } + }) + + t.Run("Write to stdout", func(t *testing.T) { + if err := ingestor.ParseAll("testdata/recursive", true); err != nil { t.Errorf("Unexpected error: %v", err) } - if err := ingestor.Parse(validFilePath, outputFile, false); err == nil { + }) + + t.Run("Error accessing file", func(t *testing.T) { + if err := ingestor.ParseAll("nonexistent_directory", false); err == nil { t.Errorf("Expected error but not occurred") } }) // cleanup - os.Remove(outputFile) + for _, file := range []string{"valid.tf", "valid1.tf", "valid2.tf"} { + filePath := fmt.Sprintf("testdata/recursive/%s", file) + if err := os.WriteFile(filePath, originalContent, 0o644); err != nil { + t.Errorf("Unexpected error: %v", err) + } + } } func TestValidateFilePath(t *testing.T) { @@ -128,12 +164,6 @@ func TestValidateFilePath(t *testing.T) { } }) - t.Run("File is directory", func(t *testing.T) { - if err := tsort.ValidateFilePath("testdata"); err == nil { - t.Errorf("Expected error but not occurred") - } - }) - t.Run("Valid File Path", func(t *testing.T) { if err := tsort.ValidateFilePath(validFilePath); err != nil { t.Errorf("Unexpected error: %v", err) diff --git a/tsort/tsort.go b/tsort/tsort.go index 8a075b2..be636e2 100644 --- a/tsort/tsort.go +++ b/tsort/tsort.go @@ -3,6 +3,7 @@ package tsort import ( "errors" "fmt" + "io/fs" "os" "path/filepath" "regexp" @@ -33,7 +34,8 @@ func NewIngestor() *Ingestor { // CanIngest reads the file at the given path and checks if it is a valid Terraform file // based on its extension and contents. func (i *Ingestor) CanIngest(path string) error { - if _, err := os.Stat(path); os.IsNotExist(err) { + info, err := os.Stat(path) + if err != nil || info.IsDir() { return fmt.Errorf("can't open file '%s': no such file or directory", path) } @@ -95,7 +97,22 @@ func (i *Ingestor) Parse(path string, outputPath string, dry bool) error { return nil } -// validateFilePath returns an error if the given path is empty, does not exist, or is a directory. +// Parses all files within the given path (including subdirectories). +func (i *Ingestor) ParseAll(path string, dry bool) error { + return filepath.Walk(path, func(path string, info fs.FileInfo, err error) error { + switch { + case err != nil: + return fmt.Errorf("error accessing file '%s': %w", path, err) + case info.IsDir(): + return nil + case filepath.Ext(path) == ".tf" || filepath.Ext(path) == ".hcl": + return i.Parse(path, "", dry) + default: + return nil + } + }) +} + func ValidateFilePath(path string) error { if path == "" { return errors.New("file path is required") @@ -109,7 +126,7 @@ func ValidateFilePath(path string) error { case err != nil: return fmt.Errorf("error accessing file '%s': %w", path, err) case info.IsDir(): - return errors.New("path is a directory, not a file") + return nil default: return nil }