File: //opt/go/pkg/mod/github.com/aws/
[email protected]/aws/session/client_tls_cert_test.go
//go:build go1.9
// +build go1.9
package session
import (
"crypto/x509"
"io"
"net/http"
"os"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestNewSession_WithClientTLSCert(t *testing.T) {
type testCase struct {
// Params
setup func(certFilename, keyFilename string) (Options, func(), error)
ExpectErr string
}
cases := map[string]testCase{
"env": {
setup: func(certFilename, keyFilename string) (Options, func(), error) {
os.Setenv(useClientTLSCert[0], certFilename)
os.Setenv(useClientTLSKey[0], keyFilename)
return Options{}, func() {}, nil
},
},
"env file not found": {
setup: func(certFilename, keyFilename string) (Options, func(), error) {
os.Setenv(useClientTLSCert[0], "some-cert-file-not-exists")
os.Setenv(useClientTLSKey[0], "some-key-file-not-exists")
return Options{}, func() {}, nil
},
ExpectErr: "LoadClientTLSCertError",
},
"env cert file only": {
setup: func(certFilename, keyFilename string) (Options, func(), error) {
os.Setenv(useClientTLSCert[0], certFilename)
return Options{}, func() {}, nil
},
ExpectErr: "must both be provided",
},
"env key file only": {
setup: func(certFilename, keyFilename string) (Options, func(), error) {
os.Setenv(useClientTLSKey[0], keyFilename)
return Options{}, func() {}, nil
},
ExpectErr: "must both be provided",
},
"session options": {
setup: func(certFilename, keyFilename string) (Options, func(), error) {
certFile, err := os.Open(certFilename)
if err != nil {
return Options{}, nil, err
}
keyFile, err := os.Open(keyFilename)
if err != nil {
return Options{}, nil, err
}
return Options{
ClientTLSCert: certFile,
ClientTLSKey: keyFile,
}, func() {
certFile.Close()
keyFile.Close()
}, nil
},
},
"session cert load error": {
setup: func(certFilename, keyFilename string) (Options, func(), error) {
certFile, err := os.Open(certFilename)
if err != nil {
return Options{}, nil, err
}
keyFile, err := os.Open(keyFilename)
if err != nil {
return Options{}, nil, err
}
stat, _ := certFile.Stat()
return Options{
ClientTLSCert: io.LimitReader(certFile, stat.Size()/2),
ClientTLSKey: keyFile,
}, func() {
certFile.Close()
keyFile.Close()
}, nil
},
ExpectErr: "unable to load x509 key pair",
},
"session key load error": {
setup: func(certFilename, keyFilename string) (Options, func(), error) {
certFile, err := os.Open(certFilename)
if err != nil {
return Options{}, nil, err
}
keyFile, err := os.Open(keyFilename)
if err != nil {
return Options{}, nil, err
}
stat, _ := keyFile.Stat()
return Options{
ClientTLSCert: certFile,
ClientTLSKey: io.LimitReader(keyFile, stat.Size()/2),
}, func() {
certFile.Close()
keyFile.Close()
}, nil
},
ExpectErr: "unable to load x509 key pair",
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
// Asserts
restoreEnvFn := initSessionTestEnv()
defer restoreEnvFn()
certFilename, keyFilename, err := awstesting.CreateClientTLSCertFiles()
if err != nil {
t.Fatalf("failed to create client certificate files, %v", err)
}
defer func() {
if err := awstesting.CleanupTLSBundleFiles(certFilename, keyFilename); err != nil {
t.Errorf("failed to cleanup client TLS cert files, %v", err)
}
}()
opts, cleanup, err := c.setup(certFilename, keyFilename)
if err != nil {
t.Fatalf("test case failed setup, %v", err)
}
if cleanup != nil {
defer cleanup()
}
server, err := awstesting.NewTLSClientCertServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
if err != nil {
t.Fatalf("failed to load session, %v", err)
}
server.StartTLS()
defer server.Close()
// Give server change to start
time.Sleep(time.Second)
// Load SDK session with options configured.
sess, err := NewSessionWithOptions(opts)
if len(c.ExpectErr) != 0 {
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %v, got %v", e, a)
}
return
}
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
// Clients need to add ca bundle for test service.
p := x509.NewCertPool()
p.AddCert(server.Certificate())
client := sess.Config.HTTPClient
client.Transport.(*http.Transport).TLSClientConfig.RootCAs = p
// Send request
req, _ := http.NewRequest("GET", server.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to send request, %v", err)
}
if e, a := 200, resp.StatusCode; e != a {
t.Errorf("expect %v status code, got %v", e, a)
}
})
}
}