Files
trivy/pkg/rpc/client/client.go
2025-03-23 23:47:03 +00:00

114 lines
3.1 KiB
Go

package client
import (
"context"
"crypto/tls"
"net/http"
"github.com/samber/lo"
"github.com/twitchtv/twirp"
"golang.org/x/xerrors"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
r "github.com/aquasecurity/trivy/pkg/rpc"
"github.com/aquasecurity/trivy/pkg/types"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
"github.com/aquasecurity/trivy/rpc/common"
rpc "github.com/aquasecurity/trivy/rpc/scanner"
)
type options struct {
rpcClient rpc.Scanner
}
type Option func(*options)
// WithRPCClient takes rpc client for testability
func WithRPCClient(c rpc.Scanner) Option {
return func(opts *options) {
opts.rpcClient = c
}
}
// ServiceOption holds options for RPC client
type ServiceOption struct {
RemoteURL string
Insecure bool
CustomHeaders http.Header
PathPrefix string
}
// Service implements the RPC client for remote scanning
type Service struct {
customHeaders http.Header
client rpc.Scanner
}
// NewService is the factory method to return RPC Service
func NewService(scannerOptions ServiceOption, opts ...Option) Service {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: scannerOptions.Insecure}
httpClient := &http.Client{Transport: tr}
var twirpOpts []twirp.ClientOption
if scannerOptions.PathPrefix != "" {
twirpOpts = append(twirpOpts, twirp.WithClientPathPrefix(scannerOptions.PathPrefix))
}
c := rpc.NewScannerProtobufClient(scannerOptions.RemoteURL, httpClient, twirpOpts...)
o := &options{rpcClient: c}
for _, opt := range opts {
opt(o)
}
return Service{
customHeaders: scannerOptions.CustomHeaders,
client: o.rpcClient,
}
}
// Scan scans the image
func (s Service) Scan(ctx context.Context, target, artifactKey string, blobKeys []string, opts types.ScanOptions) (types.Results, ftypes.OS, error) {
ctx = WithCustomHeaders(ctx, s.customHeaders)
// Convert to the rpc struct
licenseCategories := make(map[string]*rpc.Licenses)
for category, names := range opts.LicenseCategories {
licenseCategories[string(category)] = &rpc.Licenses{Names: names}
}
var distro *common.OS
if !lo.IsEmpty(opts.Distro) {
distro = &common.OS{
Family: string(opts.Distro.Family),
Name: opts.Distro.Name,
}
}
var res *rpc.ScanResponse
err := r.Retry(func() error {
var err error
res, err = s.client.Scan(ctx, &rpc.ScanRequest{
Target: target,
ArtifactId: artifactKey,
BlobIds: blobKeys,
Options: &rpc.ScanOptions{
PkgTypes: opts.PkgTypes,
PkgRelationships: xstrings.ToStringSlice(opts.PkgRelationships),
Scanners: xstrings.ToStringSlice(opts.Scanners),
LicenseCategories: licenseCategories,
LicenseFull: opts.LicenseFull,
IncludeDevDeps: opts.IncludeDevDeps,
Distro: distro,
VulnSeveritySources: xstrings.ToStringSlice(opts.VulnSeveritySources),
},
})
return err
})
if err != nil {
return nil, ftypes.OS{}, xerrors.Errorf("failed to detect vulnerabilities via RPC: %w", err)
}
return r.ConvertFromRPCResults(res.Results), r.ConvertFromRPCOS(res.Os), nil
}