mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-05 20:40:16 -08:00
Signed-off-by: knqyf263 <knqyf263@gmail.com> Co-authored-by: knqyf263 <knqyf263@users.noreply.github.com>
193 lines
5.0 KiB
Go
193 lines
5.0 KiB
Go
package downloader
|
|
|
|
import (
|
|
"cmp"
|
|
"context"
|
|
"errors"
|
|
"maps"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/go-github/v62/github"
|
|
"github.com/hashicorp/go-getter"
|
|
"github.com/samber/lo"
|
|
"golang.org/x/xerrors"
|
|
|
|
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
|
|
xos "github.com/aquasecurity/trivy/pkg/x/os"
|
|
)
|
|
|
|
var ErrSkipDownload = errors.New("skip download")
|
|
|
|
type Options struct {
|
|
Insecure bool
|
|
Auth Auth
|
|
ETag string
|
|
ClientMode getter.ClientMode
|
|
}
|
|
|
|
type Auth struct {
|
|
Username string
|
|
Password string
|
|
Token string
|
|
}
|
|
|
|
// DownloadToTempDir downloads the configured source to a temp dir.
|
|
func DownloadToTempDir(ctx context.Context, src string, opts Options) (string, error) {
|
|
tempDir, err := xos.MkdirTemp("", "download-")
|
|
if err != nil {
|
|
return "", xerrors.Errorf("failed to create a temp dir: %w", err)
|
|
}
|
|
|
|
pwd, err := os.Getwd()
|
|
if err != nil {
|
|
return "", xerrors.Errorf("unable to get the current dir: %w", err)
|
|
}
|
|
|
|
if _, err = Download(ctx, src, tempDir, pwd, opts); err != nil {
|
|
return "", xerrors.Errorf("download error: %w", err)
|
|
}
|
|
|
|
return tempDir, nil
|
|
}
|
|
|
|
// Download downloads the configured source to the destination.
|
|
func Download(ctx context.Context, src, dst, pwd string, opts Options) (string, error) {
|
|
// go-getter doesn't allow the dst directory already exists if the src is directory.
|
|
_ = os.RemoveAll(dst)
|
|
|
|
var clientOpts []getter.ClientOption
|
|
if opts.Insecure {
|
|
clientOpts = append(clientOpts, getter.WithInsecure())
|
|
}
|
|
|
|
// Clone the global map so that it will not be accessed concurrently.
|
|
getters := maps.Clone(getter.Getters)
|
|
|
|
// Overwrite the file getter so that a file will be copied
|
|
getters["file"] = &getter.FileGetter{Copy: true}
|
|
|
|
// Since "httpGetter" is a global pointer and the state is shared,
|
|
// once it is executed without "WithInsecure()",
|
|
// it cannot enable WithInsecure() afterwards because its state is preserved.
|
|
// Therefore, we need to create a new "HttpGetter" instance every time.
|
|
// cf. https://github.com/hashicorp/go-getter/blob/5a63fd9c0d5b8da8a6805e8c283f46f0dacb30b3/get.go#L63-L65
|
|
transport := NewCustomTransport(opts)
|
|
httpGetter := &getter.HttpGetter{
|
|
Netrc: true,
|
|
Client: &http.Client{
|
|
Transport: transport,
|
|
Timeout: time.Minute * 5,
|
|
},
|
|
}
|
|
getters["http"] = httpGetter
|
|
getters["https"] = httpGetter
|
|
|
|
// Build the client
|
|
client := &getter.Client{
|
|
Ctx: ctx,
|
|
Src: src,
|
|
Dst: dst,
|
|
Pwd: pwd,
|
|
Getters: getters,
|
|
Mode: lo.Ternary(opts.ClientMode == 0, getter.ClientModeAny, opts.ClientMode),
|
|
Options: clientOpts,
|
|
}
|
|
|
|
if err := client.Get(); err != nil {
|
|
return "", xerrors.Errorf("failed to download %s: %w", src, err)
|
|
}
|
|
|
|
return transport.newETag, nil
|
|
}
|
|
|
|
type CustomTransport struct {
|
|
auth Auth
|
|
cachedETag string
|
|
newETag string
|
|
}
|
|
|
|
func NewCustomTransport(opts Options) *CustomTransport {
|
|
return &CustomTransport{
|
|
auth: opts.Auth,
|
|
cachedETag: opts.ETag,
|
|
}
|
|
}
|
|
|
|
func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
if t.cachedETag != "" {
|
|
req.Header.Set("If-None-Match", t.cachedETag)
|
|
}
|
|
if t.auth.Token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+t.auth.Token)
|
|
} else if t.auth.Username != "" || t.auth.Password != "" {
|
|
req.SetBasicAuth(t.auth.Username, t.auth.Password)
|
|
}
|
|
|
|
transport := xhttp.Transport(req.Context())
|
|
if req.URL.Host == "github.com" {
|
|
transport = NewGitHubTransport(req.URL, t.auth.Token)
|
|
}
|
|
|
|
res, err := transport.RoundTrip(req)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed to round trip: %w", err)
|
|
}
|
|
|
|
switch res.StatusCode {
|
|
case http.StatusOK, http.StatusPartialContent:
|
|
// Update the ETag
|
|
t.newETag = res.Header.Get("ETag")
|
|
case http.StatusNotModified:
|
|
return nil, ErrSkipDownload
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func NewGitHubTransport(u *url.URL, token string) http.RoundTripper {
|
|
client := newGitHubClient(token)
|
|
ss := strings.SplitN(u.Path, "/", 4)
|
|
if len(ss) < 4 || strings.HasPrefix(ss[3], "archive/") || strings.HasPrefix(ss[3], "releases/") ||
|
|
strings.HasPrefix(ss[3], "tags/") {
|
|
// Use the default transport from go-github for authentication
|
|
return client.Client().Transport
|
|
}
|
|
|
|
return &GitHubContentTransport{
|
|
owner: ss[1],
|
|
repo: ss[2],
|
|
filePath: ss[3],
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
// GitHubContentTransport is a round tripper for downloading the GitHub content.
|
|
type GitHubContentTransport struct {
|
|
owner string
|
|
repo string
|
|
filePath string
|
|
client *github.Client
|
|
}
|
|
|
|
// RoundTrip calls the GitHub API to download the content.
|
|
func (t *GitHubContentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
_, res, err := t.client.Repositories.DownloadContents(req.Context(), t.owner, t.repo, t.filePath, nil)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed to get the file content: %w", err)
|
|
}
|
|
return res.Response, nil
|
|
}
|
|
|
|
func newGitHubClient(token string) *github.Client {
|
|
client := github.NewClient(xhttp.Client())
|
|
token = cmp.Or(token, os.Getenv("GITHUB_TOKEN"))
|
|
if token != "" {
|
|
client = client.WithAuthToken(token)
|
|
}
|
|
return client
|
|
}
|