File: //opt/go/pkg/mod/github.com/aws/
[email protected]/aws/credentials/credentials_test.go
package credentials
import (
"math/rand"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type stubProvider struct {
creds Value
retrievedCount int
expired bool
err error
}
func (s *stubProvider) Retrieve() (Value, error) {
s.retrievedCount++
s.expired = false
s.creds.ProviderName = "stubProvider"
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
return s.expired
}
func TestCredentialsGet(t *testing.T) {
c := NewCredentials(&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
expired: true,
})
creds, err := c.Get()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.Get()
if e, a := "provider error", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected provider error, %v got %v", e, a)
}
}
func TestCredentialsExpire(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
stub.expired = false
if !c.IsExpired() {
t.Errorf("Expected to start out expired")
}
c.Expire()
if !c.IsExpired() {
t.Errorf("Expected to be expired")
}
c.Get()
if c.IsExpired() {
t.Errorf("Expected not to be expired")
}
stub.expired = true
if !c.IsExpired() {
t.Errorf("Expected to be expired")
}
}
type MockProvider struct {
Expiry
}
func (*MockProvider) Retrieve() (Value, error) {
return Value{}, nil
}
func TestCredentialsGetWithProviderName(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
creds, err := c.Get()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := creds.ProviderName, "stubProvider"; e != a {
t.Errorf("Expected provider name to match, %v got %v", e, a)
}
}
func TestCredentialsIsExpired_Race(t *testing.T) {
creds := NewChainCredentials([]Provider{&MockProvider{}})
starter := make(chan struct{})
var wg sync.WaitGroup
wg.Add(10)
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
<-starter
for i := 0; i < 100; i++ {
creds.IsExpired()
}
}()
}
close(starter)
wg.Wait()
}
func TestCredentialsExpiresAt_NoExpirer(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
_, err := c.ExpiresAt()
if e, a := "ProviderNotExpirer", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected provider error, %v got %v", e, a)
}
}
type stubProviderExpirer struct {
stubProvider
expiration time.Time
}
func (s *stubProviderExpirer) ExpiresAt() time.Time {
return s.expiration
}
func TestCredentialsExpiresAt_HasExpirer(t *testing.T) {
stub := &stubProviderExpirer{}
c := NewCredentials(stub)
// fetch initial credentials so that forceRefresh is set false
_, err := c.Get()
if err != nil {
t.Errorf("Unexpecte error: %v", err)
}
stub.expiration = time.Unix(rand.Int63(), 0)
expiration, err := c.ExpiresAt()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if stub.expiration != expiration {
t.Errorf("Expected matching expiration, %v got %v", stub.expiration, expiration)
}
c.Expire()
expiration, err = c.ExpiresAt()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !expiration.IsZero() {
t.Errorf("Expected distant past expiration, got %v", expiration)
}
}
type stubProviderConcurrent struct {
stubProvider
done chan struct{}
}
func (s *stubProviderConcurrent) Retrieve() (Value, error) {
<-s.done
return s.stubProvider.Retrieve()
}
func TestCredentialsGetConcurrent(t *testing.T) {
stub := &stubProviderConcurrent{
done: make(chan struct{}),
}
c := NewCredentials(stub)
done := make(chan struct{})
for i := 0; i < 2; i++ {
go func() {
c.Get()
done <- struct{}{}
}()
}
// Validates that a single call to Retrieve is shared between two calls to Get
stub.done <- struct{}{}
<-done
<-done
}
type stubProviderRefreshable struct {
creds Value
expired bool
hasRetrieved bool
}
func (s *stubProviderRefreshable) Retrieve() (Value, error) {
// On first retrieval, return the creds that this provider was created with.
// On subsequent retrievals, return new refreshed credentials.
if !s.hasRetrieved {
s.expired = true
s.hasRetrieved = true
} else {
s.creds = Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "NEW_SESSION",
}
s.expired = false
time.Sleep(10 * time.Millisecond)
}
return s.creds, nil
}
func (s *stubProviderRefreshable) IsExpired() bool {
return s.expired
}
func TestCredentialsGet_RefreshableProviderRace(t *testing.T) {
stub := &stubProviderRefreshable{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "OLD_SESSION",
},
}
c := NewCredentials(stub)
// The first Get() causes stubProviderRefreshable to consider its
// OLD_SESSION credentials expired on subsequent retrievals.
creds, err := c.Get()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := "OLD_SESSION", creds.SessionToken; e != a {
t.Errorf("Expect session token to match, %v got %v", e, a)
}
// Since stubProviderRefreshable considers its OLD_SESSION credentials
// expired, all subsequent calls to Get() should retrieve NEW_SESSION creds.
var wg sync.WaitGroup
wg.Add(100)
for i := 0; i < 100; i++ {
go func() {
defer wg.Done()
creds, err := c.Get()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if c.IsExpired() {
t.Errorf("not expect expired")
}
if e, a := "NEW_SESSION", creds.SessionToken; e != a {
t.Errorf("Expect session token to match, %v got %v", e, a)
}
}()
}
wg.Wait()
}