diff --git a/pkg/parallel/pipeline.go b/pkg/parallel/pipeline.go new file mode 100644 index 0000000000..f19f064d61 --- /dev/null +++ b/pkg/parallel/pipeline.go @@ -0,0 +1,107 @@ +package parallel + +import ( + "context" + + "github.com/cheggaaa/pb/v3" + + "golang.org/x/sync/errgroup" +) + +// Pipeline represents a structure for performing parallel processing. +// T represents the input element type and U represents the output element type. +type Pipeline[T, U any] struct { + numWorkers int + items []T + onItem onItem[T, U] + onResult onResult[U] + progress bool +} + +// onItem represents a function type that takes an input element and returns an output element. +type onItem[T, U any] func(T) (U, error) + +// onResult represents a function type that takes an output element. +type onResult[U any] func(U) error + +func NewPipeline[T, U any](numWorkers int, progress bool, items []T, + fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] { + return Pipeline[T, U]{ + numWorkers: numWorkers, + progress: progress, + items: items, + onItem: fn1, + onResult: fn2, + } +} + +// Do executes pipeline processing. +// It exits when any error occurs. +func (p *Pipeline[T, U]) Do(ctx context.Context) error { + // progress bar + var bar *pb.ProgressBar + if p.progress { + bar = pb.StartNew(len(p.items)) + defer bar.Finish() + } + + g, ctx := errgroup.WithContext(ctx) + itemCh := make(chan T) + + // Start a goroutine to send input data + g.Go(func() error { + defer close(itemCh) + for _, item := range p.items { + if p.progress { + bar.Increment() + } + select { + case itemCh <- item: + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + }) + + // Generate a channel for sending output data + results := make(chan U) + + // Start a fixed number of goroutines to process items. + for i := 0; i < p.numWorkers; i++ { + g.Go(func() error { + for item := range itemCh { + res, err := p.onItem(item) + if err != nil { + return err + } + select { + case results <- res: + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + }) + } + + go func() { + _ = g.Wait() + close(results) + }() + + // Process output data received from the channel + for res := range results { + if err := p.onResult(res); err != nil { + return err + } + } + + // Check whether any of the goroutines failed. Since g is accumulating the + // errors, we don't need to send them (or check for them) in the individual + // results sent on the channel. + if err := g.Wait(); err != nil { + return err + } + return nil +} diff --git a/pkg/parallel/pipeline_test.go b/pkg/parallel/pipeline_test.go new file mode 100644 index 0000000000..c4a064e695 --- /dev/null +++ b/pkg/parallel/pipeline_test.go @@ -0,0 +1,82 @@ +package parallel_test + +import ( + "context" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aquasecurity/trivy/pkg/parallel" +) + +func TestPipeline_Do(t *testing.T) { + type field struct { + numWorkers int + items []float64 + onItem func(float64) (float64, error) + } + type testCase struct { + name string + field field + want float64 + wantErr bool + } + tests := []testCase{ + { + name: "pow", + field: field{ + numWorkers: 5, + items: []float64{ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + }, + onItem: func(f float64) (float64, error) { + return math.Pow(f, 2), nil + }, + }, + want: 385, + }, + { + name: "ceil", + field: field{ + numWorkers: 3, + items: []float64{ + 1.1, + 2.2, + 3.3, + 4.4, + 5.5, + -1.1, + -2.2, + -3.3, + }, + onItem: func(f float64) (float64, error) { + return math.Round(f), nil + }, + }, + want: 10, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got float64 + p := parallel.NewPipeline(tt.field.numWorkers, false, tt.field.items, tt.field.onItem, func(f float64) error { + got += f + return nil + }) + err := p.Do(context.Background()) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/parallel/walk.go b/pkg/parallel/walk.go index e1bd8796a7..4cf832381d 100644 --- a/pkg/parallel/walk.go +++ b/pkg/parallel/walk.go @@ -12,10 +12,10 @@ import ( ) type onFile[T any] func(string, fs.FileInfo, dio.ReadSeekerAt) (T, error) -type onResult[T any] func(T) error +type onWalkResult[T any] func(T) error func WalkDir[T any](ctx context.Context, fsys fs.FS, root string, slow bool, - onFile onFile[T], onResult onResult[T]) error { + onFile onFile[T], onResult onWalkResult[T]) error { g, ctx := errgroup.WithContext(ctx) paths := make(chan string)