mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-23 07:29:00 -08:00
refactor: add generic Set implementation (#8149)
Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
@@ -30,3 +30,8 @@ func errorsJoin(m dsl.Matcher) {
|
||||
m.Match(`errors.Join($*args)`).
|
||||
Report("use github.com/hashicorp/go-multierror.Append instead of errors.Join.")
|
||||
}
|
||||
|
||||
func mapSet(m dsl.Matcher) {
|
||||
m.Match(`map[$x]struct{}`).
|
||||
Report("use github.com/aquasecurity/trivy/pkg/set.Set instead of map.")
|
||||
}
|
||||
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"golang.org/x/xerrors"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
sp "github.com/aquasecurity/trivy-checks/pkg/spec"
|
||||
iacTypes "github.com/aquasecurity/trivy/pkg/iac/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
"github.com/aquasecurity/trivy/pkg/types"
|
||||
)
|
||||
|
||||
@@ -31,17 +31,17 @@ const (
|
||||
|
||||
// Scanners reads spec control and determines the scanners by check ID prefix
|
||||
func (cs *ComplianceSpec) Scanners() (types.Scanners, error) {
|
||||
scannerTypes := make(map[types.Scanner]struct{})
|
||||
scannerTypes := set.New[types.Scanner]()
|
||||
for _, control := range cs.Spec.Controls {
|
||||
for _, check := range control.Checks {
|
||||
scannerType := scannerByCheckID(check.ID)
|
||||
if scannerType == types.UnknownScanner {
|
||||
return nil, xerrors.Errorf("unsupported check ID: %s", check.ID)
|
||||
}
|
||||
scannerTypes[scannerType] = struct{}{}
|
||||
scannerTypes.Append(scannerType)
|
||||
}
|
||||
}
|
||||
return lo.Keys(scannerTypes), nil
|
||||
return scannerTypes.Items(), nil
|
||||
}
|
||||
|
||||
// CheckIDs return list of compliance check IDs
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
"github.com/aquasecurity/trivy/pkg/version/doc"
|
||||
)
|
||||
|
||||
@@ -30,7 +31,7 @@ type artifact struct {
|
||||
Version version
|
||||
Licenses []string
|
||||
|
||||
Exclusions map[string]struct{}
|
||||
Exclusions set.Set[string]
|
||||
|
||||
Module bool
|
||||
Relationship ftypes.Relationship
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
xio "github.com/aquasecurity/trivy/pkg/x/io"
|
||||
)
|
||||
|
||||
@@ -118,11 +119,11 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
|
||||
rootArt := root.artifact()
|
||||
rootArt.Relationship = ftypes.RelationshipRoot
|
||||
|
||||
return p.parseRoot(rootArt, make(map[string]struct{}))
|
||||
return p.parseRoot(rootArt, set.New[string]())
|
||||
}
|
||||
|
||||
// nolint: gocyclo
|
||||
func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ftypes.Package, []ftypes.Dependency, error) {
|
||||
func (p *Parser) parseRoot(root artifact, uniqModules set.Set[string]) ([]ftypes.Package, []ftypes.Dependency, error) {
|
||||
// Prepare a queue for dependencies
|
||||
queue := newArtifactQueue()
|
||||
|
||||
@@ -145,10 +146,10 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
|
||||
// Modules should be handled separately so that they can have independent dependencies.
|
||||
// It means multi-module allows for duplicate dependencies.
|
||||
if art.Module {
|
||||
if _, ok := uniqModules[art.String()]; ok {
|
||||
if uniqModules.Contains(art.String()) {
|
||||
continue
|
||||
}
|
||||
uniqModules[art.String()] = struct{}{}
|
||||
uniqModules.Append(art.String())
|
||||
|
||||
modulePkgs, moduleDeps, err := p.parseRoot(art, uniqModules)
|
||||
if err != nil {
|
||||
@@ -251,7 +252,7 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
|
||||
// `mvn` shows modules separately from the root package and does not show module nesting.
|
||||
// So we can add all modules as dependencies of root package.
|
||||
if art.Relationship == ftypes.RelationshipRoot {
|
||||
dependsOn = append(dependsOn, lo.Keys(uniqModules)...)
|
||||
dependsOn = append(dependsOn, uniqModules.Items()...)
|
||||
}
|
||||
|
||||
sort.Strings(dependsOn)
|
||||
@@ -340,7 +341,7 @@ type analysisResult struct {
|
||||
}
|
||||
|
||||
type analysisOptions struct {
|
||||
exclusions map[string]struct{}
|
||||
exclusions set.Set[string]
|
||||
depManagement []pomDependency // from the root POM
|
||||
}
|
||||
|
||||
@@ -348,6 +349,9 @@ func (p *Parser) analyze(pom *pom, opts analysisOptions) (analysisResult, error)
|
||||
if pom.nil() {
|
||||
return analysisResult{}, nil
|
||||
}
|
||||
if opts.exclusions == nil {
|
||||
opts.exclusions = set.New[string]()
|
||||
}
|
||||
// Update remoteRepositories
|
||||
pomReleaseRemoteRepos, pomSnapshotRemoteRepos := pom.repositories(p.servers)
|
||||
p.releaseRemoteRepos = lo.Uniq(append(pomReleaseRemoteRepos, p.releaseRemoteRepos...))
|
||||
@@ -408,16 +412,16 @@ func (p *Parser) resolveParent(pom *pom) error {
|
||||
}
|
||||
|
||||
func (p *Parser) mergeDependencyManagements(depManagements ...[]pomDependency) []pomDependency {
|
||||
uniq := make(map[string]struct{})
|
||||
uniq := set.New[string]()
|
||||
var depManagement []pomDependency
|
||||
// The preceding argument takes precedence.
|
||||
for _, dm := range depManagements {
|
||||
for _, dep := range dm {
|
||||
if _, ok := uniq[dep.Name()]; ok {
|
||||
if uniq.Contains(dep.Name()) {
|
||||
continue
|
||||
}
|
||||
depManagement = append(depManagement, dep)
|
||||
uniq[dep.Name()] = struct{}{}
|
||||
uniq.Append(dep.Name())
|
||||
}
|
||||
}
|
||||
return depManagement
|
||||
@@ -492,19 +496,19 @@ func (p *Parser) mergeDependencies(child, parent []pomDependency) []pomDependenc
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Parser) filterDependencies(artifacts []artifact, exclusions map[string]struct{}) []artifact {
|
||||
func (p *Parser) filterDependencies(artifacts []artifact, exclusions set.Set[string]) []artifact {
|
||||
return lo.Filter(artifacts, func(art artifact, _ int) bool {
|
||||
return !excludeDep(exclusions, art)
|
||||
})
|
||||
}
|
||||
|
||||
func excludeDep(exclusions map[string]struct{}, art artifact) bool {
|
||||
if _, ok := exclusions[art.Name()]; ok {
|
||||
func excludeDep(exclusions set.Set[string], art artifact) bool {
|
||||
if exclusions.Contains(art.Name()) {
|
||||
return true
|
||||
}
|
||||
// Maven can use "*" in GroupID and ArtifactID fields to exclude dependencies
|
||||
// https://maven.apache.org/pom.html#exclusions
|
||||
for exlusion := range exclusions {
|
||||
for exlusion := range exclusions.Iter() {
|
||||
// exclusion format - "<groupID>:<artifactID>"
|
||||
e := strings.Split(exlusion, ":")
|
||||
if (e[0] == art.GroupID || e[0] == "*") && (e[1] == art.ArtifactID || e[1] == "*") {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
@@ -15,6 +14,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
"github.com/aquasecurity/trivy/pkg/x/slices"
|
||||
)
|
||||
|
||||
@@ -287,12 +287,12 @@ func (d pomDependency) ToArtifact(opts analysisOptions) artifact {
|
||||
// To avoid shadow adding exclusions to top pom's,
|
||||
// we need to initialize a new map for each new artifact
|
||||
// See `exclusions in child` test for more information
|
||||
exclusions := make(map[string]struct{})
|
||||
exclusions := set.New[string]()
|
||||
if opts.exclusions != nil {
|
||||
exclusions = maps.Clone(opts.exclusions)
|
||||
exclusions = opts.exclusions.Clone()
|
||||
}
|
||||
for _, e := range d.Exclusions.Exclusion {
|
||||
exclusions[fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID)] = struct{}{}
|
||||
exclusions.Append(fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID))
|
||||
}
|
||||
|
||||
var locations ftypes.Locations
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
xio "github.com/aquasecurity/trivy/pkg/x/io"
|
||||
)
|
||||
|
||||
@@ -91,7 +92,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
|
||||
// https://docs.npmjs.com/cli/v9/configuring-npm/package-lock-json#packages
|
||||
p.resolveLinks(packages)
|
||||
|
||||
directDeps := make(map[string]struct{})
|
||||
directDeps := set.New[string]()
|
||||
for name, version := range lo.Assign(packages[""].Dependencies, packages[""].OptionalDependencies, packages[""].DevDependencies, packages[""].PeerDependencies) {
|
||||
pkgPath := joinPaths(nodeModulesDir, name)
|
||||
if _, ok := packages[pkgPath]; !ok {
|
||||
@@ -101,7 +102,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
|
||||
}
|
||||
// Store the package paths of direct dependencies
|
||||
// e.g. node_modules/body-parser
|
||||
directDeps[pkgPath] = struct{}{}
|
||||
directDeps.Append(pkgPath)
|
||||
}
|
||||
|
||||
for pkgPath, pkg := range packages {
|
||||
@@ -366,13 +367,13 @@ func (p *Parser) pkgNameFromPath(pkgPath string) string {
|
||||
|
||||
func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
|
||||
var uniqDeps ftypes.Dependencies
|
||||
unique := make(map[string]struct{})
|
||||
unique := set.New[string]()
|
||||
|
||||
for _, dep := range deps {
|
||||
sort.Strings(dep.DependsOn)
|
||||
depKey := fmt.Sprintf("%s:%s", dep.ID, strings.Join(dep.DependsOn, ","))
|
||||
if _, ok := unique[depKey]; !ok {
|
||||
unique[depKey] = struct{}{}
|
||||
if !unique.Contains(depKey) {
|
||||
unique.Append(depKey)
|
||||
uniqDeps = append(uniqDeps, dep)
|
||||
}
|
||||
}
|
||||
@@ -381,11 +382,11 @@ func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
|
||||
return uniqDeps
|
||||
}
|
||||
|
||||
func isIndirectPkg(pkgPath string, directDeps map[string]struct{}) bool {
|
||||
func isIndirectPkg(pkgPath string, directDeps set.Set[string]) bool {
|
||||
// A project can contain 2 different versions of the same dependency.
|
||||
// e.g. `node_modules/string-width/node_modules/strip-ansi` and `node_modules/string-ansi`
|
||||
// direct dependencies always have root path (`node_modules/<pkg_name>`)
|
||||
if _, ok := directDeps[pkgPath]; ok {
|
||||
if directDeps.Contains(pkgPath) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/dependency"
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
xio "github.com/aquasecurity/trivy/pkg/x/io"
|
||||
)
|
||||
|
||||
@@ -215,7 +216,7 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
|
||||
}
|
||||
}
|
||||
|
||||
visited := make(map[string]struct{})
|
||||
visited := set.New[string]()
|
||||
// Overwrite the `Dev` field for dev deps and their child dependencies.
|
||||
for _, pkg := range resolvedPkgs {
|
||||
if !pkg.Dev {
|
||||
@@ -227,8 +228,8 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
|
||||
}
|
||||
|
||||
// markRootPkgs sets `Dev` to false for non dev dependency.
|
||||
func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited map[string]struct{}) {
|
||||
if _, ok := visited[id]; ok {
|
||||
func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited set.Set[string]) {
|
||||
if visited.Contains(id) {
|
||||
return
|
||||
}
|
||||
pkg, ok := pkgs[id]
|
||||
@@ -238,7 +239,7 @@ func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps ma
|
||||
|
||||
pkg.Dev = false
|
||||
pkgs[id] = pkg
|
||||
visited[id] = struct{}{}
|
||||
visited.Append(id)
|
||||
|
||||
// Update child deps
|
||||
for _, depID := range deps[id].DependsOn {
|
||||
|
||||
@@ -76,7 +76,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
|
||||
}
|
||||
|
||||
if savedDependsOn, ok := depsMap[depId]; ok {
|
||||
dependsOn = utils.UniqueStrings(append(dependsOn, savedDependsOn...))
|
||||
dependsOn = lo.Uniq(append(dependsOn, savedDependsOn...))
|
||||
}
|
||||
|
||||
if len(dependsOn) > 0 {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/dependency/parser/python"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
type PyProject struct {
|
||||
@@ -19,25 +20,27 @@ type Tool struct {
|
||||
}
|
||||
|
||||
type Poetry struct {
|
||||
Dependencies dependencies `toml:"dependencies"`
|
||||
Dependencies Dependencies `toml:"dependencies"`
|
||||
Groups map[string]Group `toml:"group"`
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
Dependencies dependencies `toml:"dependencies"`
|
||||
Dependencies Dependencies `toml:"dependencies"`
|
||||
}
|
||||
|
||||
type dependencies map[string]struct{}
|
||||
type Dependencies struct {
|
||||
set.Set[string]
|
||||
}
|
||||
|
||||
func (d *dependencies) UnmarshalTOML(data any) error {
|
||||
func (d *Dependencies) UnmarshalTOML(data any) error {
|
||||
m, ok := data.(map[string]any)
|
||||
if !ok {
|
||||
return xerrors.Errorf("dependencies must be map, but got: %T", data)
|
||||
}
|
||||
|
||||
*d = lo.MapEntries(m, func(pkgName string, _ any) (string, struct{}) {
|
||||
return python.NormalizePkgName(pkgName), struct{}{}
|
||||
})
|
||||
d.Set = set.New[string](lo.MapToSlice(m, func(pkgName string, _ any) string {
|
||||
return python.NormalizePkgName(pkgName)
|
||||
})...)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/dependency/parser/python/pyproject"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
func TestParser_Parse(t *testing.T) {
|
||||
@@ -24,21 +25,18 @@ func TestParser_Parse(t *testing.T) {
|
||||
want: pyproject.PyProject{
|
||||
Tool: pyproject.Tool{
|
||||
Poetry: pyproject.Poetry{
|
||||
Dependencies: map[string]struct{}{
|
||||
"flask": {},
|
||||
"python": {},
|
||||
"requests": {},
|
||||
"virtualenv": {},
|
||||
Dependencies: pyproject.Dependencies{
|
||||
Set: set.New[string]("flask", "python", "requests", "virtualenv"),
|
||||
},
|
||||
Groups: map[string]pyproject.Group{
|
||||
"dev": {
|
||||
Dependencies: map[string]struct{}{
|
||||
"pytest": {},
|
||||
Dependencies: pyproject.Dependencies{
|
||||
Set: set.New[string]("pytest"),
|
||||
},
|
||||
},
|
||||
"lint": {
|
||||
Dependencies: map[string]struct{}{
|
||||
"ruff": {},
|
||||
Dependencies: pyproject.Dependencies{
|
||||
Set: set.New[string]("ruff"),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/dependency"
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
xio "github.com/aquasecurity/trivy/pkg/x/io"
|
||||
)
|
||||
|
||||
@@ -22,25 +23,25 @@ func (l Lock) packages() map[string]Package {
|
||||
})
|
||||
}
|
||||
|
||||
func (l Lock) directDeps(root Package) map[string]struct{} {
|
||||
deps := make(map[string]struct{})
|
||||
func (l Lock) directDeps(root Package) set.Set[string] {
|
||||
deps := set.New[string]()
|
||||
for _, dep := range root.Dependencies {
|
||||
deps[dep.Name] = struct{}{}
|
||||
deps.Append(dep.Name)
|
||||
}
|
||||
return deps
|
||||
}
|
||||
|
||||
func prodDeps(root Package, packages map[string]Package) map[string]struct{} {
|
||||
visited := make(map[string]struct{})
|
||||
func prodDeps(root Package, packages map[string]Package) set.Set[string] {
|
||||
visited := set.New[string]()
|
||||
walkPackageDeps(root, packages, visited)
|
||||
return visited
|
||||
}
|
||||
|
||||
func walkPackageDeps(pkg Package, packages map[string]Package, visited map[string]struct{}) {
|
||||
if _, ok := visited[pkg.Name]; ok {
|
||||
func walkPackageDeps(pkg Package, packages map[string]Package, visited set.Set[string]) {
|
||||
if visited.Contains(pkg.Name) {
|
||||
return
|
||||
}
|
||||
visited[pkg.Name] = struct{}{}
|
||||
visited.Append(pkg.Name)
|
||||
for _, dep := range pkg.Dependencies {
|
||||
depPkg, exists := packages[dep.Name]
|
||||
if !exists {
|
||||
@@ -119,7 +120,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
|
||||
)
|
||||
|
||||
for _, pkg := range lock.Packages {
|
||||
if _, ok := prodDeps[pkg.Name]; !ok {
|
||||
if !prodDeps.Contains(pkg.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -127,7 +128,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
|
||||
relationship := ftypes.RelationshipIndirect
|
||||
if pkg.isRoot() {
|
||||
relationship = ftypes.RelationshipRoot
|
||||
} else if _, ok := directDeps[pkg.Name]; ok {
|
||||
} else if directDeps.Contains(pkg.Name) {
|
||||
relationship = ftypes.RelationshipDirect
|
||||
}
|
||||
|
||||
|
||||
@@ -10,19 +10,6 @@ import (
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
)
|
||||
|
||||
func UniqueStrings(ss []string) []string {
|
||||
var results []string
|
||||
uniq := make(map[string]struct{})
|
||||
for _, s := range ss {
|
||||
if _, ok := uniq[s]; ok {
|
||||
continue
|
||||
}
|
||||
results = append(results, s)
|
||||
uniq[s] = struct{}{}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func UniquePackages(pkgs []ftypes.Package) []ftypes.Package {
|
||||
if len(pkgs) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/dependency"
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -179,33 +180,30 @@ func (a alpineCmdAnalyzer) parseCommand(command string, envs map[string]string)
|
||||
return pkgs
|
||||
}
|
||||
func (a alpineCmdAnalyzer) resolveDependencies(apkIndexArchive *apkIndex, originalPkgs []string) (pkgs []string) {
|
||||
uniqPkgs := make(map[string]struct{})
|
||||
uniqPkgs := set.New[string]()
|
||||
for _, pkgName := range originalPkgs {
|
||||
if _, ok := uniqPkgs[pkgName]; ok {
|
||||
if uniqPkgs.Contains(pkgName) {
|
||||
continue
|
||||
}
|
||||
|
||||
seenPkgs := make(map[string]struct{})
|
||||
seenPkgs := set.New[string]()
|
||||
for _, p := range a.resolveDependency(apkIndexArchive, pkgName, seenPkgs) {
|
||||
uniqPkgs[p] = struct{}{}
|
||||
uniqPkgs.Append(p)
|
||||
}
|
||||
}
|
||||
for pkg := range uniqPkgs {
|
||||
pkgs = append(pkgs, pkg)
|
||||
}
|
||||
return pkgs
|
||||
return uniqPkgs.Items()
|
||||
}
|
||||
|
||||
func (a alpineCmdAnalyzer) resolveDependency(apkIndexArchive *apkIndex, pkgName string,
|
||||
seenPkgs map[string]struct{}) (pkgNames []string) {
|
||||
seenPkgs set.Set[string]) (pkgNames []string) {
|
||||
pkg, ok := apkIndexArchive.Package[pkgName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if _, ok = seenPkgs[pkgName]; ok {
|
||||
if seenPkgs.Contains(pkgName) {
|
||||
return nil
|
||||
}
|
||||
seenPkgs[pkgName] = struct{}{}
|
||||
seenPkgs.Append(pkgName)
|
||||
|
||||
pkgNames = append(pkgNames, pkgName)
|
||||
for _, dependency := range pkg.Dependencies {
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -1508,86 +1509,86 @@ func TestResolveDependency(t *testing.T) {
|
||||
var tests = map[string]struct {
|
||||
pkgName string
|
||||
apkIndexArchivePath string
|
||||
expected map[string]struct{}
|
||||
want set.Set[string]
|
||||
}{
|
||||
"low": {
|
||||
pkgName: "libblkid",
|
||||
apkIndexArchivePath: "testdata/history_v3.9.json",
|
||||
expected: map[string]struct{}{
|
||||
"libblkid": {},
|
||||
"libuuid": {},
|
||||
"musl": {},
|
||||
},
|
||||
want: set.New(
|
||||
"libblkid",
|
||||
"libuuid",
|
||||
"musl",
|
||||
),
|
||||
},
|
||||
"medium": {
|
||||
pkgName: "libgcab",
|
||||
apkIndexArchivePath: "testdata/history_v3.9.json",
|
||||
expected: map[string]struct{}{
|
||||
"busybox": {},
|
||||
"libblkid": {},
|
||||
"libuuid": {},
|
||||
"musl": {},
|
||||
"libmount": {},
|
||||
"pcre": {},
|
||||
"glib": {},
|
||||
"libgcab": {},
|
||||
"libintl": {},
|
||||
"zlib": {},
|
||||
"libffi": {},
|
||||
},
|
||||
want: set.New(
|
||||
"busybox",
|
||||
"libblkid",
|
||||
"libuuid",
|
||||
"musl",
|
||||
"libmount",
|
||||
"pcre",
|
||||
"glib",
|
||||
"libgcab",
|
||||
"libintl",
|
||||
"zlib",
|
||||
"libffi",
|
||||
),
|
||||
},
|
||||
"high": {
|
||||
pkgName: "postgresql",
|
||||
apkIndexArchivePath: "testdata/history_v3.9.json",
|
||||
expected: map[string]struct{}{
|
||||
"busybox": {},
|
||||
"ncurses-terminfo-base": {},
|
||||
"ncurses-terminfo": {},
|
||||
"libedit": {},
|
||||
"db": {},
|
||||
"libsasl": {},
|
||||
"libldap": {},
|
||||
"libpq": {},
|
||||
"postgresql-client": {},
|
||||
"tzdata": {},
|
||||
"libxml2": {},
|
||||
"postgresql": {},
|
||||
"musl": {},
|
||||
"libcrypto1.1": {},
|
||||
"libssl1.1": {},
|
||||
"ncurses-libs": {},
|
||||
"zlib": {},
|
||||
},
|
||||
want: set.New(
|
||||
"busybox",
|
||||
"ncurses-terminfo-base",
|
||||
"ncurses-terminfo",
|
||||
"libedit",
|
||||
"db",
|
||||
"libsasl",
|
||||
"libldap",
|
||||
"libpq",
|
||||
"postgresql-client",
|
||||
"tzdata",
|
||||
"libxml2",
|
||||
"postgresql",
|
||||
"musl",
|
||||
"libcrypto1.1",
|
||||
"libssl1.1",
|
||||
"ncurses-libs",
|
||||
"zlib",
|
||||
),
|
||||
},
|
||||
"package alias": {
|
||||
pkgName: "sqlite-dev",
|
||||
apkIndexArchivePath: "testdata/history_v3.9.json",
|
||||
expected: map[string]struct{}{
|
||||
"sqlite-dev": {},
|
||||
"sqlite-libs": {},
|
||||
"pkgconf": {}, // pkgconfig => pkgconf
|
||||
"musl": {},
|
||||
},
|
||||
want: set.New(
|
||||
"sqlite-dev",
|
||||
"sqlite-libs",
|
||||
"pkgconf", // pkgconfig => pkgconf
|
||||
"musl",
|
||||
),
|
||||
},
|
||||
"circular dependencies": {
|
||||
pkgName: "nodejs",
|
||||
apkIndexArchivePath: "testdata/history_v3.7.json",
|
||||
expected: map[string]struct{}{
|
||||
"busybox": {},
|
||||
"c-ares": {},
|
||||
"ca-certificates": {},
|
||||
"http-parser": {},
|
||||
"libcrypto1.0": {},
|
||||
"libgcc": {},
|
||||
"libressl2.6-libcrypto": {},
|
||||
"libssl1.0": {},
|
||||
"libstdc++": {},
|
||||
"libuv": {},
|
||||
"musl": {},
|
||||
"nodejs": {},
|
||||
"nodejs-npm": {},
|
||||
"zlib": {},
|
||||
},
|
||||
want: set.New(
|
||||
"busybox",
|
||||
"c-ares",
|
||||
"ca-certificates",
|
||||
"http-parser",
|
||||
"libcrypto1.0",
|
||||
"libgcc",
|
||||
"libressl2.6-libcrypto",
|
||||
"libssl1.0",
|
||||
"libstdc++",
|
||||
"libuv",
|
||||
"musl",
|
||||
"nodejs",
|
||||
"nodejs-npm",
|
||||
"zlib",
|
||||
),
|
||||
},
|
||||
}
|
||||
analyzer := alpineCmdAnalyzer{}
|
||||
@@ -1600,15 +1601,10 @@ func TestResolveDependency(t *testing.T) {
|
||||
if err = json.NewDecoder(f).Decode(&apkIndexArchive); err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
circularDependencyCheck := make(map[string]struct{})
|
||||
circularDependencyCheck := set.New[string]()
|
||||
pkgs := analyzer.resolveDependency(apkIndexArchive, v.pkgName, circularDependencyCheck)
|
||||
actual := make(map[string]struct{})
|
||||
for _, pkg := range pkgs {
|
||||
actual[pkg] = struct{}{}
|
||||
}
|
||||
if !reflect.DeepEqual(v.expected, actual) {
|
||||
t.Errorf("[%s]\n%s", testName, pretty.Compare(v.expected, actual))
|
||||
}
|
||||
got := set.New(pkgs...)
|
||||
assert.Equal(t, v.want, got, testName)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/analyzer/language"
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
|
||||
)
|
||||
|
||||
@@ -105,7 +106,7 @@ func (a poetryAnalyzer) mergePyProject(fsys fs.FS, dir string, app *types.Applic
|
||||
|
||||
// Identify the direct/transitive dependencies
|
||||
for i, pkg := range app.Packages {
|
||||
if _, ok := project.Tool.Poetry.Dependencies[pkg.Name]; ok {
|
||||
if project.Tool.Poetry.Dependencies.Contains(pkg.Name) {
|
||||
app.Packages[i].Relationship = types.RelationshipDirect
|
||||
} else {
|
||||
app.Packages[i].Indirect = true
|
||||
@@ -122,34 +123,33 @@ func filterProdPackages(project pyproject.PyProject, app *types.Application) {
|
||||
return pkg.ID, pkg
|
||||
})
|
||||
|
||||
visited := make(map[string]struct{})
|
||||
visited := set.New[string]()
|
||||
deps := project.Tool.Poetry.Dependencies
|
||||
|
||||
for group, groupDeps := range project.Tool.Poetry.Groups {
|
||||
if group == "dev" {
|
||||
continue
|
||||
}
|
||||
deps = lo.Assign(deps, groupDeps.Dependencies)
|
||||
deps.Set = deps.Union(groupDeps.Dependencies)
|
||||
}
|
||||
|
||||
for _, pkg := range packages {
|
||||
if _, prodDep := deps[pkg.Name]; !prodDep {
|
||||
if !deps.Contains(pkg.Name) {
|
||||
continue
|
||||
}
|
||||
walkPackageDeps(pkg.ID, packages, visited)
|
||||
}
|
||||
|
||||
app.Packages = lo.Filter(app.Packages, func(pkg types.Package, _ int) bool {
|
||||
_, ok := visited[pkg.ID]
|
||||
return ok
|
||||
return visited.Contains(pkg.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func walkPackageDeps(pkgID string, packages map[string]types.Package, visited map[string]struct{}) {
|
||||
if _, ok := visited[pkgID]; ok {
|
||||
func walkPackageDeps(pkgID string, packages map[string]types.Package, visited set.Set[string]) {
|
||||
if visited.Contains(pkgID) {
|
||||
return
|
||||
}
|
||||
visited[pkgID] = struct{}{}
|
||||
visited.Append(pkgID)
|
||||
for _, dep := range packages[pkgID].DependsOn {
|
||||
walkPackageDeps(dep, packages, visited)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/licensing"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -185,13 +186,13 @@ func (a alpinePkgAnalyzer) consolidateDependencies(pkgs []types.Package, provide
|
||||
}
|
||||
|
||||
func (a alpinePkgAnalyzer) uniquePkgs(pkgs []types.Package) (uniqPkgs []types.Package) {
|
||||
uniq := make(map[string]struct{})
|
||||
uniq := set.New[string]()
|
||||
for _, pkg := range pkgs {
|
||||
if _, ok := uniq[pkg.Name]; ok {
|
||||
if uniq.Contains(pkg.Name) {
|
||||
continue
|
||||
}
|
||||
uniqPkgs = append(uniqPkgs, pkg)
|
||||
uniq[pkg.Name] = struct{}{}
|
||||
uniq.Append(pkg.Name)
|
||||
}
|
||||
return uniqPkgs
|
||||
}
|
||||
|
||||
@@ -226,9 +226,9 @@ func (img *image) imageConfig(config *container.Config) v1.Config {
|
||||
}
|
||||
|
||||
if len(config.ExposedPorts) > 0 {
|
||||
c.ExposedPorts = make(map[string]struct{})
|
||||
for port := range c.ExposedPorts {
|
||||
c.ExposedPorts[port] = struct{}{}
|
||||
c.ExposedPorts = make(map[string]struct{}) //nolint: gocritic
|
||||
for port := range config.ExposedPorts {
|
||||
c.ExposedPorts[port.Port()] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -56,14 +56,6 @@ func IsGzip(f *bufio.Reader) bool {
|
||||
return buf[0] == 0x1F && buf[1] == 0x8B && buf[2] == 0x8
|
||||
}
|
||||
|
||||
func Keys(m map[string]struct{}) []string {
|
||||
var keys []string
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func IsExecutable(fileInfo os.FileInfo) bool {
|
||||
// For Windows
|
||||
if filepath.Ext(fileInfo.Name()) == ".exe" {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
checks "github.com/aquasecurity/trivy-checks"
|
||||
"github.com/aquasecurity/trivy/pkg/iac/rules"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
var LoadAndRegister = sync.OnceFunc(func() {
|
||||
@@ -49,7 +50,7 @@ func RegisterRegoRules(modules map[string]*ast.Module) {
|
||||
}
|
||||
|
||||
retriever := NewMetadataRetriever(compiler)
|
||||
regoCheckIDs := make(map[string]struct{})
|
||||
regoCheckIDs := set.New[string]()
|
||||
|
||||
for _, module := range modules {
|
||||
metadata, err := retriever.RetrieveMetadata(ctx, module)
|
||||
@@ -66,7 +67,7 @@ func RegisterRegoRules(modules map[string]*ast.Module) {
|
||||
}
|
||||
|
||||
if !metadata.Deprecated {
|
||||
regoCheckIDs[metadata.AVDID] = struct{}{}
|
||||
regoCheckIDs.Append(metadata.AVDID)
|
||||
}
|
||||
|
||||
rules.Register(metadata.ToRule())
|
||||
|
||||
@@ -12,16 +12,13 @@ import (
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
var builtinNamespaces = map[string]struct{}{
|
||||
"builtin": {},
|
||||
"defsec": {},
|
||||
"appshield": {},
|
||||
}
|
||||
var builtinNamespaces = set.New("builtin", "defsec", "appshield")
|
||||
|
||||
func BuiltinNamespaces() []string {
|
||||
return lo.Keys(builtinNamespaces)
|
||||
return builtinNamespaces.Items()
|
||||
}
|
||||
|
||||
func IsBuiltinNamespace(namespace string) bool {
|
||||
@@ -122,15 +119,12 @@ func (s *Scanner) LoadPolicies(srcFS fs.FS) error {
|
||||
}
|
||||
|
||||
// gather namespaces
|
||||
uniq := make(map[string]struct{})
|
||||
uniq := set.New[string]()
|
||||
for _, module := range s.policies {
|
||||
namespace := getModuleNamespace(module)
|
||||
uniq[namespace] = struct{}{}
|
||||
}
|
||||
var namespaces []string
|
||||
for namespace := range uniq {
|
||||
namespaces = append(namespaces, namespace)
|
||||
uniq.Append(namespace)
|
||||
}
|
||||
namespaces := uniq.Items()
|
||||
|
||||
dataFS := srcFS
|
||||
if s.dataFS != nil {
|
||||
@@ -296,7 +290,7 @@ func (s *Scanner) filterModules(retriever *MetadataRetriever) error {
|
||||
}
|
||||
|
||||
if IsBuiltinNamespace(getModuleNamespace(module)) {
|
||||
if _, disabled := s.disabledCheckIDs[meta.ID]; disabled { // ignore builtin disabled checks
|
||||
if s.disabledCheckIDs.Contains(meta.ID) { // ignore builtin disabled checks
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,9 +69,7 @@ func WithDataDirs(paths ...string) options.ScannerOption {
|
||||
func WithPolicyNamespaces(namespaces ...string) options.ScannerOption {
|
||||
return func(s options.ConfigurableScanner) {
|
||||
if ss, ok := s.(*Scanner); ok {
|
||||
for _, namespace := range namespaces {
|
||||
ss.ruleNamespaces[namespace] = struct{}{}
|
||||
}
|
||||
ss.ruleNamespaces.Append(namespaces...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -112,9 +110,7 @@ func WithCustomSchemas(schemas map[string][]byte) options.ScannerOption {
|
||||
func WithDisabledCheckIDs(ids ...string) options.ScannerOption {
|
||||
return func(s options.ConfigurableScanner) {
|
||||
if ss, ok := s.(*Scanner); ok {
|
||||
for _, id := range ids {
|
||||
ss.disabledCheckIDs[id] = struct{}{}
|
||||
}
|
||||
ss.disabledCheckIDs.Append(ids...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ func parseLineNumber(raw any) int {
|
||||
return n
|
||||
}
|
||||
|
||||
func (s *Scanner) convertResults(set rego.ResultSet, input Input, namespace, rule string, traces []string) scan.Results {
|
||||
func (s *Scanner) convertResults(resultSet rego.ResultSet, input Input, namespace, rule string, traces []string) scan.Results {
|
||||
var results scan.Results
|
||||
|
||||
offset := 0
|
||||
@@ -136,7 +136,7 @@ func (s *Scanner) convertResults(set rego.ResultSet, input Input, namespace, rul
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, result := range set {
|
||||
for _, result := range resultSet {
|
||||
for _, expression := range result.Expressions {
|
||||
values, ok := expression.Value.([]any)
|
||||
if !ok {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
@@ -22,29 +21,26 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
|
||||
"github.com/aquasecurity/trivy/pkg/iac/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
var checkTypesWithSubtype = map[types.Source]struct{}{
|
||||
types.SourceCloud: {},
|
||||
types.SourceDefsec: {},
|
||||
types.SourceKubernetes: {},
|
||||
}
|
||||
var checkTypesWithSubtype = set.New[types.Source](types.SourceCloud, types.SourceDefsec, types.SourceKubernetes)
|
||||
|
||||
var supportedProviders = makeSupportedProviders()
|
||||
|
||||
func makeSupportedProviders() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
func makeSupportedProviders() set.Set[string] {
|
||||
m := set.New[string]()
|
||||
for _, p := range providers.AllProviders() {
|
||||
m[string(p)] = struct{}{}
|
||||
m.Append(string(p))
|
||||
}
|
||||
m["kind"] = struct{}{} // kubernetes
|
||||
m.Append("kind") // kubernetes
|
||||
return m
|
||||
}
|
||||
|
||||
var _ options.ConfigurableScanner = (*Scanner)(nil)
|
||||
|
||||
type Scanner struct {
|
||||
ruleNamespaces map[string]struct{}
|
||||
ruleNamespaces set.Set[string]
|
||||
policies map[string]*ast.Module
|
||||
store storage.Store
|
||||
runtimeValues *ast.Term
|
||||
@@ -70,7 +66,7 @@ type Scanner struct {
|
||||
embeddedChecks map[string]*ast.Module
|
||||
customSchemas map[string][]byte
|
||||
|
||||
disabledCheckIDs map[string]struct{}
|
||||
disabledCheckIDs set.Set[string]
|
||||
}
|
||||
|
||||
func (s *Scanner) trace(heading string, input any) {
|
||||
@@ -103,15 +99,13 @@ func NewScanner(source types.Source, opts ...options.ScannerOption) *Scanner {
|
||||
s := &Scanner{
|
||||
regoErrorLimit: ast.CompileErrorLimitDefault,
|
||||
sourceType: source,
|
||||
ruleNamespaces: make(map[string]struct{}),
|
||||
ruleNamespaces: builtinNamespaces.Clone(),
|
||||
runtimeValues: addRuntimeValues(),
|
||||
logger: log.WithPrefix("rego"),
|
||||
customSchemas: make(map[string][]byte),
|
||||
disabledCheckIDs: make(map[string]struct{}),
|
||||
disabledCheckIDs: set.New[string](),
|
||||
}
|
||||
|
||||
maps.Copy(s.ruleNamespaces, builtinNamespaces)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
@@ -147,7 +141,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, d
|
||||
}
|
||||
|
||||
instance := rego.New(regoOptions...)
|
||||
set, err := instance.Eval(ctx)
|
||||
resultSet, err := instance.Eval(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -165,7 +159,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, d
|
||||
traces = strings.Split(traceBuffer.String(), "\n")
|
||||
}
|
||||
}
|
||||
return set, traces, nil
|
||||
return resultSet, traces, nil
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
@@ -198,7 +192,7 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
|
||||
|
||||
namespace := getModuleNamespace(module)
|
||||
topLevel := strings.Split(namespace, ".")[0]
|
||||
if _, ok := s.ruleNamespaces[topLevel]; !ok {
|
||||
if !s.ruleNamespaces.Contains(topLevel) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -227,15 +221,15 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
|
||||
continue
|
||||
}
|
||||
|
||||
usedRules := make(map[string]struct{})
|
||||
usedRules := set.New[string]()
|
||||
|
||||
// all rules
|
||||
for _, rule := range module.Rules {
|
||||
ruleName := rule.Head.Name.String()
|
||||
if _, ok := usedRules[ruleName]; ok {
|
||||
if usedRules.Contains(ruleName) {
|
||||
continue
|
||||
}
|
||||
usedRules[ruleName] = struct{}{}
|
||||
usedRules.Append(ruleName)
|
||||
if isEnforcedRule(ruleName) {
|
||||
ruleResults, err := s.applyRule(ctx, namespace, ruleName, inputs)
|
||||
if err != nil {
|
||||
@@ -257,8 +251,7 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
|
||||
}
|
||||
|
||||
func isPolicyWithSubtype(sourceType types.Source) bool {
|
||||
_, exists := checkTypesWithSubtype[sourceType]
|
||||
return exists
|
||||
return checkTypesWithSubtype.Contains(sourceType)
|
||||
}
|
||||
|
||||
func checkSubtype(ii map[string]any, provider string, subTypes []SubType) bool {
|
||||
@@ -290,7 +283,7 @@ func isPolicyApplicable(staticMetadata *StaticMetadata, inputs ...Input) bool {
|
||||
for _, input := range inputs {
|
||||
if ii, ok := input.Contents.(map[string]any); ok {
|
||||
for provider := range ii {
|
||||
if _, exists := supportedProviders[provider]; !exists {
|
||||
if !supportedProviders.Contains(provider) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -329,12 +322,12 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs
|
||||
continue
|
||||
}
|
||||
|
||||
set, traces, err := s.runQuery(ctx, qualified, parsedInput, false)
|
||||
resultSet, traces, err := s.runQuery(ctx, qualified, parsedInput, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.trace("RESULTSET", set)
|
||||
ruleResults := s.convertResults(set, input, namespace, rule, traces)
|
||||
s.trace("RESULTSET", resultSet)
|
||||
ruleResults := s.convertResults(resultSet, input, namespace, rule, traces)
|
||||
if len(ruleResults) == 0 { // It passed because we didn't find anything wrong (NOT because it didn't exist)
|
||||
var result regoResult
|
||||
result.FS = input.FS
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/iac/scan"
|
||||
dftypes "github.com/aquasecurity/trivy/pkg/iac/types"
|
||||
ruleTypes "github.com/aquasecurity/trivy/pkg/iac/types/rules"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
type registry struct {
|
||||
@@ -74,14 +75,14 @@ func (r *registry) getFrameworkRules(fw ...framework.Framework) []ruleTypes.Regi
|
||||
if len(fw) == 0 {
|
||||
fw = []framework.Framework{framework.Default}
|
||||
}
|
||||
unique := make(map[int]struct{})
|
||||
unique := set.New[int]()
|
||||
for _, f := range fw {
|
||||
for _, rule := range r.frameworks[f] {
|
||||
if _, ok := unique[rule.Number]; ok {
|
||||
if unique.Contains(rule.Number) {
|
||||
continue
|
||||
}
|
||||
registered = append(registered, rule)
|
||||
unique[rule.Number] = struct{}{}
|
||||
unique.Append(rule.Number)
|
||||
}
|
||||
}
|
||||
return registered
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/iac/terraform"
|
||||
"github.com/aquasecurity/trivy/pkg/iac/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
var _ scanners.FSScanner = (*Scanner)(nil)
|
||||
@@ -31,7 +32,7 @@ type Scanner struct {
|
||||
options []options.ScannerOption
|
||||
parserOpt []parser.Option
|
||||
executorOpt []executor.Option
|
||||
dirs map[string]struct{}
|
||||
dirs set.Set[string]
|
||||
forceAllDirs bool
|
||||
regoScanner *rego.Scanner
|
||||
execLock sync.RWMutex
|
||||
@@ -55,7 +56,7 @@ func (s *Scanner) AddExecutorOptions(opts ...executor.Option) {
|
||||
|
||||
func New(opts ...options.ScannerOption) *Scanner {
|
||||
s := &Scanner{
|
||||
dirs: make(map[string]struct{}),
|
||||
dirs: set.New[string](),
|
||||
options: opts,
|
||||
logger: log.WithPrefix("terraform scanner"),
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
|
||||
"github.com/liamg/memoryfs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
func Test_FSKey(t *testing.T) {
|
||||
@@ -18,22 +20,20 @@ func Test_FSKey(t *testing.T) {
|
||||
memoryfs.New(),
|
||||
}
|
||||
|
||||
keys := make(map[string]struct{})
|
||||
keys := set.New[string]()
|
||||
|
||||
t.Run("uniqueness", func(t *testing.T) {
|
||||
for _, system := range systems {
|
||||
key := CreateFSKey(system)
|
||||
_, ok := keys[key]
|
||||
assert.False(t, ok, "filesystem keys should be unique")
|
||||
keys[key] = struct{}{}
|
||||
assert.False(t, keys.Contains(key), "filesystem keys should be unique")
|
||||
keys.Append(key)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reproducible", func(t *testing.T) {
|
||||
for _, system := range systems {
|
||||
key := CreateFSKey(system)
|
||||
_, ok := keys[key]
|
||||
assert.True(t, ok, "filesystem keys should be reproducible")
|
||||
assert.True(t, keys.Contains(key), "filesystem keys should be reproducible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -43,7 +44,7 @@ func Classify(filePath string, r io.Reader, confidenceLevel float64) (*types.Lic
|
||||
|
||||
var findings types.LicenseFindings
|
||||
var matchType types.LicenseType
|
||||
seen := make(map[string]struct{})
|
||||
seen := set.New[string]()
|
||||
|
||||
// cf.Match is not thread safe
|
||||
m.Lock()
|
||||
@@ -57,11 +58,11 @@ func Classify(filePath string, r io.Reader, confidenceLevel float64) (*types.Lic
|
||||
if match.Confidence <= confidenceLevel {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[match.Name]; ok {
|
||||
if seen.Contains(match.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
seen[match.Name] = struct{}{}
|
||||
seen.Append(match.Name)
|
||||
|
||||
switch match.MatchType {
|
||||
case "Header":
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/aquasecurity/testdocker/registry"
|
||||
"github.com/aquasecurity/testdocker/tarfile"
|
||||
"github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
"github.com/aquasecurity/trivy/pkg/version/app"
|
||||
)
|
||||
|
||||
@@ -216,13 +217,13 @@ type userAgentsTrackingHandler struct {
|
||||
hr http.Handler
|
||||
|
||||
mu sync.Mutex
|
||||
agents map[string]struct{}
|
||||
agents set.Set[string]
|
||||
}
|
||||
|
||||
func newUserAgentsTrackingHandler(hr http.Handler) *userAgentsTrackingHandler {
|
||||
return &userAgentsTrackingHandler{
|
||||
hr: hr,
|
||||
agents: make(map[string]struct{}),
|
||||
agents: set.New[string](),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,7 +231,7 @@ func (uh *userAgentsTrackingHandler) ServeHTTP(rw http.ResponseWriter, r *http.R
|
||||
for _, agent := range r.Header["User-Agent"] {
|
||||
// Skip test framework user agent
|
||||
if agent != "Go-http-client/1.1" {
|
||||
uh.agents[agent] = struct{}{}
|
||||
uh.agents.Append(agent)
|
||||
}
|
||||
}
|
||||
uh.hr.ServeHTTP(rw, r)
|
||||
@@ -271,7 +272,7 @@ func TestUserAgents(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, tracker.agents, 1)
|
||||
_, ok := tracker.agents[fmt.Sprintf("trivy/%s go-containerregistry", app.Version())]
|
||||
ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s go-containerregistry", app.Version()))
|
||||
require.True(t, ok, `user-agent header equals to "trivy/dev go-containerregistry"`)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
"github.com/aquasecurity/trivy/pkg/types"
|
||||
"github.com/aquasecurity/trivy/pkg/version/doc"
|
||||
)
|
||||
@@ -279,7 +280,7 @@ Dependency Origin Tree (Reversed)
|
||||
topLvlID := tml.Sprintf("<red>%s, (%s)</red>", vulnPkg.ID, strings.Join(summaries, ", "))
|
||||
|
||||
branch := root.AddBranch(topLvlID)
|
||||
addParents(branch, vulnPkg, parents, ancestors, map[string]struct{}{vulnPkg.ID: {}}, 1)
|
||||
addParents(branch, vulnPkg, parents, ancestors, set.New(vulnPkg.ID), 1)
|
||||
|
||||
}
|
||||
r.printf(root.String())
|
||||
@@ -291,17 +292,17 @@ func (r *vulnerabilityRenderer) printf(format string, args ...any) {
|
||||
}
|
||||
|
||||
func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string]ftypes.Packages, ancestors map[string][]string,
|
||||
seen map[string]struct{}, depth int) {
|
||||
seen set.Set[string], depth int) {
|
||||
if pkg.Relationship == ftypes.RelationshipDirect {
|
||||
return
|
||||
}
|
||||
|
||||
roots := make(map[string]struct{})
|
||||
roots := set.New[string]()
|
||||
for _, parent := range parentMap[pkg.ID] {
|
||||
if _, ok := seen[parent.ID]; ok {
|
||||
if seen.Contains(parent.ID) {
|
||||
continue
|
||||
}
|
||||
seen[parent.ID] = struct{}{} // to avoid infinite loops
|
||||
seen.Append(parent.ID) // to avoid infinite loops
|
||||
|
||||
if depth == 1 && parent.Relationship == ftypes.RelationshipDirect {
|
||||
topItem.AddBranch(parent.ID)
|
||||
@@ -309,16 +310,13 @@ func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string
|
||||
// We omit intermediate dependencies and show only direct dependencies
|
||||
// as this could make the dependency tree huge.
|
||||
for _, ancestor := range ancestors[parent.ID] {
|
||||
roots[ancestor] = struct{}{}
|
||||
roots.Append(ancestor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Omitted
|
||||
rootIDs := lo.Filter(lo.Keys(roots), func(pkgID string, _ int) bool {
|
||||
_, ok := seen[pkgID]
|
||||
return !ok
|
||||
})
|
||||
rootIDs := roots.Difference(seen).Items()
|
||||
sort.Strings(rootIDs)
|
||||
if len(rootIDs) > 0 {
|
||||
branch := topItem.AddBranch("...(omitted)...")
|
||||
@@ -331,21 +329,21 @@ func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string
|
||||
func traverseAncestors(pkgs []ftypes.Package, parentMap map[string]ftypes.Packages) map[string][]string {
|
||||
ancestors := make(map[string][]string)
|
||||
for _, pkg := range pkgs {
|
||||
ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, make(map[string]struct{}))
|
||||
ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, set.New[string]())
|
||||
}
|
||||
return ancestors
|
||||
}
|
||||
|
||||
func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen map[string]struct{}) []string {
|
||||
ancestors := make(map[string]struct{})
|
||||
seen[pkgID] = struct{}{}
|
||||
func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen set.Set[string]) []string {
|
||||
ancestors := set.New[string]()
|
||||
seen.Append(pkgID)
|
||||
for _, parent := range parentMap[pkgID] {
|
||||
if _, ok := seen[parent.ID]; ok {
|
||||
if seen.Contains(parent.ID) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case parent.Relationship == ftypes.RelationshipDirect:
|
||||
ancestors[parent.ID] = struct{}{}
|
||||
ancestors.Append(parent.ID)
|
||||
case len(parentMap[parent.ID]) == 0:
|
||||
// Some package managers, such as "package-lock.json" v1, can retrieve package dependencies but not relationships.
|
||||
// We try to guess direct dependencies in this case. A dependency with no parents must be a direct dependency.
|
||||
@@ -358,14 +356,14 @@ func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen map[s
|
||||
//
|
||||
// Even if `styled-components` is not marked as a direct dependency, it must be a direct dependency
|
||||
// as it has no parents. Note that it doesn't mean `fbjs` is an indirect dependency.
|
||||
ancestors[parent.ID] = struct{}{}
|
||||
ancestors.Append(parent.ID)
|
||||
default:
|
||||
for _, a := range findAncestor(parent.ID, parentMap, seen) {
|
||||
ancestors[a] = struct{}{}
|
||||
ancestors.Append(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
return lo.Keys(ancestors)
|
||||
return ancestors.Items()
|
||||
}
|
||||
|
||||
var jarExtensions = []string{
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/detector/library"
|
||||
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
"github.com/aquasecurity/trivy/pkg/types"
|
||||
)
|
||||
|
||||
@@ -41,7 +42,7 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.
|
||||
}
|
||||
|
||||
var results types.Results
|
||||
printedTypes := make(map[ftypes.LangType]struct{})
|
||||
printedTypes := set.New[ftypes.LangType]()
|
||||
for _, app := range apps {
|
||||
if len(app.Packages) == 0 {
|
||||
continue
|
||||
@@ -76,13 +77,13 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes map[ftypes.LangType]struct{}) (
|
||||
func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes set.Set[ftypes.LangType]) (
|
||||
[]types.DetectedVulnerability, error) {
|
||||
|
||||
// Prevent the same log messages from being displayed many times for the same type.
|
||||
if _, ok := printedTypes[app.Type]; !ok {
|
||||
if !printedTypes.Contains(app.Type) {
|
||||
log.InfoContext(ctx, "Detecting vulnerabilities...")
|
||||
printedTypes[app.Type] = struct{}{}
|
||||
printedTypes.Append(app.Type)
|
||||
}
|
||||
|
||||
log.DebugContext(ctx, "Scanning packages for vulnerabilities", log.FilePath(app.FilePath))
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/aquasecurity/trivy/pkg/scanner/langpkg"
|
||||
"github.com/aquasecurity/trivy/pkg/scanner/ospkg"
|
||||
"github.com/aquasecurity/trivy/pkg/scanner/post"
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
"github.com/aquasecurity/trivy/pkg/types"
|
||||
"github.com/aquasecurity/trivy/pkg/vulnerability"
|
||||
|
||||
@@ -458,12 +459,12 @@ func mergePkgs(pkgs, pkgsFromCommands []ftypes.Package, options types.ScanOption
|
||||
}
|
||||
|
||||
// pkg has priority over pkgsFromCommands
|
||||
uniqPkgs := make(map[string]struct{})
|
||||
uniqPkgs := set.New[string]()
|
||||
for _, pkg := range pkgs {
|
||||
uniqPkgs[pkg.Name] = struct{}{}
|
||||
uniqPkgs.Append(pkg.Name)
|
||||
}
|
||||
for _, pkg := range pkgsFromCommands {
|
||||
if _, ok := uniqPkgs[pkg.Name]; ok {
|
||||
if uniqPkgs.Contains(pkg.Name) {
|
||||
continue
|
||||
}
|
||||
pkgs = append(pkgs, pkg)
|
||||
|
||||
39
pkg/set/set.go
Normal file
39
pkg/set/set.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package set
|
||||
|
||||
import "iter"
|
||||
|
||||
// Set defines the interface for set operations
|
||||
type Set[T comparable] interface {
|
||||
// Append adds multiple items to the set and returns the new size
|
||||
Append(val ...T) int
|
||||
|
||||
// Remove removes an item from the set
|
||||
Remove(item T)
|
||||
|
||||
// Contains checks if an item exists in the set
|
||||
Contains(item T) bool
|
||||
|
||||
// Size returns the number of items in the set
|
||||
Size() int
|
||||
|
||||
// Clear removes all items from the set
|
||||
Clear()
|
||||
|
||||
// Clone returns a new set with a copy of all items
|
||||
Clone() Set[T]
|
||||
|
||||
// Items returns all items in the set as a slice
|
||||
Items() []T
|
||||
|
||||
// Iter returns an iterator over the set
|
||||
Iter() iter.Seq[T]
|
||||
|
||||
// Union returns a new set containing all items from both sets
|
||||
Union(other Set[T]) Set[T]
|
||||
|
||||
// Intersection returns a new set containing items present in both sets
|
||||
Intersection(other Set[T]) Set[T]
|
||||
|
||||
// Difference returns a new set containing items present in this set but not in the other
|
||||
Difference(other Set[T]) Set[T]
|
||||
}
|
||||
100
pkg/set/unsafe.go
Normal file
100
pkg/set/unsafe.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package set
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"maps"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// unsafeSet represents a non-thread-safe set implementation
|
||||
// WARNING: This implementation is not thread-safe
|
||||
type unsafeSet[T comparable] map[T]struct{} //nolint: gocritic
|
||||
|
||||
// New creates a new empty non-thread-safe set with optional initial values
|
||||
func New[T comparable](values ...T) Set[T] {
|
||||
s := make(unsafeSet[T])
|
||||
for _, v := range values {
|
||||
s[v] = struct{}{}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Append adds multiple items to the set and returns the new size
|
||||
func (s unsafeSet[T]) Append(val ...T) int {
|
||||
for _, item := range val {
|
||||
s[item] = struct{}{}
|
||||
}
|
||||
return len(s)
|
||||
}
|
||||
|
||||
// Remove removes an item from the set
|
||||
func (s unsafeSet[T]) Remove(item T) {
|
||||
delete(s, item)
|
||||
}
|
||||
|
||||
// Contains checks if an item exists in the set
|
||||
func (s unsafeSet[T]) Contains(item T) bool {
|
||||
_, exists := s[item]
|
||||
return exists
|
||||
}
|
||||
|
||||
// Size returns the number of items in the set
|
||||
func (s unsafeSet[T]) Size() int {
|
||||
return len(s)
|
||||
}
|
||||
|
||||
// Clear removes all items from the set
|
||||
func (s unsafeSet[T]) Clear() {
|
||||
for k := range s {
|
||||
delete(s, k)
|
||||
}
|
||||
}
|
||||
|
||||
// Clone returns a new set with a copy of all items
|
||||
func (s unsafeSet[T]) Clone() Set[T] {
|
||||
return maps.Clone(s)
|
||||
}
|
||||
|
||||
// Items returns all items in the set as a slice
|
||||
func (s unsafeSet[T]) Items() []T {
|
||||
return slices.Collect(s.Iter())
|
||||
}
|
||||
|
||||
// Iter returns an iterator over the set
|
||||
func (s unsafeSet[T]) Iter() iter.Seq[T] {
|
||||
return maps.Keys(s)
|
||||
}
|
||||
|
||||
// Union returns a new set containing all items from both sets
|
||||
func (s unsafeSet[T]) Union(other Set[T]) Set[T] {
|
||||
result := make(unsafeSet[T])
|
||||
for k := range s {
|
||||
result[k] = struct{}{}
|
||||
}
|
||||
for _, item := range other.Items() {
|
||||
result[item] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Intersection returns a new set containing items present in both sets
|
||||
func (s unsafeSet[T]) Intersection(other Set[T]) Set[T] {
|
||||
result := make(unsafeSet[T])
|
||||
for k := range s {
|
||||
if other.Contains(k) {
|
||||
result[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Difference returns a new set containing items present in this set but not in the other
|
||||
func (s unsafeSet[T]) Difference(other Set[T]) Set[T] {
|
||||
result := make(unsafeSet[T])
|
||||
for k := range s {
|
||||
if !other.Contains(k) {
|
||||
result[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
583
pkg/set/unsafe_test.go
Normal file
583
pkg/set/unsafe_test.go
Normal file
@@ -0,0 +1,583 @@
|
||||
package set_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/set"
|
||||
)
|
||||
|
||||
func Test_New(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
values []int
|
||||
wantSize int
|
||||
wantAll bool
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
name: "new empty set",
|
||||
values: []int{},
|
||||
wantSize: 0,
|
||||
wantAll: true,
|
||||
desc: "should create empty set when no values provided",
|
||||
},
|
||||
{
|
||||
name: "new set with single value",
|
||||
values: []int{1},
|
||||
wantSize: 1,
|
||||
wantAll: true,
|
||||
desc: "should create set with single value",
|
||||
},
|
||||
{
|
||||
name: "new set with multiple values",
|
||||
values: []int{
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
},
|
||||
wantSize: 3,
|
||||
wantAll: true,
|
||||
desc: "should create set with multiple values",
|
||||
},
|
||||
{
|
||||
name: "new set with duplicate values",
|
||||
values: []int{
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
},
|
||||
wantSize: 3,
|
||||
wantAll: true,
|
||||
desc: "should create set with unique values only",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := set.New(tt.values...)
|
||||
assert.Equal(t, tt.wantSize, s.Size(), "unexpected set size")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unsafeSet_Add(t *testing.T) {
|
||||
// Define custom type for struct test cases
|
||||
type custom struct {
|
||||
id int
|
||||
name string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare func(s set.Set[any])
|
||||
input any
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "add integer",
|
||||
prepare: nil,
|
||||
input: 1,
|
||||
wantSize: 1,
|
||||
},
|
||||
{
|
||||
name: "add duplicate integer",
|
||||
prepare: func(s set.Set[any]) {
|
||||
s.Append(1)
|
||||
},
|
||||
input: 1,
|
||||
wantSize: 1,
|
||||
},
|
||||
{
|
||||
name: "add string",
|
||||
prepare: nil,
|
||||
input: "test",
|
||||
wantSize: 1,
|
||||
},
|
||||
{
|
||||
name: "add empty string",
|
||||
prepare: nil,
|
||||
input: "",
|
||||
wantSize: 1,
|
||||
},
|
||||
{
|
||||
name: "add custom struct",
|
||||
prepare: nil,
|
||||
input: custom{
|
||||
id: 1,
|
||||
name: "test1",
|
||||
},
|
||||
wantSize: 1,
|
||||
},
|
||||
{
|
||||
name: "add nil pointer",
|
||||
prepare: nil,
|
||||
input: (*int)(nil),
|
||||
wantSize: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := set.New[any]()
|
||||
if tt.prepare != nil {
|
||||
tt.prepare(s)
|
||||
}
|
||||
s.Append(tt.input)
|
||||
|
||||
got := s.Size()
|
||||
assert.Equal(t, tt.wantSize, got, "unexpected set size")
|
||||
assert.True(t, s.Contains(tt.input), "unexpected contains result for value: %v", tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unsafeSet_Append(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare func(s set.Set[int])
|
||||
input []int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "append to empty set",
|
||||
prepare: nil,
|
||||
input: []int{
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
},
|
||||
wantSize: 3,
|
||||
},
|
||||
{
|
||||
name: "append with duplicates",
|
||||
prepare: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
},
|
||||
input: []int{
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
3,
|
||||
2,
|
||||
},
|
||||
wantSize: 3,
|
||||
},
|
||||
{
|
||||
name: "append empty slice",
|
||||
prepare: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
},
|
||||
input: []int{},
|
||||
wantSize: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := set.New[int]()
|
||||
if tt.prepare != nil {
|
||||
tt.prepare(s)
|
||||
}
|
||||
got := s.Append(tt.input...)
|
||||
|
||||
assert.Equal(t, tt.wantSize, got, "unexpected returned size")
|
||||
assert.Equal(t, tt.wantSize, s.Size(), "unexpected actual size")
|
||||
|
||||
for _, item := range tt.input {
|
||||
assert.True(t, s.Contains(item), "set should contain appended item: %v", item)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unsafeSet_Remove(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare func(s set.Set[int])
|
||||
input int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "remove existing element",
|
||||
prepare: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
},
|
||||
input: 1,
|
||||
wantSize: 0,
|
||||
},
|
||||
{
|
||||
name: "remove non-existing element",
|
||||
prepare: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
},
|
||||
input: 2,
|
||||
wantSize: 1,
|
||||
},
|
||||
{
|
||||
name: "remove from empty set",
|
||||
prepare: nil,
|
||||
input: 1,
|
||||
wantSize: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := set.New[int]()
|
||||
if tt.prepare != nil {
|
||||
tt.prepare(s)
|
||||
}
|
||||
s.Remove(tt.input)
|
||||
|
||||
got := s.Size()
|
||||
assert.Equal(t, tt.wantSize, got, "unexpected set size")
|
||||
assert.False(t, s.Contains(tt.input), "unexpected contains result for value: %v", tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unsafeSet_Clear(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare func(s set.Set[int])
|
||||
}{
|
||||
{
|
||||
name: "clear non-empty set",
|
||||
prepare: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
s.Append(3)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "clear empty set",
|
||||
prepare: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := set.New[int]()
|
||||
if tt.prepare != nil {
|
||||
tt.prepare(s)
|
||||
}
|
||||
s.Clear()
|
||||
|
||||
got := s.Size()
|
||||
assert.Zero(t, got, "unexpected set size")
|
||||
assert.Empty(t, s.Items(), "items should be empty")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unsafeSet_Clone(t *testing.T) {
|
||||
t.Run("empty set", func(t *testing.T) {
|
||||
original := set.New[string]()
|
||||
cloned := original.Clone()
|
||||
|
||||
assert.Equal(t, 0, cloned.Size(), "cloned set should be empty")
|
||||
|
||||
// Verify independence
|
||||
original.Append("test")
|
||||
assert.False(t, cloned.Contains("test"), "cloned set should not be affected by original")
|
||||
})
|
||||
|
||||
t.Run("basic types", func(t *testing.T) {
|
||||
original := set.New[any](1, "test", true)
|
||||
cloned := original.Clone()
|
||||
|
||||
assert.Equal(t, original.Size(), cloned.Size(), "sizes should match")
|
||||
assert.True(t, cloned.Contains(1), "should contain integer")
|
||||
assert.True(t, cloned.Contains("test"), "should contain string")
|
||||
assert.True(t, cloned.Contains(true), "should contain boolean")
|
||||
|
||||
// Verify independence
|
||||
original.Append("new")
|
||||
assert.False(t, cloned.Contains("new"), "cloned set should not be affected by original")
|
||||
cloned.Append("another")
|
||||
assert.False(t, original.Contains("another"), "original set should not be affected by clone")
|
||||
})
|
||||
|
||||
// Test nil pointer
|
||||
t.Run("nil pointer", func(t *testing.T) {
|
||||
original := set.New[*int]()
|
||||
original.Append(nil)
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
assert.Equal(t, original.Size(), cloned.Size(), "sizes should match")
|
||||
assert.True(t, cloned.Contains((*int)(nil)), "should contain nil pointer")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_unsafeSet_Items(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare func(s set.Set[int])
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
name: "get items from non-empty set",
|
||||
prepare: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
s.Append(3)
|
||||
},
|
||||
want: []int{
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get items from empty set",
|
||||
prepare: nil,
|
||||
want: []int{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := set.New[int]()
|
||||
if tt.prepare != nil {
|
||||
tt.prepare(s)
|
||||
}
|
||||
got := s.Items()
|
||||
|
||||
assert.ElementsMatch(t, tt.want, got, "unexpected items in set")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unsafeSet_Union(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare1 func(s set.Set[int])
|
||||
prepare2 func(s set.Set[int])
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
name: "union of non-overlapping sets",
|
||||
prepare1: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
},
|
||||
prepare2: func(s set.Set[int]) {
|
||||
s.Append(3)
|
||||
s.Append(4)
|
||||
},
|
||||
want: []int{
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "union of overlapping sets",
|
||||
prepare1: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
s.Append(3)
|
||||
},
|
||||
prepare2: func(s set.Set[int]) {
|
||||
s.Append(2)
|
||||
s.Append(3)
|
||||
s.Append(4)
|
||||
},
|
||||
want: []int{
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "union with empty set",
|
||||
prepare1: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
},
|
||||
prepare2: nil,
|
||||
want: []int{
|
||||
1,
|
||||
2,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "union of empty sets",
|
||||
prepare1: nil,
|
||||
prepare2: nil,
|
||||
want: []int{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s1 := set.New[int]()
|
||||
s2 := set.New[int]()
|
||||
|
||||
if tt.prepare1 != nil {
|
||||
tt.prepare1(s1)
|
||||
}
|
||||
if tt.prepare2 != nil {
|
||||
tt.prepare2(s2)
|
||||
}
|
||||
|
||||
result := s1.Union(s2)
|
||||
got := result.Items()
|
||||
|
||||
assert.ElementsMatch(t, tt.want, got, "unexpected union result")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unsafeSet_Intersection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare1 func(s set.Set[int])
|
||||
prepare2 func(s set.Set[int])
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
name: "intersection of overlapping sets",
|
||||
prepare1: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
s.Append(3)
|
||||
},
|
||||
prepare2: func(s set.Set[int]) {
|
||||
s.Append(2)
|
||||
s.Append(3)
|
||||
s.Append(4)
|
||||
},
|
||||
want: []int{
|
||||
2,
|
||||
3,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "intersection of non-overlapping sets",
|
||||
prepare1: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
},
|
||||
prepare2: func(s set.Set[int]) {
|
||||
s.Append(3)
|
||||
s.Append(4)
|
||||
},
|
||||
want: []int{},
|
||||
},
|
||||
{
|
||||
name: "intersection with empty set",
|
||||
prepare1: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
},
|
||||
prepare2: nil,
|
||||
want: []int{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s1 := set.New[int]()
|
||||
s2 := set.New[int]()
|
||||
|
||||
if tt.prepare1 != nil {
|
||||
tt.prepare1(s1)
|
||||
}
|
||||
if tt.prepare2 != nil {
|
||||
tt.prepare2(s2)
|
||||
}
|
||||
|
||||
result := s1.Intersection(s2)
|
||||
got := result.Items()
|
||||
|
||||
assert.ElementsMatch(t, tt.want, got, "unexpected intersection result")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unsafeSet_Difference(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare1 func(s set.Set[int])
|
||||
prepare2 func(s set.Set[int])
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
name: "difference of overlapping sets",
|
||||
prepare1: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
s.Append(3)
|
||||
},
|
||||
prepare2: func(s set.Set[int]) {
|
||||
s.Append(2)
|
||||
s.Append(3)
|
||||
s.Append(4)
|
||||
},
|
||||
want: []int{1},
|
||||
},
|
||||
{
|
||||
name: "difference with non-overlapping set",
|
||||
prepare1: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
},
|
||||
prepare2: func(s set.Set[int]) {
|
||||
s.Append(3)
|
||||
s.Append(4)
|
||||
},
|
||||
want: []int{
|
||||
1,
|
||||
2,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "difference with empty set",
|
||||
prepare1: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
},
|
||||
prepare2: nil,
|
||||
want: []int{
|
||||
1,
|
||||
2,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "difference of empty set",
|
||||
prepare1: nil,
|
||||
prepare2: func(s set.Set[int]) {
|
||||
s.Append(1)
|
||||
s.Append(2)
|
||||
},
|
||||
want: []int{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s1 := set.New[int]()
|
||||
s2 := set.New[int]()
|
||||
if tt.prepare1 != nil {
|
||||
tt.prepare1(s1)
|
||||
}
|
||||
if tt.prepare2 != nil {
|
||||
tt.prepare2(s2)
|
||||
}
|
||||
|
||||
result := s1.Difference(s2)
|
||||
got := result.Items()
|
||||
|
||||
assert.ElementsMatch(t, tt.want, got, "unexpected difference result")
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user