refactor: add generic Set implementation (#8149)

Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
Teppei Fukuda
2024-12-24 13:47:21 +09:00
committed by GitHub
parent e6d0ba5cc9
commit b5859d3fb5
34 changed files with 968 additions and 270 deletions

View File

@@ -30,3 +30,8 @@ func errorsJoin(m dsl.Matcher) {
m.Match(`errors.Join($*args)`). m.Match(`errors.Join($*args)`).
Report("use github.com/hashicorp/go-multierror.Append instead of errors.Join.") Report("use github.com/hashicorp/go-multierror.Append instead of errors.Join.")
} }
func mapSet(m dsl.Matcher) {
m.Match(`map[$x]struct{}`).
Report("use github.com/aquasecurity/trivy/pkg/set.Set instead of map.")
}

View File

@@ -6,13 +6,13 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/samber/lo"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
sp "github.com/aquasecurity/trivy-checks/pkg/spec" sp "github.com/aquasecurity/trivy-checks/pkg/spec"
iacTypes "github.com/aquasecurity/trivy/pkg/iac/types" iacTypes "github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
) )
@@ -31,17 +31,17 @@ const (
// Scanners reads spec control and determines the scanners by check ID prefix // Scanners reads spec control and determines the scanners by check ID prefix
func (cs *ComplianceSpec) Scanners() (types.Scanners, error) { func (cs *ComplianceSpec) Scanners() (types.Scanners, error) {
scannerTypes := make(map[types.Scanner]struct{}) scannerTypes := set.New[types.Scanner]()
for _, control := range cs.Spec.Controls { for _, control := range cs.Spec.Controls {
for _, check := range control.Checks { for _, check := range control.Checks {
scannerType := scannerByCheckID(check.ID) scannerType := scannerByCheckID(check.ID)
if scannerType == types.UnknownScanner { if scannerType == types.UnknownScanner {
return nil, xerrors.Errorf("unsupported check ID: %s", check.ID) return nil, xerrors.Errorf("unsupported check ID: %s", check.ID)
} }
scannerTypes[scannerType] = struct{}{} scannerTypes.Append(scannerType)
} }
} }
return lo.Keys(scannerTypes), nil return scannerTypes.Items(), nil
} }
// CheckIDs return list of compliance check IDs // CheckIDs return list of compliance check IDs

View File

@@ -12,6 +12,7 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/version/doc" "github.com/aquasecurity/trivy/pkg/version/doc"
) )
@@ -30,7 +31,7 @@ type artifact struct {
Version version Version version
Licenses []string Licenses []string
Exclusions map[string]struct{} Exclusions set.Set[string]
Module bool Module bool
Relationship ftypes.Relationship Relationship ftypes.Relationship

View File

@@ -22,6 +22,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils" "github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io" xio "github.com/aquasecurity/trivy/pkg/x/io"
) )
@@ -118,11 +119,11 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
rootArt := root.artifact() rootArt := root.artifact()
rootArt.Relationship = ftypes.RelationshipRoot rootArt.Relationship = ftypes.RelationshipRoot
return p.parseRoot(rootArt, make(map[string]struct{})) return p.parseRoot(rootArt, set.New[string]())
} }
// nolint: gocyclo // nolint: gocyclo
func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ftypes.Package, []ftypes.Dependency, error) { func (p *Parser) parseRoot(root artifact, uniqModules set.Set[string]) ([]ftypes.Package, []ftypes.Dependency, error) {
// Prepare a queue for dependencies // Prepare a queue for dependencies
queue := newArtifactQueue() queue := newArtifactQueue()
@@ -145,10 +146,10 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
// Modules should be handled separately so that they can have independent dependencies. // Modules should be handled separately so that they can have independent dependencies.
// It means multi-module allows for duplicate dependencies. // It means multi-module allows for duplicate dependencies.
if art.Module { if art.Module {
if _, ok := uniqModules[art.String()]; ok { if uniqModules.Contains(art.String()) {
continue continue
} }
uniqModules[art.String()] = struct{}{} uniqModules.Append(art.String())
modulePkgs, moduleDeps, err := p.parseRoot(art, uniqModules) modulePkgs, moduleDeps, err := p.parseRoot(art, uniqModules)
if err != nil { if err != nil {
@@ -251,7 +252,7 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
// `mvn` shows modules separately from the root package and does not show module nesting. // `mvn` shows modules separately from the root package and does not show module nesting.
// So we can add all modules as dependencies of root package. // So we can add all modules as dependencies of root package.
if art.Relationship == ftypes.RelationshipRoot { if art.Relationship == ftypes.RelationshipRoot {
dependsOn = append(dependsOn, lo.Keys(uniqModules)...) dependsOn = append(dependsOn, uniqModules.Items()...)
} }
sort.Strings(dependsOn) sort.Strings(dependsOn)
@@ -340,7 +341,7 @@ type analysisResult struct {
} }
type analysisOptions struct { type analysisOptions struct {
exclusions map[string]struct{} exclusions set.Set[string]
depManagement []pomDependency // from the root POM depManagement []pomDependency // from the root POM
} }
@@ -348,6 +349,9 @@ func (p *Parser) analyze(pom *pom, opts analysisOptions) (analysisResult, error)
if pom.nil() { if pom.nil() {
return analysisResult{}, nil return analysisResult{}, nil
} }
if opts.exclusions == nil {
opts.exclusions = set.New[string]()
}
// Update remoteRepositories // Update remoteRepositories
pomReleaseRemoteRepos, pomSnapshotRemoteRepos := pom.repositories(p.servers) pomReleaseRemoteRepos, pomSnapshotRemoteRepos := pom.repositories(p.servers)
p.releaseRemoteRepos = lo.Uniq(append(pomReleaseRemoteRepos, p.releaseRemoteRepos...)) p.releaseRemoteRepos = lo.Uniq(append(pomReleaseRemoteRepos, p.releaseRemoteRepos...))
@@ -408,16 +412,16 @@ func (p *Parser) resolveParent(pom *pom) error {
} }
func (p *Parser) mergeDependencyManagements(depManagements ...[]pomDependency) []pomDependency { func (p *Parser) mergeDependencyManagements(depManagements ...[]pomDependency) []pomDependency {
uniq := make(map[string]struct{}) uniq := set.New[string]()
var depManagement []pomDependency var depManagement []pomDependency
// The preceding argument takes precedence. // The preceding argument takes precedence.
for _, dm := range depManagements { for _, dm := range depManagements {
for _, dep := range dm { for _, dep := range dm {
if _, ok := uniq[dep.Name()]; ok { if uniq.Contains(dep.Name()) {
continue continue
} }
depManagement = append(depManagement, dep) depManagement = append(depManagement, dep)
uniq[dep.Name()] = struct{}{} uniq.Append(dep.Name())
} }
} }
return depManagement return depManagement
@@ -492,19 +496,19 @@ func (p *Parser) mergeDependencies(child, parent []pomDependency) []pomDependenc
}) })
} }
func (p *Parser) filterDependencies(artifacts []artifact, exclusions map[string]struct{}) []artifact { func (p *Parser) filterDependencies(artifacts []artifact, exclusions set.Set[string]) []artifact {
return lo.Filter(artifacts, func(art artifact, _ int) bool { return lo.Filter(artifacts, func(art artifact, _ int) bool {
return !excludeDep(exclusions, art) return !excludeDep(exclusions, art)
}) })
} }
func excludeDep(exclusions map[string]struct{}, art artifact) bool { func excludeDep(exclusions set.Set[string], art artifact) bool {
if _, ok := exclusions[art.Name()]; ok { if exclusions.Contains(art.Name()) {
return true return true
} }
// Maven can use "*" in GroupID and ArtifactID fields to exclude dependencies // Maven can use "*" in GroupID and ArtifactID fields to exclude dependencies
// https://maven.apache.org/pom.html#exclusions // https://maven.apache.org/pom.html#exclusions
for exlusion := range exclusions { for exlusion := range exclusions.Iter() {
// exclusion format - "<groupID>:<artifactID>" // exclusion format - "<groupID>:<artifactID>"
e := strings.Split(exlusion, ":") e := strings.Split(exlusion, ":")
if (e[0] == art.GroupID || e[0] == "*") && (e[1] == art.ArtifactID || e[1] == "*") { if (e[0] == art.GroupID || e[0] == "*") && (e[1] == art.ArtifactID || e[1] == "*") {

View File

@@ -4,7 +4,6 @@ import (
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"io" "io"
"maps"
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
@@ -15,6 +14,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils" "github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/x/slices" "github.com/aquasecurity/trivy/pkg/x/slices"
) )
@@ -287,12 +287,12 @@ func (d pomDependency) ToArtifact(opts analysisOptions) artifact {
// To avoid shadow adding exclusions to top pom's, // To avoid shadow adding exclusions to top pom's,
// we need to initialize a new map for each new artifact // we need to initialize a new map for each new artifact
// See `exclusions in child` test for more information // See `exclusions in child` test for more information
exclusions := make(map[string]struct{}) exclusions := set.New[string]()
if opts.exclusions != nil { if opts.exclusions != nil {
exclusions = maps.Clone(opts.exclusions) exclusions = opts.exclusions.Clone()
} }
for _, e := range d.Exclusions.Exclusion { for _, e := range d.Exclusions.Exclusion {
exclusions[fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID)] = struct{}{} exclusions.Append(fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID))
} }
var locations ftypes.Locations var locations ftypes.Locations

View File

@@ -17,6 +17,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils" "github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io" xio "github.com/aquasecurity/trivy/pkg/x/io"
) )
@@ -91,7 +92,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
// https://docs.npmjs.com/cli/v9/configuring-npm/package-lock-json#packages // https://docs.npmjs.com/cli/v9/configuring-npm/package-lock-json#packages
p.resolveLinks(packages) p.resolveLinks(packages)
directDeps := make(map[string]struct{}) directDeps := set.New[string]()
for name, version := range lo.Assign(packages[""].Dependencies, packages[""].OptionalDependencies, packages[""].DevDependencies, packages[""].PeerDependencies) { for name, version := range lo.Assign(packages[""].Dependencies, packages[""].OptionalDependencies, packages[""].DevDependencies, packages[""].PeerDependencies) {
pkgPath := joinPaths(nodeModulesDir, name) pkgPath := joinPaths(nodeModulesDir, name)
if _, ok := packages[pkgPath]; !ok { if _, ok := packages[pkgPath]; !ok {
@@ -101,7 +102,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
} }
// Store the package paths of direct dependencies // Store the package paths of direct dependencies
// e.g. node_modules/body-parser // e.g. node_modules/body-parser
directDeps[pkgPath] = struct{}{} directDeps.Append(pkgPath)
} }
for pkgPath, pkg := range packages { for pkgPath, pkg := range packages {
@@ -366,13 +367,13 @@ func (p *Parser) pkgNameFromPath(pkgPath string) string {
func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency { func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
var uniqDeps ftypes.Dependencies var uniqDeps ftypes.Dependencies
unique := make(map[string]struct{}) unique := set.New[string]()
for _, dep := range deps { for _, dep := range deps {
sort.Strings(dep.DependsOn) sort.Strings(dep.DependsOn)
depKey := fmt.Sprintf("%s:%s", dep.ID, strings.Join(dep.DependsOn, ",")) depKey := fmt.Sprintf("%s:%s", dep.ID, strings.Join(dep.DependsOn, ","))
if _, ok := unique[depKey]; !ok { if !unique.Contains(depKey) {
unique[depKey] = struct{}{} unique.Append(depKey)
uniqDeps = append(uniqDeps, dep) uniqDeps = append(uniqDeps, dep)
} }
} }
@@ -381,11 +382,11 @@ func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
return uniqDeps return uniqDeps
} }
func isIndirectPkg(pkgPath string, directDeps map[string]struct{}) bool { func isIndirectPkg(pkgPath string, directDeps set.Set[string]) bool {
// A project can contain 2 different versions of the same dependency. // A project can contain 2 different versions of the same dependency.
// e.g. `node_modules/string-width/node_modules/strip-ansi` and `node_modules/string-ansi` // e.g. `node_modules/string-width/node_modules/strip-ansi` and `node_modules/string-ansi`
// direct dependencies always have root path (`node_modules/<pkg_name>`) // direct dependencies always have root path (`node_modules/<pkg_name>`)
if _, ok := directDeps[pkgPath]; ok { if directDeps.Contains(pkgPath) {
return false return false
} }
return true return true

View File

@@ -14,6 +14,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency" "github.com/aquasecurity/trivy/pkg/dependency"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io" xio "github.com/aquasecurity/trivy/pkg/x/io"
) )
@@ -215,7 +216,7 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
} }
} }
visited := make(map[string]struct{}) visited := set.New[string]()
// Overwrite the `Dev` field for dev deps and their child dependencies. // Overwrite the `Dev` field for dev deps and their child dependencies.
for _, pkg := range resolvedPkgs { for _, pkg := range resolvedPkgs {
if !pkg.Dev { if !pkg.Dev {
@@ -227,8 +228,8 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
} }
// markRootPkgs sets `Dev` to false for non dev dependency. // markRootPkgs sets `Dev` to false for non dev dependency.
func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited map[string]struct{}) { func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited set.Set[string]) {
if _, ok := visited[id]; ok { if visited.Contains(id) {
return return
} }
pkg, ok := pkgs[id] pkg, ok := pkgs[id]
@@ -238,7 +239,7 @@ func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps ma
pkg.Dev = false pkg.Dev = false
pkgs[id] = pkg pkgs[id] = pkg
visited[id] = struct{}{} visited.Append(id)
// Update child deps // Update child deps
for _, depID := range deps[id].DependsOn { for _, depID := range deps[id].DependsOn {

View File

@@ -76,7 +76,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
} }
if savedDependsOn, ok := depsMap[depId]; ok { if savedDependsOn, ok := depsMap[depId]; ok {
dependsOn = utils.UniqueStrings(append(dependsOn, savedDependsOn...)) dependsOn = lo.Uniq(append(dependsOn, savedDependsOn...))
} }
if len(dependsOn) > 0 { if len(dependsOn) > 0 {

View File

@@ -8,6 +8,7 @@ import (
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/dependency/parser/python" "github.com/aquasecurity/trivy/pkg/dependency/parser/python"
"github.com/aquasecurity/trivy/pkg/set"
) )
type PyProject struct { type PyProject struct {
@@ -19,25 +20,27 @@ type Tool struct {
} }
type Poetry struct { type Poetry struct {
Dependencies dependencies `toml:"dependencies"` Dependencies Dependencies `toml:"dependencies"`
Groups map[string]Group `toml:"group"` Groups map[string]Group `toml:"group"`
} }
type Group struct { type Group struct {
Dependencies dependencies `toml:"dependencies"` Dependencies Dependencies `toml:"dependencies"`
} }
type dependencies map[string]struct{} type Dependencies struct {
set.Set[string]
}
func (d *dependencies) UnmarshalTOML(data any) error { func (d *Dependencies) UnmarshalTOML(data any) error {
m, ok := data.(map[string]any) m, ok := data.(map[string]any)
if !ok { if !ok {
return xerrors.Errorf("dependencies must be map, but got: %T", data) return xerrors.Errorf("dependencies must be map, but got: %T", data)
} }
*d = lo.MapEntries(m, func(pkgName string, _ any) (string, struct{}) { d.Set = set.New[string](lo.MapToSlice(m, func(pkgName string, _ any) string {
return python.NormalizePkgName(pkgName), struct{}{} return python.NormalizePkgName(pkgName)
}) })...)
return nil return nil
} }

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/dependency/parser/python/pyproject" "github.com/aquasecurity/trivy/pkg/dependency/parser/python/pyproject"
"github.com/aquasecurity/trivy/pkg/set"
) )
func TestParser_Parse(t *testing.T) { func TestParser_Parse(t *testing.T) {
@@ -24,21 +25,18 @@ func TestParser_Parse(t *testing.T) {
want: pyproject.PyProject{ want: pyproject.PyProject{
Tool: pyproject.Tool{ Tool: pyproject.Tool{
Poetry: pyproject.Poetry{ Poetry: pyproject.Poetry{
Dependencies: map[string]struct{}{ Dependencies: pyproject.Dependencies{
"flask": {}, Set: set.New[string]("flask", "python", "requests", "virtualenv"),
"python": {},
"requests": {},
"virtualenv": {},
}, },
Groups: map[string]pyproject.Group{ Groups: map[string]pyproject.Group{
"dev": { "dev": {
Dependencies: map[string]struct{}{ Dependencies: pyproject.Dependencies{
"pytest": {}, Set: set.New[string]("pytest"),
}, },
}, },
"lint": { "lint": {
Dependencies: map[string]struct{}{ Dependencies: pyproject.Dependencies{
"ruff": {}, Set: set.New[string]("ruff"),
}, },
}, },
}, },

View File

@@ -9,6 +9,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency" "github.com/aquasecurity/trivy/pkg/dependency"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io" xio "github.com/aquasecurity/trivy/pkg/x/io"
) )
@@ -22,25 +23,25 @@ func (l Lock) packages() map[string]Package {
}) })
} }
func (l Lock) directDeps(root Package) map[string]struct{} { func (l Lock) directDeps(root Package) set.Set[string] {
deps := make(map[string]struct{}) deps := set.New[string]()
for _, dep := range root.Dependencies { for _, dep := range root.Dependencies {
deps[dep.Name] = struct{}{} deps.Append(dep.Name)
} }
return deps return deps
} }
func prodDeps(root Package, packages map[string]Package) map[string]struct{} { func prodDeps(root Package, packages map[string]Package) set.Set[string] {
visited := make(map[string]struct{}) visited := set.New[string]()
walkPackageDeps(root, packages, visited) walkPackageDeps(root, packages, visited)
return visited return visited
} }
func walkPackageDeps(pkg Package, packages map[string]Package, visited map[string]struct{}) { func walkPackageDeps(pkg Package, packages map[string]Package, visited set.Set[string]) {
if _, ok := visited[pkg.Name]; ok { if visited.Contains(pkg.Name) {
return return
} }
visited[pkg.Name] = struct{}{} visited.Append(pkg.Name)
for _, dep := range pkg.Dependencies { for _, dep := range pkg.Dependencies {
depPkg, exists := packages[dep.Name] depPkg, exists := packages[dep.Name]
if !exists { if !exists {
@@ -119,7 +120,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
) )
for _, pkg := range lock.Packages { for _, pkg := range lock.Packages {
if _, ok := prodDeps[pkg.Name]; !ok { if !prodDeps.Contains(pkg.Name) {
continue continue
} }
@@ -127,7 +128,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
relationship := ftypes.RelationshipIndirect relationship := ftypes.RelationshipIndirect
if pkg.isRoot() { if pkg.isRoot() {
relationship = ftypes.RelationshipRoot relationship = ftypes.RelationshipRoot
} else if _, ok := directDeps[pkg.Name]; ok { } else if directDeps.Contains(pkg.Name) {
relationship = ftypes.RelationshipDirect relationship = ftypes.RelationshipDirect
} }

View File

@@ -10,19 +10,6 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
) )
func UniqueStrings(ss []string) []string {
var results []string
uniq := make(map[string]struct{})
for _, s := range ss {
if _, ok := uniq[s]; ok {
continue
}
results = append(results, s)
uniq[s] = struct{}{}
}
return results
}
func UniquePackages(pkgs []ftypes.Package) []ftypes.Package { func UniquePackages(pkgs []ftypes.Package) []ftypes.Package {
if len(pkgs) == 0 { if len(pkgs) == 0 {
return nil return nil

View File

@@ -19,6 +19,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency" "github.com/aquasecurity/trivy/pkg/dependency"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/set"
) )
const ( const (
@@ -179,33 +180,30 @@ func (a alpineCmdAnalyzer) parseCommand(command string, envs map[string]string)
return pkgs return pkgs
} }
func (a alpineCmdAnalyzer) resolveDependencies(apkIndexArchive *apkIndex, originalPkgs []string) (pkgs []string) { func (a alpineCmdAnalyzer) resolveDependencies(apkIndexArchive *apkIndex, originalPkgs []string) (pkgs []string) {
uniqPkgs := make(map[string]struct{}) uniqPkgs := set.New[string]()
for _, pkgName := range originalPkgs { for _, pkgName := range originalPkgs {
if _, ok := uniqPkgs[pkgName]; ok { if uniqPkgs.Contains(pkgName) {
continue continue
} }
seenPkgs := make(map[string]struct{}) seenPkgs := set.New[string]()
for _, p := range a.resolveDependency(apkIndexArchive, pkgName, seenPkgs) { for _, p := range a.resolveDependency(apkIndexArchive, pkgName, seenPkgs) {
uniqPkgs[p] = struct{}{} uniqPkgs.Append(p)
} }
} }
for pkg := range uniqPkgs { return uniqPkgs.Items()
pkgs = append(pkgs, pkg)
}
return pkgs
} }
func (a alpineCmdAnalyzer) resolveDependency(apkIndexArchive *apkIndex, pkgName string, func (a alpineCmdAnalyzer) resolveDependency(apkIndexArchive *apkIndex, pkgName string,
seenPkgs map[string]struct{}) (pkgNames []string) { seenPkgs set.Set[string]) (pkgNames []string) {
pkg, ok := apkIndexArchive.Package[pkgName] pkg, ok := apkIndexArchive.Package[pkgName]
if !ok { if !ok {
return nil return nil
} }
if _, ok = seenPkgs[pkgName]; ok { if seenPkgs.Contains(pkgName) {
return nil return nil
} }
seenPkgs[pkgName] = struct{}{} seenPkgs.Append(pkgName)
pkgNames = append(pkgNames, pkgName) pkgNames = append(pkgNames, pkgName)
for _, dependency := range pkg.Dependencies { for _, dependency := range pkg.Dependencies {

View File

@@ -19,6 +19,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/set"
) )
var ( var (
@@ -1508,86 +1509,86 @@ func TestResolveDependency(t *testing.T) {
var tests = map[string]struct { var tests = map[string]struct {
pkgName string pkgName string
apkIndexArchivePath string apkIndexArchivePath string
expected map[string]struct{} want set.Set[string]
}{ }{
"low": { "low": {
pkgName: "libblkid", pkgName: "libblkid",
apkIndexArchivePath: "testdata/history_v3.9.json", apkIndexArchivePath: "testdata/history_v3.9.json",
expected: map[string]struct{}{ want: set.New(
"libblkid": {}, "libblkid",
"libuuid": {}, "libuuid",
"musl": {}, "musl",
}, ),
}, },
"medium": { "medium": {
pkgName: "libgcab", pkgName: "libgcab",
apkIndexArchivePath: "testdata/history_v3.9.json", apkIndexArchivePath: "testdata/history_v3.9.json",
expected: map[string]struct{}{ want: set.New(
"busybox": {}, "busybox",
"libblkid": {}, "libblkid",
"libuuid": {}, "libuuid",
"musl": {}, "musl",
"libmount": {}, "libmount",
"pcre": {}, "pcre",
"glib": {}, "glib",
"libgcab": {}, "libgcab",
"libintl": {}, "libintl",
"zlib": {}, "zlib",
"libffi": {}, "libffi",
}, ),
}, },
"high": { "high": {
pkgName: "postgresql", pkgName: "postgresql",
apkIndexArchivePath: "testdata/history_v3.9.json", apkIndexArchivePath: "testdata/history_v3.9.json",
expected: map[string]struct{}{ want: set.New(
"busybox": {}, "busybox",
"ncurses-terminfo-base": {}, "ncurses-terminfo-base",
"ncurses-terminfo": {}, "ncurses-terminfo",
"libedit": {}, "libedit",
"db": {}, "db",
"libsasl": {}, "libsasl",
"libldap": {}, "libldap",
"libpq": {}, "libpq",
"postgresql-client": {}, "postgresql-client",
"tzdata": {}, "tzdata",
"libxml2": {}, "libxml2",
"postgresql": {}, "postgresql",
"musl": {}, "musl",
"libcrypto1.1": {}, "libcrypto1.1",
"libssl1.1": {}, "libssl1.1",
"ncurses-libs": {}, "ncurses-libs",
"zlib": {}, "zlib",
}, ),
}, },
"package alias": { "package alias": {
pkgName: "sqlite-dev", pkgName: "sqlite-dev",
apkIndexArchivePath: "testdata/history_v3.9.json", apkIndexArchivePath: "testdata/history_v3.9.json",
expected: map[string]struct{}{ want: set.New(
"sqlite-dev": {}, "sqlite-dev",
"sqlite-libs": {}, "sqlite-libs",
"pkgconf": {}, // pkgconfig => pkgconf "pkgconf", // pkgconfig => pkgconf
"musl": {}, "musl",
}, ),
}, },
"circular dependencies": { "circular dependencies": {
pkgName: "nodejs", pkgName: "nodejs",
apkIndexArchivePath: "testdata/history_v3.7.json", apkIndexArchivePath: "testdata/history_v3.7.json",
expected: map[string]struct{}{ want: set.New(
"busybox": {}, "busybox",
"c-ares": {}, "c-ares",
"ca-certificates": {}, "ca-certificates",
"http-parser": {}, "http-parser",
"libcrypto1.0": {}, "libcrypto1.0",
"libgcc": {}, "libgcc",
"libressl2.6-libcrypto": {}, "libressl2.6-libcrypto",
"libssl1.0": {}, "libssl1.0",
"libstdc++": {}, "libstdc++",
"libuv": {}, "libuv",
"musl": {}, "musl",
"nodejs": {}, "nodejs",
"nodejs-npm": {}, "nodejs-npm",
"zlib": {}, "zlib",
}, ),
}, },
} }
analyzer := alpineCmdAnalyzer{} analyzer := alpineCmdAnalyzer{}
@@ -1600,15 +1601,10 @@ func TestResolveDependency(t *testing.T) {
if err = json.NewDecoder(f).Decode(&apkIndexArchive); err != nil { if err = json.NewDecoder(f).Decode(&apkIndexArchive); err != nil {
t.Fatalf("unexpected error: %s", err) t.Fatalf("unexpected error: %s", err)
} }
circularDependencyCheck := make(map[string]struct{}) circularDependencyCheck := set.New[string]()
pkgs := analyzer.resolveDependency(apkIndexArchive, v.pkgName, circularDependencyCheck) pkgs := analyzer.resolveDependency(apkIndexArchive, v.pkgName, circularDependencyCheck)
actual := make(map[string]struct{}) got := set.New(pkgs...)
for _, pkg := range pkgs { assert.Equal(t, v.want, got, testName)
actual[pkg] = struct{}{}
}
if !reflect.DeepEqual(v.expected, actual) {
t.Errorf("[%s]\n%s", testName, pretty.Compare(v.expected, actual))
}
} }
} }

View File

@@ -17,6 +17,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/analyzer/language" "github.com/aquasecurity/trivy/pkg/fanal/analyzer/language"
"github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/utils/fsutils" "github.com/aquasecurity/trivy/pkg/utils/fsutils"
) )
@@ -105,7 +106,7 @@ func (a poetryAnalyzer) mergePyProject(fsys fs.FS, dir string, app *types.Applic
// Identify the direct/transitive dependencies // Identify the direct/transitive dependencies
for i, pkg := range app.Packages { for i, pkg := range app.Packages {
if _, ok := project.Tool.Poetry.Dependencies[pkg.Name]; ok { if project.Tool.Poetry.Dependencies.Contains(pkg.Name) {
app.Packages[i].Relationship = types.RelationshipDirect app.Packages[i].Relationship = types.RelationshipDirect
} else { } else {
app.Packages[i].Indirect = true app.Packages[i].Indirect = true
@@ -122,34 +123,33 @@ func filterProdPackages(project pyproject.PyProject, app *types.Application) {
return pkg.ID, pkg return pkg.ID, pkg
}) })
visited := make(map[string]struct{}) visited := set.New[string]()
deps := project.Tool.Poetry.Dependencies deps := project.Tool.Poetry.Dependencies
for group, groupDeps := range project.Tool.Poetry.Groups { for group, groupDeps := range project.Tool.Poetry.Groups {
if group == "dev" { if group == "dev" {
continue continue
} }
deps = lo.Assign(deps, groupDeps.Dependencies) deps.Set = deps.Union(groupDeps.Dependencies)
} }
for _, pkg := range packages { for _, pkg := range packages {
if _, prodDep := deps[pkg.Name]; !prodDep { if !deps.Contains(pkg.Name) {
continue continue
} }
walkPackageDeps(pkg.ID, packages, visited) walkPackageDeps(pkg.ID, packages, visited)
} }
app.Packages = lo.Filter(app.Packages, func(pkg types.Package, _ int) bool { app.Packages = lo.Filter(app.Packages, func(pkg types.Package, _ int) bool {
_, ok := visited[pkg.ID] return visited.Contains(pkg.ID)
return ok
}) })
} }
func walkPackageDeps(pkgID string, packages map[string]types.Package, visited map[string]struct{}) { func walkPackageDeps(pkgID string, packages map[string]types.Package, visited set.Set[string]) {
if _, ok := visited[pkgID]; ok { if visited.Contains(pkgID) {
return return
} }
visited[pkgID] = struct{}{} visited.Append(pkgID)
for _, dep := range packages[pkgID].DependsOn { for _, dep := range packages[pkgID].DependsOn {
walkPackageDeps(dep, packages, visited) walkPackageDeps(dep, packages, visited)
} }

View File

@@ -20,6 +20,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/licensing" "github.com/aquasecurity/trivy/pkg/licensing"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
) )
func init() { func init() {
@@ -185,13 +186,13 @@ func (a alpinePkgAnalyzer) consolidateDependencies(pkgs []types.Package, provide
} }
func (a alpinePkgAnalyzer) uniquePkgs(pkgs []types.Package) (uniqPkgs []types.Package) { func (a alpinePkgAnalyzer) uniquePkgs(pkgs []types.Package) (uniqPkgs []types.Package) {
uniq := make(map[string]struct{}) uniq := set.New[string]()
for _, pkg := range pkgs { for _, pkg := range pkgs {
if _, ok := uniq[pkg.Name]; ok { if uniq.Contains(pkg.Name) {
continue continue
} }
uniqPkgs = append(uniqPkgs, pkg) uniqPkgs = append(uniqPkgs, pkg)
uniq[pkg.Name] = struct{}{} uniq.Append(pkg.Name)
} }
return uniqPkgs return uniqPkgs
} }

View File

@@ -226,9 +226,9 @@ func (img *image) imageConfig(config *container.Config) v1.Config {
} }
if len(config.ExposedPorts) > 0 { if len(config.ExposedPorts) > 0 {
c.ExposedPorts = make(map[string]struct{}) c.ExposedPorts = make(map[string]struct{}) //nolint: gocritic
for port := range c.ExposedPorts { for port := range config.ExposedPorts {
c.ExposedPorts[port] = struct{}{} c.ExposedPorts[port.Port()] = struct{}{}
} }
} }

View File

@@ -56,14 +56,6 @@ func IsGzip(f *bufio.Reader) bool {
return buf[0] == 0x1F && buf[1] == 0x8B && buf[2] == 0x8 return buf[0] == 0x1F && buf[1] == 0x8B && buf[2] == 0x8
} }
func Keys(m map[string]struct{}) []string {
var keys []string
for k := range m {
keys = append(keys, k)
}
return keys
}
func IsExecutable(fileInfo os.FileInfo) bool { func IsExecutable(fileInfo os.FileInfo) bool {
// For Windows // For Windows
if filepath.Ext(fileInfo.Name()) == ".exe" { if filepath.Ext(fileInfo.Name()) == ".exe" {

View File

@@ -13,6 +13,7 @@ import (
checks "github.com/aquasecurity/trivy-checks" checks "github.com/aquasecurity/trivy-checks"
"github.com/aquasecurity/trivy/pkg/iac/rules" "github.com/aquasecurity/trivy/pkg/iac/rules"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
) )
var LoadAndRegister = sync.OnceFunc(func() { var LoadAndRegister = sync.OnceFunc(func() {
@@ -49,7 +50,7 @@ func RegisterRegoRules(modules map[string]*ast.Module) {
} }
retriever := NewMetadataRetriever(compiler) retriever := NewMetadataRetriever(compiler)
regoCheckIDs := make(map[string]struct{}) regoCheckIDs := set.New[string]()
for _, module := range modules { for _, module := range modules {
metadata, err := retriever.RetrieveMetadata(ctx, module) metadata, err := retriever.RetrieveMetadata(ctx, module)
@@ -66,7 +67,7 @@ func RegisterRegoRules(modules map[string]*ast.Module) {
} }
if !metadata.Deprecated { if !metadata.Deprecated {
regoCheckIDs[metadata.AVDID] = struct{}{} regoCheckIDs.Append(metadata.AVDID)
} }
rules.Register(metadata.ToRule()) rules.Register(metadata.ToRule())

View File

@@ -12,16 +12,13 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
) )
var builtinNamespaces = map[string]struct{}{ var builtinNamespaces = set.New("builtin", "defsec", "appshield")
"builtin": {},
"defsec": {},
"appshield": {},
}
func BuiltinNamespaces() []string { func BuiltinNamespaces() []string {
return lo.Keys(builtinNamespaces) return builtinNamespaces.Items()
} }
func IsBuiltinNamespace(namespace string) bool { func IsBuiltinNamespace(namespace string) bool {
@@ -122,15 +119,12 @@ func (s *Scanner) LoadPolicies(srcFS fs.FS) error {
} }
// gather namespaces // gather namespaces
uniq := make(map[string]struct{}) uniq := set.New[string]()
for _, module := range s.policies { for _, module := range s.policies {
namespace := getModuleNamespace(module) namespace := getModuleNamespace(module)
uniq[namespace] = struct{}{} uniq.Append(namespace)
}
var namespaces []string
for namespace := range uniq {
namespaces = append(namespaces, namespace)
} }
namespaces := uniq.Items()
dataFS := srcFS dataFS := srcFS
if s.dataFS != nil { if s.dataFS != nil {
@@ -296,7 +290,7 @@ func (s *Scanner) filterModules(retriever *MetadataRetriever) error {
} }
if IsBuiltinNamespace(getModuleNamespace(module)) { if IsBuiltinNamespace(getModuleNamespace(module)) {
if _, disabled := s.disabledCheckIDs[meta.ID]; disabled { // ignore builtin disabled checks if s.disabledCheckIDs.Contains(meta.ID) { // ignore builtin disabled checks
continue continue
} }
} }

View File

@@ -69,9 +69,7 @@ func WithDataDirs(paths ...string) options.ScannerOption {
func WithPolicyNamespaces(namespaces ...string) options.ScannerOption { func WithPolicyNamespaces(namespaces ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) { return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok { if ss, ok := s.(*Scanner); ok {
for _, namespace := range namespaces { ss.ruleNamespaces.Append(namespaces...)
ss.ruleNamespaces[namespace] = struct{}{}
}
} }
} }
} }
@@ -112,9 +110,7 @@ func WithCustomSchemas(schemas map[string][]byte) options.ScannerOption {
func WithDisabledCheckIDs(ids ...string) options.ScannerOption { func WithDisabledCheckIDs(ids ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) { return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok { if ss, ok := s.(*Scanner); ok {
for _, id := range ids { ss.disabledCheckIDs.Append(ids...)
ss.disabledCheckIDs[id] = struct{}{}
}
} }
} }
} }

View File

@@ -121,7 +121,7 @@ func parseLineNumber(raw any) int {
return n return n
} }
func (s *Scanner) convertResults(set rego.ResultSet, input Input, namespace, rule string, traces []string) scan.Results { func (s *Scanner) convertResults(resultSet rego.ResultSet, input Input, namespace, rule string, traces []string) scan.Results {
var results scan.Results var results scan.Results
offset := 0 offset := 0
@@ -136,7 +136,7 @@ func (s *Scanner) convertResults(set rego.ResultSet, input Input, namespace, rul
} }
} }
} }
for _, result := range set { for _, result := range resultSet {
for _, expression := range result.Expressions { for _, expression := range result.Expressions {
values, ok := expression.Value.([]any) values, ok := expression.Value.([]any)
if !ok { if !ok {

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"maps"
"strings" "strings"
"github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/ast"
@@ -22,29 +21,26 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/scanners/options" "github.com/aquasecurity/trivy/pkg/iac/scanners/options"
"github.com/aquasecurity/trivy/pkg/iac/types" "github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
) )
var checkTypesWithSubtype = map[types.Source]struct{}{ var checkTypesWithSubtype = set.New[types.Source](types.SourceCloud, types.SourceDefsec, types.SourceKubernetes)
types.SourceCloud: {},
types.SourceDefsec: {},
types.SourceKubernetes: {},
}
var supportedProviders = makeSupportedProviders() var supportedProviders = makeSupportedProviders()
func makeSupportedProviders() map[string]struct{} { func makeSupportedProviders() set.Set[string] {
m := make(map[string]struct{}) m := set.New[string]()
for _, p := range providers.AllProviders() { for _, p := range providers.AllProviders() {
m[string(p)] = struct{}{} m.Append(string(p))
} }
m["kind"] = struct{}{} // kubernetes m.Append("kind") // kubernetes
return m return m
} }
var _ options.ConfigurableScanner = (*Scanner)(nil) var _ options.ConfigurableScanner = (*Scanner)(nil)
type Scanner struct { type Scanner struct {
ruleNamespaces map[string]struct{} ruleNamespaces set.Set[string]
policies map[string]*ast.Module policies map[string]*ast.Module
store storage.Store store storage.Store
runtimeValues *ast.Term runtimeValues *ast.Term
@@ -70,7 +66,7 @@ type Scanner struct {
embeddedChecks map[string]*ast.Module embeddedChecks map[string]*ast.Module
customSchemas map[string][]byte customSchemas map[string][]byte
disabledCheckIDs map[string]struct{} disabledCheckIDs set.Set[string]
} }
func (s *Scanner) trace(heading string, input any) { func (s *Scanner) trace(heading string, input any) {
@@ -103,15 +99,13 @@ func NewScanner(source types.Source, opts ...options.ScannerOption) *Scanner {
s := &Scanner{ s := &Scanner{
regoErrorLimit: ast.CompileErrorLimitDefault, regoErrorLimit: ast.CompileErrorLimitDefault,
sourceType: source, sourceType: source,
ruleNamespaces: make(map[string]struct{}), ruleNamespaces: builtinNamespaces.Clone(),
runtimeValues: addRuntimeValues(), runtimeValues: addRuntimeValues(),
logger: log.WithPrefix("rego"), logger: log.WithPrefix("rego"),
customSchemas: make(map[string][]byte), customSchemas: make(map[string][]byte),
disabledCheckIDs: make(map[string]struct{}), disabledCheckIDs: set.New[string](),
} }
maps.Copy(s.ruleNamespaces, builtinNamespaces)
for _, opt := range opts { for _, opt := range opts {
opt(s) opt(s)
} }
@@ -147,7 +141,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, d
} }
instance := rego.New(regoOptions...) instance := rego.New(regoOptions...)
set, err := instance.Eval(ctx) resultSet, err := instance.Eval(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -165,7 +159,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, d
traces = strings.Split(traceBuffer.String(), "\n") traces = strings.Split(traceBuffer.String(), "\n")
} }
} }
return set, traces, nil return resultSet, traces, nil
} }
type Input struct { type Input struct {
@@ -198,7 +192,7 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
namespace := getModuleNamespace(module) namespace := getModuleNamespace(module)
topLevel := strings.Split(namespace, ".")[0] topLevel := strings.Split(namespace, ".")[0]
if _, ok := s.ruleNamespaces[topLevel]; !ok { if !s.ruleNamespaces.Contains(topLevel) {
continue continue
} }
@@ -227,15 +221,15 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
continue continue
} }
usedRules := make(map[string]struct{}) usedRules := set.New[string]()
// all rules // all rules
for _, rule := range module.Rules { for _, rule := range module.Rules {
ruleName := rule.Head.Name.String() ruleName := rule.Head.Name.String()
if _, ok := usedRules[ruleName]; ok { if usedRules.Contains(ruleName) {
continue continue
} }
usedRules[ruleName] = struct{}{} usedRules.Append(ruleName)
if isEnforcedRule(ruleName) { if isEnforcedRule(ruleName) {
ruleResults, err := s.applyRule(ctx, namespace, ruleName, inputs) ruleResults, err := s.applyRule(ctx, namespace, ruleName, inputs)
if err != nil { if err != nil {
@@ -257,8 +251,7 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
} }
func isPolicyWithSubtype(sourceType types.Source) bool { func isPolicyWithSubtype(sourceType types.Source) bool {
_, exists := checkTypesWithSubtype[sourceType] return checkTypesWithSubtype.Contains(sourceType)
return exists
} }
func checkSubtype(ii map[string]any, provider string, subTypes []SubType) bool { func checkSubtype(ii map[string]any, provider string, subTypes []SubType) bool {
@@ -290,7 +283,7 @@ func isPolicyApplicable(staticMetadata *StaticMetadata, inputs ...Input) bool {
for _, input := range inputs { for _, input := range inputs {
if ii, ok := input.Contents.(map[string]any); ok { if ii, ok := input.Contents.(map[string]any); ok {
for provider := range ii { for provider := range ii {
if _, exists := supportedProviders[provider]; !exists { if !supportedProviders.Contains(provider) {
continue continue
} }
@@ -329,12 +322,12 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs
continue continue
} }
set, traces, err := s.runQuery(ctx, qualified, parsedInput, false) resultSet, traces, err := s.runQuery(ctx, qualified, parsedInput, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.trace("RESULTSET", set) s.trace("RESULTSET", resultSet)
ruleResults := s.convertResults(set, input, namespace, rule, traces) ruleResults := s.convertResults(resultSet, input, namespace, rule, traces)
if len(ruleResults) == 0 { // It passed because we didn't find anything wrong (NOT because it didn't exist) if len(ruleResults) == 0 { // It passed because we didn't find anything wrong (NOT because it didn't exist)
var result regoResult var result regoResult
result.FS = input.FS result.FS = input.FS

View File

@@ -10,6 +10,7 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/scan" "github.com/aquasecurity/trivy/pkg/iac/scan"
dftypes "github.com/aquasecurity/trivy/pkg/iac/types" dftypes "github.com/aquasecurity/trivy/pkg/iac/types"
ruleTypes "github.com/aquasecurity/trivy/pkg/iac/types/rules" ruleTypes "github.com/aquasecurity/trivy/pkg/iac/types/rules"
"github.com/aquasecurity/trivy/pkg/set"
) )
type registry struct { type registry struct {
@@ -74,14 +75,14 @@ func (r *registry) getFrameworkRules(fw ...framework.Framework) []ruleTypes.Regi
if len(fw) == 0 { if len(fw) == 0 {
fw = []framework.Framework{framework.Default} fw = []framework.Framework{framework.Default}
} }
unique := make(map[int]struct{}) unique := set.New[int]()
for _, f := range fw { for _, f := range fw {
for _, rule := range r.frameworks[f] { for _, rule := range r.frameworks[f] {
if _, ok := unique[rule.Number]; ok { if unique.Contains(rule.Number) {
continue continue
} }
registered = append(registered, rule) registered = append(registered, rule)
unique[rule.Number] = struct{}{} unique.Append(rule.Number)
} }
} }
return registered return registered

View File

@@ -19,6 +19,7 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/terraform" "github.com/aquasecurity/trivy/pkg/iac/terraform"
"github.com/aquasecurity/trivy/pkg/iac/types" "github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
) )
var _ scanners.FSScanner = (*Scanner)(nil) var _ scanners.FSScanner = (*Scanner)(nil)
@@ -31,7 +32,7 @@ type Scanner struct {
options []options.ScannerOption options []options.ScannerOption
parserOpt []parser.Option parserOpt []parser.Option
executorOpt []executor.Option executorOpt []executor.Option
dirs map[string]struct{} dirs set.Set[string]
forceAllDirs bool forceAllDirs bool
regoScanner *rego.Scanner regoScanner *rego.Scanner
execLock sync.RWMutex execLock sync.RWMutex
@@ -55,7 +56,7 @@ func (s *Scanner) AddExecutorOptions(opts ...executor.Option) {
func New(opts ...options.ScannerOption) *Scanner { func New(opts ...options.ScannerOption) *Scanner {
s := &Scanner{ s := &Scanner{
dirs: make(map[string]struct{}), dirs: set.New[string](),
options: opts, options: opts,
logger: log.WithPrefix("terraform scanner"), logger: log.WithPrefix("terraform scanner"),
} }

View File

@@ -7,6 +7,8 @@ import (
"github.com/liamg/memoryfs" "github.com/liamg/memoryfs"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/aquasecurity/trivy/pkg/set"
) )
func Test_FSKey(t *testing.T) { func Test_FSKey(t *testing.T) {
@@ -18,22 +20,20 @@ func Test_FSKey(t *testing.T) {
memoryfs.New(), memoryfs.New(),
} }
keys := make(map[string]struct{}) keys := set.New[string]()
t.Run("uniqueness", func(t *testing.T) { t.Run("uniqueness", func(t *testing.T) {
for _, system := range systems { for _, system := range systems {
key := CreateFSKey(system) key := CreateFSKey(system)
_, ok := keys[key] assert.False(t, keys.Contains(key), "filesystem keys should be unique")
assert.False(t, ok, "filesystem keys should be unique") keys.Append(key)
keys[key] = struct{}{}
} }
}) })
t.Run("reproducible", func(t *testing.T) { t.Run("reproducible", func(t *testing.T) {
for _, system := range systems { for _, system := range systems {
key := CreateFSKey(system) key := CreateFSKey(system)
_, ok := keys[key] assert.True(t, keys.Contains(key), "filesystem keys should be reproducible")
assert.True(t, ok, "filesystem keys should be reproducible")
} }
}) })
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
) )
var ( var (
@@ -43,7 +44,7 @@ func Classify(filePath string, r io.Reader, confidenceLevel float64) (*types.Lic
var findings types.LicenseFindings var findings types.LicenseFindings
var matchType types.LicenseType var matchType types.LicenseType
seen := make(map[string]struct{}) seen := set.New[string]()
// cf.Match is not thread safe // cf.Match is not thread safe
m.Lock() m.Lock()
@@ -57,11 +58,11 @@ func Classify(filePath string, r io.Reader, confidenceLevel float64) (*types.Lic
if match.Confidence <= confidenceLevel { if match.Confidence <= confidenceLevel {
continue continue
} }
if _, ok := seen[match.Name]; ok { if seen.Contains(match.Name) {
continue continue
} }
seen[match.Name] = struct{}{} seen.Append(match.Name)
switch match.MatchType { switch match.MatchType {
case "Header": case "Header":

View File

@@ -20,6 +20,7 @@ import (
"github.com/aquasecurity/testdocker/registry" "github.com/aquasecurity/testdocker/registry"
"github.com/aquasecurity/testdocker/tarfile" "github.com/aquasecurity/testdocker/tarfile"
"github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/version/app" "github.com/aquasecurity/trivy/pkg/version/app"
) )
@@ -216,13 +217,13 @@ type userAgentsTrackingHandler struct {
hr http.Handler hr http.Handler
mu sync.Mutex mu sync.Mutex
agents map[string]struct{} agents set.Set[string]
} }
func newUserAgentsTrackingHandler(hr http.Handler) *userAgentsTrackingHandler { func newUserAgentsTrackingHandler(hr http.Handler) *userAgentsTrackingHandler {
return &userAgentsTrackingHandler{ return &userAgentsTrackingHandler{
hr: hr, hr: hr,
agents: make(map[string]struct{}), agents: set.New[string](),
} }
} }
@@ -230,7 +231,7 @@ func (uh *userAgentsTrackingHandler) ServeHTTP(rw http.ResponseWriter, r *http.R
for _, agent := range r.Header["User-Agent"] { for _, agent := range r.Header["User-Agent"] {
// Skip test framework user agent // Skip test framework user agent
if agent != "Go-http-client/1.1" { if agent != "Go-http-client/1.1" {
uh.agents[agent] = struct{}{} uh.agents.Append(agent)
} }
} }
uh.hr.ServeHTTP(rw, r) uh.hr.ServeHTTP(rw, r)
@@ -271,7 +272,7 @@ func TestUserAgents(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, tracker.agents, 1) require.Len(t, tracker.agents, 1)
_, ok := tracker.agents[fmt.Sprintf("trivy/%s go-containerregistry", app.Version())] ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s go-containerregistry", app.Version()))
require.True(t, ok, `user-agent header equals to "trivy/dev go-containerregistry"`) require.True(t, ok, `user-agent header equals to "trivy/dev go-containerregistry"`)
} }

View File

@@ -19,6 +19,7 @@ import (
dbTypes "github.com/aquasecurity/trivy-db/pkg/types" dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/version/doc" "github.com/aquasecurity/trivy/pkg/version/doc"
) )
@@ -279,7 +280,7 @@ Dependency Origin Tree (Reversed)
topLvlID := tml.Sprintf("<red>%s, (%s)</red>", vulnPkg.ID, strings.Join(summaries, ", ")) topLvlID := tml.Sprintf("<red>%s, (%s)</red>", vulnPkg.ID, strings.Join(summaries, ", "))
branch := root.AddBranch(topLvlID) branch := root.AddBranch(topLvlID)
addParents(branch, vulnPkg, parents, ancestors, map[string]struct{}{vulnPkg.ID: {}}, 1) addParents(branch, vulnPkg, parents, ancestors, set.New(vulnPkg.ID), 1)
} }
r.printf(root.String()) r.printf(root.String())
@@ -291,17 +292,17 @@ func (r *vulnerabilityRenderer) printf(format string, args ...any) {
} }
func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string]ftypes.Packages, ancestors map[string][]string, func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string]ftypes.Packages, ancestors map[string][]string,
seen map[string]struct{}, depth int) { seen set.Set[string], depth int) {
if pkg.Relationship == ftypes.RelationshipDirect { if pkg.Relationship == ftypes.RelationshipDirect {
return return
} }
roots := make(map[string]struct{}) roots := set.New[string]()
for _, parent := range parentMap[pkg.ID] { for _, parent := range parentMap[pkg.ID] {
if _, ok := seen[parent.ID]; ok { if seen.Contains(parent.ID) {
continue continue
} }
seen[parent.ID] = struct{}{} // to avoid infinite loops seen.Append(parent.ID) // to avoid infinite loops
if depth == 1 && parent.Relationship == ftypes.RelationshipDirect { if depth == 1 && parent.Relationship == ftypes.RelationshipDirect {
topItem.AddBranch(parent.ID) topItem.AddBranch(parent.ID)
@@ -309,16 +310,13 @@ func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string
// We omit intermediate dependencies and show only direct dependencies // We omit intermediate dependencies and show only direct dependencies
// as this could make the dependency tree huge. // as this could make the dependency tree huge.
for _, ancestor := range ancestors[parent.ID] { for _, ancestor := range ancestors[parent.ID] {
roots[ancestor] = struct{}{} roots.Append(ancestor)
} }
} }
} }
// Omitted // Omitted
rootIDs := lo.Filter(lo.Keys(roots), func(pkgID string, _ int) bool { rootIDs := roots.Difference(seen).Items()
_, ok := seen[pkgID]
return !ok
})
sort.Strings(rootIDs) sort.Strings(rootIDs)
if len(rootIDs) > 0 { if len(rootIDs) > 0 {
branch := topItem.AddBranch("...(omitted)...") branch := topItem.AddBranch("...(omitted)...")
@@ -331,21 +329,21 @@ func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string
func traverseAncestors(pkgs []ftypes.Package, parentMap map[string]ftypes.Packages) map[string][]string { func traverseAncestors(pkgs []ftypes.Package, parentMap map[string]ftypes.Packages) map[string][]string {
ancestors := make(map[string][]string) ancestors := make(map[string][]string)
for _, pkg := range pkgs { for _, pkg := range pkgs {
ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, make(map[string]struct{})) ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, set.New[string]())
} }
return ancestors return ancestors
} }
func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen map[string]struct{}) []string { func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen set.Set[string]) []string {
ancestors := make(map[string]struct{}) ancestors := set.New[string]()
seen[pkgID] = struct{}{} seen.Append(pkgID)
for _, parent := range parentMap[pkgID] { for _, parent := range parentMap[pkgID] {
if _, ok := seen[parent.ID]; ok { if seen.Contains(parent.ID) {
continue continue
} }
switch { switch {
case parent.Relationship == ftypes.RelationshipDirect: case parent.Relationship == ftypes.RelationshipDirect:
ancestors[parent.ID] = struct{}{} ancestors.Append(parent.ID)
case len(parentMap[parent.ID]) == 0: case len(parentMap[parent.ID]) == 0:
// Some package managers, such as "package-lock.json" v1, can retrieve package dependencies but not relationships. // Some package managers, such as "package-lock.json" v1, can retrieve package dependencies but not relationships.
// We try to guess direct dependencies in this case. A dependency with no parents must be a direct dependency. // We try to guess direct dependencies in this case. A dependency with no parents must be a direct dependency.
@@ -358,14 +356,14 @@ func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen map[s
// //
// Even if `styled-components` is not marked as a direct dependency, it must be a direct dependency // Even if `styled-components` is not marked as a direct dependency, it must be a direct dependency
// as it has no parents. Note that it doesn't mean `fbjs` is an indirect dependency. // as it has no parents. Note that it doesn't mean `fbjs` is an indirect dependency.
ancestors[parent.ID] = struct{}{} ancestors.Append(parent.ID)
default: default:
for _, a := range findAncestor(parent.ID, parentMap, seen) { for _, a := range findAncestor(parent.ID, parentMap, seen) {
ancestors[a] = struct{}{} ancestors.Append(a)
} }
} }
} }
return lo.Keys(ancestors) return ancestors.Items()
} }
var jarExtensions = []string{ var jarExtensions = []string{

View File

@@ -9,6 +9,7 @@ import (
"github.com/aquasecurity/trivy/pkg/detector/library" "github.com/aquasecurity/trivy/pkg/detector/library"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
) )
@@ -41,7 +42,7 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.
} }
var results types.Results var results types.Results
printedTypes := make(map[ftypes.LangType]struct{}) printedTypes := set.New[ftypes.LangType]()
for _, app := range apps { for _, app := range apps {
if len(app.Packages) == 0 { if len(app.Packages) == 0 {
continue continue
@@ -76,13 +77,13 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.
return results, nil return results, nil
} }
func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes map[ftypes.LangType]struct{}) ( func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes set.Set[ftypes.LangType]) (
[]types.DetectedVulnerability, error) { []types.DetectedVulnerability, error) {
// Prevent the same log messages from being displayed many times for the same type. // Prevent the same log messages from being displayed many times for the same type.
if _, ok := printedTypes[app.Type]; !ok { if !printedTypes.Contains(app.Type) {
log.InfoContext(ctx, "Detecting vulnerabilities...") log.InfoContext(ctx, "Detecting vulnerabilities...")
printedTypes[app.Type] = struct{}{} printedTypes.Append(app.Type)
} }
log.DebugContext(ctx, "Scanning packages for vulnerabilities", log.FilePath(app.FilePath)) log.DebugContext(ctx, "Scanning packages for vulnerabilities", log.FilePath(app.FilePath))

View File

@@ -24,6 +24,7 @@ import (
"github.com/aquasecurity/trivy/pkg/scanner/langpkg" "github.com/aquasecurity/trivy/pkg/scanner/langpkg"
"github.com/aquasecurity/trivy/pkg/scanner/ospkg" "github.com/aquasecurity/trivy/pkg/scanner/ospkg"
"github.com/aquasecurity/trivy/pkg/scanner/post" "github.com/aquasecurity/trivy/pkg/scanner/post"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/vulnerability" "github.com/aquasecurity/trivy/pkg/vulnerability"
@@ -458,12 +459,12 @@ func mergePkgs(pkgs, pkgsFromCommands []ftypes.Package, options types.ScanOption
} }
// pkg has priority over pkgsFromCommands // pkg has priority over pkgsFromCommands
uniqPkgs := make(map[string]struct{}) uniqPkgs := set.New[string]()
for _, pkg := range pkgs { for _, pkg := range pkgs {
uniqPkgs[pkg.Name] = struct{}{} uniqPkgs.Append(pkg.Name)
} }
for _, pkg := range pkgsFromCommands { for _, pkg := range pkgsFromCommands {
if _, ok := uniqPkgs[pkg.Name]; ok { if uniqPkgs.Contains(pkg.Name) {
continue continue
} }
pkgs = append(pkgs, pkg) pkgs = append(pkgs, pkg)

39
pkg/set/set.go Normal file
View File

@@ -0,0 +1,39 @@
package set
import "iter"
// Set defines the interface for set operations
type Set[T comparable] interface {
// Append adds multiple items to the set and returns the new size
Append(val ...T) int
// Remove removes an item from the set
Remove(item T)
// Contains checks if an item exists in the set
Contains(item T) bool
// Size returns the number of items in the set
Size() int
// Clear removes all items from the set
Clear()
// Clone returns a new set with a copy of all items
Clone() Set[T]
// Items returns all items in the set as a slice
Items() []T
// Iter returns an iterator over the set
Iter() iter.Seq[T]
// Union returns a new set containing all items from both sets
Union(other Set[T]) Set[T]
// Intersection returns a new set containing items present in both sets
Intersection(other Set[T]) Set[T]
// Difference returns a new set containing items present in this set but not in the other
Difference(other Set[T]) Set[T]
}

100
pkg/set/unsafe.go Normal file
View File

@@ -0,0 +1,100 @@
package set
import (
"iter"
"maps"
"slices"
)
// unsafeSet represents a non-thread-safe set implementation
// WARNING: This implementation is not thread-safe
type unsafeSet[T comparable] map[T]struct{} //nolint: gocritic
// New creates a new empty non-thread-safe set with optional initial values
func New[T comparable](values ...T) Set[T] {
s := make(unsafeSet[T])
for _, v := range values {
s[v] = struct{}{}
}
return s
}
// Append adds multiple items to the set and returns the new size
func (s unsafeSet[T]) Append(val ...T) int {
for _, item := range val {
s[item] = struct{}{}
}
return len(s)
}
// Remove removes an item from the set
func (s unsafeSet[T]) Remove(item T) {
delete(s, item)
}
// Contains checks if an item exists in the set
func (s unsafeSet[T]) Contains(item T) bool {
_, exists := s[item]
return exists
}
// Size returns the number of items in the set
func (s unsafeSet[T]) Size() int {
return len(s)
}
// Clear removes all items from the set
func (s unsafeSet[T]) Clear() {
for k := range s {
delete(s, k)
}
}
// Clone returns a new set with a copy of all items
func (s unsafeSet[T]) Clone() Set[T] {
return maps.Clone(s)
}
// Items returns all items in the set as a slice
func (s unsafeSet[T]) Items() []T {
return slices.Collect(s.Iter())
}
// Iter returns an iterator over the set
func (s unsafeSet[T]) Iter() iter.Seq[T] {
return maps.Keys(s)
}
// Union returns a new set containing all items from both sets
func (s unsafeSet[T]) Union(other Set[T]) Set[T] {
result := make(unsafeSet[T])
for k := range s {
result[k] = struct{}{}
}
for _, item := range other.Items() {
result[item] = struct{}{}
}
return result
}
// Intersection returns a new set containing items present in both sets
func (s unsafeSet[T]) Intersection(other Set[T]) Set[T] {
result := make(unsafeSet[T])
for k := range s {
if other.Contains(k) {
result[k] = struct{}{}
}
}
return result
}
// Difference returns a new set containing items present in this set but not in the other
func (s unsafeSet[T]) Difference(other Set[T]) Set[T] {
result := make(unsafeSet[T])
for k := range s {
if !other.Contains(k) {
result[k] = struct{}{}
}
}
return result
}

583
pkg/set/unsafe_test.go Normal file
View File

@@ -0,0 +1,583 @@
package set_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aquasecurity/trivy/pkg/set"
)
func Test_New(t *testing.T) {
tests := []struct {
name string
values []int
wantSize int
wantAll bool
desc string
}{
{
name: "new empty set",
values: []int{},
wantSize: 0,
wantAll: true,
desc: "should create empty set when no values provided",
},
{
name: "new set with single value",
values: []int{1},
wantSize: 1,
wantAll: true,
desc: "should create set with single value",
},
{
name: "new set with multiple values",
values: []int{
1,
2,
3,
},
wantSize: 3,
wantAll: true,
desc: "should create set with multiple values",
},
{
name: "new set with duplicate values",
values: []int{
1,
2,
2,
3,
3,
3,
},
wantSize: 3,
wantAll: true,
desc: "should create set with unique values only",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New(tt.values...)
assert.Equal(t, tt.wantSize, s.Size(), "unexpected set size")
})
}
}
func Test_unsafeSet_Add(t *testing.T) {
// Define custom type for struct test cases
type custom struct {
id int
name string
}
tests := []struct {
name string
prepare func(s set.Set[any])
input any
wantSize int
}{
{
name: "add integer",
prepare: nil,
input: 1,
wantSize: 1,
},
{
name: "add duplicate integer",
prepare: func(s set.Set[any]) {
s.Append(1)
},
input: 1,
wantSize: 1,
},
{
name: "add string",
prepare: nil,
input: "test",
wantSize: 1,
},
{
name: "add empty string",
prepare: nil,
input: "",
wantSize: 1,
},
{
name: "add custom struct",
prepare: nil,
input: custom{
id: 1,
name: "test1",
},
wantSize: 1,
},
{
name: "add nil pointer",
prepare: nil,
input: (*int)(nil),
wantSize: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[any]()
if tt.prepare != nil {
tt.prepare(s)
}
s.Append(tt.input)
got := s.Size()
assert.Equal(t, tt.wantSize, got, "unexpected set size")
assert.True(t, s.Contains(tt.input), "unexpected contains result for value: %v", tt.input)
})
}
}
func Test_unsafeSet_Append(t *testing.T) {
tests := []struct {
name string
prepare func(s set.Set[int])
input []int
wantSize int
}{
{
name: "append to empty set",
prepare: nil,
input: []int{
1,
2,
3,
},
wantSize: 3,
},
{
name: "append with duplicates",
prepare: func(s set.Set[int]) {
s.Append(1)
},
input: []int{
1,
2,
1,
3,
2,
},
wantSize: 3,
},
{
name: "append empty slice",
prepare: func(s set.Set[int]) {
s.Append(1)
},
input: []int{},
wantSize: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[int]()
if tt.prepare != nil {
tt.prepare(s)
}
got := s.Append(tt.input...)
assert.Equal(t, tt.wantSize, got, "unexpected returned size")
assert.Equal(t, tt.wantSize, s.Size(), "unexpected actual size")
for _, item := range tt.input {
assert.True(t, s.Contains(item), "set should contain appended item: %v", item)
}
})
}
}
func Test_unsafeSet_Remove(t *testing.T) {
tests := []struct {
name string
prepare func(s set.Set[int])
input int
wantSize int
}{
{
name: "remove existing element",
prepare: func(s set.Set[int]) {
s.Append(1)
},
input: 1,
wantSize: 0,
},
{
name: "remove non-existing element",
prepare: func(s set.Set[int]) {
s.Append(1)
},
input: 2,
wantSize: 1,
},
{
name: "remove from empty set",
prepare: nil,
input: 1,
wantSize: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[int]()
if tt.prepare != nil {
tt.prepare(s)
}
s.Remove(tt.input)
got := s.Size()
assert.Equal(t, tt.wantSize, got, "unexpected set size")
assert.False(t, s.Contains(tt.input), "unexpected contains result for value: %v", tt.input)
})
}
}
func Test_unsafeSet_Clear(t *testing.T) {
tests := []struct {
name string
prepare func(s set.Set[int])
}{
{
name: "clear non-empty set",
prepare: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
},
{
name: "clear empty set",
prepare: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[int]()
if tt.prepare != nil {
tt.prepare(s)
}
s.Clear()
got := s.Size()
assert.Zero(t, got, "unexpected set size")
assert.Empty(t, s.Items(), "items should be empty")
})
}
}
func Test_unsafeSet_Clone(t *testing.T) {
t.Run("empty set", func(t *testing.T) {
original := set.New[string]()
cloned := original.Clone()
assert.Equal(t, 0, cloned.Size(), "cloned set should be empty")
// Verify independence
original.Append("test")
assert.False(t, cloned.Contains("test"), "cloned set should not be affected by original")
})
t.Run("basic types", func(t *testing.T) {
original := set.New[any](1, "test", true)
cloned := original.Clone()
assert.Equal(t, original.Size(), cloned.Size(), "sizes should match")
assert.True(t, cloned.Contains(1), "should contain integer")
assert.True(t, cloned.Contains("test"), "should contain string")
assert.True(t, cloned.Contains(true), "should contain boolean")
// Verify independence
original.Append("new")
assert.False(t, cloned.Contains("new"), "cloned set should not be affected by original")
cloned.Append("another")
assert.False(t, original.Contains("another"), "original set should not be affected by clone")
})
// Test nil pointer
t.Run("nil pointer", func(t *testing.T) {
original := set.New[*int]()
original.Append(nil)
cloned := original.Clone()
assert.Equal(t, original.Size(), cloned.Size(), "sizes should match")
assert.True(t, cloned.Contains((*int)(nil)), "should contain nil pointer")
})
}
func Test_unsafeSet_Items(t *testing.T) {
tests := []struct {
name string
prepare func(s set.Set[int])
want []int
}{
{
name: "get items from non-empty set",
prepare: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
want: []int{
1,
2,
3,
},
},
{
name: "get items from empty set",
prepare: nil,
want: []int{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[int]()
if tt.prepare != nil {
tt.prepare(s)
}
got := s.Items()
assert.ElementsMatch(t, tt.want, got, "unexpected items in set")
})
}
}
func Test_unsafeSet_Union(t *testing.T) {
tests := []struct {
name string
prepare1 func(s set.Set[int])
prepare2 func(s set.Set[int])
want []int
}{
{
name: "union of non-overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: func(s set.Set[int]) {
s.Append(3)
s.Append(4)
},
want: []int{
1,
2,
3,
4,
},
},
{
name: "union of overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
prepare2: func(s set.Set[int]) {
s.Append(2)
s.Append(3)
s.Append(4)
},
want: []int{
1,
2,
3,
4,
},
},
{
name: "union with empty set",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: nil,
want: []int{
1,
2,
},
},
{
name: "union of empty sets",
prepare1: nil,
prepare2: nil,
want: []int{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s1 := set.New[int]()
s2 := set.New[int]()
if tt.prepare1 != nil {
tt.prepare1(s1)
}
if tt.prepare2 != nil {
tt.prepare2(s2)
}
result := s1.Union(s2)
got := result.Items()
assert.ElementsMatch(t, tt.want, got, "unexpected union result")
})
}
}
func Test_unsafeSet_Intersection(t *testing.T) {
tests := []struct {
name string
prepare1 func(s set.Set[int])
prepare2 func(s set.Set[int])
want []int
}{
{
name: "intersection of overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
prepare2: func(s set.Set[int]) {
s.Append(2)
s.Append(3)
s.Append(4)
},
want: []int{
2,
3,
},
},
{
name: "intersection of non-overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: func(s set.Set[int]) {
s.Append(3)
s.Append(4)
},
want: []int{},
},
{
name: "intersection with empty set",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: nil,
want: []int{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s1 := set.New[int]()
s2 := set.New[int]()
if tt.prepare1 != nil {
tt.prepare1(s1)
}
if tt.prepare2 != nil {
tt.prepare2(s2)
}
result := s1.Intersection(s2)
got := result.Items()
assert.ElementsMatch(t, tt.want, got, "unexpected intersection result")
})
}
}
func Test_unsafeSet_Difference(t *testing.T) {
tests := []struct {
name string
prepare1 func(s set.Set[int])
prepare2 func(s set.Set[int])
want []int
}{
{
name: "difference of overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
prepare2: func(s set.Set[int]) {
s.Append(2)
s.Append(3)
s.Append(4)
},
want: []int{1},
},
{
name: "difference with non-overlapping set",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: func(s set.Set[int]) {
s.Append(3)
s.Append(4)
},
want: []int{
1,
2,
},
},
{
name: "difference with empty set",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: nil,
want: []int{
1,
2,
},
},
{
name: "difference of empty set",
prepare1: nil,
prepare2: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
want: []int{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s1 := set.New[int]()
s2 := set.New[int]()
if tt.prepare1 != nil {
tt.prepare1(s1)
}
if tt.prepare2 != nil {
tt.prepare2(s2)
}
result := s1.Difference(s2)
got := result.Items()
assert.ElementsMatch(t, tt.want, got, "unexpected difference result")
})
}
}