mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-23 07:29:00 -08:00
feat: support client/server mode (#295)
* chore(app): change dir * feat(rpc): add a proto file and auto-generated files * chore(dep): add dependencies * fix(app): fix import path * fix(integration): fix import path * fix(protoc): use enum for severity * chore(Makefile): add fmt andd protoc * chore(clang): add .clang-format * refactor: split functions for client/server (#296) * refactor(db): split db.Download * refactor(standalone): create a different package * refactor(vulnerability): split FillAndFilter * fix(protoc): use enum for severity * chore(Makefile): add fmt andd protoc * chore(clang): add .clang-format * fix(db): remove an unused variable * fix(db): expose the github client as an argument of constructor * refactor(vulnerability): add the detail message * feat(rpc): add rpc client (#302) * fix(protoc): use enum for severity * chore(Makefile): add fmt andd protoc * chore(clang): add .clang-format * feat(rpc): convert types * feat(rpc): add rpc client * token: Refactor to handle bad headers being set Signed-off-by: Simarpreet Singh <simar@linux.com> * feat(rpc): add rpc server (#303) * feat(rpc): add rpc server * feat(utils): add CopyFile * feat(server/config): add config struct * feat(detector): add detector * feat(scanner): delegate procedures to detector * fix(scanner): fix the interface * test(mock): add mocks * test(rpc/server): add tests * test(rpc/ospkg/server): add tests * tets(os/detector): add tests * refactor(library): move directories * chore(dependency): add google/wire * refactor(library): introduce google/wire * refactor(ospkg/detector): move directory * feat(rpc): add eosl * refactor(ospkg): introduce google/wire * refactor(wire): bind an interface * refactor(client): use wire.Struct * chore(Makefile): fix wire * test(server): add AssertExpectations * test(server): add AssertExpectations * refactor(server): remove debug log * refactor(error): add more context messages * test(server): fix error message * refactor(test): create a constructor of mock * refactor(config): remove an unused variable * test(config): add an assertion to test the config struct * feat(client/server): add sub commands (#304) * feat(rpc): add rpc server * feat(utils): add CopyFile * feat(server/config): add config struct * feat(detector): add detector * feat(scanner): delegate procedures to detector * fix(scanner): fix the interface * feat(client/server): add sub commands * merge(server3) * test(scan): remove an unused mock * refactor(client): generate the constructor by wire * fix(cli): change the default port * fix(server): use auto-generated constructor * feat(ospkg): return eosl * test(integration): add integration tests for client/server (#306) * fix(server): remove unnecessary options * test(integration): add integration tests for client/server * fix(server): wrap an error * fix(server): change the update interval * fix(server): display the error detail * test(config): add an assertion to test the config struct * fix(client): returns an error when failing to initizlie a logger * test(ospkg/server): add eosl * Squashed commit of the following: * test(server): refactor and add tests (#307) * test(github): create a mock * test(db): create a mock * test(server): add tests for DB hot update * chore(db): add a log message * refactor(db): introduce google/wire * refactor(rpc): move directory * refactor(injector): fix import name * refactor(import): remove new lines * fix(server): display the error detail * fix(server): change the update interval * fix(server): wrap an error * test(integration): add integration tests for client/server * fix(server): remove unnecessary options * refactor(server): return an error when failing to initialize a logger * refactor(server): remove unused error * fix(client/server): fix default port * chore(README): add client/server * chore(README): update
This commit is contained in:
137
pkg/rpc/server/server.go
Normal file
137
pkg/rpc/server/server.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/wire"
|
||||
|
||||
"github.com/twitchtv/twirp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/aquasecurity/trivy-db/pkg/db"
|
||||
"github.com/aquasecurity/trivy/internal/server/config"
|
||||
dbFile "github.com/aquasecurity/trivy/pkg/db"
|
||||
"github.com/aquasecurity/trivy/pkg/log"
|
||||
"github.com/aquasecurity/trivy/pkg/utils"
|
||||
rpc "github.com/aquasecurity/trivy/rpc/detector"
|
||||
)
|
||||
|
||||
var SuperSet = wire.NewSet(
|
||||
dbFile.SuperSet,
|
||||
newDBWorker,
|
||||
)
|
||||
|
||||
func ListenAndServe(addr string, c config.Config) error {
|
||||
requestWg := &sync.WaitGroup{}
|
||||
dbUpdateWg := &sync.WaitGroup{}
|
||||
|
||||
withWaitGroup := func(base http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Stop processing requests during DB update
|
||||
dbUpdateWg.Wait()
|
||||
|
||||
// Wait for all requests to be processed before DB update
|
||||
requestWg.Add(1)
|
||||
defer requestWg.Done()
|
||||
|
||||
base.ServeHTTP(w, r)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
worker := initializeDBWorker()
|
||||
ctx := context.Background()
|
||||
for {
|
||||
time.Sleep(1 * time.Hour)
|
||||
if err := worker.update(ctx, c.AppVersion, c.CacheDir, dbUpdateWg, requestWg); err != nil {
|
||||
log.Logger.Errorf("%+v\n", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
osHandler := rpc.NewOSDetectorServer(initializeOspkgServer(), nil)
|
||||
mux.Handle(rpc.OSDetectorPathPrefix, withToken(withWaitGroup(osHandler), c.Token))
|
||||
|
||||
libHandler := rpc.NewLibDetectorServer(initializeLibServer(), nil)
|
||||
mux.Handle(rpc.LibDetectorPathPrefix, withToken(withWaitGroup(libHandler), c.Token))
|
||||
|
||||
log.Logger.Infof("Listening %s...", addr)
|
||||
|
||||
return http.ListenAndServe(addr, mux)
|
||||
}
|
||||
|
||||
func withToken(base http.Handler, token string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if token != "" && token != r.Header.Get("Trivy-Token") {
|
||||
rpc.WriteError(w, twirp.NewError(twirp.Unauthenticated, "invalid token"))
|
||||
return
|
||||
}
|
||||
base.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
type dbWorker struct {
|
||||
dbClient dbFile.Operation
|
||||
}
|
||||
|
||||
func newDBWorker(dbClient dbFile.Operation) dbWorker {
|
||||
return dbWorker{dbClient: dbClient}
|
||||
}
|
||||
|
||||
func (w dbWorker) update(ctx context.Context, appVersion, cacheDir string,
|
||||
dbUpdateWg, requestWg *sync.WaitGroup) error {
|
||||
needsUpdate, err := w.dbClient.NeedsUpdate(ctx, appVersion, false, false)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to check if db needs an update")
|
||||
} else if !needsUpdate {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Logger.Info("Updating DB...")
|
||||
if err = w.hotUpdate(ctx, cacheDir, dbUpdateWg, requestWg); err != nil {
|
||||
return xerrors.Errorf("failed DB hot update")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w dbWorker) hotUpdate(ctx context.Context, cacheDir string, dbUpdateWg, requestWg *sync.WaitGroup) error {
|
||||
tmpDir, err := ioutil.TempDir("", "db")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to create a temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
if err := w.dbClient.Download(ctx, tmpDir, false); err != nil {
|
||||
return xerrors.Errorf("failed to download vulnerability DB: %w", err)
|
||||
}
|
||||
|
||||
log.Logger.Info("Suspending all requests during DB update")
|
||||
dbUpdateWg.Add(1)
|
||||
defer dbUpdateWg.Done()
|
||||
|
||||
log.Logger.Info("Waiting for all requests to be processed before DB update...")
|
||||
requestWg.Wait()
|
||||
|
||||
if err = db.Close(); err != nil {
|
||||
return xerrors.Errorf("failed to close DB: %w", err)
|
||||
}
|
||||
|
||||
if _, err = utils.CopyFile(db.Path(tmpDir), db.Path(cacheDir)); err != nil {
|
||||
return xerrors.Errorf("failed to copy the database file: %w", err)
|
||||
}
|
||||
|
||||
log.Logger.Info("Reopening DB...")
|
||||
if err = db.Init(cacheDir); err != nil {
|
||||
return xerrors.Errorf("failed to open DB: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user