Files
trivy/pkg/cache/redis.go
2025-10-03 09:37:05 +00:00

249 lines
7.2 KiB
Go

package cache
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/hashicorp/go-multierror"
"github.com/samber/lo"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
)
var _ Cache = (*RedisCache)(nil)
const redisPrefix = "fanal"
type RedisOptions struct {
Backend string
TLS bool
TLSOptions RedisTLSOptions
}
func NewRedisOptions(backend, caCert, cert, key string, enableTLS bool) (RedisOptions, error) {
tlsOpts, err := NewRedisTLSOptions(caCert, cert, key)
if err != nil {
return RedisOptions{}, xerrors.Errorf("redis TLS option error: %w", err)
}
return RedisOptions{
Backend: backend,
TLS: enableTLS,
TLSOptions: tlsOpts,
}, nil
}
// BackendMasked returns the redis connection string masking credentials
func (o *RedisOptions) BackendMasked() string {
endIndex := strings.Index(o.Backend, "@")
if endIndex == -1 {
return o.Backend
}
startIndex := strings.Index(o.Backend, "//")
return fmt.Sprintf("%s****%s", o.Backend[:startIndex+2], o.Backend[endIndex:])
}
// RedisTLSOptions holds the options for redis cache
type RedisTLSOptions struct {
CACert string
Cert string
Key string
}
func NewRedisTLSOptions(caCert, cert, key string) (RedisTLSOptions, error) {
opts := RedisTLSOptions{
CACert: caCert,
Cert: cert,
Key: key,
}
// If one of redis option not nil, make sure CA, cert, and key provided
if !lo.IsEmpty(opts) {
if opts.CACert == "" || opts.Cert == "" || opts.Key == "" {
return RedisTLSOptions{}, xerrors.Errorf("you must provide Redis CA, cert and key file path when using TLS")
}
}
return opts, nil
}
type RedisCache struct {
client *redis.Client
expiration time.Duration
}
func NewRedisCache(backend, caCertPath, certPath, keyPath string, enableTLS bool, ttl time.Duration) (RedisCache, error) {
opts, err := NewRedisOptions(backend, caCertPath, certPath, keyPath, enableTLS)
if err != nil {
return RedisCache{}, xerrors.Errorf("failed to create Redis options: %w", err)
}
log.Info("Redis scan cache", log.String("url", opts.BackendMasked()))
options, err := redis.ParseURL(opts.Backend)
if err != nil {
return RedisCache{}, xerrors.Errorf("failed to parse Redis URL: %w", err)
}
if tlsOpts := opts.TLSOptions; !lo.IsEmpty(tlsOpts) {
caCert, cert, err := GetTLSConfig(tlsOpts.CACert, tlsOpts.Cert, tlsOpts.Key)
if err != nil {
return RedisCache{}, xerrors.Errorf("failed to get TLS config: %w", err)
}
options.TLSConfig = &tls.Config{
RootCAs: caCert,
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
} else if opts.TLS {
options.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
return RedisCache{
client: redis.NewClient(options),
expiration: ttl,
}, nil
}
func (c RedisCache) PutArtifact(ctx context.Context, artifactID string, artifactConfig types.ArtifactInfo) error {
key := fmt.Sprintf("%s::%s::%s", redisPrefix, artifactBucket, artifactID)
b, err := json.Marshal(artifactConfig)
if err != nil {
return xerrors.Errorf("failed to marshal artifact JSON: %w", err)
}
if err := c.client.Set(ctx, key, string(b), c.expiration).Err(); err != nil {
return xerrors.Errorf("unable to store artifact information in Redis cache (%s): %w", artifactID, err)
}
return nil
}
func (c RedisCache) PutBlob(ctx context.Context, blobID string, blobInfo types.BlobInfo) error {
b, err := json.Marshal(blobInfo)
if err != nil {
return xerrors.Errorf("failed to marshal blob JSON: %w", err)
}
key := fmt.Sprintf("%s::%s::%s", redisPrefix, blobBucket, blobID)
if err := c.client.Set(ctx, key, string(b), c.expiration).Err(); err != nil {
return xerrors.Errorf("unable to store blob information in Redis cache (%s): %w", blobID, err)
}
return nil
}
func (c RedisCache) DeleteBlobs(ctx context.Context, blobIDs []string) error {
var errs error
for _, blobID := range blobIDs {
key := fmt.Sprintf("%s::%s::%s", redisPrefix, blobBucket, blobID)
if err := c.client.Del(ctx, key).Err(); err != nil {
errs = multierror.Append(errs, xerrors.Errorf("unable to delete blob %s: %w", blobID, err))
}
}
return errs
}
func (c RedisCache) GetArtifact(ctx context.Context, artifactID string) (types.ArtifactInfo, error) {
key := fmt.Sprintf("%s::%s::%s", redisPrefix, artifactBucket, artifactID)
val, err := c.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return types.ArtifactInfo{}, xerrors.Errorf("artifact (%s) is missing in Redis cache", artifactID)
} else if err != nil {
return types.ArtifactInfo{}, xerrors.Errorf("failed to get artifact from the Redis cache: %w", err)
}
var info types.ArtifactInfo
err = json.Unmarshal(val, &info)
if err != nil {
return types.ArtifactInfo{}, xerrors.Errorf("failed to unmarshal artifact (%s) from Redis value: %w", artifactID, err)
}
return info, nil
}
func (c RedisCache) GetBlob(ctx context.Context, blobID string) (types.BlobInfo, error) {
key := fmt.Sprintf("%s::%s::%s", redisPrefix, blobBucket, blobID)
val, err := c.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return types.BlobInfo{}, xerrors.Errorf("blob (%s) is missing in Redis cache", blobID)
} else if err != nil {
return types.BlobInfo{}, xerrors.Errorf("failed to get blob from the Redis cache: %w", err)
}
var blobInfo types.BlobInfo
if err = json.Unmarshal(val, &blobInfo); err != nil {
return types.BlobInfo{}, xerrors.Errorf("failed to unmarshal blob (%s) from Redis value: %w", blobID, err)
}
return blobInfo, nil
}
func (c RedisCache) MissingBlobs(ctx context.Context, artifactID string, blobIDs []string) (bool, []string, error) {
var missingArtifact bool
var missingBlobIDs []string
for _, blobID := range blobIDs {
blobInfo, err := c.GetBlob(ctx, blobID)
if err != nil {
// error means cache missed blob info
missingBlobIDs = append(missingBlobIDs, blobID)
continue
}
if blobInfo.SchemaVersion != types.BlobJSONSchemaVersion {
missingBlobIDs = append(missingBlobIDs, blobID)
}
}
// get artifact info
artifactInfo, err := c.GetArtifact(ctx, artifactID)
// error means cache missed artifact info
if err != nil {
return true, missingBlobIDs, nil
}
if artifactInfo.SchemaVersion != types.ArtifactJSONSchemaVersion {
missingArtifact = true
}
return missingArtifact, missingBlobIDs, nil
}
func (c RedisCache) Close() error {
return c.client.Close()
}
func (c RedisCache) Clear(ctx context.Context) error {
for {
keys, cursor, err := c.client.Scan(ctx, 0, redisPrefix+"::*", 100).Result()
if err != nil {
return xerrors.Errorf("failed to perform prefix scanning: %w", err)
}
if err = c.client.Unlink(ctx, keys...).Err(); err != nil {
return xerrors.Errorf("failed to unlink redis keys: %w", err)
}
if cursor == 0 { // We cleared all keys
break
}
}
return nil
}
// GetTLSConfig gets tls config from CA, Cert and Key file
func GetTLSConfig(caCertPath, certPath, keyPath string) (*x509.CertPool, tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, tls.Certificate{}, err
}
caCert, err := os.ReadFile(caCertPath)
if err != nil {
return nil, tls.Certificate{}, err
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
return caCertPool, cert, nil
}