feat: add generic worker pool for concurrent task processing

implement channel-based work distribution with generics for type-safe
concurrent processing, includes run, runwithfilter, and foreach methods
with comprehensive test coverage
This commit is contained in:
vmfunc
2026-01-02 22:59:45 -08:00
parent aba8c410a6
commit 7ab5cfc18c
2 changed files with 353 additions and 0 deletions
+170
View File
@@ -0,0 +1,170 @@
/*
·━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━·
: :
: █▀ █ █▀▀ · Blazing-fast pentesting suite :
: ▄█ █ █▀ · BSD 3-Clause License :
: :
: (c) 2022-2025 vmfunc, xyzeva, :
: lunchcat alumni & contributors :
: :
·━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━·
*/
// Package worker provides a generic worker pool for concurrent task processing.
package worker
import "sync"
// Pool manages a pool of workers that process items concurrently.
// It uses channel-based distribution for efficient load balancing.
type Pool[T any, R any] struct {
workers int
fn func(T) R
}
// New creates a new worker pool with the specified number of workers
// and a processing function.
func New[T any, R any](workers int, fn func(T) R) *Pool[T, R] {
if workers < 1 {
workers = 1
}
return &Pool[T, R]{
workers: workers,
fn: fn,
}
}
// Run processes all items concurrently and returns the results.
// Items are distributed via a channel for optimal load balancing.
func (p *Pool[T, R]) Run(items []T) []R {
if len(items) == 0 {
return nil
}
input := make(chan T, len(items))
output := make(chan R, len(items))
var wg sync.WaitGroup
wg.Add(p.workers)
// Start workers
for i := 0; i < p.workers; i++ {
go func() {
defer wg.Done()
for item := range input {
output <- p.fn(item)
}
}()
}
// Feed items to workers
for _, item := range items {
input <- item
}
close(input)
// Wait for all workers to finish, then close output
go func() {
wg.Wait()
close(output)
}()
// Collect results
results := make([]R, 0, len(items))
for r := range output {
results = append(results, r)
}
return results
}
// RunWithFilter processes items concurrently and returns only non-zero results.
// Useful when the processing function may return zero values for filtered items.
func (p *Pool[T, R]) RunWithFilter(items []T, filter func(R) bool) []R {
if len(items) == 0 {
return nil
}
input := make(chan T, len(items))
output := make(chan R, len(items))
var wg sync.WaitGroup
wg.Add(p.workers)
// Start workers
for i := 0; i < p.workers; i++ {
go func() {
defer wg.Done()
for item := range input {
result := p.fn(item)
if filter(result) {
output <- result
}
}
}()
}
// Feed items to workers
for _, item := range items {
input <- item
}
close(input)
// Wait for all workers to finish, then close output
go func() {
wg.Wait()
close(output)
}()
// Collect results
results := make([]R, 0, len(items)/2) // Estimate half will pass filter
for r := range output {
results = append(results, r)
}
return results
}
// ForEach processes items concurrently without collecting results.
// Useful for side-effect operations like logging or writing to external stores.
func (p *Pool[T, R]) ForEach(items []T, callback func(R)) {
if len(items) == 0 {
return
}
input := make(chan T, len(items))
output := make(chan R, len(items))
var wg sync.WaitGroup
wg.Add(p.workers)
// Start workers
for i := 0; i < p.workers; i++ {
go func() {
defer wg.Done()
for item := range input {
output <- p.fn(item)
}
}()
}
// Feed items to workers
for _, item := range items {
input <- item
}
close(input)
// Process results as they come in
var outputWg sync.WaitGroup
outputWg.Add(1)
go func() {
defer outputWg.Done()
for r := range output {
callback(r)
}
}()
wg.Wait()
close(output)
outputWg.Wait()
}
+183
View File
@@ -0,0 +1,183 @@
/*
·━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━·
: :
: █▀ █ █▀▀ · Blazing-fast pentesting suite :
: ▄█ █ █▀ · BSD 3-Clause License :
: :
: (c) 2022-2025 vmfunc, xyzeva, :
: lunchcat alumni & contributors :
: :
·━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━·
*/
package worker
import (
"sort"
"sync/atomic"
"testing"
)
func TestPoolRun(t *testing.T) {
pool := New(4, func(x int) int {
return x * 2
})
items := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
results := pool.Run(items)
if len(results) != len(items) {
t.Errorf("Expected %d results, got %d", len(items), len(results))
}
// Sort results since order is not guaranteed
sort.Ints(results)
expected := []int{2, 4, 6, 8, 10, 12, 14, 16, 18, 20}
for i, v := range results {
if v != expected[i] {
t.Errorf("Expected results[%d] = %d, got %d", i, expected[i], v)
}
}
}
func TestPoolRunEmpty(t *testing.T) {
pool := New(4, func(x int) int {
return x * 2
})
results := pool.Run(nil)
if results != nil {
t.Errorf("Expected nil for empty input, got %v", results)
}
results = pool.Run([]int{})
if results != nil {
t.Errorf("Expected nil for empty slice, got %v", results)
}
}
func TestPoolRunWithFilter(t *testing.T) {
pool := New(4, func(x int) int {
return x * 2
})
items := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
results := pool.RunWithFilter(items, func(r int) bool {
return r > 10 // Only keep results > 10
})
// Should have 5 results: 12, 14, 16, 18, 20
if len(results) != 5 {
t.Errorf("Expected 5 results, got %d", len(results))
}
sort.Ints(results)
expected := []int{12, 14, 16, 18, 20}
for i, v := range results {
if v != expected[i] {
t.Errorf("Expected results[%d] = %d, got %d", i, expected[i], v)
}
}
}
func TestPoolForEach(t *testing.T) {
var sum atomic.Int64
pool := New(4, func(x int) int {
return x * 2
})
items := []int{1, 2, 3, 4, 5}
pool.ForEach(items, func(r int) {
sum.Add(int64(r))
})
// Sum should be 2+4+6+8+10 = 30
if sum.Load() != 30 {
t.Errorf("Expected sum = 30, got %d", sum.Load())
}
}
func TestPoolSingleWorker(t *testing.T) {
pool := New(1, func(x int) int {
return x + 1
})
items := []int{1, 2, 3}
results := pool.Run(items)
if len(results) != 3 {
t.Errorf("Expected 3 results, got %d", len(results))
}
sort.Ints(results)
expected := []int{2, 3, 4}
for i, v := range results {
if v != expected[i] {
t.Errorf("Expected results[%d] = %d, got %d", i, expected[i], v)
}
}
}
func TestPoolZeroWorkers(t *testing.T) {
// Zero workers should default to 1
pool := New(0, func(x int) int {
return x
})
if pool.workers != 1 {
t.Errorf("Expected workers = 1, got %d", pool.workers)
}
}
func TestPoolStringProcessing(t *testing.T) {
pool := New(2, func(s string) int {
return len(s)
})
items := []string{"a", "bb", "ccc", "dddd"}
results := pool.Run(items)
sort.Ints(results)
expected := []int{1, 2, 3, 4}
for i, v := range results {
if v != expected[i] {
t.Errorf("Expected results[%d] = %d, got %d", i, expected[i], v)
}
}
}
func TestPoolStructProcessing(t *testing.T) {
type input struct {
a int
b int
}
type output struct {
sum int
prod int
}
pool := New(3, func(in input) output {
return output{sum: in.a + in.b, prod: in.a * in.b}
})
items := []input{{1, 2}, {3, 4}, {5, 6}}
results := pool.Run(items)
if len(results) != 3 {
t.Errorf("Expected 3 results, got %d", len(results))
}
// Verify all expected outputs are present
found := make(map[output]bool)
for _, r := range results {
found[r] = true
}
expectedOutputs := []output{{3, 2}, {7, 12}, {11, 30}}
for _, exp := range expectedOutputs {
if !found[exp] {
t.Errorf("Expected output %v not found in results", exp)
}
}
}