fix(report): close the file (#4842)

* fix(report): close the file

* refactor: add the format type

* fix: return errors in version printing

* fix: lint issues

* fix: do not fail on bogus cache dir

---------

Co-authored-by: DmitriyLewen <dmitriy.lewen@smartforce.io>
This commit is contained in:
Teppei Fukuda
2023-07-23 16:37:18 +03:00
committed by GitHub
parent 24a3e547d9
commit 20c2246a61
38 changed files with 352 additions and 362 deletions

View File

@@ -5,10 +5,9 @@ import (
"errors" "errors"
"strings" "strings"
"golang.org/x/exp/slices"
"github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go-v2/service/sts"
"golang.org/x/exp/slices"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/aquasecurity/defsec/pkg/errs" "github.com/aquasecurity/defsec/pkg/errs"
@@ -17,10 +16,8 @@ import (
"github.com/aquasecurity/trivy/pkg/cloud/aws/scanner" "github.com/aquasecurity/trivy/pkg/cloud/aws/scanner"
"github.com/aquasecurity/trivy/pkg/cloud/report" "github.com/aquasecurity/trivy/pkg/cloud/report"
"github.com/aquasecurity/trivy/pkg/commands/operation" "github.com/aquasecurity/trivy/pkg/commands/operation"
cr "github.com/aquasecurity/trivy/pkg/compliance/report"
"github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/types"
) )
var allSupportedServicesFunc = awsScanner.AllSupportedServices var allSupportedServicesFunc = awsScanner.AllSupportedServices
@@ -166,24 +163,6 @@ func Run(ctx context.Context, opt flag.Options) error {
} }
log.Logger.Debug("Writing report to output...") log.Logger.Debug("Writing report to output...")
if opt.Compliance.Spec.ID != "" {
convertedResults := report.ConvertResults(results, cloud.ProviderAWS, opt.Services)
var crr []types.Results
for _, r := range convertedResults {
crr = append(crr, r.Results)
}
complianceReport, err := cr.BuildComplianceReport(crr, opt.Compliance)
if err != nil {
return xerrors.Errorf("compliance report build error: %w", err)
}
return cr.Write(complianceReport, cr.Option{
Format: opt.Format,
Report: opt.ReportFormat,
Output: opt.Output,
})
}
res := results.GetFailed() res := results.GetFailed()
if opt.MisconfOptions.IncludeNonFailures { if opt.MisconfOptions.IncludeNonFailures {

View File

@@ -1,7 +1,6 @@
package commands package commands
import ( import (
"bytes"
"context" "context"
"os" "os"
"path/filepath" "path/filepath"
@@ -1243,8 +1242,8 @@ Summary Report for compliance: my-custom-spec
}() }()
} }
buffer := new(bytes.Buffer) output := filepath.Join(t.TempDir(), "output")
test.options.Output = buffer test.options.Output = output
test.options.Debug = true test.options.Debug = true
test.options.GlobalOptions.Timeout = time.Minute test.options.GlobalOptions.Timeout = time.Minute
if test.options.Format == "" { if test.options.Format == "" {
@@ -1283,10 +1282,13 @@ Summary Report for compliance: my-custom-spec
err := Run(context.Background(), test.options) err := Run(context.Background(), test.options)
if test.expectErr { if test.expectErr {
assert.Error(t, err) assert.Error(t, err)
} else { return
assert.NoError(t, err)
assert.Equal(t, test.want, buffer.String())
} }
assert.NoError(t, err)
b, err := os.ReadFile(output)
require.NoError(t, err)
assert.Equal(t, test.want, string(b))
}) })
} }
} }

View File

@@ -2,12 +2,16 @@ package report
import ( import (
"context" "context"
"io"
"os" "os"
"sort" "sort"
"time" "time"
"golang.org/x/xerrors"
"github.com/aquasecurity/defsec/pkg/scan" "github.com/aquasecurity/defsec/pkg/scan"
"github.com/aquasecurity/tml" "github.com/aquasecurity/tml"
cr "github.com/aquasecurity/trivy/pkg/compliance/report"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/flag"
pkgReport "github.com/aquasecurity/trivy/pkg/report" pkgReport "github.com/aquasecurity/trivy/pkg/report"
@@ -55,6 +59,15 @@ func (r *Report) Failed() bool {
// Write writes the results in the give format // Write writes the results in the give format
func Write(rep *Report, opt flag.Options, fromCache bool) error { func Write(rep *Report, opt flag.Options, fromCache bool) error {
output, err := opt.OutputWriter()
if err != nil {
return xerrors.Errorf("failed to create output file: %w", err)
}
defer output.Close()
if opt.Compliance.Spec.ID != "" {
return writeCompliance(rep, opt, output)
}
var filtered []types.Result var filtered []types.Result
@@ -91,7 +104,7 @@ func Write(rep *Report, opt flag.Options, fromCache bool) error {
// ensure color/formatting is disabled for pipes/non-pty // ensure color/formatting is disabled for pipes/non-pty
var useANSI bool var useANSI bool
if opt.Output == os.Stdout { if opt.Output == "" {
if o, err := os.Stdout.Stat(); err == nil { if o, err := os.Stdout.Stat(); err == nil {
useANSI = (o.Mode() & os.ModeCharDevice) == os.ModeCharDevice useANSI = (o.Mode() & os.ModeCharDevice) == os.ModeCharDevice
} }
@@ -102,33 +115,44 @@ func Write(rep *Report, opt flag.Options, fromCache bool) error {
switch { switch {
case len(opt.Services) == 1 && opt.ARN == "": case len(opt.Services) == 1 && opt.ARN == "":
if err := writeResourceTable(rep, filtered, opt.Output, opt.Services[0]); err != nil { if err := writeResourceTable(rep, filtered, output, opt.Services[0]); err != nil {
return err return err
} }
case len(opt.Services) == 1 && opt.ARN != "": case len(opt.Services) == 1 && opt.ARN != "":
if err := writeResultsForARN(rep, filtered, opt.Output, opt.Services[0], opt.ARN, opt.Severities); err != nil { if err := writeResultsForARN(rep, filtered, output, opt.Services[0], opt.ARN, opt.Severities); err != nil {
return err return err
} }
default: default:
if err := writeServiceTable(rep, filtered, opt.Output); err != nil { if err := writeServiceTable(rep, filtered, output); err != nil {
return err return err
} }
} }
// render cache info // render cache info
if fromCache { if fromCache {
_ = tml.Fprintf(opt.Output, "\n<blue>This scan report was loaded from cached results. If you'd like to run a fresh scan, use --update-cache.</blue>\n") _ = tml.Fprintf(output, "\n<blue>This scan report was loaded from cached results. If you'd like to run a fresh scan, use --update-cache.</blue>\n")
} }
return nil return nil
default: default:
return pkgReport.Write(base, pkgReport.Option{ return pkgReport.Write(base, opt)
Format: opt.Format,
Output: opt.Output,
Severities: opt.Severities,
OutputTemplate: opt.Template,
IncludeNonFailures: opt.IncludeNonFailures,
Trace: opt.Trace,
})
} }
} }
func writeCompliance(rep *Report, opt flag.Options, output io.Writer) error {
var crr []types.Results
for _, r := range rep.Results {
crr = append(crr, r.Results)
}
complianceReport, err := cr.BuildComplianceReport(crr, opt.Compliance)
if err != nil {
return xerrors.Errorf("compliance report build error: %w", err)
}
return cr.Write(complianceReport, cr.Option{
Format: opt.Format,
Report: opt.ReportFormat,
Output: output,
})
}

View File

@@ -1,7 +1,8 @@
package report package report
import ( import (
"bytes" "os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -109,15 +110,18 @@ No problems detected.
tt.options.AWSOptions.Services, tt.options.AWSOptions.Services,
) )
buffer := bytes.NewBuffer([]byte{}) output := filepath.Join(t.TempDir(), "output")
tt.options.Output = buffer tt.options.Output = output
require.NoError(t, Write(report, tt.options, tt.fromCache)) require.NoError(t, Write(report, tt.options, tt.fromCache))
assert.Equal(t, "AWS", report.Provider) assert.Equal(t, "AWS", report.Provider)
assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID)
assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.Equal(t, tt.options.AWSOptions.Region, report.Region)
assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope)
assert.Equal(t, tt.expected, buffer.String())
b, err := os.ReadFile(output)
require.NoError(t, err)
assert.Equal(t, tt.expected, string(b))
}) })
} }
} }

View File

@@ -1,7 +1,8 @@
package report package report
import ( import (
"bytes" "os"
"path/filepath"
"strings" "strings"
"testing" "testing"
@@ -68,15 +69,18 @@ See https://avd.aquasec.com/misconfig/avd-aws-9999
tt.options.AWSOptions.Services, tt.options.AWSOptions.Services,
) )
buffer := bytes.NewBuffer([]byte{}) output := filepath.Join(t.TempDir(), "output")
tt.options.Output = buffer tt.options.Output = output
require.NoError(t, Write(report, tt.options, tt.fromCache)) require.NoError(t, Write(report, tt.options, tt.fromCache))
b, err := os.ReadFile(output)
require.NoError(t, err)
assert.Equal(t, "AWS", report.Provider) assert.Equal(t, "AWS", report.Provider)
assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID)
assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.Equal(t, tt.options.AWSOptions.Region, report.Region)
assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope)
assert.Equal(t, tt.expected, strings.ReplaceAll(buffer.String(), "\r\n", "\n")) assert.Equal(t, tt.expected, strings.ReplaceAll(string(b), "\r\n", "\n"))
}) })
} }
} }

