feat(db): show progress when downloading the DB (#317)

* fix(github): return db size

* fix(github_mock): add size

* feat(indicator): add progress bar

* refactor(config): remove global Quiet

* fix(db): take progress bar as an argument

* fix(progress): inject progress bar
This commit is contained in:
Teppei Fukuda
2019-12-16 19:23:08 +02:00
committed by GitHub
parent bc8f613ba6
commit cee08c38f4
20 changed files with 163 additions and 159 deletions

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"os"
"sort"
"strconv"
"strings"
"github.com/aquasecurity/trivy-db/pkg/db"
@@ -42,7 +43,7 @@ func (r Repository) DownloadAsset(ctx context.Context, id int64) (io.ReadCloser,
}
type Operation interface {
DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, error)
DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, int, error)
}
type Client struct {
@@ -72,11 +73,11 @@ func NewClient() Client {
}
}
func (c Client) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, error) {
func (c Client) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser, int, error) {
options := github.ListOptions{}
releases, _, err := c.Repository.ListReleases(ctx, &options)
if err != nil {
return nil, xerrors.Errorf("failed to list releases: %w", err)
return nil, 0, xerrors.Errorf("failed to list releases: %w", err)
}
sort.Slice(releases, func(i, j int) bool {
@@ -91,37 +92,42 @@ func (c Client) DownloadDB(ctx context.Context, fileName string) (io.ReadCloser,
}
for _, asset := range release.Assets {
rc, err := c.downloadAsset(ctx, asset, fileName)
rc, size, err := c.downloadAsset(ctx, asset, fileName)
if err != nil {
log.Logger.Debug(err)
continue
}
return rc, nil
return rc, size, nil
}
}
return nil, xerrors.New("DB file not found")
return nil, 0, xerrors.New("DB file not found")
}
func (c Client) downloadAsset(ctx context.Context, asset github.ReleaseAsset, fileName string) (io.ReadCloser, error) {
func (c Client) downloadAsset(ctx context.Context, asset github.ReleaseAsset, fileName string) (io.ReadCloser, int, error) {
log.Logger.Debugf("asset name: %s", asset.GetName())
if asset.GetName() != fileName {
return nil, xerrors.New("file name doesn't match")
return nil, 0, xerrors.New("file name doesn't match")
}
rc, url, err := c.Repository.DownloadAsset(ctx, asset.GetID())
if err != nil {
return nil, xerrors.Errorf("unable to download the asset: %w", err)
return nil, 0, xerrors.Errorf("unable to download the asset: %w", err)
}
if rc != nil {
return rc, nil
return rc, asset.GetSize(), nil
}
log.Logger.Debugf("asset URL: %s", url)
resp, err := http.Get(url)
if err != nil || resp.StatusCode != http.StatusOK {
return nil, xerrors.Errorf("unable to download the asset via URL: %w", err)
return nil, 0, xerrors.Errorf("unable to download the asset via URL: %w", err)
}
return resp.Body, nil
size, err := strconv.Atoi(resp.Header.Get("Content-Length"))
if err != nil {
return nil, 0, xerrors.Errorf("invalid size: %w", err)
}
return resp.Body, size, nil
}