Files
trivy/pkg/notification/notice.go
2025-08-01 08:27:30 +00:00

246 lines
6.5 KiB
Go

package notification
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"runtime"
"strconv"
"strings"
"time"
"github.com/samber/lo"
"github.com/aquasecurity/go-version/pkg/semver"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/version/app"
xhttp "github.com/aquasecurity/trivy/pkg/x/http"
)
type VersionChecker struct {
updatesApi string
commandName string
cliOptions *flag.Options
done bool
responseReceived bool
currentVersion string
latestVersion updateResponse
}
// NewVersionChecker creates a new VersionChecker with the default
// updates API URL.
func NewVersionChecker(commandName string, cliOptions *flag.Options) *VersionChecker {
v := &VersionChecker{
updatesApi: "https://check.trivy.dev/updates",
currentVersion: app.Version(),
commandName: commandName,
cliOptions: cliOptions,
}
return v
}
// RunUpdateCheck makes a best efforts request to determine the
// latest version and any announcements
// Logic:
// 1. if skipUpdateCheck is true AND telemetryDisabled are both true, skip the request
// 2. if skipUpdateCheck is true AND telemetryDisabled is false, run check with metric details but suppress output
// 3. if skipUpdateCheck is false AND telemetryDisabled is true, run update check but don't send any metric identifiers
func (v *VersionChecker) RunUpdateCheck(ctx context.Context) {
logger := log.WithPrefix("notification")
if v.cliOptions.SkipVersionCheck && v.cliOptions.DisableTelemetry {
logger.Debug("Skipping update check and metric ping")
return
}
go func() {
logger.Debug("Running version check")
commandParts := v.getFlags()
client := xhttp.ClientWithContext(ctx, xhttp.WithTimeout(3*time.Second))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, v.updatesApi, http.NoBody)
if err != nil {
logger.Warn("Failed to create a request for Trivy api", log.Err(err))
return
}
// if the user hasn't disabled metrics, send the anonymous information as headers
if !v.cliOptions.DisableTelemetry {
req.Header.Set("Trivy-Identifier", uniqueIdentifier())
req.Header.Set("Trivy-Command", v.commandName)
req.Header.Set("Trivy-Flags", commandParts)
req.Header.Set("Trivy-OS", runtime.GOOS)
req.Header.Set("Trivy-Arch", runtime.GOARCH)
}
req.Header.Set("User-Agent", fmt.Sprintf("trivy/%s", v.currentVersion))
resp, err := client.Do(req)
if err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.Debug("Failed getting response from Trivy api", log.Err(err))
}
return
} else if resp.StatusCode != http.StatusOK {
logger.Debug("Unexpected status code from Trivy api", log.Int("status_code", resp.StatusCode))
return
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(&v.latestVersion); err != nil {
logger.Debug("Failed to decode the Trivy response", log.Err(err))
return
}
// enable priting if update allowed and quiet mode is not set
if !v.cliOptions.SkipVersionCheck && !v.cliOptions.Quiet {
v.responseReceived = true
}
logger.Debug("Version check completed", log.String("latest_version", v.latestVersion.Trivy.LatestVersion))
v.done = true
}()
}
// PrintNotices prints any announcements or warnings
// to the output writer, most likely stderr
func (v *VersionChecker) PrintNotices(ctx context.Context, output io.Writer) {
if !v.responseReceived {
return
}
logger := log.WithPrefix("notification")
var notices []string
cv, err := v.CurrentVersion()
if err != nil {
return
}
lv, err := v.LatestVersion()
if err != nil {
return
}
notices = append(notices, v.Warnings()...)
for _, announcement := range v.Announcements() {
if announcement.shouldDisplay(ctx, cv) {
notices = append(notices, announcement.Announcement)
}
}
if cv.LessThan(lv) {
notices = append(notices, fmt.Sprintf("Version %s of Trivy is now available, current version is %s", lv, cv))
}
if len(notices) > 0 {
logger.Debug("Printing notices")
fmt.Fprintf(output, "\n📣 \x1b[34mNotices:\x1b[0m\n")
for _, notice := range notices {
fmt.Fprintf(output, " - %s\n", notice)
}
fmt.Fprintln(output)
fmt.Fprintln(output, "To suppress version checks, run Trivy scans with the --skip-version-check flag")
fmt.Fprintln(output)
}
}
func (v *VersionChecker) CurrentVersion() (semver.Version, error) {
current, err := semver.Parse(strings.TrimPrefix(v.currentVersion, "v"))
if err != nil {
return semver.Version{}, fmt.Errorf("failed to parse current version: %w", err)
}
return current, nil
}
func (v *VersionChecker) LatestVersion() (semver.Version, error) {
if v.responseReceived {
latest, err := semver.Parse(strings.TrimPrefix(v.latestVersion.Trivy.LatestVersion, "v"))
if err != nil {
return semver.Version{}, fmt.Errorf("failed to parse latest version: %w", err)
}
return latest, nil
}
return semver.Version{}, errors.New("no response received from version check")
}
func (v *VersionChecker) Announcements() []announcement {
if v.responseReceived {
return v.latestVersion.Announcements
}
return nil
}
func (v *VersionChecker) Warnings() []string {
if v.responseReceived {
return v.latestVersion.Warnings
}
return nil
}
func (fd *flexibleTime) UnmarshalJSON(b []byte) error {
s := strings.Trim(string(b), `"`)
if s == "" {
return nil
}
// Try parsing with time component
layouts := []string{
time.RFC3339,
"2006-01-02",
time.RFC1123,
}
var err error
for _, layout := range layouts {
var t time.Time
t, err = time.Parse(layout, s)
if err == nil {
fd.Time = t
return nil
}
}
return fmt.Errorf("unable to parse date: %s", s)
}
func (v *VersionChecker) getFlags() string {
var flags []string
for _, f := range v.cliOptions.GetUsedFlags() {
name := f.GetName()
if name == "" {
continue // Skip flags without a name
}
value := lo.Ternary(!f.IsTelemetrySafe(), "***", getFlagValue(f))
flags = append(flags, fmt.Sprintf("--%s=%s", name, value))
}
return strings.Join(flags, " ")
}
func getFlagValue(f flag.Flagger) string {
type flagger[T flag.FlagType] interface {
Value() T
}
switch ff := f.(type) {
case flagger[string]:
return ff.Value()
case flagger[int]:
return strconv.Itoa(ff.Value())
case flagger[float64]:
return fmt.Sprintf("%f", ff.Value())
case flagger[bool]:
return strconv.FormatBool(ff.Value())
case flagger[time.Duration]:
return ff.Value().String()
case flagger[[]string]:
return strings.Join(ff.Value(), ",")
default:
return "***" // Default case for unsupported types
}
}