feat: Adding --module-dir and --enable-modules (#3677)

Co-authored-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
Kalyana Krishna Varanasi
2023-03-01 15:39:53 +05:30
committed by GitHub
parent 34120f4201
commit 302c8ae24c
9 changed files with 222 additions and 88 deletions

View File

@@ -1,18 +1,14 @@
//go:build module_integration //go:build module_integration
package integration package integration
import ( import (
"os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/module" "github.com/aquasecurity/trivy/pkg/scanner/post"
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
) )
func TestModule(t *testing.T) { func TestModule(t *testing.T) {
@@ -36,17 +32,6 @@ func TestModule(t *testing.T) {
// Set up testing DB // Set up testing DB
cacheDir := initDB(t) cacheDir := initDB(t)
// Set up module dir
moduleDir := filepath.Join(cacheDir, module.RelativeDir)
err := os.MkdirAll(moduleDir, 0700)
require.NoError(t, err)
// Set up Spring4Shell module
t.Setenv("XDG_DATA_HOME", cacheDir)
_, err = fsutils.CopyFile(filepath.Join("../", "examples", "module", "spring4shell", "spring4shell.wasm"),
filepath.Join(moduleDir, "spring4shell.wasm"))
require.NoError(t, err)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
osArgs := []string{ osArgs := []string{
@@ -56,8 +41,11 @@ func TestModule(t *testing.T) {
"--ignore-unfixed", "--ignore-unfixed",
"--format", "--format",
"json", "json",
"--skip-update", "--skip-db-update",
"--offline-scan", "--offline-scan",
"--quiet",
"--module-dir",
filepath.Join("../", "examples", "module", "spring4shell"),
"--input", "--input",
tt.input, tt.input,
} }
@@ -74,9 +62,12 @@ func TestModule(t *testing.T) {
}...) }...)
// Run Trivy // Run Trivy
err = execute(osArgs) err := execute(osArgs)
assert.NoError(t, err) require.NoError(t, err)
defer analyzer.DeregisterAnalyzer("spring4shell") defer func() {
analyzer.DeregisterAnalyzer("spring4shell")
post.DeregisterPostScanner("spring4shell")
}()
// Compare want and got // Compare want and got
compareReports(t, tt.golden, outputFile) compareReports(t, tt.golden, outputFile)

View File

@@ -161,7 +161,7 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co
// viper.BindPFlag cannot be called in init(). // viper.BindPFlag cannot be called in init().
// cf. https://github.com/spf13/cobra/issues/875 // cf. https://github.com/spf13/cobra/issues/875
// https://github.com/spf13/viper/issues/233 // https://github.com/spf13/viper/issues/233
if err := globalFlags.Bind(cmd.Root()); err != nil { if err := globalFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err) return xerrors.Errorf("flag bind error: %w", err)
} }
@@ -222,6 +222,7 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific
LicenseFlagGroup: flag.NewLicenseFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
RegoFlagGroup: flag.NewRegoFlagGroup(), RegoFlagGroup: flag.NewRegoFlagGroup(),
ReportFlagGroup: reportFlagGroup, ReportFlagGroup: reportFlagGroup,
@@ -297,6 +298,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
DBFlagGroup: flag.NewDBFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
RegoFlagGroup: flag.NewRegoFlagGroup(), RegoFlagGroup: flag.NewRegoFlagGroup(),
ReportFlagGroup: reportFlagGroup, ReportFlagGroup: reportFlagGroup,
@@ -351,6 +353,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
DBFlagGroup: flag.NewDBFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
RegoFlagGroup: flag.NewRegoFlagGroup(), RegoFlagGroup: flag.NewRegoFlagGroup(),
ReportFlagGroup: reportFlagGroup, ReportFlagGroup: reportFlagGroup,
@@ -407,6 +410,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
DBFlagGroup: flag.NewDBFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RegoFlagGroup: flag.NewRegoFlagGroup(), RegoFlagGroup: flag.NewRegoFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup, ReportFlagGroup: reportFlagGroup,
@@ -507,6 +511,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
serverFlags := &flag.Flags{ serverFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(), CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewServerFlags(), RemoteFlagGroup: flag.NewServerFlags(),
} }
@@ -560,6 +565,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
configFlags := &flag.Flags{ configFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(), CacheFlagGroup: flag.NewCacheFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RegoFlagGroup: flag.NewRegoFlagGroup(), RegoFlagGroup: flag.NewRegoFlagGroup(),
K8sFlagGroup: &flag.K8sFlagGroup{ K8sFlagGroup: &flag.K8sFlagGroup{
// disable unneeded flags // disable unneeded flags
@@ -708,6 +714,10 @@ func NewPluginCommand() *cobra.Command {
} }
func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
moduleFlags := &flag.Flags{
ModuleFlagGroup: flag.NewModuleFlagGroup(),
}
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "module subcommand", Use: "module subcommand",
Aliases: []string{"m"}, Aliases: []string{"m"},
@@ -723,14 +733,23 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
Aliases: []string{"i"}, Aliases: []string{"i"},
Short: "Install a module", Short: "Install a module",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
if err := moduleFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if len(args) != 1 { if len(args) != 1 {
return cmd.Help() return cmd.Help()
} }
repo := args[0] repo := args[0]
opts := globalFlags.ToOptions() opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter)
return module.Install(cmd.Context(), repo, opts.Quiet, opts.Insecure) if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
return module.Install(cmd.Context(), opts.ModuleDir, repo, opts.Quiet, opts.Insecure)
}, },
}, },
&cobra.Command{ &cobra.Command{
@@ -738,16 +757,27 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
Aliases: []string{"u"}, Aliases: []string{"u"},
Short: "Uninstall a module", Short: "Uninstall a module",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
if err := moduleFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if len(args) != 1 { if len(args) != 1 {
return cmd.Help() return cmd.Help()
} }
repo := args[0] repo := args[0]
return module.Uninstall(cmd.Context(), repo) opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
return module.Uninstall(cmd.Context(), opts.ModuleDir, repo)
}, },
}, },
) )
moduleFlags.AddFlags(cmd)
cmd.SetFlagErrorFunc(flagErrorFunc) cmd.SetFlagErrorFunc(flagErrorFunc)
return cmd return cmd
} }
@@ -901,6 +931,7 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
DBFlagGroup: flag.NewDBFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup, ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(), ScanFlagGroup: flag.NewScanFlagGroup(),

