refactor: centralize HTTP transport configuration (#9058)

Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
Teppei Fukuda
2025-06-24 21:43:58 +04:00
committed by GitHub
parent cd7c595e4a
commit 3adfd988d1
22 changed files with 310 additions and 139 deletions

18
pkg/cache/remote.go vendored
View File

@@ -2,7 +2,6 @@ package cache
import (
"context"
"crypto/tls"
"net/http"
"github.com/twitchtv/twirp"
@@ -11,6 +10,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/rpc"
"github.com/aquasecurity/trivy/pkg/rpc/client"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
rpcCache "github.com/aquasecurity/trivy/rpc/cache"
)
@@ -19,7 +19,6 @@ var _ ArtifactCache = (*RemoteCache)(nil)
type RemoteOptions struct {
ServerAddr string
CustomHeaders http.Header
Insecure bool
PathPrefix string
}
@@ -30,23 +29,14 @@ type RemoteCache struct {
}
// NewRemoteCache is the factory method for RemoteCache
func NewRemoteCache(opts RemoteOptions) *RemoteCache {
ctx := client.WithCustomHeaders(context.Background(), opts.CustomHeaders)
httpClient := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: opts.Insecure,
},
},
}
func NewRemoteCache(ctx context.Context, opts RemoteOptions) *RemoteCache {
ctx = client.WithCustomHeaders(ctx, opts.CustomHeaders)
var twirpOpts []twirp.ClientOption
if opts.PathPrefix != "" {
twirpOpts = append(twirpOpts, twirp.WithClientPathPrefix(opts.PathPrefix))
}
c := rpcCache.NewCacheProtobufClient(opts.ServerAddr, httpClient, twirpOpts...)
c := rpcCache.NewCacheProtobufClient(opts.ServerAddr, xhttp.ClientWithContext(ctx), twirpOpts...)
return &RemoteCache{
ctx: ctx,
client: c,

View File

@@ -16,6 +16,7 @@ import (
"github.com/aquasecurity/trivy/pkg/cache"
"github.com/aquasecurity/trivy/pkg/fanal/types"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
rpcCache "github.com/aquasecurity/trivy/rpc/cache"
rpcScanner "github.com/aquasecurity/trivy/rpc/scanner"
)
@@ -145,10 +146,9 @@ func TestRemoteCache_PutArtifact(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := cache.NewRemoteCache(cache.RemoteOptions{
c := cache.NewRemoteCache(t.Context(), cache.RemoteOptions{
ServerAddr: ts.URL,
CustomHeaders: tt.args.customHeaders,
Insecure: false,
})
err := c.PutArtifact(tt.args.imageID, tt.args.imageInfo)
if tt.wantErr != "" {
@@ -208,10 +208,9 @@ func TestRemoteCache_PutBlob(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := cache.NewRemoteCache(cache.RemoteOptions{
c := cache.NewRemoteCache(t.Context(), cache.RemoteOptions{
ServerAddr: ts.URL,
CustomHeaders: tt.args.customHeaders,
Insecure: false,
})
err := c.PutBlob(tt.args.diffID, tt.args.layerInfo)
if tt.wantErr != "" {
@@ -288,10 +287,9 @@ func TestRemoteCache_MissingBlobs(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := cache.NewRemoteCache(cache.RemoteOptions{
c := cache.NewRemoteCache(t.Context(), cache.RemoteOptions{
ServerAddr: ts.URL,
CustomHeaders: tt.args.customHeaders,
Insecure: false,
})
gotMissingImage, gotMissingLayerIDs, err := c.MissingBlobs(tt.args.imageID, tt.args.layerIDs)
if tt.wantErr != "" {
@@ -339,10 +337,12 @@ func TestRemoteCache_PutArtifactInsecure(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := cache.NewRemoteCache(cache.RemoteOptions{
ctx := xhttp.WithTransport(t.Context(), xhttp.NewTransport(xhttp.Options{
Insecure: tt.args.insecure,
}))
c := cache.NewRemoteCache(ctx, cache.RemoteOptions{
ServerAddr: ts.URL,
CustomHeaders: nil,
Insecure: tt.args.insecure,
})
err := c.PutArtifact(tt.args.imageID, tt.args.imageInfo)
if tt.wantErr != "" {

View File

@@ -34,6 +34,7 @@ import (
"github.com/aquasecurity/trivy/pkg/scan"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/version/doc"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
// TargetKind represents what kind of artifact Trivy scans
@@ -118,6 +119,12 @@ func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...RunnerOptio
opt(r)
}
// Set the default HTTP transport
xhttp.SetDefaultTransport(xhttp.NewTransport(xhttp.Options{
Insecure: cliOptions.Insecure,
Timeout: cliOptions.Timeout,
}))
// If the user has not disabled notices or is running in quiet mode
r.versionChecker = notification.NewVersionChecker(
notification.WithSkipVersionCheck(cliOptions.SkipVersionCheck),

View File

@@ -195,7 +195,7 @@ func initializeRemoteImageScanService(ctx context.Context, imageName string, rem
if err != nil {
return scan.Service{}, nil, err
}
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
artifactArtifact, err := image2.NewArtifact(typesImage, remoteCache, artifactOption)
if err != nil {
cleanup()
@@ -220,7 +220,7 @@ func initializeRemoteArchiveScanService(ctx context.Context, filePath string, re
if err != nil {
return scan.Service{}, nil, err
}
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
artifactArtifact, err := image2.NewArtifact(typesImage, remoteCache, artifactOption)
if err != nil {
return scan.Service{}, nil, err
@@ -234,7 +234,7 @@ func initializeRemoteArchiveScanService(ctx context.Context, filePath string, re
func initializeRemoteFilesystemScanService(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) {
v := _wireValue
service := client.NewService(remoteScanOptions, v...)
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
fs := walker.NewFS()
artifactArtifact, err := local2.NewArtifact(path, remoteCache, fs, artifactOption)
if err != nil {
@@ -249,7 +249,7 @@ func initializeRemoteFilesystemScanService(ctx context.Context, path string, rem
func initializeRemoteRepositoryScanService(ctx context.Context, url string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) {
v := _wireValue
service := client.NewService(remoteScanOptions, v...)
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
fs := walker.NewFS()
artifactArtifact, cleanup, err := repo.NewArtifact(url, remoteCache, fs, artifactOption)
if err != nil {
@@ -265,7 +265,7 @@ func initializeRemoteRepositoryScanService(ctx context.Context, url string, remo
func initializeRemoteSBOMScanService(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) {
v := _wireValue
service := client.NewService(remoteScanOptions, v...)
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
artifactArtifact, err := sbom.NewArtifact(path, remoteCache, artifactOption)
if err != nil {
return scan.Service{}, nil, err
@@ -279,7 +279,7 @@ func initializeRemoteSBOMScanService(ctx context.Context, path string, remoteCac
func initializeRemoteVMScanService(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ServiceOption, artifactOption artifact.Option) (scan.Service, func(), error) {
v := _wireValue
service := client.NewService(remoteScanOptions, v...)
remoteCache := cache.NewRemoteCache(remoteCacheOptions)
remoteCache := cache.NewRemoteCache(ctx, remoteCacheOptions)
walkerVM := walker.NewVM()
artifactArtifact, err := vm.NewArtifact(path, remoteCache, walkerVM, artifactOption)
if err != nil {

View File

@@ -2,19 +2,18 @@ package auth
import (
"context"
"net/http"
"os"
"github.com/docker/cli/cli/config"
"github.com/docker/cli/cli/config/types"
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/remote/transport"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/log"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
func Login(ctx context.Context, registry string, opts flag.Options) error {
@@ -34,7 +33,7 @@ func Login(ctx context.Context, registry string, opts flag.Options) error {
_, err = transport.NewWithContext(ctx, reg, &authn.Basic{
Username: opts.Credentials[0].Username,
Password: opts.Credentials[0].Password,
}, httpTransport(opts), nil)
}, xhttp.Transport(ctx), nil)
if err != nil {
return xerrors.Errorf("failed to authenticate: %w", err)
}
@@ -99,11 +98,3 @@ func parseRegistry(registry string, opts flag.Options) (name.Registry, error) {
}
return reg, nil
}
func httpTransport(opts flag.Options) *http.Transport {
tr := remote.DefaultTransport.(*http.Transport).Clone()
if opts.Insecure {
tr.TLSClientConfig.InsecureSkipVerify = true
}
return tr
}

View File

@@ -23,6 +23,7 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)
@@ -742,7 +743,7 @@ func (p *Parser) fetchPomFileNameFromMavenMetadata(repo string, paths []string)
return "", nil
}
client := &http.Client{}
client := xhttp.Client()
resp, err := client.Do(req)
if err != nil {
p.logger.Debug("Failed to fetch", log.String("url", req.URL.String()), log.Err(err))
@@ -776,7 +777,7 @@ func (p *Parser) fetchPOMFromRemoteRepository(repo string, paths []string) (*pom
return nil, nil
}
client := &http.Client{}
client := xhttp.Client()
resp, err := client.Do(req)
if err != nil {
p.logger.Debug("Failed to fetch", log.String("url", req.URL.String()), log.Err(err))

View File

@@ -3,7 +3,6 @@ package downloader
import (
"cmp"
"context"
"crypto/tls"
"errors"
"maps"
"net/http"
@@ -13,9 +12,11 @@ import (
"time"
"github.com/google/go-github/v62/github"
getter "github.com/hashicorp/go-getter"
"github.com/hashicorp/go-getter"
"github.com/samber/lo"
"golang.org/x/xerrors"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
var ErrSkipDownload = errors.New("skip download")
@@ -106,14 +107,12 @@ type CustomTransport struct {
auth Auth
cachedETag string
newETag string
insecure bool
}
func NewCustomTransport(opts Options) *CustomTransport {
return &CustomTransport{
auth: opts.Auth,
cachedETag: opts.ETag,
insecure: opts.Insecure,
}
}
@@ -127,12 +126,9 @@ func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.SetBasicAuth(t.auth.Username, t.auth.Password)
}
var transport http.RoundTripper
transport := xhttp.Transport(req.Context())
if req.URL.Host == "github.com" {
transport = NewGitHubTransport(req.URL, t.insecure, t.auth.Token)
}
if transport == nil {
transport = httpTransport(t.insecure)
transport = NewGitHubTransport(req.URL, t.auth.Token)
}
res, err := transport.RoundTrip(req)
@@ -151,8 +147,8 @@ func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return res, nil
}
func NewGitHubTransport(u *url.URL, insecure bool, token string) http.RoundTripper {
client := newGitHubClient(insecure, token)
func NewGitHubTransport(u *url.URL, token string) http.RoundTripper {
client := newGitHubClient(token)
ss := strings.SplitN(u.Path, "/", 4)
if len(ss) < 4 || strings.HasPrefix(ss[3], "archive/") || strings.HasPrefix(ss[3], "releases/") ||
strings.HasPrefix(ss[3], "tags/") {
@@ -185,17 +181,11 @@ func (t *GitHubContentTransport) RoundTrip(req *http.Request) (*http.Response, e
return res.Response, nil
}
func newGitHubClient(insecure bool, token string) *github.Client {
client := github.NewClient(&http.Client{Transport: httpTransport(insecure)})
func newGitHubClient(token string) *github.Client {
client := github.NewClient(xhttp.Client())
token = cmp.Or(token, os.Getenv("GITHUB_TOKEN"))
if token != "" {
client = client.WithAuthToken(token)
}
return client
}
func httpTransport(insecure bool) *http.Transport {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: insecure}
return tr
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/downloader"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
func TestDownload(t *testing.T) {
@@ -42,8 +43,13 @@ func TestDownload(t *testing.T) {
// Set up the destination path
dst := t.TempDir()
// Configure the HTTP transport with the insecure option
ctx := xhttp.WithTransport(t.Context(), xhttp.NewTransport(xhttp.Options{
Insecure: tt.insecure,
}))
// Execute the download
_, err := downloader.Download(t.Context(), server.URL, dst, "", downloader.Options{
_, err := downloader.Download(ctx, server.URL, dst, "", downloader.Options{
Insecure: tt.insecure,
})

View File

@@ -14,7 +14,7 @@ import (
"github.com/docker/go-connections/nat"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
testcontainers "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/aquasecurity/trivy/internal/testutil"
@@ -26,6 +26,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/image"
testdocker "github.com/aquasecurity/trivy/pkg/fanal/test/integration/docker"
"github.com/aquasecurity/trivy/pkg/fanal/types"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
_ "github.com/aquasecurity/trivy/pkg/fanal/analyzer/all"
)
@@ -228,6 +229,11 @@ func analyze(t *testing.T, ctx context.Context, imageRef string, opt types.Image
}
cli.NegotiateAPIVersion(ctx)
// Configure custom transport with insecure option
ctx = xhttp.WithTransport(ctx, xhttp.NewTransport(xhttp.Options{
Insecure: opt.RegistryOptions.Insecure,
}))
img, cleanup, err := image.NewContainerImage(ctx, imageRef, opt)
if err != nil {
return nil, err

View File

@@ -88,10 +88,6 @@ type RegistryOptions struct {
// SSL/TLS
Insecure bool
// For internal use. Needed for mTLS authentication.
ClientCert []byte
ClientKey []byte
// Architecture
Platform Platform

View File

@@ -517,7 +517,6 @@ func (o *Options) RemoteCacheOpts() cache.RemoteOptions {
return cache.RemoteOptions{
ServerAddr: o.ServerAddr,
CustomHeaders: o.CustomHeaders,
Insecure: o.Insecure,
PathPrefix: o.PathPrefix,
}
}

View File

@@ -4,7 +4,6 @@ package resolvers_test
import (
"context"
"crypto/tls"
"io/fs"
"net/http"
"net/http/httptest"
@@ -18,6 +17,7 @@ import (
"github.com/aquasecurity/trivy/internal/gittest"
"github.com/aquasecurity/trivy/pkg/iac/scanners/terraform/parser/resolvers"
"github.com/aquasecurity/trivy/pkg/log"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
type moduleResolver interface {
@@ -73,9 +73,7 @@ func TestResolveModuleFromCache(t *testing.T) {
opts: resolvers.Options{
Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws",
Client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}),
},
},
firstResolver: resolvers.Registry,
@@ -87,9 +85,7 @@ func TestResolveModuleFromCache(t *testing.T) {
opts: resolvers.Options{
Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws//modules/object",
Client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}),
},
},
firstResolver: resolvers.Registry,

View File

@@ -16,6 +16,7 @@ import (
"github.com/aquasecurity/go-version/pkg/version"
"github.com/aquasecurity/trivy/pkg/log"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
type registryResolver struct {
@@ -23,10 +24,8 @@ type registryResolver struct {
}
var Registry = &registryResolver{
client: &http.Client{
// give it a maximum 5 seconds to resolve the module
Timeout: time.Second * 5,
},
client: xhttp.Client(xhttp.WithTimeout(5 * time.Second)),
}
type moduleVersions struct {

View File

@@ -14,6 +14,7 @@ import (
"github.com/aquasecurity/go-version/pkg/semver"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/version/app"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
type VersionChecker struct {
@@ -60,9 +61,7 @@ func (v *VersionChecker) RunUpdateCheck(ctx context.Context, args []string) {
go func() {
logger.Debug("Running version check")
args = getFlags(args)
client := &http.Client{
Timeout: 3 * time.Second,
}
client := xhttp.ClientWithContext(ctx, xhttp.WithTimeout(3*time.Second))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, v.updatesApi, http.NoBody)
if err != nil {

View File

@@ -2,20 +2,15 @@ package remote
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/authn/github"
"github.com/google/go-containerregistry/pkg/name"
v1 "github.com/google/go-containerregistry/pkg/v1"
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/remote/transport"
v1types "github.com/google/go-containerregistry/pkg/v1/types"
"github.com/hashicorp/go-multierror"
"github.com/samber/lo"
@@ -24,7 +19,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/image/registry"
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/version/app"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
type Descriptor = remote.Descriptor
@@ -32,13 +27,8 @@ type Descriptor = remote.Descriptor
// Get is a wrapper of google/go-containerregistry/pkg/v1/remote.Get
// so that it can try multiple authentication methods.
func Get(ctx context.Context, ref name.Reference, option types.RegistryOptions) (*Descriptor, error) {
tr, err := httpTransport(option)
if err != nil {
return nil, xerrors.Errorf("failed to create http transport: %w", err)
}
return tryWithMirrors(ref, option, func(r name.Reference) (*Descriptor, error) {
return tryGet(ctx, tr, r, option)
return tryGet(ctx, xhttp.Transport(ctx), r, option)
})
}
@@ -81,13 +71,8 @@ func tryGet(ctx context.Context, tr http.RoundTripper, ref name.Reference, optio
// Image is a wrapper of google/go-containerregistry/pkg/v1/remote.Image
// so that it can try multiple authentication methods.
func Image(ctx context.Context, ref name.Reference, option types.RegistryOptions) (v1.Image, error) {
tr, err := httpTransport(option)
if err != nil {
return nil, xerrors.Errorf("failed to create http transport: %w", err)
}
return tryWithMirrors(ref, option, func(r name.Reference) (v1.Image, error) {
return tryImage(ctx, tr, r, option)
return tryImage(ctx, xhttp.Transport(ctx), r, option)
})
}
@@ -148,16 +133,11 @@ func tryImage(ctx context.Context, tr http.RoundTripper, ref name.Reference, opt
// Referrers is a wrapper of google/go-containerregistry/pkg/v1/remote.Referrers
// so that it can try multiple authentication methods.
func Referrers(ctx context.Context, d name.Digest, option types.RegistryOptions) (v1.ImageIndex, error) {
tr, err := httpTransport(option)
if err != nil {
return nil, xerrors.Errorf("failed to create http transport: %w", err)
}
var errs error
// Try each authentication method until it succeeds
for _, authOpt := range authOptions(ctx, d, option) {
remoteOpts := []remote.Option{
remote.WithTransport(tr),
remote.WithTransport(xhttp.Transport(ctx)),
authOpt,
}
index, err := remote.Referrers(d, remoteOpts...)
@@ -197,26 +177,6 @@ func registryMirrors(hostRef name.Reference, option types.RegistryOptions) ([]na
return mirrors, nil
}
func httpTransport(option types.RegistryOptions) (http.RoundTripper, error) {
d := &net.Dialer{
Timeout: 10 * time.Minute,
}
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialContext = d.DialContext
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: option.Insecure}
if len(option.ClientCert) != 0 && len(option.ClientKey) != 0 {
cert, err := tls.X509KeyPair(option.ClientCert, option.ClientKey)
if err != nil {
return nil, err
}
tr.TLSClientConfig.Certificates = []tls.Certificate{cert}
}
tripper := transport.NewUserAgent(tr, fmt.Sprintf("trivy/%s", app.Version()))
return tripper, nil
}
func authOptions(ctx context.Context, ref name.Reference, option types.RegistryOptions) []remote.Option {
var opts []remote.Option
for _, cred := range option.Credentials {

View File

@@ -368,8 +368,8 @@ func TestUserAgents(t *testing.T) {
require.NoError(t, err)
require.Len(t, tracker.agents, 1)
ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s go-containerregistry", app.Version()))
require.True(t, ok, `user-agent header equals to "trivy/dev go-containerregistry"`)
ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s", app.Version()))
require.True(t, ok, `user-agent header equals to "trivy/dev"`)
}
func localImage(t *testing.T) v1.Image {

View File

@@ -2,7 +2,6 @@ package client
import (
"context"
"crypto/tls"
"net/http"
"github.com/samber/lo"
@@ -12,6 +11,7 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
r "github.com/aquasecurity/trivy/pkg/rpc"
"github.com/aquasecurity/trivy/pkg/types"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
"github.com/aquasecurity/trivy/pkg/x/slices"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
"github.com/aquasecurity/trivy/rpc/common"
@@ -47,15 +47,11 @@ type Service struct {
// NewService is the factory method to return RPC Service
func NewService(scannerOptions ServiceOption, opts ...Option) Service {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: scannerOptions.Insecure}
httpClient := &http.Client{Transport: tr}
var twirpOpts []twirp.ClientOption
if scannerOptions.PathPrefix != "" {
twirpOpts = append(twirpOpts, twirp.WithClientPathPrefix(scannerOptions.PathPrefix))
}
c := rpc.NewScannerProtobufClient(scannerOptions.RemoteURL, httpClient, twirpOpts...)
c := rpc.NewScannerProtobufClient(scannerOptions.RemoteURL, xhttp.Client(), twirpOpts...)
o := &options{rpcClient: c}
for _, opt := range opts {

View File

@@ -1,7 +1,6 @@
package client
import (
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
@@ -18,6 +17,7 @@ import (
"github.com/aquasecurity/trivy-db/pkg/vulnsrc/vulnerability"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/types"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
"github.com/aquasecurity/trivy/rpc/common"
rpc "github.com/aquasecurity/trivy/rpc/scanner"
)
@@ -236,11 +236,7 @@ func TestScanner_ScanServerInsecure(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := rpc.NewScannerProtobufClient(ts.URL, &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: tt.insecure,
},
},
Transport: xhttp.NewTransport(xhttp.Options{Insecure: tt.insecure}),
})
s := NewService(ServiceOption{Insecure: tt.insecure}, WithRPCClient(c))
_, err := s.Scan(t.Context(), "dummy", "", nil, types.ScanOptions{})

30
pkg/x/http/client.go Normal file
View File

@@ -0,0 +1,30 @@
package http
import (
"context"
"net/http"
"time"
)
type ClientOption func(client *http.Client)
func WithTimeout(timeout time.Duration) ClientOption {
return func(client *http.Client) {
client.Timeout = timeout
}
}
func Client(opts ...ClientOption) *http.Client {
return ClientWithContext(context.Background(), opts...)
}
// ClientWithContext returns an HTTP client with the specified context and options.
func ClientWithContext(ctx context.Context, opts ...ClientOption) *http.Client {
c := &http.Client{
Transport: Transport(ctx),
}
for _, opt := range opts {
opt(c)
}
return c
}

79
pkg/x/http/transport.go Normal file
View File

@@ -0,0 +1,79 @@
package http
import (
"cmp"
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/aquasecurity/trivy/pkg/version/app"
)
var (
defaultTransport = NewTransport(Options{})
mu sync.RWMutex
)
type transportKey struct{}
// WithTransport returns a new context with the given transport.
// This is mainly for testing when a different HTTP transport needs to be used.
func WithTransport(ctx context.Context, tr http.RoundTripper) context.Context {
return context.WithValue(ctx, transportKey{}, tr)
}
// Options configures the transport settings
type Options struct {
Insecure bool
Timeout time.Duration
UserAgent string
}
// SetDefaultTransport sets the default transport configuration
func SetDefaultTransport(tr http.RoundTripper) {
mu.Lock()
defer mu.Unlock()
defaultTransport = tr
}
// Transport returns the transport from the context, or the default transport if none is set.
func Transport(ctx context.Context) http.RoundTripper {
t, ok := ctx.Value(transportKey{}).(http.RoundTripper)
if ok {
// If the transport is already set in the context, return it.
return t
}
mu.RLock()
defer mu.RUnlock()
return defaultTransport
}
// NewTransport creates a new HTTP transport with the specified options.
// It should be used to initialize the default transport.
// In most cases, you should use the `Transport` function to get the default transport.
func NewTransport(opts Options) http.RoundTripper {
tr := http.DefaultTransport.(*http.Transport).Clone()
// Set timeout (default to 5 minutes)
timeout := cmp.Or(opts.Timeout, 5*time.Minute)
d := &net.Dialer{
Timeout: timeout,
}
tr.DialContext = d.DialContext
// Configure TLS
if opts.Insecure {
tr.TLSClientConfig = &tls.Config{
InsecureSkipVerify: opts.Insecure,
}
}
userAgent := cmp.Or(opts.UserAgent, fmt.Sprintf("trivy/%s", app.Version()))
return NewUserAgent(tr, userAgent)
}

28
pkg/x/http/useragent.go Normal file
View File

@@ -0,0 +1,28 @@
package http
import (
"net/http"
)
type userAgentTransport struct {
inner http.RoundTripper
ua string
}
// NewUserAgent returns an http.Roundtripper that sets the user agent
//
// User-Agent: trivy/v0.64.0
func NewUserAgent(inner http.RoundTripper, ua string) http.RoundTripper {
return &userAgentTransport{
inner: inner,
ua: ua,
}
}
// RoundTrip implements http.RoundTripper
func (ut *userAgentTransport) RoundTrip(in *http.Request) (*http.Response, error) {
if ut.ua != "" {
in.Header.Set("User-Agent", ut.ua)
}
return ut.inner.RoundTrip(in)
}

View File

@@ -0,0 +1,102 @@
package http_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
func TestUserAgentTransport_RoundTrip(t *testing.T) {
tests := []struct {
name string
userAgent string
existingHeaders map[string]string
existingUA string
wantUA string
wantHeaders map[string]string
}{
{
name: "default user agent",
userAgent: "",
wantUA: "trivy/dev",
},
{
name: "custom user agent",
userAgent: "custom-scanner/2.1",
wantUA: "custom-scanner/2.1",
},
{
name: "preserves existing headers",
userAgent: "test-agent/1.0",
existingHeaders: map[string]string{
"Authorization": "Bearer token123",
"Content-Type": "application/json",
},
wantUA: "test-agent/1.0",
wantHeaders: map[string]string{
"Authorization": "Bearer token123",
"Content-Type": "application/json",
},
},
{
name: "overwrites existing user agent",
userAgent: "new-agent/2.0",
existingUA: "old-agent/1.0",
wantUA: "new-agent/2.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check User-Agent
gotUA := r.Header.Get("User-Agent")
assert.Equal(t, tt.wantUA, gotUA)
// Check other headers are preserved
for key, wantValue := range tt.wantHeaders {
gotValue := r.Header.Get(key)
assert.Equal(t, wantValue, gotValue, "header %s", key)
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Create transport with user agent
transport := xhttp.NewTransport(xhttp.Options{
Insecure: true,
UserAgent: tt.userAgent,
})
client := &http.Client{Transport: transport}
// Create request
req, err := http.NewRequest(http.MethodGet, server.URL, http.NoBody)
require.NoError(t, err)
// Set existing headers
for key, value := range tt.existingHeaders {
req.Header.Set(key, value)
}
// Set existing User-Agent if specified
if tt.existingUA != "" {
req.Header.Set("User-Agent", tt.existingUA)
}
// Make request
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
}
}