mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-21 23:00:42 -08:00
feat: Adding --module-dir and --enable-modules (#3677)
Co-authored-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
committed by
GitHub
parent
34120f4201
commit
302c8ae24c
@@ -1,18 +1,14 @@
|
|||||||
//go:build module_integration
|
//go:build module_integration
|
||||||
|
|
||||||
package integration
|
package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
||||||
"github.com/aquasecurity/trivy/pkg/module"
|
"github.com/aquasecurity/trivy/pkg/scanner/post"
|
||||||
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestModule(t *testing.T) {
|
func TestModule(t *testing.T) {
|
||||||
@@ -36,17 +32,6 @@ func TestModule(t *testing.T) {
|
|||||||
// Set up testing DB
|
// Set up testing DB
|
||||||
cacheDir := initDB(t)
|
cacheDir := initDB(t)
|
||||||
|
|
||||||
// Set up module dir
|
|
||||||
moduleDir := filepath.Join(cacheDir, module.RelativeDir)
|
|
||||||
err := os.MkdirAll(moduleDir, 0700)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Set up Spring4Shell module
|
|
||||||
t.Setenv("XDG_DATA_HOME", cacheDir)
|
|
||||||
_, err = fsutils.CopyFile(filepath.Join("../", "examples", "module", "spring4shell", "spring4shell.wasm"),
|
|
||||||
filepath.Join(moduleDir, "spring4shell.wasm"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
osArgs := []string{
|
osArgs := []string{
|
||||||
@@ -56,8 +41,11 @@ func TestModule(t *testing.T) {
|
|||||||
"--ignore-unfixed",
|
"--ignore-unfixed",
|
||||||
"--format",
|
"--format",
|
||||||
"json",
|
"json",
|
||||||
"--skip-update",
|
"--skip-db-update",
|
||||||
"--offline-scan",
|
"--offline-scan",
|
||||||
|
"--quiet",
|
||||||
|
"--module-dir",
|
||||||
|
filepath.Join("../", "examples", "module", "spring4shell"),
|
||||||
"--input",
|
"--input",
|
||||||
tt.input,
|
tt.input,
|
||||||
}
|
}
|
||||||
@@ -74,9 +62,12 @@ func TestModule(t *testing.T) {
|
|||||||
}...)
|
}...)
|
||||||
|
|
||||||
// Run Trivy
|
// Run Trivy
|
||||||
err = execute(osArgs)
|
err := execute(osArgs)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer analyzer.DeregisterAnalyzer("spring4shell")
|
defer func() {
|
||||||
|
analyzer.DeregisterAnalyzer("spring4shell")
|
||||||
|
post.DeregisterPostScanner("spring4shell")
|
||||||
|
}()
|
||||||
|
|
||||||
// Compare want and got
|
// Compare want and got
|
||||||
compareReports(t, tt.golden, outputFile)
|
compareReports(t, tt.golden, outputFile)
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co
|
|||||||
// viper.BindPFlag cannot be called in init().
|
// viper.BindPFlag cannot be called in init().
|
||||||
// cf. https://github.com/spf13/cobra/issues/875
|
// cf. https://github.com/spf13/cobra/issues/875
|
||||||
// https://github.com/spf13/viper/issues/233
|
// https://github.com/spf13/viper/issues/233
|
||||||
if err := globalFlags.Bind(cmd.Root()); err != nil {
|
if err := globalFlags.Bind(cmd); err != nil {
|
||||||
return xerrors.Errorf("flag bind error: %w", err)
|
return xerrors.Errorf("flag bind error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,6 +222,7 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
|||||||
ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific
|
ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific
|
||||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||||
|
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||||
ReportFlagGroup: reportFlagGroup,
|
ReportFlagGroup: reportFlagGroup,
|
||||||
@@ -297,6 +298,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
|||||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||||
|
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||||
ReportFlagGroup: reportFlagGroup,
|
ReportFlagGroup: reportFlagGroup,
|
||||||
@@ -351,6 +353,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
|||||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||||
|
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||||
ReportFlagGroup: reportFlagGroup,
|
ReportFlagGroup: reportFlagGroup,
|
||||||
@@ -407,6 +410,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
|||||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||||
|
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||||
ReportFlagGroup: reportFlagGroup,
|
ReportFlagGroup: reportFlagGroup,
|
||||||
@@ -507,6 +511,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
|||||||
serverFlags := &flag.Flags{
|
serverFlags := &flag.Flags{
|
||||||
CacheFlagGroup: flag.NewCacheFlagGroup(),
|
CacheFlagGroup: flag.NewCacheFlagGroup(),
|
||||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||||
|
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||||
RemoteFlagGroup: flag.NewServerFlags(),
|
RemoteFlagGroup: flag.NewServerFlags(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -560,6 +565,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
|||||||
configFlags := &flag.Flags{
|
configFlags := &flag.Flags{
|
||||||
CacheFlagGroup: flag.NewCacheFlagGroup(),
|
CacheFlagGroup: flag.NewCacheFlagGroup(),
|
||||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||||
|
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||||
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
RegoFlagGroup: flag.NewRegoFlagGroup(),
|
||||||
K8sFlagGroup: &flag.K8sFlagGroup{
|
K8sFlagGroup: &flag.K8sFlagGroup{
|
||||||
// disable unneeded flags
|
// disable unneeded flags
|
||||||
@@ -708,6 +714,10 @@ func NewPluginCommand() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
||||||
|
moduleFlags := &flag.Flags{
|
||||||
|
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||||
|
}
|
||||||
|
|
||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "module subcommand",
|
Use: "module subcommand",
|
||||||
Aliases: []string{"m"},
|
Aliases: []string{"m"},
|
||||||
@@ -723,14 +733,23 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
|||||||
Aliases: []string{"i"},
|
Aliases: []string{"i"},
|
||||||
Short: "Install a module",
|
Short: "Install a module",
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
|
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := moduleFlags.Bind(cmd); err != nil {
|
||||||
|
return xerrors.Errorf("flag bind error: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return cmd.Help()
|
return cmd.Help()
|
||||||
}
|
}
|
||||||
|
|
||||||
repo := args[0]
|
repo := args[0]
|
||||||
opts := globalFlags.ToOptions()
|
opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter)
|
||||||
return module.Install(cmd.Context(), repo, opts.Quiet, opts.Insecure)
|
if err != nil {
|
||||||
|
return xerrors.Errorf("flag error: %w", err)
|
||||||
|
}
|
||||||
|
return module.Install(cmd.Context(), opts.ModuleDir, repo, opts.Quiet, opts.Insecure)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&cobra.Command{
|
&cobra.Command{
|
||||||
@@ -738,16 +757,27 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
|||||||
Aliases: []string{"u"},
|
Aliases: []string{"u"},
|
||||||
Short: "Uninstall a module",
|
Short: "Uninstall a module",
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
|
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := moduleFlags.Bind(cmd); err != nil {
|
||||||
|
return xerrors.Errorf("flag bind error: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return cmd.Help()
|
return cmd.Help()
|
||||||
}
|
}
|
||||||
|
|
||||||
repo := args[0]
|
repo := args[0]
|
||||||
return module.Uninstall(cmd.Context(), repo)
|
opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("flag error: %w", err)
|
||||||
|
}
|
||||||
|
return module.Uninstall(cmd.Context(), opts.ModuleDir, repo)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
moduleFlags.AddFlags(cmd)
|
||||||
cmd.SetFlagErrorFunc(flagErrorFunc)
|
cmd.SetFlagErrorFunc(flagErrorFunc)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
@@ -901,6 +931,7 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
|
|||||||
DBFlagGroup: flag.NewDBFlagGroup(),
|
DBFlagGroup: flag.NewDBFlagGroup(),
|
||||||
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
|
||||||
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
|
||||||
|
ModuleFlagGroup: flag.NewModuleFlagGroup(),
|
||||||
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
|
||||||
ReportFlagGroup: reportFlagGroup,
|
ReportFlagGroup: reportFlagGroup,
|
||||||
ScanFlagGroup: flag.NewScanFlagGroup(),
|
ScanFlagGroup: flag.NewScanFlagGroup(),
|
||||||
|
|||||||
@@ -131,7 +131,10 @@ func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...runnerOptio
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize WASM modules
|
// Initialize WASM modules
|
||||||
m, err := module.NewManager(ctx)
|
m, err := module.NewManager(ctx, module.Options{
|
||||||
|
Dir: cliOptions.ModuleDir,
|
||||||
|
EnabledModules: cliOptions.EnabledModules,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("WASM module error: %w", err)
|
return nil, xerrors.Errorf("WASM module error: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,10 @@ func Run(ctx context.Context, opts flag.Options) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize WASM modules
|
// Initialize WASM modules
|
||||||
m, err := module.NewManager(ctx)
|
m, err := module.NewManager(ctx, module.Options{
|
||||||
|
Dir: opts.ModuleDir,
|
||||||
|
EnabledModules: opts.EnabledModules,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("WASM module error: %w", err)
|
return xerrors.Errorf("WASM module error: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
64
pkg/flag/module_flags.go
Normal file
64
pkg/flag/module_flags.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package flag
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/aquasecurity/trivy/pkg/module"
|
||||||
|
)
|
||||||
|
|
||||||
|
// e.g. config yaml
|
||||||
|
// module:
|
||||||
|
// dir: "/path/to/my_modules"
|
||||||
|
// enable-modules:
|
||||||
|
// - spring4shell
|
||||||
|
|
||||||
|
var (
|
||||||
|
ModuleDirFlag = Flag{
|
||||||
|
Name: "module-dir",
|
||||||
|
ConfigName: "module.dir",
|
||||||
|
Value: module.DefaultDir,
|
||||||
|
Usage: "specify directory to the wasm modules that will be loaded",
|
||||||
|
Persistent: true,
|
||||||
|
}
|
||||||
|
EnableModulesFlag = Flag{
|
||||||
|
Name: "enable-modules",
|
||||||
|
ConfigName: "module.enable-modules",
|
||||||
|
Value: []string{},
|
||||||
|
Usage: "[EXPERIMENTAL] module names to enable",
|
||||||
|
Persistent: true,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModuleFlagGroup defines flags for modules
|
||||||
|
type ModuleFlagGroup struct {
|
||||||
|
Dir *Flag
|
||||||
|
EnabledModules *Flag
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModuleOptions struct {
|
||||||
|
ModuleDir string
|
||||||
|
EnabledModules []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewModuleFlagGroup() *ModuleFlagGroup {
|
||||||
|
return &ModuleFlagGroup{
|
||||||
|
Dir: &ModuleDirFlag,
|
||||||
|
EnabledModules: &EnableModulesFlag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *ModuleFlagGroup) Name() string {
|
||||||
|
return "Module"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *ModuleFlagGroup) Flags() []*Flag {
|
||||||
|
return []*Flag{
|
||||||
|
f.Dir,
|
||||||
|
f.EnabledModules,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *ModuleFlagGroup) ToOptions() ModuleOptions {
|
||||||
|
return ModuleOptions{
|
||||||
|
ModuleDir: getString(f.Dir),
|
||||||
|
EnabledModules: getStringSlice(f.EnabledModules),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -66,6 +66,7 @@ type Flags struct {
|
|||||||
K8sFlagGroup *K8sFlagGroup
|
K8sFlagGroup *K8sFlagGroup
|
||||||
LicenseFlagGroup *LicenseFlagGroup
|
LicenseFlagGroup *LicenseFlagGroup
|
||||||
MisconfFlagGroup *MisconfFlagGroup
|
MisconfFlagGroup *MisconfFlagGroup
|
||||||
|
ModuleFlagGroup *ModuleFlagGroup
|
||||||
RemoteFlagGroup *RemoteFlagGroup
|
RemoteFlagGroup *RemoteFlagGroup
|
||||||
RegoFlagGroup *RegoFlagGroup
|
RegoFlagGroup *RegoFlagGroup
|
||||||
RepoFlagGroup *RepoFlagGroup
|
RepoFlagGroup *RepoFlagGroup
|
||||||
@@ -87,6 +88,7 @@ type Options struct {
|
|||||||
K8sOptions
|
K8sOptions
|
||||||
LicenseOptions
|
LicenseOptions
|
||||||
MisconfOptions
|
MisconfOptions
|
||||||
|
ModuleOptions
|
||||||
RegoOptions
|
RegoOptions
|
||||||
RemoteOptions
|
RemoteOptions
|
||||||
RepoOptions
|
RepoOptions
|
||||||
@@ -156,15 +158,14 @@ func bind(cmd *cobra.Command, flag *Flag) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Bind CLI flags
|
// Bind CLI flags
|
||||||
if flag.Persistent {
|
f := cmd.Flags().Lookup(flag.Name)
|
||||||
if err := viper.BindPFlag(flag.ConfigName, cmd.PersistentFlags().Lookup(flag.Name)); err != nil {
|
if f == nil {
|
||||||
return xerrors.Errorf("bind flag error: %w", err)
|
// Lookup local persistent flags
|
||||||
|
f = cmd.PersistentFlags().Lookup(flag.Name)
|
||||||
}
|
}
|
||||||
} else {
|
if err := viper.BindPFlag(flag.ConfigName, f); err != nil {
|
||||||
if err := viper.BindPFlag(flag.ConfigName, cmd.Flags().Lookup(flag.Name)); err != nil {
|
|
||||||
return xerrors.Errorf("bind flag error: %w", err)
|
return xerrors.Errorf("bind flag error: %w", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Bind environmental variable
|
// Bind environmental variable
|
||||||
if err := bindEnv(flag); err != nil {
|
if err := bindEnv(flag); err != nil {
|
||||||
@@ -275,6 +276,9 @@ func (f *Flags) groups() []FlagGroup {
|
|||||||
if f.MisconfFlagGroup != nil {
|
if f.MisconfFlagGroup != nil {
|
||||||
groups = append(groups, f.MisconfFlagGroup)
|
groups = append(groups, f.MisconfFlagGroup)
|
||||||
}
|
}
|
||||||
|
if f.ModuleFlagGroup != nil {
|
||||||
|
groups = append(groups, f.ModuleFlagGroup)
|
||||||
|
}
|
||||||
if f.SecretFlagGroup != nil {
|
if f.SecretFlagGroup != nil {
|
||||||
groups = append(groups, f.SecretFlagGroup)
|
groups = append(groups, f.SecretFlagGroup)
|
||||||
}
|
}
|
||||||
@@ -404,6 +408,10 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if f.ModuleFlagGroup != nil {
|
||||||
|
opts.ModuleOptions = f.ModuleFlagGroup.ToOptions()
|
||||||
|
}
|
||||||
|
|
||||||
if f.RegoFlagGroup != nil {
|
if f.RegoFlagGroup != nil {
|
||||||
opts.RegoOptions, err = f.RegoFlagGroup.ToOptions()
|
opts.RegoOptions, err = f.RegoFlagGroup.ToOptions()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
const mediaType = "application/vnd.module.wasm.content.layer.v1+wasm"
|
const mediaType = "application/vnd.module.wasm.content.layer.v1+wasm"
|
||||||
|
|
||||||
// Install installs a module
|
// Install installs a module
|
||||||
func Install(ctx context.Context, repo string, quiet, insecure bool) error {
|
func Install(ctx context.Context, dir, repo string, quiet, insecure bool) error {
|
||||||
ref, err := name.ParseReference(repo)
|
ref, err := name.ParseReference(repo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("repository parse error: %w", err)
|
return xerrors.Errorf("repository parse error: %w", err)
|
||||||
@@ -27,7 +27,7 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error {
|
|||||||
return xerrors.Errorf("module initialize error: %w", err)
|
return xerrors.Errorf("module initialize error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dst := filepath.Join(dir(), ref.Context().Name())
|
dst := filepath.Join(dir, ref.Context().Name())
|
||||||
log.Logger.Debugf("Installing the module to %s...", dst)
|
log.Logger.Debugf("Installing the module to %s...", dst)
|
||||||
|
|
||||||
if err = artifact.Download(ctx, dst); err != nil {
|
if err = artifact.Download(ctx, dst); err != nil {
|
||||||
@@ -38,14 +38,14 @@ func Install(ctx context.Context, repo string, quiet, insecure bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Uninstall uninstalls a module
|
// Uninstall uninstalls a module
|
||||||
func Uninstall(_ context.Context, repo string) error {
|
func Uninstall(_ context.Context, dir, repo string) error {
|
||||||
ref, err := name.ParseReference(repo)
|
ref, err := name.ParseReference(repo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("repository parse error: %w", err)
|
return xerrors.Errorf("repository parse error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Logger.Infof("Uninstalling %s ...", repo)
|
log.Logger.Infof("Uninstalling %s ...", repo)
|
||||||
dst := filepath.Join(dir(), ref.Context().Name())
|
dst := filepath.Join(dir, ref.Context().Name())
|
||||||
if err = os.RemoveAll(dst); err != nil {
|
if err = os.RemoveAll(dst); err != nil {
|
||||||
return xerrors.Errorf("remove error: %w", err)
|
return xerrors.Errorf("remove error: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ var (
|
|||||||
}
|
}
|
||||||
|
|
||||||
RelativeDir = filepath.Join(".trivy", "modules")
|
RelativeDir = filepath.Join(".trivy", "modules")
|
||||||
|
|
||||||
|
DefaultDir = dir()
|
||||||
)
|
)
|
||||||
|
|
||||||
// logDebug is defined as an api.GoModuleFunc for lower overhead vs reflection.
|
// logDebug is defined as an api.GoModuleFunc for lower overhead vs reflection.
|
||||||
@@ -94,13 +96,23 @@ func readMemory(mem api.Memory, offset, size uint32) []byte {
|
|||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
Dir string
|
||||||
|
EnabledModules []string
|
||||||
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
cache wazero.CompilationCache
|
cache wazero.CompilationCache
|
||||||
modules []*wasmModule
|
modules []*wasmModule
|
||||||
|
dir string
|
||||||
|
enabledModules []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(ctx context.Context) (*Manager, error) {
|
func NewManager(ctx context.Context, opts Options) (*Manager, error) {
|
||||||
m := &Manager{}
|
m := &Manager{
|
||||||
|
dir: opts.Dir,
|
||||||
|
enabledModules: opts.EnabledModules,
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new WebAssembly Runtime.
|
// Create a new WebAssembly Runtime.
|
||||||
m.cache = wazero.NewCompilationCache()
|
m.cache = wazero.NewCompilationCache()
|
||||||
@@ -114,26 +126,25 @@ func NewManager(ctx context.Context) (*Manager, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) loadModules(ctx context.Context) error {
|
func (m *Manager) loadModules(ctx context.Context) error {
|
||||||
moduleDir := dir()
|
_, err := os.Stat(m.dir)
|
||||||
_, err := os.Stat(moduleDir)
|
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
log.Logger.Debugf("Module dir: %s", moduleDir)
|
log.Logger.Debugf("Module dir: %s", m.dir)
|
||||||
|
|
||||||
err = filepath.Walk(moduleDir, func(path string, info fs.FileInfo, err error) error {
|
err = filepath.Walk(m.dir, func(path string, info fs.FileInfo, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if info.IsDir() || filepath.Ext(info.Name()) != ".wasm" {
|
} else if info.IsDir() || filepath.Ext(info.Name()) != ".wasm" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rel, err := filepath.Rel(moduleDir, path)
|
rel, err := filepath.Rel(m.dir, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("failed to get a relative path: %w", err)
|
return xerrors.Errorf("failed to get a relative path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Logger.Infof("Loading %s...", rel)
|
log.Logger.Infof("Reading %s...", rel)
|
||||||
wasmCode, err := os.ReadFile(path)
|
wasmCode, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("file read error: %w", err)
|
return xerrors.Errorf("file read error: %w", err)
|
||||||
@@ -144,6 +155,12 @@ func (m *Manager) loadModules(ctx context.Context) error {
|
|||||||
return xerrors.Errorf("WASM module init error %s: %w", rel, err)
|
return xerrors.Errorf("WASM module init error %s: %w", rel, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip Loading WASM modules if not in the list of enable modules flag.
|
||||||
|
if len(m.enabledModules) > 0 && !slices.Contains(m.enabledModules, p.Name()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Logger.Infof("%s loaded", rel)
|
||||||
m.modules = append(m.modules, p)
|
m.modules = append(m.modules, p)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -161,6 +178,13 @@ func (m *Manager) Register() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Deregister() {
|
||||||
|
for _, mod := range m.modules {
|
||||||
|
analyzer.DeregisterAnalyzer(analyzer.Type(mod.Name()))
|
||||||
|
post.DeregisterPostScanner(mod.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) Close(ctx context.Context) error {
|
func (m *Manager) Close(ctx context.Context) error {
|
||||||
return m.cache.Close(ctx)
|
return m.cache.Close(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package module_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
"io/fs"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
||||||
"github.com/aquasecurity/trivy/pkg/module"
|
"github.com/aquasecurity/trivy/pkg/module"
|
||||||
"github.com/aquasecurity/trivy/pkg/scanner/post"
|
"github.com/aquasecurity/trivy/pkg/scanner/post"
|
||||||
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestManager_Register(t *testing.T) {
|
func TestManager_Register(t *testing.T) {
|
||||||
@@ -23,15 +22,15 @@ func TestManager_Register(t *testing.T) {
|
|||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
noModuleDir bool
|
moduleDir string
|
||||||
moduleName string
|
enabledModules []string
|
||||||
wantAnalyzerVersions analyzer.Versions
|
wantAnalyzerVersions analyzer.Versions
|
||||||
wantPostScannerVersions map[string]int
|
wantPostScannerVersions map[string]int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "happy path",
|
name: "happy path",
|
||||||
moduleName: "happy",
|
moduleDir: "testdata/happy",
|
||||||
wantAnalyzerVersions: analyzer.Versions{
|
wantAnalyzerVersions: analyzer.Versions{
|
||||||
Analyzers: map[string]int{
|
Analyzers: map[string]int{
|
||||||
"happy": 1,
|
"happy": 1,
|
||||||
@@ -44,7 +43,7 @@ func TestManager_Register(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "only analyzer",
|
name: "only analyzer",
|
||||||
moduleName: "analyzer",
|
moduleDir: "testdata/analyzer",
|
||||||
wantAnalyzerVersions: analyzer.Versions{
|
wantAnalyzerVersions: analyzer.Versions{
|
||||||
Analyzers: map[string]int{
|
Analyzers: map[string]int{
|
||||||
"analyzer": 1,
|
"analyzer": 1,
|
||||||
@@ -55,7 +54,7 @@ func TestManager_Register(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "only post scanner",
|
name: "only post scanner",
|
||||||
moduleName: "scanner",
|
moduleDir: "testdata/scanner",
|
||||||
wantAnalyzerVersions: analyzer.Versions{
|
wantAnalyzerVersions: analyzer.Versions{
|
||||||
Analyzers: map[string]int{},
|
Analyzers: map[string]int{},
|
||||||
PostAnalyzers: map[string]int{},
|
PostAnalyzers: map[string]int{},
|
||||||
@@ -66,47 +65,58 @@ func TestManager_Register(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no module dir",
|
name: "no module dir",
|
||||||
noModuleDir: true,
|
moduleDir: "no-such-dir",
|
||||||
moduleName: "happy",
|
|
||||||
wantAnalyzerVersions: analyzer.Versions{
|
wantAnalyzerVersions: analyzer.Versions{
|
||||||
Analyzers: map[string]int{},
|
Analyzers: map[string]int{},
|
||||||
PostAnalyzers: map[string]int{},
|
PostAnalyzers: map[string]int{},
|
||||||
},
|
},
|
||||||
wantPostScannerVersions: map[string]int{},
|
wantPostScannerVersions: map[string]int{},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "pass enabled modules",
|
||||||
|
moduleDir: "testdata",
|
||||||
|
enabledModules: []string{
|
||||||
|
"happy",
|
||||||
|
"analyzer",
|
||||||
|
},
|
||||||
|
wantAnalyzerVersions: analyzer.Versions{
|
||||||
|
Analyzers: map[string]int{
|
||||||
|
"happy": 1,
|
||||||
|
"analyzer": 1,
|
||||||
|
},
|
||||||
|
PostAnalyzers: map[string]int{},
|
||||||
|
},
|
||||||
|
wantPostScannerVersions: map[string]int{
|
||||||
|
"happy": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Confirm that wasm modules are generated beforehand
|
||||||
|
var count int
|
||||||
|
err := filepath.WalkDir("testdata", func(path string, d fs.DirEntry, err error) error {
|
||||||
|
if filepath.Ext(path) == ".wasm" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
// WASM modules must be generated before running the tests.
|
||||||
|
require.Equal(t, count, 3, "missing WASM modules, try 'make test' or 'make generate-test-modules'")
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
modulePath := filepath.Join("testdata", tt.moduleName, tt.moduleName+".wasm")
|
m, err := module.NewManager(context.Background(), module.Options{
|
||||||
|
Dir: tt.moduleDir,
|
||||||
// WASM modules must be generated before running this test.
|
EnabledModules: tt.enabledModules,
|
||||||
if _, err := os.Stat(modulePath); os.IsNotExist(err) {
|
})
|
||||||
require.Fail(t, "missing WASM modules, try 'make test' or 'make generate-test-modules'")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up a temp dir for modules
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
t.Setenv("XDG_DATA_HOME", tmpDir)
|
|
||||||
moduleDir := filepath.Join(tmpDir, module.RelativeDir)
|
|
||||||
|
|
||||||
if !tt.noModuleDir {
|
|
||||||
err := os.MkdirAll(moduleDir, 0777)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Copy the wasm module for testing
|
|
||||||
_, err = fsutils.CopyFile(modulePath, filepath.Join(moduleDir, filepath.Base(modulePath)))
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := module.NewManager(context.Background())
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Register analyzer and post scanner from WASM module
|
// Register analyzer and post scanner from WASM module
|
||||||
m.Register()
|
m.Register()
|
||||||
defer func() {
|
|
||||||
analyzer.DeregisterAnalyzer(analyzer.Type(tt.moduleName))
|
// Remove registered analyzers and post scanners so that it will not affect other tests.
|
||||||
post.DeregisterPostScanner(tt.moduleName)
|
defer m.Deregister()
|
||||||
}()
|
|
||||||
|
|
||||||
// Confirm the analyzer is registered
|
// Confirm the analyzer is registered
|
||||||
a, err := analyzer.NewAnalyzerGroup(analyzer.AnalyzerOptions{})
|
a, err := analyzer.NewAnalyzerGroup(analyzer.AnalyzerOptions{})
|
||||||
|
|||||||
Reference in New Issue
Block a user