feat(server): add internal --path-prefix flag for client/server mode (#7321)

Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
Teppei Fukuda
2024-08-21 09:26:11 +04:00
committed by GitHub
parent 3f0e7ebe0d
commit 24a4563686
11 changed files with 130 additions and 77 deletions

View File

@@ -33,9 +33,6 @@ cache:
# Same as '--cache-backend'
backend: "fs"
# Same as '--clear-cache'
clear: false
redis:
# Same as '--redis-ca'
ca: ""
@@ -112,9 +109,6 @@ db:
# Same as '--skip-java-db-update'
java-skip-update: false
# Same as '--light'
light: false
# Same as '--no-progress'
no-progress: false
@@ -124,9 +118,6 @@ db:
# Same as '--skip-db-update'
skip-update: false
# Same as '--reset'
reset: false
```
## Image options
@@ -411,9 +402,6 @@ misconfiguration:
# Same as '--include-non-failures'
include-non-failures: false
# Same as '--reset-checks-bundle'
reset-checks-bundle: false
# Same as '--misconfig-scanners'
scanners:
- azure-arm
@@ -580,9 +568,6 @@ scan:
# Same as '--skip-files'
skip-files: []
# Same as '--slow'
slow: false
```
## Secret options

View File

@@ -32,6 +32,7 @@ type csArgs struct {
Input string
ClientToken string
ClientTokenHeader string
PathPrefix string
ListAllPackages bool
Target string
secretConfig string
@@ -443,7 +444,11 @@ func TestClientServerWithCycloneDX(t *testing.T) {
}
}
func TestClientServerWithToken(t *testing.T) {
func TestClientServerWithCustomOptions(t *testing.T) {
token := "token"
tokenHeader := "Trivy-Token"
pathPrefix := "prefix"
tests := []struct {
name string
args csArgs
@@ -451,11 +456,12 @@ func TestClientServerWithToken(t *testing.T) {
wantErr string
}{
{
name: "alpine 3.9 with token",
name: "alpine 3.9 with token and prefix",
args: csArgs{
Input: "testdata/fixtures/images/alpine-39.tar.gz",
ClientToken: "token",
ClientTokenHeader: "Trivy-Token",
ClientToken: token,
ClientTokenHeader: tokenHeader,
PathPrefix: pathPrefix,
},
golden: "testdata/alpine-39.json.golden",
},
@@ -464,7 +470,8 @@ func TestClientServerWithToken(t *testing.T) {
args: csArgs{
Input: "testdata/fixtures/images/distroless-base.tar.gz",
ClientToken: "invalidtoken",
ClientTokenHeader: "Trivy-Token",
ClientTokenHeader: tokenHeader,
PathPrefix: pathPrefix,
},
wantErr: "twirp error unauthenticated: invalid token",
},
@@ -472,18 +479,28 @@ func TestClientServerWithToken(t *testing.T) {
name: "invalid token header",
args: csArgs{
Input: "testdata/fixtures/images/distroless-base.tar.gz",
ClientToken: "token",
ClientToken: token,
ClientTokenHeader: "Unknown-Header",
PathPrefix: pathPrefix,
},
wantErr: "twirp error unauthenticated: invalid token",
},
{
name: "wrong path prefix",
args: csArgs{
Input: "testdata/fixtures/images/distroless-base.tar.gz",
ClientToken: token,
ClientTokenHeader: tokenHeader,
PathPrefix: "wrong",
},
wantErr: "HTTP status code 404",
},
}
serverToken := "token"
serverTokenHeader := "Trivy-Token"
addr, cacheDir := setup(t, setupOptions{
token: serverToken,
tokenHeader: serverTokenHeader,
token: token,
tokenHeader: tokenHeader,
pathPrefix: pathPrefix,
})
for _, tt := range tests {
@@ -539,6 +556,7 @@ func TestClientServerWithRedis(t *testing.T) {
type setupOptions struct {
token string
tokenHeader string
pathPrefix string
cacheBackend string
}
@@ -556,7 +574,7 @@ func setup(t *testing.T, options setupOptions) (string, string) {
addr := fmt.Sprintf("localhost:%d", port)
go func() {
osArgs := setupServer(addr, options.token, options.tokenHeader, cacheDir, options.cacheBackend)
osArgs := setupServer(addr, options.token, options.tokenHeader, options.pathPrefix, cacheDir, options.cacheBackend)
// Run Trivy server
require.NoError(t, execute(osArgs))
@@ -569,22 +587,20 @@ func setup(t *testing.T, options setupOptions) (string, string) {
return addr, cacheDir
}
func setupServer(addr, token, tokenHeader, cacheDir, cacheBackend string) []string {
func setupServer(addr, token, tokenHeader, pathPrefix, cacheDir, cacheBackend string) []string {
osArgs := []string{
"--cache-dir",
cacheDir,
"server",
"--skip-update",
"--skip-db-update",
"--listen",
addr,
}
if token != "" {
osArgs = append(osArgs, []string{
"--token",
token,
"--token-header",
tokenHeader,
}...)
osArgs = append(osArgs, "--token", token, "--token-header", tokenHeader)
}
if pathPrefix != "" {
osArgs = append(osArgs, "--path-prefix", pathPrefix)
}
if cacheBackend != "" {
osArgs = append(osArgs, "--cache-backend", cacheBackend)
@@ -593,13 +609,13 @@ func setupServer(addr, token, tokenHeader, cacheDir, cacheBackend string) []stri
}
func setupClient(t *testing.T, c csArgs, addr string, cacheDir string) []string {
t.Helper()
if c.Command == "" {
c.Command = "image"
}
if c.RemoteAddrOption == "" {
c.RemoteAddrOption = "--server"
}
t.Helper()
osArgs := []string{
"--cache-dir",
cacheDir,
@@ -639,6 +655,9 @@ func setupClient(t *testing.T, c csArgs, addr string, cacheDir string) []string
if c.ClientToken != "" {
osArgs = append(osArgs, "--token", c.ClientToken, "--token-header", c.ClientTokenHeader)
}
if c.PathPrefix != "" {
osArgs = append(osArgs, "--path-prefix", c.PathPrefix)
}
if c.Input != "" {
osArgs = append(osArgs, "--input", c.Input)
}

View File

@@ -105,7 +105,7 @@ func writeFlags(group flag.FlagGroup, w *os.File) {
var lastParts []string
for _, flg := range flags {
if flg.GetConfigName() == "" {
if flg.GetConfigName() == "" || flg.Hidden() {
continue
}
// We need to split the config name on `.` to make the indentations needed in yaml.

9
pkg/cache/remote.go vendored
View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"net/http"
"github.com/twitchtv/twirp"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/fanal/types"
@@ -19,6 +20,7 @@ type RemoteOptions struct {
ServerAddr string
CustomHeaders http.Header
Insecure bool
PathPrefix string
}
// RemoteCache implements remote cache
@@ -39,7 +41,12 @@ func NewRemoteCache(opts RemoteOptions) *RemoteCache {
},
},
}
c := rpcCache.NewCacheProtobufClient(opts.ServerAddr, httpClient)
var twirpOpts []twirp.ClientOption
if opts.PathPrefix != "" {
twirpOpts = append(twirpOpts, twirp.WithClientPathPrefix(opts.PathPrefix))
}
c := rpcCache.NewCacheProtobufClient(opts.ServerAddr, httpClient, twirpOpts...)
return &RemoteCache{
ctx: ctx,
client: c,

View File

@@ -545,11 +545,7 @@ func (r *runner) initScannerConfig(opts flag.Options) (ScannerConfig, types.Scan
Target: target,
CacheOptions: opts.CacheOpts(),
RemoteCacheOptions: opts.RemoteCacheOpts(),
ServerOption: client.ScannerOption{
RemoteURL: opts.ServerAddr,
CustomHeaders: opts.CustomHeaders,
Insecure: opts.Insecure,
},
ServerOption: opts.ClientScannerOpts(),
ArtifactOption: artifact.Option{
DisabledAnalyzers: disabledAnalyzers(opts),
DisabledHandlers: disabledHandlers,

View File

@@ -50,6 +50,6 @@ func Run(ctx context.Context, opts flag.Options) (err error) {
m.Register()
server := rpcServer.NewServer(opts.AppVersion, opts.Listen, opts.CacheDir, opts.Token, opts.TokenHeader,
opts.DBRepository, opts.RegistryOpts())
opts.PathPrefix, opts.DBRepository, opts.RegistryOpts())
return server.ListenAndServe(ctx, cacheClient, opts.SkipDBUpdate)
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/plugin"
"github.com/aquasecurity/trivy/pkg/result"
"github.com/aquasecurity/trivy/pkg/rpc/client"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/version/app"
)
@@ -56,15 +57,21 @@ type Flag[T FlagType] struct {
// Usage explains how to use the flag.
Usage string
// Persistent represents if the flag is persistent
// Persistent represents if the flag is persistent.
Persistent bool
// Deprecated represents if the flag is deprecated
// Deprecated represents if the flag is deprecated.
// It shows a warning message when the flag is used.
Deprecated string
// Removed represents if the flag is removed and no longer works
// Removed represents if the flag is removed and no longer works.
// It shows an error message when the flag is used.
Removed string
// Internal represents if the flag is for internal use only.
// It is not shown in the usage message.
Internal bool
// Aliases represents aliases
Aliases []Alias
@@ -208,6 +215,10 @@ func (f *Flag[T]) GetAliases() []Alias {
return f.Aliases
}
func (f *Flag[T]) Hidden() bool {
return f.Deprecated != "" || f.Removed != "" || f.Internal
}
func (f *Flag[T]) Value() (t T) {
if f == nil {
return t
@@ -249,7 +260,7 @@ func (f *Flag[T]) Add(cmd *cobra.Command) {
flags.Float64P(f.Name, f.Shorthand, v, f.Usage)
}
if f.Deprecated != "" || f.Removed != "" {
if f.Hidden() {
_ = flags.MarkHidden(f.Name)
}
}
@@ -313,6 +324,7 @@ type Flagger interface {
GetConfigName() string
GetDefaultValue() any
GetAliases() []Alias
Hidden() bool
Parse() error
Add(cmd *cobra.Command)
@@ -480,6 +492,16 @@ func (o *Options) RemoteCacheOpts() cache.RemoteOptions {
ServerAddr: o.ServerAddr,
CustomHeaders: o.CustomHeaders,
Insecure: o.Insecure,
PathPrefix: o.PathPrefix,
}
}
func (o *Options) ClientScannerOpts() client.ScannerOption {
return client.ScannerOption{
RemoteURL: o.ServerAddr,
CustomHeaders: o.CustomHeaders,
Insecure: o.Insecure,
PathPrefix: o.PathPrefix,
}
}

View File

@@ -39,6 +39,12 @@ var (
Default: "localhost:4954",
Usage: "listen address in server mode",
}
ServerPathPrefixFlag = Flag[string]{
Name: "path-prefix",
ConfigName: "server.path-prefix",
Usage: "prefix for the server endpoint",
Internal: true, // Internal use
}
)
// RemoteFlagGroup composes common printer flag structs
@@ -47,6 +53,7 @@ type RemoteFlagGroup struct {
// for client/server
Token *Flag[string]
TokenHeader *Flag[string]
PathPrefix *Flag[string]
// for client
ServerAddr *Flag[string]
@@ -63,12 +70,17 @@ type RemoteOptions struct {
ServerAddr string
Listen string
CustomHeaders http.Header
// Server endpoint: <baseURL>[<prefix>]/<package>.<Service>/<Method> (default prefix: /twirp)
// e.g., http://localhost:4954/twirp/trivy.scanner.v1.Scanner/Scan
PathPrefix string
}
func NewClientFlags() *RemoteFlagGroup {
return &RemoteFlagGroup{
Token: ServerTokenFlag.Clone(),
TokenHeader: ServerTokenHeaderFlag.Clone(),
PathPrefix: ServerPathPrefixFlag.Clone(),
ServerAddr: ServerAddrFlag.Clone(),
CustomHeaders: ServerCustomHeadersFlag.Clone(),
}
@@ -76,9 +88,10 @@ func NewClientFlags() *RemoteFlagGroup {
func NewServerFlags() *RemoteFlagGroup {
return &RemoteFlagGroup{
Token: &ServerTokenFlag,
TokenHeader: &ServerTokenHeaderFlag,
Listen: &ServerListenFlag,
Token: ServerTokenFlag.Clone(),
TokenHeader: ServerTokenHeaderFlag.Clone(),
PathPrefix: ServerPathPrefixFlag.Clone(),
Listen: ServerListenFlag.Clone(),
}
}
@@ -90,6 +103,7 @@ func (f *RemoteFlagGroup) Flags() []Flagger {
return []Flagger{
f.Token,
f.TokenHeader,
f.PathPrefix,
f.ServerAddr,
f.CustomHeaders,
f.Listen,
@@ -129,6 +143,7 @@ func (f *RemoteFlagGroup) ToOptions() (RemoteOptions, error) {
return RemoteOptions{
Token: token,
TokenHeader: tokenHeader,
PathPrefix: f.PathPrefix.Value(),
ServerAddr: serverAddr,
CustomHeaders: customHeaders,
Listen: listen,

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"net/http"
"github.com/twitchtv/twirp"
"golang.org/x/xerrors"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
@@ -32,6 +33,7 @@ type ScannerOption struct {
RemoteURL string
Insecure bool
CustomHeaders http.Header
PathPrefix string
}
// Scanner implements the RPC scanner
@@ -42,16 +44,15 @@ type Scanner struct {
// NewScanner is the factory method to return RPC Scanner
func NewScanner(scannerOptions ScannerOption, opts ...Option) Scanner {
httpClient := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: scannerOptions.Insecure,
},
},
}
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: scannerOptions.Insecure}
httpClient := &http.Client{Transport: tr}
c := rpc.NewScannerProtobufClient(scannerOptions.RemoteURL, httpClient)
var twirpOpts []twirp.ClientOption
if scannerOptions.PathPrefix != "" {
twirpOpts = append(twirpOpts, twirp.WithClientPathPrefix(scannerOptions.PathPrefix))
}
c := rpc.NewScannerProtobufClient(scannerOptions.RemoteURL, httpClient, twirpOpts...)
o := &options{rpcClient: c}
for _, opt := range opts {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"os"
"strings"
"sync"
"time"
@@ -30,9 +31,11 @@ const updateInterval = 1 * time.Hour
type Server struct {
appVersion string
addr string
cacheDir string
dbDir string
token string
tokenHeader string
pathPrefix string
dbRepository name.Reference
// For OCI registries
@@ -40,13 +43,15 @@ type Server struct {
}
// NewServer returns an instance of Server
func NewServer(appVersion, addr, cacheDir, token, tokenHeader string, dbRepository name.Reference, opt types.RegistryOptions) Server {
func NewServer(appVersion, addr, cacheDir, token, tokenHeader, pathPrefix string, dbRepository name.Reference, opt types.RegistryOptions) Server {
return Server{
appVersion: appVersion,
addr: addr,
cacheDir: cacheDir,
dbDir: db.Dir(cacheDir),
token: token,
tokenHeader: tokenHeader,
pathPrefix: pathPrefix,
dbRepository: dbRepository,
RegistryOptions: opt,
}
@@ -67,14 +72,13 @@ func (s Server) ListenAndServe(ctx context.Context, serverCache cache.Cache, ski
}
}()
mux := newServeMux(ctx, serverCache, dbUpdateWg, requestWg, s.token, s.tokenHeader, s.dbDir)
mux := s.newServeMux(ctx, serverCache, dbUpdateWg, requestWg)
log.Infof("Listening %s...", s.addr)
return http.ListenAndServe(s.addr, mux)
}
func newServeMux(ctx context.Context, serverCache cache.Cache, dbUpdateWg, requestWg *sync.WaitGroup,
token, tokenHeader, cacheDir string) *http.ServeMux {
func (s Server) newServeMux(ctx context.Context, serverCache cache.Cache, dbUpdateWg, requestWg *sync.WaitGroup) *http.ServeMux {
withWaitGroup := func(base http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Stop processing requests during DB update
@@ -91,13 +95,19 @@ func newServeMux(ctx context.Context, serverCache cache.Cache, dbUpdateWg, reque
mux := http.NewServeMux()
scanServer := rpcScanner.NewScannerServer(initializeScanServer(serverCache), nil)
scanHandler := withToken(withWaitGroup(scanServer), token, tokenHeader)
mux.Handle(rpcScanner.ScannerPathPrefix, gziphandler.GzipHandler(scanHandler))
var twirpOpts []any
if s.pathPrefix != "" {
pathPrefix := "/" + strings.TrimPrefix(s.pathPrefix, "/") // Twirp requires the leading slash
twirpOpts = append(twirpOpts, twirp.WithServerPathPrefix(pathPrefix))
}
layerServer := rpcCache.NewCacheServer(NewCacheServer(serverCache), nil)
layerHandler := withToken(withWaitGroup(layerServer), token, tokenHeader)
mux.Handle(rpcCache.CachePathPrefix, gziphandler.GzipHandler(layerHandler))
scanServer := rpcScanner.NewScannerServer(initializeScanServer(serverCache), twirpOpts...)
scanHandler := withToken(withWaitGroup(scanServer), s.token, s.tokenHeader)
mux.Handle(scanServer.PathPrefix(), gziphandler.GzipHandler(scanHandler))
cacheServer := rpcCache.NewCacheServer(NewCacheServer(serverCache), twirpOpts...)
layerHandler := withToken(withWaitGroup(cacheServer), s.token, s.tokenHeader)
mux.Handle(cacheServer.PathPrefix(), gziphandler.GzipHandler(layerHandler))
mux.HandleFunc("/healthz", func(rw http.ResponseWriter, r *http.Request) {
if _, err := rw.Write([]byte("ok")); err != nil {
@@ -108,7 +118,7 @@ func newServeMux(ctx context.Context, serverCache cache.Cache, dbUpdateWg, reque
mux.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(version.NewVersionInfo(cacheDir)); err != nil {
if err := json.NewEncoder(w).Encode(version.NewVersionInfo(s.cacheDir)); err != nil {
log.Error("Version error", log.Err(err))
}
})

View File

@@ -115,7 +115,7 @@ func Test_dbWorker_update(t *testing.T) {
}
}
func Test_newServeMux(t *testing.T) {
func TestServer_newServeMux(t *testing.T) {
type args struct {
token string
tokenHeader string
@@ -182,9 +182,8 @@ func Test_newServeMux(t *testing.T) {
require.NoError(t, err)
defer func() { _ = c.Close() }()
ts := httptest.NewServer(newServeMux(context.Background(), c, dbUpdateWg, requestWg, tt.args.token,
tt.args.tokenHeader, ""),
)
s := NewServer("", "", "", tt.args.token, tt.args.tokenHeader, "", nil, ftypes.RegistryOptions{})
ts := httptest.NewServer(s.newServeMux(context.Background(), c, dbUpdateWg, requestWg))
defer ts.Close()
var resp *http.Response
@@ -214,9 +213,8 @@ func Test_VersionEndpoint(t *testing.T) {
require.NoError(t, err)
defer func() { _ = c.Close() }()
ts := httptest.NewServer(newServeMux(context.Background(), c, dbUpdateWg, requestWg, "", "",
"testdata/testcache"),
)
s := NewServer("", "", "testdata/testcache", "", "", "", nil, ftypes.RegistryOptions{})
ts := httptest.NewServer(s.newServeMux(context.Background(), c, dbUpdateWg, requestWg))
defer ts.Close()
resp, err := http.Get(ts.URL + "/version")