View File

@@ -131,7 +131,10 @@ func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...runnerOptio
} }
// Initialize WASM modules // Initialize WASM modules
m, err := module.NewManager(ctx) m, err := module.NewManager(ctx, module.Options{
Dir: cliOptions.ModuleDir,
EnabledModules: cliOptions.EnabledModules,
})
if err != nil { if err != nil {
return nil, xerrors.Errorf("WASM module error: %w", err) return nil, xerrors.Errorf("WASM module error: %w", err)
} }

View File

@@ -48,7 +48,10 @@ func Run(ctx context.Context, opts flag.Options) (err error) {
} }
// Initialize WASM modules // Initialize WASM modules
m, err := module.NewManager(ctx) m, err := module.NewManager(ctx, module.Options{
Dir: opts.ModuleDir,
EnabledModules: opts.EnabledModules,
})
if err != nil { if err != nil {
return xerrors.Errorf("WASM module error: %w", err) return xerrors.Errorf("WASM module error: %w", err)
} }

64
pkg/flag/module_flags.go Normal file
View File

@@ -0,0 +1,64 @@
package flag
import (
"github.com/aquasecurity/trivy/pkg/module"
)
// e.g. config yaml
// module:
// dir: "/path/to/my_modules"
// enable-modules:
// - spring4shell
var (
ModuleDirFlag = Flag{
Name: "module-dir",
ConfigName: "module.dir",
Value: module.DefaultDir,
Usage: "specify directory to the wasm modules that will be loaded",
Persistent: true,
}
EnableModulesFlag = Flag{
Name: "enable-modules",
ConfigName: "module.enable-modules",
Value: []string{},
Usage: "[EXPERIMENTAL] module names to enable",
Persistent: true,
}
)
// ModuleFlagGroup defines flags for modules
type ModuleFlagGroup struct {
Dir *Flag
EnabledModules *Flag
}
type ModuleOptions struct {
ModuleDir string
EnabledModules []string
}
func NewModuleFlagGroup() *ModuleFlagGroup {
return &ModuleFlagGroup{
Dir: &ModuleDirFlag,
EnabledModules: &EnableModulesFlag,
}
}
func (f *ModuleFlagGroup) Name() string {
return "Module"
}
func (f *ModuleFlagGroup) Flags() []*Flag {
return []*Flag{
f.Dir,
f.EnabledModules,
}
}
func (f *ModuleFlagGroup) ToOptions() ModuleOptions {
return ModuleOptions{
ModuleDir: getString(f.Dir),
EnabledModules: getStringSlice(f.EnabledModules),
}
}

View File

