diff --git a/integration/module_test.go b/integration/module_test.go index c264670c48..0f4f650e8a 100644 --- a/integration/module_test.go +++ b/integration/module_test.go @@ -1,18 +1,14 @@ //go:build module_integration - package integration import ( - "os" "path/filepath" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" - "github.com/aquasecurity/trivy/pkg/module" - "github.com/aquasecurity/trivy/pkg/utils/fsutils" + "github.com/aquasecurity/trivy/pkg/scanner/post" ) func TestModule(t *testing.T) { @@ -36,17 +32,6 @@ func TestModule(t *testing.T) { // Set up testing DB cacheDir := initDB(t) - // Set up module dir - moduleDir := filepath.Join(cacheDir, module.RelativeDir) - err := os.MkdirAll(moduleDir, 0700) - require.NoError(t, err) - - // Set up Spring4Shell module - t.Setenv("XDG_DATA_HOME", cacheDir) - _, err = fsutils.CopyFile(filepath.Join("../", "examples", "module", "spring4shell", "spring4shell.wasm"), - filepath.Join(moduleDir, "spring4shell.wasm")) - require.NoError(t, err) - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { osArgs := []string{ @@ -56,8 +41,11 @@ func TestModule(t *testing.T) { "--ignore-unfixed", "--format", "json", - "--skip-update", + "--skip-db-update", "--offline-scan", + "--quiet", + "--module-dir", + filepath.Join("../", "examples", "module", "spring4shell"), "--input", tt.input, } @@ -74,9 +62,12 @@ func TestModule(t *testing.T) { }...) // Run Trivy - err = execute(osArgs) - assert.NoError(t, err) - defer analyzer.DeregisterAnalyzer("spring4shell") + err := execute(osArgs) + require.NoError(t, err) + defer func() { + analyzer.DeregisterAnalyzer("spring4shell") + post.DeregisterPostScanner("spring4shell") + }() // Compare want and got compareReports(t, tt.golden, outputFile) diff --git a/pkg/commands/app.go b/pkg/commands/app.go index a75c138c8d..b737e10a76 100644 --- a/pkg/commands/app.go +++ b/pkg/commands/app.go @@ -161,7 +161,7 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co // viper.BindPFlag cannot be called in init(). // cf. https://github.com/spf13/cobra/issues/875 // https://github.com/spf13/viper/issues/233 - if err := globalFlags.Bind(cmd.Root()); err != nil { + if err := globalFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } @@ -222,6 +222,7 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific LicenseFlagGroup: flag.NewLicenseFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), + ModuleFlagGroup: flag.NewModuleFlagGroup(), RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode RegoFlagGroup: flag.NewRegoFlagGroup(), ReportFlagGroup: reportFlagGroup, @@ -297,6 +298,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { DBFlagGroup: flag.NewDBFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), + ModuleFlagGroup: flag.NewModuleFlagGroup(), RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode RegoFlagGroup: flag.NewRegoFlagGroup(), ReportFlagGroup: reportFlagGroup, @@ -351,6 +353,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { DBFlagGroup: flag.NewDBFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), + ModuleFlagGroup: flag.NewModuleFlagGroup(), RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode RegoFlagGroup: flag.NewRegoFlagGroup(), ReportFlagGroup: reportFlagGroup, @@ -407,6 +410,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { DBFlagGroup: flag.NewDBFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), + ModuleFlagGroup: flag.NewModuleFlagGroup(), RegoFlagGroup: flag.NewRegoFlagGroup(), RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode ReportFlagGroup: reportFlagGroup, @@ -507,6 +511,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { serverFlags := &flag.Flags{ CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), + ModuleFlagGroup: flag.NewModuleFlagGroup(), RemoteFlagGroup: flag.NewServerFlags(), } @@ -560,6 +565,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { configFlags := &flag.Flags{ CacheFlagGroup: flag.NewCacheFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), + ModuleFlagGroup: flag.NewModuleFlagGroup(), RegoFlagGroup: flag.NewRegoFlagGroup(), K8sFlagGroup: &flag.K8sFlagGroup{ // disable unneeded flags @@ -708,6 +714,10 @@ func NewPluginCommand() *cobra.Command { } func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { + moduleFlags := &flag.Flags{ + ModuleFlagGroup: flag.NewModuleFlagGroup(), + } + cmd := &cobra.Command{ Use: "module subcommand", Aliases: []string{"m"}, @@ -723,14 +733,23 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { Aliases: []string{"i"}, Short: "Install a module", Args: cobra.ExactArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { + if err := moduleFlags.Bind(cmd); err != nil { + return xerrors.Errorf("flag bind error: %w", err) + } + return nil + }, RunE: func(cmd *cobra.Command, args []string) error { if len(args) != 1 { return cmd.Help() } repo := args[0] - opts := globalFlags.ToOptions() - return module.Install(cmd.Context(), repo, opts.Quiet, opts.Insecure) + opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + if err != nil { + return xerrors.Errorf("flag error: %w", err) + } + return module.Install(cmd.Context(), opts.ModuleDir, repo, opts.Quiet, opts.Insecure) }, }, &cobra.Command{ @@ -738,16 +757,27 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { Aliases: []string{"u"}, Short: "Uninstall a module", Args: cobra.ExactArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { + if err := moduleFlags.Bind(cmd); err != nil { + return xerrors.Errorf("flag bind error: %w", err) + } + return nil + }, RunE: func(cmd *cobra.Command, args []string) error { if len(args) != 1 { return cmd.Help() } repo := args[0] - return module.Uninstall(cmd.Context(), repo) + opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + if err != nil { + return xerrors.Errorf("flag error: %w", err) + } + return module.Uninstall(cmd.Context(), opts.ModuleDir, repo) }, }, ) + moduleFlags.AddFlags(cmd) cmd.SetFlagErrorFunc(flagErrorFunc) return cmd } @@ -901,6 +931,7 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { DBFlagGroup: flag.NewDBFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), + ModuleFlagGroup: flag.NewModuleFlagGroup(), RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode ReportFlagGroup: reportFlagGroup, ScanFlagGroup: flag.NewScanFlagGroup(), diff --git a/pkg/commands/artifact/run.go b/pkg/commands/artifact/run.go index bfc6bfb2ea..aadd88e346 100644 --- a/pkg/commands/artifact/run.go +++ b/pkg/commands/artifact/run.go @@ -131,7 +131,10 @@ func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...runnerOptio } // Initialize WASM modules - m, err := module.NewManager(ctx) + m, err := module.NewManager(ctx, module.Options{ + Dir: cliOptions.ModuleDir, + EnabledModules: cliOptions.EnabledModules, + }) if err != nil { return nil, xerrors.Errorf("WASM module error: %w", err) } diff --git a/pkg/commands/server/run.go b/pkg/commands/server/run.go index 83067c284e..881be8902c 100644 --- a/pkg/commands/server/run.go +++ b/pkg/commands/server/run.go @@ -48,7 +48,10 @@ func Run(ctx context.Context, opts flag.Options) (err error) { } // Initialize WASM modules - m, err := module.NewManager(ctx) + m, err := module.NewManager(ctx, module.Options{ + Dir: opts.ModuleDir, + EnabledModules: opts.EnabledModules, + }) if err != nil { return xerrors.Errorf("WASM module error: %w", err) } diff --git a/pkg/flag/module_flags.go b/pkg/flag/module_flags.go new file mode 100644 index 0000000000..c7304d5595 --- /dev/null +++ b/pkg/flag/module_flags.go @@ -0,0 +1,64 @@ +package flag + +import ( + "github.com/aquasecurity/trivy/pkg/module" +) + +// e.g. config yaml +// module: +// dir: "/path/to/my_modules" +// enable-modules: +// - spring4shell + +var ( + ModuleDirFlag = Flag{ + Name: "module-dir", + ConfigName: "module.dir", + Value: module.DefaultDir, + Usage: "specify directory to the wasm modules that will be loaded", + Persistent: true, + } + EnableModulesFlag = Flag{ + Name: "enable-modules", + ConfigName: "module.enable-modules", + Value: []string{}, + Usage: "[EXPERIMENTAL] module names to enable", + Persistent: true, + } +) + +// ModuleFlagGroup defines flags for modules +type ModuleFlagGroup struct { + Dir *Flag + EnabledModules *Flag +} + +type ModuleOptions struct { + ModuleDir string + EnabledModules []string +} + +func NewModuleFlagGroup() *ModuleFlagGroup { + return &ModuleFlagGroup{ + Dir: &ModuleDirFlag, + EnabledModules: &EnableModulesFlag, + } +} + +func (f *ModuleFlagGroup) Name() string { + return "Module" +} + +func (f *ModuleFlagGroup) Flags() []*Flag { + return []*Flag{ + f.Dir, + f.EnabledModules, + } +} + +func (f *ModuleFlagGroup) ToOptions() ModuleOptions { + return ModuleOptions{ + ModuleDir: getString(f.Dir), + EnabledModules: getStringSlice(f.EnabledModules), + } +} diff --git a/pkg/flag/options.go b/pkg/flag/options.go index c5d22ddf93..c0cdd7ee44 100644 --- a/pkg/flag/options.go +++ b/pkg/flag/options.go @@ -66,6 +66,7 @@ type Flags struct { K8sFlagGroup *K8sFlagGroup LicenseFlagGroup *LicenseFlagGroup MisconfFlagGroup *MisconfFlagGroup + ModuleFlagGroup *ModuleFlagGroup RemoteFlagGroup *RemoteFlagGroup RegoFlagGroup *RegoFlagGroup RepoFlagGroup *RepoFlagGroup @@ -87,6 +88,7 @@ type Options struct { K8sOptions LicenseOptions MisconfOptions + ModuleOptions RegoOptions RemoteOptions RepoOptions @@ -156,14 +158,13 @@ func bind(cmd *cobra.Command, flag *Flag) error { } // Bind CLI flags - if flag.Persistent { - if err := viper.BindPFlag(flag.ConfigName, cmd.PersistentFlags().Lookup(flag.Name)); err != nil { - return xerrors.Errorf("bind flag error: %w", err) - } - } else { - if err := viper.BindPFlag(flag.ConfigName, cmd.Flags().Lookup(flag.Name)); err != nil { - return xerrors.Errorf("bind flag error: %w", err) - } + f := cmd.Flags().Lookup(flag.Name) + if f == nil { + // Lookup local persistent flags + f = cmd.PersistentFlags().Lookup(flag.Name) + } + if err := viper.BindPFlag(flag.ConfigName, f); err != nil { + return xerrors.Errorf("bind flag error: %w", err) } // Bind environmental variable @@ -275,6 +276,9 @@ func (f *Flags) groups() []FlagGroup { if f.MisconfFlagGroup != nil { groups = append(groups, f.MisconfFlagGroup) } + if f.ModuleFlagGroup != nil { + groups = append(groups, f.ModuleFlagGroup) + } if f.SecretFlagGroup != nil { groups = append(groups, f.SecretFlagGroup) } @@ -404,6 +408,10 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF } } + if f.ModuleFlagGroup != nil { + opts.ModuleOptions = f.ModuleFlagGroup.ToOptions() + } + if f.RegoFlagGroup != nil { opts.RegoOptions, err = f.RegoFlagGroup.ToOptions() if err != nil { diff --git a/pkg/module/command.go b/pkg/module/command.go index a10e575749..6ac36370c5 100644 --- a/pkg/module/command.go +++ b/pkg/module/command.go @@ -15,7 +15,7 @@ import ( const mediaType = "application/vnd.module.wasm.content.layer.v1+wasm" // Install installs a module -func Install(ctx context.Context, repo string, quiet, insecure bool) error { +func Install(ctx context.Context, dir, repo string, quiet, insecure bool) error { ref, err := name.ParseReference(repo) if err != nil { return xerrors.Errorf("repository parse error: %w", err) @@ -27,7 +27,7 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error { return xerrors.Errorf("module initialize error: %w", err) } - dst := filepath.Join(dir(), ref.Context().Name()) + dst := filepath.Join(dir, ref.Context().Name()) log.Logger.Debugf("Installing the module to %s...", dst) if err = artifact.Download(ctx, dst); err != nil { @@ -38,14 +38,14 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error { } // Uninstall uninstalls a module -func Uninstall(_ context.Context, repo string) error { +func Uninstall(_ context.Context, dir, repo string) error { ref, err := name.ParseReference(repo) if err != nil { return xerrors.Errorf("repository parse error: %w", err) } log.Logger.Infof("Uninstalling %s ...", repo) - dst := filepath.Join(dir(), ref.Context().Name()) + dst := filepath.Join(dir, ref.Context().Name()) if err = os.RemoveAll(dst); err != nil { return xerrors.Errorf("remove error: %w", err) } diff --git a/pkg/module/module.go b/pkg/module/module.go index 2588358b52..4395feedde 100644 --- a/pkg/module/module.go +++ b/pkg/module/module.go @@ -35,6 +35,8 @@ var ( } RelativeDir = filepath.Join(".trivy", "modules") + + DefaultDir = dir() ) // logDebug is defined as an api.GoModuleFunc for lower overhead vs reflection. @@ -94,13 +96,23 @@ func readMemory(mem api.Memory, offset, size uint32) []byte { return buf } -type Manager struct { - cache wazero.CompilationCache - modules []*wasmModule +type Options struct { + Dir string + EnabledModules []string } -func NewManager(ctx context.Context) (*Manager, error) { - m := &Manager{} +type Manager struct { + cache wazero.CompilationCache + modules []*wasmModule + dir string + enabledModules []string +} + +func NewManager(ctx context.Context, opts Options) (*Manager, error) { + m := &Manager{ + dir: opts.Dir, + enabledModules: opts.EnabledModules, + } // Create a new WebAssembly Runtime. m.cache = wazero.NewCompilationCache() @@ -114,26 +126,25 @@ func NewManager(ctx context.Context) (*Manager, error) { } func (m *Manager) loadModules(ctx context.Context) error { - moduleDir := dir() - _, err := os.Stat(moduleDir) + _, err := os.Stat(m.dir) if os.IsNotExist(err) { return nil } - log.Logger.Debugf("Module dir: %s", moduleDir) + log.Logger.Debugf("Module dir: %s", m.dir) - err = filepath.Walk(moduleDir, func(path string, info fs.FileInfo, err error) error { + err = filepath.Walk(m.dir, func(path string, info fs.FileInfo, err error) error { if err != nil { return err } else if info.IsDir() || filepath.Ext(info.Name()) != ".wasm" { return nil } - rel, err := filepath.Rel(moduleDir, path) + rel, err := filepath.Rel(m.dir, path) if err != nil { return xerrors.Errorf("failed to get a relative path: %w", err) } - log.Logger.Infof("Loading %s...", rel) + log.Logger.Infof("Reading %s...", rel) wasmCode, err := os.ReadFile(path) if err != nil { return xerrors.Errorf("file read error: %w", err) @@ -144,6 +155,12 @@ func (m *Manager) loadModules(ctx context.Context) error { return xerrors.Errorf("WASM module init error %s: %w", rel, err) } + // Skip Loading WASM modules if not in the list of enable modules flag. + if len(m.enabledModules) > 0 && !slices.Contains(m.enabledModules, p.Name()) { + return nil + } + + log.Logger.Infof("%s loaded", rel) m.modules = append(m.modules, p) return nil @@ -161,6 +178,13 @@ func (m *Manager) Register() { } } +func (m *Manager) Deregister() { + for _, mod := range m.modules { + analyzer.DeregisterAnalyzer(analyzer.Type(mod.Name())) + post.DeregisterPostScanner(mod.Name()) + } +} + func (m *Manager) Close(ctx context.Context) error { return m.cache.Close(ctx) } diff --git a/pkg/module/module_test.go b/pkg/module/module_test.go index f989d22005..0964d3ef73 100644 --- a/pkg/module/module_test.go +++ b/pkg/module/module_test.go @@ -2,7 +2,7 @@ package module_test import ( "context" - "os" + "io/fs" "path/filepath" "runtime" "testing" @@ -13,7 +13,6 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/module" "github.com/aquasecurity/trivy/pkg/scanner/post" - "github.com/aquasecurity/trivy/pkg/utils/fsutils" ) func TestManager_Register(t *testing.T) { @@ -23,15 +22,15 @@ func TestManager_Register(t *testing.T) { } tests := []struct { name string - noModuleDir bool - moduleName string + moduleDir string + enabledModules []string wantAnalyzerVersions analyzer.Versions wantPostScannerVersions map[string]int wantErr bool }{ { - name: "happy path", - moduleName: "happy", + name: "happy path", + moduleDir: "testdata/happy", wantAnalyzerVersions: analyzer.Versions{ Analyzers: map[string]int{ "happy": 1, @@ -43,8 +42,8 @@ func TestManager_Register(t *testing.T) { }, }, { - name: "only analyzer", - moduleName: "analyzer", + name: "only analyzer", + moduleDir: "testdata/analyzer", wantAnalyzerVersions: analyzer.Versions{ Analyzers: map[string]int{ "analyzer": 1, @@ -54,8 +53,8 @@ func TestManager_Register(t *testing.T) { wantPostScannerVersions: map[string]int{}, }, { - name: "only post scanner", - moduleName: "scanner", + name: "only post scanner", + moduleDir: "testdata/scanner", wantAnalyzerVersions: analyzer.Versions{ Analyzers: map[string]int{}, PostAnalyzers: map[string]int{}, @@ -65,48 +64,59 @@ func TestManager_Register(t *testing.T) { }, }, { - name: "no module dir", - noModuleDir: true, - moduleName: "happy", + name: "no module dir", + moduleDir: "no-such-dir", wantAnalyzerVersions: analyzer.Versions{ Analyzers: map[string]int{}, PostAnalyzers: map[string]int{}, }, wantPostScannerVersions: map[string]int{}, }, + { + name: "pass enabled modules", + moduleDir: "testdata", + enabledModules: []string{ + "happy", + "analyzer", + }, + wantAnalyzerVersions: analyzer.Versions{ + Analyzers: map[string]int{ + "happy": 1, + "analyzer": 1, + }, + PostAnalyzers: map[string]int{}, + }, + wantPostScannerVersions: map[string]int{ + "happy": 1, + }, + }, } + + // Confirm that wasm modules are generated beforehand + var count int + err := filepath.WalkDir("testdata", func(path string, d fs.DirEntry, err error) error { + if filepath.Ext(path) == ".wasm" { + count++ + } + return nil + }) + require.NoError(t, err) + // WASM modules must be generated before running the tests. + require.Equal(t, count, 3, "missing WASM modules, try 'make test' or 'make generate-test-modules'") + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - modulePath := filepath.Join("testdata", tt.moduleName, tt.moduleName+".wasm") - - // WASM modules must be generated before running this test. - if _, err := os.Stat(modulePath); os.IsNotExist(err) { - require.Fail(t, "missing WASM modules, try 'make test' or 'make generate-test-modules'") - } - - // Set up a temp dir for modules - tmpDir := t.TempDir() - t.Setenv("XDG_DATA_HOME", tmpDir) - moduleDir := filepath.Join(tmpDir, module.RelativeDir) - - if !tt.noModuleDir { - err := os.MkdirAll(moduleDir, 0777) - require.NoError(t, err) - - // Copy the wasm module for testing - _, err = fsutils.CopyFile(modulePath, filepath.Join(moduleDir, filepath.Base(modulePath))) - require.NoError(t, err) - } - - m, err := module.NewManager(context.Background()) + m, err := module.NewManager(context.Background(), module.Options{ + Dir: tt.moduleDir, + EnabledModules: tt.enabledModules, + }) require.NoError(t, err) // Register analyzer and post scanner from WASM module m.Register() - defer func() { - analyzer.DeregisterAnalyzer(analyzer.Type(tt.moduleName)) - post.DeregisterPostScanner(tt.moduleName) - }() + + // Remove registered analyzers and post scanners so that it will not affect other tests. + defer m.Deregister() // Confirm the analyzer is registered a, err := analyzer.NewAnalyzerGroup(analyzer.AnalyzerOptions{})