View File

@@ -1,7 +1,8 @@
package report package report
import ( import (
"bytes" "os"
"path/filepath"
"testing" "testing"
"github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy-db/pkg/types"
@@ -320,19 +321,22 @@ Scan Overview for AWS Account
tt.options.AWSOptions.Services, tt.options.AWSOptions.Services,
) )
buffer := bytes.NewBuffer([]byte{}) output := filepath.Join(t.TempDir(), "output")
tt.options.Output = buffer tt.options.Output = output
require.NoError(t, Write(report, tt.options, tt.fromCache)) require.NoError(t, Write(report, tt.options, tt.fromCache))
assert.Equal(t, "AWS", report.Provider) assert.Equal(t, "AWS", report.Provider)
assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID)
assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.Equal(t, tt.options.AWSOptions.Region, report.Region)
assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope)
b, err := os.ReadFile(output)
require.NoError(t, err)
if tt.options.Format == "json" { if tt.options.Format == "json" {
// json output can be formatted/ordered differently - we just care that the data matches // json output can be formatted/ordered differently - we just care that the data matches
assert.JSONEq(t, tt.expected, buffer.String()) assert.JSONEq(t, tt.expected, string(b))
} else { } else {
assert.Equal(t, tt.expected, buffer.String()) assert.Equal(t, tt.expected, string(b))
} }
}) })
} }

View File

