refactor(misconf): get a block or attribute without calling HasChild (#8586)

Signed-off-by: nikpivkin <nikita.pivkin@smartforce.io>
This commit is contained in:
Nikita Pivkin
2025-03-22 08:48:34 +06:00
committed by GitHub
parent ba77dbe5f9
commit f07030daf2
13 changed files with 83 additions and 197 deletions

View File

@@ -73,12 +73,10 @@ func adaptCluster(resource *terraform.Block) msk.Cluster {
cluster.Logging.Metadata = logBlock.GetMetadata()
if brokerLogsBlock := logBlock.GetBlock("broker_logs"); brokerLogsBlock.IsNotNil() {
cluster.Logging.Broker.Metadata = brokerLogsBlock.GetMetadata()
if brokerLogsBlock.HasChild("s3") {
if s3Block := brokerLogsBlock.GetBlock("s3"); s3Block.IsNotNil() {
s3enabledAttr := s3Block.GetAttribute("enabled")
cluster.Logging.Broker.S3.Metadata = s3Block.GetMetadata()
cluster.Logging.Broker.S3.Enabled = s3enabledAttr.AsBoolValueOrDefault(false, s3Block)
}
if s3Block := brokerLogsBlock.GetBlock("s3"); s3Block.IsNotNil() {
s3enabledAttr := s3Block.GetAttribute("enabled")
cluster.Logging.Broker.S3.Metadata = s3Block.GetMetadata()
cluster.Logging.Broker.S3.Enabled = s3enabledAttr.AsBoolValueOrDefault(false, s3Block)
}
if cloudwatchBlock := brokerLogsBlock.GetBlock("cloudwatch_logs"); cloudwatchBlock.IsNotNil() {
cwEnabledAttr := cloudwatchBlock.GetAttribute("enabled")

View File

@@ -29,9 +29,9 @@ func adaptCompute(modules terraform.Modules) compute.Compute {
windowsVirtualMachines = append(windowsVirtualMachines, adaptWindowsVM(resource))
}
for _, resource := range module.GetResourcesByType(AzureVirtualMachine) {
if resource.HasChild("os_profile_linux_config") {
if linuxConfigBlock := resource.GetBlock("os_profile_linux_config"); linuxConfigBlock.IsNotNil() {
linuxVirtualMachines = append(linuxVirtualMachines, adaptLinuxVM(resource))
} else if resource.HasChild("os_profile_windows_config") {
} else if windowsConfigBlock := resource.GetBlock("os_profile_windows_config"); windowsConfigBlock.IsNotNil() {
windowsVirtualMachines = append(windowsVirtualMachines, adaptWindowsVM(resource))
}
}

View File

@@ -79,28 +79,27 @@ func adaptCluster(resource *terraform.Block) container.KubernetesCluster {
}
// azurerm < 2.99.0
if resource.HasChild("role_based_access_control") {
roleBasedAccessControlBlock := resource.GetBlock("role_based_access_control")
rbEnabledAttr := roleBasedAccessControlBlock.GetAttribute("enabled")
cluster.RoleBasedAccessControl.Metadata = roleBasedAccessControlBlock.GetMetadata()
cluster.RoleBasedAccessControl.Enabled = rbEnabledAttr.AsBoolValueOrDefault(false, roleBasedAccessControlBlock)
}
if resource.HasChild("role_based_access_control_enabled") {
// azurerm >= 2.99.0
roleBasedAccessControlEnabledAttr := resource.GetAttribute("role_based_access_control_enabled")
cluster.RoleBasedAccessControl.Metadata = roleBasedAccessControlEnabledAttr.GetMetadata()
cluster.RoleBasedAccessControl.Enabled = roleBasedAccessControlEnabledAttr.AsBoolValueOrDefault(false, resource)
if rbacBlock := resource.GetBlock("role_based_access_control"); rbacBlock.IsNotNil() {
rbEnabledAttr := rbacBlock.GetAttribute("enabled")
cluster.RoleBasedAccessControl.Metadata = rbacBlock.GetMetadata()
cluster.RoleBasedAccessControl.Enabled = rbEnabledAttr.AsBoolValueOrDefault(false, rbacBlock)
}
if resource.HasChild("azure_active_directory_role_based_access_control") {
azureRoleBasedAccessControl := resource.GetBlock("azure_active_directory_role_based_access_control")
if azureRoleBasedAccessControl.IsNotNil() {
enabledAttr := azureRoleBasedAccessControl.GetAttribute("azure_rbac_enabled")
if rbacEnabledAttr := resource.GetAttribute("role_based_access_control_enabled"); rbacEnabledAttr.IsNotNil() {
// azurerm >= 2.99.0
cluster.RoleBasedAccessControl.Metadata = rbacEnabledAttr.GetMetadata()
cluster.RoleBasedAccessControl.Enabled = rbacEnabledAttr.AsBoolValueOrDefault(false, resource)
}
if block := resource.GetBlock("azure_active_directory_role_based_access_control"); block.IsNotNil() {
enabledAttr := block.GetAttribute("azure_rbac_enabled")
if enabledAttr.IsNotNil() {
if !cluster.RoleBasedAccessControl.Enabled.IsTrue() {
cluster.RoleBasedAccessControl.Metadata = azureRoleBasedAccessControl.GetMetadata()
cluster.RoleBasedAccessControl.Enabled = enabledAttr.AsBoolValueOrDefault(false, azureRoleBasedAccessControl)
cluster.RoleBasedAccessControl.Metadata = block.GetMetadata()
cluster.RoleBasedAccessControl.Enabled = enabledAttr.AsBoolValueOrDefault(false, block)
}
}
}
return cluster
}

View File

@@ -244,9 +244,7 @@ func (a *mssqlAdapter) adaptMSSQLServer(resource *terraform.Block, module *terra
}
auditingPoliciesBlocks := module.GetReferencingResources(resource, "azurerm_mssql_server_extended_auditing_policy", "server_id")
if resource.HasChild("extended_auditing_policy") {
auditingPoliciesBlocks = append(auditingPoliciesBlocks, resource.GetBlocks("extended_auditing_policy")...)
}
auditingPoliciesBlocks = append(auditingPoliciesBlocks, resource.GetBlocks("extended_auditing_policy")...)
databasesRes := module.GetReferencingResources(resource, "azurerm_mssql_database", "server_id")
for _, databaseRes := range databasesRes {

View File

@@ -162,8 +162,7 @@ func adaptNetworkRule(resource *terraform.Block) storage.NetworkRule {
allowByDefault = iacTypes.BoolDefault(false, resource.GetMetadata())
}
if resource.HasChild("bypass") {
bypassAttr := resource.GetAttribute("bypass")
if bypassAttr := resource.GetAttribute("bypass"); bypassAttr.IsNotNil() {
bypass = bypassAttr.AsStringValues()
}

View File

@@ -176,18 +176,17 @@ func (a *adapter) adaptNodePool(resource *terraform.Block) {
EnableAutoUpgrade: iacTypes.BoolDefault(false, resource.GetMetadata()),
}
if resource.HasChild("management") {
management.Metadata = resource.GetBlock("management").GetMetadata()
if managementBlock := resource.GetBlock("management"); managementBlock.IsNotNil() {
management.Metadata = managementBlock.GetMetadata()
autoRepairAttr := managementBlock.GetAttribute("auto_repair")
management.EnableAutoRepair = autoRepairAttr.AsBoolValueOrDefault(false, managementBlock)
autoRepairAttr := resource.GetBlock("management").GetAttribute("auto_repair")
management.EnableAutoRepair = autoRepairAttr.AsBoolValueOrDefault(false, resource.GetBlock("management"))
autoUpgradeAttr := resource.GetBlock("management").GetAttribute("auto_upgrade")
management.EnableAutoUpgrade = autoUpgradeAttr.AsBoolValueOrDefault(false, resource.GetBlock("management"))
autoUpgradeAttr := managementBlock.GetAttribute("auto_upgrade")
management.EnableAutoUpgrade = autoUpgradeAttr.AsBoolValueOrDefault(false, managementBlock)
}
if resource.HasChild("node_config") {
nodeConfig = adaptNodeConfig(resource.GetBlock("node_config"))
if nodeConfigBlock := resource.GetBlock("node_config"); nodeConfigBlock.IsNotNil() {
nodeConfig = adaptNodeConfig(nodeConfigBlock)
}
nodePool := gke.NodePool{
@@ -296,10 +295,10 @@ func adaptMasterAuth(resource *terraform.Block) gke.MasterAuth {
IssueCertificate: iacTypes.BoolDefault(false, resource.GetMetadata()),
}
if resource.HasChild("client_certificate_config") {
clientCertAttr := resource.GetBlock("client_certificate_config").GetAttribute("issue_client_certificate")
clientCert.IssueCertificate = clientCertAttr.AsBoolValueOrDefault(false, resource.GetBlock("client_certificate_config"))
clientCert.Metadata = resource.GetBlock("client_certificate_config").GetMetadata()
if certConfigBlock := resource.GetBlock("client_certificate_config"); certConfigBlock.IsNotNil() {
clientCertAttr := certConfigBlock.GetAttribute("issue_client_certificate")
clientCert.IssueCertificate = clientCertAttr.AsBoolValueOrDefault(false, certConfigBlock)
clientCert.Metadata = certConfigBlock.GetMetadata()
}
username := resource.GetAttribute("username").AsStringValueOrDefault("", resource)

View File

@@ -73,8 +73,8 @@ func adaptInstance(resource *terraform.Block) sql.DatabaseInstance {
backupConfigEnabledAttr := backupBlock.GetAttribute("enabled")
instance.Settings.Backups.Enabled = backupConfigEnabledAttr.AsBoolValueOrDefault(false, backupBlock)
}
if settingsBlock.HasChild("ip_configuration") {
instance.Settings.IPConfiguration = adaptIPConfig(settingsBlock.GetBlock("ip_configuration"))
if ipConfBlock := settingsBlock.GetBlock("ip_configuration"); ipConfBlock.IsNotNil() {
instance.Settings.IPConfiguration = adaptIPConfig(ipConfBlock)
}
}
return instance

View File

@@ -84,6 +84,34 @@ func Test_Adapt(t *testing.T) {
},
},
},
{
name: "wrong ip_configuration",
terraform: `
resource "google_sql_database_instance" "test" {
settings {
ip_configuration = []
}
}
`,
expected: sql.SQL{
Instances: []sql.DatabaseInstance{
{
Settings: sql.Settings{
Flags: sql.Flags{
ContainedDatabaseAuthentication: iacTypes.BoolTest(true),
CrossDBOwnershipChaining: iacTypes.BoolTest(true),
LogMinDurationStatement: iacTypes.IntTest(-1),
LogTempFileSize: iacTypes.IntTest(-1),
},
IPConfiguration: sql.IPConfiguration{
EnableIPv4: iacTypes.BoolTest(true),
},
},
},
},
},
},
}
for _, test := range tests {

View File

@@ -4,6 +4,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/iac/terraform"
)
@@ -59,10 +60,8 @@ resource "aws_s3_bucket" "my-bucket" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
if !block.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := block.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
assert.Equal(t, test.expectedResult, attr.StartsWith(test.checkValue))
}
}
@@ -121,10 +120,8 @@ resource "aws_s3_bucket" "my-bucket" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
if !block.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := block.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
assert.Equal(t, test.expectedResult, attr.EndsWith(test.checkValue))
}
}
@@ -277,10 +274,8 @@ resource "aws_security_group" "my-security_group" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, b := range module.GetBlocks() {
if !b.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := b.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
if test.ignoreCase {
assert.Equal(t, test.expectedResult, attr.Contains(test.checkValue, terraform.IgnoreCase))
} else {
@@ -339,10 +334,8 @@ resource "aws_security_group" "my-security_group" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
if !block.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := block.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
assert.Equal(t, test.expectedResult, attr.IsAny(test.checkValue...))
}
}
@@ -397,10 +390,8 @@ resource "aws_security_group" "my-security_group" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
if !block.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := block.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
assert.Equal(t, test.expectedResult, attr.IsNone(test.checkValue...))
}
}
@@ -509,10 +500,8 @@ resource "aws_security_group_rule" "example" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
if !block.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := block.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
assert.Equal(t, test.expectedResult, attr.IsEmpty())
}
}
@@ -554,10 +543,8 @@ resource "numerical_something" "my-bucket" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
if !block.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := block.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
assert.Equal(t, test.expectedResult, attr.LessThan(test.checkValue))
}
}
@@ -599,10 +586,8 @@ resource "numerical_something" "my-bucket" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
if !block.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := block.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
assert.Equal(t, test.expectedResult, attr.LessThanOrEqualTo(test.checkValue))
}
}
@@ -650,10 +635,8 @@ resource "boolean_something" "my-something" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
if !block.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := block.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
assert.Equal(t, test.expectedResult, attr.IsTrue())
}
}
@@ -701,10 +684,8 @@ resource "boolean_something" "my-something" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
if !block.HasChild(test.checkAttribute) {
t.FailNow()
}
attr := block.GetAttribute(test.checkAttribute)
require.NotNil(t, attr)
assert.Equal(t, test.expectedResult, attr.IsFalse())
}
}

