diff --git a/go.mod b/go.mod index 8238541641..c63ea2feb0 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,8 @@ require ( github.com/aquasecurity/fanal v0.0.0-20191205044128-99e4876e56b0 github.com/aquasecurity/go-dep-parser v0.0.0-20190819075924-ea223f0ef24b github.com/aquasecurity/trivy-db v0.0.0-20191120190201-a6645984b409 - github.com/briandowns/spinner v0.0.0-20190319032542-ac46072a5a91 github.com/caarlos0/env/v6 v6.0.0 + github.com/cheggaaa/pb/v3 v3.0.3 github.com/genuinetools/reg v0.16.0 github.com/golang/protobuf v1.3.1 github.com/google/go-github/v28 v28.1.1 @@ -27,10 +27,8 @@ require ( golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529 golang.org/x/net v0.0.0-20191014212845-da9a3fd4c582 // indirect golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421 - golang.org/x/sys v0.0.0-20191020152052-9984515f0562 // indirect golang.org/x/tools v0.0.0-20191121040551-947d4aa89328 // indirect golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 - gopkg.in/cheggaaa/pb.v1 v1.0.28 gopkg.in/yaml.v2 v2.2.4 // indirect k8s.io/utils v0.0.0-20191010214722-8d271d903fe4 ) diff --git a/go.sum b/go.sum index 7a6de6c885..ce78273319 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8 github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/VividCortex/ewma v1.1.1 h1:MnEK4VOv6n0RSY4vtRe3h11qjxL3+t0B8yOL8iMXdcM= +github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= github.com/alcortesm/tgz v0.0.0-20161220082320-9c5fe88206d7 h1:uSoVVbwJiQipAclBbw+8quDsfcvFjOpI5iCf4p/cqCs= github.com/alcortesm/tgz v0.0.0-20161220082320-9c5fe88206d7/go.mod h1:6zEj6s6u/ghQa61ZWa/C2Aw3RkjiTBOix7dkqa1VLIs= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc h1:cAKDfWh5VpdgMhJosfJnn5/FoN2SRZ4p7fJNX58YPaU= @@ -47,6 +49,8 @@ github.com/briandowns/spinner v0.0.0-20190319032542-ac46072a5a91/go.mod h1:hw/JE github.com/caarlos0/env/v6 v6.0.0 h1:NZt6FAoB8ieKO5lEwRdwCzYxWFx7ZYF2R7UcoyaWtyc= github.com/caarlos0/env/v6 v6.0.0/go.mod h1:+wdyOmtjoZIW2GJOc2OYa5NoOFuWD/bIpWqm30NgtRk= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= +github.com/cheggaaa/pb/v3 v3.0.3 h1:8WApbyUmgMOz7WIxJVNK0IRDcRfAmTxcEdi0TuxjdP4= +github.com/cheggaaa/pb/v3 v3.0.3/go.mod h1:Pp35CDuiEpHa/ZLGCtBbM6CBwMstv1bJlG884V+73Yc= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc h1:TP+534wVlf61smEIq1nwLLAjQVEK2EADoW3CX9AuT+8= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= @@ -189,10 +193,14 @@ github.com/mattn/go-isatty v0.0.5 h1:tHXDdz1cpzGaovsTB+TVB8q90WEokoVmfMqoVcrLUgw github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10= +github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= github.com/mattn/go-jsonpointer v0.0.0-20180225143300-37667080efed h1:fCWISZq4YN4ulCJx7x0KB15rqxLEe3mtNJL8cSOGKZU= github.com/mattn/go-jsonpointer v0.0.0-20180225143300-37667080efed/go.mod h1:SDJ4hurDYyQ9/7nc+eCYtXqdufgK4Cq9TJlwPklqEYA= github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y= github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-runewidth v0.0.6 h1:V2iyH+aX9C5fsYCpK60U8BYIvmhqxuOL3JZcqc1NB7k= +github.com/mattn/go-runewidth v0.0.6/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= @@ -356,8 +364,9 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190506115046-ca7f33d4116e h1:bq5BY1tGuaK8HxuwN6pT6kWgTVLeJ5KwuyBpsl1CZL4= golang.org/x/sys v0.0.0-20190506115046-ca7f33d4116e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191020152052-9984515f0562 h1:wOweSabW7qssfcg63CEDHHA4zyoqRlGU6eYV7IUMCq0= -golang.org/x/sys v0.0.0-20191020152052-9984515f0562/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 h1:ZBzSG/7F4eNKz2L3GE9o300RX0Az1Bw5HF7PDraD+qU= +golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 h1:z99zHgr7hKfrUcX/KsoJk5FJfjTceCKIp96+biqP4To= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/client/config/config.go b/internal/client/config/config.go index 9927c0052b..2adf757c0e 100644 --- a/internal/client/config/config.go +++ b/internal/client/config/config.go @@ -12,7 +12,6 @@ import ( dbTypes "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/utils" ) type Config struct { @@ -87,10 +86,6 @@ func (c *Config) Init() (err error) { c.VulnType = strings.Split(c.vulnType, ",") c.AppVersion = c.context.App.Version - if c.Quiet { - utils.Quiet = true - } - // --clear-cache doesn't conduct the scan if c.ClearCache { return nil diff --git a/internal/operation/inject.go b/internal/operation/inject.go index 7dcfe62357..fc6aebd126 100644 --- a/internal/operation/inject.go +++ b/internal/operation/inject.go @@ -7,7 +7,7 @@ import ( "github.com/google/wire" ) -func initializeDBClient() db.Client { +func initializeDBClient(quiet bool) db.Client { wire.Build(db.SuperSet) return db.Client{} } diff --git a/internal/operation/operation.go b/internal/operation/operation.go index 6661e1a345..9fe6d6a0fe 100644 --- a/internal/operation/operation.go +++ b/internal/operation/operation.go @@ -33,8 +33,8 @@ func ClearCache() error { return nil } -func DownloadDB(appVersion, cacheDir string, light, skipUpdate bool) error { - client := initializeDBClient() +func DownloadDB(appVersion, cacheDir string, quiet, light, skipUpdate bool) error { + client := initializeDBClient(quiet) ctx := context.Background() needsUpdate, err := client.NeedsUpdate(ctx, appVersion, light, skipUpdate) if err != nil { @@ -46,6 +46,7 @@ func DownloadDB(appVersion, cacheDir string, light, skipUpdate bool) error { if err = db.Close(); err != nil { return xerrors.Errorf("failed db close: %w", err) } + log.Logger.Info("Downloading DB...") if err := client.Download(ctx, cacheDir, light); err != nil { return xerrors.Errorf("failed to download vulnerability DB: %w", err) } diff --git a/internal/operation/wire_gen.go b/internal/operation/wire_gen.go index 3e6e1e4ac5..460ec57e29 100644 --- a/internal/operation/wire_gen.go +++ b/internal/operation/wire_gen.go @@ -9,15 +9,17 @@ import ( db2 "github.com/aquasecurity/trivy-db/pkg/db" "github.com/aquasecurity/trivy/pkg/db" "github.com/aquasecurity/trivy/pkg/github" + "github.com/aquasecurity/trivy/pkg/indicator" "k8s.io/utils/clock" ) // Injectors from inject.go: -func initializeDBClient() db.Client { +func initializeDBClient(quiet bool) db.Client { config := db2.Config{} client := github.NewClient() + progressBar := indicator.NewProgressBar(quiet) realClock := clock.RealClock{} - dbClient := db.NewClient(config, client, realClock) + dbClient := db.NewClient(config, client, progressBar, realClock) return dbClient } diff --git a/internal/server/config/config.go b/internal/server/config/config.go index 8dcda814a4..7e6e591cea 100644 --- a/internal/server/config/config.go +++ b/internal/server/config/config.go @@ -3,8 +3,6 @@ package config import ( "github.com/urfave/cli" "golang.org/x/xerrors" - - "github.com/aquasecurity/trivy/pkg/utils" ) type Config struct { @@ -48,8 +46,5 @@ func (c *Config) Init() (err error) { c.AppVersion = c.context.App.Version - // A server always suppresses a progress bar - utils.Quiet = true - return nil } diff --git a/internal/server/run.go b/internal/server/run.go index a485e2ca00..e75eb358da 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -39,7 +39,7 @@ func run(c config.Config) (err error) { } // download the database file - if err = operation.DownloadDB(c.AppVersion, c.CacheDir, false, c.SkipUpdate); err != nil { + if err = operation.DownloadDB(c.AppVersion, c.CacheDir, true, false, c.SkipUpdate); err != nil { return err } diff --git a/internal/standalone/config/config.go b/internal/standalone/config/config.go index 355bb2e837..48a62d1470 100644 --- a/internal/standalone/config/config.go +++ b/internal/standalone/config/config.go @@ -12,7 +12,6 @@ import ( dbTypes "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/utils" ) type Config struct { @@ -109,10 +108,6 @@ func (c *Config) Init() (err error) { c.VulnType = strings.Split(c.vulnType, ",") c.AppVersion = c.context.App.Version - if c.Quiet || c.NoProgress { - utils.Quiet = true - } - // --clear-cache, --download-db-only and --reset don't conduct the scan if c.ClearCache || c.DownloadDBOnly || c.Reset { return nil diff --git a/internal/standalone/run.go b/internal/standalone/run.go index 59a5bc9f69..b86ae98077 100644 --- a/internal/standalone/run.go +++ b/internal/standalone/run.go @@ -50,7 +50,8 @@ func run(c config.Config) (err error) { } // download the database file - if err = operation.DownloadDB(c.AppVersion, c.CacheDir, c.Light, c.SkipUpdate); err != nil { + noProgress := c.Quiet || c.NoProgress + if err = operation.DownloadDB(c.AppVersion, c.CacheDir, noProgress, c.Light, c.SkipUpdate); err != nil { return err } diff --git a/pkg/db/db.go b/pkg/db/db.go index bb5bb81c7a..4f03542d18 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -13,8 +13,8 @@ import ( "github.com/aquasecurity/trivy-db/pkg/db" "github.com/aquasecurity/trivy/pkg/github" + "github.com/aquasecurity/trivy/pkg/indicator" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/utils" ) const ( @@ -23,11 +23,21 @@ const ( ) var SuperSet = wire.NewSet( + // indicator.ProgressBar + indicator.NewProgressBar, + + // clock.Clock wire.Struct(new(clock.RealClock)), wire.Bind(new(clock.Clock), new(clock.RealClock)), + + // db.Config wire.Struct(new(db.Config)), + + // github.Client github.NewClient, wire.Bind(new(github.Operation), new(github.Client)), + + // db.Client NewClient, wire.Bind(new(Operation), new(Client)), ) @@ -44,13 +54,15 @@ type dbOperation interface { type Client struct { dbc dbOperation githubClient github.Operation + pb indicator.ProgressBar clock clock.Clock } -func NewClient(dbc db.Config, githubClient github.Operation, clock clock.Clock) Client { +func NewClient(dbc db.Config, githubClient github.Operation, pb indicator.ProgressBar, clock clock.Clock) Client { return Client{ dbc: dbc, githubClient: githubClient, + pb: pb, clock: clock, } } @@ -102,23 +114,21 @@ func (c Client) NeedsUpdate(ctx context.Context, cliVersion string, light, skip func (c Client) Download(ctx context.Context, cacheDir string, light bool) error { dbFile := fullDB - message := " Downloading Full DB file..." if light { dbFile = lightDB - message = " Downloading Lightweight DB file..." } - spinner := utils.NewSpinner(message) - spinner.Start() - defer spinner.Stop() - - rc, err := c.githubClient.DownloadDB(ctx, dbFile) + rc, size, err := c.githubClient.DownloadDB(ctx, dbFile) if err != nil { return xerrors.Errorf("failed to download vulnerability DB: %w", err) } defer rc.Close() - gr, err := gzip.NewReader(rc) + bar := c.pb.Start(int64(size)) + barReader := bar.NewProxyReader(rc) + defer bar.Finish() + + gr, err := gzip.NewReader(barReader) if err != nil { return xerrors.Errorf("invalid gzip file: %w", err) } diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go index 50dcf09191..a1913dc02c 100644 --- a/pkg/db/db_test.go +++ b/pkg/db/db_test.go @@ -3,12 +3,13 @@ package db import ( "context" "errors" - "io" "io/ioutil" "os" "testing" "time" + "github.com/aquasecurity/trivy/pkg/indicator" + "github.com/stretchr/testify/require" "golang.org/x/xerrors" @@ -208,30 +209,21 @@ func TestClient_NeedsUpdate(t *testing.T) { } func TestClient_Download(t *testing.T) { - type downloadDBOutput struct { - fileName string - err error - } - type downloadDB struct { - input string - output downloadDBOutput - } - testCases := []struct { name string light bool - downloadDB []downloadDB + downloadDB []github.DownloadDBExpectation expectedContent []byte expectedError error }{ { name: "happy path", light: false, - downloadDB: []downloadDB{ + downloadDB: []github.DownloadDBExpectation{ { - input: fullDB, - output: downloadDBOutput{ - fileName: "testdata/test.db.gz", + Args: github.DownloadDBInput{FileName: fullDB}, + ReturnArgs: github.DownloadDBOutput{ + FileName: "testdata/test.db.gz", }, }, }, @@ -239,11 +231,11 @@ func TestClient_Download(t *testing.T) { { name: "DownloadDB returns an error", light: false, - downloadDB: []downloadDB{ + downloadDB: []github.DownloadDBExpectation{ { - input: fullDB, - output: downloadDBOutput{ - err: xerrors.New("download failed"), + Args: github.DownloadDBInput{FileName: fullDB}, + ReturnArgs: github.DownloadDBOutput{ + Err: xerrors.New("download failed"), }, }, }, @@ -252,11 +244,11 @@ func TestClient_Download(t *testing.T) { { name: "invalid gzip", light: false, - downloadDB: []downloadDB{ + downloadDB: []github.DownloadDBExpectation{ { - input: fullDB, - output: downloadDBOutput{ - fileName: "testdata/invalid.db.gz", + Args: github.DownloadDBInput{FileName: fullDB}, + ReturnArgs: github.DownloadDBOutput{ + FileName: "testdata/invalid.db.gz", }, }, }, @@ -269,19 +261,8 @@ func TestClient_Download(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - mockGitHubClient := new(github.MockClient) - for _, dd := range tc.downloadDB { - var rc io.ReadCloser - if dd.output.fileName != "" { - f, err := os.Open(dd.output.fileName) - assert.NoError(t, err, tc.name) - rc = f - } - - mockGitHubClient.On("DownloadDB", mock.Anything, dd.input).Return( - rc, dd.output.err, - ) - } + mockGitHubClient, err := github.NewMockClient(tc.downloadDB) + require.NoError(t, err, tc.name) dir, err := ioutil.TempDir("", "db") require.NoError(t, err, tc.name) @@ -290,7 +271,8 @@ func TestClient_Download(t *testing.T) { err = db.Init(dir) require.NoError(t, err, tc.name) - client := NewClient(db.Config{}, mockGitHubClient, nil) + pb := indicator.NewProgressBar(true) + client := NewClient(db.Config{}, mockGitHubClient, pb, nil) ctx := context.Background() err = client.Download(ctx, dir, tc.light) diff --git a/pkg/github/github.go b/pkg/github/github.go index 2883319d6e..6711d7e0e3 100644 --- a/pkg/github/github.go +++ b/pkg/github/github.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "sort" + "strconv" "strings" "github.com/aquasecurity/trivy-db/pkg/db" @@ -42,7 +43,7 @@ func (r Repository) DownloadAsset(ctx context.Context, id int64) (io.ReadCloser, } type Operation interface { - DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, error) + DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, int, error) } type Client struct { @@ -72,11 +73,11 @@ func NewClient() Client { } } -func (c Client) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, error) { +func (c Client) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, int, error) { options := github.ListOptions{} releases, _, err := c.Repository.ListReleases(ctx, &options) if err != nil { - return nil, xerrors.Errorf("failed to list releases: %w", err) + return nil, 0, xerrors.Errorf("failed to list releases: %w", err) } sort.Slice(releases, func(i, j int) bool { @@ -91,37 +92,42 @@ func (c Client) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, } for _, asset := range release.Assets { - rc, err := c.downloadAsset(ctx, asset, fileName) + rc, size, err := c.downloadAsset(ctx, asset, fileName) if err != nil { log.Logger.Debug(err) continue } - return rc, nil + return rc, size, nil } } - return nil, xerrors.New("DB file not found") + return nil, 0, xerrors.New("DB file not found") } -func (c Client) downloadAsset(ctx context.Context, asset github.ReleaseAsset, fileName string) (io.ReadCloser, error) { +func (c Client) downloadAsset(ctx context.Context, asset github.ReleaseAsset, fileName string) (io.ReadCloser, int, error) { log.Logger.Debugf("asset name: %s", asset.GetName()) if asset.GetName() != fileName { - return nil, xerrors.New("file name doesn't match") + return nil, 0, xerrors.New("file name doesn't match") } rc, url, err := c.Repository.DownloadAsset(ctx, asset.GetID()) if err != nil { - return nil, xerrors.Errorf("unable to download the asset: %w", err) + return nil, 0, xerrors.Errorf("unable to download the asset: %w", err) } if rc != nil { - return rc, nil + return rc, asset.GetSize(), nil } log.Logger.Debugf("asset URL: %s", url) resp, err := http.Get(url) if err != nil || resp.StatusCode != http.StatusOK { - return nil, xerrors.Errorf("unable to download the asset via URL: %w", err) + return nil, 0, xerrors.Errorf("unable to download the asset via URL: %w", err) } - return resp.Body, nil + + size, err := strconv.Atoi(resp.Header.Get("Content-Length")) + if err != nil { + return nil, 0, xerrors.Errorf("invalid size: %w", err) + } + return resp.Body, size, nil } diff --git a/pkg/github/github_mock.go b/pkg/github/github_mock.go index 1bb1f4d41c..d9d01da873 100644 --- a/pkg/github/github_mock.go +++ b/pkg/github/github_mock.go @@ -3,6 +3,7 @@ package github import ( "context" "io" + "os" "github.com/stretchr/testify/mock" ) @@ -11,15 +12,46 @@ type MockClient struct { mock.Mock } -func (_m *MockClient) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, error) { +type DownloadDBInput struct { + FileName string +} +type DownloadDBOutput struct { + FileName string + Size int + Err error +} +type DownloadDBExpectation struct { + Args DownloadDBInput + ReturnArgs DownloadDBOutput +} + +func NewMockClient(downloadDBExpectations []DownloadDBExpectation) (*MockClient, error) { + mockDetector := new(MockClient) + for _, e := range downloadDBExpectations { + var rc io.ReadCloser + if e.ReturnArgs.FileName != "" { + f, err := os.Open(e.ReturnArgs.FileName) + if err != nil { + return nil, err + } + rc = f + } + + mockDetector.On("DownloadDB", mock.Anything, e.Args.FileName).Return( + rc, e.ReturnArgs.Size, e.ReturnArgs.Err) + } + return mockDetector, nil +} + +func (_m *MockClient) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, int, error) { ret := _m.Called(ctx, fileName) ret0 := ret.Get(0) if ret0 == nil { - return nil, ret.Error(1) + return nil, ret.Int(1), ret.Error(2) } rc, ok := ret0.(io.ReadCloser) if !ok { - return nil, ret.Error(1) + return nil, ret.Int(1), ret.Error(2) } - return rc, ret.Error(1) + return rc, ret.Int(1), ret.Error(2) } diff --git a/pkg/github/github_test.go b/pkg/github/github_test.go index 235e589529..e7c5d2e2b4 100644 --- a/pkg/github/github_test.go +++ b/pkg/github/github_test.go @@ -451,7 +451,7 @@ func TestClient_DownloadDB(t *testing.T) { } ctx := context.Background() - rc, err := client.DownloadDB(ctx, tc.fileName) + rc, _, err := client.DownloadDB(ctx, tc.fileName) switch { case tc.expectedError != nil: diff --git a/pkg/indicator/progress.go b/pkg/indicator/progress.go new file mode 100644 index 0000000000..9f50a4bfd4 --- /dev/null +++ b/pkg/indicator/progress.go @@ -0,0 +1,40 @@ +package indicator + +import ( + "io" + + "github.com/cheggaaa/pb/v3" +) + +type ProgressBar struct { + quiet bool +} + +func NewProgressBar(quiet bool) ProgressBar { + return ProgressBar{quiet: quiet} +} + +func (p ProgressBar) Start(total int64) Bar { + if p.quiet { + return Bar{} + } + bar := pb.Full.Start64(total) + return Bar{bar: bar} +} + +type Bar struct { + bar *pb.ProgressBar +} + +func (b Bar) NewProxyReader(r io.Reader) io.Reader { + if b.bar == nil { + return r + } + return b.bar.NewProxyReader(r) +} +func (b Bar) Finish() { + if b.bar == nil { + return + } + b.bar.Finish() +} diff --git a/pkg/rpc/server/inject.go b/pkg/rpc/server/inject.go index d5d989eeae..d525a53dfd 100644 --- a/pkg/rpc/server/inject.go +++ b/pkg/rpc/server/inject.go @@ -19,7 +19,7 @@ func initializeLibServer() *library.Server { return &library.Server{} } -func initializeDBWorker() dbWorker { +func initializeDBWorker(quiet bool) dbWorker { wire.Build(SuperSet) return dbWorker{} } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index bca8adfb9f..bd45634f54 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -45,7 +45,7 @@ func ListenAndServe(addr string, c config.Config) error { } go func() { - worker := initializeDBWorker() + worker := initializeDBWorker(true) ctx := context.Background() for { time.Sleep(1 * time.Hour) diff --git a/pkg/rpc/server/wire_gen.go b/pkg/rpc/server/wire_gen.go index 150c757734..22ce988ddf 100644 --- a/pkg/rpc/server/wire_gen.go +++ b/pkg/rpc/server/wire_gen.go @@ -11,6 +11,7 @@ import ( library2 "github.com/aquasecurity/trivy/pkg/detector/library" ospkg2 "github.com/aquasecurity/trivy/pkg/detector/ospkg" "github.com/aquasecurity/trivy/pkg/github" + "github.com/aquasecurity/trivy/pkg/indicator" "github.com/aquasecurity/trivy/pkg/rpc/server/library" "github.com/aquasecurity/trivy/pkg/rpc/server/ospkg" "github.com/aquasecurity/trivy/pkg/vulnerability" @@ -36,11 +37,12 @@ func initializeLibServer() *library.Server { return server } -func initializeDBWorker() dbWorker { +func initializeDBWorker(quiet bool) dbWorker { config := db.Config{} client := github.NewClient() + progressBar := indicator.NewProgressBar(quiet) realClock := clock.RealClock{} - dbClient := db2.NewClient(config, client, realClock) + dbClient := db2.NewClient(config, client, progressBar, realClock) serverDbWorker := newDBWorker(dbClient) return serverDbWorker } diff --git a/pkg/utils/progress.go b/pkg/utils/progress.go deleted file mode 100644 index 6790e1141a..0000000000 --- a/pkg/utils/progress.go +++ /dev/null @@ -1,64 +0,0 @@ -package utils - -import ( - "time" - - "github.com/briandowns/spinner" - pb "gopkg.in/cheggaaa/pb.v1" -) - -var ( - Quiet = false -) - -type Spinner struct { - client *spinner.Spinner -} - -func NewSpinner(suffix string) *Spinner { - if Quiet { - return &Spinner{} - } - s := spinner.New(spinner.CharSets[36], 100*time.Millisecond) - s.Suffix = suffix - return &Spinner{client: s} -} - -func (s *Spinner) Start() { - if s.client == nil { - return - } - s.client.Start() -} -func (s *Spinner) Stop() { - if s.client == nil { - return - } - s.client.Stop() -} - -// TODO: Expose an interface for progressbar -type ProgressBar struct { - client *pb.ProgressBar -} - -func PbStartNew(total int) *ProgressBar { - if Quiet { - return &ProgressBar{} - } - bar := pb.StartNew(total) - return &ProgressBar{client: bar} -} - -func (p *ProgressBar) Increment() { - if p.client == nil { - return - } - p.client.Increment() -} -func (p *ProgressBar) Finish() { - if p.client == nil { - return - } - p.client.Finish() -}