@@ -28,8 +28,8 @@ import (
"github.com/aquasecurity/trivy/pkg/module" "github.com/aquasecurity/trivy/pkg/module"
"github.com/aquasecurity/trivy/pkg/plugin" "github.com/aquasecurity/trivy/pkg/plugin"
"github.com/aquasecurity/trivy/pkg/policy" "github.com/aquasecurity/trivy/pkg/policy"
r "github.com/aquasecurity/trivy/pkg/report"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
) )
// VersionInfo holds the trivy DB version Info // VersionInfo holds the trivy DB version Info
@@ -71,15 +71,6 @@ Use "{{.CommandPath}} [command] --help" for more information about a command.{{e
groupPlugin = "plugin" groupPlugin = "plugin"
) )
var (
outputWriter io.Writer = os.Stdout
)
// SetOut overrides the destination for messages
func SetOut(out io.Writer) {
outputWriter = out
}
// NewApp is the factory method to return Trivy CLI // NewApp is the factory method to return Trivy CLI
func NewApp(version string) *cobra.Command { func NewApp(version string) *cobra.Command {
globalFlags := flag.NewGlobalFlagGroup() globalFlags := flag.NewGlobalFlagGroup()
@@ -189,8 +180,6 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co
$ trivy server`, $ trivy server`,
Args: cobra.NoArgs, Args: cobra.NoArgs,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error { PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
cmd.SetOut(outputWriter)
// Set the Trivy version here so that we can override version printer. // Set the Trivy version here so that we can override version printer.
cmd.Version = version cmd.Version = version
@@ -224,11 +213,10 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co
globalOptions := globalFlags.ToOptions() globalOptions := globalFlags.ToOptions()
if globalOptions.ShowVersion { if globalOptions.ShowVersion {
// Customize version output // Customize version output
showVersion(globalOptions.CacheDir, versionFormat, version, outputWriter) return showVersion(globalOptions.CacheDir, versionFormat, version, cmd.OutOrStdout())
} else { } else {
return cmd.Help() return cmd.Help()
} }
return nil
}, },
} }
@@ -310,7 +298,7 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
return validateArgs(cmd, args) return validateArgs(cmd, args)
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
options, err := imageFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) options, err := imageFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -369,7 +357,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := fsFlags.Bind(cmd); err != nil { if err := fsFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
options, err := fsFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) options, err := fsFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -428,7 +416,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := rootfsFlags.Bind(cmd); err != nil { if err := rootfsFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
options, err := rootfsFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) options, err := rootfsFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -481,7 +469,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := repoFlags.Bind(cmd); err != nil { if err := repoFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
options, err := repoFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) options, err := repoFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -521,7 +509,7 @@ func NewConvertCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := convertFlags.Bind(cmd); err != nil { if err := convertFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
opts, err := convertFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) opts, err := convertFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -578,7 +566,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := clientFlags.Bind(cmd); err != nil { if err := clientFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
options, err := clientFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) options, err := clientFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -619,7 +607,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := serverFlags.Bind(cmd); err != nil { if err := serverFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
options, err := serverFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) options, err := serverFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -681,7 +669,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := configFlags.Bind(cmd); err != nil { if err := configFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
options, err := configFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) options, err := configFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -839,7 +827,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
} }
repo := args[0] repo := args[0]
opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -863,7 +851,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
} }
repo := args[0] repo := args[0]
opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -904,11 +892,11 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol' reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol'
formatFlag := flag.FormatFlag formatFlag := flag.FormatFlag
formatFlag.Values = []string{ formatFlag.Values = xstrings.ToStringSlice([]types.Format{
r.FormatTable, types.FormatTable,
r.FormatJSON, types.FormatJSON,
r.FormatCycloneDX, types.FormatCycloneDX,
} })
reportFlagGroup.Format = &formatFlag reportFlagGroup.Format = &formatFlag
k8sFlags := &flag.Flags{ k8sFlags := &flag.Flags{
@@ -952,7 +940,7 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := k8sFlags.Bind(cmd); err != nil { if err := k8sFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
opts, err := k8sFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) opts, err := k8sFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -972,7 +960,10 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
func NewAWSCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { func NewAWSCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
reportFlagGroup := flag.NewReportFlagGroup() reportFlagGroup := flag.NewReportFlagGroup()
compliance := flag.ComplianceFlag compliance := flag.ComplianceFlag
compliance.Values = []string{types.ComplianceAWSCIS12, types.ComplianceAWSCIS14} compliance.Values = []string{
types.ComplianceAWSCIS12,
types.ComplianceAWSCIS14,
}
reportFlagGroup.Compliance = &compliance // override usage as the accepted values differ for each subcommand. reportFlagGroup.Compliance = &compliance // override usage as the accepted values differ for each subcommand.
reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol' reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol'
@@ -1016,7 +1007,7 @@ The following services are supported:
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
opts, err := awsFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) opts, err := awsFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -1080,7 +1071,7 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := vmFlags.Bind(cmd); err != nil { if err := vmFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
options, err := vmFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) options, err := vmFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -1139,7 +1130,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := sbomFlags.Bind(cmd); err != nil { if err := sbomFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
options, err := sbomFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) options, err := sbomFlags.ToOptions(cmd.Version, args, globalFlags)
if err != nil { if err != nil {
return xerrors.Errorf("flag error: %w", err) return xerrors.Errorf("flag error: %w", err)
} }
@@ -1168,9 +1159,7 @@ func NewVersionCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
Args: cobra.NoArgs, Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
options := globalFlags.ToOptions() options := globalFlags.ToOptions()
showVersion(options.CacheDir, versionFormat, cmd.Version, outputWriter) return showVersion(options.CacheDir, versionFormat, cmd.Version, cmd.OutOrStdout())
return nil
}, },
SilenceErrors: true, SilenceErrors: true,
SilenceUsage: true, SilenceUsage: true,
@@ -1183,12 +1172,15 @@ func NewVersionCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
return cmd return cmd
} }
func showVersion(cacheDir, outputFormat, version string, outputWriter io.Writer) { func showVersion(cacheDir, outputFormat, version string, w io.Writer) error {
var dbMeta *metadata.Metadata var dbMeta *metadata.Metadata
var javadbMeta *metadata.Metadata var javadbMeta *metadata.Metadata
mc := metadata.NewClient(cacheDir) mc := metadata.NewClient(cacheDir)
meta, _ := mc.Get() // nolint: errcheck meta, err := mc.Get()
if err != nil {
log.Logger.Debugw("Failed to get DB metadata", "error", err)
}
if !meta.UpdatedAt.IsZero() && !meta.NextUpdate.IsZero() && meta.Version != 0 { if !meta.UpdatedAt.IsZero() && !meta.NextUpdate.IsZero() && meta.Version != 0 {
dbMeta = &metadata.Metadata{ dbMeta = &metadata.Metadata{
Version: meta.Version, Version: meta.Version,
@@ -1199,7 +1191,10 @@ func showVersion(cacheDir, outputFormat, version string, outputWriter io.Writer)
} }
mcJava := javadb.NewMetadata(filepath.Join(cacheDir, "java-db")) mcJava := javadb.NewMetadata(filepath.Join(cacheDir, "java-db"))
metaJava, _ := mcJava.Get() // nolint: errcheck metaJava, err := mcJava.Get()
if err != nil {
log.Logger.Debugw("Failed to get Java DB metadata", "error", err)
}
if !metaJava.UpdatedAt.IsZero() && !metaJava.NextUpdate.IsZero() && metaJava.Version != 0 { if !metaJava.UpdatedAt.IsZero() && !metaJava.NextUpdate.IsZero() && metaJava.Version != 0 {
javadbMeta = &metadata.Metadata{ javadbMeta = &metadata.Metadata{
Version: metaJava.Version, Version: metaJava.Version,
@@ -1212,18 +1207,23 @@ func showVersion(cacheDir, outputFormat, version string, outputWriter io.Writer)
var pbMeta *policy.Metadata var pbMeta *policy.Metadata
pc, err := policy.NewClient(cacheDir, false) pc, err := policy.NewClient(cacheDir, false)
if pc != nil && err == nil { if pc != nil && err == nil {
pbMeta, _ = pc.GetMetadata() pbMeta, err = pc.GetMetadata()
if err != nil {
log.Logger.Debugw("Failed to get policy metadata", "error", err)
}
} }
switch outputFormat { switch outputFormat {
case "json": case "json":
b, _ := json.Marshal(VersionInfo{ err = json.NewEncoder(w).Encode(VersionInfo{
Version: version, Version: version,
VulnerabilityDB: dbMeta, VulnerabilityDB: dbMeta,
JavaDB: javadbMeta, JavaDB: javadbMeta,
PolicyBundle: pbMeta, PolicyBundle: pbMeta,
}) })
fmt.Fprintln(outputWriter, string(b)) if err != nil {
return xerrors.Errorf("json encode error: %w", err)
}
default: default:
output := fmt.Sprintf("Version: %s\n", version) output := fmt.Sprintf("Version: %s\n", version)
if dbMeta != nil { if dbMeta != nil {
@@ -1250,8 +1250,9 @@ func showVersion(cacheDir, outputFormat, version string, outputWriter io.Writer)
DownloadedAt: %s DownloadedAt: %s
`, pbMeta.Digest, pbMeta.DownloadedAt.UTC()) `, pbMeta.Digest, pbMeta.DownloadedAt.UTC())
} }
fmt.Fprintf(outputWriter, output) fmt.Fprintf(w, output)
} }
return nil
} }
func validateArgs(cmd *cobra.Command, args []string) error { func validateArgs(cmd *cobra.Command, args []string) error {

View File

@@ -11,7 +11,7 @@ import (
dbTypes "github.com/aquasecurity/trivy-db/pkg/types" dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/types"
) )
func Test_showVersion(t *testing.T) { func Test_showVersion(t *testing.T) {
@@ -158,7 +158,7 @@ Policy Bundle:
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got := new(bytes.Buffer) got := new(bytes.Buffer)
app := NewApp("test") app := NewApp("test")
SetOut(got) app.SetOut(got)
app.SetArgs(test.arguments) app.SetArgs(test.arguments)
err := app.Execute() err := app.Execute()
@@ -170,7 +170,7 @@ Policy Bundle:
func TestFlags(t *testing.T) { func TestFlags(t *testing.T) {
type want struct { type want struct {
format string format types.Format
severities []dbTypes.Severity severities []dbTypes.Severity
} }
tests := []struct { tests := []struct {
@@ -185,7 +185,7 @@ func TestFlags(t *testing.T) {
"test", "test",
}, },
want: want{ want: want{
format: report.FormatTable, format: types.FormatTable,
severities: []dbTypes.Severity{ severities: []dbTypes.Severity{
dbTypes.SeverityUnknown, dbTypes.SeverityUnknown,
dbTypes.SeverityLow, dbTypes.SeverityLow,
@@ -203,7 +203,7 @@ func TestFlags(t *testing.T) {
"LOW,MEDIUM", "LOW,MEDIUM",
}, },
want: want{ want: want{
format: report.FormatTable, format: types.FormatTable,
severities: []dbTypes.Severity{ severities: []dbTypes.Severity{
dbTypes.SeverityLow, dbTypes.SeverityLow,
dbTypes.SeverityMedium, dbTypes.SeverityMedium,
@@ -220,7 +220,7 @@ func TestFlags(t *testing.T) {
"HIGH", "HIGH",
}, },
want: want{ want: want{
format: report.FormatTable, format: types.FormatTable,
severities: []dbTypes.Severity{ severities: []dbTypes.Severity{
dbTypes.SeverityLow, dbTypes.SeverityLow,
dbTypes.SeverityHigh, dbTypes.SeverityHigh,
@@ -237,7 +237,7 @@ func TestFlags(t *testing.T) {
"CRITICAL", "CRITICAL",
}, },
want: want{ want: want{
format: report.FormatJSON, format: types.FormatJSON,
severities: []dbTypes.Severity{ severities: []dbTypes.Severity{
dbTypes.SeverityCritical, dbTypes.SeverityCritical,
}, },
@@ -259,7 +259,7 @@ func TestFlags(t *testing.T) {
globalFlags := flag.NewGlobalFlagGroup() globalFlags := flag.NewGlobalFlagGroup()
rootCmd := NewRootCommand("dev", globalFlags) rootCmd := NewRootCommand("dev", globalFlags)
rootCmd.SetErr(io.Discard) rootCmd.SetErr(io.Discard)
SetOut(io.Discard) rootCmd.SetOut(io.Discard)
flags := &flag.Flags{ flags := &flag.Flags{
ReportFlagGroup: flag.NewReportFlagGroup(), ReportFlagGroup: flag.NewReportFlagGroup(),
@@ -270,7 +270,7 @@ func TestFlags(t *testing.T) {
// Bind // Bind
require.NoError(t, flags.Bind(cmd)) require.NoError(t, flags.Bind(cmd))
options, err := flags.ToOptions("dev", args, globalFlags, nil) options, err := flags.ToOptions("dev", args, globalFlags)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.want.format, options.Format) assert.Equal(t, tt.want.format, options.Format)

View File

@@ -282,7 +282,7 @@ func (r *runner) Filter(ctx context.Context, opts flag.Options, report types.Rep
} }
func (r *runner) Report(opts flag.Options, report types.Report) error { func (r *runner) Report(opts flag.Options, report types.Report) error {
if err := pkgReport.Write(report, opts.ReportOpts()); err != nil { if err := pkgReport.Write(report, opts); err != nil {
return xerrors.Errorf("unable to write results: %w", err) return xerrors.Errorf("unable to write results: %w", err)
} }
@@ -325,7 +325,7 @@ func (r *runner) initJavaDB(opts flag.Options) error {
// If vulnerability scanning and SBOM generation are disabled, it doesn't need to download the Java database. // If vulnerability scanning and SBOM generation are disabled, it doesn't need to download the Java database.
if !opts.Scanners.Enabled(types.VulnerabilityScanner) && if !opts.Scanners.Enabled(types.VulnerabilityScanner) &&
!slices.Contains(pkgReport.SupportedSBOMFormats, opts.Format) { !slices.Contains(types.SupportedSBOMFormats, opts.Format) {
return nil return nil
} }
@@ -497,7 +497,7 @@ func disabledAnalyzers(opts flag.Options) []analyzer.Type {
// But we don't create client if vulnerability analysis is disabled and SBOM format is not used // But we don't create client if vulnerability analysis is disabled and SBOM format is not used
// We need to disable jar analyzer to avoid errors // We need to disable jar analyzer to avoid errors
// TODO disable all languages that don't contain license information for this case // TODO disable all languages that don't contain license information for this case
if !opts.Scanners.Enabled(types.VulnerabilityScanner) && !slices.Contains(pkgReport.SupportedSBOMFormats, opts.Format) { if !opts.Scanners.Enabled(types.VulnerabilityScanner) && !slices.Contains(types.SupportedSBOMFormats, opts.Format) {
analyzers = append(analyzers, analyzer.TypeJar) analyzers = append(analyzers, analyzer.TypeJar)
} }
@@ -612,7 +612,7 @@ func initScannerConfig(opts flag.Options, cacheClient cache.Cache) (ScannerConfi
// SPDX needs to calculate digests for package files // SPDX needs to calculate digests for package files
var fileChecksum bool var fileChecksum bool
if opts.Format == pkgReport.FormatSPDXJSON || opts.Format == pkgReport.FormatSPDX { if opts.Format == types.FormatSPDXJSON || opts.Format == types.FormatSPDX {
fileChecksum = true fileChecksum = true
} }

View File

@@ -37,7 +37,7 @@ func Run(ctx context.Context, opts flag.Options) (err error) {
} }
log.Logger.Debug("Writing report to output...") log.Logger.Debug("Writing report to output...")
if err = report.Write(r, opts.ReportOpts()); err != nil { if err = report.Write(r, opts); err != nil {
return xerrors.Errorf("unable to write results: %w", err) return xerrors.Errorf("unable to write results: %w", err)
} }

View File

@@ -15,13 +15,10 @@ import (
const ( const (
allReport = "all" allReport = "all"
summaryReport = "summary" summaryReport = "summary"
tableFormat = "table"
jsonFormat = "json"
) )
type Option struct { type Option struct {
Format string Format types.Format
Report string Report string
Output io.Writer Output io.Writer
Severities []dbTypes.Severity Severities []dbTypes.Severity
@@ -70,10 +67,10 @@ type Writer interface {
// Write writes the results in the give format // Write writes the results in the give format
func Write(report *ComplianceReport, option Option) error { func Write(report *ComplianceReport, option Option) error {
switch option.Format { switch option.Format {
case jsonFormat: case types.FormatJSON:
jwriter := JSONWriter{Output: option.Output, Report: option.Report} jwriter := JSONWriter{Output: option.Output, Report: option.Report}
return jwriter.Write(report) return jwriter.Write(report)
case tableFormat: case types.FormatTable:
if !report.empty() { if !report.empty() {
complianceWriter := &TableWriter{ complianceWriter := &TableWriter{
Output: option.Output, Output: option.Output,

View File

@@ -10,14 +10,14 @@ import (
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/mapfs" "github.com/aquasecurity/trivy/pkg/mapfs"
"github.com/aquasecurity/trivy/pkg/syncx" "github.com/aquasecurity/trivy/pkg/x/sync"
) )
// CompositeFS contains multiple filesystems for post-analyzers // CompositeFS contains multiple filesystems for post-analyzers
type CompositeFS struct { type CompositeFS struct {
group AnalyzerGroup group AnalyzerGroup
dir string dir string
files *syncx.Map[Type, *mapfs.FS] files *sync.Map[Type, *mapfs.FS]
} }
func NewCompositeFS(group AnalyzerGroup) (*CompositeFS, error) { func NewCompositeFS(group AnalyzerGroup) (*CompositeFS, error) {
@@ -29,7 +29,7 @@ func NewCompositeFS(group AnalyzerGroup) (*CompositeFS, error) {
return &CompositeFS{ return &CompositeFS{
group: group, group: group,
dir: tmpDir, dir: tmpDir,
files: new(syncx.Map[Type, *mapfs.FS]), files: new(sync.Map[Type, *mapfs.FS]),
}, nil }, nil
} }

View File

@@ -2,7 +2,6 @@ package types
import ( import (
v1 "github.com/google/go-containerregistry/pkg/v1" v1 "github.com/google/go-containerregistry/pkg/v1"
"github.com/samber/lo"
) )
const ( const (
@@ -106,9 +105,3 @@ type Credential struct {
Username string Username string
Password string Password string
} }
func (runtimes ImageSources) StringSlice() []string {
return lo.Map(runtimes, func(r ImageSource, _ int) string {
return string(r)
})
}

View File

@@ -6,6 +6,7 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
) )
// e.g. config yaml // e.g. config yaml
@@ -18,10 +19,10 @@ var (
Name: "image-config-scanners", Name: "image-config-scanners",
ConfigName: "image.image-config-scanners", ConfigName: "image.image-config-scanners",
Default: []string{}, Default: []string{},
Values: types.Scanners{ Values: xstrings.ToStringSlice(types.Scanners{
types.MisconfigScanner, types.MisconfigScanner,
types.SecretScanner, types.SecretScanner,
}.StringSlice(), }),
Usage: "comma-separated list of what security issues to detect on container image configurations", Usage: "comma-separated list of what security issues to detect on container image configurations",
} }
ScanRemovedPkgsFlag = Flag{ ScanRemovedPkgsFlag = Flag{
@@ -51,8 +52,8 @@ var (
SourceFlag = Flag{ SourceFlag = Flag{
Name: "image-src", Name: "image-src",
ConfigName: "image.source", ConfigName: "image.source",
Default: ftypes.AllImageSources.StringSlice(), Default: xstrings.ToStringSlice(ftypes.AllImageSources),
Values: ftypes.AllImageSources.StringSlice(), Values: xstrings.ToStringSlice(ftypes.AllImageSources),
Usage: "image source(s) to use, in priority order", Usage: "image source(s) to use, in priority order",
} }
) )

View File

@@ -8,7 +8,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/samber/lo"
"github.com/spf13/cast" "github.com/spf13/cast"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" "github.com/spf13/pflag"
@@ -18,14 +17,12 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/analyzer"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/report"
"github.com/aquasecurity/trivy/pkg/result" "github.com/aquasecurity/trivy/pkg/result"
"github.com/aquasecurity/trivy/pkg/types"
xio "github.com/aquasecurity/trivy/pkg/x/io"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
) )
type String interface {
~string
}
type Flag struct { type Flag struct {
// Name is for CLI flag and environment variable. // Name is for CLI flag and environment variable.
// If this field is empty, it will be available only in config file. // If this field is empty, it will be available only in config file.
@@ -120,18 +117,18 @@ type Options struct {
// Align takes consistency of options // Align takes consistency of options
func (o *Options) Align() { func (o *Options) Align() {
if o.Format == report.FormatSPDX || o.Format == report.FormatSPDXJSON { if o.Format == types.FormatSPDX || o.Format == types.FormatSPDXJSON {
log.Logger.Info(`"--format spdx" and "--format spdx-json" disable security scanning`) log.Logger.Info(`"--format spdx" and "--format spdx-json" disable security scanning`)
o.Scanners = nil o.Scanners = nil
} }
// Vulnerability scanning is disabled by default for CycloneDX. // Vulnerability scanning is disabled by default for CycloneDX.
if o.Format == report.FormatCycloneDX && !viper.IsSet(ScannersFlag.ConfigName) && len(o.K8sOptions.Components) == 0 { // remove K8sOptions.Components validation check when vuln scan is supported for k8s report with cycloneDX if o.Format == types.FormatCycloneDX && !viper.IsSet(ScannersFlag.ConfigName) && len(o.K8sOptions.Components) == 0 { // remove K8sOptions.Components validation check when vuln scan is supported for k8s report with cycloneDX
log.Logger.Info(`"--format cyclonedx" disables security scanning. Specify "--scanners vuln" explicitly if you want to include vulnerabilities in the CycloneDX report.`) log.Logger.Info(`"--format cyclonedx" disables security scanning. Specify "--scanners vuln" explicitly if you want to include vulnerabilities in the CycloneDX report.`)
o.Scanners = nil o.Scanners = nil
} }
if o.Format == report.FormatCycloneDX && len(o.K8sOptions.Components) > 0 { if o.Format == types.FormatCycloneDX && len(o.K8sOptions.Components) > 0 {
log.Logger.Info(`"k8s with --format cyclonedx" disable security scanning`) log.Logger.Info(`"k8s with --format cyclonedx" disable security scanning`)
o.Scanners = nil o.Scanners = nil
} }
@@ -161,19 +158,17 @@ func (o *Options) FilterOpts() result.FilterOption {
} }
} }
func (o *Options) ReportOpts() report.Option { // OutputWriter returns an output writer.
return report.Option{ // If the output file is not specified, it returns os.Stdout.
AppVersion: o.AppVersion, func (o *Options) OutputWriter() (io.WriteCloser, error) {
Format: o.Format, if o.Output != "" {
Output: o.Output, f, err := os.Create(o.Output)
Tree: o.DependencyTree, if err != nil {
Severities: o.Severities, return nil, xerrors.Errorf("failed to create output file: %w", err)
OutputTemplate: o.Template, }
IncludeNonFailures: o.IncludeNonFailures, return f, nil
Trace: o.Trace,
Report: o.ReportFormat,
Compliance: o.Compliance,
} }
return xio.NopCloser(os.Stdout), nil
} }
func addFlag(cmd *cobra.Command, flag *Flag) { func addFlag(cmd *cobra.Command, flag *Flag) {
@@ -268,6 +263,11 @@ func getString(flag *Flag) string {
return cast.ToString(getValue(flag)) return cast.ToString(getValue(flag))
} }
func getUnderlyingString[T xstrings.String](flag *Flag) T {
s := getString(flag)
return T(s)
}
func getStringSlice(flag *Flag) []string { func getStringSlice(flag *Flag) []string {
// viper always returns a string for ENV // viper always returns a string for ENV
// https://github.com/spf13/viper/blob/419fd86e49ef061d0d33f4d1d56d5e2a480df5bb/viper.go#L545-L553 // https://github.com/spf13/viper/blob/419fd86e49ef061d0d33f4d1d56d5e2a480df5bb/viper.go#L545-L553
@@ -283,14 +283,12 @@ func getStringSlice(flag *Flag) []string {
return v return v
} }
func getUnderlyingStringSlice[T String](flag *Flag) []T { func getUnderlyingStringSlice[T xstrings.String](flag *Flag) []T {
ss := getStringSlice(flag) ss := getStringSlice(flag)
if len(ss) == 0 { if len(ss) == 0 {
return nil return nil
} }
return lo.Map(ss, func(s string, _ int) T { return xstrings.ToTSlice[T](ss)
return T(s)
})
} }
func getInt(flag *Flag) int { func getInt(flag *Flag) int {
@@ -441,7 +439,7 @@ func (f *Flags) Bind(cmd *cobra.Command) error {
} }
// nolint: gocyclo // nolint: gocyclo
func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalFlagGroup, output io.Writer) (Options, error) { func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalFlagGroup) (Options, error) {
var err error var err error
opts := Options{ opts := Options{
AppVersion: appVersion, AppVersion: appVersion,
@@ -522,7 +520,7 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF
} }
if f.ReportFlagGroup != nil { if f.ReportFlagGroup != nil {
opts.ReportOptions, err = f.ReportFlagGroup.ToOptions(output) opts.ReportOptions, err = f.ReportFlagGroup.ToOptions()
if err != nil { if err != nil {
return Options{}, xerrors.Errorf("report flag error: %w", err) return Options{}, xerrors.Errorf("report flag error: %w", err)
} }

View File

@@ -1,8 +1,6 @@
package flag package flag
import ( import (
"io"
"os"
"strings" "strings"
"github.com/samber/lo" "github.com/samber/lo"
@@ -12,9 +10,9 @@ import (
dbTypes "github.com/aquasecurity/trivy-db/pkg/types" dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/pkg/compliance/spec" "github.com/aquasecurity/trivy/pkg/compliance/spec"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/report"
"github.com/aquasecurity/trivy/pkg/result" "github.com/aquasecurity/trivy/pkg/result"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
) )
// e.g. config yaml: // e.g. config yaml:
@@ -27,8 +25,8 @@ var (
Name: "format", Name: "format",
ConfigName: "format", ConfigName: "format",
Shorthand: "f", Shorthand: "f",
Default: report.FormatTable, Default: string(types.FormatTable),
Values: report.SupportedFormats, Values: xstrings.ToStringSlice(types.SupportedFormats),
Usage: "format", Usage: "format",
} }
ReportFormatFlag = Flag{ ReportFormatFlag = Flag{
@@ -122,7 +120,7 @@ type ReportFlagGroup struct {
} }
type ReportOptions struct { type ReportOptions struct {
Format string Format types.Format
ReportFormat string ReportFormat string
Template string Template string
DependencyTree bool DependencyTree bool
@@ -131,7 +129,7 @@ type ReportOptions struct {
ExitCode int ExitCode int
ExitOnEOL int ExitOnEOL int
IgnorePolicy string IgnorePolicy string
Output io.Writer Output string
Severities []dbTypes.Severity Severities []dbTypes.Severity
Compliance spec.ComplianceSpec Compliance spec.ComplianceSpec
} }
@@ -174,12 +172,11 @@ func (f *ReportFlagGroup) Flags() []*Flag {
} }
} }
func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) { func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) {
format := getString(f.Format) format := getUnderlyingString[types.Format](f.Format)
template := getString(f.Template) template := getString(f.Template)
dependencyTree := getBool(f.DependencyTree) dependencyTree := getBool(f.DependencyTree)
listAllPkgs := getBool(f.ListAllPkgs) listAllPkgs := getBool(f.ListAllPkgs)
output := getString(f.Output)
if template != "" { if template != "" {
if format == "" { if format == "" {
@@ -188,14 +185,14 @@ func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) {
log.Logger.Warnf("'--template' is ignored because '--format %s' is specified. Use '--template' option with '--format template' option.", format) log.Logger.Warnf("'--template' is ignored because '--format %s' is specified. Use '--template' option with '--format template' option.", format)
} }
} else { } else {
if format == report.FormatTemplate { if format == types.FormatTemplate {
log.Logger.Warn("'--format template' is ignored because '--template' is not specified. Specify '--template' option when you use '--format template'.") log.Logger.Warn("'--format template' is ignored because '--template' is not specified. Specify '--template' option when you use '--format template'.")
} }
} }
// "--list-all-pkgs" option is unavailable with "--format table". // "--list-all-pkgs" option is unavailable with "--format table".
// If user specifies "--list-all-pkgs" with "--format table", we should warn it. // If user specifies "--list-all-pkgs" with "--format table", we should warn it.
if listAllPkgs && format == report.FormatTable { if listAllPkgs && format == types.FormatTable {
log.Logger.Warn(`"--list-all-pkgs" cannot be used with "--format table". Try "--format json" or other formats.`) log.Logger.Warn(`"--list-all-pkgs" cannot be used with "--format table". Try "--format json" or other formats.`)
} }
@@ -204,7 +201,7 @@ func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) {
log.Logger.Infof(`"--dependency-tree" only shows the dependents of vulnerable packages. ` + log.Logger.Infof(`"--dependency-tree" only shows the dependents of vulnerable packages. ` +
`Note that it is the reverse of the usual dependency tree, which shows the packages that depend on the vulnerable package. ` + `Note that it is the reverse of the usual dependency tree, which shows the packages that depend on the vulnerable package. ` +
`It supports limited package managers. Please see the document for the detail.`) `It supports limited package managers. Please see the document for the detail.`)
if format != report.FormatTable { if format != types.FormatTable {
log.Logger.Warn(`"--dependency-tree" can be used only with "--format table".`) log.Logger.Warn(`"--dependency-tree" can be used only with "--format table".`)
} }
} }
@@ -214,13 +211,6 @@ func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) {
listAllPkgs = true listAllPkgs = true
} }
if output != "" {
var err error
if out, err = os.Create(output); err != nil {
return ReportOptions{}, xerrors.Errorf("failed to create an output file: %w", err)
}
}
cs, err := loadComplianceTypes(getString(f.Compliance)) cs, err := loadComplianceTypes(getString(f.Compliance))
if err != nil { if err != nil {
return ReportOptions{}, xerrors.Errorf("unable to load compliance spec: %w", err) return ReportOptions{}, xerrors.Errorf("unable to load compliance spec: %w", err)
@@ -236,14 +226,14 @@ func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) {
ExitCode: getInt(f.ExitCode), ExitCode: getInt(f.ExitCode),
ExitOnEOL: getInt(f.ExitOnEOL), ExitOnEOL: getInt(f.ExitOnEOL),
IgnorePolicy: getString(f.IgnorePolicy), IgnorePolicy: getString(f.IgnorePolicy),
Output: out, Output: getString(f.Output),
Severities: toSeverity(getStringSlice(f.Severity)), Severities: toSeverity(getStringSlice(f.Severity)),
Compliance: cs, Compliance: cs,
}, nil }, nil
} }
func loadComplianceTypes(compliance string) (spec.ComplianceSpec, error) { func loadComplianceTypes(compliance string) (spec.ComplianceSpec, error) {
if len(compliance) > 0 && !slices.Contains(types.Compliances, compliance) && !strings.HasPrefix(compliance, "@") { if len(compliance) > 0 && !slices.Contains(types.SupportedCompliances, compliance) && !strings.HasPrefix(compliance, "@") {
return spec.ComplianceSpec{}, xerrors.Errorf("unknown compliance : %v", compliance) return spec.ComplianceSpec{}, xerrors.Errorf("unknown compliance : %v", compliance)
} }
@@ -255,13 +245,13 @@ func loadComplianceTypes(compliance string) (spec.ComplianceSpec, error) {
return cs, nil return cs, nil
} }
func (f *ReportFlagGroup) forceListAllPkgs(format string, listAllPkgs, dependencyTree bool) bool { func (f *ReportFlagGroup) forceListAllPkgs(format types.Format, listAllPkgs, dependencyTree bool) bool {
if slices.Contains(report.SupportedSBOMFormats, format) && !listAllPkgs { if slices.Contains(types.SupportedSBOMFormats, format) && !listAllPkgs {
log.Logger.Debugf("%q automatically enables '--list-all-pkgs'.", report.SupportedSBOMFormats) log.Logger.Debugf("%q automatically enables '--list-all-pkgs'.", types.SupportedSBOMFormats)
return true return true
} }
// We need this flag to insert dependency locations into Sarif('Package' struct contains 'Locations') // We need this flag to insert dependency locations into Sarif('Package' struct contains 'Locations')
if format == report.FormatSarif && !listAllPkgs { if format == types.FormatSarif && !listAllPkgs {
log.Logger.Debugf("Sarif format automatically enables '--list-all-pkgs' to get locations") log.Logger.Debugf("Sarif format automatically enables '--list-all-pkgs' to get locations")
return true return true
} }

View File

@@ -1,26 +1,24 @@
package flag_test package flag_test
import ( import (
"os"
"testing" "testing"
defsecTypes "github.com/aquasecurity/defsec/pkg/types"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zaptest/observer" "go.uber.org/zap/zaptest/observer"
defsecTypes "github.com/aquasecurity/defsec/pkg/types"
dbTypes "github.com/aquasecurity/trivy-db/pkg/types" dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/pkg/compliance/spec" "github.com/aquasecurity/trivy/pkg/compliance/spec"
"github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/types"
) )
func TestReportFlagGroup_ToOptions(t *testing.T) { func TestReportFlagGroup_ToOptions(t *testing.T) {
type fields struct { type fields struct {
format string format types.Format
template string template string
dependencyTree bool dependencyTree bool
listAllPkgs bool listAllPkgs bool
@@ -44,9 +42,7 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
{ {
name: "happy default (without flags)", name: "happy default (without flags)",
fields: fields{}, fields: fields{},
want: flag.ReportOptions{ want: flag.ReportOptions{},
Output: os.Stdout,
},
}, },
{ {
name: "happy path with an cyclonedx", name: "happy path with an cyclonedx",
@@ -56,9 +52,8 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
listAllPkgs: true, listAllPkgs: true,
}, },
want: flag.ReportOptions{ want: flag.ReportOptions{
Output: os.Stdout,
Severities: []dbTypes.Severity{dbTypes.SeverityCritical}, Severities: []dbTypes.Severity{dbTypes.SeverityCritical},
Format: report.FormatCycloneDX, Format: types.FormatCycloneDX,
ListAllPkgs: true, ListAllPkgs: true,
}, },
}, },
@@ -76,11 +71,10 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
`Severities: ["CRITICAL"]`, `Severities: ["CRITICAL"]`,
}, },
want: flag.ReportOptions{ want: flag.ReportOptions{
Output: os.Stdout,
Severities: []dbTypes.Severity{ Severities: []dbTypes.Severity{
dbTypes.SeverityCritical, dbTypes.SeverityCritical,
}, },
Format: report.FormatCycloneDX, Format: types.FormatCycloneDX,
ListAllPkgs: true, ListAllPkgs: true,
}, },
}, },
@@ -94,7 +88,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
"'--template' is ignored because '--format template' is not specified. Use '--template' option with '--format template' option.", "'--template' is ignored because '--format template' is not specified. Use '--template' option with '--format template' option.",
}, },
want: flag.ReportOptions{ want: flag.ReportOptions{
Output: os.Stdout,
Severities: []dbTypes.Severity{dbTypes.SeverityLow}, Severities: []dbTypes.Severity{dbTypes.SeverityLow},
Template: "@contrib/gitlab.tpl", Template: "@contrib/gitlab.tpl",
}, },
@@ -110,7 +103,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
"'--template' is ignored because '--format json' is specified. Use '--template' option with '--format template' option.", "'--template' is ignored because '--format json' is specified. Use '--template' option with '--format template' option.",
}, },
want: flag.ReportOptions{ want: flag.ReportOptions{
Output: os.Stdout,
Format: "json", Format: "json",
Severities: []dbTypes.Severity{dbTypes.SeverityLow}, Severities: []dbTypes.Severity{dbTypes.SeverityLow},
Template: "@contrib/gitlab.tpl", Template: "@contrib/gitlab.tpl",
@@ -126,7 +118,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
"'--format template' is ignored because '--template' is not specified. Specify '--template' option when you use '--format template'.", "'--format template' is ignored because '--template' is not specified. Specify '--template' option when you use '--format template'.",
}, },
want: flag.ReportOptions{ want: flag.ReportOptions{
Output: os.Stdout,
Format: "template", Format: "template",
Severities: []dbTypes.Severity{dbTypes.SeverityLow}, Severities: []dbTypes.Severity{dbTypes.SeverityLow},
}, },
@@ -143,7 +134,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
}, },
want: flag.ReportOptions{ want: flag.ReportOptions{
Format: "table", Format: "table",
Output: os.Stdout,
Severities: []dbTypes.Severity{dbTypes.SeverityLow}, Severities: []dbTypes.Severity{dbTypes.SeverityLow},
ListAllPkgs: true, ListAllPkgs: true,
}, },
@@ -155,7 +145,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
severities: dbTypes.SeverityLow.String(), severities: dbTypes.SeverityLow.String(),
}, },
want: flag.ReportOptions{ want: flag.ReportOptions{
Output: os.Stdout,
Compliance: spec.ComplianceSpec{ Compliance: spec.ComplianceSpec{
Spec: defsecTypes.Spec{ Spec: defsecTypes.Spec{
ID: "0001", ID: "0001",
@@ -188,7 +177,7 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
core, obs := observer.New(level) core, obs := observer.New(level)
log.Logger = zap.New(core).Sugar() log.Logger = zap.New(core).Sugar()
viper.Set(flag.FormatFlag.ConfigName, tt.fields.format) viper.Set(flag.FormatFlag.ConfigName, string(tt.fields.format))
viper.Set(flag.TemplateFlag.ConfigName, tt.fields.template) viper.Set(flag.TemplateFlag.ConfigName, tt.fields.template)
viper.Set(flag.DependencyTreeFlag.ConfigName, tt.fields.dependencyTree) viper.Set(flag.DependencyTreeFlag.ConfigName, tt.fields.dependencyTree)
viper.Set(flag.ListAllPkgsFlag.ConfigName, tt.fields.listAllPkgs) viper.Set(flag.ListAllPkgsFlag.ConfigName, tt.fields.listAllPkgs)
@@ -216,7 +205,7 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
Compliance: &flag.ComplianceFlag, Compliance: &flag.ComplianceFlag,
} }
got, err := f.ToOptions(os.Stdout) got, err := f.ToOptions()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equalf(t, tt.want, got, "ToOptions()") assert.Equalf(t, tt.want, got, "ToOptions()")

View File

@@ -2,6 +2,7 @@ package flag
import ( import (
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
) )
var ( var (
@@ -26,16 +27,16 @@ var (
ScannersFlag = Flag{ ScannersFlag = Flag{
Name: "scanners", Name: "scanners",
ConfigName: "scan.scanners", ConfigName: "scan.scanners",
Default: types.Scanners{ Default: xstrings.ToStringSlice(types.Scanners{
types.VulnerabilityScanner, types.VulnerabilityScanner,
types.SecretScanner, types.SecretScanner,
}.StringSlice(), }),
Values: types.Scanners{ Values: xstrings.ToStringSlice(types.Scanners{
types.VulnerabilityScanner, types.VulnerabilityScanner,
types.MisconfigScanner, types.MisconfigScanner,
types.SecretScanner, types.SecretScanner,
types.LicenseScanner, types.LicenseScanner,
}.StringSlice(), }),
Aliases: []Alias{ Aliases: []Alias{
{ {
Name: "security-checks", Name: "security-checks",

View File

@@ -11,7 +11,6 @@ import (
"github.com/aquasecurity/trivy-kubernetes/pkg/trivyk8s" "github.com/aquasecurity/trivy-kubernetes/pkg/trivyk8s"
"github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/report"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
) )
@@ -23,12 +22,12 @@ func clusterRun(ctx context.Context, opts flag.Options, cluster k8s.Cluster) err
var artifacts []*artifacts.Artifact var artifacts []*artifacts.Artifact
var err error var err error
switch opts.Format { switch opts.Format {
case report.FormatCycloneDX: case types.FormatCycloneDX:
artifacts, err = trivyk8s.New(cluster, log.Logger).ListBomInfo(ctx) artifacts, err = trivyk8s.New(cluster, log.Logger).ListBomInfo(ctx)
if err != nil { if err != nil {
return xerrors.Errorf("get k8s artifacts with node info error: %w", err) return xerrors.Errorf("get k8s artifacts with node info error: %w", err)
} }
case report.FormatJSON, report.FormatTable: case types.FormatJSON, types.FormatTable:
if opts.Scanners.AnyEnabled(types.MisconfigScanner) && slices.Contains(opts.Components, "infra") { if opts.Scanners.AnyEnabled(types.MisconfigScanner) && slices.Contains(opts.Components, "infra") {
artifacts, err = trivyk8s.New(cluster, log.Logger).ListArtifactAndNodeInfo(ctx, opts.NodeCollectorNamespace, opts.ExcludeNodes, opts.Tolerations...) artifacts, err = trivyk8s.New(cluster, log.Logger).ListArtifactAndNodeInfo(ctx, opts.NodeCollectorNamespace, opts.ExcludeNodes, opts.Tolerations...)
if err != nil { if err != nil {

View File

@@ -95,6 +95,12 @@ func (r *runner) run(ctx context.Context, artifacts []*artifacts.Artifact) error
return xerrors.Errorf("k8s scan error: %w", err) return xerrors.Errorf("k8s scan error: %w", err)
} }
output, err := r.flagOpts.OutputWriter()
if err != nil {
return xerrors.Errorf("failed to create output file: %w", err)
}
defer output.Close()
if r.flagOpts.Compliance.Spec.ID != "" { if r.flagOpts.Compliance.Spec.ID != "" {
var scanResults []types.Results var scanResults []types.Results
for _, rss := range rpt.Resources { for _, rss := range rpt.Resources {
@@ -107,14 +113,14 @@ func (r *runner) run(ctx context.Context, artifacts []*artifacts.Artifact) error
return cr.Write(complianceReport, cr.Option{ return cr.Write(complianceReport, cr.Option{
Format: r.flagOpts.Format, Format: r.flagOpts.Format,
Report: r.flagOpts.ReportFormat, Report: r.flagOpts.ReportFormat,
Output: r.flagOpts.Output, Output: output,
}) })
} }
if err := k8sRep.Write(rpt, report.Option{ if err := k8sRep.Write(rpt, report.Option{
Format: r.flagOpts.Format, Format: r.flagOpts.Format,
Report: r.flagOpts.ReportFormat, Report: r.flagOpts.ReportFormat,
Output: r.flagOpts.Output, Output: output,
Severities: r.flagOpts.Severities, Severities: r.flagOpts.Severities,
Components: r.flagOpts.Components, Components: r.flagOpts.Components,
Scanners: r.flagOpts.ScanOptions.Scanners, Scanners: r.flagOpts.ScanOptions.Scanners,

View File

@@ -25,7 +25,7 @@ const (
) )
type Option struct { type Option struct {
Format string Format types.Format
Report string Report string
Output io.Writer Output io.Writer
Severities []dbTypes.Severity Severities []dbTypes.Severity

View File

@@ -7,16 +7,13 @@ import (
"sort" "sort"
"strings" "strings"
"golang.org/x/xerrors" cdx "github.com/CycloneDX/cyclonedx-go"
ms "github.com/mitchellh/mapstructure" ms "github.com/mitchellh/mapstructure"
"github.com/package-url/packageurl-go" "github.com/package-url/packageurl-go"
"github.com/samber/lo" "github.com/samber/lo"
"golang.org/x/xerrors"
"github.com/aquasecurity/go-version/pkg/version" "github.com/aquasecurity/go-version/pkg/version"
cdx "github.com/CycloneDX/cyclonedx-go"
"github.com/aquasecurity/trivy-kubernetes/pkg/artifacts" "github.com/aquasecurity/trivy-kubernetes/pkg/artifacts"
"github.com/aquasecurity/trivy-kubernetes/pkg/bom" "github.com/aquasecurity/trivy-kubernetes/pkg/bom"
cmd "github.com/aquasecurity/trivy/pkg/commands/artifact" cmd "github.com/aquasecurity/trivy/pkg/commands/artifact"
@@ -27,7 +24,6 @@ import (
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/parallel" "github.com/aquasecurity/trivy/pkg/parallel"
"github.com/aquasecurity/trivy/pkg/purl" "github.com/aquasecurity/trivy/pkg/purl"
rep "github.com/aquasecurity/trivy/pkg/report"
cyc "github.com/aquasecurity/trivy/pkg/sbom/cyclonedx" cyc "github.com/aquasecurity/trivy/pkg/sbom/cyclonedx"
"github.com/aquasecurity/trivy/pkg/sbom/cyclonedx/core" "github.com/aquasecurity/trivy/pkg/sbom/cyclonedx/core"
"github.com/aquasecurity/trivy/pkg/scanner/local" "github.com/aquasecurity/trivy/pkg/scanner/local"
@@ -74,7 +70,7 @@ func (s *Scanner) Scan(ctx context.Context, artifactsData []*artifacts.Artifact)
} }
}() }()
if s.opts.Format == rep.FormatCycloneDX { if s.opts.Format == types.FormatCycloneDX {
rootComponent, err := clusterInfoToReportResources(artifactsData, s.cluster) rootComponent, err := clusterInfoToReportResources(artifactsData, s.cluster)
if err != nil { if err != nil {
return report.Report{}, err return report.Report{}, err

View File

@@ -3,12 +3,11 @@ package k8s
import ( import (
"fmt" "fmt"
"github.com/aquasecurity/trivy/pkg/k8s/report"
cdx "github.com/CycloneDX/cyclonedx-go" cdx "github.com/CycloneDX/cyclonedx-go"
rp "github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/k8s/report"
"github.com/aquasecurity/trivy/pkg/report/table" "github.com/aquasecurity/trivy/pkg/report/table"
"github.com/aquasecurity/trivy/pkg/types"
) )
type Writer interface { type Writer interface {
@@ -20,13 +19,13 @@ func Write(k8sreport report.Report, option report.Option) error {
k8sreport.PrintErrors() k8sreport.PrintErrors()
switch option.Format { switch option.Format {
case rp.FormatJSON: case types.FormatJSON:
jwriter := report.JSONWriter{ jwriter := report.JSONWriter{
Output: option.Output, Output: option.Output,
Report: option.Report, Report: option.Report,
} }
return jwriter.Write(k8sreport) return jwriter.Write(k8sreport)
case rp.FormatTable: case types.FormatTable:
separatedReports := report.SeparateMisconfigReports(k8sreport, option.Scanners, option.Components) separatedReports := report.SeparateMisconfigReports(k8sreport, option.Scanners, option.Components)
if option.Report == report.SummaryReport { if option.Report == report.SummaryReport {
@@ -48,7 +47,7 @@ func Write(k8sreport report.Report, option report.Option) error {
} }
return nil return nil
case rp.FormatCycloneDX: case types.FormatCycloneDX:
w := report.NewCycloneDXWriter(option.Output, cdx.BOMFileFormatJSON, option.APIVersion) w := report.NewCycloneDXWriter(option.Output, cdx.BOMFileFormatJSON, option.APIVersion)
return w.Write(k8sreport.RootComponent) return w.Write(k8sreport.RootComponent)
} }

View File

@@ -11,7 +11,7 @@ import (
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/syncx" xsync "github.com/aquasecurity/trivy/pkg/x/sync"
) )
var separator = "/" var separator = "/"
@@ -24,7 +24,7 @@ type file struct {
underlyingPath string // underlying file path underlyingPath string // underlying file path
data []byte // virtual file, only either of 'path' or 'data' has a value. data []byte // virtual file, only either of 'path' or 'data' has a value.
stat fileStat stat fileStat
files syncx.Map[string, *file] files xsync.Map[string, *file]
} }
func (f *file) isVirtual() bool { func (f *file) isVirtual() bool {
@@ -187,7 +187,7 @@ func (f *file) MkdirAll(path string, perm fs.FileMode) error {
modTime: time.Now(), modTime: time.Now(),
mode: perm, mode: perm,
}, },
files: syncx.Map[string, *file]{}, files: xsync.Map[string, *file]{},
} }
// Create the directory when the key is not present // Create the directory when the key is not present

View File

@@ -12,7 +12,7 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/syncx" xsync "github.com/aquasecurity/trivy/pkg/x/sync"
) )
type allFS interface { type allFS interface {
@@ -56,7 +56,7 @@ func New(opts ...Option) *FS {
modTime: time.Now(), modTime: time.Now(),
mode: 0o0700 | fs.ModeDir, mode: 0o0700 | fs.ModeDir,
}, },
files: syncx.Map[string, *file]{}, files: xsync.Map[string, *file]{},
}, },
} }
for _, opt := range opts { for _, opt := range opts {

View File

@@ -9,7 +9,6 @@ import (
dbTypes "github.com/aquasecurity/trivy-db/pkg/types" dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/report"
"github.com/aquasecurity/trivy/pkg/report/github" "github.com/aquasecurity/trivy/pkg/report/github"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
) )
@@ -136,22 +135,19 @@ func TestWriter_Write(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
jw := github.Writer{} written := bytes.NewBuffer(nil)
written := bytes.Buffer{} w := github.Writer{
jw.Output = &written Output: written,
}
inputResults := tt.report inputResults := tt.report
err := report.Write(inputResults, report.Option{ err := w.Write(inputResults)
Format: "github",
Output: &written,
})
assert.NoError(t, err) assert.NoError(t, err)
var got github.DependencySnapshot var got github.DependencySnapshot
err = json.Unmarshal(written.Bytes(), &got) err = json.Unmarshal(written.Bytes(), &got)
assert.NoError(t, err, "invalid github written") assert.NoError(t, err, "invalid github written")
assert.Equal(t, tt.want, got.Manifests, tt.name) assert.Equal(t, tt.want, got.Manifests, tt.name)
}) })
} }

View File

@@ -66,9 +66,10 @@ func TestReportWriter_JSON(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
jw := report.JSONWriter{} jsonWritten := bytes.NewBuffer(nil)
jsonWritten := bytes.Buffer{} jw := report.JSONWriter{
jw.Output = &jsonWritten Output: jsonWritten,
}
inputResults := types.Report{ inputResults := types.Report{
SchemaVersion: 2, SchemaVersion: 2,
@@ -81,10 +82,7 @@ func TestReportWriter_JSON(t *testing.T) {
}, },
} }
err := report.Write(inputResults, report.Option{ err := jw.Write(inputResults)
Format: "json",
Output: &jsonWritten,
})
assert.NoError(t, err) assert.NoError(t, err)
var got types.Report var got types.Report

View File

@@ -456,11 +456,11 @@ func TestReportWriter_Sarif(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
sarifWritten := bytes.Buffer{} sarifWritten := bytes.NewBuffer(nil)
err := report.Write(tt.input, report.Option{ w := report.SarifWriter{
Format: "sarif", Output: sarifWritten,
Output: &sarifWritten, }
}) err := w.Write(tt.input)
assert.NoError(t, err) assert.NoError(t, err)
result := &sarif.Report{} result := &sarif.Report{}

View File

@@ -15,11 +15,11 @@ import (
type Writer struct { type Writer struct {
output io.Writer output io.Writer
version string version string
format string format types.Format
marshaler *spdx.Marshaler marshaler *spdx.Marshaler
} }
func NewWriter(output io.Writer, version string, spdxFormat string) Writer { func NewWriter(output io.Writer, version string, spdxFormat types.Format) Writer {
return Writer{ return Writer{
output: output, output: output,
version: version, version: version,

View File

@@ -8,7 +8,7 @@ import (
dbTypes "github.com/aquasecurity/trivy-db/pkg/types" dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/report/table"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
) )
@@ -339,8 +339,7 @@ package-lock.json
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
tableWritten := bytes.Buffer{} tableWritten := bytes.Buffer{}
err := report.Write(types.Report{Results: tc.results}, report.Option{ writer := table.Writer{
Format: report.FormatTable,
Output: &tableWritten, Output: &tableWritten,
Tree: true, Tree: true,
IncludeNonFailures: tc.includeNonFailures, IncludeNonFailures: tc.includeNonFailures,
@@ -348,7 +347,8 @@ package-lock.json
dbTypes.SeverityHigh, dbTypes.SeverityHigh,
dbTypes.SeverityMedium, dbTypes.SeverityMedium,
}, },
}) }
err := writer.Write(types.Report{Results: tc.results})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, tableWritten.String(), tc.name) assert.Equal(t, tc.expectedOutput, tableWritten.String(), tc.name)
}) })

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dbTypes "github.com/aquasecurity/trivy-db/pkg/types" dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/pkg/clock" "github.com/aquasecurity/trivy/pkg/clock"
@@ -177,11 +178,9 @@ func TestReportWriter_Template(t *testing.T) {
}, },
} }
err := report.Write(inputReport, report.Option{ w, err := report.NewTemplateWriter(&got, tc.template)
Format: "template", require.NoError(t, err)
Output: &got, err = w.Write(inputReport)
OutputTemplate: tc.template,
})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.expected, got.String()) assert.Equal(t, tc.expected, got.String())
}) })

View File

@@ -7,9 +7,8 @@ import (
"golang.org/x/xerrors" "golang.org/x/xerrors"
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
cr "github.com/aquasecurity/trivy/pkg/compliance/report" cr "github.com/aquasecurity/trivy/pkg/compliance/report"
"github.com/aquasecurity/trivy/pkg/compliance/spec" "github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/report/cyclonedx" "github.com/aquasecurity/trivy/pkg/report/cyclonedx"
"github.com/aquasecurity/trivy/pkg/report/github" "github.com/aquasecurity/trivy/pkg/report/github"
@@ -21,114 +20,67 @@ import (
const ( const (
SchemaVersion = 2 SchemaVersion = 2
FormatTable = "table"
FormatJSON = "json"
FormatTemplate = "template"
FormatSarif = "sarif"
FormatCycloneDX = "cyclonedx"
FormatSPDX = "spdx"
FormatSPDXJSON = "spdx-json"
FormatGitHub = "github"
FormatCosignVuln = "cosign-vuln"
) )
var (
SupportedFormats = []string{
FormatTable,
FormatJSON,
FormatTemplate,
FormatSarif,
FormatCycloneDX,
FormatSPDX,
FormatSPDXJSON,
FormatGitHub,
FormatCosignVuln,
}
)
var (
SupportedSBOMFormats = []string{
FormatCycloneDX,
FormatSPDX,
FormatSPDXJSON,
FormatGitHub,
}
)
type Option struct {
AppVersion string
Format string
Report string
Output io.Writer
Tree bool
Severities []dbTypes.Severity
OutputTemplate string
Compliance spec.ComplianceSpec
// For misconfigurations
IncludeNonFailures bool
Trace bool
// For licenses
LicenseRiskThreshold int
IgnoredLicenses []string
}
// Write writes the result to output, format as passed in argument // Write writes the result to output, format as passed in argument
func Write(report types.Report, option Option) error { func Write(report types.Report, option flag.Options) error {
output, err := option.OutputWriter()
if err != nil {
return xerrors.Errorf("failed to create a file: %w", err)
}
defer output.Close()
// Compliance report // Compliance report
if option.Compliance.Spec.ID != "" { if option.Compliance.Spec.ID != "" {
return complianceWrite(report, option) return complianceWrite(report, option, output)
} }
var writer Writer var writer Writer
switch option.Format { switch option.Format {
case FormatTable: case types.FormatTable:
writer = &table.Writer{ writer = &table.Writer{
Output: option.Output, Output: output,
Severities: option.Severities, Severities: option.Severities,
Tree: option.Tree, Tree: option.DependencyTree,
ShowMessageOnce: &sync.Once{}, ShowMessageOnce: &sync.Once{},
IncludeNonFailures: option.IncludeNonFailures, IncludeNonFailures: option.IncludeNonFailures,
Trace: option.Trace, Trace: option.Trace,
LicenseRiskThreshold: option.LicenseRiskThreshold, LicenseRiskThreshold: option.LicenseRiskThreshold,
IgnoredLicenses: option.IgnoredLicenses, IgnoredLicenses: option.IgnoredLicenses,
} }
case FormatJSON: case types.FormatJSON:
writer = &JSONWriter{Output: option.Output} writer = &JSONWriter{Output: output}
case FormatGitHub: case types.FormatGitHub:
writer = &github.Writer{ writer = &github.Writer{
Output: option.Output, Output: output,
Version: option.AppVersion, Version: option.AppVersion,
} }
case FormatCycloneDX: case types.FormatCycloneDX:
// TODO: support xml format option with cyclonedx writer // TODO: support xml format option with cyclonedx writer
writer = cyclonedx.NewWriter(option.Output, option.AppVersion) writer = cyclonedx.NewWriter(output, option.AppVersion)
case FormatSPDX, FormatSPDXJSON: case types.FormatSPDX, types.FormatSPDXJSON:
writer = spdx.NewWriter(option.Output, option.AppVersion, option.Format) writer = spdx.NewWriter(output, option.AppVersion, option.Format)
case FormatTemplate: case types.FormatTemplate:
// We keep `sarif.tpl` template working for backward compatibility for a while. // We keep `sarif.tpl` template working for backward compatibility for a while.
if strings.HasPrefix(option.OutputTemplate, "@") && strings.HasSuffix(option.OutputTemplate, "sarif.tpl") { if strings.HasPrefix(option.Template, "@") && strings.HasSuffix(option.Template, "sarif.tpl") {
log.Logger.Warn("Using `--template sarif.tpl` is deprecated. Please migrate to `--format sarif`. See https://github.com/aquasecurity/trivy/discussions/1571") log.Logger.Warn("Using `--template sarif.tpl` is deprecated. Please migrate to `--format sarif`. See https://github.com/aquasecurity/trivy/discussions/1571")
writer = &SarifWriter{ writer = &SarifWriter{
Output: option.Output, Output: output,
Version: option.AppVersion, Version: option.AppVersion,
} }
break break
} }
var err error var err error
if writer, err = NewTemplateWriter(option.Output, option.OutputTemplate); err != nil { if writer, err = NewTemplateWriter(output, option.Template); err != nil {
return xerrors.Errorf("failed to initialize template writer: %w", err) return xerrors.Errorf("failed to initialize template writer: %w", err)
} }
case FormatSarif: case types.FormatSarif:
writer = &SarifWriter{ writer = &SarifWriter{
Output: option.Output, Output: output,
Version: option.AppVersion, Version: option.AppVersion,
} }
case FormatCosignVuln: case types.FormatCosignVuln:
writer = predicate.NewVulnWriter(option.Output, option.AppVersion) writer = predicate.NewVulnWriter(output, option.AppVersion)
default: default:
return xerrors.Errorf("unknown format: %v", option.Format) return xerrors.Errorf("unknown format: %v", option.Format)
} }
@@ -139,15 +91,15 @@ func Write(report types.Report, option Option) error {
return nil return nil
} }
func complianceWrite(report types.Report, opt Option) error { func complianceWrite(report types.Report, opt flag.Options, output io.Writer) error {
complianceReport, err := cr.BuildComplianceReport([]types.Results{report.Results}, opt.Compliance) complianceReport, err := cr.BuildComplianceReport([]types.Results{report.Results}, opt.Compliance)
if err != nil { if err != nil {
return xerrors.Errorf("compliance report build error: %w", err) return xerrors.Errorf("compliance report build error: %w", err)
} }
return cr.Write(complianceReport, cr.Option{ return cr.Write(complianceReport, cr.Option{
Format: opt.Format, Format: opt.Format,
Report: opt.Report, Report: opt.ReportFormat,
Output: opt.Output, Output: output,
Severities: opt.Severities, Severities: opt.Severities,
}) })
} }

View File

@@ -10,6 +10,7 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
r "github.com/aquasecurity/trivy/pkg/rpc" r "github.com/aquasecurity/trivy/pkg/rpc"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
rpc "github.com/aquasecurity/trivy/rpc/scanner" rpc "github.com/aquasecurity/trivy/rpc/scanner"
) )
@@ -82,7 +83,7 @@ func (s Scanner) Scan(ctx context.Context, target, artifactKey string, blobKeys
BlobIds: blobKeys, BlobIds: blobKeys,
Options: &rpc.ScanOptions{ Options: &rpc.ScanOptions{
VulnType: opts.VulnType, VulnType: opts.VulnType,
Scanners: opts.Scanners.StringSlice(), Scanners: xstrings.ToStringSlice(opts.Scanners),
ListAllPackages: opts.ListAllPackages, ListAllPackages: opts.ListAllPackages,
LicenseCategories: licenseCategories, LicenseCategories: licenseCategories,
IncludeDevDeps: opts.IncludeDevDeps, IncludeDevDeps: opts.IncludeDevDeps,

View File

@@ -8,16 +8,6 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
) )
var Compliances = []string{
ComplianceK8sNsa,
ComplianceK8sCIS,
ComplianceK8sPSSBaseline,
ComplianceK8sPSSRestricted,
ComplianceAWSCIS12,
ComplianceAWSCIS14,
ComplianceDockerCIS,
}
// Report represents a scan result // Report represents a scan result
type Report struct { type Report struct {
SchemaVersion int `json:",omitempty"` SchemaVersion int `json:",omitempty"`
@@ -48,6 +38,7 @@ type Results []Result
type ResultClass string type ResultClass string
type Compliance = string type Compliance = string
type Format string
const ( const (
ClassOSPkg = "os-pkgs" // For detected packages and vulnerabilities in OS packages ClassOSPkg = "os-pkgs" // For detected packages and vulnerabilities in OS packages
@@ -65,6 +56,45 @@ const (
ComplianceAWSCIS12 = Compliance("aws-cis-1.2") ComplianceAWSCIS12 = Compliance("aws-cis-1.2")
ComplianceAWSCIS14 = Compliance("aws-cis-1.4") ComplianceAWSCIS14 = Compliance("aws-cis-1.4")
ComplianceDockerCIS = Compliance("docker-cis") ComplianceDockerCIS = Compliance("docker-cis")
FormatTable Format = "table"
FormatJSON Format = "json"
FormatTemplate Format = "template"
FormatSarif Format = "sarif"
FormatCycloneDX Format = "cyclonedx"
FormatSPDX Format = "spdx"
FormatSPDXJSON Format = "spdx-json"
FormatGitHub Format = "github"
FormatCosignVuln Format = "cosign-vuln"
)
var (
SupportedFormats = []Format{
FormatTable,
FormatJSON,
FormatTemplate,
FormatSarif,
FormatCycloneDX,
FormatSPDX,
FormatSPDXJSON,
FormatGitHub,
FormatCosignVuln,
}
SupportedSBOMFormats = []Format{
FormatCycloneDX,
FormatSPDX,
FormatSPDXJSON,
FormatGitHub,
}
SupportedCompliances = []string{
ComplianceK8sNsa,
ComplianceK8sCIS,
ComplianceK8sPSSBaseline,
ComplianceK8sPSSRestricted,
ComplianceAWSCIS12,
ComplianceAWSCIS14,
ComplianceDockerCIS,
}
) )
// Result holds a target and detected vulnerabilities // Result holds a target and detected vulnerabilities

View File

@@ -1,7 +1,6 @@
package types package types
import ( import (
"github.com/samber/lo"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@@ -84,9 +83,3 @@ func (scanners Scanners) AnyEnabled(ss ...Scanner) bool {
} }
return false return false
} }
func (scanners Scanners) StringSlice() []string {
return lo.Map(scanners, func(s Scanner, _ int) string {
return string(s)
})
}

15
pkg/x/io/io.go Normal file
View File

@@ -0,0 +1,15 @@
package io
import "io"
// NopCloser returns a WriteCloser with a no-op Close method wrapping
// the provided Writer w.
func NopCloser(w io.Writer) io.WriteCloser {
return nopCloser{w}
}
type nopCloser struct {
io.Writer
}
func (nopCloser) Close() error { return nil }

19
pkg/x/strings/strings.go Normal file
View File

@@ -0,0 +1,19 @@
package strings
import "github.com/samber/lo"
type String interface {
~string
}
func ToStringSlice[T String](ss []T) []string {
return lo.Map(ss, func(s T, _ int) string {
return string(s)
})
}
func ToTSlice[T String](ss []string) []T {
return lo.Map(ss, func(s string, _ int) T {
return T(s)
})
}

View File

@@ -1,4 +1,4 @@
package syncx package sync
import "sync" import "sync"