fix(terraform): fix root module search (#6160)

Co-authored-by: simar7 <1254783+simar7@users.noreply.github.com>
This commit is contained in:
Nikita Pivkin
2024-02-28 06:31:03 +03:00
committed by GitHub
parent e1ea02c7b8
commit 1dfece89d0
8 changed files with 334 additions and 60 deletions

View File

@@ -191,18 +191,10 @@ func (e *evaluator) EvaluateAll(ctx context.Context) (terraform.Modules, map[str
e.debug.Log("Module evaluation complete.")
parseDuration += time.Since(start)
rootModule := terraform.NewModule(e.projectRootPath, e.modulePath, e.blocks, e.ignores, e.isModuleLocal())
for _, m := range modules {
m.SetParent(rootModule)
}
rootModule := terraform.NewModule(e.projectRootPath, e.modulePath, e.blocks, e.ignores)
return append(terraform.Modules{rootModule}, modules...), fsMap, parseDuration
}
func (e *evaluator) isModuleLocal() bool {
// the module source is empty only for local modules
return e.parentParser.moduleSource == ""
}
func (e *evaluator) expandBlocks(blocks terraform.Blocks) terraform.Blocks {
return e.expandDynamicBlocks(e.expandBlockForEaches(e.expandBlockCounts(blocks), false)...)
}

View File

@@ -0,0 +1,78 @@
package parser
import (
"context"
"path"
"sort"
"strings"
"github.com/samber/lo"
"github.com/zclconf/go-cty/cty"
"github.com/aquasecurity/trivy/pkg/iac/terraform"
)
// FindRootModules takes a list of module paths and identifies the root local modules.
// It builds a graph based on the module dependencies and determines the modules that have no incoming dependencies,
// considering them as root modules.
func (p *Parser) FindRootModules(ctx context.Context, dirs []string) ([]string, error) {
for _, dir := range dirs {
if err := p.ParseFS(ctx, dir); err != nil {
return nil, err
}
}
blocks, _, err := p.readBlocks(p.files)
if err != nil {
return nil, err
}
g := buildGraph(blocks, dirs)
rootModules := g.rootModules()
sort.Strings(rootModules)
return rootModules, nil
}
type modulesGraph map[string][]string
func buildGraph(blocks terraform.Blocks, paths []string) modulesGraph {
moduleBlocks := blocks.OfType("module")
graph := lo.SliceToMap(paths, func(p string) (string, []string) {
return p, nil
})
for _, block := range moduleBlocks {
sourceVal := block.GetAttribute("source").Value()
if sourceVal.Type() != cty.String {
continue
}
source := sourceVal.AsString()
if strings.HasPrefix(source, ".") {
filename := block.GetMetadata().Range().GetFilename()
dir := path.Dir(filename)
graph[dir] = append(graph[dir], path.Join(dir, source))
}
}
return graph
}
func (g modulesGraph) rootModules() []string {
incomingEdges := make(map[string]int)
for _, neighbors := range g {
for _, neighbor := range neighbors {
incomingEdges[neighbor]++
}
}
var roots []string
for module := range g {
if incomingEdges[module] == 0 {
roots = append(roots, module)
}
}
return roots
}

View File

@@ -0,0 +1,71 @@
package parser
import (
"context"
"path"
"testing"
"github.com/aquasecurity/trivy/internal/testutil"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)
func TestFindRootModules(t *testing.T) {
tests := []struct {
name string
files map[string]string
expected []string
}{
{
name: "multiple root modules",
files: map[string]string{
"code/main.tf": `
module "this" {
count = 0
source = "./modules/s3"
}`,
"code/modules/s3/main.tf": `
module "this" {
source = "./modules/logging"
}
resource "aws_s3_bucket" "this" {
bucket = "test"
}`,
"code/modules/s3/modules/logging/main.tf": `
resource "aws_s3_bucket" "this" {
bucket = "test1"
}`,
"code/example/main.tf": `
module "this" {
source = "../modules/s3"
}`,
},
expected: []string{"code", "code/example"},
},
{
name: "without module block",
files: map[string]string{
"code/infra1/main.tf": `resource "test" "this" {}`,
"code/infra2/main.tf": `resource "test" "this" {}`,
},
expected: []string{"code/infra1", "code/infra2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fsys := testutil.CreateFS(t, tt.files)
parser := New(fsys, "", OptionStopOnHCLError(true))
modules := lo.Map(maps.Keys(tt.files), func(p string, _ int) string {
return path.Dir(p)
})
got, err := parser.FindRootModules(context.TODO(), modules)
require.NoError(t, err)
assert.Equal(t, tt.expected, got)
})
}
}

View File

@@ -1271,6 +1271,96 @@ func TestForEachWithObjectsOfDifferentTypes(t *testing.T) {
assert.Len(t, modules, 1)
}
func TestCountMetaArgument(t *testing.T) {
tests := []struct {
name string
src string
expected int
}{
{
name: "zero resources",
src: `resource "test" "this" {
count = 0
}`,
expected: 0,
},
{
name: "several resources",
src: `resource "test" "this" {
count = 2
}`,
expected: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fsys := testutil.CreateFS(t, map[string]string{
"main.tf": tt.src,
})
parser := New(fsys, "", OptionStopOnHCLError(true))
require.NoError(t, parser.ParseFS(context.TODO(), "."))
modules, _, err := parser.EvaluateAll(context.TODO())
require.NoError(t, err)
assert.Len(t, modules, 1)
resources := modules.GetResourcesByType("test")
assert.Len(t, resources, tt.expected)
})
}
}
func TestCountMetaArgumentInModule(t *testing.T) {
tests := []struct {
name string
files map[string]string
expectedCountModules int
expectedCountResources int
}{
{
name: "zero modules",
files: map[string]string{
"main.tf": `module "this" {
count = 0
source = "./modules/test"
}`,
"modules/test/main.tf": `resource "test" "this" {}`,
},
expectedCountModules: 1,
expectedCountResources: 0,
},
{
name: "several modules",
files: map[string]string{
"main.tf": `module "this" {
count = 2
source = "./modules/test"
}`,
"modules/test/main.tf": `resource "test" "this" {}`,
},
expectedCountModules: 3,
expectedCountResources: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fsys := testutil.CreateFS(t, tt.files)
parser := New(fsys, "", OptionStopOnHCLError(true))
require.NoError(t, parser.ParseFS(context.TODO(), "."))
modules, _, err := parser.EvaluateAll(context.TODO())
require.NoError(t, err)
assert.Len(t, modules, tt.expectedCountModules)
resources := modules.GetResourcesByType("test")
assert.Len(t, resources, tt.expectedCountResources)
})
}
}
func TestDynamicBlocks(t *testing.T) {
t.Run("arg is list of int", func(t *testing.T) {
modules := parse(t, map[string]string{

View File

@@ -2,6 +2,7 @@ package terraform
import (
"context"
"fmt"
"io"
"io/fs"
"path"
@@ -11,8 +12,6 @@ import (
"sync"
"time"
"golang.org/x/exp/slices"
"github.com/aquasecurity/trivy/pkg/extrafs"
"github.com/aquasecurity/trivy/pkg/iac/debug"
"github.com/aquasecurity/trivy/pkg/iac/framework"
@@ -20,7 +19,7 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/scan"
"github.com/aquasecurity/trivy/pkg/iac/scanners"
"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
executor2 "github.com/aquasecurity/trivy/pkg/iac/scanners/terraform/executor"
"github.com/aquasecurity/trivy/pkg/iac/scanners/terraform/executor"
"github.com/aquasecurity/trivy/pkg/iac/scanners/terraform/parser"
"github.com/aquasecurity/trivy/pkg/iac/scanners/terraform/parser/resolvers"
"github.com/aquasecurity/trivy/pkg/iac/terraform"
@@ -35,7 +34,7 @@ type Scanner struct { // nolint: gocritic
sync.Mutex
options []options.ScannerOption
parserOpt []options.ParserOption
executorOpt []executor2.Option
executorOpt []executor.Option
dirs map[string]struct{}
forceAllDirs bool
policyDirs []string
@@ -54,7 +53,7 @@ func (s *Scanner) SetSpec(spec string) {
}
func (s *Scanner) SetRegoOnly(regoOnly bool) {
s.executorOpt = append(s.executorOpt, executor2.OptionWithRegoOnly(regoOnly))
s.executorOpt = append(s.executorOpt, executor.OptionWithRegoOnly(regoOnly))
}
func (s *Scanner) SetFrameworks(frameworks []framework.Framework) {
@@ -81,7 +80,7 @@ func (s *Scanner) AddParserOptions(opts ...options.ParserOption) {
s.parserOpt = append(s.parserOpt, opts...)
}
func (s *Scanner) AddExecutorOptions(opts ...executor2.Option) {
func (s *Scanner) AddExecutorOptions(opts ...executor.Option) {
s.executorOpt = append(s.executorOpt, opts...)
}
@@ -95,7 +94,7 @@ func (s *Scanner) SetSkipRequiredCheck(skip bool) {
func (s *Scanner) SetDebugWriter(writer io.Writer) {
s.parserOpt = append(s.parserOpt, options.ParserWithDebug(writer))
s.executorOpt = append(s.executorOpt, executor2.OptionWithDebugWriter(writer))
s.executorOpt = append(s.executorOpt, executor.OptionWithDebugWriter(writer))
s.debug = debug.New(writer, "terraform", "scanner")
}
@@ -123,7 +122,7 @@ func (s *Scanner) SetRegoErrorLimit(_ int) {}
type Metrics struct {
Parser parser.Metrics
Executor executor2.Metrics
Executor executor.Metrics
Timings struct {
Total time.Duration
}
@@ -168,36 +167,17 @@ type terraformRootModule struct {
fsMap map[string]fs.FS
}
func excludeNonRootModules(modules []terraformRootModule) []terraformRootModule {
var result []terraformRootModule
var childPaths []string
for _, module := range modules {
childPaths = append(childPaths, module.childs.ChildModulesPaths()...)
}
for _, module := range modules {
// if the path of the root module matches the path of the child module,
// then we should not scan it
if !slices.Contains(childPaths, module.rootPath) {
result = append(result, module)
}
}
return result
}
func (s *Scanner) ScanFSWithMetrics(ctx context.Context, target fs.FS, dir string) (scan.Results, Metrics, error) {
var metrics Metrics
s.debug.Log("Scanning [%s] at '%s'...", target, dir)
// find directories which directly contain tf files (and have no parent containing tf files)
rootDirs := s.findRootModules(target, dir, dir)
sort.Strings(rootDirs)
// find directories which directly contain tf files
modulePaths := s.findModules(target, dir, dir)
sort.Strings(modulePaths)
if len(rootDirs) == 0 {
s.debug.Log("no root modules found")
if len(modulePaths) == 0 {
s.debug.Log("no modules found")
return nil, metrics, nil
}
@@ -207,13 +187,20 @@ func (s *Scanner) ScanFSWithMetrics(ctx context.Context, target fs.FS, dir strin
}
s.execLock.Lock()
s.executorOpt = append(s.executorOpt, executor2.OptionWithRegoScanner(regoScanner), executor2.OptionWithFrameworks(s.frameworks...))
s.executorOpt = append(s.executorOpt, executor.OptionWithRegoScanner(regoScanner), executor.OptionWithFrameworks(s.frameworks...))
s.execLock.Unlock()
var allResults scan.Results
p := parser.New(target, "", s.parserOpt...)
rootDirs, err := p.FindRootModules(ctx, modulePaths)
if err != nil {
return nil, metrics, fmt.Errorf("failed to find root modules: %w", err)
}
rootModules := make([]terraformRootModule, 0, len(rootDirs))
// parse all root module directories
var rootModules []terraformRootModule
for _, dir := range rootDirs {
s.debug.Log("Scanning root module '%s'...", dir)
@@ -243,10 +230,9 @@ func (s *Scanner) ScanFSWithMetrics(ctx context.Context, target fs.FS, dir strin
})
}
rootModules = excludeNonRootModules(rootModules)
for _, module := range rootModules {
s.execLock.RLock()
e := executor2.New(s.executorOpt...)
e := executor.New(s.executorOpt...)
s.execLock.RUnlock()
results, execMetrics, err := e.Execute(module.childs)
if err != nil {
@@ -316,7 +302,7 @@ func (s *Scanner) removeNestedDirs(dirs []string) []string {
return clean
}
func (s *Scanner) findRootModules(target fs.FS, scanDir string, dirs ...string) []string {
func (s *Scanner) findModules(target fs.FS, scanDir string, dirs ...string) []string {
var roots []string
var others []string
@@ -358,7 +344,7 @@ func (s *Scanner) findRootModules(target fs.FS, scanDir string, dirs ...string)
}
if (len(roots) == 0 || s.forceAllDirs) && len(others) > 0 {
roots = append(roots, s.findRootModules(target, scanDir, others...)...)
roots = append(roots, s.findModules(target, scanDir, others...)...)
}
return s.removeNestedDirs(roots)

View File

@@ -1321,3 +1321,72 @@ deny[res] {
fmt.Printf("Debug logs:\n%s\n", debugLog.String())
}
}
func TestScanModuleWithCount(t *testing.T) {
fs := testutil.CreateFS(t, map[string]string{
"code/main.tf": `
module "this" {
count = 0
source = "./modules/s3"
}`,
"code/modules/s3/main.tf": `
module "this" {
source = "./modules/logging"
}
resource "aws_s3_bucket" "this" {
bucket = "test"
}`,
"code/modules/s3/modules/logging/main.tf": `
resource "aws_s3_bucket" "this" {
bucket = "test1"
}`,
"code/example/main.tf": `
module "this" {
source = "../modules/s3"
}`,
"rules/region.rego": `
# METADATA
# schemas:
# - input: schema.input
# custom:
# avd_id: AVD-AWS-0001
# input:
# selector:
# - type: cloud
# subtypes:
# - service: s3
# provider: aws
package user.test.aws1
deny[res] {
bucket := input.aws.s3.buckets[_]
bucket.name.value == "test"
res := result.new("bucket with test name is not allowed!", bucket)
}
`,
})
debugLog := bytes.NewBuffer([]byte{})
scanner := New(
options.ScannerWithDebug(debugLog),
options.ScannerWithPolicyDirs("rules"),
options.ScannerWithPolicyFilesystem(fs),
options.ScannerWithRegoOnly(true),
options.ScannerWithPolicyNamespaces("user"),
options.ScannerWithEmbeddedLibraries(false),
options.ScannerWithEmbeddedPolicies(false),
options.ScannerWithRegoErrorLimits(0),
ScannerWithAllDirectories(true),
)
results, err := scanner.ScanFS(context.TODO(), fs, "code")
require.NoError(t, err)
require.Len(t, results, 1)
failed := results.GetFailed()
assert.Len(t, failed, 1)
occurrences := failed[0].Occurrences()
assert.Equal(t, "code/example/main.tf", occurrences[0].Filename)
}

View File

@@ -12,10 +12,9 @@ type Module struct {
modulePath string
ignores Ignores
parent *Module
local bool
}
func NewModule(rootPath, modulePath string, blocks Blocks, ignores Ignores, local bool) *Module {
func NewModule(rootPath, modulePath string, blocks Blocks, ignores Ignores) *Module {
blockMap := make(map[string]Blocks)
@@ -31,7 +30,6 @@ func NewModule(rootPath, modulePath string, blocks Blocks, ignores Ignores, loca
blockMap: blockMap,
rootPath: rootPath,
modulePath: modulePath,
local: local,
}
}

View File

@@ -8,16 +8,6 @@ import (
type Modules []*Module
func (m Modules) ChildModulesPaths() []string {
var result []string
for _, module := range m {
if module.parent != nil && module.local {
result = append(result, module.modulePath)
}
}
return result
}
type ResourceIDResolutions map[string]bool
func (r ResourceIDResolutions) Resolve(id string) {