mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-21 14:50:53 -08:00
feat: Adding --module-dir and --enable-modules (#3677)
Co-authored-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
committed by
GitHub
parent
34120f4201
commit
302c8ae24c
@@ -1,18 +1,14 @@
|
||||
//go:build module_integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
||||
"github.com/aquasecurity/trivy/pkg/module"
|
||||
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
|
||||
"github.com/aquasecurity/trivy/pkg/scanner/post"
|
||||
)
|
||||
|
||||
func TestModule(t *testing.T) {
|
||||
@@ -36,17 +32,6 @@ func TestModule(t *testing.T) {
|
||||
// Set up testing DB
|
||||
cacheDir := initDB(t)
|
||||
|
||||
// Set up module dir
|
||||
moduleDir := filepath.Join(cacheDir, module.RelativeDir)
|
||||
err := os.MkdirAll(moduleDir, 0700)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up Spring4Shell module
|
||||
t.Setenv("XDG_DATA_HOME", cacheDir)
|
||||
_, err = fsutils.CopyFile(filepath.Join("../", "examples", "module", "spring4shell", "spring4shell.wasm"),
|
||||
filepath.Join(moduleDir, "spring4shell.wasm"))
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
osArgs := []string{
|
||||
@@ -56,8 +41,11 @@ func TestModule(t *testing.T) {
|
||||
"--ignore-unfixed",
|
||||
"--format",
|
||||
"json",
|
||||
"--skip-update",
|
||||
"--skip-db-update",
|
||||
"--offline-scan",
|
||||
"--quiet",
|
||||
"--module-dir",
|
||||
filepath.Join("../", "examples", "module", "spring4shell"),
|
||||
"--input",
|
||||
tt.input,
|
||||
}
|
||||
@@ -74,9 +62,12 @@ func TestModule(t *testing.T) {
|
||||
}...)
|
||||
|
||||
// Run Trivy
|
||||
err = execute(osArgs)
|
||||
assert.NoError(t, err)
|
||||
defer analyzer.DeregisterAnalyzer("spring4shell")
|
||||
err := execute(osArgs)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
analyzer.DeregisterAnalyzer("spring4shell")
|
||||
post.DeregisterPostScanner("spring4shell")
|
||||
}()
|
||||
|
||||
// Compare want and got
|
||||
compareReports(t, tt.golden, outputFile)
|
||||
|
||||
@@ -161,7 +161,7 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co
|
||||
// viper.BindPFlag cannot be called in init().
|
||||
// cf. https://github.com/spf13/cobra/issues/875
|
||||
// https://github.com/spf13/viper/issues/233
|
||||
if err := globalFlags.Bind(cmd.Root()); err != nil {
|
||||
if err := globalFlags.Bind(cmd); err != nil {
|
||||
return xerrors.Errorf("flag bind error: %w", err)
|
||||
}
|
||||
|
||||
@@ -222,6 +222,7 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific
|
||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||
ReportFlagGroup: reportFlagGroup,
|
||||
@@ -297,6 +298,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||
ReportFlagGroup: reportFlagGroup,
|
||||
@@ -351,6 +353,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||
ReportFlagGroup: reportFlagGroup,
|
||||
@@ -407,6 +410,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||
ReportFlagGroup: reportFlagGroup,
|
||||
@@ -507,6 +511,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
serverFlags := &flag.Flags{
|
||||
CacheFlagGroup: flag.NewCacheFlagGroup(),
|
||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||
RemoteFlagGroup: flag.NewServerFlags(),
|
||||
}
|
||||
|
||||
@@ -560,6 +565,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
configFlags := &flag.Flags{
|
||||
CacheFlagGroup: flag.NewCacheFlagGroup(),
|
||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||
K8sFlagGroup: &flag.K8sFlagGroup{
|
||||
// disable unneeded flags
|
||||
@@ -708,6 +714,10 @@ func NewPluginCommand() *cobra.Command {
|
||||
}
|
||||
|
||||
func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
moduleFlags := &flag.Flags{
|
||||
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "module subcommand",
|
||||
Aliases: []string{"m"},
|
||||
@@ -723,14 +733,23 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
Aliases: []string{"i"},
|
||||
Short: "Install a module",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := moduleFlags.Bind(cmd); err != nil {
|
||||
return xerrors.Errorf("flag bind error: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return cmd.Help()
|
||||
}
|
||||
|
||||
repo := args[0]
|
||||
opts := globalFlags.ToOptions()
|
||||
return module.Install(cmd.Context(), repo, opts.Quiet, opts.Insecure)
|
||||
opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("flag error: %w", err)
|
||||
}
|
||||
return module.Install(cmd.Context(), opts.ModuleDir, repo, opts.Quiet, opts.Insecure)
|
||||
},
|
||||
},
|
||||
&cobra.Command{
|
||||
@@ -738,16 +757,27 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
Aliases: []string{"u"},
|
||||
Short: "Uninstall a module",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := moduleFlags.Bind(cmd); err != nil {
|
||||
return xerrors.Errorf("flag bind error: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return cmd.Help()
|
||||
}
|
||||
|
||||
repo := args[0]
|
||||
return module.Uninstall(cmd.Context(), repo)
|
||||
opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("flag error: %w", err)
|
||||
}
|
||||
return module.Uninstall(cmd.Context(), opts.ModuleDir, repo)
|
||||
},
|
||||
},
|
||||
)
|
||||
moduleFlags.AddFlags(cmd)
|
||||
cmd.SetFlagErrorFunc(flagErrorFunc)
|
||||
return cmd
|
||||
}
|
||||
@@ -901,6 +931,7 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||
ReportFlagGroup: reportFlagGroup,
|
||||
ScanFlagGroup: flag.NewScanFlagGroup(),
|
||||
|
||||
@@ -131,7 +131,10 @@ func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...runnerOptio
|
||||
}
|
||||
|
||||
// Initialize WASM modules
|
||||
m, err := module.NewManager(ctx)
|
||||
m, err := module.NewManager(ctx, module.Options{
|
||||
Dir: cliOptions.ModuleDir,
|
||||
EnabledModules: cliOptions.EnabledModules,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("WASM module error: %w", err)
|
||||
}
|
||||
|
||||
@@ -48,7 +48,10 @@ func Run(ctx context.Context, opts flag.Options) (err error) {
|
||||
}
|
||||
|
||||
// Initialize WASM modules
|
||||
m, err := module.NewManager(ctx)
|
||||
m, err := module.NewManager(ctx, module.Options{
|
||||
Dir: opts.ModuleDir,
|
||||
EnabledModules: opts.EnabledModules,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("WASM module error: %w", err)
|
||||
}
|
||||
|
||||
64
pkg/flag/module_flags.go
Normal file
64
pkg/flag/module_flags.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package flag
|
||||
|
||||
import (
|
||||
"github.com/aquasecurity/trivy/pkg/module"
|
||||
)
|
||||
|
||||
// e.g. config yaml
|
||||
// module:
|
||||
// dir: "/path/to/my_modules"
|
||||
// enable-modules:
|
||||
// - spring4shell
|
||||
|
||||
var (
|
||||
ModuleDirFlag = Flag{
|
||||
Name: "module-dir",
|
||||
ConfigName: "module.dir",
|
||||
Value: module.DefaultDir,
|
||||
Usage: "specify directory to the wasm modules that will be loaded",
|
||||
Persistent: true,
|
||||
}
|
||||
EnableModulesFlag = Flag{
|
||||
Name: "enable-modules",
|
||||
ConfigName: "module.enable-modules",
|
||||
Value: []string{},
|
||||
Usage: "[EXPERIMENTAL] module names to enable",
|
||||
Persistent: true,
|
||||
}
|
||||
)
|
||||
|
||||
// ModuleFlagGroup defines flags for modules
|
||||
type ModuleFlagGroup struct {
|
||||
Dir *Flag
|
||||
EnabledModules *Flag
|
||||
}
|
||||
|
||||
type ModuleOptions struct {
|
||||
ModuleDir string
|
||||
EnabledModules []string
|
||||
}
|
||||
|
||||
func NewModuleFlagGroup() *ModuleFlagGroup {
|
||||
return &ModuleFlagGroup{
|
||||
Dir: &ModuleDirFlag,
|
||||
EnabledModules: &EnableModulesFlag,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *ModuleFlagGroup) Name() string {
|
||||
return "Module"
|
||||
}
|
||||
|
||||
func (f *ModuleFlagGroup) Flags() []*Flag {
|
||||
return []*Flag{
|
||||
f.Dir,
|
||||
f.EnabledModules,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *ModuleFlagGroup) ToOptions() ModuleOptions {
|
||||
return ModuleOptions{
|
||||
ModuleDir: getString(f.Dir),
|
||||
EnabledModules: getStringSlice(f.EnabledModules),
|
||||
}
|
||||
}
|
||||
@@ -66,6 +66,7 @@ type Flags struct {
|
||||
K8sFlagGroup *K8sFlagGroup
|
||||
LicenseFlagGroup *LicenseFlagGroup
|
||||
MisconfFlagGroup *MisconfFlagGroup
|
||||
ModuleFlagGroup *ModuleFlagGroup
|
||||
RemoteFlagGroup *RemoteFlagGroup
|
||||
RegoFlagGroup *RegoFlagGroup
|
||||
RepoFlagGroup *RepoFlagGroup
|
||||
@@ -87,6 +88,7 @@ type Options struct {
|
||||
K8sOptions
|
||||
LicenseOptions
|
||||
MisconfOptions
|
||||
ModuleOptions
|
||||
RegoOptions
|
||||
RemoteOptions
|
||||
RepoOptions
|
||||
@@ -156,14 +158,13 @@ func bind(cmd *cobra.Command, flag *Flag) error {
|
||||
}
|
||||
|
||||
// Bind CLI flags
|
||||
if flag.Persistent {
|
||||
if err := viper.BindPFlag(flag.ConfigName, cmd.PersistentFlags().Lookup(flag.Name)); err != nil {
|
||||
return xerrors.Errorf("bind flag error: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := viper.BindPFlag(flag.ConfigName, cmd.Flags().Lookup(flag.Name)); err != nil {
|
||||
return xerrors.Errorf("bind flag error: %w", err)
|
||||
}
|
||||
f := cmd.Flags().Lookup(flag.Name)
|
||||
if f == nil {
|
||||
// Lookup local persistent flags
|
||||
f = cmd.PersistentFlags().Lookup(flag.Name)
|
||||
}
|
||||
if err := viper.BindPFlag(flag.ConfigName, f); err != nil {
|
||||
return xerrors.Errorf("bind flag error: %w", err)
|
||||
}
|
||||
|
||||
// Bind environmental variable
|
||||
@@ -275,6 +276,9 @@ func (f *Flags) groups() []FlagGroup {
|
||||
if f.MisconfFlagGroup != nil {
|
||||
groups = append(groups, f.MisconfFlagGroup)
|
||||
}
|
||||
if f.ModuleFlagGroup != nil {
|
||||
groups = append(groups, f.ModuleFlagGroup)
|
||||
}
|
||||
if f.SecretFlagGroup != nil {
|
||||
groups = append(groups, f.SecretFlagGroup)
|
||||
}
|
||||
@@ -404,6 +408,10 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF
|
||||
}
|
||||
}
|
||||
|
||||
if f.ModuleFlagGroup != nil {
|
||||
opts.ModuleOptions = f.ModuleFlagGroup.ToOptions()
|
||||
}
|
||||
|
||||
if f.RegoFlagGroup != nil {
|
||||
opts.RegoOptions, err = f.RegoFlagGroup.ToOptions()
|
||||
if err != nil {
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
const mediaType = "application/vnd.module.wasm.content.layer.v1+wasm"
|
||||
|
||||
// Install installs a module
|
||||
func Install(ctx context.Context, repo string, quiet, insecure bool) error {
|
||||
func Install(ctx context.Context, dir, repo string, quiet, insecure bool) error {
|
||||
ref, err := name.ParseReference(repo)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("repository parse error: %w", err)
|
||||
@@ -27,7 +27,7 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error {
|
||||
return xerrors.Errorf("module initialize error: %w", err)
|
||||
}
|
||||
|
||||
dst := filepath.Join(dir(), ref.Context().Name())
|
||||
dst := filepath.Join(dir, ref.Context().Name())
|
||||
log.Logger.Debugf("Installing the module to %s...", dst)
|
||||
|
||||
if err = artifact.Download(ctx, dst); err != nil {
|
||||
@@ -38,14 +38,14 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error {
|
||||
}
|
||||
|
||||
// Uninstall uninstalls a module
|
||||
func Uninstall(_ context.Context, repo string) error {
|
||||
func Uninstall(_ context.Context, dir, repo string) error {
|
||||
ref, err := name.ParseReference(repo)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("repository parse error: %w", err)
|
||||
}
|
||||
|
||||
log.Logger.Infof("Uninstalling %s ...", repo)
|
||||
dst := filepath.Join(dir(), ref.Context().Name())
|
||||
dst := filepath.Join(dir, ref.Context().Name())
|
||||
if err = os.RemoveAll(dst); err != nil {
|
||||
return xerrors.Errorf("remove error: %w", err)
|
||||
}
|
||||
|
||||
@@ -35,6 +35,8 @@ var (
|
||||
}
|
||||
|
||||
RelativeDir = filepath.Join(".trivy", "modules")
|
||||
|
||||
DefaultDir = dir()
|
||||
)
|
||||
|
||||
// logDebug is defined as an api.GoModuleFunc for lower overhead vs reflection.
|
||||
@@ -94,13 +96,23 @@ func readMemory(mem api.Memory, offset, size uint32) []byte {
|
||||
return buf
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
cache wazero.CompilationCache
|
||||
modules []*wasmModule
|
||||
type Options struct {
|
||||
Dir string
|
||||
EnabledModules []string
|
||||
}
|
||||
|
||||
func NewManager(ctx context.Context) (*Manager, error) {
|
||||
m := &Manager{}
|
||||
type Manager struct {
|
||||
cache wazero.CompilationCache
|
||||
modules []*wasmModule
|
||||
dir string
|
||||
enabledModules []string
|
||||
}
|
||||
|
||||
func NewManager(ctx context.Context, opts Options) (*Manager, error) {
|
||||
m := &Manager{
|
||||
dir: opts.Dir,
|
||||
enabledModules: opts.EnabledModules,
|
||||
}
|
||||
|
||||
// Create a new WebAssembly Runtime.
|
||||
m.cache = wazero.NewCompilationCache()
|
||||
@@ -114,26 +126,25 @@ func NewManager(ctx context.Context) (*Manager, error) {
|
||||
}
|
||||
|
||||
func (m *Manager) loadModules(ctx context.Context) error {
|
||||
moduleDir := dir()
|
||||
_, err := os.Stat(moduleDir)
|
||||
_, err := os.Stat(m.dir)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
log.Logger.Debugf("Module dir: %s", moduleDir)
|
||||
log.Logger.Debugf("Module dir: %s", m.dir)
|
||||
|
||||
err = filepath.Walk(moduleDir, func(path string, info fs.FileInfo, err error) error {
|
||||
err = filepath.Walk(m.dir, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
} else if info.IsDir() || filepath.Ext(info.Name()) != ".wasm" {
|
||||
return nil
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(moduleDir, path)
|
||||
rel, err := filepath.Rel(m.dir, path)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to get a relative path: %w", err)
|
||||
}
|
||||
|
||||
log.Logger.Infof("Loading %s...", rel)
|
||||
log.Logger.Infof("Reading %s...", rel)
|
||||
wasmCode, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("file read error: %w", err)
|
||||
@@ -144,6 +155,12 @@ func (m *Manager) loadModules(ctx context.Context) error {
|
||||
return xerrors.Errorf("WASM module init error %s: %w", rel, err)
|
||||
}
|
||||
|
||||
// Skip Loading WASM modules if not in the list of enable modules flag.
|
||||
if len(m.enabledModules) > 0 && !slices.Contains(m.enabledModules, p.Name()) {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Logger.Infof("%s loaded", rel)
|
||||
m.modules = append(m.modules, p)
|
||||
|
||||
return nil
|
||||
@@ -161,6 +178,13 @@ func (m *Manager) Register() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Deregister() {
|
||||
for _, mod := range m.modules {
|
||||
analyzer.DeregisterAnalyzer(analyzer.Type(mod.Name()))
|
||||
post.DeregisterPostScanner(mod.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Close(ctx context.Context) error {
|
||||
return m.cache.Close(ctx)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package module_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
||||
"github.com/aquasecurity/trivy/pkg/module"
|
||||
"github.com/aquasecurity/trivy/pkg/scanner/post"
|
||||
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
|
||||
)
|
||||
|
||||
func TestManager_Register(t *testing.T) {
|
||||
@@ -23,15 +22,15 @@ func TestManager_Register(t *testing.T) {
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
noModuleDir bool
|
||||
moduleName string
|
||||
moduleDir string
|
||||
enabledModules []string
|
||||
wantAnalyzerVersions analyzer.Versions
|
||||
wantPostScannerVersions map[string]int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
moduleName: "happy",
|
||||
name: "happy path",
|
||||
moduleDir: "testdata/happy",
|
||||
wantAnalyzerVersions: analyzer.Versions{
|
||||
Analyzers: map[string]int{
|
||||
"happy": 1,
|
||||
@@ -43,8 +42,8 @@ func TestManager_Register(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "only analyzer",
|
||||
moduleName: "analyzer",
|
||||
name: "only analyzer",
|
||||
moduleDir: "testdata/analyzer",
|
||||
wantAnalyzerVersions: analyzer.Versions{
|
||||
Analyzers: map[string]int{
|
||||
"analyzer": 1,
|
||||
@@ -54,8 +53,8 @@ func TestManager_Register(t *testing.T) {
|
||||
wantPostScannerVersions: map[string]int{},
|
||||
},
|
||||
{
|
||||
name: "only post scanner",
|
||||
moduleName: "scanner",
|
||||
name: "only post scanner",
|
||||
moduleDir: "testdata/scanner",
|
||||
wantAnalyzerVersions: analyzer.Versions{
|
||||
Analyzers: map[string]int{},
|
||||
PostAnalyzers: map[string]int{},
|
||||
@@ -65,48 +64,59 @@ func TestManager_Register(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no module dir",
|
||||
noModuleDir: true,
|
||||
moduleName: "happy",
|
||||
name: "no module dir",
|
||||
moduleDir: "no-such-dir",
|
||||
wantAnalyzerVersions: analyzer.Versions{
|
||||
Analyzers: map[string]int{},
|
||||
PostAnalyzers: map[string]int{},
|
||||
},
|
||||
wantPostScannerVersions: map[string]int{},
|
||||
},
|
||||
{
|
||||
name: "pass enabled modules",
|
||||
moduleDir: "testdata",
|
||||
enabledModules: []string{
|
||||
"happy",
|
||||
"analyzer",
|
||||
},
|
||||
wantAnalyzerVersions: analyzer.Versions{
|
||||
Analyzers: map[string]int{
|
||||
"happy": 1,
|
||||
"analyzer": 1,
|
||||
},
|
||||
PostAnalyzers: map[string]int{},
|
||||
},
|
||||
wantPostScannerVersions: map[string]int{
|
||||
"happy": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Confirm that wasm modules are generated beforehand
|
||||
var count int
|
||||
err := filepath.WalkDir("testdata", func(path string, d fs.DirEntry, err error) error {
|
||||
if filepath.Ext(path) == ".wasm" {
|
||||
count++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// WASM modules must be generated before running the tests.
|
||||
require.Equal(t, count, 3, "missing WASM modules, try 'make test' or 'make generate-test-modules'")
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
modulePath := filepath.Join("testdata", tt.moduleName, tt.moduleName+".wasm")
|
||||
|
||||
// WASM modules must be generated before running this test.
|
||||
if _, err := os.Stat(modulePath); os.IsNotExist(err) {
|
||||
require.Fail(t, "missing WASM modules, try 'make test' or 'make generate-test-modules'")
|
||||
}
|
||||
|
||||
// Set up a temp dir for modules
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("XDG_DATA_HOME", tmpDir)
|
||||
moduleDir := filepath.Join(tmpDir, module.RelativeDir)
|
||||
|
||||
if !tt.noModuleDir {
|
||||
err := os.MkdirAll(moduleDir, 0777)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Copy the wasm module for testing
|
||||
_, err = fsutils.CopyFile(modulePath, filepath.Join(moduleDir, filepath.Base(modulePath)))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
m, err := module.NewManager(context.Background())
|
||||
m, err := module.NewManager(context.Background(), module.Options{
|
||||
Dir: tt.moduleDir,
|
||||
EnabledModules: tt.enabledModules,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Register analyzer and post scanner from WASM module
|
||||
m.Register()
|
||||
defer func() {
|
||||
analyzer.DeregisterAnalyzer(analyzer.Type(tt.moduleName))
|
||||
post.DeregisterPostScanner(tt.moduleName)
|
||||
}()
|
||||
|
||||
// Remove registered analyzers and post scanners so that it will not affect other tests.
|
||||
defer m.Deregister()
|
||||
|
||||
// Confirm the analyzer is registered
|
||||
a, err := analyzer.NewAnalyzerGroup(analyzer.AnalyzerOptions{})
|
||||
|
||||
Reference in New Issue
Block a user