mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-06 04:41:18 -08:00
116 lines
2.5 KiB
Go
116 lines
2.5 KiB
Go
package parallel
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/cheggaaa/pb/v3"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
const defaultWorkers = 5
|
|
|
|
// Pipeline represents a structure for performing parallel processing.
|
|
// T represents the input element type and U represents the output element type.
|
|
type Pipeline[T, U any] struct {
|
|
numWorkers int
|
|
items []T
|
|
onItem onItem[T, U]
|
|
onResult onResult[U]
|
|
progress bool
|
|
}
|
|
|
|
// onItem represents a function type that takes an input element and returns an output element.
|
|
type onItem[T, U any] func(context.Context, T) (U, error)
|
|
|
|
// onResult represents a function type that takes an output element.
|
|
type onResult[U any] func(U) error
|
|
|
|
func NewPipeline[T, U any](numWorkers int, progress bool, items []T,
|
|
fn1 onItem[T, U], fn2 onResult[U]) Pipeline[T, U] {
|
|
if fn2 == nil {
|
|
// In case where there is no need to process the return values
|
|
fn2 = func(_ U) error { return nil }
|
|
}
|
|
if numWorkers == 0 {
|
|
numWorkers = defaultWorkers
|
|
}
|
|
return Pipeline[T, U]{
|
|
numWorkers: numWorkers,
|
|
progress: progress,
|
|
items: items,
|
|
onItem: fn1,
|
|
onResult: fn2,
|
|
}
|
|
}
|
|
|
|
// Do executes pipeline processing.
|
|
// It exits when any error occurs.
|
|
func (p *Pipeline[T, U]) Do(ctx context.Context) error {
|
|
// progress bar
|
|
var bar *pb.ProgressBar
|
|
if p.progress {
|
|
bar = pb.StartNew(len(p.items))
|
|
defer bar.Finish()
|
|
}
|
|
|
|
g, ctx := errgroup.WithContext(ctx)
|
|
itemCh := make(chan T)
|
|
|
|
// Start a goroutine to send input data
|
|
g.Go(func() error {
|
|
defer close(itemCh)
|
|
for _, item := range p.items {
|
|
if p.progress {
|
|
bar.Increment()
|
|
}
|
|
select {
|
|
case itemCh <- item:
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
|
|
// Generate a channel for sending output data
|
|
results := make(chan U)
|
|
|
|
// Start a fixed number of goroutines to process items.
|
|
for i := 0; i < p.numWorkers; i++ {
|
|
g.Go(func() error {
|
|
for item := range itemCh {
|
|
res, err := p.onItem(ctx, item)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case results <- res:
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
go func() {
|
|
_ = g.Wait()
|
|
close(results)
|
|
}()
|
|
|
|
// Process output data received from the channel
|
|
for res := range results {
|
|
if err := p.onResult(res); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Check whether any of the goroutines failed. Since g is accumulating the
|
|
// errors, we don't need to send them (or check for them) in the individual
|
|
// results sent on the channel.
|
|
if err := g.Wait(); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|