diff --git a/pkg/cache/remote.go b/pkg/cache/remote.go index 19d748bf79..6a3b8a8eaa 100644 --- a/pkg/cache/remote.go +++ b/pkg/cache/remote.go @@ -2,7 +2,6 @@ package cache import ( "context" - "crypto/tls" "net/http" "github.com/twitchtv/twirp" @@ -11,6 +10,7 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/rpc" "github.com/aquasecurity/trivy/pkg/rpc/client" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" rpcCache "github.com/aquasecurity/trivy/rpc/cache" ) @@ -19,7 +19,6 @@ var _ ArtifactCache = (*RemoteCache)(nil) type RemoteOptions struct { ServerAddr string CustomHeaders http.Header - Insecure bool PathPrefix string } @@ -30,23 +29,14 @@ type RemoteCache struct { } // NewRemoteCache is the factory method for RemoteCache -func NewRemoteCache(opts RemoteOptions) *RemoteCache { - ctx := client.WithCustomHeaders(context.Background(), opts.CustomHeaders) - - httpClient := &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: opts.Insecure, - }, - }, - } +func NewRemoteCache(ctx context.Context, opts RemoteOptions) *RemoteCache { + ctx = client.WithCustomHeaders(ctx, opts.CustomHeaders) var twirpOpts []twirp.ClientOption if opts.PathPrefix != "" { twirpOpts = append(twirpOpts, twirp.WithClientPathPrefix(opts.PathPrefix)) } - c := rpcCache.NewCacheProtobufClient(opts.ServerAddr, httpClient, twirpOpts...) + c := rpcCache.NewCacheProtobufClient(opts.ServerAddr, xhttp.ClientWithContext(ctx), twirpOpts...) return &RemoteCache{ ctx: ctx, client: c, diff --git a/pkg/cache/remote_test.go b/pkg/cache/remote_test.go index 7c464c6593..a6c1e9dcba 100644 --- a/pkg/cache/remote_test.go +++ b/pkg/cache/remote_test.go @@ -16,6 +16,7 @@ import ( "github.com/aquasecurity/trivy/pkg/cache" "github.com/aquasecurity/trivy/pkg/fanal/types" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" rpcCache "github.com/aquasecurity/trivy/rpc/cache" rpcScanner "github.com/aquasecurity/trivy/rpc/scanner" ) @@ -145,10 +146,9 @@ func TestRemoteCache_PutArtifact(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := cache.NewRemoteCache(cache.RemoteOptions{ + c := cache.NewRemoteCache(t.Context(), cache.RemoteOptions{ ServerAddr: ts.URL, CustomHeaders: tt.args.customHeaders, - Insecure: false, }) err := c.PutArtifact(tt.args.imageID, tt.args.imageInfo) if tt.wantErr != "" { @@ -208,10 +208,9 @@ func TestRemoteCache_PutBlob(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := cache.NewRemoteCache(cache.RemoteOptions{ + c := cache.NewRemoteCache(t.Context(), cache.RemoteOptions{ ServerAddr: ts.URL, CustomHeaders: tt.args.customHeaders, - Insecure: false, }) err := c.PutBlob(tt.args.diffID, tt.args.layerInfo) if tt.wantErr != "" { @@ -288,10 +287,9 @@ func TestRemoteCache_MissingBlobs(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := cache.NewRemoteCache(cache.RemoteOptions{ + c := cache.NewRemoteCache(t.Context(), cache.RemoteOptions{ ServerAddr: ts.URL, CustomHeaders: tt.args.customHeaders, - Insecure: false, }) gotMissingImage, gotMissingLayerIDs, err := c.MissingBlobs(tt.args.imageID, tt.args.layerIDs) if tt.wantErr != "" { @@ -339,10 +337,12 @@ func TestRemoteCache_PutArtifactInsecure(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := cache.NewRemoteCache(cache.RemoteOptions{ + ctx := xhttp.WithTransport(t.Context(), xhttp.NewTransport(xhttp.Options{ + Insecure: tt.args.insecure, + })) + c := cache.NewRemoteCache(ctx, cache.RemoteOptions{ ServerAddr: ts.URL, CustomHeaders: nil, - Insecure: tt.args.insecure, }) err := c.PutArtifact(tt.args.imageID, tt.args.imageInfo) if tt.wantErr != "" { diff --git a/pkg/commands/artifact/run.go b/pkg/commands/artifact/run.go index 40aba95b85..350fb70e72 100644 --- a/pkg/commands/artifact/run.go +++ b/pkg/commands/artifact/run.go @@ -34,6 +34,7 @@ import ( "github.com/aquasecurity/trivy/pkg/scan" "github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/version/doc" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" ) // TargetKind represents what kind of artifact Trivy scans @@ -118,6 +119,12 @@ func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...RunnerOptio opt(r) } + // Set the default HTTP transport + xhttp.SetDefaultTransport(xhttp.NewTransport(xhttp.Options{ + Insecure: cliOptions.Insecure, + Timeout: cliOptions.Timeout, + })) + // If the user has not disabled notices or is running in quiet mode r.versionChecker = notification.NewVersionChecker( notification.WithSkipVersionCheck(cliOptions.SkipVersionCheck), diff --git a/pkg/commands/artifact/wire_gen.go b/pkg/commands/artifact/wire_gen.go index 8c80120589..5e197e3d97 100644 --- a/pkg/commands/artifact/wire_gen.go +++ b/pkg/commands/artifact/wire_gen.go @@ -195,7 +195,7 @@ func initializeRemoteImageScanService(ctx context.Context, imageName string, rem if err != nil { return scan.Service{}, nil, err } - remoteCache := cache.NewRemoteCache(remoteCacheOptions) + remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions) artifactArtifact, err := image2.NewArtifact(typesImage, remoteCache, artifactOption) if err != nil { cleanup() @@ -220,7 +220,7 @@ func initializeRemoteArchiveScanService(ctx context.Context, filePath string, re if err != nil { return scan.Service{}, nil, err } - remoteCache := cache.NewRemoteCache(remoteCacheOptions) + remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions) artifactArtifact, err := image2.NewArtifact(typesImage, remoteCache, artifactOption) if err != nil { return scan.Service{}, nil, err @@ -234,7 +234,7 @@ func initializeRemoteArchiveScanService(ctx context.Context, filePath string, re func initializeRemoteFilesystemScanService(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) { v := _wireValue service := client.NewService(remoteScanOptions, v...) - remoteCache := cache.NewRemoteCache(remoteCacheOptions) + remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions) fs := walker.NewFS() artifactArtifact, err := local2.NewArtifact(path, remoteCache, fs, artifactOption) if err != nil { @@ -249,7 +249,7 @@ func initializeRemoteFilesystemScanService(ctx context.Context, path string, rem func initializeRemoteRepositoryScanService(ctx context.Context, url string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) { v := _wireValue service := client.NewService(remoteScanOptions, v...) - remoteCache := cache.NewRemoteCache(remoteCacheOptions) + remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions) fs := walker.NewFS() artifactArtifact, cleanup, err := repo.NewArtifact(url, remoteCache, fs, artifactOption) if err != nil { @@ -265,7 +265,7 @@ func initializeRemoteRepositoryScanService(ctx context.Context, url string, remo func initializeRemoteSBOMScanService(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) { v := _wireValue service := client.NewService(remoteScanOptions, v...) - remoteCache := cache.NewRemoteCache(remoteCacheOptions) + remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions) artifactArtifact, err := sbom.NewArtifact(path, remoteCache, artifactOption) if err != nil { return scan.Service{}, nil, err @@ -279,7 +279,7 @@ func initializeRemoteSBOMScanService(ctx context.Context, path string, remoteCac func initializeRemoteVMScanService(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) { v := _wireValue service := client.NewService(remoteScanOptions, v...) - remoteCache := cache.NewRemoteCache(remoteCacheOptions) + remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions) walkerVM := walker.NewVM() artifactArtifact, err := vm.NewArtifact(path, remoteCache, walkerVM, artifactOption) if err != nil { diff --git a/pkg/commands/auth/run.go b/pkg/commands/auth/run.go index 6baa4d0d48..d22de4bdf5 100644 --- a/pkg/commands/auth/run.go +++ b/pkg/commands/auth/run.go @@ -2,19 +2,18 @@ package auth import ( "context" - "net/http" "os" "github.com/docker/cli/cli/config" "github.com/docker/cli/cli/config/types" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/name" - "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/remote/transport" "golang.org/x/xerrors" "github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/log" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" ) func Login(ctx context.Context, registry string, opts flag.Options) error { @@ -34,7 +33,7 @@ func Login(ctx context.Context, registry string, opts flag.Options) error { _, err = transport.NewWithContext(ctx, reg, &authn.Basic{ Username: opts.Credentials[0].Username, Password: opts.Credentials[0].Password, - }, httpTransport(opts), nil) + }, xhttp.Transport(ctx), nil) if err != nil { return xerrors.Errorf("failed to authenticate: %w", err) } @@ -99,11 +98,3 @@ func parseRegistry(registry string, opts flag.Options) (name.Registry, error) { } return reg, nil } - -func httpTransport(opts flag.Options) *http.Transport { - tr := remote.DefaultTransport.(*http.Transport).Clone() - if opts.Insecure { - tr.TLSClientConfig.InsecureSkipVerify = true - } - return tr -} diff --git a/pkg/dependency/parser/java/pom/parse.go b/pkg/dependency/parser/java/pom/parse.go index 946898cfab..91a02530ad 100644 --- a/pkg/dependency/parser/java/pom/parse.go +++ b/pkg/dependency/parser/java/pom/parse.go @@ -23,6 +23,7 @@ import ( ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/set" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" xio "github.com/aquasecurity/trivy/pkg/x/io" ) @@ -742,7 +743,7 @@ func (p *Parser) fetchPomFileNameFromMavenMetadata(repo string, paths []string) return "", nil } - client := &http.Client{} + client := xhttp.Client() resp, err := client.Do(req) if err != nil { p.logger.Debug("Failed to fetch", log.String("url", req.URL.String()), log.Err(err)) @@ -776,7 +777,7 @@ func (p *Parser) fetchPOMFromRemoteRepository(repo string, paths []string) (*pom return nil, nil } - client := &http.Client{} + client := xhttp.Client() resp, err := client.Do(req) if err != nil { p.logger.Debug("Failed to fetch", log.String("url", req.URL.String()), log.Err(err)) diff --git a/pkg/downloader/download.go b/pkg/downloader/download.go index 96fd3ce490..d1e37b68fe 100644 --- a/pkg/downloader/download.go +++ b/pkg/downloader/download.go @@ -3,7 +3,6 @@ package downloader import ( "cmp" "context" - "crypto/tls" "errors" "maps" "net/http" @@ -13,9 +12,11 @@ import ( "time" "github.com/google/go-github/v62/github" - getter "github.com/hashicorp/go-getter" + "github.com/hashicorp/go-getter" "github.com/samber/lo" "golang.org/x/xerrors" + + xhttp "github.com/aquasecurity/trivy/pkg/x/http" ) var ErrSkipDownload = errors.New("skip download") @@ -106,14 +107,12 @@ type CustomTransport struct { auth Auth cachedETag string newETag string - insecure bool } func NewCustomTransport(opts Options) *CustomTransport { return &CustomTransport{ auth: opts.Auth, cachedETag: opts.ETag, - insecure: opts.Insecure, } } @@ -127,12 +126,9 @@ func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) { req.SetBasicAuth(t.auth.Username, t.auth.Password) } - var transport http.RoundTripper + transport := xhttp.Transport(req.Context()) if req.URL.Host == "github.com" { - transport = NewGitHubTransport(req.URL, t.insecure, t.auth.Token) - } - if transport == nil { - transport = httpTransport(t.insecure) + transport = NewGitHubTransport(req.URL, t.auth.Token) } res, err := transport.RoundTrip(req) @@ -151,8 +147,8 @@ func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) { return res, nil } -func NewGitHubTransport(u *url.URL, insecure bool, token string) http.RoundTripper { - client := newGitHubClient(insecure, token) +func NewGitHubTransport(u *url.URL, token string) http.RoundTripper { + client := newGitHubClient(token) ss := strings.SplitN(u.Path, "/", 4) if len(ss) < 4 || strings.HasPrefix(ss[3], "archive/") || strings.HasPrefix(ss[3], "releases/") || strings.HasPrefix(ss[3], "tags/") { @@ -185,17 +181,11 @@ func (t *GitHubContentTransport) RoundTrip(req *http.Request) (*http.Response, e return res.Response, nil } -func newGitHubClient(insecure bool, token string) *github.Client { - client := github.NewClient(&http.Client{Transport: httpTransport(insecure)}) +func newGitHubClient(token string) *github.Client { + client := github.NewClient(xhttp.Client()) token = cmp.Or(token, os.Getenv("GITHUB_TOKEN")) if token != "" { client = client.WithAuthToken(token) } return client } - -func httpTransport(insecure bool) *http.Transport { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: insecure} - return tr -} diff --git a/pkg/downloader/downloader_test.go b/pkg/downloader/downloader_test.go index 67eb6428e3..14eb23caf9 100644 --- a/pkg/downloader/downloader_test.go +++ b/pkg/downloader/downloader_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/aquasecurity/trivy/pkg/downloader" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" ) func TestDownload(t *testing.T) { @@ -42,8 +43,13 @@ func TestDownload(t *testing.T) { // Set up the destination path dst := t.TempDir() + // Configure the HTTP transport with the insecure option + ctx := xhttp.WithTransport(t.Context(), xhttp.NewTransport(xhttp.Options{ + Insecure: tt.insecure, + })) + // Execute the download - _, err := downloader.Download(t.Context(), server.URL, dst, "", downloader.Options{ + _, err := downloader.Download(ctx, server.URL, dst, "", downloader.Options{ Insecure: tt.insecure, }) diff --git a/pkg/fanal/test/integration/registry_test.go b/pkg/fanal/test/integration/registry_test.go index 9ac682471e..2d803dc5d1 100644 --- a/pkg/fanal/test/integration/registry_test.go +++ b/pkg/fanal/test/integration/registry_test.go @@ -14,7 +14,7 @@ import ( "github.com/docker/go-connections/nat" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - testcontainers "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" "github.com/aquasecurity/trivy/internal/testutil" @@ -26,6 +26,7 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/image" testdocker "github.com/aquasecurity/trivy/pkg/fanal/test/integration/docker" "github.com/aquasecurity/trivy/pkg/fanal/types" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" _ "github.com/aquasecurity/trivy/pkg/fanal/analyzer/all" ) @@ -228,6 +229,11 @@ func analyze(t *testing.T, ctx context.Context, imageRef string, opt types.Image } cli.NegotiateAPIVersion(ctx) + // Configure custom transport with insecure option + ctx = xhttp.WithTransport(ctx, xhttp.NewTransport(xhttp.Options{ + Insecure: opt.RegistryOptions.Insecure, + })) + img, cleanup, err := image.NewContainerImage(ctx, imageRef, opt) if err != nil { return nil, err diff --git a/pkg/fanal/types/image.go b/pkg/fanal/types/image.go index 12c45de3ab..c648cd1206 100644 --- a/pkg/fanal/types/image.go +++ b/pkg/fanal/types/image.go @@ -88,10 +88,6 @@ type RegistryOptions struct { // SSL/TLS Insecure bool - // For internal use. Needed for mTLS authentication. - ClientCert []byte - ClientKey []byte - // Architecture Platform Platform diff --git a/pkg/flag/options.go b/pkg/flag/options.go index f0f8276a56..26eb5c7022 100644 --- a/pkg/flag/options.go +++ b/pkg/flag/options.go @@ -517,7 +517,6 @@ func (o *Options) RemoteCacheOpts() cache.RemoteOptions { return cache.RemoteOptions{ ServerAddr: o.ServerAddr, CustomHeaders: o.CustomHeaders, - Insecure: o.Insecure, PathPrefix: o.PathPrefix, } } diff --git a/pkg/iac/scanners/terraform/parser/resolvers/cache_integration_test.go b/pkg/iac/scanners/terraform/parser/resolvers/cache_integration_test.go index 213a896372..b6c4badfcd 100644 --- a/pkg/iac/scanners/terraform/parser/resolvers/cache_integration_test.go +++ b/pkg/iac/scanners/terraform/parser/resolvers/cache_integration_test.go @@ -4,7 +4,6 @@ package resolvers_test import ( "context" - "crypto/tls" "io/fs" "net/http" "net/http/httptest" @@ -18,6 +17,7 @@ import ( "github.com/aquasecurity/trivy/internal/gittest" "github.com/aquasecurity/trivy/pkg/iac/scanners/terraform/parser/resolvers" "github.com/aquasecurity/trivy/pkg/log" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" ) type moduleResolver interface { @@ -73,9 +73,7 @@ func TestResolveModuleFromCache(t *testing.T) { opts: resolvers.Options{ Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws", Client: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, + Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}), }, }, firstResolver: resolvers.Registry, @@ -87,9 +85,7 @@ func TestResolveModuleFromCache(t *testing.T) { opts: resolvers.Options{ Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws//modules/object", Client: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, + Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}), }, }, firstResolver: resolvers.Registry, diff --git a/pkg/iac/scanners/terraform/parser/resolvers/registry.go b/pkg/iac/scanners/terraform/parser/resolvers/registry.go index 3b37134b30..10a0d3973b 100644 --- a/pkg/iac/scanners/terraform/parser/resolvers/registry.go +++ b/pkg/iac/scanners/terraform/parser/resolvers/registry.go @@ -16,6 +16,7 @@ import ( "github.com/aquasecurity/go-version/pkg/version" "github.com/aquasecurity/trivy/pkg/log" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" ) type registryResolver struct { @@ -23,10 +24,8 @@ type registryResolver struct { } var Registry = ®istryResolver{ - client: &http.Client{ - // give it a maximum 5 seconds to resolve the module - Timeout: time.Second * 5, - }, + // give it a maximum 5 seconds to resolve the module + client: xhttp.Client(xhttp.WithTimeout(5 * time.Second)), } type moduleVersions struct { diff --git a/pkg/notification/notice.go b/pkg/notification/notice.go index 491ae0cb52..6d1dcf01e9 100644 --- a/pkg/notification/notice.go +++ b/pkg/notification/notice.go @@ -14,6 +14,7 @@ import ( "github.com/aquasecurity/go-version/pkg/semver" "github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/version/app" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" ) type VersionChecker struct { @@ -60,9 +61,7 @@ func (v *VersionChecker) RunUpdateCheck(ctx context.Context, args []string) { go func() { logger.Debug("Running version check") args = getFlags(args) - client := &http.Client{ - Timeout: 3 * time.Second, - } + client := xhttp.ClientWithContext(ctx, xhttp.WithTimeout(3*time.Second)) req, err := http.NewRequestWithContext(ctx, http.MethodGet, v.updatesApi, http.NoBody) if err != nil { diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index c9f75e8349..a713fbe6cc 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -2,20 +2,15 @@ package remote import ( "context" - "crypto/tls" "errors" - "fmt" - "net" "net/http" "strings" - "time" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/authn/github" "github.com/google/go-containerregistry/pkg/name" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/remote" - "github.com/google/go-containerregistry/pkg/v1/remote/transport" v1types "github.com/google/go-containerregistry/pkg/v1/types" "github.com/hashicorp/go-multierror" "github.com/samber/lo" @@ -24,7 +19,7 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/image/registry" "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/version/app" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" ) type Descriptor = remote.Descriptor @@ -32,13 +27,8 @@ type Descriptor = remote.Descriptor // Get is a wrapper of google/go-containerregistry/pkg/v1/remote.Get // so that it can try multiple authentication methods. func Get(ctx context.Context, ref name.Reference, option types.RegistryOptions) (*Descriptor, error) { - tr, err := httpTransport(option) - if err != nil { - return nil, xerrors.Errorf("failed to create http transport: %w", err) - } - return tryWithMirrors(ref, option, func(r name.Reference) (*Descriptor, error) { - return tryGet(ctx, tr, r, option) + return tryGet(ctx, xhttp.Transport(ctx), r, option) }) } @@ -81,13 +71,8 @@ func tryGet(ctx context.Context, tr http.RoundTripper, ref name.Reference, optio // Image is a wrapper of google/go-containerregistry/pkg/v1/remote.Image // so that it can try multiple authentication methods. func Image(ctx context.Context, ref name.Reference, option types.RegistryOptions) (v1.Image, error) { - tr, err := httpTransport(option) - if err != nil { - return nil, xerrors.Errorf("failed to create http transport: %w", err) - } - return tryWithMirrors(ref, option, func(r name.Reference) (v1.Image, error) { - return tryImage(ctx, tr, r, option) + return tryImage(ctx, xhttp.Transport(ctx), r, option) }) } @@ -148,16 +133,11 @@ func tryImage(ctx context.Context, tr http.RoundTripper, ref name.Reference, opt // Referrers is a wrapper of google/go-containerregistry/pkg/v1/remote.Referrers // so that it can try multiple authentication methods. func Referrers(ctx context.Context, d name.Digest, option types.RegistryOptions) (v1.ImageIndex, error) { - tr, err := httpTransport(option) - if err != nil { - return nil, xerrors.Errorf("failed to create http transport: %w", err) - } - var errs error // Try each authentication method until it succeeds for _, authOpt := range authOptions(ctx, d, option) { remoteOpts := []remote.Option{ - remote.WithTransport(tr), + remote.WithTransport(xhttp.Transport(ctx)), authOpt, } index, err := remote.Referrers(d, remoteOpts...) @@ -197,26 +177,6 @@ func registryMirrors(hostRef name.Reference, option types.RegistryOptions) ([]na return mirrors, nil } -func httpTransport(option types.RegistryOptions) (http.RoundTripper, error) { - d := &net.Dialer{ - Timeout: 10 * time.Minute, - } - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.DialContext = d.DialContext - tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: option.Insecure} - - if len(option.ClientCert) != 0 && len(option.ClientKey) != 0 { - cert, err := tls.X509KeyPair(option.ClientCert, option.ClientKey) - if err != nil { - return nil, err - } - tr.TLSClientConfig.Certificates = []tls.Certificate{cert} - } - - tripper := transport.NewUserAgent(tr, fmt.Sprintf("trivy/%s", app.Version())) - return tripper, nil -} - func authOptions(ctx context.Context, ref name.Reference, option types.RegistryOptions) []remote.Option { var opts []remote.Option for _, cred := range option.Credentials { diff --git a/pkg/remote/remote_test.go b/pkg/remote/remote_test.go index a2efab1d72..be2541f2eb 100644 --- a/pkg/remote/remote_test.go +++ b/pkg/remote/remote_test.go @@ -368,8 +368,8 @@ func TestUserAgents(t *testing.T) { require.NoError(t, err) require.Len(t, tracker.agents, 1) - ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s go-containerregistry", app.Version())) - require.True(t, ok, `user-agent header equals to "trivy/dev go-containerregistry"`) + ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s", app.Version())) + require.True(t, ok, `user-agent header equals to "trivy/dev"`) } func localImage(t *testing.T) v1.Image { diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index 80a82a12d9..e178dc49fb 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -2,7 +2,6 @@ package client import ( "context" - "crypto/tls" "net/http" "github.com/samber/lo" @@ -12,6 +11,7 @@ import ( ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" r "github.com/aquasecurity/trivy/pkg/rpc" "github.com/aquasecurity/trivy/pkg/types" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" "github.com/aquasecurity/trivy/pkg/x/slices" xstrings "github.com/aquasecurity/trivy/pkg/x/strings" "github.com/aquasecurity/trivy/rpc/common" @@ -47,15 +47,11 @@ type Service struct { // NewService is the factory method to return RPC Service func NewService(scannerOptions ServiceOption, opts ...Option) Service { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: scannerOptions.Insecure} - httpClient := &http.Client{Transport: tr} - var twirpOpts []twirp.ClientOption if scannerOptions.PathPrefix != "" { twirpOpts = append(twirpOpts, twirp.WithClientPathPrefix(scannerOptions.PathPrefix)) } - c := rpc.NewScannerProtobufClient(scannerOptions.RemoteURL, httpClient, twirpOpts...) + c := rpc.NewScannerProtobufClient(scannerOptions.RemoteURL, xhttp.Client(), twirpOpts...) o := &options{rpcClient: c} for _, opt := range opts { diff --git a/pkg/rpc/client/client_test.go b/pkg/rpc/client/client_test.go index 0eb19fd874..0da1cad218 100644 --- a/pkg/rpc/client/client_test.go +++ b/pkg/rpc/client/client_test.go @@ -1,7 +1,6 @@ package client import ( - "crypto/tls" "encoding/json" "fmt" "net/http" @@ -18,6 +17,7 @@ import ( "github.com/aquasecurity/trivy-db/pkg/vulnsrc/vulnerability" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/types" + xhttp "github.com/aquasecurity/trivy/pkg/x/http" "github.com/aquasecurity/trivy/rpc/common" rpc "github.com/aquasecurity/trivy/rpc/scanner" ) @@ -236,11 +236,7 @@ func TestScanner_ScanServerInsecure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := rpc.NewScannerProtobufClient(ts.URL, &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: tt.insecure, - }, - }, + Transport: xhttp.NewTransport(xhttp.Options{Insecure: tt.insecure}), }) s := NewService(ServiceOption{Insecure: tt.insecure}, WithRPCClient(c)) _, err := s.Scan(t.Context(), "dummy", "", nil, types.ScanOptions{}) diff --git a/pkg/x/http/client.go b/pkg/x/http/client.go new file mode 100644 index 0000000000..e97fefb340 --- /dev/null +++ b/pkg/x/http/client.go @@ -0,0 +1,30 @@ +package http + +import ( + "context" + "net/http" + "time" +) + +type ClientOption func(client *http.Client) + +func WithTimeout(timeout time.Duration) ClientOption { + return func(client *http.Client) { + client.Timeout = timeout + } +} + +func Client(opts ...ClientOption) *http.Client { + return ClientWithContext(context.Background(), opts...) +} + +// ClientWithContext returns an HTTP client with the specified context and options. +func ClientWithContext(ctx context.Context, opts ...ClientOption) *http.Client { + c := &http.Client{ + Transport: Transport(ctx), + } + for _, opt := range opts { + opt(c) + } + return c +} diff --git a/pkg/x/http/transport.go b/pkg/x/http/transport.go new file mode 100644 index 0000000000..01eb424a1b --- /dev/null +++ b/pkg/x/http/transport.go @@ -0,0 +1,79 @@ +package http + +import ( + "cmp" + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "sync" + "time" + + "github.com/aquasecurity/trivy/pkg/version/app" +) + +var ( + defaultTransport = NewTransport(Options{}) + mu sync.RWMutex +) + +type transportKey struct{} + +// WithTransport returns a new context with the given transport. +// This is mainly for testing when a different HTTP transport needs to be used. +func WithTransport(ctx context.Context, tr http.RoundTripper) context.Context { + return context.WithValue(ctx, transportKey{}, tr) +} + +// Options configures the transport settings +type Options struct { + Insecure bool + Timeout time.Duration + UserAgent string +} + +// SetDefaultTransport sets the default transport configuration +func SetDefaultTransport(tr http.RoundTripper) { + mu.Lock() + defer mu.Unlock() + defaultTransport = tr +} + +// Transport returns the transport from the context, or the default transport if none is set. +func Transport(ctx context.Context) http.RoundTripper { + t, ok := ctx.Value(transportKey{}).(http.RoundTripper) + if ok { + // If the transport is already set in the context, return it. + return t + } + + mu.RLock() + defer mu.RUnlock() + + return defaultTransport +} + +// NewTransport creates a new HTTP transport with the specified options. +// It should be used to initialize the default transport. +// In most cases, you should use the `Transport` function to get the default transport. +func NewTransport(opts Options) http.RoundTripper { + tr := http.DefaultTransport.(*http.Transport).Clone() + + // Set timeout (default to 5 minutes) + timeout := cmp.Or(opts.Timeout, 5*time.Minute) + d := &net.Dialer{ + Timeout: timeout, + } + tr.DialContext = d.DialContext + + // Configure TLS + if opts.Insecure { + tr.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: opts.Insecure, + } + } + + userAgent := cmp.Or(opts.UserAgent, fmt.Sprintf("trivy/%s", app.Version())) + return NewUserAgent(tr, userAgent) +} diff --git a/pkg/x/http/useragent.go b/pkg/x/http/useragent.go new file mode 100644 index 0000000000..0601682a01 --- /dev/null +++ b/pkg/x/http/useragent.go @@ -0,0 +1,28 @@ +package http + +import ( + "net/http" +) + +type userAgentTransport struct { + inner http.RoundTripper + ua string +} + +// NewUserAgent returns an http.Roundtripper that sets the user agent +// +// User-Agent: trivy/v0.64.0 +func NewUserAgent(inner http.RoundTripper, ua string) http.RoundTripper { + return &userAgentTransport{ + inner: inner, + ua: ua, + } +} + +// RoundTrip implements http.RoundTripper +func (ut *userAgentTransport) RoundTrip(in *http.Request) (*http.Response, error) { + if ut.ua != "" { + in.Header.Set("User-Agent", ut.ua) + } + return ut.inner.RoundTrip(in) +} diff --git a/pkg/x/http/useragent_test.go b/pkg/x/http/useragent_test.go new file mode 100644 index 0000000000..0486d237f9 --- /dev/null +++ b/pkg/x/http/useragent_test.go @@ -0,0 +1,102 @@ +package http_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + xhttp "github.com/aquasecurity/trivy/pkg/x/http" +) + +func TestUserAgentTransport_RoundTrip(t *testing.T) { + tests := []struct { + name string + userAgent string + existingHeaders map[string]string + existingUA string + wantUA string + wantHeaders map[string]string + }{ + { + name: "default user agent", + userAgent: "", + wantUA: "trivy/dev", + }, + { + name: "custom user agent", + userAgent: "custom-scanner/2.1", + wantUA: "custom-scanner/2.1", + }, + { + name: "preserves existing headers", + userAgent: "test-agent/1.0", + existingHeaders: map[string]string{ + "Authorization": "Bearer token123", + "Content-Type": "application/json", + }, + wantUA: "test-agent/1.0", + wantHeaders: map[string]string{ + "Authorization": "Bearer token123", + "Content-Type": "application/json", + }, + }, + { + name: "overwrites existing user agent", + userAgent: "new-agent/2.0", + existingUA: "old-agent/1.0", + wantUA: "new-agent/2.0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check User-Agent + gotUA := r.Header.Get("User-Agent") + assert.Equal(t, tt.wantUA, gotUA) + + // Check other headers are preserved + for key, wantValue := range tt.wantHeaders { + gotValue := r.Header.Get(key) + assert.Equal(t, wantValue, gotValue, "header %s", key) + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create transport with user agent + transport := xhttp.NewTransport(xhttp.Options{ + Insecure: true, + UserAgent: tt.userAgent, + }) + + client := &http.Client{Transport: transport} + + // Create request + req, err := http.NewRequest(http.MethodGet, server.URL, http.NoBody) + require.NoError(t, err) + + // Set existing headers + for key, value := range tt.existingHeaders { + req.Header.Set(key, value) + } + + // Set existing User-Agent if specified + if tt.existingUA != "" { + req.Header.Set("User-Agent", tt.existingUA) + } + + // Make request + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + } +}