diff --git a/pkg/db/db.go b/pkg/db/db.go index 4bf5006775..7811b5a8af 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "path/filepath" + "strconv" "github.com/aquasecurity/trivy/pkg/log" @@ -14,13 +15,17 @@ import ( bolt "github.com/etcd-io/bbolt" ) +const ( + SchemaVersion = 1 +) + var ( db *bolt.DB dbDir string ) type Operations interface { - SetVersion(string) error + SetVersion(int) error Update(string, string, string, interface{}) error BatchUpdate(func(*bolt.Tx) error) error PutNestedBucket(*bolt.Tx, string, string, string, interface{}) error @@ -67,19 +72,22 @@ func Reset() error { return nil } -func GetVersion() string { - var version string +func GetVersion() int { value, err := Get("trivy", "metadata", "version") - if err != nil { - return "" + if err != nil || len(value) == 0 { + // initial run + return 0 } - if err = json.Unmarshal(value, &version); err != nil { - return "" + + version, err := strconv.Atoi(string(value)) + if err != nil { + // old trivy version + return 1 } return version } -func (dbc Config) SetVersion(version string) error { +func (dbc Config) SetVersion(version int) error { err := dbc.Update("trivy", "metadata", "version", version) if err != nil { return xerrors.Errorf("failed to save DB version: %w", err) diff --git a/pkg/db/db_mock.go b/pkg/db/db_mock.go index c6e0bbba66..2a3f770a64 100644 --- a/pkg/db/db_mock.go +++ b/pkg/db/db_mock.go @@ -9,7 +9,7 @@ type MockDBConfig struct { mock.Mock } -func (_m *MockDBConfig) SetVersion(version string) error { +func (_m *MockDBConfig) SetVersion(version int) error { ret := _m.Called(version) return ret.Error(0) } diff --git a/pkg/git/git.go b/pkg/git/git.go index e3d45c2839..14951b015b 100644 --- a/pkg/git/git.go +++ b/pkg/git/git.go @@ -57,7 +57,7 @@ func CloneOrPull(url, repoPath string) (map[string]struct{}, error) { } // Need to refresh all vulnerabilities - if db.GetVersion() == "" { + if db.GetVersion() == 0 { err = filepath.Walk(repoPath, func(path string, info os.FileInfo, err error) error { if err != nil { return err diff --git a/pkg/run.go b/pkg/run.go index 460741793d..93e41256a1 100644 --- a/pkg/run.go +++ b/pkg/run.go @@ -21,8 +21,6 @@ import ( ) func Run(c *cli.Context) (err error) { - cliVersion := c.App.Version - if c.Bool("quiet") || c.Bool("no-progress") { utils.Quiet = true } @@ -87,7 +85,7 @@ func Run(c *cli.Context) (err error) { needRefresh := false dbVersion := db.GetVersion() - if dbVersion != "" && dbVersion != cliVersion { + if 0 < dbVersion && dbVersion < db.SchemaVersion { if !refresh && !autoRefresh { return xerrors.New("Detected version update of trivy. Please try again with --refresh or --auto-refresh option") } @@ -114,7 +112,7 @@ func Run(c *cli.Context) (err error) { } dbc := db.Config{} - if err = dbc.SetVersion(cliVersion); err != nil { + if err = dbc.SetVersion(db.SchemaVersion); err != nil { return xerrors.Errorf("unexpected error: %w", err) } diff --git a/pkg/vulnsrc/vulnsrc_test.go b/pkg/vulnsrc/vulnsrc_test.go index 58517e5f8f..bea5e75dc1 100644 --- a/pkg/vulnsrc/vulnsrc_test.go +++ b/pkg/vulnsrc/vulnsrc_test.go @@ -28,7 +28,7 @@ func BenchmarkUpdate(b *testing.B) { b.Run("NVD", func(b *testing.B) { dbc := db.Config{} for i := 0; i < b.N; i++ { - if err := dbc.SetVersion(""); err != nil { + if err := dbc.SetVersion(db.SchemaVersion); err != nil { b.Fatal(err) } if err := Update([]string{vulnerability.Nvd}); err != nil {