feat: add support for flag groups (#2488)

This commit is contained in:
Teppei Fukuda
2022-07-10 15:03:57 +03:00
committed by GitHub
parent 5b7e0a858d
commit 736e3f11f7
16 changed files with 496 additions and 410 deletions

View File

@@ -30,6 +30,32 @@ type VersionInfo struct {
VulnerabilityDB *metadata.Metadata `json:",omitempty"`
}
const (
usageTemplate = `Usage:{{if .Runnable}}
{{.UseLine}}{{end}}{{if .HasAvailableSubCommands}}
{{.CommandPath}} [command]{{end}}{{if gt (len .Aliases) 0}}
Aliases:
{{.NameAndAliases}}{{end}}{{if .HasExample}}
Examples:
{{.Example}}{{end}}{{if .HasAvailableSubCommands}}
Available Commands:{{range .Commands}}{{if (or .IsAvailableCommand (eq .Name "help"))}}
{{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableLocalFlags}}
%s
Global Flags:
{{.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasHelpSubCommands}}
Additional help topics:{{range .Commands}}{{if .IsAdditionalHelpTopicCommand}}
{{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableSubCommands}}
Use "{{.CommandPath}} [command] --help" for more information about a command.{{end}}
`
)
var (
outputWriter io.Writer = os.Stdout
)
@@ -199,6 +225,8 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}
cmd := &cobra.Command{
@@ -250,8 +278,9 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
SilenceUsage: true,
}
cmd.SetFlagErrorFunc(flagErrorFunc)
imageFlags.AddFlags(cmd)
cmd.SetFlagErrorFunc(flagErrorFunc)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, imageFlags.Usages(cmd)))
return cmd
}
@@ -267,6 +296,8 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}
cmd := &cobra.Command{
@@ -300,6 +331,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
cmd.SetFlagErrorFunc(flagErrorFunc)
fsFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, fsFlags.Usages(cmd)))
return cmd
}
@@ -314,6 +346,8 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}
cmd := &cobra.Command{
@@ -348,6 +382,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
rootfsFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, rootfsFlags.Usages(cmd)))
return cmd
}
@@ -363,6 +398,8 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}
cmd := &cobra.Command{
@@ -392,6 +429,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
repoFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, repoFlags.Usages(cmd)))
return cmd
}
@@ -415,6 +453,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
RemoteFlagGroup: remoteFlags,
ReportFlagGroup: flag.NewReportFlagGroup(),
ScanFlagGroup: flag.NewScanFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}
cmd := &cobra.Command{
@@ -444,6 +483,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
clientFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, clientFlags.Usages(cmd)))
return cmd
}
@@ -459,6 +499,12 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
Use: "server [flags]",
Aliases: []string{"s"},
Short: "Server mode",
Example: ` # Run a server
$ trivy server
# Listen on 0.0.0.0:10000
$ trivy server --listen 0.0.0.0:10000
`,
Args: cobra.ExactArgs(0),
RunE: func(cmd *cobra.Command, args []string) error {
if err := serverFlags.Bind(cmd); err != nil {
@@ -475,6 +521,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
serverFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, serverFlags.Usages(cmd)))
return cmd
}
@@ -528,6 +575,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
configFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, configFlags.Usages(cmd)))
return cmd
}
@@ -696,6 +744,8 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ReportFlagGroup: flag.NewReportFlagGroup(),
ScanFlagGroup: scanFlags,
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}
cmd := &cobra.Command{
Use: "kubernetes [flags] { cluster | all | specific resources like kubectl. eg: pods, pod/NAME }",
@@ -736,6 +786,7 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
k8sFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, k8sFlags.Usages(cmd)))
return cmd
}
@@ -754,6 +805,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SBOMFlagGroup: flag.NewSBOMFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}
cmd := &cobra.Command{
@@ -766,7 +818,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
$ trivy sbom --format cyclonedx /path/to/report.cdx
`,
PreRunE: func(cmd *cobra.Command, args []string) error {
if err := scanFlags.Bind(cmd); err != nil {
if err := sbomFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
return validateArgs(cmd, args)
@@ -790,6 +842,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
sbomFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, sbomFlags.Usages(cmd)))
return cmd
}

View File

@@ -6,7 +6,6 @@ import (
"time"
"github.com/samber/lo"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
)
@@ -85,40 +84,29 @@ type RedisOptions struct {
// NewCacheFlagGroup returns a default CacheFlagGroup
func NewCacheFlagGroup() *CacheFlagGroup {
return &CacheFlagGroup{
ClearCache: lo.ToPtr(ClearCacheFlag),
CacheBackend: lo.ToPtr(CacheBackendFlag),
CacheTTL: lo.ToPtr(CacheTTLFlag),
RedisCACert: lo.ToPtr(RedisCACertFlag),
RedisCert: lo.ToPtr(RedisCertFlag),
RedisKey: lo.ToPtr(RedisKeyFlag),
ClearCache: &ClearCacheFlag,
CacheBackend: &CacheBackendFlag,
CacheTTL: &CacheTTLFlag,
RedisCACert: &RedisCACertFlag,
RedisCert: &RedisCertFlag,
RedisKey: &RedisKeyFlag,
}
}
func (f *CacheFlagGroup) flags() []*Flag {
return []*Flag{f.ClearCache, f.CacheBackend, f.CacheTTL, f.RedisCACert, f.RedisCert, f.RedisKey}
func (fg *CacheFlagGroup) Name() string {
return "Cache"
}
func (f *CacheFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
}
func (fg *CacheFlagGroup) Flags() []*Flag {
return []*Flag{fg.ClearCache, fg.CacheBackend, fg.CacheTTL, fg.RedisCACert, fg.RedisCert, fg.RedisKey}
}
func (f *CacheFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
return err
}
}
return nil
}
func (f *CacheFlagGroup) ToOptions() (CacheOptions, error) {
cacheBackend := getString(f.CacheBackend)
func (fg *CacheFlagGroup) ToOptions() (CacheOptions, error) {
cacheBackend := getString(fg.CacheBackend)
redisOptions := RedisOptions{
RedisCACert: getString(f.RedisCACert),
RedisCert: getString(f.RedisCert),
RedisKey: getString(f.RedisKey),
RedisCACert: getString(fg.RedisCACert),
RedisCert: getString(fg.RedisCert),
RedisKey: getString(fg.RedisKey),
}
// "redis://" or "fs" are allowed for now
@@ -135,9 +123,9 @@ func (f *CacheFlagGroup) ToOptions() (CacheOptions, error) {
}
return CacheOptions{
ClearCache: getBool(f.ClearCache),
ClearCache: getBool(fg.ClearCache),
CacheBackend: cacheBackend,
CacheTTL: getDuration(f.CacheTTL),
CacheTTL: getDuration(fg.CacheTTL),
RedisOptions: redisOptions,
}, nil
}

View File

@@ -1,8 +1,6 @@
package flag
import (
"github.com/samber/lo"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/log"
@@ -46,6 +44,7 @@ var (
ConfigName: "db.light",
Value: false,
Usage: "deprecated",
Deprecated: true,
}
)
@@ -71,34 +70,23 @@ type DBOptions struct {
// NewDBFlagGroup returns a default DBFlagGroup
func NewDBFlagGroup() *DBFlagGroup {
return &DBFlagGroup{
Reset: lo.ToPtr(ResetFlag),
DownloadDBOnly: lo.ToPtr(DownloadDBOnlyFlag),
SkipDBUpdate: lo.ToPtr(SkipDBUpdateFlag),
Light: lo.ToPtr(LightFlag),
NoProgress: lo.ToPtr(NoProgressFlag),
DBRepository: lo.ToPtr(DBRepositoryFlag),
Reset: &ResetFlag,
DownloadDBOnly: &DownloadDBOnlyFlag,
SkipDBUpdate: &SkipDBUpdateFlag,
Light: &LightFlag,
NoProgress: &NoProgressFlag,
DBRepository: &DBRepositoryFlag,
}
}
func (f *DBFlagGroup) flags() []*Flag {
func (f *DBFlagGroup) Name() string {
return "DB"
}
func (f *DBFlagGroup) Flags() []*Flag {
return []*Flag{f.Reset, f.DownloadDBOnly, f.SkipDBUpdate, f.NoProgress, f.DBRepository, f.Light}
}
func (f *DBFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
}
}
func (f *DBFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
return err
}
}
return nil
}
func (f *DBFlagGroup) ToOptions() (DBOptions, error) {
skipDBUpdate := getBool(f.SkipDBUpdate)
downloadDBOnly := getBool(f.DownloadDBOnly)

View File

@@ -1,9 +1,5 @@
package flag
import (
"github.com/spf13/cobra"
)
// e.g. config yaml
// image:
// removed-pkgs: true
@@ -41,25 +37,14 @@ func NewImageFlagGroup() *ImageFlagGroup {
}
}
func (f *ImageFlagGroup) flags() []*Flag {
func (f *ImageFlagGroup) Name() string {
return "Image"
}
func (f *ImageFlagGroup) Flags() []*Flag {
return []*Flag{f.Input, f.ScanRemovedPkgs}
}
func (f *ImageFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
}
}
func (f *ImageFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
return err
}
}
return nil
}
func (f *ImageFlagGroup) ToOptions() ImageOptions {
return ImageOptions{
Input: getString(f.Input),

View File

@@ -1,10 +1,5 @@
package flag
import (
"github.com/samber/lo"
"github.com/spf13/cobra"
)
var (
ClusterContextFlag = Flag{
Name: "context",
@@ -32,30 +27,19 @@ type K8sOptions struct {
func NewK8sFlagGroup() *K8sFlagGroup {
return &K8sFlagGroup{
ClusterContext: lo.ToPtr(ClusterContextFlag),
Namespace: lo.ToPtr(K8sNamespaceFlag),
ClusterContext: &ClusterContextFlag,
Namespace: &K8sNamespaceFlag,
}
}
func (f *K8sFlagGroup) flags() []*Flag {
func (f *K8sFlagGroup) Name() string {
return "Kubernetes"
}
func (f *K8sFlagGroup) Flags() []*Flag {
return []*Flag{f.ClusterContext, f.Namespace}
}
func (f *K8sFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
}
}
func (f *K8sFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
return err
}
}
return nil
}
func (f *K8sFlagGroup) ToOptions() K8sOptions {
return K8sOptions{
ClusterContext: getString(f.ClusterContext),

View File

@@ -1,9 +1,6 @@
package flag
import (
"github.com/samber/lo"
"github.com/spf13/cobra"
"github.com/aquasecurity/trivy/pkg/log"
)
@@ -30,6 +27,7 @@ var (
ConfigName: "misconfiguration.skip-policy-update",
Value: false,
Usage: "deprecated",
Deprecated: true,
}
TraceFlag = Flag{
Name: "trace",
@@ -84,35 +82,24 @@ type MisconfOptions struct {
func NewMisconfFlagGroup() *MisconfFlagGroup {
return &MisconfFlagGroup{
FilePatterns: lo.ToPtr(FilePatternsFlag),
IncludeNonFailures: lo.ToPtr(IncludeNonFailuresFlag),
SkipPolicyUpdate: lo.ToPtr(SkipPolicyUpdateFlag),
Trace: lo.ToPtr(TraceFlag),
PolicyPaths: lo.ToPtr(ConfigPolicyFlag),
DataPaths: lo.ToPtr(ConfigDataFlag),
PolicyNamespaces: lo.ToPtr(PolicyNamespaceFlag),
FilePatterns: &FilePatternsFlag,
IncludeNonFailures: &IncludeNonFailuresFlag,
SkipPolicyUpdate: &SkipPolicyUpdateFlag,
Trace: &TraceFlag,
PolicyPaths: &ConfigPolicyFlag,
DataPaths: &ConfigDataFlag,
PolicyNamespaces: &PolicyNamespaceFlag,
}
}
func (f *MisconfFlagGroup) flags() []*Flag {
func (f *MisconfFlagGroup) Name() string {
return "Misconfiguration"
}
func (f *MisconfFlagGroup) Flags() []*Flag {
return []*Flag{f.FilePatterns, f.IncludeNonFailures, f.SkipPolicyUpdate, f.Trace, f.PolicyPaths, f.DataPaths, f.PolicyNamespaces}
}
func (f *MisconfFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
}
}
func (f *MisconfFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
return err
}
}
return nil
}
func (f *MisconfFlagGroup) ToOptions() (MisconfOptions, error) {
skipPolicyUpdateFlag := getBool(f.SkipPolicyUpdate)
if skipPolicyUpdateFlag {

View File

@@ -1,6 +1,7 @@
package flag
import (
"fmt"
"io"
"strings"
"time"
@@ -32,11 +33,14 @@ type Flag struct {
// Persistent represents if the flag is persistent
Persistent bool
// Deprecated represents if the flag is deprecated
Deprecated bool
}
type FlagGroup interface {
AddFlags(cmd *cobra.Command)
Bind(cmd *cobra.Command) error
Name() string
Flags() []*Flag
}
type Flags struct {
@@ -49,6 +53,8 @@ type Flags struct {
ReportFlagGroup *ReportFlagGroup
SBOMFlagGroup *SBOMFlagGroup
ScanFlagGroup *ScanFlagGroup
SecretFlagGroup *SecretFlagGroup
VulnerabilityFlagGroup *VulnerabilityFlagGroup
}
// Options holds all the runtime configuration
@@ -63,6 +69,8 @@ type Options struct {
ReportOptions
SBOMOptions
ScanOptions
SecretOptions
VulnerabilityOptions
// Trivy's version, not populated via CLI flags
AppVersion string
@@ -75,37 +83,28 @@ func addFlag(cmd *cobra.Command, flag *Flag) {
if flag == nil || flag.Name == "" {
return
}
var flags *pflag.FlagSet
if flag.Persistent {
flags = cmd.PersistentFlags()
} else {
flags = cmd.Flags()
}
switch v := flag.Value.(type) {
case int:
if flag.Persistent {
cmd.PersistentFlags().IntP(flag.Name, flag.Shorthand, v, flag.Usage)
} else {
cmd.Flags().IntP(flag.Name, flag.Shorthand, v, flag.Usage)
}
flags.IntP(flag.Name, flag.Shorthand, v, flag.Usage)
case string:
if flag.Persistent {
cmd.PersistentFlags().StringP(flag.Name, flag.Shorthand, v, flag.Usage)
} else {
cmd.Flags().StringP(flag.Name, flag.Shorthand, v, flag.Usage)
}
flags.StringP(flag.Name, flag.Shorthand, v, flag.Usage)
case []string:
if flag.Persistent {
cmd.PersistentFlags().StringSliceP(flag.Name, flag.Shorthand, v, flag.Usage)
} else {
cmd.Flags().StringSliceP(flag.Name, flag.Shorthand, v, flag.Usage)
}
flags.StringSliceP(flag.Name, flag.Shorthand, v, flag.Usage)
case bool:
if flag.Persistent {
cmd.PersistentFlags().BoolP(flag.Name, flag.Shorthand, v, flag.Usage)
} else {
cmd.Flags().BoolP(flag.Name, flag.Shorthand, v, flag.Usage)
}
flags.BoolP(flag.Name, flag.Shorthand, v, flag.Usage)
case time.Duration:
if flag.Persistent {
cmd.PersistentFlags().DurationP(flag.Name, flag.Shorthand, v, flag.Usage)
} else {
cmd.PersistentFlags().DurationP(flag.Name, flag.Shorthand, v, flag.Usage)
flags.DurationP(flag.Name, flag.Shorthand, v, flag.Usage)
}
if flag.Deprecated {
flags.MarkHidden(flag.Name) // nolint: gosec
}
}
@@ -166,6 +165,13 @@ func getDuration(flag *Flag) time.Duration {
func (f *Flags) groups() []FlagGroup {
var groups []FlagGroup
// This order affects the usage message, so they are sorted by frequency of use.
if f.ScanFlagGroup != nil {
groups = append(groups, f.ScanFlagGroup)
}
if f.ReportFlagGroup != nil {
groups = append(groups, f.ReportFlagGroup)
}
if f.CacheFlagGroup != nil {
groups = append(groups, f.CacheFlagGroup)
}
@@ -175,47 +181,70 @@ func (f *Flags) groups() []FlagGroup {
if f.ImageFlagGroup != nil {
groups = append(groups, f.ImageFlagGroup)
}
if f.K8sFlagGroup != nil {
groups = append(groups, f.K8sFlagGroup)
if f.SBOMFlagGroup != nil {
groups = append(groups, f.SBOMFlagGroup)
}
if f.VulnerabilityFlagGroup != nil {
groups = append(groups, f.VulnerabilityFlagGroup)
}
if f.MisconfFlagGroup != nil {
groups = append(groups, f.MisconfFlagGroup)
}
if f.SecretFlagGroup != nil {
groups = append(groups, f.SecretFlagGroup)
}
if f.K8sFlagGroup != nil {
groups = append(groups, f.K8sFlagGroup)
}
if f.RemoteFlagGroup != nil {
groups = append(groups, f.RemoteFlagGroup)
}
if f.ReportFlagGroup != nil {
groups = append(groups, f.ReportFlagGroup)
}
if f.SBOMFlagGroup != nil {
groups = append(groups, f.SBOMFlagGroup)
}
if f.ScanFlagGroup != nil {
groups = append(groups, f.ScanFlagGroup)
}
return groups
}
func (f *Flags) AddFlags(cmd *cobra.Command) {
for _, group := range f.groups() {
if group == nil {
continue
for _, flag := range group.Flags() {
addFlag(cmd, flag)
}
group.AddFlags(cmd)
}
cmd.Flags().SetNormalizeFunc(flagNameNormalize)
}
func (f *Flags) Usages(cmd *cobra.Command) string {
var usages string
for _, group := range f.groups() {
flags := pflag.NewFlagSet(cmd.Name(), pflag.ContinueOnError)
lflags := cmd.LocalFlags()
for _, flag := range group.Flags() {
if flag == nil {
continue
}
flags.AddFlag(lflags.Lookup(flag.Name))
}
if !flags.HasAvailableFlags() {
continue
}
usages += fmt.Sprintf("%s Flags\n", group.Name())
usages += flags.FlagUsages() + "\n"
}
return strings.TrimSpace(usages)
}
func (f *Flags) Bind(cmd *cobra.Command) error {
for _, group := range f.groups() {
if group == nil {
continue
}
if err := group.Bind(cmd); err != nil {
for _, flag := range group.Flags() {
if err := bind(cmd, flag); err != nil {
return xerrors.Errorf("flag groups: %w", err)
}
}
}
return nil
}
@@ -277,6 +306,14 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF
opts.ScanOptions = f.ScanFlagGroup.ToOptions(args)
}
if f.SecretFlagGroup != nil {
opts.SecretOptions = f.SecretFlagGroup.ToOptions()
}
if f.VulnerabilityFlagGroup != nil {
opts.VulnerabilityOptions = f.VulnerabilityFlagGroup.ToOptions()
}
return opts, nil
}

View File

@@ -4,8 +4,6 @@ import (
"net/http"
"strings"
"github.com/spf13/cobra"
"github.com/aquasecurity/trivy/pkg/log"
)
@@ -87,25 +85,14 @@ func NewServerFlags() *RemoteFlagGroup {
}
}
func (f *RemoteFlagGroup) flags() []*Flag {
func (f *RemoteFlagGroup) Name() string {
return "Client/Server"
}
func (f *RemoteFlagGroup) Flags() []*Flag {
return []*Flag{f.Token, f.TokenHeader, f.ServerAddr, f.CustomHeaders, f.Listen}
}
func (f *RemoteFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
return err
}
}
return nil
}
func (f *RemoteFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
}
}
func (f *RemoteFlagGroup) ToOptions() RemoteOptions {
serverAddr := getString(f.ServerAddr)
customHeaders := splitCustomHeaders(getStringSlice(f.CustomHeaders))

View File

@@ -5,8 +5,6 @@ import (
"os"
"strings"
"github.com/samber/lo"
"github.com/spf13/cobra"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
@@ -87,14 +85,6 @@ var (
Value: strings.Join(dbTypes.SeverityNames, ","),
Usage: "severities of security issues to be displayed (comma separated)",
}
// Vulnerabilities
IgnoreUnfixedFlag = Flag{
Name: "ignore-unfixed",
ConfigName: "vulnerability.ignore-unfixed",
Value: false,
Usage: "display only fixed vulnerabilities",
}
)
// ReportFlagGroup composes common printer flag structs
@@ -105,7 +95,6 @@ type ReportFlagGroup struct {
Template *Flag
DependencyTree *Flag
ListAllPkgs *Flag
IgnoreUnfixed *Flag
IgnoreFile *Flag
IgnorePolicy *Flag
ExitCode *Flag
@@ -119,7 +108,6 @@ type ReportOptions struct {
Template string
DependencyTree bool
ListAllPkgs bool
IgnoreUnfixed bool
IgnoreFile string
ExitCode int
IgnorePolicy string
@@ -129,38 +117,26 @@ type ReportOptions struct {
func NewReportFlagGroup() *ReportFlagGroup {
return &ReportFlagGroup{
Format: lo.ToPtr(FormatFlag),
ReportFormat: lo.ToPtr(ReportFormatFlag),
Template: lo.ToPtr(TemplateFlag),
DependencyTree: lo.ToPtr(DependencyTreeFlag),
ListAllPkgs: lo.ToPtr(ListAllPkgsFlag),
IgnoreUnfixed: lo.ToPtr(IgnoreUnfixedFlag),
IgnoreFile: lo.ToPtr(IgnoreFileFlag),
IgnorePolicy: lo.ToPtr(IgnorePolicyFlag),
ExitCode: lo.ToPtr(ExitCodeFlag),
Output: lo.ToPtr(OutputFlag),
Severity: lo.ToPtr(SeverityFlag),
Format: &FormatFlag,
ReportFormat: &ReportFormatFlag,
Template: &TemplateFlag,
DependencyTree: &DependencyTreeFlag,
ListAllPkgs: &ListAllPkgsFlag,
IgnoreFile: &IgnoreFileFlag,
IgnorePolicy: &IgnorePolicyFlag,
ExitCode: &ExitCodeFlag,
Output: &OutputFlag,
Severity: &SeverityFlag,
}
}
func (f *ReportFlagGroup) flags() []*Flag {
return []*Flag{f.Format, f.ReportFormat, f.Template, f.DependencyTree, f.ListAllPkgs, f.IgnoreUnfixed, f.IgnoreFile, f.IgnorePolicy,
f.ExitCode, f.Output, f.Severity}
func (f *ReportFlagGroup) Name() string {
return "Report"
}
func (f *ReportFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
}
}
func (f *ReportFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
return err
}
}
return nil
func (f *ReportFlagGroup) Flags() []*Flag {
return []*Flag{f.Format, f.ReportFormat, f.Template, f.DependencyTree, f.ListAllPkgs, f.IgnoreFile,
f.IgnorePolicy, f.ExitCode, f.Output, f.Severity}
}
func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) {
@@ -211,12 +187,11 @@ func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) {
Template: template,
DependencyTree: dependencyTree,
ListAllPkgs: listAllPkgs,
IgnoreUnfixed: getBool(f.IgnoreUnfixed),
IgnoreFile: getString(f.IgnoreFile),
ExitCode: getInt(f.ExitCode),
IgnorePolicy: getString(f.IgnorePolicy),
Output: out,
Severities: splitSeverity(getString(f.Severity)),
Severities: splitSeverity(getStringSlice(f.Severity)),
}, nil
}
@@ -232,13 +207,16 @@ func (f *ReportFlagGroup) forceListAllPkgs(format string, listAllPkgs, dependenc
return false
}
func splitSeverity(severity string) []dbTypes.Severity {
if severity == "" {
func splitSeverity(severity []string) []dbTypes.Severity {
switch {
case len(severity) == 0:
return nil
case len(severity) == 1 && strings.Contains(severity[0], ","): // get severities from flag
severity = strings.Split(severity[0], ",")
}
var severities []dbTypes.Severity
for _, s := range strings.Split(severity, ",") {
for _, s := range severity {
sev, err := dbTypes.NewSeverity(s)
if err != nil {
log.Logger.Warnf("unknown severity option: %s", err)

View File

@@ -186,7 +186,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
DependencyTree: &flag.DependencyTreeFlag,
ListAllPkgs: &flag.ListAllPkgsFlag,
IgnoreFile: &flag.IgnoreFileFlag,
IgnoreUnfixed: &flag.IgnoreUnfixedFlag,
IgnorePolicy: &flag.IgnorePolicyFlag,
ExitCode: &flag.ExitCodeFlag,
Output: &flag.OutputFlag,

View File

@@ -1,7 +1,6 @@
package flag
import (
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/log"
@@ -10,13 +9,17 @@ import (
var (
ArtifactTypeFlag = Flag{
Name: "artifact-type",
ConfigName: "sbom.artifact-type",
Value: "",
Usage: "deprecated",
Deprecated: true,
}
SBOMFormatFlag = Flag{
Name: "sbom-format",
ConfigName: "sbom.format",
Value: "",
Usage: "deprecated",
Deprecated: true,
}
)
@@ -37,20 +40,12 @@ func NewSBOMFlagGroup() *SBOMFlagGroup {
}
}
func (f *SBOMFlagGroup) AddFlags(cmd *cobra.Command) {
if f.ArtifactType != nil {
cmd.Flags().String(ArtifactTypeFlag.Name, "", "deprecated")
cmd.Flags().MarkHidden(ArtifactTypeFlag.Name) // nolint: gosec
}
if f.SBOMFormat != nil {
cmd.Flags().String(SBOMFormatFlag.Name, "", "deprecated")
cmd.Flags().MarkHidden(SBOMFormatFlag.Name) // nolint: gosec
}
func (f *SBOMFlagGroup) Name() string {
return "SBOM"
}
func (f *SBOMFlagGroup) Bind(cmd *cobra.Command) error {
// All the flags are deprecated
return nil
func (f *SBOMFlagGroup) Flags() []*Flag {
return []*Flag{f.ArtifactType, f.SBOMFormat}
}
func (f *SBOMFlagGroup) ToOptions() (SBOMOptions, error) {

View File

@@ -4,8 +4,6 @@ import (
"fmt"
"strings"
"github.com/samber/lo"
"github.com/spf13/cobra"
"golang.org/x/exp/slices"
"github.com/aquasecurity/trivy/pkg/log"
@@ -37,18 +35,6 @@ var (
Value: fmt.Sprintf("%s,%s", types.SecurityCheckVulnerability, types.SecurityCheckSecret),
Usage: "comma-separated list of what security issues to detect (vuln,config,secret)",
}
VulnTypeFlag = Flag{
Name: "vuln-type",
ConfigName: "vulnerability.type",
Value: strings.Join([]string{types.VulnTypeOS, types.VulnTypeLibrary}, ","),
Usage: "comma-separated list of vulnerability types (os,library)",
}
SecretConfigFlag = Flag{
Name: "secret-config",
ConfigName: "secret.config",
Value: "trivy-secret.yaml",
Usage: "specify a path to config file for secret scanning",
}
)
type ScanFlagGroup struct {
@@ -56,9 +42,6 @@ type ScanFlagGroup struct {
SkipFiles *Flag
OfflineScan *Flag
SecurityChecks *Flag
VulnType *Flag
SecretConfig *Flag
}
type ScanOptions struct {
@@ -67,42 +50,23 @@ type ScanOptions struct {
SkipFiles []string
OfflineScan bool
SecurityChecks []string
// Vulnerabilities
VulnType []string
// Secrets
SecretConfigPath string
}
func NewScanFlagGroup() *ScanFlagGroup {
return &ScanFlagGroup{
SkipDirs: lo.ToPtr(SkipDirsFlag),
SkipFiles: lo.ToPtr(SkipFilesFlag),
OfflineScan: lo.ToPtr(OfflineScanFlag),
SecurityChecks: lo.ToPtr(SecurityChecksFlag),
VulnType: lo.ToPtr(VulnTypeFlag),
SecretConfig: lo.ToPtr(SecretConfigFlag),
SkipDirs: &SkipDirsFlag,
SkipFiles: &SkipFilesFlag,
OfflineScan: &OfflineScanFlag,
SecurityChecks: &SecurityChecksFlag,
}
}
func (f *ScanFlagGroup) flags() []*Flag {
return []*Flag{f.SkipDirs, f.SkipFiles, f.OfflineScan, f.SecurityChecks, f.VulnType, f.SecretConfig}
func (f *ScanFlagGroup) Name() string {
return "Scan"
}
func (f *ScanFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
return err
}
}
return nil
}
func (f *ScanFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
}
func (f *ScanFlagGroup) Flags() []*Flag {
return []*Flag{f.SkipDirs, f.SkipFiles, f.OfflineScan, f.SecurityChecks}
}
func (f *ScanFlagGroup) ToOptions(args []string) ScanOptions {
@@ -116,31 +80,10 @@ func (f *ScanFlagGroup) ToOptions(args []string) ScanOptions {
SkipDirs: getStringSlice(f.SkipDirs),
SkipFiles: getStringSlice(f.SkipFiles),
OfflineScan: getBool(f.OfflineScan),
VulnType: parseVulnType(getStringSlice(f.VulnType)),
SecurityChecks: parseSecurityCheck(getStringSlice(f.SecurityChecks)),
SecretConfigPath: getString(f.SecretConfig),
}
}
func parseVulnType(vulnType []string) []string {
switch {
case len(vulnType) == 0: // no types
return nil
case len(vulnType) == 1 && strings.Contains(vulnType[0], ","): // get checks from flag
vulnType = strings.Split(vulnType[0], ",")
}
var vulnTypes []string
for _, v := range vulnType {
if !slices.Contains(types.VulnTypes, v) {
log.Logger.Warnf("unknown vulnerability type: %s", v)
continue
}
vulnTypes = append(vulnTypes, v)
}
return vulnTypes
}
func parseSecurityCheck(securityCheck []string) []string {
switch {
case len(securityCheck) == 0: // no checks

View File

@@ -36,32 +36,6 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
Target: "alpine:latest",
},
},
{
name: "happy path for OS vulnerabilities",
args: []string{"alpine:latest"},
fields: fields{
vulnType: "os",
securityChecks: "vuln",
},
want: flag.ScanOptions{
Target: "alpine:latest",
VulnType: []string{types.VulnTypeOS},
SecurityChecks: []string{types.SecurityCheckVulnerability},
},
},
{
name: "happy path for library vulnerabilities",
args: []string{"alpine:latest"},
fields: fields{
vulnType: "library",
securityChecks: "vuln",
},
want: flag.ScanOptions{
Target: "alpine:latest",
VulnType: []string{types.VulnTypeLibrary},
SecurityChecks: []string{types.SecurityCheckVulnerability},
},
},
{
name: "happy path for configs",
args: []string{"alpine:latest"},
@@ -85,18 +59,6 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
`unknown security check: WRONG-CHECK`,
},
},
{
name: "with wrong vuln type",
fields: fields{
vulnType: "os,nonevuln",
},
want: flag.ScanOptions{
VulnType: []string{types.VulnTypeOS},
},
wantLogs: []string{
`unknown vulnerability type: nonevuln`,
},
},
{
name: "without target (args)",
args: []string{},
@@ -156,7 +118,6 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
SkipDirs: &flag.SkipDirsFlag,
SkipFiles: &flag.SkipFilesFlag,
OfflineScan: &flag.OfflineScanFlag,
VulnType: &flag.VulnTypeFlag,
SecurityChecks: &flag.SecurityChecksFlag,
}

38
pkg/flag/secret_flags.go Normal file
View File

@@ -0,0 +1,38 @@
package flag
var (
SecretConfigFlag = Flag{
Name: "secret-config",
ConfigName: "secret.config",
Value: "trivy-secret.yaml",
Usage: "specify a path to config file for secret scanning",
}
)
type SecretFlagGroup struct {
SecretConfig *Flag
}
type SecretOptions struct {
SecretConfigPath string
}
func NewSecretFlagGroup() *SecretFlagGroup {
return &SecretFlagGroup{
SecretConfig: &SecretConfigFlag,
}
}
func (f *SecretFlagGroup) Name() string {
return "Secret"
}
func (f *SecretFlagGroup) Flags() []*Flag {
return []*Flag{f.SecretConfig}
}
func (f *SecretFlagGroup) ToOptions() SecretOptions {
return SecretOptions{
SecretConfigPath: getString(f.SecretConfig),
}
}

View File

@@ -0,0 +1,76 @@
package flag
import (
"strings"
"golang.org/x/exp/slices"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/types"
)
var (
VulnTypeFlag = Flag{
Name: "vuln-type",
ConfigName: "vulnerability.type",
Value: strings.Join([]string{types.VulnTypeOS, types.VulnTypeLibrary}, ","),
Usage: "comma-separated list of vulnerability types (os,library)",
}
IgnoreUnfixedFlag = Flag{
Name: "ignore-unfixed",
ConfigName: "vulnerability.ignore-unfixed",
Value: false,
Usage: "display only fixed vulnerabilities",
}
)
type VulnerabilityFlagGroup struct {
VulnType *Flag
IgnoreUnfixed *Flag
}
type VulnerabilityOptions struct {
VulnType []string
IgnoreUnfixed bool
}
func NewVulnerabilityFlagGroup() *VulnerabilityFlagGroup {
return &VulnerabilityFlagGroup{
VulnType: &VulnTypeFlag,
IgnoreUnfixed: &IgnoreUnfixedFlag,
}
}
func (f *VulnerabilityFlagGroup) Name() string {
return "Vulnerability"
}
func (f *VulnerabilityFlagGroup) Flags() []*Flag {
return []*Flag{f.VulnType, f.IgnoreUnfixed}
}
func (f *VulnerabilityFlagGroup) ToOptions() VulnerabilityOptions {
return VulnerabilityOptions{
VulnType: parseVulnType(getStringSlice(f.VulnType)),
IgnoreUnfixed: getBool(f.IgnoreUnfixed),
}
}
func parseVulnType(vulnType []string) []string {
switch {
case len(vulnType) == 0: // no types
return nil
case len(vulnType) == 1 && strings.Contains(vulnType[0], ","): // get checks from flag
vulnType = strings.Split(vulnType[0], ",")
}
var vulnTypes []string
for _, v := range vulnType {
if !slices.Contains(types.VulnTypes, v) {
log.Logger.Warnf("unknown vulnerability type: %s", v)
continue
}
vulnTypes = append(vulnTypes, v)
}
return vulnTypes
}

View File

@@ -0,0 +1,87 @@
package flag_test
import (
"testing"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/types"
)
func TestVulnerabilityFlagGroup_ToOptions(t *testing.T) {
type fields struct {
vulnType string
}
tests := []struct {
name string
args []string
fields fields
want flag.VulnerabilityOptions
wantLogs []string
}{
{
name: "happy path for OS vulnerabilities",
args: []string{"alpine:latest"},
fields: fields{
vulnType: "os",
},
want: flag.VulnerabilityOptions{
VulnType: []string{types.VulnTypeOS},
},
},
{
name: "happy path for library vulnerabilities",
args: []string{"alpine:latest"},
fields: fields{
vulnType: "library",
},
want: flag.VulnerabilityOptions{
VulnType: []string{types.VulnTypeLibrary},
},
},
{
name: "wrong vuln type",
fields: fields{
vulnType: "os,nonevuln",
},
want: flag.VulnerabilityOptions{
VulnType: []string{types.VulnTypeOS},
},
wantLogs: []string{
`unknown vulnerability type: nonevuln`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
level := zap.WarnLevel
core, obs := observer.New(level)
log.Logger = zap.New(core).Sugar()
viper.Set(flag.VulnTypeFlag.ConfigName, tt.fields.vulnType)
// Assert options
f := &flag.VulnerabilityFlagGroup{
VulnType: &flag.VulnTypeFlag,
}
got := f.ToOptions()
assert.Equalf(t, tt.want, got, "ToOptions()")
// Assert log messages
var gotMessages []string
for _, entry := range obs.AllUntimed() {
gotMessages = append(gotMessages, entry.Message)
}
assert.Equal(t, tt.wantLogs, gotMessages, tt.name)
})
}
}