mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-22 07:10:41 -08:00
feat(db): show progress when downloading the DB (#317)
* fix(github): return db size * fix(github_mock): add size * feat(indicator): add progress bar * refactor(config): remove global Quiet * fix(db): take progress bar as an argument * fix(progress): inject progress bar
This commit is contained in:
4
go.mod
4
go.mod
@@ -6,8 +6,8 @@ require (
|
||||
github.com/aquasecurity/fanal v0.0.0-20191205044128-99e4876e56b0
|
||||
github.com/aquasecurity/go-dep-parser v0.0.0-20190819075924-ea223f0ef24b
|
||||
github.com/aquasecurity/trivy-db v0.0.0-20191120190201-a6645984b409
|
||||
github.com/briandowns/spinner v0.0.0-20190319032542-ac46072a5a91
|
||||
github.com/caarlos0/env/v6 v6.0.0
|
||||
github.com/cheggaaa/pb/v3 v3.0.3
|
||||
github.com/genuinetools/reg v0.16.0
|
||||
github.com/golang/protobuf v1.3.1
|
||||
github.com/google/go-github/v28 v28.1.1
|
||||
@@ -27,10 +27,8 @@ require (
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529
|
||||
golang.org/x/net v0.0.0-20191014212845-da9a3fd4c582 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421
|
||||
golang.org/x/sys v0.0.0-20191020152052-9984515f0562 // indirect
|
||||
golang.org/x/tools v0.0.0-20191121040551-947d4aa89328 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898
|
||||
gopkg.in/cheggaaa/pb.v1 v1.0.28
|
||||
gopkg.in/yaml.v2 v2.2.4 // indirect
|
||||
k8s.io/utils v0.0.0-20191010214722-8d271d903fe4
|
||||
)
|
||||
|
||||
13
go.sum
13
go.sum
@@ -15,6 +15,8 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8
|
||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
|
||||
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
||||
github.com/VividCortex/ewma v1.1.1 h1:MnEK4VOv6n0RSY4vtRe3h11qjxL3+t0B8yOL8iMXdcM=
|
||||
github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA=
|
||||
github.com/alcortesm/tgz v0.0.0-20161220082320-9c5fe88206d7 h1:uSoVVbwJiQipAclBbw+8quDsfcvFjOpI5iCf4p/cqCs=
|
||||
github.com/alcortesm/tgz v0.0.0-20161220082320-9c5fe88206d7/go.mod h1:6zEj6s6u/ghQa61ZWa/C2Aw3RkjiTBOix7dkqa1VLIs=
|
||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc h1:cAKDfWh5VpdgMhJosfJnn5/FoN2SRZ4p7fJNX58YPaU=
|
||||
@@ -47,6 +49,8 @@ github.com/briandowns/spinner v0.0.0-20190319032542-ac46072a5a91/go.mod h1:hw/JE
|
||||
github.com/caarlos0/env/v6 v6.0.0 h1:NZt6FAoB8ieKO5lEwRdwCzYxWFx7ZYF2R7UcoyaWtyc=
|
||||
github.com/caarlos0/env/v6 v6.0.0/go.mod h1:+wdyOmtjoZIW2GJOc2OYa5NoOFuWD/bIpWqm30NgtRk=
|
||||
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
|
||||
github.com/cheggaaa/pb/v3 v3.0.3 h1:8WApbyUmgMOz7WIxJVNK0IRDcRfAmTxcEdi0TuxjdP4=
|
||||
github.com/cheggaaa/pb/v3 v3.0.3/go.mod h1:Pp35CDuiEpHa/ZLGCtBbM6CBwMstv1bJlG884V+73Yc=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc h1:TP+534wVlf61smEIq1nwLLAjQVEK2EADoW3CX9AuT+8=
|
||||
github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y=
|
||||
@@ -189,10 +193,14 @@ github.com/mattn/go-isatty v0.0.5 h1:tHXDdz1cpzGaovsTB+TVB8q90WEokoVmfMqoVcrLUgw
|
||||
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE=
|
||||
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10=
|
||||
github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84=
|
||||
github.com/mattn/go-jsonpointer v0.0.0-20180225143300-37667080efed h1:fCWISZq4YN4ulCJx7x0KB15rqxLEe3mtNJL8cSOGKZU=
|
||||
github.com/mattn/go-jsonpointer v0.0.0-20180225143300-37667080efed/go.mod h1:SDJ4hurDYyQ9/7nc+eCYtXqdufgK4Cq9TJlwPklqEYA=
|
||||
github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y=
|
||||
github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
|
||||
github.com/mattn/go-runewidth v0.0.6 h1:V2iyH+aX9C5fsYCpK60U8BYIvmhqxuOL3JZcqc1NB7k=
|
||||
github.com/mattn/go-runewidth v0.0.6/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
@@ -356,8 +364,9 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190506115046-ca7f33d4116e h1:bq5BY1tGuaK8HxuwN6pT6kWgTVLeJ5KwuyBpsl1CZL4=
|
||||
golang.org/x/sys v0.0.0-20190506115046-ca7f33d4116e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191020152052-9984515f0562 h1:wOweSabW7qssfcg63CEDHHA4zyoqRlGU6eYV7IUMCq0=
|
||||
golang.org/x/sys v0.0.0-20191020152052-9984515f0562/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 h1:ZBzSG/7F4eNKz2L3GE9o300RX0Az1Bw5HF7PDraD+qU=
|
||||
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 h1:z99zHgr7hKfrUcX/KsoJk5FJfjTceCKIp96+biqP4To=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/utils"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
@@ -87,10 +86,6 @@ func (c *Config) Init() (err error) {
|
||||
c.VulnType = strings.Split(c.vulnType, ",")
|
||||
c.AppVersion = c.context.App.Version
|
||||
|
||||
if c.Quiet {
|
||||
utils.Quiet = true
|
||||
}
|
||||
|
||||
// --clear-cache doesn't conduct the scan
|
||||
if c.ClearCache {
|
||||
return nil
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
func initializeDBClient() db.Client {
|
||||
func initializeDBClient(quiet bool) db.Client {
|
||||
wire.Build(db.SuperSet)
|
||||
return db.Client{}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,8 @@ func ClearCache() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownloadDB(appVersion, cacheDir string, light, skipUpdate bool) error {
|
||||
client := initializeDBClient()
|
||||
func DownloadDB(appVersion, cacheDir string, quiet, light, skipUpdate bool) error {
|
||||
client := initializeDBClient(quiet)
|
||||
ctx := context.Background()
|
||||
needsUpdate, err := client.NeedsUpdate(ctx, appVersion, light, skipUpdate)
|
||||
if err != nil {
|
||||
@@ -46,6 +46,7 @@ func DownloadDB(appVersion, cacheDir string, light, skipUpdate bool) error {
|
||||
if err = db.Close(); err != nil {
|
||||
return xerrors.Errorf("failed db close: %w", err)
|
||||
}
|
||||
log.Logger.Info("Downloading DB...")
|
||||
if err := client.Download(ctx, cacheDir, light); err != nil {
|
||||
return xerrors.Errorf("failed to download vulnerability DB: %w", err)
|
||||
}
|
||||
|
||||
@@ -9,15 +9,17 @@ import (
|
||||
db2 "github.com/aquasecurity/trivy-db/pkg/db"
|
||||
"github.com/aquasecurity/trivy/pkg/db"
|
||||
"github.com/aquasecurity/trivy/pkg/github"
|
||||
"github.com/aquasecurity/trivy/pkg/indicator"
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
// Injectors from inject.go:
|
||||
|
||||
func initializeDBClient() db.Client {
|
||||
func initializeDBClient(quiet bool) db.Client {
|
||||
config := db2.Config{}
|
||||
client := github.NewClient()
|
||||
progressBar := indicator.NewProgressBar(quiet)
|
||||
realClock := clock.RealClock{}
|
||||
dbClient := db.NewClient(config, client, realClock)
|
||||
dbClient := db.NewClient(config, client, progressBar, realClock)
|
||||
return dbClient
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package config
|
||||
import (
|
||||
"github.com/urfave/cli"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/utils"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
@@ -48,8 +46,5 @@ func (c *Config) Init() (err error) {
|
||||
|
||||
c.AppVersion = c.context.App.Version
|
||||
|
||||
// A server always suppresses a progress bar
|
||||
utils.Quiet = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func run(c config.Config) (err error) {
|
||||
}
|
||||
|
||||
// download the database file
|
||||
if err = operation.DownloadDB(c.AppVersion, c.CacheDir, false, c.SkipUpdate); err != nil {
|
||||
if err = operation.DownloadDB(c.AppVersion, c.CacheDir, true, false, c.SkipUpdate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/utils"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
@@ -109,10 +108,6 @@ func (c *Config) Init() (err error) {
|
||||
c.VulnType = strings.Split(c.vulnType, ",")
|
||||
c.AppVersion = c.context.App.Version
|
||||
|
||||
if c.Quiet || c.NoProgress {
|
||||
utils.Quiet = true
|
||||
}
|
||||
|
||||
// --clear-cache, --download-db-only and --reset don't conduct the scan
|
||||
if c.ClearCache || c.DownloadDBOnly || c.Reset {
|
||||
return nil
|
||||
|
||||
@@ -50,7 +50,8 @@ func run(c config.Config) (err error) {
|
||||
}
|
||||
|
||||
// download the database file
|
||||
if err = operation.DownloadDB(c.AppVersion, c.CacheDir, c.Light, c.SkipUpdate); err != nil {
|
||||
noProgress := c.Quiet || c.NoProgress
|
||||
if err = operation.DownloadDB(c.AppVersion, c.CacheDir, noProgress, c.Light, c.SkipUpdate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
30
pkg/db/db.go
30
pkg/db/db.go
@@ -13,8 +13,8 @@ import (
|
||||
|
||||
"github.com/aquasecurity/trivy-db/pkg/db"
|
||||
"github.com/aquasecurity/trivy/pkg/github"
|
||||
"github.com/aquasecurity/trivy/pkg/indicator"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,11 +23,21 @@ const (
|
||||
)
|
||||
|
||||
var SuperSet = wire.NewSet(
|
||||
// indicator.ProgressBar
|
||||
indicator.NewProgressBar,
|
||||
|
||||
// clock.Clock
|
||||
wire.Struct(new(clock.RealClock)),
|
||||
wire.Bind(new(clock.Clock), new(clock.RealClock)),
|
||||
|
||||
// db.Config
|
||||
wire.Struct(new(db.Config)),
|
||||
|
||||
// github.Client
|
||||
github.NewClient,
|
||||
wire.Bind(new(github.Operation), new(github.Client)),
|
||||
|
||||
// db.Client
|
||||
NewClient,
|
||||
wire.Bind(new(Operation), new(Client)),
|
||||
)
|
||||
@@ -44,13 +54,15 @@ type dbOperation interface {
|
||||
type Client struct {
|
||||
dbc dbOperation
|
||||
githubClient github.Operation
|
||||
pb indicator.ProgressBar
|
||||
clock clock.Clock
|
||||
}
|
||||
|
||||
func NewClient(dbc db.Config, githubClient github.Operation, clock clock.Clock) Client {
|
||||
func NewClient(dbc db.Config, githubClient github.Operation, pb indicator.ProgressBar, clock clock.Clock) Client {
|
||||
return Client{
|
||||
dbc: dbc,
|
||||
githubClient: githubClient,
|
||||
pb: pb,
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
@@ -102,23 +114,21 @@ func (c Client) NeedsUpdate(ctx context.Context, cliVersion string, light, skip
|
||||
|
||||
func (c Client) Download(ctx context.Context, cacheDir string, light bool) error {
|
||||
dbFile := fullDB
|
||||
message := " Downloading Full DB file..."
|
||||
if light {
|
||||
dbFile = lightDB
|
||||
message = " Downloading Lightweight DB file..."
|
||||
}
|
||||
|
||||
spinner := utils.NewSpinner(message)
|
||||
spinner.Start()
|
||||
defer spinner.Stop()
|
||||
|
||||
rc, err := c.githubClient.DownloadDB(ctx, dbFile)
|
||||
rc, size, err := c.githubClient.DownloadDB(ctx, dbFile)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to download vulnerability DB: %w", err)
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
gr, err := gzip.NewReader(rc)
|
||||
bar := c.pb.Start(int64(size))
|
||||
barReader := bar.NewProxyReader(rc)
|
||||
defer bar.Finish()
|
||||
|
||||
gr, err := gzip.NewReader(barReader)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("invalid gzip file: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/indicator"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -208,30 +209,21 @@ func TestClient_NeedsUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_Download(t *testing.T) {
|
||||
type downloadDBOutput struct {
|
||||
fileName string
|
||||
err error
|
||||
}
|
||||
type downloadDB struct {
|
||||
input string
|
||||
output downloadDBOutput
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
light bool
|
||||
downloadDB []downloadDB
|
||||
downloadDB []github.DownloadDBExpectation
|
||||
expectedContent []byte
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
light: false,
|
||||
downloadDB: []downloadDB{
|
||||
downloadDB: []github.DownloadDBExpectation{
|
||||
{
|
||||
input: fullDB,
|
||||
output: downloadDBOutput{
|
||||
fileName: "testdata/test.db.gz",
|
||||
Args: github.DownloadDBInput{FileName: fullDB},
|
||||
ReturnArgs: github.DownloadDBOutput{
|
||||
FileName: "testdata/test.db.gz",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -239,11 +231,11 @@ func TestClient_Download(t *testing.T) {
|
||||
{
|
||||
name: "DownloadDB returns an error",
|
||||
light: false,
|
||||
downloadDB: []downloadDB{
|
||||
downloadDB: []github.DownloadDBExpectation{
|
||||
{
|
||||
input: fullDB,
|
||||
output: downloadDBOutput{
|
||||
err: xerrors.New("download failed"),
|
||||
Args: github.DownloadDBInput{FileName: fullDB},
|
||||
ReturnArgs: github.DownloadDBOutput{
|
||||
Err: xerrors.New("download failed"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -252,11 +244,11 @@ func TestClient_Download(t *testing.T) {
|
||||
{
|
||||
name: "invalid gzip",
|
||||
light: false,
|
||||
downloadDB: []downloadDB{
|
||||
downloadDB: []github.DownloadDBExpectation{
|
||||
{
|
||||
input: fullDB,
|
||||
output: downloadDBOutput{
|
||||
fileName: "testdata/invalid.db.gz",
|
||||
Args: github.DownloadDBInput{FileName: fullDB},
|
||||
ReturnArgs: github.DownloadDBOutput{
|
||||
FileName: "testdata/invalid.db.gz",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -269,19 +261,8 @@ func TestClient_Download(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockGitHubClient := new(github.MockClient)
|
||||
for _, dd := range tc.downloadDB {
|
||||
var rc io.ReadCloser
|
||||
if dd.output.fileName != "" {
|
||||
f, err := os.Open(dd.output.fileName)
|
||||
assert.NoError(t, err, tc.name)
|
||||
rc = f
|
||||
}
|
||||
|
||||
mockGitHubClient.On("DownloadDB", mock.Anything, dd.input).Return(
|
||||
rc, dd.output.err,
|
||||
)
|
||||
}
|
||||
mockGitHubClient, err := github.NewMockClient(tc.downloadDB)
|
||||
require.NoError(t, err, tc.name)
|
||||
|
||||
dir, err := ioutil.TempDir("", "db")
|
||||
require.NoError(t, err, tc.name)
|
||||
@@ -290,7 +271,8 @@ func TestClient_Download(t *testing.T) {
|
||||
err = db.Init(dir)
|
||||
require.NoError(t, err, tc.name)
|
||||
|
||||
client := NewClient(db.Config{}, mockGitHubClient, nil)
|
||||
pb := indicator.NewProgressBar(true)
|
||||
client := NewClient(db.Config{}, mockGitHubClient, pb, nil)
|
||||
ctx := context.Background()
|
||||
err = client.Download(ctx, dir, tc.light)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/aquasecurity/trivy-db/pkg/db"
|
||||
@@ -42,7 +43,7 @@ func (r Repository) DownloadAsset(ctx context.Context, id int64) (io.ReadCloser,
|
||||
}
|
||||
|
||||
type Operation interface {
|
||||
DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, error)
|
||||
DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, int, error)
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
@@ -72,11 +73,11 @@ func NewClient() Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c Client) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, error) {
|
||||
func (c Client) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, int, error) {
|
||||
options := github.ListOptions{}
|
||||
releases, _, err := c.Repository.ListReleases(ctx, &options)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to list releases: %w", err)
|
||||
return nil, 0, xerrors.Errorf("failed to list releases: %w", err)
|
||||
}
|
||||
|
||||
sort.Slice(releases, func(i, j int) bool {
|
||||
@@ -91,37 +92,42 @@ func (c Client) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser,
|
||||
}
|
||||
|
||||
for _, asset := range release.Assets {
|
||||
rc, err := c.downloadAsset(ctx, asset, fileName)
|
||||
rc, size, err := c.downloadAsset(ctx, asset, fileName)
|
||||
if err != nil {
|
||||
log.Logger.Debug(err)
|
||||
continue
|
||||
}
|
||||
return rc, nil
|
||||
return rc, size, nil
|
||||
}
|
||||
|
||||
}
|
||||
return nil, xerrors.New("DB file not found")
|
||||
return nil, 0, xerrors.New("DB file not found")
|
||||
}
|
||||
|
||||
func (c Client) downloadAsset(ctx context.Context, asset github.ReleaseAsset, fileName string) (io.ReadCloser, error) {
|
||||
func (c Client) downloadAsset(ctx context.Context, asset github.ReleaseAsset, fileName string) (io.ReadCloser, int, error) {
|
||||
log.Logger.Debugf("asset name: %s", asset.GetName())
|
||||
if asset.GetName() != fileName {
|
||||
return nil, xerrors.New("file name doesn't match")
|
||||
return nil, 0, xerrors.New("file name doesn't match")
|
||||
}
|
||||
|
||||
rc, url, err := c.Repository.DownloadAsset(ctx, asset.GetID())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unable to download the asset: %w", err)
|
||||
return nil, 0, xerrors.Errorf("unable to download the asset: %w", err)
|
||||
}
|
||||
|
||||
if rc != nil {
|
||||
return rc, nil
|
||||
return rc, asset.GetSize(), nil
|
||||
}
|
||||
|
||||
log.Logger.Debugf("asset URL: %s", url)
|
||||
resp, err := http.Get(url)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
return nil, xerrors.Errorf("unable to download the asset via URL: %w", err)
|
||||
return nil, 0, xerrors.Errorf("unable to download the asset via URL: %w", err)
|
||||
}
|
||||
return resp.Body, nil
|
||||
|
||||
size, err := strconv.Atoi(resp.Header.Get("Content-Length"))
|
||||
if err != nil {
|
||||
return nil, 0, xerrors.Errorf("invalid size: %w", err)
|
||||
}
|
||||
return resp.Body, size, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package github
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
@@ -11,15 +12,46 @@ type MockClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockClient) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, error) {
|
||||
type DownloadDBInput struct {
|
||||
FileName string
|
||||
}
|
||||
type DownloadDBOutput struct {
|
||||
FileName string
|
||||
Size int
|
||||
Err error
|
||||
}
|
||||
type DownloadDBExpectation struct {
|
||||
Args DownloadDBInput
|
||||
ReturnArgs DownloadDBOutput
|
||||
}
|
||||
|
||||
func NewMockClient(downloadDBExpectations []DownloadDBExpectation) (*MockClient, error) {
|
||||
mockDetector := new(MockClient)
|
||||
for _, e := range downloadDBExpectations {
|
||||
var rc io.ReadCloser
|
||||
if e.ReturnArgs.FileName != "" {
|
||||
f, err := os.Open(e.ReturnArgs.FileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rc = f
|
||||
}
|
||||
|
||||
mockDetector.On("DownloadDB", mock.Anything, e.Args.FileName).Return(
|
||||
rc, e.ReturnArgs.Size, e.ReturnArgs.Err)
|
||||
}
|
||||
return mockDetector, nil
|
||||
}
|
||||
|
||||
func (_m *MockClient) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, int, error) {
|
||||
ret := _m.Called(ctx, fileName)
|
||||
ret0 := ret.Get(0)
|
||||
if ret0 == nil {
|
||||
return nil, ret.Error(1)
|
||||
return nil, ret.Int(1), ret.Error(2)
|
||||
}
|
||||
rc, ok := ret0.(io.ReadCloser)
|
||||
if !ok {
|
||||
return nil, ret.Error(1)
|
||||
return nil, ret.Int(1), ret.Error(2)
|
||||
}
|
||||
return rc, ret.Error(1)
|
||||
return rc, ret.Int(1), ret.Error(2)
|
||||
}
|
||||
|
||||
@@ -451,7 +451,7 @@ func TestClient_DownloadDB(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
rc, err := client.DownloadDB(ctx, tc.fileName)
|
||||
rc, _, err := client.DownloadDB(ctx, tc.fileName)
|
||||
|
||||
switch {
|
||||
case tc.expectedError != nil:
|
||||
|
||||
40
pkg/indicator/progress.go
Normal file
40
pkg/indicator/progress.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package indicator
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/cheggaaa/pb/v3"
|
||||
)
|
||||
|
||||
type ProgressBar struct {
|
||||
quiet bool
|
||||
}
|
||||
|
||||
func NewProgressBar(quiet bool) ProgressBar {
|
||||
return ProgressBar{quiet: quiet}
|
||||
}
|
||||
|
||||
func (p ProgressBar) Start(total int64) Bar {
|
||||
if p.quiet {
|
||||
return Bar{}
|
||||
}
|
||||
bar := pb.Full.Start64(total)
|
||||
return Bar{bar: bar}
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
bar *pb.ProgressBar
|
||||
}
|
||||
|
||||
func (b Bar) NewProxyReader(r io.Reader) io.Reader {
|
||||
if b.bar == nil {
|
||||
return r
|
||||
}
|
||||
return b.bar.NewProxyReader(r)
|
||||
}
|
||||
func (b Bar) Finish() {
|
||||
if b.bar == nil {
|
||||
return
|
||||
}
|
||||
b.bar.Finish()
|
||||
}
|
||||
@@ -19,7 +19,7 @@ func initializeLibServer() *library.Server {
|
||||
return &library.Server{}
|
||||
}
|
||||
|
||||
func initializeDBWorker() dbWorker {
|
||||
func initializeDBWorker(quiet bool) dbWorker {
|
||||
wire.Build(SuperSet)
|
||||
return dbWorker{}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func ListenAndServe(addr string, c config.Config) error {
|
||||
}
|
||||
|
||||
go func() {
|
||||
worker := initializeDBWorker()
|
||||
worker := initializeDBWorker(true)
|
||||
ctx := context.Background()
|
||||
for {
|
||||
time.Sleep(1 * time.Hour)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
library2 "github.com/aquasecurity/trivy/pkg/detector/library"
|
||||
ospkg2 "github.com/aquasecurity/trivy/pkg/detector/ospkg"
|
||||
"github.com/aquasecurity/trivy/pkg/github"
|
||||
"github.com/aquasecurity/trivy/pkg/indicator"
|
||||
"github.com/aquasecurity/trivy/pkg/rpc/server/library"
|
||||
"github.com/aquasecurity/trivy/pkg/rpc/server/ospkg"
|
||||
"github.com/aquasecurity/trivy/pkg/vulnerability"
|
||||
@@ -36,11 +37,12 @@ func initializeLibServer() *library.Server {
|
||||
return server
|
||||
}
|
||||
|
||||
func initializeDBWorker() dbWorker {
|
||||
func initializeDBWorker(quiet bool) dbWorker {
|
||||
config := db.Config{}
|
||||
client := github.NewClient()
|
||||
progressBar := indicator.NewProgressBar(quiet)
|
||||
realClock := clock.RealClock{}
|
||||
dbClient := db2.NewClient(config, client, realClock)
|
||||
dbClient := db2.NewClient(config, client, progressBar, realClock)
|
||||
serverDbWorker := newDBWorker(dbClient)
|
||||
return serverDbWorker
|
||||
}
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
pb "gopkg.in/cheggaaa/pb.v1"
|
||||
)
|
||||
|
||||
var (
|
||||
Quiet = false
|
||||
)
|
||||
|
||||
type Spinner struct {
|
||||
client *spinner.Spinner
|
||||
}
|
||||
|
||||
func NewSpinner(suffix string) *Spinner {
|
||||
if Quiet {
|
||||
return &Spinner{}
|
||||
}
|
||||
s := spinner.New(spinner.CharSets[36], 100*time.Millisecond)
|
||||
s.Suffix = suffix
|
||||
return &Spinner{client: s}
|
||||
}
|
||||
|
||||
func (s *Spinner) Start() {
|
||||
if s.client == nil {
|
||||
return
|
||||
}
|
||||
s.client.Start()
|
||||
}
|
||||
func (s *Spinner) Stop() {
|
||||
if s.client == nil {
|
||||
return
|
||||
}
|
||||
s.client.Stop()
|
||||
}
|
||||
|
||||
// TODO: Expose an interface for progressbar
|
||||
type ProgressBar struct {
|
||||
client *pb.ProgressBar
|
||||
}
|
||||
|
||||
func PbStartNew(total int) *ProgressBar {
|
||||
if Quiet {
|
||||
return &ProgressBar{}
|
||||
}
|
||||
bar := pb.StartNew(total)
|
||||
return &ProgressBar{client: bar}
|
||||
}
|
||||
|
||||
func (p *ProgressBar) Increment() {
|
||||
if p.client == nil {
|
||||
return
|
||||
}
|
||||
p.client.Increment()
|
||||
}
|
||||
func (p *ProgressBar) Finish() {
|
||||
if p.client == nil {
|
||||
return
|
||||
}
|
||||
p.client.Finish()
|
||||
}
|
||||
Reference in New Issue
Block a user