diff --git a/pkg/flag/kubernetes_flags.go b/pkg/flag/kubernetes_flags.go index e96ca01119..1259ba353e 100644 --- a/pkg/flag/kubernetes_flags.go +++ b/pkg/flag/kubernetes_flags.go @@ -1,10 +1,13 @@ package flag import ( - "fmt" "strconv" + + "fmt" "strings" + "golang.org/x/xerrors" + "github.com/samber/lo" corev1 "k8s.io/api/core/v1" ) @@ -47,6 +50,12 @@ var ( Value: "", Usage: "specify k8s version to validate outdated api by it (example: 1.21.0)", } + ParallelFlag = Flag{ + Name: "parallel", + ConfigName: "kubernetes.parallel", + Value: 5, + Usage: "number (between 1-20) of goroutines enabled for parallel scanning", + } TolerationsFlag = Flag{ Name: "tolerations", ConfigName: "kubernetes.tolerations", @@ -61,6 +70,7 @@ type K8sFlagGroup struct { KubeConfig *Flag Components *Flag K8sVersion *Flag + Parallel *Flag Tolerations *Flag } @@ -70,6 +80,7 @@ type K8sOptions struct { KubeConfig string Components []string K8sVersion string + Parallel int Tolerations []corev1.Toleration } @@ -80,6 +91,7 @@ func NewK8sFlagGroup() *K8sFlagGroup { KubeConfig: &KubeConfigFlag, Components: &ComponentsFlag, K8sVersion: &K8sVersionFlag, + Parallel: &ParallelFlag, Tolerations: &TolerationsFlag, } } @@ -95,6 +107,7 @@ func (f *K8sFlagGroup) Flags() []*Flag { f.KubeConfig, f.Components, f.K8sVersion, + f.Parallel, f.Tolerations, } } @@ -104,12 +117,21 @@ func (f *K8sFlagGroup) ToOptions() (K8sOptions, error) { if err != nil { return K8sOptions{}, err } + var parallel int + if f.Parallel != nil { + parallel = getInt(f.Parallel) + // check parallel flag is a valid number between 1-20 + if parallel < 1 || parallel > 20 { + return K8sOptions{}, xerrors.Errorf("unable to parse parallel value, please ensure that the value entered is a valid number between 1-20.") + } + } return K8sOptions{ ClusterContext: getString(f.ClusterContext), Namespace: getString(f.Namespace), KubeConfig: getString(f.KubeConfig), Components: getStringSlice(f.Components), K8sVersion: getString(f.K8sVersion), + Parallel: parallel, Tolerations: tolerations, }, nil } diff --git a/pkg/k8s/scanner/scanner.go b/pkg/k8s/scanner/scanner.go index 0133db8982..2e10fa8baa 100644 --- a/pkg/k8s/scanner/scanner.go +++ b/pkg/k8s/scanner/scanner.go @@ -2,9 +2,7 @@ package scanner import ( "context" - "io" - "github.com/cheggaaa/pb/v3" "golang.org/x/xerrors" "github.com/aquasecurity/trivy-kubernetes/pkg/artifacts" @@ -12,6 +10,7 @@ import ( "github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/k8s/report" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/parallel" "github.com/aquasecurity/trivy/pkg/scanner/local" "github.com/aquasecurity/trivy/pkg/types" ) @@ -30,15 +29,7 @@ func NewScanner(cluster string, runner cmd.Runner, opts flag.Options) *Scanner { } } -func (s *Scanner) Scan(ctx context.Context, artifacts []*artifacts.Artifact) (report.Report, error) { - // progress bar - bar := pb.StartNew(len(artifacts)) - if s.opts.NoProgress { - bar.SetWriter(io.Discard) - } - defer bar.Finish() - - var vulns, misconfigs []report.Resource +func (s *Scanner) Scan(ctx context.Context, artifactsData []*artifacts.Artifact) (report.Report, error) { // disable logs before scanning err := log.InitLogger(s.opts.Debug, true) @@ -56,27 +47,42 @@ func (s *Scanner) Scan(ctx context.Context, artifacts []*artifacts.Artifact) (re log.Fatal(xerrors.Errorf("can't enable logger error: %w", err)) } }() + var vulns, misconfigs []report.Resource - // Loops once over all artifacts, and execute scanners as necessary. Not every artifacts has an image, - // so image scanner is not always executed. - for _, artifact := range artifacts { - bar.Increment() + type scanResult struct { + vulns []report.Resource + misconfig report.Resource + } + onItem := func(artifact *artifacts.Artifact) (scanResult, error) { + scanResults := scanResult{} if s.opts.Scanners.AnyEnabled(types.VulnerabilityScanner, types.SecretScanner) { - resources, err := s.scanVulns(ctx, artifact) + vulns, err := s.scanVulns(ctx, artifact) if err != nil { - return report.Report{}, xerrors.Errorf("scanning vulnerabilities error: %w", err) + return scanResult{}, xerrors.Errorf("scanning vulnerabilities error: %w", err) } - vulns = append(vulns, resources...) + scanResults.vulns = vulns } - if local.ShouldScanMisconfigOrRbac(s.opts.Scanners) { - resource, err := s.scanMisconfigs(ctx, artifact) + misconfig, err := s.scanMisconfigs(ctx, artifact) if err != nil { - return report.Report{}, xerrors.Errorf("scanning misconfigurations error: %w", err) + return scanResult{}, xerrors.Errorf("scanning misconfigurations error: %w", err) } - misconfigs = append(misconfigs, resource) + scanResults.misconfig = misconfig } + return scanResults, nil + } + + onResult := func(result scanResult) error { + vulns = append(vulns, result.vulns...) + misconfigs = append(misconfigs, result.misconfig) + return nil + } + + p := parallel.NewPipeline(s.opts.Parallel, true, artifactsData, onItem, onResult) + err = p.Do(ctx) + if err != nil { + return report.Report{}, err } return report.Report{