mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-23 15:37:50 -08:00
cache: Define an interface for cache, remove global state
Signed-off-by: Simarpreet Singh <simar@linux.com>
This commit is contained in:
31
cache/cache.go
vendored
31
cache/cache.go
vendored
@@ -6,18 +6,29 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/aquasecurity/fanal/utils"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
cacheDir = utils.CacheDir()
|
||||
replacer = strings.NewReplacer("/", "_")
|
||||
)
|
||||
|
||||
func Get(key string) io.Reader {
|
||||
filePath := filepath.Join(cacheDir, replacer.Replace(key))
|
||||
type Cache interface {
|
||||
Get(key string) io.Reader
|
||||
Set(key string, file io.Reader) (io.Reader, error)
|
||||
Clear() error
|
||||
}
|
||||
|
||||
type FSCache struct {
|
||||
Directory string
|
||||
}
|
||||
|
||||
func Initialize(cacheDir string) Cache {
|
||||
return &FSCache{Directory: cacheDir}
|
||||
}
|
||||
|
||||
func (fs FSCache) Get(key string) io.Reader {
|
||||
filePath := filepath.Join(fs.Directory, replacer.Replace(key))
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -25,9 +36,9 @@ func Get(key string) io.Reader {
|
||||
return f
|
||||
}
|
||||
|
||||
func Set(key string, file io.Reader) (io.Reader, error) {
|
||||
filePath := filepath.Join(cacheDir, replacer.Replace(key))
|
||||
if err := os.MkdirAll(cacheDir, os.ModePerm); err != nil {
|
||||
func (fs FSCache) Set(key string, file io.Reader) (io.Reader, error) {
|
||||
filePath := filepath.Join(fs.Directory, replacer.Replace(key))
|
||||
if err := os.MkdirAll(fs.Directory, os.ModePerm); err != nil {
|
||||
return nil, xerrors.Errorf("failed to mkdir all: %w", err)
|
||||
}
|
||||
cacheFile, err := os.Create(filePath)
|
||||
@@ -39,8 +50,8 @@ func Set(key string, file io.Reader) (io.Reader, error) {
|
||||
return tee, nil
|
||||
}
|
||||
|
||||
func Clear() error {
|
||||
if err := os.RemoveAll(cacheDir); err != nil {
|
||||
func (fs FSCache) Clear() error {
|
||||
if err := os.RemoveAll(fs.Directory); err != nil {
|
||||
return xerrors.New("failed to remove cache")
|
||||
}
|
||||
return nil
|
||||
|
||||
19
cache/cache_test.go
vendored
19
cache/cache_test.go
vendored
@@ -10,36 +10,31 @@ import (
|
||||
)
|
||||
|
||||
func TestSetAndGetAndClear(t *testing.T) {
|
||||
d, _ := ioutil.TempDir("", "TestCacheDir-*")
|
||||
f, _ := ioutil.TempFile(d, "foo.bar.baz-*")
|
||||
tempCacheDir, _ := ioutil.TempDir("", "TestCacheDir-*")
|
||||
f, _ := ioutil.TempFile(tempCacheDir, "foo.bar.baz-*")
|
||||
|
||||
oldCacheDir := cacheDir
|
||||
defer func() {
|
||||
cacheDir = oldCacheDir
|
||||
_ = os.RemoveAll(d)
|
||||
}()
|
||||
cacheDir = d
|
||||
c := Initialize(tempCacheDir)
|
||||
|
||||
// set
|
||||
expectedCacheContents := "foo bar baz"
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte(expectedCacheContents))
|
||||
|
||||
r, err := Set(f.Name(), &buf)
|
||||
r, err := c.Set(f.Name(), &buf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
b, _ := ioutil.ReadAll(r)
|
||||
assert.Equal(t, expectedCacheContents, string(b))
|
||||
|
||||
// get
|
||||
actualFile := Get(f.Name())
|
||||
actualFile := c.Get(f.Name())
|
||||
actualBytes, _ := ioutil.ReadAll(actualFile)
|
||||
assert.Equal(t, expectedCacheContents, string(actualBytes))
|
||||
|
||||
// clear
|
||||
assert.NoError(t, Clear())
|
||||
assert.NoError(t, c.Clear())
|
||||
|
||||
// confirm that no cachedir remains
|
||||
_, err = os.Stat(d)
|
||||
_, err = os.Stat(tempCacheDir)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/aquasecurity/fanal/utils"
|
||||
|
||||
"github.com/aquasecurity/fanal/types"
|
||||
|
||||
"github.com/aquasecurity/fanal/extractor/docker"
|
||||
@@ -49,8 +51,10 @@ func run() (err error) {
|
||||
clearCache := flag.Bool("clear", false, "clear cache")
|
||||
flag.Parse()
|
||||
|
||||
c := cache.Initialize(utils.CacheDir())
|
||||
|
||||
if *clearCache {
|
||||
if err = cache.Clear(); err != nil {
|
||||
if err = c.Clear(); err != nil {
|
||||
return xerrors.Errorf("error in cache clear: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ type layer struct {
|
||||
|
||||
type DockerExtractor struct {
|
||||
Client *client.Client
|
||||
Cache cache.Cache
|
||||
Option types.DockerOption
|
||||
}
|
||||
|
||||
@@ -74,7 +75,11 @@ func NewDockerExtractor(option types.DockerOption) (extractor.Extractor, error)
|
||||
return nil, xerrors.Errorf("error initializing docker extractor: %w", err)
|
||||
}
|
||||
|
||||
return DockerExtractor{Option: option, Client: cli}, nil
|
||||
return DockerExtractor{
|
||||
Option: option,
|
||||
Client: cli,
|
||||
Cache: cache.Initialize(utils.CacheDir()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func applyLayers(layerPaths []string, filesInLayers map[string]extractor.FileMap, opqInLayers map[string]extractor.OPQDirs) (extractor.FileMap, error) {
|
||||
@@ -141,14 +146,14 @@ func (d DockerExtractor) createRegistryClient(ctx context.Context, domain string
|
||||
|
||||
func (d DockerExtractor) SaveLocalImage(ctx context.Context, imageName string) (io.Reader, error) {
|
||||
var err error
|
||||
r := cache.Get(imageName)
|
||||
r := d.Cache.Get(imageName)
|
||||
if r == nil {
|
||||
// Save the image
|
||||
r, err = d.saveLocalImage(ctx, imageName)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to save the image: %w", err)
|
||||
}
|
||||
r, err = cache.Set(imageName, r)
|
||||
r, err = d.Cache.Set(imageName, r)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
@@ -194,18 +199,18 @@ func (d DockerExtractor) Extract(ctx context.Context, imageName string, filename
|
||||
|
||||
for _, ref := range m.Manifest.Layers {
|
||||
layerIDs = append(layerIDs, string(ref.Digest))
|
||||
go func(d digest.Digest) {
|
||||
go func(dig digest.Digest) {
|
||||
// Use cache
|
||||
rc := cache.Get(string(d))
|
||||
rc := d.Cache.Get(string(dig))
|
||||
if rc == nil {
|
||||
// Download the layer.
|
||||
rc, err = r.DownloadLayer(ctx, image.Path, d)
|
||||
rc, err = r.DownloadLayer(ctx, image.Path, dig)
|
||||
if err != nil {
|
||||
errCh <- xerrors.Errorf("failed to download the layer(%s): %w", d, err)
|
||||
errCh <- xerrors.Errorf("failed to download the layer(%s): %w", dig, err)
|
||||
return
|
||||
}
|
||||
|
||||
rc, err = cache.Set(string(d), rc)
|
||||
rc, err = d.Cache.Set(string(dig), rc)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
@@ -215,7 +220,7 @@ func (d DockerExtractor) Extract(ctx context.Context, imageName string, filename
|
||||
errCh <- xerrors.Errorf("invalid gzip: %w", err)
|
||||
return
|
||||
}
|
||||
ch <- layer{ID: d, Content: gzipReader}
|
||||
ch <- layer{ID: dig, Content: gzipReader}
|
||||
}(ref.Digest)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package docker
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -11,6 +12,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/aquasecurity/fanal/cache"
|
||||
|
||||
"github.com/docker/docker/client"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -201,12 +204,19 @@ func TestDockerExtractor_SaveLocalImage(t *testing.T) {
|
||||
c, err := client.NewClientWithOpts(client.WithHost(ts.URL))
|
||||
assert.NoError(t, err)
|
||||
|
||||
d := DockerExtractor{
|
||||
// setup cache
|
||||
tempCacheDir, _ := ioutil.TempDir("", "TestDockerExtractor_SaveLocalImage-*")
|
||||
defer func() {
|
||||
_ = os.RemoveAll(tempCacheDir)
|
||||
}()
|
||||
|
||||
de := DockerExtractor{
|
||||
Option: types.DockerOption{},
|
||||
Client: c,
|
||||
Cache: cache.Initialize(tempCacheDir),
|
||||
}
|
||||
|
||||
r, err := d.SaveLocalImage(context.TODO(), "fooimage")
|
||||
r, err := de.SaveLocalImage(context.TODO(), "fooimage")
|
||||
assert.NotNil(t, r)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user