refactor: enable cases where return values are not needed in pipeline (#4443)

This commit is contained in:
Teppei Fukuda
2023-05-22 08:11:24 +03:00
committed by GitHub
parent 29b5f7e8ec
commit e1361368a1
4 changed files with 69 additions and 51 deletions

View File

@@ -14,7 +14,6 @@ import (
v1 "github.com/google/go-containerregistry/pkg/v1" v1 "github.com/google/go-containerregistry/pkg/v1"
"github.com/samber/lo" "github.com/samber/lo"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/analyzer"
@@ -25,6 +24,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/fanal/walker" "github.com/aquasecurity/trivy/pkg/fanal/walker"
"github.com/aquasecurity/trivy/pkg/mapfs" "github.com/aquasecurity/trivy/pkg/mapfs"
"github.com/aquasecurity/trivy/pkg/parallel"
"github.com/aquasecurity/trivy/pkg/semaphore" "github.com/aquasecurity/trivy/pkg/semaphore"
"github.com/aquasecurity/trivy/pkg/syncx" "github.com/aquasecurity/trivy/pkg/syncx"
) )
@@ -216,18 +216,9 @@ func (a Artifact) consolidateCreatedBy(diffIDs, layerKeys []string, configFile *
func (a Artifact) inspect(ctx context.Context, missingImage string, layerKeys, baseDiffIDs []string, func (a Artifact) inspect(ctx context.Context, missingImage string, layerKeys, baseDiffIDs []string,
layerKeyMap map[string]LayerInfo, configFile *v1.ConfigFile) error { layerKeyMap map[string]LayerInfo, configFile *v1.ConfigFile) error {
group, groupCtx := errgroup.WithContext(ctx)
concurrencyLimit := 5
if a.artifactOption.Slow {
concurrencyLimit = 1
}
group.SetLimit(concurrencyLimit)
var osFound types.OS var osFound types.OS
for _, k := range layerKeys { workers := lo.Ternary(a.artifactOption.Slow, 1, 5)
layerKey := k p := parallel.NewPipeline(workers, false, layerKeys, func(ctx context.Context, layerKey string) (any, error) {
ctx := groupCtx
group.Go(func() error {
layer := layerKeyMap[layerKey] layer := layerKeyMap[layerKey]
// If it is a base layer, secret scanning should not be performed. // If it is a base layer, secret scanning should not be performed.
@@ -238,27 +229,20 @@ func (a Artifact) inspect(ctx context.Context, missingImage string, layerKeys, b
layerInfo, err := a.inspectLayer(ctx, layer, disabledAnalyzers) layerInfo, err := a.inspectLayer(ctx, layer, disabledAnalyzers)
if err != nil { if err != nil {
return xerrors.Errorf("failed to analyze layer (%s): %w", layer.DiffID, err) return nil, xerrors.Errorf("failed to analyze layer (%s): %w", layer.DiffID, err)
} }
if err = a.cache.PutBlob(layerKey, layerInfo); err != nil { if err = a.cache.PutBlob(layerKey, layerInfo); err != nil {
return xerrors.Errorf("failed to store layer: %s in cache: %w", layerKey, err) return nil, xerrors.Errorf("failed to store layer: %s in cache: %w", layerKey, err)
} }
if lo.IsNotEmpty(layerInfo.OS) { if lo.IsNotEmpty(layerInfo.OS) {
osFound = layerInfo.OS osFound = layerInfo.OS
} }
return nil return nil, nil
})
if ctx.Err() != nil { }, nil)
break
}
}
if err := group.Wait(); err != nil { if err := p.Do(ctx); err != nil {
if ctx.Err() != nil { return xerrors.Errorf("pipeline error: %w", err)
return xerrors.Errorf("timeout: %w", ctx.Err())
}
return err
} }
if missingImage != "" { if missingImage != "" {
@@ -268,7 +252,6 @@ func (a Artifact) inspect(ctx context.Context, missingImage string, layerKeys, b
} }
return nil return nil
} }
func (a Artifact) inspectLayer(ctx context.Context, layerInfo LayerInfo, disabled []analyzer.Type) (types.BlobInfo, error) { func (a Artifact) inspectLayer(ctx context.Context, layerInfo LayerInfo, disabled []analyzer.Type) (types.BlobInfo, error) {

View File

@@ -54,7 +54,7 @@ func (s *Scanner) Scan(ctx context.Context, artifactsData []*artifacts.Artifact)
misconfig report.Resource misconfig report.Resource
} }
onItem := func(artifact *artifacts.Artifact) (scanResult, error) { onItem := func(ctx context.Context, artifact *artifacts.Artifact) (scanResult, error) {
scanResults := scanResult{} scanResults := scanResult{}
if s.opts.Scanners.AnyEnabled(types.VulnerabilityScanner, types.SecretScanner) { if s.opts.Scanners.AnyEnabled(types.VulnerabilityScanner, types.SecretScanner) {
vulns, err := s.scanVulns(ctx, artifact) vulns, err := s.scanVulns(ctx, artifact)

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"github.com/cheggaaa/pb/v3" "github.com/cheggaaa/pb/v3"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@@ -19,13 +18,17 @@ type Pipeline[T, U any] struct {
} }
// onItem represents a function type that takes an input element and returns an output element. // onItem represents a function type that takes an input element and returns an output element.
type onItem[T, U any] func(T) (U, error) type onItem[T, U any] func(context.Context, T) (U, error)
// onResult represents a function type that takes an output element. // onResult represents a function type that takes an output element.
type onResult[U any] func(U) error type onResult[U any] func(U) error
func NewPipeline[T, U any](numWorkers int, progress bool, items []T, func NewPipeline[T, U any](numWorkers int, progress bool, items []T,
fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] { fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] {
if fn2 == nil {
// In case where there is no need to process the return values
fn2 = func(_ U) error { return nil }
}
return Pipeline[T, U]{ return Pipeline[T, U]{
numWorkers: numWorkers, numWorkers: numWorkers,
progress: progress, progress: progress,
@@ -71,7 +74,7 @@ func (p *Pipeline[T, U]) Do(ctx context.Context) error {
for i := 0; i < p.numWorkers; i++ { for i := 0; i < p.numWorkers; i++ {
g.Go(func() error { g.Go(func() error {
for item := range itemCh { for item := range itemCh {
res, err := p.onItem(item) res, err := p.onItem(ctx, item)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -2,6 +2,7 @@ package parallel_test
import ( import (
"context" "context"
"fmt"
"math" "math"
"testing" "testing"
@@ -15,13 +16,13 @@ func TestPipeline_Do(t *testing.T) {
type field struct { type field struct {
numWorkers int numWorkers int
items []float64 items []float64
onItem func(float64) (float64, error) onItem func(context.Context, float64) (float64, error)
} }
type testCase struct { type testCase struct {
name string name string
field field field field
want float64 want float64
wantErr bool wantErr require.ErrorAssertionFunc
} }
tests := []testCase{ tests := []testCase{
{ {
@@ -40,11 +41,12 @@ func TestPipeline_Do(t *testing.T) {
9, 9,
10, 10,
}, },
onItem: func(f float64) (float64, error) { onItem: func(_ context.Context, f float64) (float64, error) {
return math.Pow(f, 2), nil return math.Pow(f, 2), nil
}, },
}, },
want: 385, want: 385,
wantErr: require.NoError,
}, },
{ {
name: "ceil", name: "ceil",
@@ -60,11 +62,41 @@ func TestPipeline_Do(t *testing.T) {
-2.2, -2.2,
-3.3, -3.3,
}, },
onItem: func(f float64) (float64, error) { onItem: func(_ context.Context, f float64) (float64, error) {
return math.Round(f), nil return math.Round(f), nil
}, },
}, },
want: 10, want: 10,
wantErr: require.NoError,
},
{
name: "error in series",
field: field{
numWorkers: 1,
items: []float64{
1,
2,
3,
},
onItem: func(_ context.Context, f float64) (float64, error) {
return 0, fmt.Errorf("error")
},
},
wantErr: require.Error,
},
{
name: "error in parallel",
field: field{
numWorkers: 3,
items: []float64{
1,
2,
},
onItem: func(_ context.Context, f float64) (float64, error) {
return 0, fmt.Errorf("error")
},
},
wantErr: require.Error,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -75,7 +107,7 @@ func TestPipeline_Do(t *testing.T) {
return nil return nil
}) })
err := p.Do(context.Background()) err := p.Do(context.Background())
require.NoError(t, err) tt.wantErr(t, err)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }