mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-22 23:26:39 -08:00
189 lines
5.3 KiB
Go
189 lines
5.3 KiB
Go
package commands
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/aquasecurity/defsec/pkg/errs"
|
|
|
|
cmd "github.com/aquasecurity/trivy/pkg/commands/artifact"
|
|
|
|
"github.com/aquasecurity/trivy/pkg/cloud"
|
|
|
|
"github.com/aquasecurity/trivy/pkg/cloud/cache"
|
|
|
|
"github.com/aquasecurity/trivy/pkg/flag"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/service/sts"
|
|
|
|
"github.com/aquasecurity/trivy/pkg/cloud/aws/scanner"
|
|
"github.com/aquasecurity/trivy/pkg/cloud/report"
|
|
|
|
"github.com/aquasecurity/trivy/pkg/log"
|
|
|
|
awsScanner "github.com/aquasecurity/defsec/pkg/scanners/cloud/aws"
|
|
)
|
|
|
|
func getAccountIDAndRegion(ctx context.Context, region string) (string, string, error) {
|
|
log.Logger.Debug("Looking for AWS credentials provider...")
|
|
|
|
cfg, err := config.LoadDefaultConfig(context.TODO())
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
if region != "" {
|
|
cfg.Region = region
|
|
}
|
|
|
|
svc := sts.NewFromConfig(cfg)
|
|
|
|
log.Logger.Debug("Looking up AWS caller identity...")
|
|
result, err := svc.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to discover AWS caller identity: %w", err)
|
|
}
|
|
if result.Account == nil {
|
|
return "", "", fmt.Errorf("missing account id for aws account")
|
|
}
|
|
log.Logger.Debugf("Verified AWS credentials for account %s!", *result.Account)
|
|
return *result.Account, cfg.Region, nil
|
|
}
|
|
|
|
func processOptions(ctx context.Context, opt *flag.Options) error {
|
|
// support comma separated services too
|
|
var splitServices []string
|
|
for _, service := range opt.Services {
|
|
splitServices = append(splitServices, strings.Split(service, ",")...)
|
|
}
|
|
opt.Services = splitServices
|
|
|
|
if len(opt.Services) != 1 && opt.ARN != "" {
|
|
return fmt.Errorf("you must specify the single --service which the --arn relates to")
|
|
}
|
|
|
|
if opt.Account == "" || opt.Region == "" {
|
|
var err error
|
|
opt.Account, opt.Region, err = getAccountIDAndRegion(ctx, opt.Region)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if len(opt.Services) == 0 {
|
|
log.Logger.Debug("No service(s) specified, scanning all services...")
|
|
opt.Services = awsScanner.AllSupportedServices()
|
|
} else {
|
|
log.Logger.Debugf("Specific services were requested: [%s]...", strings.Join(opt.Services, ", "))
|
|
for _, service := range opt.Services {
|
|
var found bool
|
|
supported := awsScanner.AllSupportedServices()
|
|
for _, allowed := range supported {
|
|
if allowed == service {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("service '%s' is not currently supported - supported services are: %s", service, strings.Join(supported, ", "))
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func Run(ctx context.Context, opt flag.Options) error {
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, opt.GlobalOptions.Timeout)
|
|
defer cancel()
|
|
|
|
if err := log.InitLogger(opt.Debug, false); err != nil {
|
|
return fmt.Errorf("logger error: %w", err)
|
|
}
|
|
|
|
var err error
|
|
defer func() {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
log.Logger.Warn("Increase --timeout value")
|
|
}
|
|
}()
|
|
|
|
if err := processOptions(ctx, &opt); err != nil {
|
|
return err
|
|
}
|
|
|
|
cached := cache.New(opt.CacheDir, opt.MaxCacheAge, cloud.ProviderAWS, opt.Account, opt.Region)
|
|
servicesInCache := cached.ListAvailableServices(false)
|
|
var servicesToLoadFromCache []string
|
|
var servicesToScan []string
|
|
for _, service := range opt.Services {
|
|
if cached != nil {
|
|
var inCache bool
|
|
for _, cacheSvc := range servicesInCache {
|
|
if cacheSvc == service {
|
|
log.Logger.Debugf("Results for service '%s' found in cache.", service)
|
|
inCache = true
|
|
break
|
|
}
|
|
}
|
|
if inCache && !opt.UpdateCache {
|
|
servicesToLoadFromCache = append(servicesToLoadFromCache, service)
|
|
continue
|
|
}
|
|
}
|
|
servicesToScan = append(servicesToScan, service)
|
|
}
|
|
|
|
var r *report.Report
|
|
|
|
// if there is anything we need that wasn't in the cache, scan it now
|
|
if len(servicesToScan) > 0 {
|
|
log.Logger.Debugf("Scanning the following services using the AWS API: [%s]...", strings.Join(servicesToScan, ", "))
|
|
opt.Services = servicesToScan
|
|
results, err := scanner.NewScanner().Scan(ctx, opt)
|
|
if err != nil {
|
|
var aerr errs.AdapterError
|
|
if errors.As(err, &aerr) {
|
|
for _, e := range aerr.Errors() {
|
|
log.Logger.Warnf("Adapter error: %s", e)
|
|
}
|
|
} else {
|
|
return fmt.Errorf("aws scan error: %w", err)
|
|
}
|
|
}
|
|
r = report.New(cloud.ProviderAWS, opt.Account, opt.Region, results.GetFailed(), opt.Services)
|
|
} else {
|
|
log.Logger.Debug("No more services to scan - everything was found in the cache.")
|
|
r = report.New(cloud.ProviderAWS, opt.Account, opt.Region, nil, opt.Services)
|
|
}
|
|
if len(servicesToLoadFromCache) > 0 {
|
|
log.Logger.Debug("Loading cached results...")
|
|
cachedReport, err := cached.LoadReport(servicesToLoadFromCache...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for service, results := range cachedReport.Results {
|
|
log.Logger.Debugf("Adding cached results for '%s'...", service)
|
|
r.AddResultsForService(service, results.Results, results.CreationTime)
|
|
}
|
|
}
|
|
|
|
if len(servicesToScan) > 0 { // don't write cache if we didn't scan anything new
|
|
log.Logger.Debugf("Writing results to cache for services [%s]...", strings.Join(r.ServicesInScope, ", "))
|
|
if err := cached.Save(r); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
log.Logger.Debug("Writing report to output...")
|
|
if err := report.Write(r, opt, len(servicesToLoadFromCache) > 0); err != nil {
|
|
return fmt.Errorf("unable to write results: %w", err)
|
|
}
|
|
|
|
cmd.Exit(opt, r.Failed())
|
|
return nil
|
|
}
|