mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-22 07:10:41 -08:00
refactor: enable cases where return values are not needed in pipeline (#4443)
This commit is contained in:
@@ -14,7 +14,6 @@ import (
|
|||||||
v1 "github.com/google/go-containerregistry/pkg/v1"
|
v1 "github.com/google/go-containerregistry/pkg/v1"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
||||||
@@ -25,6 +24,7 @@ import (
|
|||||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||||
"github.com/aquasecurity/trivy/pkg/fanal/walker"
|
"github.com/aquasecurity/trivy/pkg/fanal/walker"
|
||||||
"github.com/aquasecurity/trivy/pkg/mapfs"
|
"github.com/aquasecurity/trivy/pkg/mapfs"
|
||||||
|
"github.com/aquasecurity/trivy/pkg/parallel"
|
||||||
"github.com/aquasecurity/trivy/pkg/semaphore"
|
"github.com/aquasecurity/trivy/pkg/semaphore"
|
||||||
"github.com/aquasecurity/trivy/pkg/syncx"
|
"github.com/aquasecurity/trivy/pkg/syncx"
|
||||||
)
|
)
|
||||||
@@ -216,49 +216,33 @@ func (a Artifact) consolidateCreatedBy(diffIDs, layerKeys []string, configFile *
|
|||||||
func (a Artifact) inspect(ctx context.Context, missingImage string, layerKeys, baseDiffIDs []string,
|
func (a Artifact) inspect(ctx context.Context, missingImage string, layerKeys, baseDiffIDs []string,
|
||||||
layerKeyMap map[string]LayerInfo, configFile *v1.ConfigFile) error {
|
layerKeyMap map[string]LayerInfo, configFile *v1.ConfigFile) error {
|
||||||
|
|
||||||
group, groupCtx := errgroup.WithContext(ctx)
|
|
||||||
concurrencyLimit := 5
|
|
||||||
if a.artifactOption.Slow {
|
|
||||||
concurrencyLimit = 1
|
|
||||||
}
|
|
||||||
group.SetLimit(concurrencyLimit)
|
|
||||||
|
|
||||||
var osFound types.OS
|
var osFound types.OS
|
||||||
for _, k := range layerKeys {
|
workers := lo.Ternary(a.artifactOption.Slow, 1, 5)
|
||||||
layerKey := k
|
p := parallel.NewPipeline(workers, false, layerKeys, func(ctx context.Context, layerKey string) (any, error) {
|
||||||
ctx := groupCtx
|
layer := layerKeyMap[layerKey]
|
||||||
group.Go(func() error {
|
|
||||||
layer := layerKeyMap[layerKey]
|
|
||||||
|
|
||||||
// If it is a base layer, secret scanning should not be performed.
|
// If it is a base layer, secret scanning should not be performed.
|
||||||
var disabledAnalyzers []analyzer.Type
|
var disabledAnalyzers []analyzer.Type
|
||||||
if slices.Contains(baseDiffIDs, layer.DiffID) {
|
if slices.Contains(baseDiffIDs, layer.DiffID) {
|
||||||
disabledAnalyzers = append(disabledAnalyzers, analyzer.TypeSecret)
|
disabledAnalyzers = append(disabledAnalyzers, analyzer.TypeSecret)
|
||||||
}
|
|
||||||
|
|
||||||
layerInfo, err := a.inspectLayer(ctx, layer, disabledAnalyzers)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("failed to analyze layer (%s): %w", layer.DiffID, err)
|
|
||||||
}
|
|
||||||
if err = a.cache.PutBlob(layerKey, layerInfo); err != nil {
|
|
||||||
return xerrors.Errorf("failed to store layer: %s in cache: %w", layerKey, err)
|
|
||||||
}
|
|
||||||
if lo.IsNotEmpty(layerInfo.OS) {
|
|
||||||
osFound = layerInfo.OS
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := group.Wait(); err != nil {
|
layerInfo, err := a.inspectLayer(ctx, layer, disabledAnalyzers)
|
||||||
if ctx.Err() != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("timeout: %w", ctx.Err())
|
return nil, xerrors.Errorf("failed to analyze layer (%s): %w", layer.DiffID, err)
|
||||||
}
|
}
|
||||||
return err
|
if err = a.cache.PutBlob(layerKey, layerInfo); err != nil {
|
||||||
|
return nil, xerrors.Errorf("failed to store layer: %s in cache: %w", layerKey, err)
|
||||||
|
}
|
||||||
|
if lo.IsNotEmpty(layerInfo.OS) {
|
||||||
|
osFound = layerInfo.OS
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
if err := p.Do(ctx); err != nil {
|
||||||
|
return xerrors.Errorf("pipeline error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if missingImage != "" {
|
if missingImage != "" {
|
||||||
@@ -268,7 +252,6 @@ func (a Artifact) inspect(ctx context.Context, missingImage string, layerKeys, b
|
|||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a Artifact) inspectLayer(ctx context.Context, layerInfo LayerInfo, disabled []analyzer.Type) (types.BlobInfo, error) {
|
func (a Artifact) inspectLayer(ctx context.Context, layerInfo LayerInfo, disabled []analyzer.Type) (types.BlobInfo, error) {
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func (s *Scanner) Scan(ctx context.Context, artifactsData []*artifacts.Artifact)
|
|||||||
misconfig report.Resource
|
misconfig report.Resource
|
||||||
}
|
}
|
||||||
|
|
||||||
onItem := func(artifact *artifacts.Artifact) (scanResult, error) {
|
onItem := func(ctx context.Context, artifact *artifacts.Artifact) (scanResult, error) {
|
||||||
scanResults := scanResult{}
|
scanResults := scanResult{}
|
||||||
if s.opts.Scanners.AnyEnabled(types.VulnerabilityScanner, types.SecretScanner) {
|
if s.opts.Scanners.AnyEnabled(types.VulnerabilityScanner, types.SecretScanner) {
|
||||||
vulns, err := s.scanVulns(ctx, artifact)
|
vulns, err := s.scanVulns(ctx, artifact)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/cheggaaa/pb/v3"
|
"github.com/cheggaaa/pb/v3"
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,13 +18,17 @@ type Pipeline[T, U any] struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// onItem represents a function type that takes an input element and returns an output element.
|
// onItem represents a function type that takes an input element and returns an output element.
|
||||||
type onItem[T, U any] func(T) (U, error)
|
type onItem[T, U any] func(context.Context, T) (U, error)
|
||||||
|
|
||||||
// onResult represents a function type that takes an output element.
|
// onResult represents a function type that takes an output element.
|
||||||
type onResult[U any] func(U) error
|
type onResult[U any] func(U) error
|
||||||
|
|
||||||
func NewPipeline[T, U any](numWorkers int, progress bool, items []T,
|
func NewPipeline[T, U any](numWorkers int, progress bool, items []T,
|
||||||
fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] {
|
fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] {
|
||||||
|
if fn2 == nil {
|
||||||
|
// In case where there is no need to process the return values
|
||||||
|
fn2 = func(_ U) error { return nil }
|
||||||
|
}
|
||||||
return Pipeline[T, U]{
|
return Pipeline[T, U]{
|
||||||
numWorkers: numWorkers,
|
numWorkers: numWorkers,
|
||||||
progress: progress,
|
progress: progress,
|
||||||
@@ -71,7 +74,7 @@ func (p *Pipeline[T, U]) Do(ctx context.Context) error {
|
|||||||
for i := 0; i < p.numWorkers; i++ {
|
for i := 0; i < p.numWorkers; i++ {
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
for item := range itemCh {
|
for item := range itemCh {
|
||||||
res, err := p.onItem(item)
|
res, err := p.onItem(ctx, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package parallel_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -15,13 +16,13 @@ func TestPipeline_Do(t *testing.T) {
|
|||||||
type field struct {
|
type field struct {
|
||||||
numWorkers int
|
numWorkers int
|
||||||
items []float64
|
items []float64
|
||||||
onItem func(float64) (float64, error)
|
onItem func(context.Context, float64) (float64, error)
|
||||||
}
|
}
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
name string
|
name string
|
||||||
field field
|
field field
|
||||||
want float64
|
want float64
|
||||||
wantErr bool
|
wantErr require.ErrorAssertionFunc
|
||||||
}
|
}
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -40,11 +41,12 @@ func TestPipeline_Do(t *testing.T) {
|
|||||||
9,
|
9,
|
||||||
10,
|
10,
|
||||||
},
|
},
|
||||||
onItem: func(f float64) (float64, error) {
|
onItem: func(_ context.Context, f float64) (float64, error) {
|
||||||
return math.Pow(f, 2), nil
|
return math.Pow(f, 2), nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: 385,
|
want: 385,
|
||||||
|
wantErr: require.NoError,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ceil",
|
name: "ceil",
|
||||||
@@ -60,11 +62,41 @@ func TestPipeline_Do(t *testing.T) {
|
|||||||
-2.2,
|
-2.2,
|
||||||
-3.3,
|
-3.3,
|
||||||
},
|
},
|
||||||
onItem: func(f float64) (float64, error) {
|
onItem: func(_ context.Context, f float64) (float64, error) {
|
||||||
return math.Round(f), nil
|
return math.Round(f), nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: 10,
|
want: 10,
|
||||||
|
wantErr: require.NoError,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error in series",
|
||||||
|
field: field{
|
||||||
|
numWorkers: 1,
|
||||||
|
items: []float64{
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
},
|
||||||
|
onItem: func(_ context.Context, f float64) (float64, error) {
|
||||||
|
return 0, fmt.Errorf("error")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: require.Error,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error in parallel",
|
||||||
|
field: field{
|
||||||
|
numWorkers: 3,
|
||||||
|
items: []float64{
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
},
|
||||||
|
onItem: func(_ context.Context, f float64) (float64, error) {
|
||||||
|
return 0, fmt.Errorf("error")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: require.Error,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -75,7 +107,7 @@ func TestPipeline_Do(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
err := p.Do(context.Background())
|
err := p.Do(context.Background())
|
||||||
require.NoError(t, err)
|
tt.wantErr(t, err)
|
||||||
assert.Equal(t, tt.want, got)
|
assert.Equal(t, tt.want, got)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user