feat: Adding --module-dir and --enable-modules (#3677)

Co-authored-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
Kalyana Krishna Varanasi
2023-03-01 15:39:53 +05:30
committed by GitHub
parent 34120f4201
commit 302c8ae24c
9 changed files with 222 additions and 88 deletions

View File

@@ -1,18 +1,14 @@
//go:build module_integration
package integration
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/module"
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
"github.com/aquasecurity/trivy/pkg/scanner/post"
)
func TestModule(t *testing.T) {
@@ -36,17 +32,6 @@ func TestModule(t *testing.T) {
// Set up testing DB
cacheDir := initDB(t)
// Set up module dir
moduleDir := filepath.Join(cacheDir, module.RelativeDir)
err := os.MkdirAll(moduleDir, 0700)
require.NoError(t, err)
// Set up Spring4Shell module
t.Setenv("XDG_DATA_HOME", cacheDir)
_, err = fsutils.CopyFile(filepath.Join("../", "examples", "module", "spring4shell", "spring4shell.wasm"),
filepath.Join(moduleDir, "spring4shell.wasm"))
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
osArgs := []string{
@@ -56,8 +41,11 @@ func TestModule(t *testing.T) {
"--ignore-unfixed",
"--format",
"json",
"--skip-update",
"--skip-db-update",
"--offline-scan",
"--quiet",
"--module-dir",
filepath.Join("../", "examples", "module", "spring4shell"),
"--input",
tt.input,
}
@@ -74,9 +62,12 @@ func TestModule(t *testing.T) {
}...)
// Run Trivy
err = execute(osArgs)
assert.NoError(t, err)
defer analyzer.DeregisterAnalyzer("spring4shell")
err := execute(osArgs)
require.NoError(t, err)
defer func() {
analyzer.DeregisterAnalyzer("spring4shell")
post.DeregisterPostScanner("spring4shell")
}()
// Compare want and got
compareReports(t, tt.golden, outputFile)

View File

@@ -161,7 +161,7 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co
// viper.BindPFlag cannot be called in init().
// cf. https://github.com/spf13/cobra/issues/875
// https://github.com/spf13/viper/issues/233
if err := globalFlags.Bind(cmd.Root()); err != nil {
if err := globalFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
@@ -222,6 +222,7 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
RegoFlagGroup: flag.NewRegoFlagGroup(),
ReportFlagGroup: reportFlagGroup,
@@ -297,6 +298,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
RegoFlagGroup: flag.NewRegoFlagGroup(),
ReportFlagGroup: reportFlagGroup,
@@ -351,6 +353,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
RegoFlagGroup: flag.NewRegoFlagGroup(),
ReportFlagGroup: reportFlagGroup,
@@ -407,6 +410,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RegoFlagGroup: flag.NewRegoFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
@@ -507,6 +511,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
serverFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewServerFlags(),
}
@@ -560,6 +565,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
configFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RegoFlagGroup: flag.NewRegoFlagGroup(),
K8sFlagGroup: &flag.K8sFlagGroup{
// disable unneeded flags
@@ -708,6 +714,10 @@ func NewPluginCommand() *cobra.Command {
}
func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
moduleFlags := &flag.Flags{
ModuleFlagGroup: flag.NewModuleFlagGroup(),
}
cmd := &cobra.Command{
Use: "module subcommand",
Aliases: []string{"m"},
@@ -723,14 +733,23 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
Aliases: []string{"i"},
Short: "Install a module",
Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
if err := moduleFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return cmd.Help()
}
repo := args[0]
opts := globalFlags.ToOptions()
return module.Install(cmd.Context(), repo, opts.Quiet, opts.Insecure)
opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
return module.Install(cmd.Context(), opts.ModuleDir, repo, opts.Quiet, opts.Insecure)
},
},
&cobra.Command{
@@ -738,16 +757,27 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
Aliases: []string{"u"},
Short: "Uninstall a module",
Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
if err := moduleFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return cmd.Help()
}
repo := args[0]
return module.Uninstall(cmd.Context(), repo)
opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
return module.Uninstall(cmd.Context(), opts.ModuleDir, repo)
},
},
)
moduleFlags.AddFlags(cmd)
cmd.SetFlagErrorFunc(flagErrorFunc)
return cmd
}
@@ -901,6 +931,7 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),

View File

@@ -131,7 +131,10 @@ func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...runnerOptio
}
// Initialize WASM modules
m, err := module.NewManager(ctx)
m, err := module.NewManager(ctx, module.Options{
Dir: cliOptions.ModuleDir,
EnabledModules: cliOptions.EnabledModules,
})
if err != nil {
return nil, xerrors.Errorf("WASM module error: %w", err)
}

View File

@@ -48,7 +48,10 @@ func Run(ctx context.Context, opts flag.Options) (err error) {
}
// Initialize WASM modules
m, err := module.NewManager(ctx)
m, err := module.NewManager(ctx, module.Options{
Dir: opts.ModuleDir,
EnabledModules: opts.EnabledModules,
})
if err != nil {
return xerrors.Errorf("WASM module error: %w", err)
}

64
pkg/flag/module_flags.go Normal file
View File

@@ -0,0 +1,64 @@
package flag
import (
"github.com/aquasecurity/trivy/pkg/module"
)
// e.g. config yaml
// module:
// dir: "/path/to/my_modules"
// enable-modules:
// - spring4shell
var (
ModuleDirFlag = Flag{
Name: "module-dir",
ConfigName: "module.dir",
Value: module.DefaultDir,
Usage: "specify directory to the wasm modules that will be loaded",
Persistent: true,
}
EnableModulesFlag = Flag{
Name: "enable-modules",
ConfigName: "module.enable-modules",
Value: []string{},
Usage: "[EXPERIMENTAL] module names to enable",
Persistent: true,
}
)
// ModuleFlagGroup defines flags for modules
type ModuleFlagGroup struct {
Dir *Flag
EnabledModules *Flag
}
type ModuleOptions struct {
ModuleDir string
EnabledModules []string
}
func NewModuleFlagGroup() *ModuleFlagGroup {
return &ModuleFlagGroup{
Dir: &ModuleDirFlag,
EnabledModules: &EnableModulesFlag,
}
}
func (f *ModuleFlagGroup) Name() string {
return "Module"
}
func (f *ModuleFlagGroup) Flags() []*Flag {
return []*Flag{
f.Dir,
f.EnabledModules,
}
}
func (f *ModuleFlagGroup) ToOptions() ModuleOptions {
return ModuleOptions{
ModuleDir: getString(f.Dir),
EnabledModules: getStringSlice(f.EnabledModules),
}
}

View File