View File

@@ -29,19 +29,6 @@ resource "aws_s3_bucket" "my-bucket" {
}`,
expectedAttribute: "acl",
},
{
name: "expected acl attribute is present",
source: `
resource "aws_s3_bucket" "my-bucket" {
bucket_name = "bucketName"
acl = "public-read"
logging {
target_bucket = aws_s3_bucket.log_bucket.id
target_prefix = "log/"
}
}`,
expectedAttribute: "logging",
},
}
for _, test := range tests {
@@ -49,8 +36,7 @@ resource "aws_s3_bucket" "my-bucket" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
assert.True(t, block.HasChild(test.expectedAttribute))
assert.True(t, block.HasChild(test.expectedAttribute))
assert.NotNil(t, block.GetAttribute(test.expectedAttribute))
}
}
})
@@ -89,48 +75,7 @@ resource "aws_s3_bucket" "my-bucket" {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
assert.False(t, block.HasChild(test.expectedAttribute))
assert.False(t, block.HasChild(test.expectedAttribute))
}
}
})
}
}
func Test_MissingChildNotFoundOnBlock(t *testing.T) {
var tests = []struct {
name string
source string
expectedAttribute string
}{
{
name: "expected attribute is not present",
source: `
resource "aws_s3_bucket" "my-bucket" {
bucket_name = "bucketName"
}`,
expectedAttribute: "acl",
},
{
name: "expected acl attribute is not present",
source: `
resource "aws_s3_bucket" "my-bucket" {
bucket_name = "bucketName"
acl = "public-read"
}`,
expectedAttribute: "logging",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
modules := createModulesFromSource(t, test.source, ".tf")
for _, module := range modules {
for _, block := range module.GetBlocks() {
assert.True(t, block.MissingChild(test.expectedAttribute))
assert.False(t, block.HasChild(test.expectedAttribute))
assert.Nil(t, block.GetAttribute(test.expectedAttribute))
}
}
})

View File

@@ -963,8 +963,10 @@ func (a *Attribute) AllReferences(blocks ...*Block) []*Reference {
refs := a.extractReferences()
for _, block := range blocks {
for _, ref := range refs {
if ref.TypeLabel() == "each" && block.HasChild("for_each") {
refs = append(refs, block.GetAttribute("for_each").AllReferences()...)
if ref.TypeLabel() == "each" {
if forEachAttr := block.GetAttribute("for_each"); forEachAttr.IsNotNil() {
refs = append(refs, forEachAttr.AllReferences()...)
}
}
}
}

View File

@@ -517,39 +517,6 @@ func (b *Block) NameLabel() string {
return ""
}
func (b *Block) HasChild(childElement string) bool {
return b.GetAttribute(childElement).IsNotNil() || b.GetBlock(childElement).IsNotNil()
}
func (b *Block) MissingChild(childElement string) bool {
if b == nil {
return true
}
return !b.HasChild(childElement)
}
func (b *Block) MissingNestedChild(name string) bool {
if b == nil {
return true
}
parts := strings.Split(name, ".")
blocks := parts[:len(parts)-1]
last := parts[len(parts)-1]
working := b
for _, subBlock := range blocks {
if checkBlock := working.GetBlock(subBlock); checkBlock == nil {
return true
} else {
working = checkBlock
}
}
return !working.HasChild(last)
}
func (b *Block) InModule() bool {
if b == nil {
return false

View File

@@ -2,7 +2,6 @@ package terraform
import (
"fmt"
"strings"
"github.com/aquasecurity/trivy/pkg/iac/ignore"
)
@@ -107,23 +106,6 @@ func (c *Module) GetDatasByType(label string) Blocks {
return c.getBlocksByType("data", label)
}
func (c *Module) GetProviderBlocksByProvider(providerName, alias string) Blocks {
var results Blocks
for _, block := range c.blocks {
if block.Type() == "provider" && len(block.Labels()) > 0 && block.TypeLabel() == providerName {
if alias != "" {
if block.HasChild("alias") && block.GetAttribute("alias").Equals(strings.ReplaceAll(alias, fmt.Sprintf("%s.", providerName), "")) {
results = append(results, block)
}
} else if block.MissingChild("alias") {
results = append(results, block)
}
}
}
return results
}
func (c *Module) GetReferencedBlock(referringAttr *Attribute, parentBlock *Block) (*Block, error) {
for _, ref := range referringAttr.AllReferences() {
if ref.TypeLabel() == "each" {
@@ -159,18 +141,6 @@ func (c *Module) GetReferencingResources(originalBlock *Block, referencingLabel,
return c.GetReferencingBlocks(originalBlock, "resource", referencingLabel, referencingAttributeName)
}
func (c *Module) GetsModulesBySource(moduleSource string) (Blocks, error) {
var results Blocks
modules := c.getModuleBlocks()
for _, module := range modules {
if module.HasChild("source") && module.GetAttribute("source").Equals(moduleSource) {
results = append(results, module)
}
}
return results, nil
}
func (c *Module) GetReferencingBlocks(originalBlock *Block, referencingType, referencingLabel, referencingAttributeName string) Blocks {
blocks := c.getBlocksByType(referencingType, referencingLabel)
var results Blocks