Skip to content

Commit

Permalink
Simplify methods
Browse files Browse the repository at this point in the history
  • Loading branch information
seilagamo committed Jan 28, 2025
1 parent 0ebfd3f commit edef026
Showing 1 changed file with 23 additions and 52 deletions.
75 changes: 23 additions & 52 deletions cmd/vulcan-subdomain-takeover/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type awsConfigurator interface {
}

type route53Client interface {
ListHostedZonesByName(ctx context.Context, params *route53.ListHostedZonesByNameInput, optFns ...func(*route53.Options)) (*route53.ListHostedZonesByNameOutput, error)
ListHostedZones(context.Context, *route53.ListHostedZonesInput, ...func(*route53.Options)) (*route53.ListHostedZonesOutput, error)
ListResourceRecordSets(ctx context.Context, params *route53.ListResourceRecordSetsInput, optFns ...func(*route53.Options)) (*route53.ListResourceRecordSetsOutput, error)
}

Expand Down Expand Up @@ -100,9 +100,7 @@ func NewScan(logger *logrus.Entry, target string) (Scanner, error) {
}

func (s Scanner) Run() ([]string, error) {
var routeZoneRecordsSet = make(map[string]string)

dnsRecords, err := s.getRoute53ARecords(routeZoneRecordsSet)
dnsRecords, err := s.getRoute53ARecords()
if err != nil {
return nil, fmt.Errorf("get DNS records: %w", err)
}
Expand Down Expand Up @@ -179,17 +177,16 @@ type dnsRecord struct {
records []string
}

func (s Scanner) getRoute53ARecords(routeZoneRecordsSet map[string]string) ([]dnsRecord, error) {
func (s Scanner) getRoute53ARecords() ([]dnsRecord, error) {
var dnsRecords []dnsRecord

hz, err := s.getRoute53HostedZones(nil, nil)
hz, err := s.getRoute53HostedZones()
if err != nil {
return nil, fmt.Errorf("get hosted zones: %w", err)
}

for _, hostedZone := range hz {
var nextRecordType types.RRType
zr, err := s.getRoute53ZoneRecords(hostedZone.Id, nil, nextRecordType, routeZoneRecordsSet)
zr, err := s.getRoute53ZoneRecords(hostedZone.Id)
if err != nil {
return nil, fmt.Errorf("get zone records: %w", err)
}
Expand All @@ -216,59 +213,33 @@ func (s Scanner) getRoute53ARecords(routeZoneRecordsSet map[string]string) ([]dn
return dnsRecords, nil
}

func (s Scanner) getRoute53HostedZones(dnsName, hostedZoneId *string) ([]types.HostedZone, error) {
var listHostedZonesByNameOutput *route53.ListHostedZonesByNameOutput
var err error
var listParams route53.ListHostedZonesByNameInput
if dnsName != nil && *dnsName == "" {
listParams.DNSName = dnsName
listParams.HostedZoneId = hostedZoneId
}
listHostedZonesByNameOutput, err = s.route53Client.ListHostedZonesByName(context.Background(), &listParams)
if err != nil {
return nil, fmt.Errorf("list hosted zones: %w", err)
}

hostedZones := listHostedZonesByNameOutput.HostedZones

if listHostedZonesByNameOutput.IsTruncated {
hz, err := s.getRoute53HostedZones(listHostedZonesByNameOutput.NextDNSName, listHostedZonesByNameOutput.HostedZoneId)
// getRoute53HostedZones retrieves all the hosted zones.
func (s Scanner) getRoute53HostedZones() ([]types.HostedZone, error) {
paginator := route53.NewListHostedZonesPaginator(s.route53Client, nil)
var hostedZones []types.HostedZone
for paginator.HasMorePages() {
resp, err := paginator.NextPage(context.Background())
if err != nil {
return nil, fmt.Errorf("get hosted zones: %w", err)
return nil, fmt.Errorf("list hosted zones: %w", err)
}
hostedZones = append(hostedZones, hz...)
hostedZones = append(hostedZones, resp.HostedZones...)
}
return hostedZones, err
return hostedZones, nil
}

func (s Scanner) getRoute53ZoneRecords(
zoneId *string, nextRecordName *string, nextRecordType types.RRType,
routeZoneRecordsSet map[string]string) ([]types.ResourceRecordSet, error) {
var recordSetsOutput *route53.ListResourceRecordSetsOutput
var err error
// getRoute53ZoneRecords retrieves all the Zone Records for a ZoneId.
func (s Scanner) getRoute53ZoneRecords(zoneId *string) ([]types.ResourceRecordSet, error) {
listParams := &route53.ListResourceRecordSetsInput{
HostedZoneId: zoneId,
}
if nextRecordName != nil && *nextRecordName != "" {
listParams.StartRecordName = nextRecordName
listParams.StartRecordType = nextRecordType
}
recordSetsOutput, err = s.route53Client.ListResourceRecordSets(context.Background(), listParams)
if err != nil {
return nil, fmt.Errorf("list resource records: %w", err)
}

zoneRecords := recordSetsOutput.ResourceRecordSets

if recordSetsOutput.IsTruncated {
zoneSetRecordKey := *recordSetsOutput.NextRecordName + "_" + string(recordSetsOutput.NextRecordType)
if _, ok := routeZoneRecordsSet[zoneSetRecordKey]; !ok {
zr, err := s.getRoute53ZoneRecords(zoneId, recordSetsOutput.NextRecordName, recordSetsOutput.NextRecordType, routeZoneRecordsSet)
if err != nil {
return nil, fmt.Errorf("get zone records: %w", err)
}
zoneRecords = append(zoneRecords, zr...)
paginator := route53.NewListResourceRecordSetsPaginator(s.route53Client, listParams)
var zoneRecords []types.ResourceRecordSet
for paginator.HasMorePages() {
resp, err := paginator.NextPage(context.Background())
if err != nil {
return nil, fmt.Errorf("list resource record sets: %w", err)
}
zoneRecords = append(zoneRecords, resp.ResourceRecordSets...)
}
return zoneRecords, nil
}
Expand Down

0 comments on commit edef026

Please sign in to comment.