@@ -66,6 +66,7 @@ type Flags struct {
K8sFlagGroup *K8sFlagGroup K8sFlagGroup *K8sFlagGroup
LicenseFlagGroup *LicenseFlagGroup LicenseFlagGroup *LicenseFlagGroup
MisconfFlagGroup *MisconfFlagGroup MisconfFlagGroup *MisconfFlagGroup
ModuleFlagGroup *ModuleFlagGroup
RemoteFlagGroup *RemoteFlagGroup RemoteFlagGroup *RemoteFlagGroup
RegoFlagGroup *RegoFlagGroup RegoFlagGroup *RegoFlagGroup
RepoFlagGroup *RepoFlagGroup RepoFlagGroup *RepoFlagGroup
@@ -87,6 +88,7 @@ type Options struct {
K8sOptions K8sOptions
LicenseOptions LicenseOptions
MisconfOptions MisconfOptions
ModuleOptions
RegoOptions RegoOptions
RemoteOptions RemoteOptions
RepoOptions RepoOptions
@@ -156,14 +158,13 @@ func bind(cmd *cobra.Command, flag *Flag) error {
} }
// Bind CLI flags // Bind CLI flags
if flag.Persistent { f := cmd.Flags().Lookup(flag.Name)
if err := viper.BindPFlag(flag.ConfigName, cmd.PersistentFlags().Lookup(flag.Name)); err != nil { if f == nil {
return xerrors.Errorf("bind flag error: %w", err) // Lookup local persistent flags
} f = cmd.PersistentFlags().Lookup(flag.Name)
} else { }
if err := viper.BindPFlag(flag.ConfigName, cmd.Flags().Lookup(flag.Name)); err != nil { if err := viper.BindPFlag(flag.ConfigName, f); err != nil {
return xerrors.Errorf("bind flag error: %w", err) return xerrors.Errorf("bind flag error: %w", err)
}
} }
// Bind environmental variable // Bind environmental variable
@@ -275,6 +276,9 @@ func (f *Flags) groups() []FlagGroup {
if f.MisconfFlagGroup != nil { if f.MisconfFlagGroup != nil {
groups = append(groups, f.MisconfFlagGroup) groups = append(groups, f.MisconfFlagGroup)
} }
if f.ModuleFlagGroup != nil {
groups = append(groups, f.ModuleFlagGroup)
}
if f.SecretFlagGroup != nil { if f.SecretFlagGroup != nil {
groups = append(groups, f.SecretFlagGroup) groups = append(groups, f.SecretFlagGroup)
} }
@@ -404,6 +408,10 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF
} }
} }
if f.ModuleFlagGroup != nil {
opts.ModuleOptions = f.ModuleFlagGroup.ToOptions()
}
if f.RegoFlagGroup != nil { if f.RegoFlagGroup != nil {
opts.RegoOptions, err = f.RegoFlagGroup.ToOptions() opts.RegoOptions, err = f.RegoFlagGroup.ToOptions()
if err != nil { if err != nil {

View File

@@ -15,7 +15,7 @@ import (
const mediaType = "application/vnd.module.wasm.content.layer.v1+wasm" const mediaType = "application/vnd.module.wasm.content.layer.v1+wasm"
// Install installs a module // Install installs a module
func Install(ctx context.Context, repo string, quiet, insecure bool) error { func Install(ctx context.Context, dir, repo string, quiet, insecure bool) error {
ref, err := name.ParseReference(repo) ref, err := name.ParseReference(repo)
if err != nil { if err != nil {
return xerrors.Errorf("repository parse error: %w", err) return xerrors.Errorf("repository parse error: %w", err)
@@ -27,7 +27,7 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error {
return xerrors.Errorf("module initialize error: %w", err) return xerrors.Errorf("module initialize error: %w", err)
} }
dst := filepath.Join(dir(), ref.Context().Name()) dst := filepath.Join(dir, ref.Context().Name())
log.Logger.Debugf("Installing the module to %s...", dst) log.Logger.Debugf("Installing the module to %s...", dst)
if err = artifact.Download(ctx, dst); err != nil { if err = artifact.Download(ctx, dst); err != nil {
@@ -38,14 +38,14 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error {
} }
// Uninstall uninstalls a module // Uninstall uninstalls a module
func Uninstall(_ context.Context, repo string) error { func Uninstall(_ context.Context, dir, repo string) error {
ref, err := name.ParseReference(repo) ref, err := name.ParseReference(repo)
if err != nil { if err != nil {
return xerrors.Errorf("repository parse error: %w", err) return xerrors.Errorf("repository parse error: %w", err)
} }
log.Logger.Infof("Uninstalling %s ...", repo) log.Logger.Infof("Uninstalling %s ...", repo)
dst := filepath.Join(dir(), ref.Context().Name()) dst := filepath.Join(dir, ref.Context().Name())
if err = os.RemoveAll(dst); err != nil { if err = os.RemoveAll(dst); err != nil {
return xerrors.Errorf("remove error: %w", err) return xerrors.Errorf("remove error: %w", err)
} }

View File

@@ -35,6 +35,8 @@ var (
} }
RelativeDir = filepath.Join(".trivy", "modules") RelativeDir = filepath.Join(".trivy", "modules")
DefaultDir = dir()
) )
// logDebug is defined as an api.GoModuleFunc for lower overhead vs reflection. // logDebug is defined as an api.GoModuleFunc for lower overhead vs reflection.
@@ -94,13 +96,23 @@ func readMemory(mem api.Memory, offset, size uint32) []byte {
return buf return buf
} }
type Manager struct { type Options struct {
cache wazero.CompilationCache Dir string
modules []*wasmModule EnabledModules []string
} }
func NewManager(ctx context.Context) (*Manager, error) { type Manager struct {
m := &Manager{} cache wazero.CompilationCache
modules []*wasmModule
dir string
enabledModules []string
}
func NewManager(ctx context.Context, opts Options) (*Manager, error) {
m := &Manager{
dir: opts.Dir,
enabledModules: opts.EnabledModules,
}
// Create a new WebAssembly Runtime. // Create a new WebAssembly Runtime.
m.cache = wazero.NewCompilationCache() m.cache = wazero.NewCompilationCache()
@@ -114,26 +126,25 @@ func NewManager(ctx context.Context) (*Manager, error) {
} }
func (m *Manager) loadModules(ctx context.Context) error { func (m *Manager) loadModules(ctx context.Context) error {
moduleDir := dir() _, err := os.Stat(m.dir)
_, err := os.Stat(moduleDir)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil return nil
} }
log.Logger.Debugf("Module dir: %s", moduleDir) log.Logger.Debugf("Module dir: %s", m.dir)
err = filepath.Walk(moduleDir, func(path string, info fs.FileInfo, err error) error { err = filepath.Walk(m.dir, func(path string, info fs.FileInfo, err error) error {
if err != nil { if err != nil {
return err return err
} else if info.IsDir() || filepath.Ext(info.Name()) != ".wasm" { } else if info.IsDir() || filepath.Ext(info.Name()) != ".wasm" {
return nil return nil
} }
rel, err := filepath.Rel(moduleDir, path) rel, err := filepath.Rel(m.dir, path)
if err != nil { if err != nil {
return xerrors.Errorf("failed to get a relative path: %w", err) return xerrors.Errorf("failed to get a relative path: %w", err)
} }
log.Logger.Infof("Loading %s...", rel) log.Logger.Infof("Reading %s...", rel)
wasmCode, err := os.ReadFile(path) wasmCode, err := os.ReadFile(path)
if err != nil { if err != nil {
return xerrors.Errorf("file read error: %w", err) return xerrors.Errorf("file read error: %w", err)
@@ -144,6 +155,12 @@ func (m *Manager) loadModules(ctx context.Context) error {
return xerrors.Errorf("WASM module init error %s: %w", rel, err) return xerrors.Errorf("WASM module init error %s: %w", rel, err)
} }
// Skip Loading WASM modules if not in the list of enable modules flag.
if len(m.enabledModules) > 0 && !slices.Contains(m.enabledModules, p.Name()) {
return nil
}
log.Logger.Infof("%s loaded", rel)
m.modules = append(m.modules, p) m.modules = append(m.modules, p)
return nil return nil
@@ -161,6 +178,13 @@ func (m *Manager) Register() {
} }
} }
func (m *Manager) Deregister() {
for _, mod := range m.modules {
analyzer.DeregisterAnalyzer(analyzer.Type(mod.Name()))
post.DeregisterPostScanner(mod.Name())
}
}
func (m *Manager) Close(ctx context.Context) error { func (m *Manager) Close(ctx context.Context) error {
return m.cache.Close(ctx) return m.cache.Close(ctx)
} }

View File

@@ -2,7 +2,7 @@ package module_test
import ( import (
"context" "context"
"os" "io/fs"
"path/filepath" "path/filepath"
"runtime" "runtime"
"testing" "testing"
@@ -13,7 +13,6 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/module" "github.com/aquasecurity/trivy/pkg/module"
"github.com/aquasecurity/trivy/pkg/scanner/post" "github.com/aquasecurity/trivy/pkg/scanner/post"
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
) )
func TestManager_Register(t *testing.T) { func TestManager_Register(t *testing.T) {
@@ -23,15 +22,15 @@ func TestManager_Register(t *testing.T) {
} }
tests := []struct { tests := []struct {
name string name string
noModuleDir bool moduleDir string
moduleName string enabledModules []string
wantAnalyzerVersions analyzer.Versions wantAnalyzerVersions analyzer.Versions
wantPostScannerVersions map[string]int wantPostScannerVersions map[string]int
wantErr bool wantErr bool
}{ }{
{ {
name: "happy path", name: "happy path",
moduleName: "happy", moduleDir: "testdata/happy",
wantAnalyzerVersions: analyzer.Versions{ wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{ Analyzers: map[string]int{
"happy": 1, "happy": 1,
@@ -43,8 +42,8 @@ func TestManager_Register(t *testing.T) {
}, },
}, },
{ {
name: "only analyzer", name: "only analyzer",
moduleName: "analyzer", moduleDir: "testdata/analyzer",
wantAnalyzerVersions: analyzer.Versions{ wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{ Analyzers: map[string]int{
"analyzer": 1, "analyzer": 1,
@@ -54,8 +53,8 @@ func TestManager_Register(t *testing.T) {
wantPostScannerVersions: map[string]int{}, wantPostScannerVersions: map[string]int{},
}, },
{ {
name: "only post scanner", name: "only post scanner",
moduleName: "scanner", moduleDir: "testdata/scanner",
wantAnalyzerVersions: analyzer.Versions{ wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{}, Analyzers: map[string]int{},
PostAnalyzers: map[string]int{}, PostAnalyzers: map[string]int{},
@@ -65,48 +64,59 @@ func TestManager_Register(t *testing.T) {
}, },
}, },
{ {
name: "no module dir", name: "no module dir",
noModuleDir: true, moduleDir: "no-such-dir",
moduleName: "happy",
wantAnalyzerVersions: analyzer.Versions{ wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{}, Analyzers: map[string]int{},
PostAnalyzers: map[string]int{}, PostAnalyzers: map[string]int{},
}, },
wantPostScannerVersions: map[string]int{}, wantPostScannerVersions: map[string]int{},
}, },
{
name: "pass enabled modules",
moduleDir: "testdata",
enabledModules: []string{
"happy",
"analyzer",
},
wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{
"happy": 1,
"analyzer": 1,
},
PostAnalyzers: map[string]int{},
},
wantPostScannerVersions: map[string]int{
"happy": 1,
},
},
} }
// Confirm that wasm modules are generated beforehand
var count int
err := filepath.WalkDir("testdata", func(path string, d fs.DirEntry, err error) error {
if filepath.Ext(path) == ".wasm" {
count++
}
return nil
})
require.NoError(t, err)
// WASM modules must be generated before running the tests.
require.Equal(t, count, 3, "missing WASM modules, try 'make test' or 'make generate-test-modules'")
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
modulePath := filepath.Join("testdata", tt.moduleName, tt.moduleName+".wasm") m, err := module.NewManager(context.Background(), module.Options{
Dir: tt.moduleDir,
// WASM modules must be generated before running this test. EnabledModules: tt.enabledModules,
if _, err := os.Stat(modulePath); os.IsNotExist(err) { })
require.Fail(t, "missing WASM modules, try 'make test' or 'make generate-test-modules'")
}
// Set up a temp dir for modules
tmpDir := t.TempDir()
t.Setenv("XDG_DATA_HOME", tmpDir)
moduleDir := filepath.Join(tmpDir, module.RelativeDir)
if !tt.noModuleDir {
err := os.MkdirAll(moduleDir, 0777)
require.NoError(t, err)
// Copy the wasm module for testing
_, err = fsutils.CopyFile(modulePath, filepath.Join(moduleDir, filepath.Base(modulePath)))
require.NoError(t, err)
}
m, err := module.NewManager(context.Background())
require.NoError(t, err) require.NoError(t, err)
// Register analyzer and post scanner from WASM module // Register analyzer and post scanner from WASM module
m.Register() m.Register()
defer func() {
analyzer.DeregisterAnalyzer(analyzer.Type(tt.moduleName)) // Remove registered analyzers and post scanners so that it will not affect other tests.
post.DeregisterPostScanner(tt.moduleName) defer m.Deregister()
}()
// Confirm the analyzer is registered // Confirm the analyzer is registered
a, err := analyzer.NewAnalyzerGroup(analyzer.AnalyzerOptions{}) a, err := analyzer.NewAnalyzerGroup(analyzer.AnalyzerOptions{})