refactor: add generic Set implementation (#8149)

Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
Teppei Fukuda
2024-12-24 13:47:21 +09:00
committed by GitHub
parent e6d0ba5cc9
commit b5859d3fb5
34 changed files with 968 additions and 270 deletions

View File

@@ -30,3 +30,8 @@ func errorsJoin(m dsl.Matcher) {
m.Match(`errors.Join($*args)`).
Report("use github.com/hashicorp/go-multierror.Append instead of errors.Join.")
}
func mapSet(m dsl.Matcher) {
m.Match(`map[$x]struct{}`).
Report("use github.com/aquasecurity/trivy/pkg/set.Set instead of map.")
}

View File

@@ -6,13 +6,13 @@ import (
"path/filepath"
"strings"
"github.com/samber/lo"
"golang.org/x/xerrors"
"gopkg.in/yaml.v3"
sp "github.com/aquasecurity/trivy-checks/pkg/spec"
iacTypes "github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types"
)
@@ -31,17 +31,17 @@ const (
// Scanners reads spec control and determines the scanners by check ID prefix
func (cs *ComplianceSpec) Scanners() (types.Scanners, error) {
scannerTypes := make(map[types.Scanner]struct{})
scannerTypes := set.New[types.Scanner]()
for _, control := range cs.Spec.Controls {
for _, check := range control.Checks {
scannerType := scannerByCheckID(check.ID)
if scannerType == types.UnknownScanner {
return nil, xerrors.Errorf("unsupported check ID: %s", check.ID)
}
scannerTypes[scannerType] = struct{}{}
scannerTypes.Append(scannerType)
}
}
return lo.Keys(scannerTypes), nil
return scannerTypes.Items(), nil
}
// CheckIDs return list of compliance check IDs

View File

@@ -12,6 +12,7 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/version/doc"
)
@@ -30,7 +31,7 @@ type artifact struct {
Version version
Licenses []string
Exclusions map[string]struct{}
Exclusions set.Set[string]
Module bool
Relationship ftypes.Relationship

View File

@@ -22,6 +22,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)
@@ -118,11 +119,11 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
rootArt := root.artifact()
rootArt.Relationship = ftypes.RelationshipRoot
return p.parseRoot(rootArt, make(map[string]struct{}))
return p.parseRoot(rootArt, set.New[string]())
}
// nolint: gocyclo
func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ftypes.Package, []ftypes.Dependency, error) {
func (p *Parser) parseRoot(root artifact, uniqModules set.Set[string]) ([]ftypes.Package, []ftypes.Dependency, error) {
// Prepare a queue for dependencies
queue := newArtifactQueue()
@@ -145,10 +146,10 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
// Modules should be handled separately so that they can have independent dependencies.
// It means multi-module allows for duplicate dependencies.
if art.Module {
if _, ok := uniqModules[art.String()]; ok {
if uniqModules.Contains(art.String()) {
continue
}
uniqModules[art.String()] = struct{}{}
uniqModules.Append(art.String())
modulePkgs, moduleDeps, err := p.parseRoot(art, uniqModules)
if err != nil {
@@ -251,7 +252,7 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
// `mvn` shows modules separately from the root package and does not show module nesting.
// So we can add all modules as dependencies of root package.
if art.Relationship == ftypes.RelationshipRoot {
dependsOn = append(dependsOn, lo.Keys(uniqModules)...)
dependsOn = append(dependsOn, uniqModules.Items()...)
}
sort.Strings(dependsOn)
@@ -340,7 +341,7 @@ type analysisResult struct {
}
type analysisOptions struct {
exclusions map[string]struct{}
exclusions set.Set[string]
depManagement []pomDependency // from the root POM
}
@@ -348,6 +349,9 @@ func (p *Parser) analyze(pom *pom, opts analysisOptions) (analysisResult, error)
if pom.nil() {
return analysisResult{}, nil
}
if opts.exclusions == nil {
opts.exclusions = set.New[string]()
}
// Update remoteRepositories
pomReleaseRemoteRepos, pomSnapshotRemoteRepos := pom.repositories(p.servers)
p.releaseRemoteRepos = lo.Uniq(append(pomReleaseRemoteRepos, p.releaseRemoteRepos...))
@@ -408,16 +412,16 @@ func (p *Parser) resolveParent(pom *pom) error {
}
func (p *Parser) mergeDependencyManagements(depManagements ...[]pomDependency) []pomDependency {
uniq := make(map[string]struct{})
uniq := set.New[string]()
var depManagement []pomDependency
// The preceding argument takes precedence.
for _, dm := range depManagements {
for _, dep := range dm {
if _, ok := uniq[dep.Name()]; ok {
if uniq.Contains(dep.Name()) {
continue
}
depManagement = append(depManagement, dep)
uniq[dep.Name()] = struct{}{}
uniq.Append(dep.Name())
}
}
return depManagement
@@ -492,19 +496,19 @@ func (p *Parser) mergeDependencies(child, parent []pomDependency) []pomDependenc
})
}
func (p *Parser) filterDependencies(artifacts []artifact, exclusions map[string]struct{}) []artifact {
func (p *Parser) filterDependencies(artifacts []artifact, exclusions set.Set[string]) []artifact {
return lo.Filter(artifacts, func(art artifact, _ int) bool {
return !excludeDep(exclusions, art)
})
}
func excludeDep(exclusions map[string]struct{}, art artifact) bool {
if _, ok := exclusions[art.Name()]; ok {
func excludeDep(exclusions set.Set[string], art artifact) bool {
if exclusions.Contains(art.Name()) {
return true
}
// Maven can use "*" in GroupID and ArtifactID fields to exclude dependencies
// https://maven.apache.org/pom.html#exclusions
for exlusion := range exclusions {
for exlusion := range exclusions.Iter() {
// exclusion format - "<groupID>:<artifactID>"
e := strings.Split(exlusion, ":")
if (e[0] == art.GroupID || e[0] == "*") && (e[1] == art.ArtifactID || e[1] == "*") {

View File

@@ -4,7 +4,6 @@ import (
"encoding/xml"
"fmt"
"io"
"maps"
"net/url"
"reflect"
"strings"
@@ -15,6 +14,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/x/slices"
)
@@ -287,12 +287,12 @@ func (d pomDependency) ToArtifact(opts analysisOptions) artifact {
// To avoid shadow adding exclusions to top pom's,
// we need to initialize a new map for each new artifact
// See `exclusions in child` test for more information
exclusions := make(map[string]struct{})
exclusions := set.New[string]()
if opts.exclusions != nil {
exclusions = maps.Clone(opts.exclusions)
exclusions = opts.exclusions.Clone()
}
for _, e := range d.Exclusions.Exclusion {
exclusions[fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID)] = struct{}{}
exclusions.Append(fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID))
}
var locations ftypes.Locations

View File

@@ -17,6 +17,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)
@@ -91,7 +92,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
// https://docs.npmjs.com/cli/v9/configuring-npm/package-lock-json#packages
p.resolveLinks(packages)
directDeps := make(map[string]struct{})
directDeps := set.New[string]()
for name, version := range lo.Assign(packages[""].Dependencies, packages[""].OptionalDependencies, packages[""].DevDependencies, packages[""].PeerDependencies) {
pkgPath := joinPaths(nodeModulesDir, name)
if _, ok := packages[pkgPath]; !ok {
@@ -101,7 +102,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
}
// Store the package paths of direct dependencies
// e.g. node_modules/body-parser
directDeps[pkgPath] = struct{}{}
directDeps.Append(pkgPath)
}
for pkgPath, pkg := range packages {
@@ -366,13 +367,13 @@ func (p *Parser) pkgNameFromPath(pkgPath string) string {
func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
var uniqDeps ftypes.Dependencies
unique := make(map[string]struct{})
unique := set.New[string]()
for _, dep := range deps {
sort.Strings(dep.DependsOn)
depKey := fmt.Sprintf("%s:%s", dep.ID, strings.Join(dep.DependsOn, ","))
if _, ok := unique[depKey]; !ok {
unique[depKey] = struct{}{}
if !unique.Contains(depKey) {
unique.Append(depKey)
uniqDeps = append(uniqDeps, dep)
}
}
@@ -381,11 +382,11 @@ func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
return uniqDeps
}
func isIndirectPkg(pkgPath string, directDeps map[string]struct{}) bool {
func isIndirectPkg(pkgPath string, directDeps set.Set[string]) bool {
// A project can contain 2 different versions of the same dependency.
// e.g. `node_modules/string-width/node_modules/strip-ansi` and `node_modules/string-ansi`
// direct dependencies always have root path (`node_modules/<pkg_name>`)
if _, ok := directDeps[pkgPath]; ok {
if directDeps.Contains(pkgPath) {
return false
}
return true

View File

@@ -14,6 +14,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)
@@ -215,7 +216,7 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
}
}
visited := make(map[string]struct{})
visited := set.New[string]()
// Overwrite the `Dev` field for dev deps and their child dependencies.
for _, pkg := range resolvedPkgs {
if !pkg.Dev {
@@ -227,8 +228,8 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
}
// markRootPkgs sets `Dev` to false for non dev dependency.
func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited map[string]struct{}) {
if _, ok := visited[id]; ok {
func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited set.Set[string]) {
if visited.Contains(id) {
return
}
pkg, ok := pkgs[id]
@@ -238,7 +239,7 @@ func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps ma
pkg.Dev = false
pkgs[id] = pkg
visited[id] = struct{}{}
visited.Append(id)
// Update child deps
for _, depID := range deps[id].DependsOn {

View File

@@ -76,7 +76,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
}
if savedDependsOn, ok := depsMap[depId]; ok {
dependsOn = utils.UniqueStrings(append(dependsOn, savedDependsOn...))
dependsOn = lo.Uniq(append(dependsOn, savedDependsOn...))
}
if len(dependsOn) > 0 {

View File

@@ -8,6 +8,7 @@ import (
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/dependency/parser/python"
"github.com/aquasecurity/trivy/pkg/set"
)
type PyProject struct {
@@ -19,25 +20,27 @@ type Tool struct {
}
type Poetry struct {
Dependencies dependencies `toml:"dependencies"`
Dependencies Dependencies `toml:"dependencies"`
Groups map[string]Group `toml:"group"`
}
type Group struct {
Dependencies dependencies `toml:"dependencies"`
Dependencies Dependencies `toml:"dependencies"`
}
type dependencies map[string]struct{}
type Dependencies struct {
set.Set[string]
}
func (d *dependencies) UnmarshalTOML(data any) error {
func (d *Dependencies) UnmarshalTOML(data any) error {
m, ok := data.(map[string]any)
if !ok {
return xerrors.Errorf("dependencies must be map, but got: %T", data)
}
*d = lo.MapEntries(m, func(pkgName string, _ any) (string, struct{}) {
return python.NormalizePkgName(pkgName), struct{}{}
})
d.Set = set.New[string](lo.MapToSlice(m, func(pkgName string, _ any) string {
return python.NormalizePkgName(pkgName)
})...)
return nil
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/dependency/parser/python/pyproject"
"github.com/aquasecurity/trivy/pkg/set"
)
func TestParser_Parse(t *testing.T) {
@@ -24,21 +25,18 @@ func TestParser_Parse(t *testing.T) {
want: pyproject.PyProject{
Tool: pyproject.Tool{
Poetry: pyproject.Poetry{
Dependencies: map[string]struct{}{
"flask": {},
"python": {},
"requests": {},
"virtualenv": {},
Dependencies: pyproject.Dependencies{
Set: set.New[string]("flask", "python", "requests", "virtualenv"),
},
Groups: map[string]pyproject.Group{
"dev": {
Dependencies: map[string]struct{}{
"pytest": {},
Dependencies: pyproject.Dependencies{
Set: set.New[string]("pytest"),
},
},
"lint": {
Dependencies: map[string]struct{}{
"ruff": {},
Dependencies: pyproject.Dependencies{
Set: set.New[string]("ruff"),
},
},
},

View File

@@ -9,6 +9,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)
@@ -22,25 +23,25 @@ func (l Lock) packages() map[string]Package {
})
}
func (l Lock) directDeps(root Package) map[string]struct{} {
deps := make(map[string]struct{})
func (l Lock) directDeps(root Package) set.Set[string] {
deps := set.New[string]()
for _, dep := range root.Dependencies {
deps[dep.Name] = struct{}{}
deps.Append(dep.Name)
}
return deps
}
func prodDeps(root Package, packages map[string]Package) map[string]struct{} {
visited := make(map[string]struct{})
func prodDeps(root Package, packages map[string]Package) set.Set[string] {
visited := set.New[string]()
walkPackageDeps(root, packages, visited)
return visited
}
func walkPackageDeps(pkg Package, packages map[string]Package, visited map[string]struct{}) {
if _, ok := visited[pkg.Name]; ok {
func walkPackageDeps(pkg Package, packages map[string]Package, visited set.Set[string]) {
if visited.Contains(pkg.Name) {
return
}
visited[pkg.Name] = struct{}{}
visited.Append(pkg.Name)
for _, dep := range pkg.Dependencies {
depPkg, exists := packages[dep.Name]
if !exists {
@@ -119,7 +120,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
)
for _, pkg := range lock.Packages {
if _, ok := prodDeps[pkg.Name]; !ok {
if !prodDeps.Contains(pkg.Name) {
continue
}
@@ -127,7 +128,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
relationship := ftypes.RelationshipIndirect
if pkg.isRoot() {
relationship = ftypes.RelationshipRoot
} else if _, ok := directDeps[pkg.Name]; ok {
} else if directDeps.Contains(pkg.Name) {
relationship = ftypes.RelationshipDirect
}

View File

@@ -10,19 +10,6 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
)
func UniqueStrings(ss []string) []string {
var results []string
uniq := make(map[string]struct{})
for _, s := range ss {
if _, ok := uniq[s]; ok {
continue
}
results = append(results, s)
uniq[s] = struct{}{}
}
return results
}
func UniquePackages(pkgs []ftypes.Package) []ftypes.Package {
if len(pkgs) == 0 {
return nil

View File

@@ -19,6 +19,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/set"
)
const (
@@ -179,33 +180,30 @@ func (a alpineCmdAnalyzer) parseCommand(command string, envs map[string]string)
return pkgs
}
func (a alpineCmdAnalyzer) resolveDependencies(apkIndexArchive *apkIndex, originalPkgs []string) (pkgs []string) {
uniqPkgs := make(map[string]struct{})
uniqPkgs := set.New[string]()
for _, pkgName := range originalPkgs {
if _, ok := uniqPkgs[pkgName]; ok {
if uniqPkgs.Contains(pkgName) {
continue
}
seenPkgs := make(map[string]struct{})
seenPkgs := set.New[string]()
for _, p := range a.resolveDependency(apkIndexArchive, pkgName, seenPkgs) {
uniqPkgs[p] = struct{}{}
uniqPkgs.Append(p)
}
}
for pkg := range uniqPkgs {
pkgs = append(pkgs, pkg)
}
return pkgs
return uniqPkgs.Items()
}
func (a alpineCmdAnalyzer) resolveDependency(apkIndexArchive *apkIndex, pkgName string,
seenPkgs map[string]struct{}) (pkgNames []string) {
seenPkgs set.Set[string]) (pkgNames []string) {
pkg, ok := apkIndexArchive.Package[pkgName]
if !ok {
return nil
}
if _, ok = seenPkgs[pkgName]; ok {
if seenPkgs.Contains(pkgName) {
return nil
}
seenPkgs[pkgName] = struct{}{}
seenPkgs.Append(pkgName)
pkgNames = append(pkgNames, pkgName)
for _, dependency := range pkg.Dependencies {

View File

@@ -19,6 +19,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/set"
)
var (
@@ -1508,86 +1509,86 @@ func TestResolveDependency(t *testing.T) {
var tests = map[string]struct {
pkgName string
apkIndexArchivePath string
expected map[string]struct{}
want set.Set[string]
}{
"low": {
pkgName: "libblkid",
apkIndexArchivePath: "testdata/history_v3.9.json",
expected: map[string]struct{}{
"libblkid": {},
"libuuid": {},
"musl": {},
},
want: set.New(
"libblkid",
"libuuid",
"musl",
),
},
"medium": {
pkgName: "libgcab",
apkIndexArchivePath: "testdata/history_v3.9.json",
expected: map[string]struct{}{
"busybox": {},
"libblkid": {},
"libuuid": {},
"musl": {},
"libmount": {},
"pcre": {},
"glib": {},
"libgcab": {},
"libintl": {},
"zlib": {},
"libffi": {},
},
want: set.New(
"busybox",
"libblkid",
"libuuid",
"musl",
"libmount",
"pcre",
"glib",
"libgcab",
"libintl",
"zlib",
"libffi",
),
},
"high": {
pkgName: "postgresql",
apkIndexArchivePath: "testdata/history_v3.9.json",
expected: map[string]struct{}{
"busybox": {},
"ncurses-terminfo-base": {},
"ncurses-terminfo": {},
"libedit": {},
"db": {},
"libsasl": {},
"libldap": {},
"libpq": {},
"postgresql-client": {},
"tzdata": {},
"libxml2": {},
"postgresql": {},
"musl": {},
"libcrypto1.1": {},
"libssl1.1": {},
"ncurses-libs": {},
"zlib": {},
},
want: set.New(
"busybox",
"ncurses-terminfo-base",
"ncurses-terminfo",
"libedit",
"db",
"libsasl",
"libldap",
"libpq",
"postgresql-client",
"tzdata",
"libxml2",
"postgresql",
"musl",
"libcrypto1.1",
"libssl1.1",
"ncurses-libs",
"zlib",
),
},
"package alias": {
pkgName: "sqlite-dev",
apkIndexArchivePath: "testdata/history_v3.9.json",
expected: map[string]struct{}{
"sqlite-dev": {},
"sqlite-libs": {},
"pkgconf": {}, // pkgconfig => pkgconf
"musl": {},
},
want: set.New(
"sqlite-dev",
"sqlite-libs",
"pkgconf", // pkgconfig => pkgconf
"musl",
),
},
"circular dependencies": {
pkgName: "nodejs",
apkIndexArchivePath: "testdata/history_v3.7.json",
expected: map[string]struct{}{
"busybox": {},
"c-ares": {},
"ca-certificates": {},
"http-parser": {},
"libcrypto1.0": {},
"libgcc": {},
"libressl2.6-libcrypto": {},
"libssl1.0": {},
"libstdc++": {},
"libuv": {},
"musl": {},
"nodejs": {},
"nodejs-npm": {},
"zlib": {},
},
want: set.New(
"busybox",
"c-ares",
"ca-certificates",
"http-parser",
"libcrypto1.0",
"libgcc",
"libressl2.6-libcrypto",
"libssl1.0",
"libstdc++",
"libuv",
"musl",
"nodejs",
"nodejs-npm",
"zlib",
),
},
}
analyzer := alpineCmdAnalyzer{}
@@ -1600,15 +1601,10 @@ func TestResolveDependency(t *testing.T) {
if err = json.NewDecoder(f).Decode(&apkIndexArchive); err != nil {
t.Fatalf("unexpected error: %s", err)
}
circularDependencyCheck := make(map[string]struct{})
circularDependencyCheck := set.New[string]()
pkgs := analyzer.resolveDependency(apkIndexArchive, v.pkgName, circularDependencyCheck)
actual := make(map[string]struct{})
for _, pkg := range pkgs {
actual[pkg] = struct{}{}
}
if !reflect.DeepEqual(v.expected, actual) {
t.Errorf("[%s]\n%s", testName, pretty.Compare(v.expected, actual))
}
got := set.New(pkgs...)
assert.Equal(t, v.want, got, testName)
}
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/analyzer/language"
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
)
@@ -105,7 +106,7 @@ func (a poetryAnalyzer) mergePyProject(fsys fs.FS, dir string, app *types.Applic
// Identify the direct/transitive dependencies
for i, pkg := range app.Packages {
if _, ok := project.Tool.Poetry.Dependencies[pkg.Name]; ok {
if project.Tool.Poetry.Dependencies.Contains(pkg.Name) {
app.Packages[i].Relationship = types.RelationshipDirect
} else {
app.Packages[i].Indirect = true
@@ -122,34 +123,33 @@ func filterProdPackages(project pyproject.PyProject, app *types.Application) {
return pkg.ID, pkg
})
visited := make(map[string]struct{})
visited := set.New[string]()
deps := project.Tool.Poetry.Dependencies
for group, groupDeps := range project.Tool.Poetry.Groups {
if group == "dev" {
continue
}
deps = lo.Assign(deps, groupDeps.Dependencies)
deps.Set = deps.Union(groupDeps.Dependencies)
}
for _, pkg := range packages {
if _, prodDep := deps[pkg.Name]; !prodDep {
if !deps.Contains(pkg.Name) {
continue
}
walkPackageDeps(pkg.ID, packages, visited)
}
app.Packages = lo.Filter(app.Packages, func(pkg types.Package, _ int) bool {
_, ok := visited[pkg.ID]
return ok
return visited.Contains(pkg.ID)
})
}
func walkPackageDeps(pkgID string, packages map[string]types.Package, visited map[string]struct{}) {
if _, ok := visited[pkgID]; ok {
func walkPackageDeps(pkgID string, packages map[string]types.Package, visited set.Set[string]) {
if visited.Contains(pkgID) {
return
}
visited[pkgID] = struct{}{}
visited.Append(pkgID)
for _, dep := range packages[pkgID].DependsOn {
walkPackageDeps(dep, packages, visited)
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/licensing"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
)
func init() {
@@ -185,13 +186,13 @@ func (a alpinePkgAnalyzer) consolidateDependencies(pkgs []types.Package, provide
}
func (a alpinePkgAnalyzer) uniquePkgs(pkgs []types.Package) (uniqPkgs []types.Package) {
uniq := make(map[string]struct{})
uniq := set.New[string]()
for _, pkg := range pkgs {
if _, ok := uniq[pkg.Name]; ok {
if uniq.Contains(pkg.Name) {
continue
}
uniqPkgs = append(uniqPkgs, pkg)
uniq[pkg.Name] = struct{}{}
uniq.Append(pkg.Name)
}
return uniqPkgs
}

View File

@@ -226,9 +226,9 @@ func (img *image) imageConfig(config *container.Config) v1.Config {
}
if len(config.ExposedPorts) > 0 {
c.ExposedPorts = make(map[string]struct{})
for port := range c.ExposedPorts {
c.ExposedPorts[port] = struct{}{}
c.ExposedPorts = make(map[string]struct{}) //nolint: gocritic
for port := range config.ExposedPorts {
c.ExposedPorts[port.Port()] = struct{}{}
}
}

View File

@@ -56,14 +56,6 @@ func IsGzip(f *bufio.Reader) bool {
return buf[0] == 0x1F && buf[1] == 0x8B && buf[2] == 0x8
}
func Keys(m map[string]struct{}) []string {
var keys []string
for k := range m {
keys = append(keys, k)
}
return keys
}
func IsExecutable(fileInfo os.FileInfo) bool {
// For Windows
if filepath.Ext(fileInfo.Name()) == ".exe" {

View File

@@ -13,6 +13,7 @@ import (
checks "github.com/aquasecurity/trivy-checks"
"github.com/aquasecurity/trivy/pkg/iac/rules"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
)
var LoadAndRegister = sync.OnceFunc(func() {
@@ -49,7 +50,7 @@ func RegisterRegoRules(modules map[string]*ast.Module) {
}
retriever := NewMetadataRetriever(compiler)
regoCheckIDs := make(map[string]struct{})
regoCheckIDs := set.New[string]()
for _, module := range modules {
metadata, err := retriever.RetrieveMetadata(ctx, module)
@@ -66,7 +67,7 @@ func RegisterRegoRules(modules map[string]*ast.Module) {
}
if !metadata.Deprecated {
regoCheckIDs[metadata.AVDID] = struct{}{}
regoCheckIDs.Append(metadata.AVDID)
}
rules.Register(metadata.ToRule())

View File

@@ -12,16 +12,13 @@ import (
"github.com/samber/lo"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
)
var builtinNamespaces = map[string]struct{}{
"builtin": {},
"defsec": {},
"appshield": {},
}
var builtinNamespaces = set.New("builtin", "defsec", "appshield")
func BuiltinNamespaces() []string {
return lo.Keys(builtinNamespaces)
return builtinNamespaces.Items()
}
func IsBuiltinNamespace(namespace string) bool {
@@ -122,15 +119,12 @@ func (s *Scanner) LoadPolicies(srcFS fs.FS) error {
}
// gather namespaces
uniq := make(map[string]struct{})
uniq := set.New[string]()
for _, module := range s.policies {
namespace := getModuleNamespace(module)
uniq[namespace] = struct{}{}
}
var namespaces []string
for namespace := range uniq {
namespaces = append(namespaces, namespace)
uniq.Append(namespace)
}
namespaces := uniq.Items()
dataFS := srcFS
if s.dataFS != nil {
@@ -296,7 +290,7 @@ func (s *Scanner) filterModules(retriever *MetadataRetriever) error {
}
if IsBuiltinNamespace(getModuleNamespace(module)) {
if _, disabled := s.disabledCheckIDs[meta.ID]; disabled { // ignore builtin disabled checks
if s.disabledCheckIDs.Contains(meta.ID) { // ignore builtin disabled checks
continue
}
}

View File

@@ -69,9 +69,7 @@ func WithDataDirs(paths ...string) options.ScannerOption {
func WithPolicyNamespaces(namespaces ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
for _, namespace := range namespaces {
ss.ruleNamespaces[namespace] = struct{}{}
}
ss.ruleNamespaces.Append(namespaces...)
}
}
}
@@ -112,9 +110,7 @@ func WithCustomSchemas(schemas map[string][]byte) options.ScannerOption {
func WithDisabledCheckIDs(ids ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
for _, id := range ids {
ss.disabledCheckIDs[id] = struct{}{}
}
ss.disabledCheckIDs.Append(ids...)
}
}
}

View File

@@ -121,7 +121,7 @@ func parseLineNumber(raw any) int {
return n
}
func (s *Scanner) convertResults(set rego.ResultSet, input Input, namespace, rule string, traces []string) scan.Results {
func (s *Scanner) convertResults(resultSet rego.ResultSet, input Input, namespace, rule string, traces []string) scan.Results {
var results scan.Results
offset := 0
@@ -136,7 +136,7 @@ func (s *Scanner) convertResults(set rego.ResultSet, input Input, namespace, rul
}
}
}
for _, result := range set {
for _, result := range resultSet {
for _, expression := range result.Expressions {
values, ok := expression.Value.([]any)
if !ok {

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"io"
"io/fs"
"maps"
"strings"
"github.com/open-policy-agent/opa/ast"
@@ -22,29 +21,26 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
"github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
)
var checkTypesWithSubtype = map[types.Source]struct{}{
types.SourceCloud: {},
types.SourceDefsec: {},
types.SourceKubernetes: {},
}
var checkTypesWithSubtype = set.New[types.Source](types.SourceCloud, types.SourceDefsec, types.SourceKubernetes)
var supportedProviders = makeSupportedProviders()
func makeSupportedProviders() map[string]struct{} {
m := make(map[string]struct{})
func makeSupportedProviders() set.Set[string] {
m := set.New[string]()
for _, p := range providers.AllProviders() {
m[string(p)] = struct{}{}
m.Append(string(p))
}
m["kind"] = struct{}{} // kubernetes
m.Append("kind") // kubernetes
return m
}
var _ options.ConfigurableScanner = (*Scanner)(nil)
type Scanner struct {
ruleNamespaces map[string]struct{}
ruleNamespaces set.Set[string]
policies map[string]*ast.Module
store storage.Store
runtimeValues *ast.Term
@@ -70,7 +66,7 @@ type Scanner struct {
embeddedChecks map[string]*ast.Module
customSchemas map[string][]byte
disabledCheckIDs map[string]struct{}
disabledCheckIDs set.Set[string]
}
func (s *Scanner) trace(heading string, input any) {
@@ -103,15 +99,13 @@ func NewScanner(source types.Source, opts ...options.ScannerOption) *Scanner {
s := &Scanner{
regoErrorLimit: ast.CompileErrorLimitDefault,
sourceType: source,
ruleNamespaces: make(map[string]struct{}),
ruleNamespaces: builtinNamespaces.Clone(),
runtimeValues: addRuntimeValues(),
logger: log.WithPrefix("rego"),
customSchemas: make(map[string][]byte),
disabledCheckIDs: make(map[string]struct{}),
disabledCheckIDs: set.New[string](),
}
maps.Copy(s.ruleNamespaces, builtinNamespaces)
for _, opt := range opts {
opt(s)
}
@@ -147,7 +141,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, d
}
instance := rego.New(regoOptions...)
set, err := instance.Eval(ctx)
resultSet, err := instance.Eval(ctx)
if err != nil {
return nil, nil, err
}
@@ -165,7 +159,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, d
traces = strings.Split(traceBuffer.String(), "\n")
}
}
return set, traces, nil
return resultSet, traces, nil
}
type Input struct {
@@ -198,7 +192,7 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
namespace := getModuleNamespace(module)
topLevel := strings.Split(namespace, ".")[0]
if _, ok := s.ruleNamespaces[topLevel]; !ok {
if !s.ruleNamespaces.Contains(topLevel) {
continue
}
@@ -227,15 +221,15 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
continue
}
usedRules := make(map[string]struct{})
usedRules := set.New[string]()
// all rules
for _, rule := range module.Rules {
ruleName := rule.Head.Name.String()
if _, ok := usedRules[ruleName]; ok {
if usedRules.Contains(ruleName) {
continue
}
usedRules[ruleName] = struct{}{}
usedRules.Append(ruleName)
if isEnforcedRule(ruleName) {
ruleResults, err := s.applyRule(ctx, namespace, ruleName, inputs)
if err != nil {
@@ -257,8 +251,7 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
}
func isPolicyWithSubtype(sourceType types.Source) bool {
_, exists := checkTypesWithSubtype[sourceType]
return exists
return checkTypesWithSubtype.Contains(sourceType)
}
func checkSubtype(ii map[string]any, provider string, subTypes []SubType) bool {
@@ -290,7 +283,7 @@ func isPolicyApplicable(staticMetadata *StaticMetadata, inputs ...Input) bool {
for _, input := range inputs {
if ii, ok := input.Contents.(map[string]any); ok {
for provider := range ii {
if _, exists := supportedProviders[provider]; !exists {
if !supportedProviders.Contains(provider) {
continue
}
@@ -329,12 +322,12 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs
continue
}
set, traces, err := s.runQuery(ctx, qualified, parsedInput, false)
resultSet, traces, err := s.runQuery(ctx, qualified, parsedInput, false)
if err != nil {
return nil, err
}
s.trace("RESULTSET", set)
ruleResults := s.convertResults(set, input, namespace, rule, traces)
s.trace("RESULTSET", resultSet)
ruleResults := s.convertResults(resultSet, input, namespace, rule, traces)
if len(ruleResults) == 0 { // It passed because we didn't find anything wrong (NOT because it didn't exist)
var result regoResult
result.FS = input.FS

View File

@@ -10,6 +10,7 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/scan"
dftypes "github.com/aquasecurity/trivy/pkg/iac/types"
ruleTypes "github.com/aquasecurity/trivy/pkg/iac/types/rules"
"github.com/aquasecurity/trivy/pkg/set"
)
type registry struct {
@@ -74,14 +75,14 @@ func (r *registry) getFrameworkRules(fw ...framework.Framework) []ruleTypes.Regi
if len(fw) == 0 {
fw = []framework.Framework{framework.Default}
}
unique := make(map[int]struct{})
unique := set.New[int]()
for _, f := range fw {
for _, rule := range r.frameworks[f] {
if _, ok := unique[rule.Number]; ok {
if unique.Contains(rule.Number) {
continue
}
registered = append(registered, rule)
unique[rule.Number] = struct{}{}
unique.Append(rule.Number)
}
}
return registered

View File

@@ -19,6 +19,7 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/terraform"
"github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
)
var _ scanners.FSScanner = (*Scanner)(nil)
@@ -31,7 +32,7 @@ type Scanner struct {
options []options.ScannerOption
parserOpt []parser.Option
executorOpt []executor.Option
dirs map[string]struct{}
dirs set.Set[string]
forceAllDirs bool
regoScanner *rego.Scanner
execLock sync.RWMutex
@@ -55,7 +56,7 @@ func (s *Scanner) AddExecutorOptions(opts ...executor.Option) {
func New(opts ...options.ScannerOption) *Scanner {
s := &Scanner{
dirs: make(map[string]struct{}),
dirs: set.New[string](),
options: opts,
logger: log.WithPrefix("terraform scanner"),
}

View File

@@ -7,6 +7,8 @@ import (
"github.com/liamg/memoryfs"
"github.com/stretchr/testify/assert"
"github.com/aquasecurity/trivy/pkg/set"
)
func Test_FSKey(t *testing.T) {
@@ -18,22 +20,20 @@ func Test_FSKey(t *testing.T) {
memoryfs.New(),
}
keys := make(map[string]struct{})
keys := set.New[string]()
t.Run("uniqueness", func(t *testing.T) {
for _, system := range systems {
key := CreateFSKey(system)
_, ok := keys[key]
assert.False(t, ok, "filesystem keys should be unique")
keys[key] = struct{}{}
assert.False(t, keys.Contains(key), "filesystem keys should be unique")
keys.Append(key)
}
})
t.Run("reproducible", func(t *testing.T) {
for _, system := range systems {
key := CreateFSKey(system)
_, ok := keys[key]
assert.True(t, ok, "filesystem keys should be reproducible")
assert.True(t, keys.Contains(key), "filesystem keys should be reproducible")
}
})
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
)
var (
@@ -43,7 +44,7 @@ func Classify(filePath string, r io.Reader, confidenceLevel float64) (*types.Lic
var findings types.LicenseFindings
var matchType types.LicenseType
seen := make(map[string]struct{})
seen := set.New[string]()
// cf.Match is not thread safe
m.Lock()
@@ -57,11 +58,11 @@ func Classify(filePath string, r io.Reader, confidenceLevel float64) (*types.Lic
if match.Confidence <= confidenceLevel {
continue
}
if _, ok := seen[match.Name]; ok {
if seen.Contains(match.Name) {
continue
}
seen[match.Name] = struct{}{}
seen.Append(match.Name)
switch match.MatchType {
case "Header":

View File

@@ -20,6 +20,7 @@ import (
"github.com/aquasecurity/testdocker/registry"
"github.com/aquasecurity/testdocker/tarfile"
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/version/app"
)
@@ -216,13 +217,13 @@ type userAgentsTrackingHandler struct {
hr http.Handler
mu sync.Mutex
agents map[string]struct{}
agents set.Set[string]
}
func newUserAgentsTrackingHandler(hr http.Handler) *userAgentsTrackingHandler {
return &userAgentsTrackingHandler{
hr: hr,
agents: make(map[string]struct{}),
agents: set.New[string](),
}
}
@@ -230,7 +231,7 @@ func (uh *userAgentsTrackingHandler) ServeHTTP(rw http.ResponseWriter, r *http.R
for _, agent := range r.Header["User-Agent"] {
// Skip test framework user agent
if agent != "Go-http-client/1.1" {
uh.agents[agent] = struct{}{}
uh.agents.Append(agent)
}
}
uh.hr.ServeHTTP(rw, r)
@@ -271,7 +272,7 @@ func TestUserAgents(t *testing.T) {
require.NoError(t, err)
require.Len(t, tracker.agents, 1)
_, ok := tracker.agents[fmt.Sprintf("trivy/%s go-containerregistry", app.Version())]
ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s go-containerregistry", app.Version()))
require.True(t, ok, `user-agent header equals to "trivy/dev go-containerregistry"`)
}

View File

@@ -19,6 +19,7 @@ import (
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/version/doc"
)
@@ -279,7 +280,7 @@ Dependency Origin Tree (Reversed)
topLvlID := tml.Sprintf("<red>%s, (%s)</red>", vulnPkg.ID, strings.Join(summaries, ", "))
branch := root.AddBranch(topLvlID)
addParents(branch, vulnPkg, parents, ancestors, map[string]struct{}{vulnPkg.ID: {}}, 1)
addParents(branch, vulnPkg, parents, ancestors, set.New(vulnPkg.ID), 1)
}
r.printf(root.String())
@@ -291,17 +292,17 @@ func (r *vulnerabilityRenderer) printf(format string, args ...any) {
}
func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string]ftypes.Packages, ancestors map[string][]string,
seen map[string]struct{}, depth int) {
seen set.Set[string], depth int) {
if pkg.Relationship == ftypes.RelationshipDirect {
return
}
roots := make(map[string]struct{})
roots := set.New[string]()
for _, parent := range parentMap[pkg.ID] {
if _, ok := seen[parent.ID]; ok {
if seen.Contains(parent.ID) {
continue
}
seen[parent.ID] = struct{}{} // to avoid infinite loops
seen.Append(parent.ID) // to avoid infinite loops
if depth == 1 && parent.Relationship == ftypes.RelationshipDirect {
topItem.AddBranch(parent.ID)
@@ -309,16 +310,13 @@ func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string
// We omit intermediate dependencies and show only direct dependencies
// as this could make the dependency tree huge.
for _, ancestor := range ancestors[parent.ID] {
roots[ancestor] = struct{}{}
roots.Append(ancestor)
}
}
}
// Omitted
rootIDs := lo.Filter(lo.Keys(roots), func(pkgID string, _ int) bool {
_, ok := seen[pkgID]
return !ok
})
rootIDs := roots.Difference(seen).Items()
sort.Strings(rootIDs)
if len(rootIDs) > 0 {
branch := topItem.AddBranch("...(omitted)...")
@@ -331,21 +329,21 @@ func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string
func traverseAncestors(pkgs []ftypes.Package, parentMap map[string]ftypes.Packages) map[string][]string {
ancestors := make(map[string][]string)
for _, pkg := range pkgs {
ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, make(map[string]struct{}))
ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, set.New[string]())
}
return ancestors
}
func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen map[string]struct{}) []string {
ancestors := make(map[string]struct{})
seen[pkgID] = struct{}{}
func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen set.Set[string]) []string {
ancestors := set.New[string]()
seen.Append(pkgID)
for _, parent := range parentMap[pkgID] {
if _, ok := seen[parent.ID]; ok {
if seen.Contains(parent.ID) {
continue
}
switch {
case parent.Relationship == ftypes.RelationshipDirect:
ancestors[parent.ID] = struct{}{}
ancestors.Append(parent.ID)
case len(parentMap[parent.ID]) == 0:
// Some package managers, such as "package-lock.json" v1, can retrieve package dependencies but not relationships.
// We try to guess direct dependencies in this case. A dependency with no parents must be a direct dependency.
@@ -358,14 +356,14 @@ func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen map[s
//
// Even if `styled-components` is not marked as a direct dependency, it must be a direct dependency
// as it has no parents. Note that it doesn't mean `fbjs` is an indirect dependency.
ancestors[parent.ID] = struct{}{}
ancestors.Append(parent.ID)
default:
for _, a := range findAncestor(parent.ID, parentMap, seen) {
ancestors[a] = struct{}{}
ancestors.Append(a)
}
}
}
return lo.Keys(ancestors)
return ancestors.Items()
}
var jarExtensions = []string{

View File

@@ -9,6 +9,7 @@ import (
"github.com/aquasecurity/trivy/pkg/detector/library"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types"
)
@@ -41,7 +42,7 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.
}
var results types.Results
printedTypes := make(map[ftypes.LangType]struct{})
printedTypes := set.New[ftypes.LangType]()
for _, app := range apps {
if len(app.Packages) == 0 {
continue
@@ -76,13 +77,13 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.
return results, nil
}
func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes map[ftypes.LangType]struct{}) (
func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes set.Set[ftypes.LangType]) (
[]types.DetectedVulnerability, error) {
// Prevent the same log messages from being displayed many times for the same type.
if _, ok := printedTypes[app.Type]; !ok {
if !printedTypes.Contains(app.Type) {
log.InfoContext(ctx, "Detecting vulnerabilities...")
printedTypes[app.Type] = struct{}{}
printedTypes.Append(app.Type)
}
log.DebugContext(ctx, "Scanning packages for vulnerabilities", log.FilePath(app.FilePath))

View File

@@ -24,6 +24,7 @@ import (
"github.com/aquasecurity/trivy/pkg/scanner/langpkg"
"github.com/aquasecurity/trivy/pkg/scanner/ospkg"
"github.com/aquasecurity/trivy/pkg/scanner/post"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/vulnerability"
@@ -458,12 +459,12 @@ func mergePkgs(pkgs, pkgsFromCommands []ftypes.Package, options types.ScanOption
}
// pkg has priority over pkgsFromCommands
uniqPkgs := make(map[string]struct{})
uniqPkgs := set.New[string]()
for _, pkg := range pkgs {
uniqPkgs[pkg.Name] = struct{}{}
uniqPkgs.Append(pkg.Name)
}
for _, pkg := range pkgsFromCommands {
if _, ok := uniqPkgs[pkg.Name]; ok {
if uniqPkgs.Contains(pkg.Name) {
continue
}
pkgs = append(pkgs, pkg)

39
pkg/set/set.go Normal file
View File

@@ -0,0 +1,39 @@
package set
import "iter"
// Set defines the interface for set operations
type Set[T comparable] interface {
// Append adds multiple items to the set and returns the new size
Append(val ...T) int
// Remove removes an item from the set
Remove(item T)
// Contains checks if an item exists in the set
Contains(item T) bool
// Size returns the number of items in the set
Size() int
// Clear removes all items from the set
Clear()
// Clone returns a new set with a copy of all items
Clone() Set[T]
// Items returns all items in the set as a slice
Items() []T
// Iter returns an iterator over the set
Iter() iter.Seq[T]
// Union returns a new set containing all items from both sets
Union(other Set[T]) Set[T]
// Intersection returns a new set containing items present in both sets
Intersection(other Set[T]) Set[T]
// Difference returns a new set containing items present in this set but not in the other
Difference(other Set[T]) Set[T]
}

100
pkg/set/unsafe.go Normal file
View File

@@ -0,0 +1,100 @@
package set
import (
"iter"
"maps"
"slices"
)
// unsafeSet represents a non-thread-safe set implementation
// WARNING: This implementation is not thread-safe
type unsafeSet[T comparable] map[T]struct{} //nolint: gocritic
// New creates a new empty non-thread-safe set with optional initial values
func New[T comparable](values ...T) Set[T] {
s := make(unsafeSet[T])
for _, v := range values {
s[v] = struct{}{}
}
return s
}
// Append adds multiple items to the set and returns the new size
func (s unsafeSet[T]) Append(val ...T) int {
for _, item := range val {
s[item] = struct{}{}
}
return len(s)
}
// Remove removes an item from the set
func (s unsafeSet[T]) Remove(item T) {
delete(s, item)
}
// Contains checks if an item exists in the set
func (s unsafeSet[T]) Contains(item T) bool {
_, exists := s[item]
return exists
}
// Size returns the number of items in the set
func (s unsafeSet[T]) Size() int {
return len(s)
}
// Clear removes all items from the set
func (s unsafeSet[T]) Clear() {
for k := range s {
delete(s, k)
}
}
// Clone returns a new set with a copy of all items
func (s unsafeSet[T]) Clone() Set[T] {
return maps.Clone(s)
}
// Items returns all items in the set as a slice
func (s unsafeSet[T]) Items() []T {
return slices.Collect(s.Iter())
}
// Iter returns an iterator over the set
func (s unsafeSet[T]) Iter() iter.Seq[T] {
return maps.Keys(s)
}
// Union returns a new set containing all items from both sets
func (s unsafeSet[T]) Union(other Set[T]) Set[T] {
result := make(unsafeSet[T])
for k := range s {
result[k] = struct{}{}
}
for _, item := range other.Items() {
result[item] = struct{}{}
}
return result
}
// Intersection returns a new set containing items present in both sets
func (s unsafeSet[T]) Intersection(other Set[T]) Set[T] {
result := make(unsafeSet[T])
for k := range s {
if other.Contains(k) {
result[k] = struct{}{}
}
}
return result
}
// Difference returns a new set containing items present in this set but not in the other
func (s unsafeSet[T]) Difference(other Set[T]) Set[T] {
result := make(unsafeSet[T])
for k := range s {
if !other.Contains(k) {
result[k] = struct{}{}
}
}
return result
}

583
pkg/set/unsafe_test.go Normal file
View File

@@ -0,0 +1,583 @@
package set_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aquasecurity/trivy/pkg/set"
)
func Test_New(t *testing.T) {
tests := []struct {
name string
values []int
wantSize int
wantAll bool
desc string
}{
{
name: "new empty set",
values: []int{},
wantSize: 0,
wantAll: true,
desc: "should create empty set when no values provided",
},
{
name: "new set with single value",
values: []int{1},
wantSize: 1,
wantAll: true,
desc: "should create set with single value",
},
{
name: "new set with multiple values",
values: []int{
1,
2,
3,
},
wantSize: 3,
wantAll: true,
desc: "should create set with multiple values",
},
{
name: "new set with duplicate values",
values: []int{
1,
2,
2,
3,
3,
3,
},
wantSize: 3,
wantAll: true,
desc: "should create set with unique values only",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New(tt.values...)
assert.Equal(t, tt.wantSize, s.Size(), "unexpected set size")
})
}
}
func Test_unsafeSet_Add(t *testing.T) {
// Define custom type for struct test cases
type custom struct {
id int
name string
}
tests := []struct {
name string
prepare func(s set.Set[any])
input any
wantSize int
}{
{
name: "add integer",
prepare: nil,
input: 1,
wantSize: 1,
},
{
name: "add duplicate integer",
prepare: func(s set.Set[any]) {
s.Append(1)
},
input: 1,
wantSize: 1,
},
{
name: "add string",
prepare: nil,
input: "test",
wantSize: 1,
},
{
name: "add empty string",
prepare: nil,
input: "",
wantSize: 1,
},
{
name: "add custom struct",
prepare: nil,
input: custom{
id: 1,
name: "test1",
},
wantSize: 1,
},
{
name: "add nil pointer",
prepare: nil,
input: (*int)(nil),
wantSize: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[any]()
if tt.prepare != nil {
tt.prepare(s)
}
s.Append(tt.input)
got := s.Size()
assert.Equal(t, tt.wantSize, got, "unexpected set size")
assert.True(t, s.Contains(tt.input), "unexpected contains result for value: %v", tt.input)
})
}
}
func Test_unsafeSet_Append(t *testing.T) {
tests := []struct {
name string
prepare func(s set.Set[int])
input []int
wantSize int
}{
{
name: "append to empty set",
prepare: nil,
input: []int{
1,
2,
3,
},
wantSize: 3,
},
{
name: "append with duplicates",
prepare: func(s set.Set[int]) {
s.Append(1)
},
input: []int{
1,
2,
1,
3,
2,
},
wantSize: 3,
},
{
name: "append empty slice",
prepare: func(s set.Set[int]) {
s.Append(1)
},
input: []int{},
wantSize: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[int]()
if tt.prepare != nil {
tt.prepare(s)
}
got := s.Append(tt.input...)
assert.Equal(t, tt.wantSize, got, "unexpected returned size")
assert.Equal(t, tt.wantSize, s.Size(), "unexpected actual size")
for _, item := range tt.input {
assert.True(t, s.Contains(item), "set should contain appended item: %v", item)
}
})
}
}
func Test_unsafeSet_Remove(t *testing.T) {
tests := []struct {
name string
prepare func(s set.Set[int])
input int
wantSize int
}{
{
name: "remove existing element",
prepare: func(s set.Set[int]) {
s.Append(1)
},
input: 1,
wantSize: 0,
},
{
name: "remove non-existing element",
prepare: func(s set.Set[int]) {
s.Append(1)
},
input: 2,
wantSize: 1,
},
{
name: "remove from empty set",
prepare: nil,
input: 1,
wantSize: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[int]()
if tt.prepare != nil {
tt.prepare(s)
}
s.Remove(tt.input)
got := s.Size()
assert.Equal(t, tt.wantSize, got, "unexpected set size")
assert.False(t, s.Contains(tt.input), "unexpected contains result for value: %v", tt.input)
})
}
}
func Test_unsafeSet_Clear(t *testing.T) {
tests := []struct {
name string
prepare func(s set.Set[int])
}{
{
name: "clear non-empty set",
prepare: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
},
{
name: "clear empty set",
prepare: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[int]()
if tt.prepare != nil {
tt.prepare(s)
}
s.Clear()
got := s.Size()
assert.Zero(t, got, "unexpected set size")
assert.Empty(t, s.Items(), "items should be empty")
})
}
}
func Test_unsafeSet_Clone(t *testing.T) {
t.Run("empty set", func(t *testing.T) {
original := set.New[string]()
cloned := original.Clone()
assert.Equal(t, 0, cloned.Size(), "cloned set should be empty")
// Verify independence
original.Append("test")
assert.False(t, cloned.Contains("test"), "cloned set should not be affected by original")
})
t.Run("basic types", func(t *testing.T) {
original := set.New[any](1, "test", true)
cloned := original.Clone()
assert.Equal(t, original.Size(), cloned.Size(), "sizes should match")
assert.True(t, cloned.Contains(1), "should contain integer")
assert.True(t, cloned.Contains("test"), "should contain string")
assert.True(t, cloned.Contains(true), "should contain boolean")
// Verify independence
original.Append("new")
assert.False(t, cloned.Contains("new"), "cloned set should not be affected by original")
cloned.Append("another")
assert.False(t, original.Contains("another"), "original set should not be affected by clone")
})
// Test nil pointer
t.Run("nil pointer", func(t *testing.T) {
original := set.New[*int]()
original.Append(nil)
cloned := original.Clone()
assert.Equal(t, original.Size(), cloned.Size(), "sizes should match")
assert.True(t, cloned.Contains((*int)(nil)), "should contain nil pointer")
})
}
func Test_unsafeSet_Items(t *testing.T) {
tests := []struct {
name string
prepare func(s set.Set[int])
want []int
}{
{
name: "get items from non-empty set",
prepare: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
want: []int{
1,
2,
3,
},
},
{
name: "get items from empty set",
prepare: nil,
want: []int{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := set.New[int]()
if tt.prepare != nil {
tt.prepare(s)
}
got := s.Items()
assert.ElementsMatch(t, tt.want, got, "unexpected items in set")
})
}
}
func Test_unsafeSet_Union(t *testing.T) {
tests := []struct {
name string
prepare1 func(s set.Set[int])
prepare2 func(s set.Set[int])
want []int
}{
{
name: "union of non-overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: func(s set.Set[int]) {
s.Append(3)
s.Append(4)
},
want: []int{
1,
2,
3,
4,
},
},
{
name: "union of overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
prepare2: func(s set.Set[int]) {
s.Append(2)
s.Append(3)
s.Append(4)
},
want: []int{
1,
2,
3,
4,
},
},
{
name: "union with empty set",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: nil,
want: []int{
1,
2,
},
},
{
name: "union of empty sets",
prepare1: nil,
prepare2: nil,
want: []int{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s1 := set.New[int]()
s2 := set.New[int]()
if tt.prepare1 != nil {
tt.prepare1(s1)
}
if tt.prepare2 != nil {
tt.prepare2(s2)
}
result := s1.Union(s2)
got := result.Items()
assert.ElementsMatch(t, tt.want, got, "unexpected union result")
})
}
}
func Test_unsafeSet_Intersection(t *testing.T) {
tests := []struct {
name string
prepare1 func(s set.Set[int])
prepare2 func(s set.Set[int])
want []int
}{
{
name: "intersection of overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
prepare2: func(s set.Set[int]) {
s.Append(2)
s.Append(3)
s.Append(4)
},
want: []int{
2,
3,
},
},
{
name: "intersection of non-overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: func(s set.Set[int]) {
s.Append(3)
s.Append(4)
},
want: []int{},
},
{
name: "intersection with empty set",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: nil,
want: []int{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s1 := set.New[int]()
s2 := set.New[int]()
if tt.prepare1 != nil {
tt.prepare1(s1)
}
if tt.prepare2 != nil {
tt.prepare2(s2)
}
result := s1.Intersection(s2)
got := result.Items()
assert.ElementsMatch(t, tt.want, got, "unexpected intersection result")
})
}
}
func Test_unsafeSet_Difference(t *testing.T) {
tests := []struct {
name string
prepare1 func(s set.Set[int])
prepare2 func(s set.Set[int])
want []int
}{
{
name: "difference of overlapping sets",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
s.Append(3)
},
prepare2: func(s set.Set[int]) {
s.Append(2)
s.Append(3)
s.Append(4)
},
want: []int{1},
},
{
name: "difference with non-overlapping set",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: func(s set.Set[int]) {
s.Append(3)
s.Append(4)
},
want: []int{
1,
2,
},
},
{
name: "difference with empty set",
prepare1: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
prepare2: nil,
want: []int{
1,
2,
},
},
{
name: "difference of empty set",
prepare1: nil,
prepare2: func(s set.Set[int]) {
s.Append(1)
s.Append(2)
},
want: []int{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s1 := set.New[int]()
s2 := set.New[int]()
if tt.prepare1 != nil {
tt.prepare1(s1)
}
if tt.prepare2 != nil {
tt.prepare2(s2)
}
result := s1.Difference(s2)
got := result.Items()
assert.ElementsMatch(t, tt.want, got, "unexpected difference result")
})
}
}