diff --git a/state/aws.go b/state/aws.go index 85dd66b3..91310380 100644 --- a/state/aws.go +++ b/state/aws.go @@ -135,25 +135,32 @@ func (a *AWS) GetLocks() (locks map[string]LockInfo, err error) { // GetStates returns a slice of State files in the S3 bucket func (a *AWS) GetStates() (states []string, err error) { + truncatedListing := true + var keys []string log.WithFields(log.Fields{ "bucket": a.bucket, "prefix": a.keyPrefix, }).Debug("Listing states from S3") - result, err := a.svc.ListObjects(&s3.ListObjectsInput{ + + params := s3.ListObjectsV2Input{ Bucket: aws_sdk.String(a.bucket), Prefix: &a.keyPrefix, - }) - if err != nil { - return states, err } + for truncatedListing { + result, err := a.svc.ListObjectsV2(¶ms) + if err != nil { + return states, err + } - var keys []string - for _, obj := range result.Contents { - for _, ext := range a.fileExtension { - if strings.HasSuffix(*obj.Key, ext) { - keys = append(keys, *obj.Key) + for _, obj := range result.Contents { + for _, ext := range a.fileExtension { + if strings.HasSuffix(*obj.Key, ext) { + keys = append(keys, *obj.Key) + } } } + params.ContinuationToken = result.NextContinuationToken + truncatedListing = *result.IsTruncated } states = keys log.WithFields(log.Fields{ diff --git a/state/aws_test.go b/state/aws_test.go index cbb26277..48a6f0c0 100644 --- a/state/aws_test.go +++ b/state/aws_test.go @@ -276,8 +276,12 @@ type s3Mock struct { s3iface.S3API } -func (s *s3Mock) ListObjects(_ *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { - return &s3.ListObjectsOutput{Contents: []*s3.Object{{Key: aws.String("test.tfstate")}}}, nil +func (s *s3Mock) ListObjectsV2(_ *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) { + return &s3.ListObjectsV2Output{Contents: []*s3.Object{ + {Key: aws.String("test.tfstate")}, {Key: aws.String("test2.tfstate")}, {Key: aws.String("test3.tfstate")}}, + IsTruncated: func() *bool { b := false; return &b }(), + KeyCount: func() *int64 { b := int64(3); return &b }(), + MaxKeys: func() *int64 { b := int64(1000); return &b }()}, nil } func (s *s3Mock) ListObjectVersions(_ *s3.ListObjectVersionsInput) (*s3.ListObjectVersionsOutput, error) { return &s3.ListObjectVersionsOutput{ @@ -316,8 +320,8 @@ func TestGetStates(t *testing.T) { states, err := awsInstance.GetStates() if err != nil { t.Error(err) - } else if len(states) != 1 { - t.Error("Expected one state") + } else if len(states) != 3 { + t.Error("Expected three states") } }