mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-21 23:00:42 -08:00
refactor: add pipeline (#3868)
This commit is contained in:
107
pkg/parallel/pipeline.go
Normal file
107
pkg/parallel/pipeline.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package parallel
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cheggaaa/pb/v3"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Pipeline represents a structure for performing parallel processing.
|
||||
// T represents the input element type and U represents the output element type.
|
||||
type Pipeline[T, U any] struct {
|
||||
numWorkers int
|
||||
items []T
|
||||
onItem onItem[T, U]
|
||||
onResult onResult[U]
|
||||
progress bool
|
||||
}
|
||||
|
||||
// onItem represents a function type that takes an input element and returns an output element.
|
||||
type onItem[T, U any] func(T) (U, error)
|
||||
|
||||
// onResult represents a function type that takes an output element.
|
||||
type onResult[U any] func(U) error
|
||||
|
||||
func NewPipeline[T, U any](numWorkers int, progress bool, items []T,
|
||||
fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] {
|
||||
return Pipeline[T, U]{
|
||||
numWorkers: numWorkers,
|
||||
progress: progress,
|
||||
items: items,
|
||||
onItem: fn1,
|
||||
onResult: fn2,
|
||||
}
|
||||
}
|
||||
|
||||
// Do executes pipeline processing.
|
||||
// It exits when any error occurs.
|
||||
func (p *Pipeline[T, U]) Do(ctx context.Context) error {
|
||||
// progress bar
|
||||
var bar *pb.ProgressBar
|
||||
if p.progress {
|
||||
bar = pb.StartNew(len(p.items))
|
||||
defer bar.Finish()
|
||||
}
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
itemCh := make(chan T)
|
||||
|
||||
// Start a goroutine to send input data
|
||||
g.Go(func() error {
|
||||
defer close(itemCh)
|
||||
for _, item := range p.items {
|
||||
if p.progress {
|
||||
bar.Increment()
|
||||
}
|
||||
select {
|
||||
case itemCh <- item:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Generate a channel for sending output data
|
||||
results := make(chan U)
|
||||
|
||||
// Start a fixed number of goroutines to process items.
|
||||
for i := 0; i < p.numWorkers; i++ {
|
||||
g.Go(func() error {
|
||||
for item := range itemCh {
|
||||
res, err := p.onItem(item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case results <- res:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = g.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
// Process output data received from the channel
|
||||
for res := range results {
|
||||
if err := p.onResult(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether any of the goroutines failed. Since g is accumulating the
|
||||
// errors, we don't need to send them (or check for them) in the individual
|
||||
// results sent on the channel.
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
82
pkg/parallel/pipeline_test.go
Normal file
82
pkg/parallel/pipeline_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package parallel_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/parallel"
|
||||
)
|
||||
|
||||
func TestPipeline_Do(t *testing.T) {
|
||||
type field struct {
|
||||
numWorkers int
|
||||
items []float64
|
||||
onItem func(float64) (float64, error)
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
field field
|
||||
want float64
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "pow",
|
||||
field: field{
|
||||
numWorkers: 5,
|
||||
items: []float64{
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
},
|
||||
onItem: func(f float64) (float64, error) {
|
||||
return math.Pow(f, 2), nil
|
||||
},
|
||||
},
|
||||
want: 385,
|
||||
},
|
||||
{
|
||||
name: "ceil",
|
||||
field: field{
|
||||
numWorkers: 3,
|
||||
items: []float64{
|
||||
1.1,
|
||||
2.2,
|
||||
3.3,
|
||||
4.4,
|
||||
5.5,
|
||||
-1.1,
|
||||
-2.2,
|
||||
-3.3,
|
||||
},
|
||||
onItem: func(f float64) (float64, error) {
|
||||
return math.Round(f), nil
|
||||
},
|
||||
},
|
||||
want: 10,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got float64
|
||||
p := parallel.NewPipeline(tt.field.numWorkers, false, tt.field.items, tt.field.onItem, func(f float64) error {
|
||||
got += f
|
||||
return nil
|
||||
})
|
||||
err := p.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
)
|
||||
|
||||
type onFile[T any] func(string, fs.FileInfo, dio.ReadSeekerAt) (T, error)
|
||||
type onResult[T any] func(T) error
|
||||
type onWalkResult[T any] func(T) error
|
||||
|
||||
func WalkDir[T any](ctx context.Context, fsys fs.FS, root string, slow bool,
|
||||
onFile onFile[T], onResult onResult[T]) error {
|
||||
onFile onFile[T], onResult onWalkResult[T]) error {
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
paths := make(chan string)
|
||||
|
||||
Reference in New Issue
Block a user