Skip to content

Commit

Permalink
fix: properly set credentials for all operations
Browse files Browse the repository at this point in the history
  • Loading branch information
dejanb committed Jan 31, 2024
1 parent 5dfbebb commit 455f567
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 58 deletions.
85 changes: 40 additions & 45 deletions pkg/handler/collector/s3/bucket/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,52 @@ func GetDefaultBucket(url string, region string) Bucket {
return &s3Bucket{url, region}
}

func (d *s3Bucket) ListFiles(ctx context.Context, bucket string, token *string, max int32) ([]string, *string, error) {
func (d *s3Bucket) getS3Client(ctx context.Context) (*s3.Client, error) {
s3Config := &viper.Viper{}
s3Config.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
s3Config.AutomaticEnv()

accessKey := s3Config.GetString("storage-access-key")
secretKey := s3Config.GetString("storage-secret-key")
region := s3Config.GetString("storage-region")
if region == "" {
region = d.region
}

cfg, err := config.LoadDefaultConfig(ctx)

if err != nil {
return nil, nil, fmt.Errorf("error loading AWS SDK config: %w", err)
return nil, fmt.Errorf("error loading AWS SDK config: %w", err)
}

client := s3.NewFromConfig(cfg, func(o *s3.Options) {
return s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = true
if d.url != "" {
o.BaseEndpoint = aws.String(d.url)
}

if d.region != "" {
o.Region = d.region
if region != "" {
o.Region = region
}

if accessKey != "" && secretKey != "" {
staticProvider := credentials.NewStaticCredentialsProvider(
accessKey,
secretKey,
"",
)
o.Credentials = staticProvider
}
})

}), nil

}

func (d *s3Bucket) ListFiles(ctx context.Context, bucket string, token *string, max int32) ([]string, *string, error) {
client, err := d.getS3Client(ctx)
if err != nil {
return nil, nil, fmt.Errorf("error creating S3 client: %w", err)
}

input := &s3.ListObjectsV2Input{
Bucket: &bucket,
Expand All @@ -94,36 +124,11 @@ func (d *s3Bucket) ListFiles(ctx context.Context, bucket string, token *string,
}

func (d *s3Bucket) DownloadFile(ctx context.Context, bucket string, item string) ([]byte, error) {
s3Config := &viper.Viper{}
s3Config.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
s3Config.AutomaticEnv()

accessKey := s3Config.GetString("storage-access-key")
secretKey := s3Config.GetString("storage-secret-key")
region := s3Config.GetString("storage-region")

staticProvider := credentials.NewStaticCredentialsProvider(
accessKey,
secretKey,
"",
)

cfg, err := config.LoadDefaultConfig(
ctx,
config.WithCredentialsProvider(staticProvider),
config.WithRegion(region),
)
client, err := d.getS3Client(ctx)
if err != nil {
return nil, fmt.Errorf("error loading AWS SDK config: %w", err)
return nil, fmt.Errorf("error creating S3 client: %w", err)
}

client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = true
if d.url != "" {
o.BaseEndpoint = aws.String(d.url)
}
})

// Create a GetObjectInput with the bucket name and object key.
input := &s3.GetObjectInput{
Bucket: aws.String(bucket),
Expand All @@ -147,21 +152,11 @@ func (d *s3Bucket) DownloadFile(ctx context.Context, bucket string, item string)

func (d *s3Bucket) GetEncoding(ctx context.Context, bucket string, item string) (string, error) {
logger := logging.FromContext(ctx)
cfg, err := config.LoadDefaultConfig(ctx)
client, err := d.getS3Client(ctx)
if err != nil {
return "", fmt.Errorf("error loading AWS SDK config: %w", err)
return "", fmt.Errorf("error creating S3 client: %w", err)
}

client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = true
if d.url != "" {
o.BaseEndpoint = aws.String(d.url)
}
if d.region != "" {
o.Region = d.region
}
})

logger.Infof("Downloading document %v from bucket %v", item, bucket)

headObject, err := client.HeadObject(context.Background(), &s3.HeadObjectInput{Bucket: aws.String(bucket), Key: aws.String(item)})
Expand Down
27 changes: 14 additions & 13 deletions pkg/handler/collector/s3/messaging/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,7 @@ func NewSqsProvider(mpConfig MessageProviderConfig) (SqsProvider, error) {
secretKey := sqsConfig.GetString("sqs-secret-key")
region := sqsConfig.GetString("sqs-region")

fmt.Printf("SQS CREDS %s %s %s\n", accessKey, secretKey, region)

staticProvider := credentials.NewStaticCredentialsProvider(
accessKey,
secretKey,
"",
)

cfg, err := config.LoadDefaultConfig(
context.TODO(),
config.WithCredentialsProvider(staticProvider),
config.WithRegion(region),
)
cfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
return SqsProvider{}, fmt.Errorf("error loading AWS SDK config: %w", err)
}
Expand All @@ -114,6 +102,19 @@ func NewSqsProvider(mpConfig MessageProviderConfig) (SqsProvider, error) {
if mpConfig.Endpoint != "" {
o.EndpointResolver = sqs.EndpointResolverFromURL(mpConfig.Endpoint)
}

if region != "" {
o.Region = region
}

if accessKey != "" && secretKey != "" {
staticProvider := credentials.NewStaticCredentialsProvider(
accessKey,
secretKey,
"",
)
o.Credentials = staticProvider
}
})

sqsProvider.client = client
Expand Down

0 comments on commit 455f567

Please sign in to comment.