Skip to content

Commit

Permalink
Merge pull request #327 from barnybug/timeout-option
Browse files Browse the repository at this point in the history
Add `--timeout` option
  • Loading branch information
barnybug authored Feb 24, 2023
2 parents a2e8ac7 + 8f3d7a8 commit 1b29a32
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 76 deletions.
99 changes: 50 additions & 49 deletions commands.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli53

import (
"context"
"fmt"
"io"
"os"
Expand All @@ -14,7 +15,7 @@ import (

const ChangeBatchSize = 100

func createZone(name, comment, vpcId, vpcRegion, delegationSetId string) {
func createZone(ctx context.Context, name, comment, vpcId, vpcRegion, delegationSetId string) {
callerReference := uniqueReference()
req := route53.CreateHostedZoneInput{
CallerReference: &callerReference,
Expand All @@ -34,20 +35,20 @@ func createZone(name, comment, vpcId, vpcRegion, delegationSetId string) {
delegationSetId = strings.Replace(delegationSetId, "/delegationset/", "", 1)
req.DelegationSetId = aws.String(delegationSetId)
}
resp, err := r53.CreateHostedZone(&req)
resp, err := r53.CreateHostedZoneWithContext(ctx, &req)
fatalIfErr(err)
fmt.Printf("Created zone: '%s' ID: '%s'\n", *resp.HostedZone.Name, *resp.HostedZone.Id)
}

func createReusableDelegationSet(zoneId string) {
func createReusableDelegationSet(ctx context.Context, zoneId string) {
callerReference := uniqueReference()
req := route53.CreateReusableDelegationSetInput{
CallerReference: &callerReference,
}
if zoneId != "" {
req.HostedZoneId = &zoneId
}
resp, err := r53.CreateReusableDelegationSet(&req)
resp, err := r53.CreateReusableDelegationSetWithContext(ctx, &req)
fatalIfErr(err)
ds := resp.DelegationSet
fmt.Printf("Created reusable delegation set ID: '%s'\n", *ds.Id)
Expand All @@ -56,9 +57,9 @@ func createReusableDelegationSet(zoneId string) {
}
}

func listReusableDelegationSets() {
func listReusableDelegationSets(ctx context.Context) {
req := route53.ListReusableDelegationSetsInput{}
resp, err := r53.ListReusableDelegationSets(&req)
resp, err := r53.ListReusableDelegationSetsWithContext(ctx, &req)
fatalIfErr(err)
fmt.Printf("Reusable delegation sets:\n")
if len(resp.DelegationSets) == 0 {
Expand All @@ -74,19 +75,19 @@ func listReusableDelegationSets() {
}
}

func deleteReusableDelegationSet(id string) {
func deleteReusableDelegationSet(ctx context.Context, id string) {
if !strings.HasPrefix(id, "/delegationset/") {
id = "/delegationset/" + id
}
req := route53.DeleteReusableDelegationSetInput{
Id: &id,
}
_, err := r53.DeleteReusableDelegationSet(&req)
_, err := r53.DeleteReusableDelegationSetWithContext(ctx, &req)
fatalIfErr(err)
fmt.Printf("Deleted reusable delegation set\n")
}

func deleteRecordSets(zone *route53.HostedZone, rrsets []*route53.ResourceRecordSet, wait bool) (int, error) {
func deleteRecordSets(ctx context.Context, zone *route53.HostedZone, rrsets []*route53.ResourceRecordSet, wait bool) (int, error) {
// delete all non-default SOA/NS records
changes := []*route53.Change{}
for _, rrset := range rrsets {
Expand All @@ -106,21 +107,21 @@ func deleteRecordSets(zone *route53.HostedZone, rrsets []*route53.ResourceRecord
Changes: changes,
},
}
resp, err := r53.ChangeResourceRecordSets(&req)
resp, err := r53.ChangeResourceRecordSetsWithContext(ctx, &req)
if err != nil {
return 0, err
}
if wait {
waitForChange(resp.ChangeInfo)
waitForChange(ctx, resp.ChangeInfo)
}
}
return len(changes), nil
}

func purgeZoneRecords(zone *route53.HostedZone, wait bool) {
func purgeZoneRecords(ctx context.Context, zone *route53.HostedZone, wait bool) {
total := 0
err := batchListAllRecordSets(r53, *zone.Id, func(rrsets []*route53.ResourceRecordSet) {
n, err := deleteRecordSets(zone, rrsets, wait)
err := batchListAllRecordSets(ctx, r53, *zone.Id, func(rrsets []*route53.ResourceRecordSet) {
n, err := deleteRecordSets(ctx, zone, rrsets, wait)
fatalIfErr(err)
total += n
})
Expand All @@ -129,24 +130,24 @@ func purgeZoneRecords(zone *route53.HostedZone, wait bool) {
fmt.Printf("%d record sets deleted\n", total)
}

func deleteZone(name string, purge bool) {
zone := lookupZone(name)
func deleteZone(ctx context.Context, name string, purge bool) {
zone := lookupZone(ctx, name)
if purge {
purgeZoneRecords(zone, false)
purgeZoneRecords(ctx, zone, false)
}
req := route53.DeleteHostedZoneInput{Id: zone.Id}
_, err := r53.DeleteHostedZone(&req)
_, err := r53.DeleteHostedZoneWithContext(ctx, &req)
fatalIfErr(err)
fmt.Printf("Deleted zone: '%s' ID: '%s'\n", *zone.Name, *zone.Id)
}

func listZones(formatter Formatter) {
func listZones(ctx context.Context, formatter Formatter) {
zones := make(chan *route53.HostedZone)
go func() {
req := route53.ListHostedZonesInput{}
for {
// paginated
resp, err := r53.ListHostedZones(&req)
resp, err := r53.ListHostedZonesWithContext(ctx, &req)
fatalIfErr(err)
for _, zone := range resp.HostedZones {
zones <- zone
Expand Down Expand Up @@ -276,8 +277,8 @@ func validateBindFile(args importArgs) {
parseBindFile(reader, args.file, "validate.test")
}

func importBind(args importArgs) {
zone := lookupZone(args.name)
func importBind(ctx context.Context, args importArgs) {
zone := lookupZone(ctx, args.name)

var reader io.Reader
if args.file == "-" {
Expand All @@ -295,7 +296,7 @@ func importBind(args importArgs) {
grouped := groupRecords(records)
existing := map[string]*route53.ResourceRecordSet{}
if args.replace || args.upsert {
rrsets, err := ListAllRecordSets(r53, *zone.Id)
rrsets, err := ListAllRecordSets(ctx, r53, *zone.Id)
fatalIfErr(err)
for _, rrset := range rrsets {
if args.editauth || !isAuthRecord(zone, rrset) {
Expand Down Expand Up @@ -363,16 +364,16 @@ func importBind(args importArgs) {
}
}
} else {
resp := batchChanges(additions, deletions, zone)
resp := batchChanges(ctx, additions, deletions, zone)
fmt.Printf("%d records imported (%d changes / %d additions / %d deletions)\n", len(records), len(additions)+len(deletions), len(additions), len(deletions))

if args.wait && resp != nil {
waitForChange(resp.ChangeInfo)
waitForChange(ctx, resp.ChangeInfo)
}
}
}

func batchChanges(additions, deletions []*route53.Change, zone *route53.HostedZone) *route53.ChangeResourceRecordSetsOutput {
func batchChanges(ctx context.Context, additions, deletions []*route53.Change, zone *route53.HostedZone) *route53.ChangeResourceRecordSetsOutput {
// sort additions so aliases are last
sort.Sort(changeSorter{additions})

Expand All @@ -392,7 +393,7 @@ func batchChanges(additions, deletions []*route53.Change, zone *route53.HostedZo
ChangeBatch: &batch,
}
var err error
resp, err = r53.ChangeResourceRecordSets(&req)
resp, err = r53.ChangeResourceRecordSetsWithContext(ctx, &req)
fatalIfErr(err)
}
return resp
Expand All @@ -416,9 +417,9 @@ func UnexpandSelfAliases(records []dns.RR, zone *route53.HostedZone, full bool)
}
}

func exportBind(name string, full bool, writer io.Writer) {
zone := lookupZone(name)
ExportBindToWriter(r53, zone, full, writer)
func exportBind(ctx context.Context, name string, full bool, writer io.Writer) {
zone := lookupZone(ctx, name)
ExportBindToWriter(ctx, r53, zone, full, writer)
}

type exportSorter struct {
Expand Down Expand Up @@ -450,8 +451,8 @@ func (r exportSorter) Less(i, j int) bool {
return *r.rrsets[i].Name < *r.rrsets[j].Name
}

func ExportBindToWriter(r53 *route53.Route53, zone *route53.HostedZone, full bool, out io.Writer) {
rrsets, err := ListAllRecordSets(r53, *zone.Id)
func ExportBindToWriter(ctx context.Context, r53 *route53.Route53, zone *route53.HostedZone, full bool, out io.Writer) {
rrsets, err := ListAllRecordSets(ctx, r53, *zone.Id)
fatalIfErr(err)

sort.Sort(exportSorter{rrsets, *zone.Name})
Expand Down Expand Up @@ -607,8 +608,8 @@ func parseRecordList(args []string, zone *route53.HostedZone) []dns.RR {
return records
}

func createRecords(args createArgs) {
zone := lookupZone(args.name)
func createRecords(ctx context.Context, args createArgs) {
zone := lookupZone(ctx, args.name)
records := parseRecordList(args.records, zone)
expandSelfAliases(records, zone)

Expand All @@ -617,7 +618,7 @@ func createRecords(args createArgs) {
var existing []*route53.ResourceRecordSet
if args.replace || args.append {
var err error
existing, err = ListAllRecordSets(r53, *zone.Id)
existing, err = ListAllRecordSets(ctx, r53, *zone.Id)
fatalIfErr(err)
}

Expand Down Expand Up @@ -654,25 +655,25 @@ func createRecords(args createArgs) {
}
}

resp := batchChanges(additions, deletions, zone)
resp := batchChanges(ctx, additions, deletions, zone)

for _, record := range records {
txt := strings.Replace(record.String(), "\t", " ", -1)
fmt.Printf("Created record: '%s'\n", txt)
}

if args.wait {
waitForChange(resp.ChangeInfo)
waitForChange(ctx, resp.ChangeInfo)
}
}

func batchListAllRecordSets(r53 *route53.Route53, id string, callback func(rrsets []*route53.ResourceRecordSet)) error {
func batchListAllRecordSets(ctx context.Context, r53 *route53.Route53, id string, callback func(rrsets []*route53.ResourceRecordSet)) error {
req := route53.ListResourceRecordSetsInput{
HostedZoneId: &id,
}

for {
resp, err := r53.ListResourceRecordSets(&req)
resp, err := r53.ListResourceRecordSetsWithContext(ctx, &req)
if err != nil {
return err
} else {
Expand All @@ -690,8 +691,8 @@ func batchListAllRecordSets(r53 *route53.Route53, id string, callback func(rrset
}

// Paginate request to get all record sets.
func ListAllRecordSets(r53 *route53.Route53, id string) (rrsets []*route53.ResourceRecordSet, err error) {
err = batchListAllRecordSets(r53, id, func(results []*route53.ResourceRecordSet) {
func ListAllRecordSets(ctx context.Context, r53 *route53.Route53, id string) (rrsets []*route53.ResourceRecordSet, err error) {
err = batchListAllRecordSets(ctx, r53, id, func(results []*route53.ResourceRecordSet) {
rrsets = append(rrsets, results...)
})

Expand All @@ -703,9 +704,9 @@ func ListAllRecordSets(r53 *route53.Route53, id string) (rrsets []*route53.Resou
return
}

func deleteRecord(name string, match string, rtype string, wait bool, identifier string) {
zone := lookupZone(name)
rrsets, err := ListAllRecordSets(r53, *zone.Id)
func deleteRecord(ctx context.Context, name string, match string, rtype string, wait bool, identifier string) {
zone := lookupZone(ctx, name)
rrsets, err := ListAllRecordSets(ctx, r53, *zone.Id)
fatalIfErr(err)

match = qualifyName(match, *zone.Name)
Expand All @@ -727,18 +728,18 @@ func deleteRecord(name string, match string, rtype string, wait bool, identifier
Changes: changes,
},
}
resp, err := r53.ChangeResourceRecordSets(&req2)
resp, err := r53.ChangeResourceRecordSetsWithContext(ctx, &req2)
fatalIfErr(err)
fmt.Printf("%d record sets deleted\n", len(changes))
if wait {
waitForChange(resp.ChangeInfo)
waitForChange(ctx, resp.ChangeInfo)
}
} else {
fmt.Println("Warning: no records matched - nothing deleted")
}
}

func purgeRecords(name string, wait bool) {
zone := lookupZone(name)
purgeZoneRecords(zone, wait)
func purgeRecords(ctx context.Context, name string, wait bool) {
zone := lookupZone(ctx, name)
purgeZoneRecords(ctx, zone, wait)
}
9 changes: 5 additions & 4 deletions instances.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli53

import (
"context"
"fmt"
"regexp"
"strings"
Expand Down Expand Up @@ -28,8 +29,8 @@ type InstanceRecord struct {
value string
}

func instances(args instancesArgs, config *aws.Config) {
zone := lookupZone(args.name)
func instances(ctx context.Context, args instancesArgs, config *aws.Config) {
zone := lookupZone(ctx, args.name)
fmt.Println("Getting DNS records")

describeInstancesInput := ec2.DescribeInstancesInput{}
Expand Down Expand Up @@ -140,11 +141,11 @@ func instances(args instancesArgs, config *aws.Config) {
fmt.Printf("+ %s %s %v\n", *rr.Name, *rr.Type, *rr.ResourceRecords[0].Value)
}
} else {
resp := batchChanges(upserts, []*route53.Change{}, zone)
resp := batchChanges(ctx, upserts, []*route53.Change{}, zone)
fmt.Printf("%d records upserted\n", len(upserts))

if args.wait && resp != nil {
waitForChange(resp.ChangeInfo)
waitForChange(ctx, resp.ChangeInfo)
}
}
}
13 changes: 9 additions & 4 deletions internal/features/step_definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package features

import (
"bytes"
"context"
"fmt"
"io/ioutil"
"log"
Expand Down Expand Up @@ -90,7 +91,8 @@ func uniqueReference() string {

func cleanupDomain(r53 *route53.Route53, id string) {
// delete all non-default SOA/NS records
rrsets, err := cli53.ListAllRecordSets(r53, id)
ctx := context.Background()
rrsets, err := cli53.ListAllRecordSets(ctx, r53, id)
fatalIfErr(err)
changes := []*route53.Change{}
for _, rrset := range rrsets {
Expand Down Expand Up @@ -309,7 +311,8 @@ func init() {
name = domain(name)
r53 := getService()
id := domainId(name)
rrsets, err := cli53.ListAllRecordSets(r53, id)
ctx := context.Background()
rrsets, err := cli53.ListAllRecordSets(ctx, r53, id)
fatalIfErr(err)
actual := len(rrsets)
if expected != actual {
Expand Down Expand Up @@ -338,7 +341,8 @@ func init() {
r53 := getService()
zone := domainZone(name)
out := new(bytes.Buffer)
cli53.ExportBindToWriter(r53, zone, false, out)
ctx := context.Background()
cli53.ExportBindToWriter(ctx, r53, zone, false, out)
actual := out.Bytes()
rfile, err := os.Open(filename)
fatalIfErr(err)
Expand Down Expand Up @@ -414,7 +418,8 @@ func init() {
func hasRecord(name, record string) bool {
r53 := getService()
zone := domainZone(name)
rrsets, err := cli53.ListAllRecordSets(r53, *zone.Id)
ctx := context.Background()
rrsets, err := cli53.ListAllRecordSets(ctx, r53, *zone.Id)
fatalIfErr(err)

for _, rrset := range rrsets {
Expand Down
Loading

0 comments on commit 1b29a32

Please sign in to comment.