mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-05 20:40:16 -08:00
refactor: centralize HTTP transport configuration (#9058)
Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
18
pkg/cache/remote.go
vendored
18
pkg/cache/remote.go
vendored
@@ -2,7 +2,6 @@ package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
|
||||
"github.com/twitchtv/twirp"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/rpc"
|
||||
"github.com/aquasecurity/trivy/pkg/rpc/client"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
rpcCache "github.com/aquasecurity/trivy/rpc/cache"
|
||||
)
|
||||
|
||||
@@ -19,7 +19,6 @@ var _ ArtifactCache = (*RemoteCache)(nil)
|
||||
type RemoteOptions struct {
|
||||
ServerAddr string
|
||||
CustomHeaders http.Header
|
||||
Insecure bool
|
||||
PathPrefix string
|
||||
}
|
||||
|
||||
@@ -30,23 +29,14 @@ type RemoteCache struct {
|
||||
}
|
||||
|
||||
// NewRemoteCache is the factory method for RemoteCache
|
||||
func NewRemoteCache(opts RemoteOptions) *RemoteCache {
|
||||
ctx := client.WithCustomHeaders(context.Background(), opts.CustomHeaders)
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: opts.Insecure,
|
||||
},
|
||||
},
|
||||
}
|
||||
func NewRemoteCache(ctx context.Context, opts RemoteOptions) *RemoteCache {
|
||||
ctx = client.WithCustomHeaders(ctx, opts.CustomHeaders)
|
||||
|
||||
var twirpOpts []twirp.ClientOption
|
||||
if opts.PathPrefix != "" {
|
||||
twirpOpts = append(twirpOpts, twirp.WithClientPathPrefix(opts.PathPrefix))
|
||||
}
|
||||
c := rpcCache.NewCacheProtobufClient(opts.ServerAddr, httpClient, twirpOpts...)
|
||||
c := rpcCache.NewCacheProtobufClient(opts.ServerAddr, xhttp.ClientWithContext(ctx), twirpOpts...)
|
||||
return &RemoteCache{
|
||||
ctx: ctx,
|
||||
client: c,
|
||||
|
||||
16
pkg/cache/remote_test.go
vendored
16
pkg/cache/remote_test.go
vendored
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/cache"
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
rpcCache "github.com/aquasecurity/trivy/rpc/cache"
|
||||
rpcScanner "github.com/aquasecurity/trivy/rpc/scanner"
|
||||
)
|
||||
@@ -145,10 +146,9 @@ func TestRemoteCache_PutArtifact(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := cache.NewRemoteCache(cache.RemoteOptions{
|
||||
c := cache.NewRemoteCache(t.Context(), cache.RemoteOptions{
|
||||
ServerAddr: ts.URL,
|
||||
CustomHeaders: tt.args.customHeaders,
|
||||
Insecure: false,
|
||||
})
|
||||
err := c.PutArtifact(tt.args.imageID, tt.args.imageInfo)
|
||||
if tt.wantErr != "" {
|
||||
@@ -208,10 +208,9 @@ func TestRemoteCache_PutBlob(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := cache.NewRemoteCache(cache.RemoteOptions{
|
||||
c := cache.NewRemoteCache(t.Context(), cache.RemoteOptions{
|
||||
ServerAddr: ts.URL,
|
||||
CustomHeaders: tt.args.customHeaders,
|
||||
Insecure: false,
|
||||
})
|
||||
err := c.PutBlob(tt.args.diffID, tt.args.layerInfo)
|
||||
if tt.wantErr != "" {
|
||||
@@ -288,10 +287,9 @@ func TestRemoteCache_MissingBlobs(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := cache.NewRemoteCache(cache.RemoteOptions{
|
||||
c := cache.NewRemoteCache(t.Context(), cache.RemoteOptions{
|
||||
ServerAddr: ts.URL,
|
||||
CustomHeaders: tt.args.customHeaders,
|
||||
Insecure: false,
|
||||
})
|
||||
gotMissingImage, gotMissingLayerIDs, err := c.MissingBlobs(tt.args.imageID, tt.args.layerIDs)
|
||||
if tt.wantErr != "" {
|
||||
@@ -339,10 +337,12 @@ func TestRemoteCache_PutArtifactInsecure(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := cache.NewRemoteCache(cache.RemoteOptions{
|
||||
ctx := xhttp.WithTransport(t.Context(), xhttp.NewTransport(xhttp.Options{
|
||||
Insecure: tt.args.insecure,
|
||||
}))
|
||||
c := cache.NewRemoteCache(ctx, cache.RemoteOptions{
|
||||
ServerAddr: ts.URL,
|
||||
CustomHeaders: nil,
|
||||
Insecure: tt.args.insecure,
|
||||
})
|
||||
err := c.PutArtifact(tt.args.imageID, tt.args.imageInfo)
|
||||
if tt.wantErr != "" {
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/scan"
|
||||
"github.com/aquasecurity/trivy/pkg/types"
|
||||
"github.com/aquasecurity/trivy/pkg/version/doc"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
)
|
||||
|
||||
// TargetKind represents what kind of artifact Trivy scans
|
||||
@@ -118,6 +119,12 @@ func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...RunnerOptio
|
||||
opt(r)
|
||||
}
|
||||
|
||||
// Set the default HTTP transport
|
||||
xhttp.SetDefaultTransport(xhttp.NewTransport(xhttp.Options{
|
||||
Insecure: cliOptions.Insecure,
|
||||
Timeout: cliOptions.Timeout,
|
||||
}))
|
||||
|
||||
// If the user has not disabled notices or is running in quiet mode
|
||||
r.versionChecker = notification.NewVersionChecker(
|
||||
notification.WithSkipVersionCheck(cliOptions.SkipVersionCheck),
|
||||
|
||||
@@ -195,7 +195,7 @@ func initializeRemoteImageScanService(ctx context.Context, imageName string, rem
|
||||
if err != nil {
|
||||
return scan.Service{}, nil, err
|
||||
}
|
||||
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
|
||||
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
|
||||
artifactArtifact, err := image2.NewArtifact(typesImage, remoteCache, artifactOption)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
@@ -220,7 +220,7 @@ func initializeRemoteArchiveScanService(ctx context.Context, filePath string, re
|
||||
if err != nil {
|
||||
return scan.Service{}, nil, err
|
||||
}
|
||||
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
|
||||
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
|
||||
artifactArtifact, err := image2.NewArtifact(typesImage, remoteCache, artifactOption)
|
||||
if err != nil {
|
||||
return scan.Service{}, nil, err
|
||||
@@ -234,7 +234,7 @@ func initializeRemoteArchiveScanService(ctx context.Context, filePath string, re
|
||||
func initializeRemoteFilesystemScanService(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) {
|
||||
v := _wireValue
|
||||
service := client.NewService(remoteScanOptions, v...)
|
||||
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
|
||||
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
|
||||
fs := walker.NewFS()
|
||||
artifactArtifact, err := local2.NewArtifact(path, remoteCache, fs, artifactOption)
|
||||
if err != nil {
|
||||
@@ -249,7 +249,7 @@ func initializeRemoteFilesystemScanService(ctx context.Context, path string, rem
|
||||
func initializeRemoteRepositoryScanService(ctx context.Context, url string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) {
|
||||
v := _wireValue
|
||||
service := client.NewService(remoteScanOptions, v...)
|
||||
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
|
||||
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
|
||||
fs := walker.NewFS()
|
||||
artifactArtifact, cleanup, err := repo.NewArtifact(url, remoteCache, fs, artifactOption)
|
||||
if err != nil {
|
||||
@@ -265,7 +265,7 @@ func initializeRemoteRepositoryScanService(ctx context.Context, url string, remo
|
||||
func initializeRemoteSBOMScanService(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) {
|
||||
v := _wireValue
|
||||
service := client.NewService(remoteScanOptions, v...)
|
||||
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
|
||||
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
|
||||
artifactArtifact, err := sbom.NewArtifact(path, remoteCache, artifactOption)
|
||||
if err != nil {
|
||||
return scan.Service{}, nil, err
|
||||
@@ -279,7 +279,7 @@ func initializeRemoteSBOMScanService(ctx context.Context, path string, remoteCac
|
||||
func initializeRemoteVMScanService(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) {
|
||||
v := _wireValue
|
||||
service := client.NewService(remoteScanOptions, v...)
|
||||
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
|
||||
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
|
||||
walkerVM := walker.NewVM()
|
||||
artifactArtifact, err := vm.NewArtifact(path, remoteCache, walkerVM, artifactOption)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,19 +2,18 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/docker/cli/cli/config"
|
||||
"github.com/docker/cli/cli/config/types"
|
||||
"github.com/google/go-containerregistry/pkg/authn"
|
||||
"github.com/google/go-containerregistry/pkg/name"
|
||||
"github.com/google/go-containerregistry/pkg/v1/remote"
|
||||
"github.com/google/go-containerregistry/pkg/v1/remote/transport"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/flag"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
)
|
||||
|
||||
func Login(ctx context.Context, registry string, opts flag.Options) error {
|
||||
@@ -34,7 +33,7 @@ func Login(ctx context.Context, registry string, opts flag.Options) error {
|
||||
_, err = transport.NewWithContext(ctx, reg, &authn.Basic{
|
||||
Username: opts.Credentials[0].Username,
|
||||
Password: opts.Credentials[0].Password,
|
||||
}, httpTransport(opts), nil)
|
||||
}, xhttp.Transport(ctx), nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
@@ -99,11 +98,3 @@ func parseRegistry(registry string, opts flag.Options) (name.Registry, error) {
|
||||
}
|
||||
return reg, nil
|
||||
}
|
||||
|
||||
func httpTransport(opts flag.Options) *http.Transport {
|
||||
tr := remote.DefaultTransport.(*http.Transport).Clone()
|
||||
if opts.Insecure {
|
||||
tr.TLSClientConfig.InsecureSkipVerify = true
|
||||
}
|
||||
return tr
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
xio "github.com/aquasecurity/trivy/pkg/x/io"
|
||||
)
|
||||
|
||||
@@ -742,7 +743,7 @@ func (p *Parser) fetchPomFileNameFromMavenMetadata(repo string, paths []string)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
client := xhttp.Client()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
p.logger.Debug("Failed to fetch", log.String("url", req.URL.String()), log.Err(err))
|
||||
@@ -776,7 +777,7 @@ func (p *Parser) fetchPOMFromRemoteRepository(repo string, paths []string) (*pom
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
client := xhttp.Client()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
p.logger.Debug("Failed to fetch", log.String("url", req.URL.String()), log.Err(err))
|
||||
|
||||
@@ -3,7 +3,6 @@ package downloader
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"maps"
|
||||
"net/http"
|
||||
@@ -13,9 +12,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/go-github/v62/github"
|
||||
getter "github.com/hashicorp/go-getter"
|
||||
"github.com/hashicorp/go-getter"
|
||||
"github.com/samber/lo"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
)
|
||||
|
||||
var ErrSkipDownload = errors.New("skip download")
|
||||
@@ -106,14 +107,12 @@ type CustomTransport struct {
|
||||
auth Auth
|
||||
cachedETag string
|
||||
newETag string
|
||||
insecure bool
|
||||
}
|
||||
|
||||
func NewCustomTransport(opts Options) *CustomTransport {
|
||||
return &CustomTransport{
|
||||
auth: opts.Auth,
|
||||
cachedETag: opts.ETag,
|
||||
insecure: opts.Insecure,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,12 +126,9 @@ func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.SetBasicAuth(t.auth.Username, t.auth.Password)
|
||||
}
|
||||
|
||||
var transport http.RoundTripper
|
||||
transport := xhttp.Transport(req.Context())
|
||||
if req.URL.Host == "github.com" {
|
||||
transport = NewGitHubTransport(req.URL, t.insecure, t.auth.Token)
|
||||
}
|
||||
if transport == nil {
|
||||
transport = httpTransport(t.insecure)
|
||||
transport = NewGitHubTransport(req.URL, t.auth.Token)
|
||||
}
|
||||
|
||||
res, err := transport.RoundTrip(req)
|
||||
@@ -151,8 +147,8 @@ func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func NewGitHubTransport(u *url.URL, insecure bool, token string) http.RoundTripper {
|
||||
client := newGitHubClient(insecure, token)
|
||||
func NewGitHubTransport(u *url.URL, token string) http.RoundTripper {
|
||||
client := newGitHubClient(token)
|
||||
ss := strings.SplitN(u.Path, "/", 4)
|
||||
if len(ss) < 4 || strings.HasPrefix(ss[3], "archive/") || strings.HasPrefix(ss[3], "releases/") ||
|
||||
strings.HasPrefix(ss[3], "tags/") {
|
||||
@@ -185,17 +181,11 @@ func (t *GitHubContentTransport) RoundTrip(req *http.Request) (*http.Response, e
|
||||
return res.Response, nil
|
||||
}
|
||||
|
||||
func newGitHubClient(insecure bool, token string) *github.Client {
|
||||
client := github.NewClient(&http.Client{Transport: httpTransport(insecure)})
|
||||
func newGitHubClient(token string) *github.Client {
|
||||
client := github.NewClient(xhttp.Client())
|
||||
token = cmp.Or(token, os.Getenv("GITHUB_TOKEN"))
|
||||
if token != "" {
|
||||
client = client.WithAuthToken(token)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func httpTransport(insecure bool) *http.Transport {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: insecure}
|
||||
return tr
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/downloader"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
)
|
||||
|
||||
func TestDownload(t *testing.T) {
|
||||
@@ -42,8 +43,13 @@ func TestDownload(t *testing.T) {
|
||||
// Set up the destination path
|
||||
dst := t.TempDir()
|
||||
|
||||
// Configure the HTTP transport with the insecure option
|
||||
ctx := xhttp.WithTransport(t.Context(), xhttp.NewTransport(xhttp.Options{
|
||||
Insecure: tt.insecure,
|
||||
}))
|
||||
|
||||
// Execute the download
|
||||
_, err := downloader.Download(t.Context(), server.URL, dst, "", downloader.Options{
|
||||
_, err := downloader.Download(ctx, server.URL, dst, "", downloader.Options{
|
||||
Insecure: tt.insecure,
|
||||
})
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/docker/go-connections/nat"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
testcontainers "github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
|
||||
"github.com/aquasecurity/trivy/internal/testutil"
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/image"
|
||||
testdocker "github.com/aquasecurity/trivy/pkg/fanal/test/integration/docker"
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
|
||||
_ "github.com/aquasecurity/trivy/pkg/fanal/analyzer/all"
|
||||
)
|
||||
@@ -228,6 +229,11 @@ func analyze(t *testing.T, ctx context.Context, imageRef string, opt types.Image
|
||||
}
|
||||
cli.NegotiateAPIVersion(ctx)
|
||||
|
||||
// Configure custom transport with insecure option
|
||||
ctx = xhttp.WithTransport(ctx, xhttp.NewTransport(xhttp.Options{
|
||||
Insecure: opt.RegistryOptions.Insecure,
|
||||
}))
|
||||
|
||||
img, cleanup, err := image.NewContainerImage(ctx, imageRef, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -88,10 +88,6 @@ type RegistryOptions struct {
|
||||
// SSL/TLS
|
||||
Insecure bool
|
||||
|
||||
// For internal use. Needed for mTLS authentication.
|
||||
ClientCert []byte
|
||||
ClientKey []byte
|
||||
|
||||
// Architecture
|
||||
Platform Platform
|
||||
|
||||
|
||||
@@ -517,7 +517,6 @@ func (o *Options) RemoteCacheOpts() cache.RemoteOptions {
|
||||
return cache.RemoteOptions{
|
||||
ServerAddr: o.ServerAddr,
|
||||
CustomHeaders: o.CustomHeaders,
|
||||
Insecure: o.Insecure,
|
||||
PathPrefix: o.PathPrefix,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ package resolvers_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/internal/gittest"
|
||||
"github.com/aquasecurity/trivy/pkg/iac/scanners/terraform/parser/resolvers"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
)
|
||||
|
||||
type moduleResolver interface {
|
||||
@@ -73,9 +73,7 @@ func TestResolveModuleFromCache(t *testing.T) {
|
||||
opts: resolvers.Options{
|
||||
Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws",
|
||||
Client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}),
|
||||
},
|
||||
},
|
||||
firstResolver: resolvers.Registry,
|
||||
@@ -87,9 +85,7 @@ func TestResolveModuleFromCache(t *testing.T) {
|
||||
opts: resolvers.Options{
|
||||
Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws//modules/object",
|
||||
Client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}),
|
||||
},
|
||||
},
|
||||
firstResolver: resolvers.Registry,
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/aquasecurity/go-version/pkg/version"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
)
|
||||
|
||||
type registryResolver struct {
|
||||
@@ -23,10 +24,8 @@ type registryResolver struct {
|
||||
}
|
||||
|
||||
var Registry = ®istryResolver{
|
||||
client: &http.Client{
|
||||
// give it a maximum 5 seconds to resolve the module
|
||||
Timeout: time.Second * 5,
|
||||
},
|
||||
// give it a maximum 5 seconds to resolve the module
|
||||
client: xhttp.Client(xhttp.WithTimeout(5 * time.Second)),
|
||||
}
|
||||
|
||||
type moduleVersions struct {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/aquasecurity/go-version/pkg/semver"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/version/app"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
)
|
||||
|
||||
type VersionChecker struct {
|
||||
@@ -60,9 +61,7 @@ func (v *VersionChecker) RunUpdateCheck(ctx context.Context, args []string) {
|
||||
go func() {
|
||||
logger.Debug("Running version check")
|
||||
args = getFlags(args)
|
||||
client := &http.Client{
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
client := xhttp.ClientWithContext(ctx, xhttp.WithTimeout(3*time.Second))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, v.updatesApi, http.NoBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,20 +2,15 @@ package remote
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-containerregistry/pkg/authn"
|
||||
"github.com/google/go-containerregistry/pkg/authn/github"
|
||||
"github.com/google/go-containerregistry/pkg/name"
|
||||
v1 "github.com/google/go-containerregistry/pkg/v1"
|
||||
"github.com/google/go-containerregistry/pkg/v1/remote"
|
||||
"github.com/google/go-containerregistry/pkg/v1/remote/transport"
|
||||
v1types "github.com/google/go-containerregistry/pkg/v1/types"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/samber/lo"
|
||||
@@ -24,7 +19,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/image/registry"
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/version/app"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
)
|
||||
|
||||
type Descriptor = remote.Descriptor
|
||||
@@ -32,13 +27,8 @@ type Descriptor = remote.Descriptor
|
||||
// Get is a wrapper of google/go-containerregistry/pkg/v1/remote.Get
|
||||
// so that it can try multiple authentication methods.
|
||||
func Get(ctx context.Context, ref name.Reference, option types.RegistryOptions) (*Descriptor, error) {
|
||||
tr, err := httpTransport(option)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to create http transport: %w", err)
|
||||
}
|
||||
|
||||
return tryWithMirrors(ref, option, func(r name.Reference) (*Descriptor, error) {
|
||||
return tryGet(ctx, tr, r, option)
|
||||
return tryGet(ctx, xhttp.Transport(ctx), r, option)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -81,13 +71,8 @@ func tryGet(ctx context.Context, tr http.RoundTripper, ref name.Reference, optio
|
||||
// Image is a wrapper of google/go-containerregistry/pkg/v1/remote.Image
|
||||
// so that it can try multiple authentication methods.
|
||||
func Image(ctx context.Context, ref name.Reference, option types.RegistryOptions) (v1.Image, error) {
|
||||
tr, err := httpTransport(option)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to create http transport: %w", err)
|
||||
}
|
||||
|
||||
return tryWithMirrors(ref, option, func(r name.Reference) (v1.Image, error) {
|
||||
return tryImage(ctx, tr, r, option)
|
||||
return tryImage(ctx, xhttp.Transport(ctx), r, option)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -148,16 +133,11 @@ func tryImage(ctx context.Context, tr http.RoundTripper, ref name.Reference, opt
|
||||
// Referrers is a wrapper of google/go-containerregistry/pkg/v1/remote.Referrers
|
||||
// so that it can try multiple authentication methods.
|
||||
func Referrers(ctx context.Context, d name.Digest, option types.RegistryOptions) (v1.ImageIndex, error) {
|
||||
tr, err := httpTransport(option)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to create http transport: %w", err)
|
||||
}
|
||||
|
||||
var errs error
|
||||
// Try each authentication method until it succeeds
|
||||
for _, authOpt := range authOptions(ctx, d, option) {
|
||||
remoteOpts := []remote.Option{
|
||||
remote.WithTransport(tr),
|
||||
remote.WithTransport(xhttp.Transport(ctx)),
|
||||
authOpt,
|
||||
}
|
||||
index, err := remote.Referrers(d, remoteOpts...)
|
||||
@@ -197,26 +177,6 @@ func registryMirrors(hostRef name.Reference, option types.RegistryOptions) ([]na
|
||||
return mirrors, nil
|
||||
}
|
||||
|
||||
func httpTransport(option types.RegistryOptions) (http.RoundTripper, error) {
|
||||
d := &net.Dialer{
|
||||
Timeout: 10 * time.Minute,
|
||||
}
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.DialContext = d.DialContext
|
||||
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: option.Insecure}
|
||||
|
||||
if len(option.ClientCert) != 0 && len(option.ClientKey) != 0 {
|
||||
cert, err := tls.X509KeyPair(option.ClientCert, option.ClientKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr.TLSClientConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
tripper := transport.NewUserAgent(tr, fmt.Sprintf("trivy/%s", app.Version()))
|
||||
return tripper, nil
|
||||
}
|
||||
|
||||
func authOptions(ctx context.Context, ref name.Reference, option types.RegistryOptions) []remote.Option {
|
||||
var opts []remote.Option
|
||||
for _, cred := range option.Credentials {
|
||||
|
||||
@@ -368,8 +368,8 @@ func TestUserAgents(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, tracker.agents, 1)
|
||||
ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s go-containerregistry", app.Version()))
|
||||
require.True(t, ok, `user-agent header equals to "trivy/dev go-containerregistry"`)
|
||||
ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s", app.Version()))
|
||||
require.True(t, ok, `user-agent header equals to "trivy/dev"`)
|
||||
}
|
||||
|
||||
func localImage(t *testing.T) v1.Image {
|
||||
|
||||
@@ -2,7 +2,6 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
r "github.com/aquasecurity/trivy/pkg/rpc"
|
||||
"github.com/aquasecurity/trivy/pkg/types"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
"github.com/aquasecurity/trivy/pkg/x/slices"
|
||||
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
|
||||
"github.com/aquasecurity/trivy/rpc/common"
|
||||
@@ -47,15 +47,11 @@ type Service struct {
|
||||
|
||||
// NewService is the factory method to return RPC Service
|
||||
func NewService(scannerOptions ServiceOption, opts ...Option) Service {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: scannerOptions.Insecure}
|
||||
httpClient := &http.Client{Transport: tr}
|
||||
|
||||
var twirpOpts []twirp.ClientOption
|
||||
if scannerOptions.PathPrefix != "" {
|
||||
twirpOpts = append(twirpOpts, twirp.WithClientPathPrefix(scannerOptions.PathPrefix))
|
||||
}
|
||||
c := rpc.NewScannerProtobufClient(scannerOptions.RemoteURL, httpClient, twirpOpts...)
|
||||
c := rpc.NewScannerProtobufClient(scannerOptions.RemoteURL, xhttp.Client(), twirpOpts...)
|
||||
|
||||
o := &options{rpcClient: c}
|
||||
for _, opt := range opts {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
"github.com/aquasecurity/trivy-db/pkg/vulnsrc/vulnerability"
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/types"
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
"github.com/aquasecurity/trivy/rpc/common"
|
||||
rpc "github.com/aquasecurity/trivy/rpc/scanner"
|
||||
)
|
||||
@@ -236,11 +236,7 @@ func TestScanner_ScanServerInsecure(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := rpc.NewScannerProtobufClient(ts.URL, &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: tt.insecure,
|
||||
},
|
||||
},
|
||||
Transport: xhttp.NewTransport(xhttp.Options{Insecure: tt.insecure}),
|
||||
})
|
||||
s := NewService(ServiceOption{Insecure: tt.insecure}, WithRPCClient(c))
|
||||
_, err := s.Scan(t.Context(), "dummy", "", nil, types.ScanOptions{})
|
||||
|
||||
30
pkg/x/http/client.go
Normal file
30
pkg/x/http/client.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ClientOption func(client *http.Client)
|
||||
|
||||
func WithTimeout(timeout time.Duration) ClientOption {
|
||||
return func(client *http.Client) {
|
||||
client.Timeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func Client(opts ...ClientOption) *http.Client {
|
||||
return ClientWithContext(context.Background(), opts...)
|
||||
}
|
||||
|
||||
// ClientWithContext returns an HTTP client with the specified context and options.
|
||||
func ClientWithContext(ctx context.Context, opts ...ClientOption) *http.Client {
|
||||
c := &http.Client{
|
||||
Transport: Transport(ctx),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
79
pkg/x/http/transport.go
Normal file
79
pkg/x/http/transport.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/version/app"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultTransport = NewTransport(Options{})
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
type transportKey struct{}
|
||||
|
||||
// WithTransport returns a new context with the given transport.
|
||||
// This is mainly for testing when a different HTTP transport needs to be used.
|
||||
func WithTransport(ctx context.Context, tr http.RoundTripper) context.Context {
|
||||
return context.WithValue(ctx, transportKey{}, tr)
|
||||
}
|
||||
|
||||
// Options configures the transport settings
|
||||
type Options struct {
|
||||
Insecure bool
|
||||
Timeout time.Duration
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
// SetDefaultTransport sets the default transport configuration
|
||||
func SetDefaultTransport(tr http.RoundTripper) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
defaultTransport = tr
|
||||
}
|
||||
|
||||
// Transport returns the transport from the context, or the default transport if none is set.
|
||||
func Transport(ctx context.Context) http.RoundTripper {
|
||||
t, ok := ctx.Value(transportKey{}).(http.RoundTripper)
|
||||
if ok {
|
||||
// If the transport is already set in the context, return it.
|
||||
return t
|
||||
}
|
||||
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
|
||||
return defaultTransport
|
||||
}
|
||||
|
||||
// NewTransport creates a new HTTP transport with the specified options.
|
||||
// It should be used to initialize the default transport.
|
||||
// In most cases, you should use the `Transport` function to get the default transport.
|
||||
func NewTransport(opts Options) http.RoundTripper {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
// Set timeout (default to 5 minutes)
|
||||
timeout := cmp.Or(opts.Timeout, 5*time.Minute)
|
||||
d := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
}
|
||||
tr.DialContext = d.DialContext
|
||||
|
||||
// Configure TLS
|
||||
if opts.Insecure {
|
||||
tr.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: opts.Insecure,
|
||||
}
|
||||
}
|
||||
|
||||
userAgent := cmp.Or(opts.UserAgent, fmt.Sprintf("trivy/%s", app.Version()))
|
||||
return NewUserAgent(tr, userAgent)
|
||||
}
|
||||
28
pkg/x/http/useragent.go
Normal file
28
pkg/x/http/useragent.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type userAgentTransport struct {
|
||||
inner http.RoundTripper
|
||||
ua string
|
||||
}
|
||||
|
||||
// NewUserAgent returns an http.Roundtripper that sets the user agent
|
||||
//
|
||||
// User-Agent: trivy/v0.64.0
|
||||
func NewUserAgent(inner http.RoundTripper, ua string) http.RoundTripper {
|
||||
return &userAgentTransport{
|
||||
inner: inner,
|
||||
ua: ua,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper
|
||||
func (ut *userAgentTransport) RoundTrip(in *http.Request) (*http.Response, error) {
|
||||
if ut.ua != "" {
|
||||
in.Header.Set("User-Agent", ut.ua)
|
||||
}
|
||||
return ut.inner.RoundTrip(in)
|
||||
}
|
||||
102
pkg/x/http/useragent_test.go
Normal file
102
pkg/x/http/useragent_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
||||
)
|
||||
|
||||
func TestUserAgentTransport_RoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
existingHeaders map[string]string
|
||||
existingUA string
|
||||
wantUA string
|
||||
wantHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "default user agent",
|
||||
userAgent: "",
|
||||
wantUA: "trivy/dev",
|
||||
},
|
||||
{
|
||||
name: "custom user agent",
|
||||
userAgent: "custom-scanner/2.1",
|
||||
wantUA: "custom-scanner/2.1",
|
||||
},
|
||||
{
|
||||
name: "preserves existing headers",
|
||||
userAgent: "test-agent/1.0",
|
||||
existingHeaders: map[string]string{
|
||||
"Authorization": "Bearer token123",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
wantUA: "test-agent/1.0",
|
||||
wantHeaders: map[string]string{
|
||||
"Authorization": "Bearer token123",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "overwrites existing user agent",
|
||||
userAgent: "new-agent/2.0",
|
||||
existingUA: "old-agent/1.0",
|
||||
wantUA: "new-agent/2.0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a test server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check User-Agent
|
||||
gotUA := r.Header.Get("User-Agent")
|
||||
assert.Equal(t, tt.wantUA, gotUA)
|
||||
|
||||
// Check other headers are preserved
|
||||
for key, wantValue := range tt.wantHeaders {
|
||||
gotValue := r.Header.Get(key)
|
||||
assert.Equal(t, wantValue, gotValue, "header %s", key)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create transport with user agent
|
||||
transport := xhttp.NewTransport(xhttp.Options{
|
||||
Insecure: true,
|
||||
UserAgent: tt.userAgent,
|
||||
})
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequest(http.MethodGet, server.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set existing headers
|
||||
for key, value := range tt.existingHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Set existing User-Agent if specified
|
||||
if tt.existingUA != "" {
|
||||
req.Header.Set("User-Agent", tt.existingUA)
|
||||
}
|
||||
|
||||
// Make request
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user