feat: k8s parallel processing (#3693)

Signed-off-by: chenk <hen.keinan@gmail.com>
Co-authored-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
chenk
2023-03-20 13:34:38 +02:00
committed by GitHub
parent b864b3b926
commit 234a360a7a
2 changed files with 51 additions and 23 deletions

View File

@@ -1,10 +1,13 @@
package flag package flag
import ( import (
"fmt"
"strconv" "strconv"
"fmt"
"strings" "strings"
"golang.org/x/xerrors"
"github.com/samber/lo" "github.com/samber/lo"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
) )
@@ -47,6 +50,12 @@ var (
Value: "", Value: "",
Usage: "specify k8s version to validate outdated api by it (example: 1.21.0)", Usage: "specify k8s version to validate outdated api by it (example: 1.21.0)",
} }
ParallelFlag = Flag{
Name: "parallel",
ConfigName: "kubernetes.parallel",
Value: 5,
Usage: "number (between 1-20) of goroutines enabled for parallel scanning",
}
TolerationsFlag = Flag{ TolerationsFlag = Flag{
Name: "tolerations", Name: "tolerations",
ConfigName: "kubernetes.tolerations", ConfigName: "kubernetes.tolerations",
@@ -61,6 +70,7 @@ type K8sFlagGroup struct {
KubeConfig *Flag KubeConfig *Flag
Components *Flag Components *Flag
K8sVersion *Flag K8sVersion *Flag
Parallel *Flag
Tolerations *Flag Tolerations *Flag
} }
@@ -70,6 +80,7 @@ type K8sOptions struct {
KubeConfig string KubeConfig string
Components []string Components []string
K8sVersion string K8sVersion string
Parallel int
Tolerations []corev1.Toleration Tolerations []corev1.Toleration
} }
@@ -80,6 +91,7 @@ func NewK8sFlagGroup() *K8sFlagGroup {
KubeConfig: &KubeConfigFlag, KubeConfig: &KubeConfigFlag,
Components: &ComponentsFlag, Components: &ComponentsFlag,
K8sVersion: &K8sVersionFlag, K8sVersion: &K8sVersionFlag,
Parallel: &ParallelFlag,
Tolerations: &TolerationsFlag, Tolerations: &TolerationsFlag,
} }
} }
@@ -95,6 +107,7 @@ func (f *K8sFlagGroup) Flags() []*Flag {
f.KubeConfig, f.KubeConfig,
f.Components, f.Components,
f.K8sVersion, f.K8sVersion,
f.Parallel,
f.Tolerations, f.Tolerations,
} }
} }
@@ -104,12 +117,21 @@ func (f *K8sFlagGroup) ToOptions() (K8sOptions, error) {
if err != nil { if err != nil {
return K8sOptions{}, err return K8sOptions{}, err
} }
var parallel int
if f.Parallel != nil {
parallel = getInt(f.Parallel)
// check parallel flag is a valid number between 1-20
if parallel < 1 || parallel > 20 {
return K8sOptions{}, xerrors.Errorf("unable to parse parallel value, please ensure that the value entered is a valid number between 1-20.")
}
}
return K8sOptions{ return K8sOptions{
ClusterContext: getString(f.ClusterContext), ClusterContext: getString(f.ClusterContext),
Namespace: getString(f.Namespace), Namespace: getString(f.Namespace),
KubeConfig: getString(f.KubeConfig), KubeConfig: getString(f.KubeConfig),
Components: getStringSlice(f.Components), Components: getStringSlice(f.Components),
K8sVersion: getString(f.K8sVersion), K8sVersion: getString(f.K8sVersion),
Parallel: parallel,
Tolerations: tolerations, Tolerations: tolerations,
}, nil }, nil
} }

View File

@@ -2,9 +2,7 @@ package scanner
import ( import (
"context" "context"
"io"
"github.com/cheggaaa/pb/v3"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/aquasecurity/trivy-kubernetes/pkg/artifacts" "github.com/aquasecurity/trivy-kubernetes/pkg/artifacts"
@@ -12,6 +10,7 @@ import (
"github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/k8s/report" "github.com/aquasecurity/trivy/pkg/k8s/report"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/parallel"
"github.com/aquasecurity/trivy/pkg/scanner/local" "github.com/aquasecurity/trivy/pkg/scanner/local"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
) )
@@ -30,15 +29,7 @@ func NewScanner(cluster string, runner cmd.Runner, opts flag.Options) *Scanner {
} }
} }
func (s *Scanner) Scan(ctx context.Context, artifacts []*artifacts.Artifact) (report.Report, error) { func (s *Scanner) Scan(ctx context.Context, artifactsData []*artifacts.Artifact) (report.Report, error) {
// progress bar
bar := pb.StartNew(len(artifacts))
if s.opts.NoProgress {
bar.SetWriter(io.Discard)
}
defer bar.Finish()
var vulns, misconfigs []report.Resource
// disable logs before scanning // disable logs before scanning
err := log.InitLogger(s.opts.Debug, true) err := log.InitLogger(s.opts.Debug, true)
@@ -56,27 +47,42 @@ func (s *Scanner) Scan(ctx context.Context, artifacts []*artifacts.Artifact) (re
log.Fatal(xerrors.Errorf("can't enable logger error: %w", err)) log.Fatal(xerrors.Errorf("can't enable logger error: %w", err))
} }
}() }()
var vulns, misconfigs []report.Resource
// Loops once over all artifacts, and execute scanners as necessary. Not every artifacts has an image, type scanResult struct {
// so image scanner is not always executed. vulns []report.Resource
for _, artifact := range artifacts { misconfig report.Resource
bar.Increment() }
onItem := func(artifact *artifacts.Artifact) (scanResult, error) {
scanResults := scanResult{}
if s.opts.Scanners.AnyEnabled(types.VulnerabilityScanner, types.SecretScanner) { if s.opts.Scanners.AnyEnabled(types.VulnerabilityScanner, types.SecretScanner) {
resources, err := s.scanVulns(ctx, artifact) vulns, err := s.scanVulns(ctx, artifact)
if err != nil { if err != nil {
return report.Report{}, xerrors.Errorf("scanning vulnerabilities error: %w", err) return scanResult{}, xerrors.Errorf("scanning vulnerabilities error: %w", err)
} }
vulns = append(vulns, resources...) scanResults.vulns = vulns
}
if local.ShouldScanMisconfigOrRbac(s.opts.Scanners) {
misconfig, err := s.scanMisconfigs(ctx, artifact)
if err != nil {
return scanResult{}, xerrors.Errorf("scanning misconfigurations error: %w", err)
}
scanResults.misconfig = misconfig
}
return scanResults, nil
} }
if local.ShouldScanMisconfigOrRbac(s.opts.Scanners) { onResult := func(result scanResult) error {
resource, err := s.scanMisconfigs(ctx, artifact) vulns = append(vulns, result.vulns...)
misconfigs = append(misconfigs, result.misconfig)
return nil
}
p := parallel.NewPipeline(s.opts.Parallel, true, artifactsData, onItem, onResult)
err = p.Do(ctx)
if err != nil { if err != nil {
return report.Report{}, xerrors.Errorf("scanning misconfigurations error: %w", err) return report.Report{}, err
}
misconfigs = append(misconfigs, resource)
}
} }
return report.Report{ return report.Report{