From a0dc3b688e87f66090fd9fd76454b40a20ff8957 Mon Sep 17 00:00:00 2001 From: Teppei Fukuda Date: Tue, 8 Apr 2025 15:49:16 +0400 Subject: [PATCH] refactor: add hook interface for extended functionality (#8585) --- integration/module_test.go | 6 +- internal/hooktest/hook.go | 96 +++++++++++ pkg/commands/artifact/run.go | 39 ++++- pkg/extension/hook.go | 162 +++++++++++++++++++ pkg/extension/hook_test.go | 278 ++++++++++++++++++++++++++++++++ pkg/flag/module_flags.go | 6 +- pkg/module/module.go | 24 ++- pkg/module/module_test.go | 38 ++--- pkg/report/writer.go | 11 ++ pkg/report/writer_test.go | 96 +++++++++++ pkg/scan/local/service.go | 14 +- pkg/scan/local/service_test.go | 75 +++++++++ pkg/scan/post/post_scan.go | 45 ------ pkg/scan/post/post_scan_test.go | 103 ------------ 14 files changed, 795 insertions(+), 198 deletions(-) create mode 100644 internal/hooktest/hook.go create mode 100644 pkg/extension/hook.go create mode 100644 pkg/extension/hook_test.go delete mode 100644 pkg/scan/post/post_scan.go delete mode 100644 pkg/scan/post/post_scan_test.go diff --git a/integration/module_test.go b/integration/module_test.go index 4dbf350779..45f9545e26 100644 --- a/integration/module_test.go +++ b/integration/module_test.go @@ -3,12 +3,12 @@ package integration import ( - "github.com/aquasecurity/trivy/pkg/types" "path/filepath" "testing" + "github.com/aquasecurity/trivy/pkg/extension" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" - "github.com/aquasecurity/trivy/pkg/scan/post" + "github.com/aquasecurity/trivy/pkg/types" ) func TestModule(t *testing.T) { @@ -52,7 +52,7 @@ func TestModule(t *testing.T) { t.Cleanup(func() { analyzer.DeregisterAnalyzer("spring4shell") - post.DeregisterPostScanner("spring4shell") + extension.DeregisterHook("spring4shell") }) // Run Trivy diff --git a/internal/hooktest/hook.go b/internal/hooktest/hook.go new file mode 100644 index 0000000000..1a40c1b1d1 --- /dev/null +++ b/internal/hooktest/hook.go @@ -0,0 +1,96 @@ +package hooktest + +import ( + "context" + "errors" + "testing" + + "github.com/aquasecurity/trivy/pkg/extension" + "github.com/aquasecurity/trivy/pkg/flag" + "github.com/aquasecurity/trivy/pkg/types" +) + +type testHook struct{} + +func (*testHook) Name() string { + return "test" +} + +func (*testHook) Version() int { + return 1 +} + +// RunHook implementation +func (*testHook) PreRun(ctx context.Context, opts flag.Options) error { + if opts.GlobalOptions.ConfigFile == "bad-config" { + return errors.New("bad pre-run") + } + return nil +} + +func (*testHook) PostRun(ctx context.Context, opts flag.Options) error { + if opts.GlobalOptions.ConfigFile == "bad-config" { + return errors.New("bad post-run") + } + return nil +} + +// ScanHook implementation +func (*testHook) PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error { + if target.Name == "bad-pre" { + return errors.New("bad pre-scan") + } + target.Name += " (pre-scan)" + return nil +} + +func (*testHook) PostScan(ctx context.Context, results types.Results) (types.Results, error) { + for i, r := range results { + if r.Target == "bad" { + return nil, errors.New("bad") + } + for j := range r.Vulnerabilities { + results[i].Vulnerabilities[j].References = []string{ + "https://example.com/post-scan", + } + } + } + return results, nil +} + +// ReportHook implementation +func (*testHook) PreReport(ctx context.Context, report *types.Report, opts flag.Options) error { + if report.ArtifactName == "bad-report" { + return errors.New("bad pre-report") + } + + // Modify the report + for i := range report.Results { + for j := range report.Results[i].Vulnerabilities { + report.Results[i].Vulnerabilities[j].Title = "Modified by pre-report hook" + } + } + return nil +} + +func (*testHook) PostReport(ctx context.Context, report *types.Report, opts flag.Options) error { + if report.ArtifactName == "bad-report" { + return errors.New("bad post-report") + } + + // Modify the report + for i := range report.Results { + for j := range report.Results[i].Vulnerabilities { + report.Results[i].Vulnerabilities[j].Description = "Modified by post-report hook" + } + } + return nil +} + +func Init(t *testing.T) { + h := &testHook{} + extension.RegisterHook(h) + t.Cleanup(func() { + extension.DeregisterHook(h.Name()) + }) +} diff --git a/pkg/commands/artifact/run.go b/pkg/commands/artifact/run.go index 68257ba533..c688ed5781 100644 --- a/pkg/commands/artifact/run.go +++ b/pkg/commands/artifact/run.go @@ -15,6 +15,7 @@ import ( "github.com/aquasecurity/trivy/pkg/cache" "github.com/aquasecurity/trivy/pkg/commands/operation" "github.com/aquasecurity/trivy/pkg/db" + "github.com/aquasecurity/trivy/pkg/extension" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/artifact" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" @@ -277,7 +278,6 @@ func (r *runner) Report(ctx context.Context, opts flag.Options, report types.Rep if err := pkgReport.Write(ctx, report, opts); err != nil { return xerrors.Errorf("unable to write results: %w", err) } - return nil } @@ -375,12 +375,32 @@ func Run(ctx context.Context, opts flag.Options, targetKind TargetKind) (err err return v.SafeWriteConfigAs("trivy-default.yaml") } + // Call pre-run hooks + if err := extension.PreRun(ctx, opts); err != nil { + return xerrors.Errorf("pre run error: %w", err) + } + + // Run the application + report, err := run(ctx, opts, targetKind) + if err != nil { + return xerrors.Errorf("run error: %w", err) + } + + // Call post-run hooks + if err := extension.PostRun(ctx, opts); err != nil { + return xerrors.Errorf("post run error: %w", err) + } + + return operation.Exit(opts, report.Results.Failed(), report.Metadata) +} + +func run(ctx context.Context, opts flag.Options, targetKind TargetKind) (types.Report, error) { r, err := NewRunner(ctx, opts) if err != nil { if errors.Is(err, SkipScan) { - return nil + return types.Report{}, nil } - return xerrors.Errorf("init error: %w", err) + return types.Report{}, xerrors.Errorf("init error: %w", err) } defer r.Close(ctx) @@ -395,24 +415,27 @@ func Run(ctx context.Context, opts flag.Options, targetKind TargetKind) (err err scanFunction, exists := scans[targetKind] if !exists { - return xerrors.Errorf("unknown target kind: %s", targetKind) + return types.Report{}, xerrors.Errorf("unknown target kind: %s", targetKind) } + // 1. Scan the artifact report, err := scanFunction(ctx, opts) if err != nil { - return xerrors.Errorf("%s scan error: %w", targetKind, err) + return types.Report{}, xerrors.Errorf("%s scan error: %w", targetKind, err) } + // 2. Filter the results report, err = r.Filter(ctx, opts, report) if err != nil { - return xerrors.Errorf("filter error: %w", err) + return types.Report{}, xerrors.Errorf("filter error: %w", err) } + // 3. Report the results if err = r.Report(ctx, opts, report); err != nil { - return xerrors.Errorf("report error: %w", err) + return types.Report{}, xerrors.Errorf("report error: %w", err) } - return operation.Exit(opts, report.Results.Failed(), report.Metadata) + return report, nil } func disabledAnalyzers(opts flag.Options) []analyzer.Type { diff --git a/pkg/extension/hook.go b/pkg/extension/hook.go new file mode 100644 index 0000000000..ebd72f07aa --- /dev/null +++ b/pkg/extension/hook.go @@ -0,0 +1,162 @@ +package extension + +import ( + "context" + "sort" + + "github.com/samber/lo" + "golang.org/x/xerrors" + + "github.com/aquasecurity/trivy/pkg/flag" + "github.com/aquasecurity/trivy/pkg/types" +) + +var hooks = make(map[string]Hook) + +func RegisterHook(s Hook) { + // Avoid duplication + hooks[s.Name()] = s +} + +func DeregisterHook(name string) { + delete(hooks, name) +} + +// Hook is an interface that defines the methods for a hook. +type Hook interface { + // Name returns the name of the extension. + Name() string +} + +// RunHook is a extension that is called before and after all the processes. +type RunHook interface { + Hook + + // PreRun is called before all the processes. + PreRun(ctx context.Context, opts flag.Options) error + + // PostRun is called after all the processes. + PostRun(ctx context.Context, opts flag.Options) error +} + +// ScanHook is a extension that is called before and after the scan. +type ScanHook interface { + Hook + + // PreScan is called before the scan. It can modify the scan target. + // It may be called on the server side in client/server mode. + PreScan(ctx context.Context, target *types.ScanTarget, opts types.ScanOptions) error + + // PostScan is called after the scan. It can modify the results. + // It may be called on the server side in client/server mode. + // NOTE: Wasm modules cannot directly modify the passed results, + // so it returns a copy of the results. + PostScan(ctx context.Context, results types.Results) (types.Results, error) +} + +// ReportHook is a extension that is called before and after the report is written. +type ReportHook interface { + Hook + + // PreReport is called before the report is written. + // It can modify the report. It is called on the client side. + PreReport(ctx context.Context, report *types.Report, opts flag.Options) error + + // PostReport is called after the report is written. + // It can modify the report. It is called on the client side. + PostReport(ctx context.Context, report *types.Report, opts flag.Options) error +} + +func PreRun(ctx context.Context, opts flag.Options) error { + for _, e := range Hooks() { + h, ok := e.(RunHook) + if !ok { + continue + } + if err := h.PreRun(ctx, opts); err != nil { + return xerrors.Errorf("%s pre run error: %w", e.Name(), err) + } + } + return nil +} + +// PostRun is a hook that is called after all the processes. +func PostRun(ctx context.Context, opts flag.Options) error { + for _, e := range Hooks() { + h, ok := e.(RunHook) + if !ok { + continue + } + if err := h.PostRun(ctx, opts); err != nil { + return xerrors.Errorf("%s post run error: %w", e.Name(), err) + } + } + return nil +} + +// PreScan is a hook that is called before the scan. +func PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error { + for _, e := range Hooks() { + h, ok := e.(ScanHook) + if !ok { + continue + } + if err := h.PreScan(ctx, target, options); err != nil { + return xerrors.Errorf("%s pre scan error: %w", e.Name(), err) + } + } + return nil +} + +// PostScan is a hook that is called after the scan. +func PostScan(ctx context.Context, results types.Results) (types.Results, error) { + var err error + for _, e := range Hooks() { + h, ok := e.(ScanHook) + if !ok { + continue + } + results, err = h.PostScan(ctx, results) + if err != nil { + return nil, xerrors.Errorf("%s post scan error: %w", e.Name(), err) + } + } + return results, nil +} + +// PreReport is a hook that is called before the report is written. +func PreReport(ctx context.Context, report *types.Report, opts flag.Options) error { + for _, e := range Hooks() { + h, ok := e.(ReportHook) + if !ok { + continue + } + if err := h.PreReport(ctx, report, opts); err != nil { + return xerrors.Errorf("%s pre report error: %w", e.Name(), err) + } + } + return nil +} + +// PostReport is a hook that is called after the report is written. +func PostReport(ctx context.Context, report *types.Report, opts flag.Options) error { + for _, e := range Hooks() { + h, ok := e.(ReportHook) + if !ok { + continue + } + if err := h.PostReport(ctx, report, opts); err != nil { + return xerrors.Errorf("%s post report error: %w", e.Name(), err) + } + } + return nil +} + +// Hooks returns the list of hooks. +func Hooks() []Hook { + hooks := lo.Values(hooks) + sort.Slice(hooks, func(i, j int) bool { + return hooks[i].Name() < hooks[j].Name() + }) + return hooks +} diff --git a/pkg/extension/hook_test.go b/pkg/extension/hook_test.go new file mode 100644 index 0000000000..c90828f585 --- /dev/null +++ b/pkg/extension/hook_test.go @@ -0,0 +1,278 @@ +package extension_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + dbTypes "github.com/aquasecurity/trivy-db/pkg/types" + "github.com/aquasecurity/trivy/internal/hooktest" + "github.com/aquasecurity/trivy/pkg/extension" + "github.com/aquasecurity/trivy/pkg/flag" + "github.com/aquasecurity/trivy/pkg/types" +) + +func TestPostScan(t *testing.T) { + tests := []struct { + name string + results types.Results + want types.Results + wantErr bool + }{ + { + name: "happy path", + results: types.Results{ + { + Target: "test", + Vulnerabilities: []types.DetectedVulnerability{ + { + VulnerabilityID: "CVE-2022-0001", + PkgName: "musl", + InstalledVersion: "1.2.3", + FixedVersion: "1.2.4", + Vulnerability: dbTypes.Vulnerability{ + Severity: "CRITICAL", + }, + }, + }, + }, + }, + want: types.Results{ + { + Target: "test", + Vulnerabilities: []types.DetectedVulnerability{ + { + VulnerabilityID: "CVE-2022-0001", + PkgName: "musl", + InstalledVersion: "1.2.3", + FixedVersion: "1.2.4", + Vulnerability: dbTypes.Vulnerability{ + Severity: "CRITICAL", + References: []string{ + "https://example.com/post-scan", + }, + }, + }, + }, + }, + }, + }, + { + name: "sad path", + results: types.Results{ + { + Target: "bad", + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize the test hook + hooktest.Init(t) + + results, err := extension.PostScan(t.Context(), tt.results) + require.Equal(t, tt.wantErr, err != nil) + assert.Equal(t, tt.want, results) + }) + } +} + +func TestPreScan(t *testing.T) { + tests := []struct { + name string + target *types.ScanTarget + options types.ScanOptions + wantErr bool + }{ + { + name: "happy path", + target: &types.ScanTarget{ + Name: "test", + }, + wantErr: false, + }, + { + name: "sad path", + target: &types.ScanTarget{ + Name: "bad-pre", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize the test hook + hooktest.Init(t) + + err := extension.PreScan(t.Context(), tt.target, tt.options) + require.Equal(t, tt.wantErr, err != nil) + }) + } +} + +func TestPreRun(t *testing.T) { + tests := []struct { + name string + opts flag.Options + wantErr bool + }{ + { + name: "happy path", + opts: flag.Options{}, + wantErr: false, + }, + { + name: "sad path", + opts: flag.Options{ + GlobalOptions: flag.GlobalOptions{ + ConfigFile: "bad-config", + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize the test hook + hooktest.Init(t) + + err := extension.PreRun(t.Context(), tt.opts) + require.Equal(t, tt.wantErr, err != nil) + }) + } +} + +func TestPostRun(t *testing.T) { + tests := []struct { + name string + opts flag.Options + wantErr bool + }{ + { + name: "happy path", + opts: flag.Options{}, + wantErr: false, + }, + { + name: "sad path", + opts: flag.Options{ + GlobalOptions: flag.GlobalOptions{ + ConfigFile: "bad-config", + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize the test extension + hooktest.Init(t) + + err := extension.PostRun(t.Context(), tt.opts) + require.Equal(t, tt.wantErr, err != nil) + }) + } +} + +func TestPreReport(t *testing.T) { + tests := []struct { + name string + report *types.Report + opts flag.Options + wantTitle string + wantErr bool + }{ + { + name: "happy path", + report: &types.Report{ + Results: types.Results{ + { + Vulnerabilities: []types.DetectedVulnerability{ + { + VulnerabilityID: "CVE-2022-0001", + }, + }, + }, + }, + }, + wantTitle: "Modified by pre-report hook", + wantErr: false, + }, + { + name: "sad path", + report: &types.Report{ + ArtifactName: "bad-report", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize the test hook + hooktest.Init(t) + + err := extension.PreReport(t.Context(), tt.report, tt.opts) + if tt.wantErr { + require.Error(t, err) + return + } + + require.Len(t, tt.report.Results, 1) + require.Len(t, tt.report.Results[0].Vulnerabilities, 1) + assert.Equal(t, tt.wantTitle, tt.report.Results[0].Vulnerabilities[0].Title) + }) + } +} + +func TestPostReport(t *testing.T) { + tests := []struct { + name string + report *types.Report + opts flag.Options + wantDescription string + wantErr bool + }{ + { + name: "happy path", + report: &types.Report{ + Results: types.Results{ + { + Vulnerabilities: []types.DetectedVulnerability{ + { + VulnerabilityID: "CVE-2022-0001", + }, + }, + }, + }, + }, + wantDescription: "Modified by post-report hook", + wantErr: false, + }, + { + name: "sad path", + report: &types.Report{ + ArtifactName: "bad-report", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize the test hook + hooktest.Init(t) + + err := extension.PostReport(t.Context(), tt.report, tt.opts) + if tt.wantErr { + require.Error(t, err) + return + } + + require.Len(t, tt.report.Results, 1) + require.Len(t, tt.report.Results[0].Vulnerabilities, 1) + assert.Equal(t, tt.wantDescription, tt.report.Results[0].Vulnerabilities[0].Description) + }) + } +} diff --git a/pkg/flag/module_flags.go b/pkg/flag/module_flags.go index a3fdca3082..d0f8bb8f4d 100644 --- a/pkg/flag/module_flags.go +++ b/pkg/flag/module_flags.go @@ -1,7 +1,9 @@ package flag import ( - "github.com/aquasecurity/trivy/pkg/module" + "path/filepath" + + "github.com/aquasecurity/trivy/pkg/utils/fsutils" ) // e.g. config yaml @@ -14,7 +16,7 @@ var ( ModuleDirFlag = Flag[string]{ Name: "module-dir", ConfigName: "module.dir", - Default: module.DefaultDir, + Default: filepath.Join(fsutils.HomeDir(), ".trivy", "modules"), Usage: "specify directory to the wasm modules that will be loaded", Persistent: true, } diff --git a/pkg/module/module.go b/pkg/module/module.go index d4a208ffc5..3a74f2c062 100644 --- a/pkg/module/module.go +++ b/pkg/module/module.go @@ -17,13 +17,12 @@ import ( wasi "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" "golang.org/x/xerrors" + "github.com/aquasecurity/trivy/pkg/extension" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/log" tapi "github.com/aquasecurity/trivy/pkg/module/api" "github.com/aquasecurity/trivy/pkg/module/serialize" - "github.com/aquasecurity/trivy/pkg/scan/post" "github.com/aquasecurity/trivy/pkg/types" - "github.com/aquasecurity/trivy/pkg/utils/fsutils" ) var ( @@ -33,10 +32,6 @@ var ( "warn": logWarn, "error": logError, } - - RelativeDir = filepath.Join(".trivy", "modules") - - DefaultDir = dir() ) // logDebug is defined as an api.GoModuleFunc for lower overhead vs reflection. @@ -172,7 +167,7 @@ func (m *Manager) Register() { func (m *Manager) Deregister() { for _, mod := range m.modules { analyzer.DeregisterAnalyzer(analyzer.Type(mod.Name())) - post.DeregisterPostScanner(mod.Name()) + extension.DeregisterHook(mod.Name()) } } @@ -262,6 +257,8 @@ func marshal(ctx context.Context, m api.Module, malloc api.Function, v any) (uin return ptr, size, nil } +var _ extension.ScanHook = (*wasmModule)(nil) + type wasmModule struct { mod api.Module memFS *memFS @@ -416,7 +413,7 @@ func (m *wasmModule) Register() { } if m.isPostScanner { logger.Debug("Registering custom post scanner") - post.RegisterPostScanner(m) + extension.RegisterHook(m) } } @@ -486,8 +483,11 @@ func (m *wasmModule) Analyze(ctx context.Context, input analyzer.AnalysisInput) return &result, nil } -// PostScan performs post scanning -// e.g. Remove a vulnerability, change severity, etc. +func (m *wasmModule) PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error { + // TODO: Implement + return nil +} + func (m *wasmModule) PostScan(ctx context.Context, results types.Results) (types.Results, error) { // Find custom resources var custom types.Result @@ -746,10 +746,6 @@ func isType(ctx context.Context, mod api.Module, name string) (bool, error) { return isRes[0] > 0, nil } -func dir() string { - return filepath.Join(fsutils.HomeDir(), RelativeDir) -} - func modulePostScanSpec(ctx context.Context, mod api.Module, freeFn api.Function) (serialize.PostScanSpec, error) { postScanSpecFunc := mod.ExportedFunction("post_scan_spec") if postScanSpecFunc == nil { diff --git a/pkg/module/module_test.go b/pkg/module/module_test.go index 31b0c0ff07..802dec8b10 100644 --- a/pkg/module/module_test.go +++ b/pkg/module/module_test.go @@ -6,12 +6,13 @@ import ( "runtime" "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/aquasecurity/trivy/pkg/extension" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/module" - "github.com/aquasecurity/trivy/pkg/scan/post" ) func TestManager_Register(t *testing.T) { @@ -20,12 +21,12 @@ func TestManager_Register(t *testing.T) { t.Skip("Test satisfied adequately by Linux tests") } tests := []struct { - name string - moduleDir string - enabledModules []string - wantAnalyzerVersions analyzer.Versions - wantPostScannerVersions map[string]int - wantErr bool + name string + moduleDir string + enabledModules []string + wantAnalyzerVersions analyzer.Versions + wantExtentions []string + wantErr bool }{ { name: "happy path", @@ -36,8 +37,8 @@ func TestManager_Register(t *testing.T) { }, PostAnalyzers: make(map[string]int), }, - wantPostScannerVersions: map[string]int{ - "happy": 1, + wantExtentions: []string{ + "happy", }, }, { @@ -49,7 +50,7 @@ func TestManager_Register(t *testing.T) { }, PostAnalyzers: make(map[string]int), }, - wantPostScannerVersions: make(map[string]int), + wantExtentions: []string{}, }, { name: "only post scanner", @@ -58,8 +59,8 @@ func TestManager_Register(t *testing.T) { Analyzers: make(map[string]int), PostAnalyzers: make(map[string]int), }, - wantPostScannerVersions: map[string]int{ - "scanner": 2, + wantExtentions: []string{ + "scanner", }, }, { @@ -69,7 +70,7 @@ func TestManager_Register(t *testing.T) { Analyzers: make(map[string]int), PostAnalyzers: make(map[string]int), }, - wantPostScannerVersions: make(map[string]int), + wantExtentions: []string{}, }, { name: "pass enabled modules", @@ -85,8 +86,8 @@ func TestManager_Register(t *testing.T) { }, PostAnalyzers: make(map[string]int), }, - wantPostScannerVersions: map[string]int{ - "happy": 1, + wantExtentions: []string{ + "happy", }, }, } @@ -124,9 +125,10 @@ func TestManager_Register(t *testing.T) { got := a.AnalyzerVersions() assert.Equal(t, tt.wantAnalyzerVersions, got) - // Confirm the post scanner is registered - gotScannerVersions := post.ScannerVersions() - assert.Equal(t, tt.wantPostScannerVersions, gotScannerVersions) + hookNames := lo.Map(extension.Hooks(), func(hook extension.Hook, _ int) string { + return hook.Name() + }) + assert.Equal(t, tt.wantExtentions, hookNames) }) } } diff --git a/pkg/report/writer.go b/pkg/report/writer.go index 023470ca97..cb6e5683f1 100644 --- a/pkg/report/writer.go +++ b/pkg/report/writer.go @@ -9,6 +9,7 @@ import ( "golang.org/x/xerrors" cr "github.com/aquasecurity/trivy/pkg/compliance/report" + "github.com/aquasecurity/trivy/pkg/extension" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/log" @@ -26,6 +27,11 @@ const ( // Write writes the result to output, format as passed in argument func Write(ctx context.Context, report types.Report, option flag.Options) (err error) { + // Call pre-report hooks + if err := extension.PreReport(ctx, &report, option); err != nil { + return xerrors.Errorf("pre report error: %w", err) + } + output, cleanup, err := option.OutputWriter(ctx) if err != nil { return xerrors.Errorf("failed to create a file: %w", err) @@ -106,6 +112,11 @@ func Write(ctx context.Context, report types.Report, option flag.Options) (err e return xerrors.Errorf("failed to write results: %w", err) } + // Call post-report hooks + if err := extension.PostReport(ctx, &report, option); err != nil { + return xerrors.Errorf("post report error: %w", err) + } + return nil } diff --git a/pkg/report/writer_test.go b/pkg/report/writer_test.go index e0e4274cb2..3433ab387e 100644 --- a/pkg/report/writer_test.go +++ b/pkg/report/writer_test.go @@ -1,10 +1,16 @@ package report_test import ( + "bytes" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dbTypes "github.com/aquasecurity/trivy-db/pkg/types" + "github.com/aquasecurity/trivy/internal/hooktest" + "github.com/aquasecurity/trivy/pkg/flag" + "github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/types" ) @@ -82,3 +88,93 @@ func TestResults_Failed(t *testing.T) { }) } } + +func TestWrite(t *testing.T) { + testReport := types.Report{ + SchemaVersion: report.SchemaVersion, + ArtifactName: "test-artifact", + Results: types.Results{ + { + Target: "test-target", + Vulnerabilities: []types.DetectedVulnerability{ + { + VulnerabilityID: "CVE-2021-0001", + PkgName: "test-pkg", + Vulnerability: dbTypes.Vulnerability{ + Title: "Test Vulnerability Title", + Description: "This is a test description of a vulnerability", + }, + }, + }, + }, + }, + } + testTemplate := "{{ range . }}{{ range .Vulnerabilities }}- {{ .VulnerabilityID }}: {{ .Title }}\n {{ .Description }}\n{{ end }}{{ end }}" + + tests := []struct { + name string + setUpHook bool + report types.Report + options flag.Options + wantOutput string + wantTitle string // Expected title after function call + wantDesc string // Expected description after function call + }{ + { + name: "template with title and description", + report: testReport, + options: flag.Options{ + ReportOptions: flag.ReportOptions{ + Format: types.FormatTemplate, + Template: testTemplate, + }, + }, + wantOutput: "- CVE-2021-0001: Test Vulnerability Title\n This is a test description of a vulnerability\n", + wantTitle: "Test Vulnerability Title", // Should remain unchanged + wantDesc: "This is a test description of a vulnerability", // Should remain unchanged + }, + { + name: "report modified by hooks", + setUpHook: true, + report: testReport, + options: flag.Options{ + ReportOptions: flag.ReportOptions{ + Format: types.FormatTemplate, + Template: testTemplate, + }, + }, + // The template output only reflects the pre-report hook changes because + // the post-report hook runs AFTER the output is written. + // However, the report object itself is modified by both pre and post hooks. + wantOutput: "- CVE-2021-0001: Modified by pre-report hook\n This is a test description of a vulnerability\n", + wantTitle: "Modified by pre-report hook", + wantDesc: "Modified by post-report hook", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setUpHook { + hooktest.Init(t) + } + + // Create a buffer to capture the output + output := new(bytes.Buffer) + tt.options.SetOutputWriter(output) + + // Execute the Write function + err := report.Write(t.Context(), tt.report, tt.options) + require.NoError(t, err) + + // Verify the output matches the expected template rendering + got := output.String() + assert.Equal(t, tt.wantOutput, got, "Template output does not match wanted value") + + // Verify that the title and description in the report match the expected values + require.Len(t, tt.report.Results, 1) + require.Len(t, tt.report.Results[0].Vulnerabilities, 1) + assert.Equal(t, tt.wantTitle, tt.report.Results[0].Vulnerabilities[0].Title) + assert.Equal(t, tt.wantDesc, tt.report.Results[0].Vulnerabilities[0].Description) + }) + } +} diff --git a/pkg/scan/local/service.go b/pkg/scan/local/service.go index 1b57124984..8226a59a67 100644 --- a/pkg/scan/local/service.go +++ b/pkg/scan/local/service.go @@ -15,6 +15,7 @@ import ( dbTypes "github.com/aquasecurity/trivy-db/pkg/types" ospkgDetector "github.com/aquasecurity/trivy/pkg/detector/ospkg" + "github.com/aquasecurity/trivy/pkg/extension" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/applier" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" @@ -23,7 +24,6 @@ import ( "github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/scan/langpkg" "github.com/aquasecurity/trivy/pkg/scan/ospkg" - "github.com/aquasecurity/trivy/pkg/scan/post" "github.com/aquasecurity/trivy/pkg/set" "github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/vulnerability" @@ -49,7 +49,7 @@ type Service struct { vulnClient vulnerability.Client } -// NewService is the factory method for Scanner +// NewService is the factory method for scan service func NewService(a applier.Applier, osPkgScanner ospkg.Scanner, langPkgScanner langpkg.Scanner, vulnClient vulnerability.Client) Service { return Service{ @@ -113,6 +113,11 @@ func (s Service) Scan(ctx context.Context, targetName, artifactKey string, blobK } func (s Service) ScanTarget(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (types.Results, ftypes.OS, error) { + // Call pre-scan hooks + if err := extension.PreScan(ctx, &target, options); err != nil { + return nil, ftypes.OS{}, xerrors.Errorf("pre scan error: %w", err) + } + var results types.Results // Filter packages according to the options @@ -148,9 +153,8 @@ func (s Service) ScanTarget(ctx context.Context, target types.ScanTarget, option s.vulnClient.FillInfo(results[i].Vulnerabilities, options.VulnSeveritySources) } - // Post scanning - results, err = post.Scan(ctx, results) - if err != nil { + // Call post-scan hooks + if results, err = extension.PostScan(ctx, results); err != nil { return nil, ftypes.OS{}, xerrors.Errorf("post scan error: %w", err) } diff --git a/pkg/scan/local/service_test.go b/pkg/scan/local/service_test.go index 876926b409..a8486533aa 100644 --- a/pkg/scan/local/service_test.go +++ b/pkg/scan/local/service_test.go @@ -12,6 +12,7 @@ import ( "github.com/aquasecurity/trivy-db/pkg/db" dbTypes "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/internal/dbtest" + "github.com/aquasecurity/trivy/internal/hooktest" "github.com/aquasecurity/trivy/pkg/cache" "github.com/aquasecurity/trivy/pkg/fanal/applier" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" @@ -151,6 +152,7 @@ func TestScanner_Scan(t *testing.T) { name string args args fixtures []string + setUpHook bool setupCache func(t *testing.T) cache.Cache wantResults types.Results wantOS ftypes.OS @@ -909,6 +911,75 @@ func TestScanner_Scan(t *testing.T) { Name: "3.11", }, }, + { + name: "happy path with hooks", + args: args{ + target: "alpine:latest", + layerIDs: []string{"sha256:5216338b40a7b96416b8b9858974bbe4acc3096ee60acbc4dfb1ee02aecceb10"}, + options: types.ScanOptions{ + PkgTypes: []string{types.PkgTypeOS}, + PkgRelationships: ftypes.Relationships, + Scanners: types.Scanners{types.VulnerabilityScanner}, + VulnSeveritySources: []dbTypes.SourceID{"auto"}, + }, + }, + fixtures: []string{"testdata/fixtures/happy.yaml"}, + setUpHook: true, + setupCache: func(t *testing.T) cache.Cache { + c := cache.NewMemoryCache() + require.NoError(t, c.PutBlob("sha256:5216338b40a7b96416b8b9858974bbe4acc3096ee60acbc4dfb1ee02aecceb10", ftypes.BlobInfo{ + SchemaVersion: ftypes.BlobJSONSchemaVersion, + OS: ftypes.OS{ + Family: ftypes.Alpine, + Name: "3.11", + }, + PackageInfos: []ftypes.PackageInfo{ + { + FilePath: "lib/apk/db/installed", + Packages: []ftypes.Package{muslPkg}, + }, + }, + })) + return c + }, + wantResults: types.Results{ + { + Target: "alpine:latest (pre-scan) (alpine 3.11)", + Class: types.ClassOSPkg, + Type: ftypes.Alpine, + Packages: ftypes.Packages{ + muslPkg, + }, + Vulnerabilities: []types.DetectedVulnerability{ + { + VulnerabilityID: "CVE-2020-9999", + PkgName: muslPkg.Name, + PkgIdentifier: muslPkg.Identifier, + InstalledVersion: muslPkg.Version, + FixedVersion: "1.2.4", + Status: dbTypes.StatusFixed, + Layer: ftypes.Layer{ + DiffID: "sha256:ebf12965380b39889c99a9c02e82ba465f887b45975b6e389d42e9e6a3857888", + }, + PrimaryURL: "https://avd.aquasec.com/nvd/cve-2020-9999", + Vulnerability: dbTypes.Vulnerability{ + Title: "dos", + Description: "dos vulnerability", + Severity: "HIGH", + References: []string{ + "https://example.com/post-scan", // modified by post-scan hook + }, + }, + }, + }, + }, + }, + wantOS: ftypes.OS{ + Family: "alpine", + Name: "3.11", + Eosl: true, + }, + }, { name: "happy path with misconfigurations", args: args{ @@ -1242,6 +1313,10 @@ func TestScanner_Scan(t *testing.T) { _ = dbtest.InitDB(t, tt.fixtures) defer db.Close() + if tt.setUpHook { + hooktest.Init(t) + } + c := tt.setupCache(t) a := applier.NewApplier(c) s := NewService(a, ospkg.NewScanner(), langpkg.NewScanner(), vulnerability.NewClient(db.Config{})) diff --git a/pkg/scan/post/post_scan.go b/pkg/scan/post/post_scan.go deleted file mode 100644 index 53f72819b1..0000000000 --- a/pkg/scan/post/post_scan.go +++ /dev/null @@ -1,45 +0,0 @@ -package post - -import ( - "context" - - "golang.org/x/xerrors" - - "github.com/aquasecurity/trivy/pkg/types" -) - -type Scanner interface { - Name() string - Version() int - PostScan(ctx context.Context, results types.Results) (types.Results, error) -} - -func RegisterPostScanner(s Scanner) { - // Avoid duplication - postScanners[s.Name()] = s -} - -func DeregisterPostScanner(name string) { - delete(postScanners, name) -} - -func ScannerVersions() map[string]int { - versions := make(map[string]int) - for _, s := range postScanners { - versions[s.Name()] = s.Version() - } - return versions -} - -var postScanners = make(map[string]Scanner) - -func Scan(ctx context.Context, results types.Results) (types.Results, error) { - var err error - for _, s := range postScanners { - results, err = s.PostScan(ctx, results) - if err != nil { - return nil, xerrors.Errorf("%s post scan error: %w", s.Name(), err) - } - } - return results, nil -} diff --git a/pkg/scan/post/post_scan_test.go b/pkg/scan/post/post_scan_test.go deleted file mode 100644 index e6c8246940..0000000000 --- a/pkg/scan/post/post_scan_test.go +++ /dev/null @@ -1,103 +0,0 @@ -package post_test - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - dbTypes "github.com/aquasecurity/trivy-db/pkg/types" - "github.com/aquasecurity/trivy/pkg/scan/post" - "github.com/aquasecurity/trivy/pkg/types" -) - -type testPostScanner struct{} - -func (testPostScanner) Name() string { - return "test" -} - -func (testPostScanner) Version() int { - return 1 -} - -func (testPostScanner) PostScan(ctx context.Context, results types.Results) (types.Results, error) { - for i, r := range results { - if r.Target == "bad" { - return nil, errors.New("bad") - } - for j := range r.Vulnerabilities { - results[i].Vulnerabilities[j].Severity = "LOW" - } - } - return results, nil -} - -func TestScan(t *testing.T) { - tests := []struct { - name string - results types.Results - want types.Results - wantErr bool - }{ - { - name: "happy path", - results: types.Results{ - { - Target: "test", - Vulnerabilities: []types.DetectedVulnerability{ - { - VulnerabilityID: "CVE-2022-0001", - PkgName: "musl", - InstalledVersion: "1.2.3", - FixedVersion: "1.2.4", - Vulnerability: dbTypes.Vulnerability{ - Severity: "CRITICAL", - }, - }, - }, - }, - }, - want: types.Results{ - { - Target: "test", - Vulnerabilities: []types.DetectedVulnerability{ - { - VulnerabilityID: "CVE-2022-0001", - PkgName: "musl", - InstalledVersion: "1.2.3", - FixedVersion: "1.2.4", - Vulnerability: dbTypes.Vulnerability{ - Severity: "LOW", - }, - }, - }, - }, - }, - }, - { - name: "sad path", - results: types.Results{ - { - Target: "bad", - }, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := testPostScanner{} - post.RegisterPostScanner(s) - defer func() { - post.DeregisterPostScanner(s.Name()) - }() - - results, err := post.Scan(t.Context(), tt.results) - require.Equal(t, tt.wantErr, err != nil) - assert.Equal(t, tt.want, results) - }) - } -}