fix(terraform): do not re-expand dynamic blocks (#6151)

This commit is contained in:
Nikita Pivkin
2024-02-27 10:02:29 +03:00
committed by GitHub
parent eb54bb5da5
commit 64926d8423
4 changed files with 335 additions and 155 deletions

View File

@@ -3,7 +3,6 @@ package parser
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"io/fs" "io/fs"
"reflect" "reflect"
"time" "time"
@@ -148,7 +147,8 @@ func (e *evaluator) EvaluateAll(ctx context.Context) (terraform.Modules, map[str
} }
} }
// expand out resources and modules via count (not a typo, we do this twice so every order is processed) // expand out resources and modules via count, for-each and dynamic
// (not a typo, we do this twice so every order is processed)
e.blocks = e.expandBlocks(e.blocks) e.blocks = e.expandBlocks(e.blocks)
e.blocks = e.expandBlocks(e.blocks) e.blocks = e.expandBlocks(e.blocks)
@@ -204,7 +204,7 @@ func (e *evaluator) isModuleLocal() bool {
} }
func (e *evaluator) expandBlocks(blocks terraform.Blocks) terraform.Blocks { func (e *evaluator) expandBlocks(blocks terraform.Blocks) terraform.Blocks {
return e.expandDynamicBlocks(e.expandBlockForEaches(e.expandBlockCounts(blocks))...) return e.expandDynamicBlocks(e.expandBlockForEaches(e.expandBlockCounts(blocks), false)...)
} }
func (e *evaluator) expandDynamicBlocks(blocks ...*terraform.Block) terraform.Blocks { func (e *evaluator) expandDynamicBlocks(blocks ...*terraform.Block) terraform.Blocks {
@@ -219,80 +219,49 @@ func (e *evaluator) expandDynamicBlock(b *terraform.Block) {
e.expandDynamicBlock(sub) e.expandDynamicBlock(sub)
} }
for _, sub := range b.AllBlocks().OfType("dynamic") { for _, sub := range b.AllBlocks().OfType("dynamic") {
if sub.IsExpanded() {
continue
}
blockName := sub.TypeLabel() blockName := sub.TypeLabel()
expanded := e.expandBlockForEaches(terraform.Blocks{sub}) expanded := e.expandBlockForEaches(terraform.Blocks{sub}, true)
for _, ex := range expanded { for _, ex := range expanded {
if content := ex.GetBlock("content"); content.IsNotNil() { if content := ex.GetBlock("content"); content.IsNotNil() {
_ = e.expandDynamicBlocks(content) _ = e.expandDynamicBlocks(content)
b.InjectBlock(content, blockName) b.InjectBlock(content, blockName)
} }
} }
sub.MarkExpanded()
} }
} }
func validateForEachArg(arg cty.Value) error {
if arg.IsNull() {
return errors.New("arg is null")
}
ty := arg.Type()
if !arg.IsKnown() || ty.Equals(cty.DynamicPseudoType) || arg.LengthInt() == 0 {
return nil
}
if !(ty.IsSetType() || ty.IsObjectType() || ty.IsMapType()) {
return fmt.Errorf("%s type is not supported: arg is not set or map", ty.FriendlyName())
}
if ty.IsSetType() {
if !ty.ElementType().Equals(cty.String) {
return errors.New("arg is not set of strings")
}
it := arg.ElementIterator()
for it.Next() {
key, _ := it.Element()
if key.IsNull() {
return errors.New("arg is set of strings, but contains null")
}
if !key.IsKnown() {
return errors.New("arg is set of strings, but contains unknown value")
}
}
}
return nil
}
func isBlockSupportsForEachMetaArgument(block *terraform.Block) bool { func isBlockSupportsForEachMetaArgument(block *terraform.Block) bool {
return slices.Contains([]string{"module", "resource", "data", "dynamic"}, block.Type()) return slices.Contains([]string{"module", "resource", "data", "dynamic"}, block.Type())
} }
func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks) terraform.Blocks { func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks, isDynamic bool) terraform.Blocks {
var forEachFiltered terraform.Blocks var forEachFiltered terraform.Blocks
for _, block := range blocks { for _, block := range blocks {
forEachAttr := block.GetAttribute("for_each") forEachAttr := block.GetAttribute("for_each")
if forEachAttr.IsNil() || block.IsCountExpanded() || !isBlockSupportsForEachMetaArgument(block) { if forEachAttr.IsNil() || block.IsExpanded() || !isBlockSupportsForEachMetaArgument(block) {
forEachFiltered = append(forEachFiltered, block) forEachFiltered = append(forEachFiltered, block)
continue continue
} }
forEachVal := forEachAttr.Value() forEachVal := forEachAttr.Value()
if err := validateForEachArg(forEachVal); err != nil { if forEachVal.IsNull() || !forEachVal.IsKnown() || !forEachAttr.IsIterable() {
e.debug.Log(`"for_each" argument is invalid: %s`, err.Error())
continue continue
} }
clones := make(map[string]cty.Value) clones := make(map[string]cty.Value)
_ = forEachAttr.Each(func(key cty.Value, val cty.Value) { _ = forEachAttr.Each(func(key cty.Value, val cty.Value) {
if !key.Type().Equals(cty.String) { // instances are identified by a map key (or set member) from the value provided to for_each
idx, err := convert.Convert(key, cty.String)
if err != nil {
e.debug.Log( e.debug.Log(
`Invalid "for-each" argument: map key (or set value) is not a string, but %s`, `Invalid "for-each" argument: map key (or set value) is not a string, but %s`,
key.Type().FriendlyName(), key.Type().FriendlyName(),
@@ -300,22 +269,34 @@ func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks) terraform.Bloc
return return
} }
clone := block.Clone(key) // if the argument is a collection but not a map, then the resource identifier
// is the value of the collection. The exception is the use of for-each inside a dynamic block,
// because in this case the collection element may not be a primitive value.
if (forEachVal.Type().IsCollectionType() || forEachVal.Type().IsTupleType()) &&
!forEachVal.Type().IsMapType() && !isDynamic {
stringVal, err := convert.Convert(val, cty.String)
if err != nil {
e.debug.Log("Failed to convert for-each arg %v to string", val)
return
}
idx = stringVal
}
clone := block.Clone(idx)
ctx := clone.Context() ctx := clone.Context()
e.copyVariables(block, clone) e.copyVariables(block, clone)
ctx.SetByDot(key, "each.key") ctx.SetByDot(idx, "each.key")
ctx.SetByDot(val, "each.value") ctx.SetByDot(val, "each.value")
ctx.Set(idx, block.TypeLabel(), "key")
ctx.Set(key, block.TypeLabel(), "key")
ctx.Set(val, block.TypeLabel(), "value") ctx.Set(val, block.TypeLabel(), "value")
forEachFiltered = append(forEachFiltered, clone) forEachFiltered = append(forEachFiltered, clone)
values := clone.Values() values := clone.Values()
clones[key.AsString()] = values clones[idx.AsString()] = values
e.ctx.SetByDot(values, clone.GetMetadata().Reference()) e.ctx.SetByDot(values, clone.GetMetadata().Reference())
}) })
@@ -341,7 +322,7 @@ func (e *evaluator) expandBlockCounts(blocks terraform.Blocks) terraform.Blocks
var countFiltered terraform.Blocks var countFiltered terraform.Blocks
for _, block := range blocks { for _, block := range blocks {
countAttr := block.GetAttribute("count") countAttr := block.GetAttribute("count")
if countAttr.IsNil() || block.IsCountExpanded() || !isBlockSupportsCountMetaArgument(block) { if countAttr.IsNil() || block.IsExpanded() || !isBlockSupportsCountMetaArgument(block) {
countFiltered = append(countFiltered, block) countFiltered = append(countFiltered, block)
continue continue
} }

View File

@@ -1,94 +0,0 @@
package parser
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zclconf/go-cty/cty"
)
func TestValidateForEachArg(t *testing.T) {
tests := []struct {
name string
arg cty.Value
expectedError string
}{
{
name: "empty set",
arg: cty.SetValEmpty(cty.String),
},
{
name: "set of strings",
arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.StringVal("val2")}),
},
{
name: "set of non-strings",
arg: cty.SetVal([]cty.Value{cty.NumberIntVal(1), cty.NumberIntVal(2)}),
expectedError: "is not set of strings",
},
{
name: "set with null",
arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.NullVal(cty.String)}),
expectedError: "arg is set of strings, but contains null",
},
{
name: "set with unknown",
arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.UnknownVal(cty.String)}),
expectedError: "arg is set of strings, but contains unknown",
},
{
name: "set with unknown",
arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.UnknownVal(cty.String)}),
expectedError: "arg is set of strings, but contains unknown",
},
{
name: "non empty map",
arg: cty.MapVal(map[string]cty.Value{
"val1": cty.StringVal("..."),
"val2": cty.StringVal("..."),
}),
},
{
name: "map with unknown",
arg: cty.MapVal(map[string]cty.Value{
"val1": cty.UnknownVal(cty.String),
"val2": cty.StringVal("..."),
}),
},
{
name: "empty obj",
arg: cty.EmptyObjectVal,
},
{
name: "obj with strings",
arg: cty.ObjectVal(map[string]cty.Value{
"val1": cty.StringVal("..."),
"val2": cty.StringVal("..."),
}),
},
{
name: "null",
arg: cty.NullVal(cty.Set(cty.String)),
expectedError: "arg is null",
},
{
name: "unknown",
arg: cty.UnknownVal(cty.Set(cty.String)),
},
{
name: "dynamic",
arg: cty.DynamicVal,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateForEachArg(tt.arg)
if tt.expectedError != "" && err != nil {
assert.ErrorContains(t, err, tt.expectedError)
return
}
assert.NoError(t, err)
})
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/aquasecurity/trivy/internal/testutil" "github.com/aquasecurity/trivy/internal/testutil"
"github.com/aquasecurity/trivy/pkg/iac/scanners/options" "github.com/aquasecurity/trivy/pkg/iac/scanners/options"
"github.com/aquasecurity/trivy/pkg/iac/terraform"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty"
@@ -904,6 +905,91 @@ data "http" "example" {
} }
func TestForEach(t *testing.T) { func TestForEach(t *testing.T) {
tests := []struct {
name string
src string
expectedBucketName string
expectedNameLabel string
}{
{
name: "arg is set and ref to each.key",
src: `locals {
buckets = ["bucket1"]
}
resource "aws_s3_bucket" "this" {
for_each = toset(local.buckets)
bucket = each.key
}`,
expectedBucketName: "bucket1",
expectedNameLabel: `this["bucket1"]`,
},
{
name: "arg is set and ref to each.value",
src: `locals {
buckets = ["bucket1"]
}
resource "aws_s3_bucket" "this" {
for_each = toset(local.buckets)
bucket = each.value
}`,
expectedBucketName: "bucket1",
expectedNameLabel: `this["bucket1"]`,
},
{
name: "arg is map and ref to each.key",
src: `locals {
buckets = {
bucket1key = "bucket1value"
}
}
resource "aws_s3_bucket" "this" {
for_each = local.buckets
bucket = each.key
}`,
expectedBucketName: "bucket1key",
expectedNameLabel: `this["bucket1key"]`,
},
{
name: "arg is map and ref to each.value",
src: `locals {
buckets = {
bucket1key = "bucket1value"
}
}
resource "aws_s3_bucket" "this" {
for_each = local.buckets
bucket = each.value
}`,
expectedBucketName: "bucket1value",
expectedNameLabel: `this["bucket1key"]`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
modules := parse(t, map[string]string{
"main.tf": tt.src,
})
require.Len(t, modules, 1)
buckets := modules.GetResourcesByType("aws_s3_bucket")
assert.Len(t, buckets, 1)
bucket := buckets[0]
bucketName := bucket.GetAttribute("bucket").Value().AsString()
assert.Equal(t, tt.expectedBucketName, bucketName)
assert.Equal(t, tt.expectedNameLabel, bucket.NameLabel())
})
}
}
func TestForEachCountExpanded(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -919,6 +1005,18 @@ func TestForEach(t *testing.T) {
resource "aws_s3_bucket" "this" { resource "aws_s3_bucket" "this" {
for_each = local.buckets for_each = local.buckets
bucket = each.key bucket = each.key
}`,
expectedCount: 2,
},
{
name: "arg is empty list",
source: `locals {
buckets = []
}
resource "aws_s3_bucket" "this" {
for_each = local.buckets
bucket = each.value
}`, }`,
expectedCount: 0, expectedCount: 0,
}, },
@@ -929,8 +1027,34 @@ resource "aws_s3_bucket" "this" {
} }
resource "aws_s3_bucket" "this" { resource "aws_s3_bucket" "this" {
for_each = loca.buckets for_each = local.buckets
bucket = each.key bucket = each.key
}`,
expectedCount: 0,
},
{
name: "argument set with the same values",
source: `locals {
buckets = ["true", "true"]
}
resource "aws_s3_bucket" "this" {
for_each = toset(local.buckets)
bucket = each.key
}`,
expectedCount: 1,
},
{
name: "arg is non-valid set",
source: `locals {
buckets = [{
bucket1key = "bucket1value"
}]
}
resource "aws_s3_bucket" "this" {
for_each = toset(local.buckets)
bucket = each.value
}`, }`,
expectedCount: 0, expectedCount: 0,
}, },
@@ -961,18 +1085,25 @@ resource "aws_s3_bucket" "this" {
}`, }`,
expectedCount: 2, expectedCount: 2,
}, },
{
name: "arg is empty map",
source: `locals {
buckets = {}
}
resource "aws_s3_bucket" "this" {
for_each = local.buckets
bucket = each.value
}
`,
expectedCount: 0,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
fs := testutil.CreateFS(t, map[string]string{ modules := parse(t, map[string]string{
"main.tf": tt.source, "main.tf": tt.source,
}) })
parser := New(fs, "", OptionStopOnHCLError(true))
require.NoError(t, parser.ParseFS(context.TODO(), "."))
modules, _, err := parser.EvaluateAll(context.TODO())
assert.NoError(t, err)
assert.Len(t, modules, 1) assert.Len(t, modules, 1)
bucketBlocks := modules.GetResourcesByType("aws_s3_bucket") bucketBlocks := modules.GetResourcesByType("aws_s3_bucket")
@@ -1139,3 +1270,165 @@ func TestForEachWithObjectsOfDifferentTypes(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, modules, 1) assert.Len(t, modules, 1)
} }
func TestDynamicBlocks(t *testing.T) {
t.Run("arg is list of int", func(t *testing.T) {
modules := parse(t, map[string]string{
"main.tf": `
resource "aws_security_group" "sg-webserver" {
vpc_id = "1111"
dynamic "ingress" {
for_each = [80, 443]
content {
from_port = ingress.value
to_port = ingress.value
protocol = "tcp"
cidr_blocks = ["0.0.0.0/0"]
}
}
}
`,
})
require.Len(t, modules, 1)
secGroups := modules.GetResourcesByType("aws_security_group")
assert.Len(t, secGroups, 1)
ingressBlocks := secGroups[0].GetBlocks("ingress")
assert.Len(t, ingressBlocks, 2)
var inboundPorts []int
for _, ingress := range ingressBlocks {
fromPort := ingress.GetAttribute("from_port").AsIntValueOrDefault(-1, ingress).Value()
inboundPorts = append(inboundPorts, fromPort)
}
assert.True(t, compareSets([]int{80, 443}, inboundPorts))
})
t.Run("empty for-each", func(t *testing.T) {
modules := parse(t, map[string]string{
"main.tf": `
resource "aws_lambda_function" "analyzer" {
dynamic "vpc_config" {
for_each = []
content {}
}
}
`,
})
require.Len(t, modules, 1)
functions := modules.GetResourcesByType("aws_lambda_function")
assert.Len(t, functions, 1)
vpcConfigs := functions[0].GetBlocks("vpc_config")
assert.Empty(t, vpcConfigs)
})
t.Run("arg is list of bool", func(t *testing.T) {
modules := parse(t, map[string]string{
"main.tf": `
resource "aws_lambda_function" "analyzer" {
dynamic "vpc_config" {
for_each = [true]
content {}
}
}
`,
})
require.Len(t, modules, 1)
functions := modules.GetResourcesByType("aws_lambda_function")
assert.Len(t, functions, 1)
vpcConfigs := functions[0].GetBlocks("vpc_config")
assert.Len(t, vpcConfigs, 1)
})
t.Run("arg is list of objects", func(t *testing.T) {
modules := parse(t, map[string]string{
"main.tf": `locals {
cluster_network_policy = [{
enabled = true
}]
}
resource "google_container_cluster" "primary" {
name = "test"
dynamic "network_policy" {
for_each = local.cluster_network_policy
content {
enabled = network_policy.value.enabled
}
}
}`,
})
require.Len(t, modules, 1)
clusters := modules.GetResourcesByType("google_container_cluster")
assert.Len(t, clusters, 1)
networkPolicies := clusters[0].GetBlocks("network_policy")
assert.Len(t, networkPolicies, 1)
enabled := networkPolicies[0].GetAttribute("enabled")
assert.True(t, enabled.Value().True())
})
t.Run("nested dynamic", func(t *testing.T) {
modules := parse(t, map[string]string{
"main.tf": `
resource "test_block" "this" {
name = "name"
location = "loc"
dynamic "env" {
for_each = ["1", "2"]
content {
dynamic "value_source" {
for_each = [true, true]
content {}
}
}
}
}`,
})
require.Len(t, modules, 1)
testResources := modules.GetResourcesByType("test_block")
assert.Len(t, testResources, 1)
envs := testResources[0].GetBlocks("env")
assert.Len(t, envs, 2)
var sources []*terraform.Block
for _, env := range envs {
sources = append(sources, env.GetBlocks("value_source")...)
}
assert.Len(t, sources, 4)
})
}
func parse(t *testing.T, files map[string]string) terraform.Modules {
fs := testutil.CreateFS(t, files)
parser := New(fs, "", OptionStopOnHCLError(true))
require.NoError(t, parser.ParseFS(context.TODO(), "."))
modules, _, err := parser.EvaluateAll(context.TODO())
require.NoError(t, err)
return modules
}
func compareSets(a []int, b []int) bool {
m := make(map[int]bool)
for _, el := range a {
m[el] = true
}
for _, el := range b {
if !m[el] {
return false
}
}
return true
}

View File

@@ -145,11 +145,11 @@ func (b *Block) InjectBlock(block *Block, name string) {
b.childBlocks = append(b.childBlocks, block) b.childBlocks = append(b.childBlocks, block)
} }
func (b *Block) MarkCountExpanded() { func (b *Block) MarkExpanded() {
b.expanded = true b.expanded = true
} }
func (b *Block) IsCountExpanded() bool { func (b *Block) IsExpanded() bool {
return b.expanded return b.expanded
} }
@@ -187,7 +187,7 @@ func (b *Block) Clone(index cty.Value) *Block {
} }
indexVal, _ := gocty.ToCtyValue(index, cty.Number) indexVal, _ := gocty.ToCtyValue(index, cty.Number)
clone.context.SetByDot(indexVal, "count.index") clone.context.SetByDot(indexVal, "count.index")
clone.MarkCountExpanded() clone.MarkExpanded()
b.cloneIndex++ b.cloneIndex++
return clone return clone
} }