refactor: add pipeline (#3868)

This commit is contained in:
Teppei Fukuda
2023-03-19 19:55:36 +02:00
committed by GitHub
parent 927acf9579
commit 7148de3252
3 changed files with 191 additions and 2 deletions

107
pkg/parallel/pipeline.go Normal file
View File

@@ -0,0 +1,107 @@
package parallel
import (
"context"
"github.com/cheggaaa/pb/v3"
"golang.org/x/sync/errgroup"
)
// Pipeline represents a structure for performing parallel processing.
// T represents the input element type and U represents the output element type.
type Pipeline[T, U any] struct {
numWorkers int
items []T
onItem onItem[T, U]
onResult onResult[U]
progress bool
}
// onItem represents a function type that takes an input element and returns an output element.
type onItem[T, U any] func(T) (U, error)
// onResult represents a function type that takes an output element.
type onResult[U any] func(U) error
func NewPipeline[T, U any](numWorkers int, progress bool, items []T,
fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] {
return Pipeline[T, U]{
numWorkers: numWorkers,
progress: progress,
items: items,
onItem: fn1,
onResult: fn2,
}
}
// Do executes pipeline processing.
// It exits when any error occurs.
func (p *Pipeline[T, U]) Do(ctx context.Context) error {
// progress bar
var bar *pb.ProgressBar
if p.progress {
bar = pb.StartNew(len(p.items))
defer bar.Finish()
}
g, ctx := errgroup.WithContext(ctx)
itemCh := make(chan T)
// Start a goroutine to send input data
g.Go(func() error {
defer close(itemCh)
for _, item := range p.items {
if p.progress {
bar.Increment()
}
select {
case itemCh <- item:
case <-ctx.Done():
return ctx.Err()
}
}
return nil
})
// Generate a channel for sending output data
results := make(chan U)
// Start a fixed number of goroutines to process items.
for i := 0; i < p.numWorkers; i++ {
g.Go(func() error {
for item := range itemCh {
res, err := p.onItem(item)
if err != nil {
return err
}
select {
case results <- res:
case <-ctx.Done():
return ctx.Err()
}
}
return nil
})
}
go func() {
_ = g.Wait()
close(results)
}()
// Process output data received from the channel
for res := range results {
if err := p.onResult(res); err != nil {
return err
}
}
// Check whether any of the goroutines failed. Since g is accumulating the
// errors, we don't need to send them (or check for them) in the individual
// results sent on the channel.
if err := g.Wait(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,82 @@
package parallel_test
import (
"context"
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/parallel"
)
func TestPipeline_Do(t *testing.T) {
type field struct {
numWorkers int
items []float64
onItem func(float64) (float64, error)
}
type testCase struct {
name string
field field
want float64
wantErr bool
}
tests := []testCase{
{
name: "pow",
field: field{
numWorkers: 5,
items: []float64{
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
},
onItem: func(f float64) (float64, error) {
return math.Pow(f, 2), nil
},
},
want: 385,
},
{
name: "ceil",
field: field{
numWorkers: 3,
items: []float64{
1.1,
2.2,
3.3,
4.4,
5.5,
-1.1,
-2.2,
-3.3,
},
onItem: func(f float64) (float64, error) {
return math.Round(f), nil
},
},
want: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got float64
p := parallel.NewPipeline(tt.field.numWorkers, false, tt.field.items, tt.field.onItem, func(f float64) error {
got += f
return nil
})
err := p.Do(context.Background())
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -12,10 +12,10 @@ import (
)
type onFile[T any] func(string, fs.FileInfo, dio.ReadSeekerAt) (T, error)
type onResult[T any] func(T) error
type onWalkResult[T any] func(T) error
func WalkDir[T any](ctx context.Context, fsys fs.FS, root string, slow bool,
onFile onFile[T], onResult onResult[T]) error {
onFile onFile[T], onResult onWalkResult[T]) error {
g, ctx := errgroup.WithContext(ctx)
paths := make(chan string)