diff --git a/pkg/fanal/artifact/image/image.go b/pkg/fanal/artifact/image/image.go index 59a0d29562..7bd0510dac 100644 --- a/pkg/fanal/artifact/image/image.go +++ b/pkg/fanal/artifact/image/image.go @@ -14,7 +14,6 @@ import ( v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/samber/lo" "golang.org/x/exp/slices" - "golang.org/x/sync/errgroup" "golang.org/x/xerrors" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" @@ -25,6 +24,7 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/fanal/walker" "github.com/aquasecurity/trivy/pkg/mapfs" + "github.com/aquasecurity/trivy/pkg/parallel" "github.com/aquasecurity/trivy/pkg/semaphore" "github.com/aquasecurity/trivy/pkg/syncx" ) @@ -216,49 +216,33 @@ func (a Artifact) consolidateCreatedBy(diffIDs, layerKeys []string, configFile * func (a Artifact) inspect(ctx context.Context, missingImage string, layerKeys, baseDiffIDs []string, layerKeyMap map[string]LayerInfo, configFile *v1.ConfigFile) error { - group, groupCtx := errgroup.WithContext(ctx) - concurrencyLimit := 5 - if a.artifactOption.Slow { - concurrencyLimit = 1 - } - group.SetLimit(concurrencyLimit) - var osFound types.OS - for _, k := range layerKeys { - layerKey := k - ctx := groupCtx - group.Go(func() error { - layer := layerKeyMap[layerKey] + workers := lo.Ternary(a.artifactOption.Slow, 1, 5) + p := parallel.NewPipeline(workers, false, layerKeys, func(ctx context.Context, layerKey string) (any, error) { + layer := layerKeyMap[layerKey] - // If it is a base layer, secret scanning should not be performed. - var disabledAnalyzers []analyzer.Type - if slices.Contains(baseDiffIDs, layer.DiffID) { - disabledAnalyzers = append(disabledAnalyzers, analyzer.TypeSecret) - } - - layerInfo, err := a.inspectLayer(ctx, layer, disabledAnalyzers) - if err != nil { - return xerrors.Errorf("failed to analyze layer (%s): %w", layer.DiffID, err) - } - if err = a.cache.PutBlob(layerKey, layerInfo); err != nil { - return xerrors.Errorf("failed to store layer: %s in cache: %w", layerKey, err) - } - if lo.IsNotEmpty(layerInfo.OS) { - osFound = layerInfo.OS - } - return nil - }) - - if ctx.Err() != nil { - break + // If it is a base layer, secret scanning should not be performed. + var disabledAnalyzers []analyzer.Type + if slices.Contains(baseDiffIDs, layer.DiffID) { + disabledAnalyzers = append(disabledAnalyzers, analyzer.TypeSecret) } - } - if err := group.Wait(); err != nil { - if ctx.Err() != nil { - return xerrors.Errorf("timeout: %w", ctx.Err()) + layerInfo, err := a.inspectLayer(ctx, layer, disabledAnalyzers) + if err != nil { + return nil, xerrors.Errorf("failed to analyze layer (%s): %w", layer.DiffID, err) } - return err + if err = a.cache.PutBlob(layerKey, layerInfo); err != nil { + return nil, xerrors.Errorf("failed to store layer: %s in cache: %w", layerKey, err) + } + if lo.IsNotEmpty(layerInfo.OS) { + osFound = layerInfo.OS + } + return nil, nil + + }, nil) + + if err := p.Do(ctx); err != nil { + return xerrors.Errorf("pipeline error: %w", err) } if missingImage != "" { @@ -268,7 +252,6 @@ func (a Artifact) inspect(ctx context.Context, missingImage string, layerKeys, b } return nil - } func (a Artifact) inspectLayer(ctx context.Context, layerInfo LayerInfo, disabled []analyzer.Type) (types.BlobInfo, error) { diff --git a/pkg/k8s/scanner/scanner.go b/pkg/k8s/scanner/scanner.go index 07f6b8f8da..f5eae5abe9 100644 --- a/pkg/k8s/scanner/scanner.go +++ b/pkg/k8s/scanner/scanner.go @@ -54,7 +54,7 @@ func (s *Scanner) Scan(ctx context.Context, artifactsData []*artifacts.Artifact) misconfig report.Resource } - onItem := func(artifact *artifacts.Artifact) (scanResult, error) { + onItem := func(ctx context.Context, artifact *artifacts.Artifact) (scanResult, error) { scanResults := scanResult{} if s.opts.Scanners.AnyEnabled(types.VulnerabilityScanner, types.SecretScanner) { vulns, err := s.scanVulns(ctx, artifact) diff --git a/pkg/parallel/pipeline.go b/pkg/parallel/pipeline.go index f19f064d61..d6879d7a11 100644 --- a/pkg/parallel/pipeline.go +++ b/pkg/parallel/pipeline.go @@ -4,7 +4,6 @@ import ( "context" "github.com/cheggaaa/pb/v3" - "golang.org/x/sync/errgroup" ) @@ -19,13 +18,17 @@ type Pipeline[T, U any] struct { } // onItem represents a function type that takes an input element and returns an output element. -type onItem[T, U any] func(T) (U, error) +type onItem[T, U any] func(context.Context, T) (U, error) // onResult represents a function type that takes an output element. type onResult[U any] func(U) error func NewPipeline[T, U any](numWorkers int, progress bool, items []T, fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] { + if fn2 == nil { + // In case where there is no need to process the return values + fn2 = func(_ U) error { return nil } + } return Pipeline[T, U]{ numWorkers: numWorkers, progress: progress, @@ -71,7 +74,7 @@ func (p *Pipeline[T, U]) Do(ctx context.Context) error { for i := 0; i < p.numWorkers; i++ { g.Go(func() error { for item := range itemCh { - res, err := p.onItem(item) + res, err := p.onItem(ctx, item) if err != nil { return err } diff --git a/pkg/parallel/pipeline_test.go b/pkg/parallel/pipeline_test.go index c4a064e695..60b8cec100 100644 --- a/pkg/parallel/pipeline_test.go +++ b/pkg/parallel/pipeline_test.go @@ -2,6 +2,7 @@ package parallel_test import ( "context" + "fmt" "math" "testing" @@ -15,13 +16,13 @@ func TestPipeline_Do(t *testing.T) { type field struct { numWorkers int items []float64 - onItem func(float64) (float64, error) + onItem func(context.Context, float64) (float64, error) } type testCase struct { name string field field want float64 - wantErr bool + wantErr require.ErrorAssertionFunc } tests := []testCase{ { @@ -40,11 +41,12 @@ func TestPipeline_Do(t *testing.T) { 9, 10, }, - onItem: func(f float64) (float64, error) { + onItem: func(_ context.Context, f float64) (float64, error) { return math.Pow(f, 2), nil }, }, - want: 385, + want: 385, + wantErr: require.NoError, }, { name: "ceil", @@ -60,11 +62,41 @@ func TestPipeline_Do(t *testing.T) { -2.2, -3.3, }, - onItem: func(f float64) (float64, error) { + onItem: func(_ context.Context, f float64) (float64, error) { return math.Round(f), nil }, }, - want: 10, + want: 10, + wantErr: require.NoError, + }, + { + name: "error in series", + field: field{ + numWorkers: 1, + items: []float64{ + 1, + 2, + 3, + }, + onItem: func(_ context.Context, f float64) (float64, error) { + return 0, fmt.Errorf("error") + }, + }, + wantErr: require.Error, + }, + { + name: "error in parallel", + field: field{ + numWorkers: 3, + items: []float64{ + 1, + 2, + }, + onItem: func(_ context.Context, f float64) (float64, error) { + return 0, fmt.Errorf("error") + }, + }, + wantErr: require.Error, }, } for _, tt := range tests { @@ -75,7 +107,7 @@ func TestPipeline_Do(t *testing.T) { return nil }) err := p.Do(context.Background()) - require.NoError(t, err) + tt.wantErr(t, err) assert.Equal(t, tt.want, got) }) }