@@ -66,6 +66,7 @@ type Flags struct {
K8sFlagGroup *K8sFlagGroup
LicenseFlagGroup *LicenseFlagGroup
MisconfFlagGroup *MisconfFlagGroup
ModuleFlagGroup *ModuleFlagGroup
RemoteFlagGroup *RemoteFlagGroup
RegoFlagGroup *RegoFlagGroup
RepoFlagGroup *RepoFlagGroup
@@ -87,6 +88,7 @@ type Options struct {
K8sOptions
LicenseOptions
MisconfOptions
ModuleOptions
RegoOptions
RemoteOptions
RepoOptions
@@ -156,14 +158,13 @@ func bind(cmd *cobra.Command, flag *Flag) error {
}
// Bind CLI flags
if flag.Persistent {
if err := viper.BindPFlag(flag.ConfigName, cmd.PersistentFlags().Lookup(flag.Name)); err != nil {
return xerrors.Errorf("bind flag error: %w", err)
}
} else {
if err := viper.BindPFlag(flag.ConfigName, cmd.Flags().Lookup(flag.Name)); err != nil {
return xerrors.Errorf("bind flag error: %w", err)
}
f := cmd.Flags().Lookup(flag.Name)
if f == nil {
// Lookup local persistent flags
f = cmd.PersistentFlags().Lookup(flag.Name)
}
if err := viper.BindPFlag(flag.ConfigName, f); err != nil {
return xerrors.Errorf("bind flag error: %w", err)
}
// Bind environmental variable
@@ -275,6 +276,9 @@ func (f *Flags) groups() []FlagGroup {
if f.MisconfFlagGroup != nil {
groups = append(groups, f.MisconfFlagGroup)
}
if f.ModuleFlagGroup != nil {
groups = append(groups, f.ModuleFlagGroup)
}
if f.SecretFlagGroup != nil {
groups = append(groups, f.SecretFlagGroup)
}
@@ -404,6 +408,10 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF
}
}
if f.ModuleFlagGroup != nil {
opts.ModuleOptions = f.ModuleFlagGroup.ToOptions()
}
if f.RegoFlagGroup != nil {
opts.RegoOptions, err = f.RegoFlagGroup.ToOptions()
if err != nil {

View File

@@ -15,7 +15,7 @@ import (
const mediaType = "application/vnd.module.wasm.content.layer.v1+wasm"
// Install installs a module
func Install(ctx context.Context, repo string, quiet, insecure bool) error {
func Install(ctx context.Context, dir, repo string, quiet, insecure bool) error {
ref, err := name.ParseReference(repo)
if err != nil {
return xerrors.Errorf("repository parse error: %w", err)
@@ -27,7 +27,7 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error {
return xerrors.Errorf("module initialize error: %w", err)
}
dst := filepath.Join(dir(), ref.Context().Name())
dst := filepath.Join(dir, ref.Context().Name())
log.Logger.Debugf("Installing the module to %s...", dst)
if err = artifact.Download(ctx, dst); err != nil {
@@ -38,14 +38,14 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error {
}
// Uninstall uninstalls a module
func Uninstall(_ context.Context, repo string) error {
func Uninstall(_ context.Context, dir, repo string) error {
ref, err := name.ParseReference(repo)
if err != nil {
return xerrors.Errorf("repository parse error: %w", err)
}
log.Logger.Infof("Uninstalling %s ...", repo)
dst := filepath.Join(dir(), ref.Context().Name())
dst := filepath.Join(dir, ref.Context().Name())
if err = os.RemoveAll(dst); err != nil {
return xerrors.Errorf("remove error: %w", err)
}

View File

@@ -35,6 +35,8 @@ var (
}
RelativeDir = filepath.Join(".trivy", "modules")
DefaultDir = dir()
)
// logDebug is defined as an api.GoModuleFunc for lower overhead vs reflection.
@@ -94,13 +96,23 @@ func readMemory(mem api.Memory, offset, size uint32) []byte {
return buf
}
type Manager struct {
cache wazero.CompilationCache
modules []*wasmModule
type Options struct {
Dir string
EnabledModules []string
}
func NewManager(ctx context.Context) (*Manager, error) {
m := &Manager{}
type Manager struct {
cache wazero.CompilationCache
modules []*wasmModule
dir string
enabledModules []string
}
func NewManager(ctx context.Context, opts Options) (*Manager, error) {
m := &Manager{
dir: opts.Dir,
enabledModules: opts.EnabledModules,
}
// Create a new WebAssembly Runtime.
m.cache = wazero.NewCompilationCache()
@@ -114,26 +126,25 @@ func NewManager(ctx context.Context) (*Manager, error) {
}
func (m *Manager) loadModules(ctx context.Context) error {
moduleDir := dir()
_, err := os.Stat(moduleDir)
_, err := os.Stat(m.dir)
if os.IsNotExist(err) {
return nil
}
log.Logger.Debugf("Module dir: %s", moduleDir)
log.Logger.Debugf("Module dir: %s", m.dir)
err = filepath.Walk(moduleDir, func(path string, info fs.FileInfo, err error) error {
err = filepath.Walk(m.dir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
} else if info.IsDir() || filepath.Ext(info.Name()) != ".wasm" {
return nil
}
rel, err := filepath.Rel(moduleDir, path)
rel, err := filepath.Rel(m.dir, path)
if err != nil {
return xerrors.Errorf("failed to get a relative path: %w", err)
}
log.Logger.Infof("Loading %s...", rel)
log.Logger.Infof("Reading %s...", rel)
wasmCode, err := os.ReadFile(path)
if err != nil {
return xerrors.Errorf("file read error: %w", err)
@@ -144,6 +155,12 @@ func (m *Manager) loadModules(ctx context.Context) error {
return xerrors.Errorf("WASM module init error %s: %w", rel, err)
}
// Skip Loading WASM modules if not in the list of enable modules flag.
if len(m.enabledModules) > 0 && !slices.Contains(m.enabledModules, p.Name()) {
return nil
}
log.Logger.Infof("%s loaded", rel)
m.modules = append(m.modules, p)
return nil
@@ -161,6 +178,13 @@ func (m *Manager) Register() {
}
}
func (m *Manager) Deregister() {
for _, mod := range m.modules {
analyzer.DeregisterAnalyzer(analyzer.Type(mod.Name()))
post.DeregisterPostScanner(mod.Name())
}
}
func (m *Manager) Close(ctx context.Context) error {
return m.cache.Close(ctx)
}

View File

@@ -2,7 +2,7 @@ package module_test
import (
"context"
"os"
"io/fs"
"path/filepath"
"runtime"
"testing"
@@ -13,7 +13,6 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/module"
"github.com/aquasecurity/trivy/pkg/scanner/post"
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
)
func TestManager_Register(t *testing.T) {
@@ -23,15 +22,15 @@ func TestManager_Register(t *testing.T) {
}
tests := []struct {
name string
noModuleDir bool
moduleName string
moduleDir string
enabledModules []string
wantAnalyzerVersions analyzer.Versions
wantPostScannerVersions map[string]int
wantErr bool
}{
{
name: "happy path",
moduleName: "happy",
name: "happy path",
moduleDir: "testdata/happy",
wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{
"happy": 1,
@@ -43,8 +42,8 @@ func TestManager_Register(t *testing.T) {
},
},
{
name: "only analyzer",
moduleName: "analyzer",
name: "only analyzer",
moduleDir: "testdata/analyzer",
wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{
"analyzer": 1,
@@ -54,8 +53,8 @@ func TestManager_Register(t *testing.T) {
wantPostScannerVersions: map[string]int{},
},
{
name: "only post scanner",
moduleName: "scanner",
name: "only post scanner",
moduleDir: "testdata/scanner",
wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{},
PostAnalyzers: map[string]int{},
@@ -65,48 +64,59 @@ func TestManager_Register(t *testing.T) {
},
},
{
name: "no module dir",
noModuleDir: true,
moduleName: "happy",
name: "no module dir",
moduleDir: "no-such-dir",
wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{},
PostAnalyzers: map[string]int{},
},
wantPostScannerVersions: map[string]int{},
},
{
name: "pass enabled modules",
moduleDir: "testdata",
enabledModules: []string{
"happy",
"analyzer",
},
wantAnalyzerVersions: analyzer.Versions{
Analyzers: map[string]int{
"happy": 1,
"analyzer": 1,
},
PostAnalyzers: map[string]int{},
},
wantPostScannerVersions: map[string]int{
"happy": 1,
},
},
}
// Confirm that wasm modules are generated beforehand
var count int
err := filepath.WalkDir("testdata", func(path string, d fs.DirEntry, err error) error {
if filepath.Ext(path) == ".wasm" {
count++
}
return nil
})
require.NoError(t, err)
// WASM modules must be generated before running the tests.
require.Equal(t, count, 3, "missing WASM modules, try 'make test' or 'make generate-test-modules'")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
modulePath := filepath.Join("testdata", tt.moduleName, tt.moduleName+".wasm")
// WASM modules must be generated before running this test.
if _, err := os.Stat(modulePath); os.IsNotExist(err) {
require.Fail(t, "missing WASM modules, try 'make test' or 'make generate-test-modules'")
}
// Set up a temp dir for modules
tmpDir := t.TempDir()
t.Setenv("XDG_DATA_HOME", tmpDir)
moduleDir := filepath.Join(tmpDir, module.RelativeDir)
if !tt.noModuleDir {
err := os.MkdirAll(moduleDir, 0777)
require.NoError(t, err)
// Copy the wasm module for testing
_, err = fsutils.CopyFile(modulePath, filepath.Join(moduleDir, filepath.Base(modulePath)))
require.NoError(t, err)
}
m, err := module.NewManager(context.Background())
m, err := module.NewManager(context.Background(), module.Options{
Dir: tt.moduleDir,
EnabledModules: tt.enabledModules,
})
require.NoError(t, err)
// Register analyzer and post scanner from WASM module
m.Register()
defer func() {
analyzer.DeregisterAnalyzer(analyzer.Type(tt.moduleName))
post.DeregisterPostScanner(tt.moduleName)
}()
// Remove registered analyzers and post scanners so that it will not affect other tests.
defer m.Deregister()
// Confirm the analyzer is registered
a, err := analyzer.NewAnalyzerGroup(analyzer.AnalyzerOptions{})