From 672d4edd921663aeff02da93940ade6e6811296a Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Mon, 9 Aug 2021 22:46:40 +0600 Subject: [PATCH 01/61] feat(api): add API draft --- postee/api.go | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 postee/api.go diff --git a/postee/api.go b/postee/api.go new file mode 100644 index 00000000..809920ec --- /dev/null +++ b/postee/api.go @@ -0,0 +1,40 @@ +package postee + +import ( + "github.com/aquasecurity/postee/router" + "github.com/aquasecurity/postee/routes" +) + +func SetAquaServerUrl(aquaServerUrl string) { //optional +} + +func SetDBMaxSize(dbMaxSize int) { //optional +} + +func SetDBTestInterval(dbTestInterval int) { //optional +} + +func SetDBRemoveOldData(dbRemoveOldData int) { //optional +} + +/* do we need bolt db at all in API mode? */ + +func AddOutput(output *router.OutputSettings) error { //is ok to pass structure as input? + return nil +} + +func AddRoute(route *routes.InputRoute) error { //same question as above + return nil +} + +func AddTemplate(template *router.Template) error { //same question as above + return nil +} + +func Send(data []byte) { + //just put data into queue. No error returned +} + +func ResetConfig() error { + return nil +} From 1a6b8dd93007549ed4d5d0e2cea76d25ba5b5477 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Thu, 12 Aug 2021 16:02:06 +0600 Subject: [PATCH 02/61] feat(api): add API draft more cleanup --- main.go | 2 +- postee/api.go | 40 -------------- router/api.go | 103 +++++++++++++++++++++++++++++++++++ router/loads_test.go | 4 +- router/routehandling_test.go | 6 +- router/router.go | 42 +++++++++----- 6 files changed, 138 insertions(+), 59 deletions(-) delete mode 100644 postee/api.go create mode 100644 router/api.go diff --git a/main.go b/main.go index c75d7d32..63b20c62 100644 --- a/main.go +++ b/main.go @@ -73,7 +73,7 @@ func main() { cfgfile = os.Getenv("POSTEE_CFG") } - err := router.Instance().Start(cfgfile) + err := router.Instance().ApplyFileCfg(cfgfile) if err != nil { log.Printf("Can't start alert manager %v", err) return diff --git a/postee/api.go b/postee/api.go deleted file mode 100644 index 809920ec..00000000 --- a/postee/api.go +++ /dev/null @@ -1,40 +0,0 @@ -package postee - -import ( - "github.com/aquasecurity/postee/router" - "github.com/aquasecurity/postee/routes" -) - -func SetAquaServerUrl(aquaServerUrl string) { //optional -} - -func SetDBMaxSize(dbMaxSize int) { //optional -} - -func SetDBTestInterval(dbTestInterval int) { //optional -} - -func SetDBRemoveOldData(dbRemoveOldData int) { //optional -} - -/* do we need bolt db at all in API mode? */ - -func AddOutput(output *router.OutputSettings) error { //is ok to pass structure as input? - return nil -} - -func AddRoute(route *routes.InputRoute) error { //same question as above - return nil -} - -func AddTemplate(template *router.Template) error { //same question as above - return nil -} - -func Send(data []byte) { - //just put data into queue. No error returned -} - -func ResetConfig() error { - return nil -} diff --git a/router/api.go b/router/api.go new file mode 100644 index 00000000..db546c0d --- /dev/null +++ b/router/api.go @@ -0,0 +1,103 @@ +package router + +import ( + "github.com/aquasecurity/postee/routes" +) + +const ( + defaultConfigPath = "config/cfg.yaml" +) + +/* +Is it possible to add a callback func to the route "input", this callback func, will be called when evaluating the input "rego" +and if the callback func returns "false" then the evaluation will fail and the message is not sent. + +we want to add this as we want the consumer to be able to add a code for extending the "input" evaluation. +when adding a route, the callback function will be part of each route +func InputCallBack(inputMessage) (bool, error) +*/ + +func WithDefaultConfig() error { + return WithFileConfig(defaultConfigPath) +} +func WithFileConfig(path string) error { + Instance().Terminate() + return Instance().ApplyFileCfg(path) +} +func WithNewConfig(name string) { //tenant name + Instance().Terminate() + Instance().NewConfig() + +} + +func AquaServerUrl(aquaServerUrl string) { //optional + Instance().setAquaServerUrl(aquaServerUrl) +} + +func DBMaxSize(dbMaxSize int) { //optional +} + +func DBTestInterval(dbTestInterval int) { //optional +} + +func DBRemoveOldData(dbRemoveOldData int) { //optional +} + +//------------------Outputs------------------- +func AddOutput(output *OutputSettings) error { + return nil +} +func UpdateOutput(output *OutputSettings) error { + return nil +} +func ListOutputs() ([]OutputSettings, error) { + //should return clones of objects + return make([]OutputSettings, 0), nil +} + +func DeleteOutput(name string) error { + return nil +} + +//----------------------------------------------- + +//------------------Routes-------------------- +func AddRoute(route *routes.InputRoute) error { + return nil +} + +func DeleteRoute(name string) error { + return nil +} +func ListRoutes() ([]routes.InputRoute, error) { + //should return clones of objects + return make([]routes.InputRoute, 0), nil +} +func UpdateRoute(*routes.InputRoute) error { + return nil +} + +//----------------------------------------------- + +//-------------------Templates------------------- +func AddTemplate(template *Template) error { + return nil +} +func UpdateTemplate(template *Template) error { + return nil +} + +func DeleteTemplate(name string) error { + return nil +} + +func ListTemplates() ([]Template, error) { + //should return clones of objects + return make([]Template, 0), nil +} + +//----------------------------------------------- + +func Send(data []byte) { + //just put data into queue. No error returned +} diff --git a/router/loads_test.go b/router/loads_test.go index 2a892d5d..7ca1dc51 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -141,7 +141,7 @@ func TestLoads(t *testing.T) { defer wrap.teardown() demoCtx := wrap.instance - demoCtx.Start(wrap.cfgPath) + demoCtx.ApplyFileCfg(wrap.cfgPath) expectedOutputsCnt := 2 if len(demoCtx.outputs) != expectedOutputsCnt { @@ -179,7 +179,7 @@ func TestReload(t *testing.T) { defer wrap.teardown() demoCtx := wrap.instance - demoCtx.Start(wrap.cfgPath) + demoCtx.ApplyFileCfg(wrap.cfgPath) expectedOutputsCnt := 2 if len(demoCtx.outputs) != expectedOutputsCnt { diff --git a/router/routehandling_test.go b/router/routehandling_test.go index 6080dfb6..879146e3 100644 --- a/router/routehandling_test.go +++ b/router/routehandling_test.go @@ -257,7 +257,7 @@ func runTestRouteHandlingCase(t *testing.T, caseDesc string, cfg string, expctdI defer wrap.teardown() - err := wrap.instance.Start(wrap.cfgPath) + err := wrap.instance.ApplyFileCfg(wrap.cfgPath) if err != nil { t.Fatalf("[%s] Unexpected error %v", caseDesc, err) } @@ -308,7 +308,7 @@ func TestInvalidRouteName(t *testing.T) { defer wrap.teardown() - err := wrap.instance.Start(wrap.cfgPath) + err := wrap.instance.ApplyFileCfg(wrap.cfgPath) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -337,7 +337,7 @@ func TestSend(t *testing.T) { defer wrap.teardown() - err := wrap.instance.Start(wrap.cfgPath) + err := wrap.instance.ApplyFileCfg(wrap.cfgPath) if err != nil { t.Fatalf("Unexpected error %v", err) } diff --git a/router/router.go b/router/router.go index aa6c8be7..fe9f1924 100644 --- a/router/router.go +++ b/router/router.go @@ -70,21 +70,29 @@ func Instance() *Router { } func (ctx *Router) ReloadConfig() { ctx.Terminate() - err := ctx.Start(ctx.cfgfile) + err := ctx.ApplyFileCfg(ctx.cfgfile) if err != nil { log.Printf("Unable to start router: %s", err) } } - -func (ctx *Router) Start(cfgfile string) error { - log.Printf("Starting Router....") - - ctx.cfgfile = cfgfile +func (ctx *Router) resetCfg() { ctx.outputs = map[string]outputs.Output{} ctx.inputRoutes = map[string]*routes.InputRoute{} ctx.templates = map[string]data.Inpteval{} ctx.ticker = nil +} +func (ctx *Router) NewConfig() { + ctx.resetCfg() + go ctx.listen() +} + +func (ctx *Router) ApplyFileCfg(cfgfile string) error { + log.Printf("Starting Router....") + + ctx.cfgfile = cfgfile + + ctx.resetCfg() err := ctx.load() if err != nil { @@ -180,6 +188,16 @@ func (ctx *Router) initTemplate(template *Template) error { } return nil } +func (ctx *Router) setAquaServerUrl(url string) { + if len(url) > 0 { + var slash string + if !strings.HasSuffix(url, "/") { + slash = "/" + } + ctx.aquaServer = fmt.Sprintf("%s%s#/images/", url, slash) + } + +} func (ctx *Router) load() error { ctx.mutexScan.Lock() @@ -191,13 +209,9 @@ func (ctx *Router) load() error { return err } - if len(tenant.AquaServer) > 0 { - var slash string - if !strings.HasSuffix(tenant.AquaServer, "/") { - slash = "/" - } - ctx.aquaServer = fmt.Sprintf("%s%s#/images/", tenant.AquaServer, slash) - } + ctx.setAquaServerUrl(tenant.AquaServer) + //---------------------------------------------------- + // TODO there should be some other way of doing that dbservice.DbSizeLimit = tenant.DBMaxSize dbservice.DbDueDate = tenant.DBRemoveOldData @@ -219,6 +233,8 @@ func (ctx *Router) load() error { }() } + //---------------------------------------------------- + for i, r := range tenant.InputRoutes { ctx.inputRoutes[r.Name] = routes.ConfigureAggrTimeout(&tenant.InputRoutes[i]) } From 7964586b455b5adaa2fdebe0fad4076c560febea Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Sat, 14 Aug 2021 17:23:57 +0600 Subject: [PATCH 03/61] added output api implementation --- {router => data}/integrations.go | 2 +- {router => data}/template.go | 2 +- {router => data}/tenants.go | 2 +- data/utils.go | 6 +++ data/utils_test.go | 10 ++++ go.mod | 2 + go.sum | 5 +- msgservice/msgservice_mocks_test.go | 4 ++ outputs/email.go | 16 ++++++ outputs/jira.go | 23 +++++++++ outputs/plugin.go | 2 + outputs/servicenow.go | 13 +++++ outputs/slack.go | 9 ++++ outputs/splunk.go | 11 +++++ outputs/teams.go | 10 ++++ outputs/webhook.go | 10 ++++ router/anonymizeSettings_test.go | 26 +++++----- router/anonymizer.go | 8 ++- router/api.go | 24 ++++----- router/apit_test.go | 71 +++++++++++++++++++++++++++ router/builders.go | 15 +++--- router/initoutputs_test.go | 26 +++++----- router/inittemplate_test.go | 18 ++++--- router/loads_test.go | 14 ++++-- router/parsecfg.go | 11 +++-- router/router.go | 76 +++++++++++++++++++++++------ 26 files changed, 335 insertions(+), 81 deletions(-) rename {router => data}/integrations.go (99%) rename {router => data}/template.go (94%) rename {router => data}/tenants.go (97%) create mode 100644 router/apit_test.go diff --git a/router/integrations.go b/data/integrations.go similarity index 99% rename from router/integrations.go rename to data/integrations.go index 10b2a477..323bb319 100644 --- a/router/integrations.go +++ b/data/integrations.go @@ -1,4 +1,4 @@ -package router +package data type OutputSettings struct { Name string `json:"name,omitempty"` diff --git a/router/template.go b/data/template.go similarity index 94% rename from router/template.go rename to data/template.go index 3e3a4c46..5a412a7e 100644 --- a/router/template.go +++ b/data/template.go @@ -1,4 +1,4 @@ -package router +package data type Template struct { Name string `json:"name"` diff --git a/router/tenants.go b/data/tenants.go similarity index 97% rename from router/tenants.go rename to data/tenants.go index 1c2881ca..ca6f696c 100644 --- a/router/tenants.go +++ b/data/tenants.go @@ -1,4 +1,4 @@ -package router +package data import ( "github.com/aquasecurity/postee/routes" diff --git a/data/utils.go b/data/utils.go index 718dc331..36d3ad6e 100644 --- a/data/utils.go +++ b/data/utils.go @@ -8,3 +8,9 @@ func ClearField(source string) string { re := regexp.MustCompile(`[[:cntrl:]]|[\x{FFFD}]`) return re.ReplaceAllString(source, "") } + +func CopyStringArray(src []string) []string { + dst := make([]string, len(src)) + copy(dst, src) + return dst +} diff --git a/data/utils_test.go b/data/utils_test.go index 0b507172..804adc0f 100644 --- a/data/utils_test.go +++ b/data/utils_test.go @@ -2,6 +2,8 @@ package data import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestClearField(t *testing.T) { @@ -19,3 +21,11 @@ func TestClearField(t *testing.T) { } } } + +func TestCopyStringArray(t *testing.T) { + src := []string{"a", "b", "c"} + dst := CopyStringArray(src) + dst[0] = "x" + assert.Equal(t, "a", src[0], "TestCopyStringArray") + assert.Equal(t, "x", dst[0], "TestCopyStringArray") +} diff --git a/go.mod b/go.mod index f0dc9d9e..b041d616 100644 --- a/go.mod +++ b/go.mod @@ -8,5 +8,7 @@ require ( github.com/gorilla/mux v1.8.0 github.com/open-policy-agent/opa v0.27.1 github.com/spf13/cobra v1.1.3 + github.com/stretchr/testify v1.6.1 go.etcd.io/bbolt v1.3.5 + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 ) diff --git a/go.sum b/go.sum index 210793ff..19cde75a 100644 --- a/go.sum +++ b/go.sum @@ -333,8 +333,9 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= @@ -552,6 +553,8 @@ gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/msgservice/msgservice_mocks_test.go b/msgservice/msgservice_mocks_test.go index 1d57dec4..2be598eb 100644 --- a/msgservice/msgservice_mocks_test.go +++ b/msgservice/msgservice_mocks_test.go @@ -5,6 +5,7 @@ import ( "strings" "sync" + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" ) @@ -71,6 +72,9 @@ type DemoEmailOutput struct { func (plg *DemoEmailOutput) GetName() string { return "demo" } +func (plg *DemoEmailOutput) CloneSettings() *data.OutputSettings { + return nil +} func (plg *DemoEmailOutput) getEmailsCount() int { plg.mu.Lock() diff --git a/outputs/email.go b/outputs/email.go index 211e04d0..98ecdb7e 100644 --- a/outputs/email.go +++ b/outputs/email.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" ) @@ -32,6 +33,21 @@ func (email *EmailOutput) GetName() string { return email.Name } +func (email *EmailOutput) CloneSettings() *data.OutputSettings { + return &data.OutputSettings{ + Name: email.Name, + User: email.User, + //password is omitted + Host: email.Host, + Port: email.Port, + Sender: email.Sender, + UseMX: email.UseMX, + Recipients: data.CopyStringArray(email.Recipients), + Enable: true, + Type: "email", + } +} + func (email *EmailOutput) Init() error { log.Printf("Starting Email output %q...", email.Name) if email.Sender == "" { diff --git a/outputs/jira.go b/outputs/jira.go index dae6f7e2..e738aff9 100644 --- a/outputs/jira.go +++ b/outputs/jira.go @@ -8,6 +8,7 @@ import ( "log" "strconv" + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" @@ -45,6 +46,28 @@ func (ctx *JiraAPI) GetName() string { return ctx.Name } +func (ctx *JiraAPI) CloneSettings() *data.OutputSettings { + return &data.OutputSettings{ + Name: ctx.Name, + Url: ctx.Url, + User: ctx.User, + //password is omitted + TlsVerify: ctx.TlsVerify, + IssueType: ctx.Issuetype, + ProjectKey: ctx.ProjectKey, + Priority: ctx.Priority, + Assignee: data.CopyStringArray(ctx.Assignee), + Summary: ctx.Summary, + Sprint: ctx.SprintName, + FixVersions: data.CopyStringArray(ctx.FixVersions), + AffectsVersions: data.CopyStringArray(ctx.AffectsVersions), + Labels: data.CopyStringArray(ctx.Labels), + //TODO Unknowns + Enable: true, + Type: "Jira", + } +} + func (ctx *JiraAPI) fetchBoardId(boardName string) { client, err := ctx.createClient() if err != nil { diff --git a/outputs/plugin.go b/outputs/plugin.go index 01532010..b5c8d2d4 100644 --- a/outputs/plugin.go +++ b/outputs/plugin.go @@ -5,6 +5,7 @@ import ( "log" "strings" + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/layout" ) @@ -18,6 +19,7 @@ type Output interface { Send(map[string]string) error Terminate() error GetLayoutProvider() layout.LayoutProvider + CloneSettings() *data.OutputSettings //TODO shouldn't return reference } func getHandledRecipients(recipients []string, content *map[string]string, outputName string) []string { diff --git a/outputs/servicenow.go b/outputs/servicenow.go index 2cddefc3..81a21e6a 100644 --- a/outputs/servicenow.go +++ b/outputs/servicenow.go @@ -4,6 +4,7 @@ import ( "encoding/json" "log" + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" servicenow "github.com/aquasecurity/postee/servicenow" @@ -22,6 +23,18 @@ func (sn *ServiceNowOutput) GetName() string { return sn.Name } +func (sn *ServiceNowOutput) CloneSettings() *data.OutputSettings { + return &data.OutputSettings{ + Name: sn.Name, + User: sn.User, + //password + InstanceName: sn.Instance, + BoardName: sn.Table, + Enable: true, + Type: "serviceNow", + } +} + func (sn *ServiceNowOutput) Init() error { log.Printf("Starting ServiceNow output %q....", sn.Name) log.Printf("Your ServiceNow Table is %q on '%s.%s'", sn.Table, sn.Instance, servicenow.BaseServer) diff --git a/outputs/slack.go b/outputs/slack.go index b60ef2d5..b745c62d 100644 --- a/outputs/slack.go +++ b/outputs/slack.go @@ -28,6 +28,15 @@ func (slack *SlackOutput) GetName() string { return slack.Name } +func (slack *SlackOutput) CloneSettings() *data.OutputSettings { + return &data.OutputSettings{ + Name: slack.Name, + Url: slack.Url, + Enable: true, + Type: "slack", + } +} + func (slack *SlackOutput) Init() error { slack.slackLayout = new(formatting.SlackMrkdwnProvider) log.Printf("Starting Slack output %q....", slack.Name) diff --git a/outputs/splunk.go b/outputs/splunk.go index 0bf49d04..667534d1 100644 --- a/outputs/splunk.go +++ b/outputs/splunk.go @@ -29,6 +29,17 @@ func (splunk *SplunkOutput) GetName() string { return splunk.Name } +func (splunk *SplunkOutput) CloneSettings() *data.OutputSettings { + return &data.OutputSettings{ + Name: splunk.Name, + Url: splunk.Url, + Token: splunk.Token, + SizeLimit: splunk.EventLimit, + Enable: true, + Type: "splunk", + } +} + func (splunk *SplunkOutput) Init() error { splunk.splunkLayout = new(formatting.HtmlProvider) log.Printf("Starting Splunk output %q....", splunk.Name) diff --git a/outputs/teams.go b/outputs/teams.go index d10fd561..c9fa2acb 100644 --- a/outputs/teams.go +++ b/outputs/teams.go @@ -4,6 +4,7 @@ import ( "encoding/json" "log" + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" "github.com/aquasecurity/postee/utils" @@ -26,6 +27,15 @@ func (teams *TeamsOutput) GetName() string { return teams.Name } +func (teams *TeamsOutput) CloneSettings() *data.OutputSettings { + return &data.OutputSettings{ + Name: teams.Name, + Url: teams.Webhook, + Enable: true, + Type: "teams", + } +} + func (teams *TeamsOutput) Init() error { log.Printf("Starting MS Teams output %q....", teams.Name) teams.teamsLayout = new(formatting.HtmlProvider) diff --git a/outputs/webhook.go b/outputs/webhook.go index 0e8a1ea7..ecd856a7 100644 --- a/outputs/webhook.go +++ b/outputs/webhook.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" ) @@ -20,6 +21,15 @@ func (webhook *WebhookOutput) GetName() string { return webhook.Name } +func (webhook *WebhookOutput) CloneSettings() *data.OutputSettings { + return &data.OutputSettings{ + Name: webhook.Name, + Url: webhook.Url, + Enable: true, + Type: "webhook", + } +} + func (webhook *WebhookOutput) Init() error { log.Printf("Starting Webhook output %q, for sending to %q", webhook.Name, webhook.Url) diff --git a/router/anonymizeSettings_test.go b/router/anonymizeSettings_test.go index eb18e48d..4bea1092 100644 --- a/router/anonymizeSettings_test.go +++ b/router/anonymizeSettings_test.go @@ -1,37 +1,41 @@ package router -import "testing" +import ( + "testing" + + "github.com/aquasecurity/postee/data" +) func TestAnonymizeSettings(t *testing.T) { tests := []struct { - original *OutputSettings - expected *OutputSettings + original *data.OutputSettings + expected *data.OutputSettings }{{ - &OutputSettings{ + &data.OutputSettings{ User: "admin", }, - &OutputSettings{ + &data.OutputSettings{ User: "", }, }, { - &OutputSettings{ + &data.OutputSettings{ User: "", }, - &OutputSettings{ + &data.OutputSettings{ User: "", }, }, { - &OutputSettings{ + &data.OutputSettings{ Password: "secret", }, - &OutputSettings{ + &data.OutputSettings{ Password: "", }, }, { - &OutputSettings{ + &data.OutputSettings{ Url: "http://localhost", }, - &OutputSettings{ + &data.OutputSettings{ Url: "", }, }, diff --git a/router/anonymizer.go b/router/anonymizer.go index d8f09772..1e8cd24c 100644 --- a/router/anonymizer.go +++ b/router/anonymizer.go @@ -1,8 +1,12 @@ package router -import "reflect" +import ( + "reflect" -func anonymizeSettings(settings *OutputSettings) *OutputSettings { + "github.com/aquasecurity/postee/data" +) + +func anonymizeSettings(settings *data.OutputSettings) *data.OutputSettings { fieldsToAnonymize := [...]string{ "User", "Password", diff --git a/router/api.go b/router/api.go index db546c0d..933ee169 100644 --- a/router/api.go +++ b/router/api.go @@ -1,6 +1,7 @@ package router import ( + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/routes" ) @@ -44,19 +45,20 @@ func DBRemoveOldData(dbRemoveOldData int) { //optional } //------------------Outputs------------------- -func AddOutput(output *OutputSettings) error { - return nil +func AddOutput(output *data.OutputSettings) error { + return Instance().addOutput(output) } -func UpdateOutput(output *OutputSettings) error { - return nil +func UpdateOutput(output *data.OutputSettings) error { + Instance().deleteOutput(output.Name, false) + return Instance().addOutput(output) } -func ListOutputs() ([]OutputSettings, error) { +func ListOutputs() []data.OutputSettings { //should return clones of objects - return make([]OutputSettings, 0), nil + return Instance().listOutputs() } func DeleteOutput(name string) error { - return nil + return Instance().deleteOutput(name, true) } //----------------------------------------------- @@ -80,10 +82,10 @@ func UpdateRoute(*routes.InputRoute) error { //----------------------------------------------- //-------------------Templates------------------- -func AddTemplate(template *Template) error { +func AddTemplate(template *data.Template) error { return nil } -func UpdateTemplate(template *Template) error { +func UpdateTemplate(template *data.Template) error { return nil } @@ -91,9 +93,9 @@ func DeleteTemplate(name string) error { return nil } -func ListTemplates() ([]Template, error) { +func ListTemplates() ([]data.Template, error) { //should return clones of objects - return make([]Template, 0), nil + return make([]data.Template, 0), nil } //----------------------------------------------- diff --git a/router/apit_test.go b/router/apit_test.go new file mode 100644 index 00000000..47ae2582 --- /dev/null +++ b/router/apit_test.go @@ -0,0 +1,71 @@ +package router + +import ( + "fmt" + "testing" + + "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/outputs" + "github.com/stretchr/testify/assert" +) + +func TestAquaServerUrl(t *testing.T) { + AquaServerUrl("http://localhost:8080") + assert.Equal(t, "http://localhost:8080/#/images/", Instance().aquaServer, "AquaServerUrl") + +} + +var outputSettings = &data.OutputSettings{ + Type: "slack", + Name: "my-slack", + Url: "https://hooks.slack.com/services/TAAAA/BBB/", + Enable: true, +} + +func TestAddOutput(t *testing.T) { + AddOutput(outputSettings) + assert.Equal(t, 1, len(Instance().outputs), "one output expected") + assert.Contains(t, Instance().outputs, "my-slack") + assert.Equal(t, "my-slack", Instance().outputs["my-slack"].GetName(), "check name failed") + assert.Equal(t, "*outputs.SlackOutput", fmt.Sprintf("%T", Instance().outputs["my-slack"]), "check name failed") + +} + +func TestDeleteOutput(t *testing.T) { + AddOutput(outputSettings) + assert.Equal(t, 1, len(Instance().outputs), "one output expected") + + DeleteOutput("my-slack") + assert.Equal(t, 0, len(Instance().outputs), "no outputs expected") + +} +func TestEditOutput(t *testing.T) { + modifiedUrl := "https://hooks.slack.com/services/TAAAA/XXX/" + AddOutput(outputSettings) + assert.Equal(t, 1, len(Instance().outputs), "one output expected") + + s := Instance().outputs["my-slack"].CloneSettings() + + s.Url = modifiedUrl + + UpdateOutput(s) + + assert.Equal(t, 1, len(Instance().outputs), "one output expected") + assert.Equal(t, modifiedUrl, Instance().outputs["my-slack"].(*outputs.SlackOutput).Url, "url is updated") + +} +func TestListOutput(t *testing.T) { + AddOutput(outputSettings) + assert.Equal(t, 1, len(Instance().outputs), "one output expected") + + outputs := ListOutputs() + + assert.Equal(t, 1, len(outputs), "one output expected") + + r := outputs[0] + + assert.Equal(t, "my-slack", r.Name, "check name failed") + assert.Equal(t, "slack", r.Type, "check type failed") + assert.True(t, r.Enable, "output must be enabled") + +} diff --git a/router/builders.go b/router/builders.go index e533f003..94f0b24b 100644 --- a/router/builders.go +++ b/router/builders.go @@ -3,10 +3,11 @@ package router import ( "strings" + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/outputs" ) -func buildSplunkOutput(sourceSettings *OutputSettings) *outputs.SplunkOutput { +func buildSplunkOutput(sourceSettings *data.OutputSettings) *outputs.SplunkOutput { return &outputs.SplunkOutput{ Name: sourceSettings.Name, Url: sourceSettings.Url, @@ -15,14 +16,14 @@ func buildSplunkOutput(sourceSettings *OutputSettings) *outputs.SplunkOutput { } } -func buildWebhookOutput(sourceSettings *OutputSettings) *outputs.WebhookOutput { +func buildWebhookOutput(sourceSettings *data.OutputSettings) *outputs.WebhookOutput { return &outputs.WebhookOutput{ Name: sourceSettings.Name, Url: sourceSettings.Url, } } -func buildTeamsOutput(sourceSettings *OutputSettings, aquaServer string) *outputs.TeamsOutput { +func buildTeamsOutput(sourceSettings *data.OutputSettings, aquaServer string) *outputs.TeamsOutput { return &outputs.TeamsOutput{ Name: sourceSettings.Name, AquaServer: aquaServer, @@ -30,7 +31,7 @@ func buildTeamsOutput(sourceSettings *OutputSettings, aquaServer string) *output } } -func buildServiceNow(sourceSettings *OutputSettings) *outputs.ServiceNowOutput { +func buildServiceNow(sourceSettings *data.OutputSettings) *outputs.ServiceNowOutput { serviceNow := &outputs.ServiceNowOutput{ Name: sourceSettings.Name, User: sourceSettings.User, @@ -44,7 +45,7 @@ func buildServiceNow(sourceSettings *OutputSettings) *outputs.ServiceNowOutput { return serviceNow } -func buildSlackOutput(sourceSettings *OutputSettings, aqua string) *outputs.SlackOutput { +func buildSlackOutput(sourceSettings *data.OutputSettings, aqua string) *outputs.SlackOutput { return &outputs.SlackOutput{ Name: sourceSettings.Name, AquaServer: aqua, @@ -52,7 +53,7 @@ func buildSlackOutput(sourceSettings *OutputSettings, aqua string) *outputs.Slac } } -func buildEmailOutput(sourceSettings *OutputSettings) *outputs.EmailOutput { +func buildEmailOutput(sourceSettings *data.OutputSettings) *outputs.EmailOutput { return &outputs.EmailOutput{ Name: sourceSettings.Name, User: sourceSettings.User, @@ -65,7 +66,7 @@ func buildEmailOutput(sourceSettings *OutputSettings) *outputs.EmailOutput { } } -func buildJiraOutput(sourceSettings *OutputSettings) *outputs.JiraAPI { +func buildJiraOutput(sourceSettings *data.OutputSettings) *outputs.JiraAPI { jiraApi := &outputs.JiraAPI{ Name: sourceSettings.Name, Url: sourceSettings.Url, diff --git a/router/initoutputs_test.go b/router/initoutputs_test.go index c8991f2d..92be152d 100644 --- a/router/initoutputs_test.go +++ b/router/initoutputs_test.go @@ -4,19 +4,21 @@ import ( "fmt" "reflect" "testing" + + "github.com/aquasecurity/postee/data" ) func TestBuildAndInitOtpt(t *testing.T) { tests := []struct { caseDesc string - outputSettings OutputSettings + outputSettings data.OutputSettings expctdProps map[string]interface{} shouldFail bool expectedOutputClass string }{ { "Simple Slack", - OutputSettings{ + data.OutputSettings{ Name: "my-slack", Type: "slack", Enable: true, @@ -31,7 +33,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, { "Simple Email output", - OutputSettings{ + data.OutputSettings{ User: "EmailUser", Password: "pAsSw0rD", Host: "smtp.gmail.com", @@ -54,7 +56,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, { "Simple Jira output", - OutputSettings{ + data.OutputSettings{ Url: "localhost:2990", User: "admin", Password: "admin", @@ -79,7 +81,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, { "Jira output without credentials", - OutputSettings{ + data.OutputSettings{ Url: "localhost:2990", Name: "my-jira", Type: "jira", @@ -94,7 +96,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, { "Jira output without password", - OutputSettings{ + data.OutputSettings{ Url: "localhost:2990", User: "admin", Name: "my-jira", @@ -110,7 +112,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, { "Jira output with missed type", - OutputSettings{ + data.OutputSettings{ Url: "localhost:2990", User: "admin", Name: "my-jira", @@ -125,7 +127,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, { "Jira Output with some default values", - OutputSettings{ + data.OutputSettings{ Url: "localhost:2990", Name: "my-jira-with-defaults", Type: "jira", @@ -149,7 +151,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, { "Simple webhook output", - OutputSettings{ + data.OutputSettings{ Url: "http://localhost:8080", Name: "my-webhook", Type: "webhook", @@ -162,7 +164,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, { "Simple ServiceNow output", - OutputSettings{ + data.OutputSettings{ Name: "my-servicenow", Type: "serviceNow", User: "admin", @@ -181,7 +183,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, { "Simple Teams output", - OutputSettings{ + data.OutputSettings{ Url: "https://outlook.office.com/webhook/ABCD", Name: "my-teams", Type: "teams", @@ -194,7 +196,7 @@ func TestBuildAndInitOtpt(t *testing.T) { }, } for _, test := range tests { - o := BuildAndInitOtpt(&test.outputSettings, "") + o, _ := buildAndInitOtpt(&test.outputSettings, "") //TODO handle error if test.shouldFail && o != nil { t.Fatalf("No output expected for %s test case", test.caseDesc) } else if !test.shouldFail && o == nil { diff --git a/router/inittemplate_test.go b/router/inittemplate_test.go index 3b0b3dde..c39eb512 100644 --- a/router/inittemplate_test.go +++ b/router/inittemplate_test.go @@ -8,6 +8,8 @@ import ( "net/http" "os" "testing" + + "github.com/aquasecurity/postee/data" ) var ( @@ -43,13 +45,13 @@ func TestInitTemplate(t *testing.T) { }() tests := []struct { - template *Template + template *data.Template caseDesc string expectedCls string shouldReturnError bool }{ { - template: &Template{ + template: &data.Template{ Name: "legacy-html", LegacyScanRenderer: "html", }, @@ -57,7 +59,7 @@ func TestInitTemplate(t *testing.T) { expectedCls: "*formatting.legacyScnEvaluator", }, { - template: &Template{ + template: &data.Template{ Name: "built-in", RegoPackage: "postee.slack", }, @@ -65,7 +67,7 @@ func TestInitTemplate(t *testing.T) { expectedCls: "*regoservice.regoEvaluator", }, { - template: &Template{ + template: &data.Template{ Name: "from-url", Url: "http://localhost/slack.rego", }, @@ -73,7 +75,7 @@ func TestInitTemplate(t *testing.T) { expectedCls: "*regoservice.regoEvaluator", }, { - template: &Template{ + template: &data.Template{ Name: "not-found", Url: "http://localhost/wrong.rego", }, @@ -82,7 +84,7 @@ func TestInitTemplate(t *testing.T) { shouldReturnError: true, }, { - template: &Template{ + template: &data.Template{ Name: "from-invalid-url", Url: "invalid-url", }, @@ -91,7 +93,7 @@ func TestInitTemplate(t *testing.T) { shouldReturnError: true, }, { - template: &Template{ + template: &data.Template{ Name: "inline", Body: "package postee.inline", }, @@ -104,7 +106,7 @@ func TestInitTemplate(t *testing.T) { } } -func doInitTemplate(t *testing.T, caseDesc string, template *Template, expectedCls string, shouldReturnError bool) { +func doInitTemplate(t *testing.T, caseDesc string, template *data.Template, expectedCls string, shouldReturnError bool) { demoCtx := Instance() err := demoCtx.initTemplate(template) if err != nil && !shouldReturnError { diff --git a/router/loads_test.go b/router/loads_test.go index 7ca1dc51..b1aaa550 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -90,9 +90,17 @@ func (ctx *ctxWrapper) MsgHandling(input []byte, output outputs.Output, route *r } func (ctxWrapper *ctxWrapper) setup(cfg string) { + ctxWrapper.init() + + ctxWrapper.cfgPath = "cfg_test.yaml" + err := ioutil.WriteFile(ctxWrapper.cfgPath, []byte(cfg), 0644) + if err != nil { + log.Printf("Can't write to %s", ctxWrapper.cfgPath) + } +} +func (ctxWrapper *ctxWrapper) init() { ctxWrapper.savedDBPath = dbservice.DbPath ctxWrapper.savedBaseForTicker = baseForTicker - ctxWrapper.cfgPath = "cfg_test.yaml" ctxWrapper.savedGetService = getScanService ctxWrapper.buff = make(chan invctn) @@ -113,10 +121,6 @@ func (ctxWrapper *ctxWrapper) setup(cfg string) { return ctxWrapper } - err = ioutil.WriteFile(ctxWrapper.cfgPath, []byte(cfg), 0644) - if err != nil { - log.Printf("Can't write to %s", ctxWrapper.cfgPath) - } ctxWrapper.instance = Instance() } diff --git a/router/parsecfg.go b/router/parsecfg.go index 575cc784..a3a6954a 100644 --- a/router/parsecfg.go +++ b/router/parsecfg.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "log" + "github.com/aquasecurity/postee/data" "github.com/ghodss/yaml" ) @@ -24,17 +25,17 @@ const ( ` ) -func Parsev2cfg(cfgpath string) (*TenantSettings, error) { - data, err := ioutil.ReadFile(cfgpath) +func Parsev2cfg(cfgpath string) (*data.TenantSettings, error) { + b, err := ioutil.ReadFile(cfgpath) if err != nil { log.Printf("Failed to open file %s, %s", cfgpath, err) return nil, err } - checkV1Cfg(data, cfgpath) + checkV1Cfg(b, cfgpath) - tenant := &TenantSettings{} - err = yaml.Unmarshal(data, tenant) + tenant := &data.TenantSettings{} + err = yaml.Unmarshal(b, tenant) if err != nil { log.Printf("Failed yaml.Unmarshal, %s", err) diff --git a/router/router.go b/router/router.go index fe9f1924..0d456cf3 100644 --- a/router/router.go +++ b/router/router.go @@ -20,6 +20,7 @@ import ( "github.com/aquasecurity/postee/regoservice" "github.com/aquasecurity/postee/routes" "github.com/aquasecurity/postee/utils" + "golang.org/x/xerrors" ) const ( @@ -128,7 +129,7 @@ func (ctx *Router) Send(data []byte) { ctx.queue <- data } -func (ctx *Router) initTemplate(template *Template) error { +func (ctx *Router) initTemplate(template *data.Template) error { log.Printf("Configuring template %s \n", template.Name) if template.LegacyScanRenderer != "" { @@ -248,16 +249,63 @@ func (ctx *Router) load() error { for _, settings := range tenant.Outputs { utils.Debug("%#v\n", anonymizeSettings(&settings)) - if settings.Enable { - plg := BuildAndInitOtpt(&settings, ctx.aquaServer) - if plg != nil { - log.Printf("Output %s is configured", settings.Name) - ctx.outputs[settings.Name] = plg - } + err = ctx.addOutput(&settings) + + if err != nil { + log.Printf("Can not initialize output %s: %v \n", settings.Name, err) + } else { + log.Printf("Output %s is configured", settings.Name) + } + + } + return nil +} + +func (ctx *Router) addOutput(settings *data.OutputSettings) error { + if settings.Enable { + plg, err := buildAndInitOtpt(settings, ctx.aquaServer) + + if err != nil { + return err + } + + ctx.outputs[settings.Name] = plg + + } + return nil +} +func (ctx *Router) deleteOutput(outputName string, removeFromRoutes bool) error { + output, ok := ctx.outputs[outputName] + if !ok { + return xerrors.Errorf("output %s is not found", outputName) + } + output.Terminate() + delete(ctx.outputs, outputName) + + if removeFromRoutes { + for _, route := range ctx.inputRoutes { + removeOutputFromRoute(route, outputName) } } + return nil } +func (ctx *Router) listOutputs() []data.OutputSettings { + r := make([]data.OutputSettings, 0) + for _, output := range ctx.outputs { + r = append(r, *output.CloneSettings()) + } + return r +} +func removeOutputFromRoute(r *routes.InputRoute, outputName string) { + filtered := make([]string, 0) + for _, n := range r.Outputs { + if n != outputName { + filtered = append(filtered, n) + } + } + r.Outputs = filtered +} type service interface { MsgHandling(input []byte, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, aquaServer *string) @@ -303,16 +351,14 @@ func (ctx *Router) handle(in []byte) { ctx.HandleRoute(routeName, in) } } -func BuildAndInitOtpt(settings *OutputSettings, aquaServerUrl string) outputs.Output { +func buildAndInitOtpt(settings *data.OutputSettings, aquaServerUrl string) (outputs.Output, error) { settings.User = utils.GetEnvironmentVarOrPlain(settings.User) if len(settings.User) == 0 && requireAuthorization[settings.Type] { - log.Printf("User for %q is empty", settings.Name) - return nil + return nil, xerrors.Errorf("user for %q is empty", settings.Name) } settings.Password = utils.GetEnvironmentVarOrPlain(settings.Password) if len(settings.Password) == 0 && requireAuthorization[settings.Type] { - log.Printf("Password for %q is empty", settings.Name) - return nil + return nil, xerrors.Errorf("password for %q is empty", settings.Name) } utils.Debug("Starting Output %q: %q\n", settings.Type, settings.Name) @@ -335,13 +381,11 @@ func BuildAndInitOtpt(settings *OutputSettings, aquaServerUrl string) outputs.Ou case "splunk": plg = buildSplunkOutput(settings) default: - log.Printf("Output type %q is undefined or empty. Output name is %q.", - settings.Type, settings.Name) - return nil + return nil, xerrors.Errorf("output %s has undefined or empty type: %q", settings.Name, settings.Type) } plg.Init() - return plg + return plg, nil } func (ctx *Router) listen() { From cf1389d4f52b776d6d9f14f44f5f93e2790b2399 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Sat, 14 Aug 2021 17:29:10 +0600 Subject: [PATCH 04/61] misc clean up --- outputs/jira.go | 15 ++++----------- outputs/webhook.go | 2 +- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/outputs/jira.go b/outputs/jira.go index e738aff9..5fd65682 100644 --- a/outputs/jira.go +++ b/outputs/jira.go @@ -171,12 +171,12 @@ func (ctx *JiraAPI) Send(content map[string]string) error { metaProject, err := createMetaProject(client, ctx.ProjectKey) if err != nil { - return fmt.Errorf("Failed to create meta project: %s\n", err) + return fmt.Errorf("failed to create meta project: %s", err) } metaIssueType, err := createMetaIssueType(metaProject, ctx.Issuetype) if err != nil { - return fmt.Errorf("Failed to create meta issue type: %s", err) + return fmt.Errorf("failed to create meta issue type: %s", err) } ctx.Summary = content["title"] @@ -221,9 +221,7 @@ func (ctx *JiraAPI) Send(content map[string]string) error { } if len(ctx.Labels) > 0 { - for _, l := range ctx.Labels { - issue.Fields.Labels = append(issue.Fields.Labels, l) - } + issue.Fields.Labels = append(issue.Fields.Labels, ctx.Labels...) } if len(ctx.FixVersions) > 0 { @@ -254,11 +252,6 @@ func (ctx *JiraAPI) Send(content map[string]string) error { return nil } -func (ctx *JiraAPI) login(client *jira.Client) error { - _, err := client.Authentication.AcquireSessionCookie(ctx.User, ctx.Password) - return err -} - func (ctx *JiraAPI) openIssue(client *jira.Client, issue *jira.Issue) (*jira.Issue, error) { i, res, err := client.Issue.Create(issue) @@ -403,7 +396,7 @@ func InitIssue(c *jira.Client, metaProject *jira.MetaProject, metaIssuetype *jir } default: - return nil, fmt.Errorf("Unknown issue type encountered: %s for %s", valueType, key) + return nil, fmt.Errorf("unknown issue type encountered: %s for %s", valueType, key) } } issue.Fields = issueFields diff --git a/outputs/webhook.go b/outputs/webhook.go index ecd856a7..3e2fdcb3 100644 --- a/outputs/webhook.go +++ b/outputs/webhook.go @@ -52,7 +52,7 @@ func (webhook *WebhookOutput) Send(content map[string]string) error { } if resp.StatusCode != http.StatusOK { - msg := "Sending webhook wrong status: %q. Body: %s" + msg := "sending webhook wrong status: %q. Body: %s" log.Printf(msg, resp.StatusCode, body) return fmt.Errorf(msg, resp.StatusCode, body) } From 1358020c7c11a76318027a1de0fc5882400f924b Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Sat, 14 Aug 2021 18:50:31 +0600 Subject: [PATCH 05/61] added routes and templates --- router/api.go | 56 +++++++++++++++++++-------- router/{apit_test.go => api_test.go} | 3 ++ router/router.go | 57 ++++++++++++++++++++++++++-- 3 files changed, 96 insertions(+), 20 deletions(-) rename router/{apit_test.go => api_test.go} (98%) diff --git a/router/api.go b/router/api.go index 933ee169..de59ce73 100644 --- a/router/api.go +++ b/router/api.go @@ -17,6 +17,13 @@ we want to add this as we want the consumer to be able to add a code for extendi when adding a route, the callback function will be part of each route func InputCallBack(inputMessage) (bool, error) */ +type InputCallbackFunc func(InputMessage interface{}) bool + +//SetInputCallbackFunc The call back func will be called as the last evaluation method of the input rego, +//it will be added to the rego with && operator and the entire input evaluation will pass through only if the callback retuns true +func SetInputCallbackFunc(routeName string, callaback InputCallbackFunc) { + +} func WithDefaultConfig() error { return WithFileConfig(defaultConfigPath) @@ -53,7 +60,6 @@ func UpdateOutput(output *data.OutputSettings) error { return Instance().addOutput(output) } func ListOutputs() []data.OutputSettings { - //should return clones of objects return Instance().listOutputs() } @@ -64,18 +70,22 @@ func DeleteOutput(name string) error { //----------------------------------------------- //------------------Routes-------------------- -func AddRoute(route *routes.InputRoute) error { - return nil +func AddRoute(route *routes.InputRoute) { + Instance().addRoute(route) } func DeleteRoute(name string) error { - return nil + return Instance().deleteRoute(name) } -func ListRoutes() ([]routes.InputRoute, error) { - //should return clones of objects - return make([]routes.InputRoute, 0), nil +func ListRoutes() []routes.InputRoute { + return Instance().listRoutes() } -func UpdateRoute(*routes.InputRoute) error { +func UpdateRoute(route *routes.InputRoute) error { + err := Instance().deleteRoute(route.Name) + if err != nil { + return err + } + Instance().addRoute(route) return nil } @@ -83,23 +93,37 @@ func UpdateRoute(*routes.InputRoute) error { //-------------------Templates------------------- func AddTemplate(template *data.Template) error { - return nil + return Instance().initTemplate(template) } func UpdateTemplate(template *data.Template) error { - return nil + err := Instance().deleteTemplate(template.Name, true) + + if err != nil { + return err + } + + return Instance().initTemplate(template) } func DeleteTemplate(name string) error { - return nil + return Instance().deleteTemplate(name, true) } -func ListTemplates() ([]data.Template, error) { - //should return clones of objects - return make([]data.Template, 0), nil +func ListTemplates() []string { + /* + There is nothing to update (as only one property defines template). + So only list of template names returned + */ + templates := Instance().templates + names := make([]string, 0, len(templates)) + for n := range templates { + names = append(names, n) + } + return names } //----------------------------------------------- -func Send(data []byte) { - //just put data into queue. No error returned +func Send(b []byte) { + Instance().Send(b) } diff --git a/router/apit_test.go b/router/api_test.go similarity index 98% rename from router/apit_test.go rename to router/api_test.go index 47ae2582..99614129 100644 --- a/router/apit_test.go +++ b/router/api_test.go @@ -69,3 +69,6 @@ func TestListOutput(t *testing.T) { assert.True(t, r.Enable, "output must be enabled") } + +//TODO templates +//TODO routes diff --git a/router/router.go b/router/router.go index 0d456cf3..b43d3031 100644 --- a/router/router.go +++ b/router/router.go @@ -2,7 +2,6 @@ package router import ( "bytes" - "errors" "fmt" "io/ioutil" "log" @@ -128,6 +127,23 @@ func (ctx *Router) Terminate() { func (ctx *Router) Send(data []byte) { ctx.queue <- data } +func (ctx *Router) deleteTemplate(name string, removeFromRoutes bool) error { + _, ok := ctx.outputs[name] + if !ok { + return xerrors.Errorf("template %s is not found", name) + } + delete(ctx.outputs, name) + + if removeFromRoutes { + for _, route := range ctx.inputRoutes { + if route.Template == name { + route.Template = "" + } + } + } + + return nil +} func (ctx *Router) initTemplate(template *data.Template) error { log.Printf("Configuring template %s \n", template.Name) @@ -163,7 +179,7 @@ func (ctx *Router) initTemplate(template *data.Template) error { } if resp.StatusCode > 399 { - return errors.New(fmt.Sprintf("can not connect to %s, response status is %d", template.Url, resp.StatusCode)) + return xerrors.Errorf("can not connect to %s, response status is %d", template.Url, resp.StatusCode) } b, err := ioutil.ReadAll(resp.Body) @@ -236,8 +252,8 @@ func (ctx *Router) load() error { //---------------------------------------------------- - for i, r := range tenant.InputRoutes { - ctx.inputRoutes[r.Name] = routes.ConfigureAggrTimeout(&tenant.InputRoutes[i]) + for _, r := range tenant.InputRoutes { + ctx.addRoute(&r) } for _, t := range tenant.Templates { err := ctx.initTemplate(&t) @@ -261,6 +277,39 @@ func (ctx *Router) load() error { return nil } +func (ctx *Router) addRoute(r *routes.InputRoute) { + ctx.inputRoutes[r.Name] = routes.ConfigureAggrTimeout(r) +} +func (ctx *Router) deleteRoute(name string) error { + r, ok := ctx.inputRoutes[name] + if !ok { + return xerrors.Errorf("output %s is not found", name) + } + r.StopScheduler() + delete(ctx.inputRoutes, name) + + return nil +} + +func (ctx *Router) listRoutes() []routes.InputRoute { + list := make([]routes.InputRoute, 0, len(ctx.inputRoutes)) + for _, r := range ctx.inputRoutes { + list = append(list, routes.InputRoute{ + Name: r.Name, + Input: r.Input, + Outputs: data.CopyStringArray(r.Outputs), + Plugins: routes.Plugins{ + AggregateIssuesNumber: r.Plugins.AggregateIssuesNumber, + AggregateIssuesTimeout: r.Plugins.AggregateIssuesTimeout, + PolicyShowAll: r.Plugins.PolicyShowAll, + AggregateTimeoutSeconds: r.Plugins.AggregateTimeoutSeconds, + }, + Template: r.Template, + }) + } + return list +} + func (ctx *Router) addOutput(settings *data.OutputSettings) error { if settings.Enable { plg, err := buildAndInitOtpt(settings, ctx.aquaServer) From 581bbb9cfe8b84ea5029e0bcfc5f9a2027ed9694 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Wed, 18 Aug 2021 15:50:40 +0600 Subject: [PATCH 06/61] switched API to synchronous mode --- router/api.go | 5 ++++- router/router.go | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/router/api.go b/router/api.go index de59ce73..c0e73386 100644 --- a/router/api.go +++ b/router/api.go @@ -1,6 +1,8 @@ package router import ( + "bytes" + "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/routes" ) @@ -125,5 +127,6 @@ func ListTemplates() []string { //----------------------------------------------- func Send(b []byte) { - Instance().Send(b) + //Instance().Send(b) + Instance().handle(bytes.ReplaceAll(b, []byte{'`'}, []byte{'\''})) } diff --git a/router/router.go b/router/router.go index b43d3031..f2927ba9 100644 --- a/router/router.go +++ b/router/router.go @@ -84,7 +84,6 @@ func (ctx *Router) resetCfg() { } func (ctx *Router) NewConfig() { ctx.resetCfg() - go ctx.listen() } func (ctx *Router) ApplyFileCfg(cfgfile string) error { From 4c74bffbaf6bffbfbb3f4838549e3cdbcfed02a5 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Thu, 19 Aug 2021 13:20:08 +0600 Subject: [PATCH 07/61] fixed typo --- router/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/api.go b/router/api.go index c0e73386..791c8549 100644 --- a/router/api.go +++ b/router/api.go @@ -22,7 +22,7 @@ func InputCallBack(inputMessage) (bool, error) type InputCallbackFunc func(InputMessage interface{}) bool //SetInputCallbackFunc The call back func will be called as the last evaluation method of the input rego, -//it will be added to the rego with && operator and the entire input evaluation will pass through only if the callback retuns true +//it will be added to the rego with && operator and the entire input evaluation will pass through only if the callback returns true func SetInputCallbackFunc(routeName string, callaback InputCallbackFunc) { } From a2d28de1c2e0b64ef7f83279ae6470ecfdbb6c1a Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Thu, 19 Aug 2021 19:02:44 +0600 Subject: [PATCH 08/61] added simple integration test --- regoservice/eval.go | 20 ++++++--- router/api.go | 23 ++++++++-- router/api_integration_test.go | 82 ++++++++++++++++++++++++++++++++++ router/loads_test.go | 4 +- router/routehandling_test.go | 6 +-- router/router.go | 42 +++++++++++------ 6 files changed, 151 insertions(+), 26 deletions(-) create mode 100644 router/api_integration_test.go diff --git a/regoservice/eval.go b/regoservice/eval.go index 4335e165..51e46c4c 100644 --- a/regoservice/eval.go +++ b/regoservice/eval.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log" + "os" "github.com/aquasecurity/postee/data" "github.com/open-policy-agent/opa/rego" @@ -97,7 +98,7 @@ func getFirstElement(context map[string]interface{}, key string) interface{} { func asStringOrJson(data map[string]interface{}, prop string) (string, error) { expr, ok := data[prop] if !ok { - return "", errors.New(fmt.Sprintf("property %s is not found", prop)) + return "", fmt.Errorf(fmt.Sprintf("property %s is not found", prop)) } fmt.Printf("value: %q", expr) switch v := expr.(type) { @@ -112,7 +113,7 @@ func asStringOrJson(data map[string]interface{}, prop string) (string, error) { } } func (regoEvaluator *regoEvaluator) BuildAggregatedContent(scans []map[string]string) (map[string]string, error) { - aggregatedJson := make([]map[string]interface{}, len(scans), len(scans)) + aggregatedJson := make([]map[string]interface{}, len(scans)) for _, scan := range scans { desc := scan["description"] @@ -203,7 +204,7 @@ func buildAggregatedRego(query *rego.PreparedEvalQuery) (*rego.PreparedEvalQuery ctx := context.Background() //execute query with empty input and check if aggregation package is defined - rs, err := query.Eval(ctx, rego.EvalInput(make(map[string]interface{}))) + rs, _ := query.Eval(ctx, rego.EvalInput(make(map[string]interface{}))) if len(rs) == 0 || len(rs[0].Expressions) == 0 { return nil, errors.New("no results") //TODO error definition @@ -214,10 +215,12 @@ func buildAggregatedRego(query *rego.PreparedEvalQuery) (*rego.PreparedEvalQuery aggregation_pkg_val := expr[aggregation_pkg_prop] var aggrQuery *rego.PreparedEvalQuery - if aggregation_pkg_val != nil { aggregation_pkg := aggregation_pkg_val.(string) + var err error + aggrQuery, err = buildBundledRegoForPackage(aggregation_pkg) + if err != nil { return nil, err } @@ -230,11 +233,18 @@ func buildAggregatedRego(query *rego.PreparedEvalQuery) (*rego.PreparedEvalQuery func BuildExternalRegoEvaluator(filename string, body string) (data.Inpteval, error) { ctx := context.Background() + foundPaths := make([]string, 0) + + for _, path := range commonRegoTemplates { + if _, err := os.Stat(commonRegoTemplates[0]); !os.IsNotExist(err) { + foundPaths = append(foundPaths, path) + } + } r, err := rego.New( rego.Query("data"), jsonFmtFunc(), - rego.Load(commonRegoTemplates, nil), //only common modules + rego.Load(foundPaths, nil), //only common modules rego.Module(filename, body), ).PrepareForEval(ctx) diff --git a/router/api.go b/router/api.go index 791c8549..333cc835 100644 --- a/router/api.go +++ b/router/api.go @@ -2,6 +2,7 @@ package router import ( "bytes" + "os" "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/routes" @@ -12,6 +13,7 @@ const ( ) /* +TODO Is it possible to add a callback func to the route "input", this callback func, will be called when evaluating the input "rego" and if the callback func returns "false" then the evaluation will fail and the message is not sent. @@ -32,12 +34,11 @@ func WithDefaultConfig() error { } func WithFileConfig(path string) error { Instance().Terminate() - return Instance().ApplyFileCfg(path) + return Instance().ApplyFileCfg(path, true) } func WithNewConfig(name string) { //tenant name Instance().Terminate() - Instance().NewConfig() - + Instance().resetCfg(true) } func AquaServerUrl(aquaServerUrl string) { //optional @@ -97,6 +98,22 @@ func UpdateRoute(route *routes.InputRoute) error { func AddTemplate(template *data.Template) error { return Instance().initTemplate(template) } + +//helper method +func AddRegoTemplateFromFile(name, filename string) error { + b, err := os.ReadFile(filename) + + if err != nil { + return err + } + + return AddTemplate(&data.Template{ + Name: name, + Body: string(b), + }) + +} + func UpdateTemplate(template *data.Template) error { err := Instance().deleteTemplate(template.Name, true) diff --git a/router/api_integration_test.go b/router/api_integration_test.go new file mode 100644 index 00000000..024cb871 --- /dev/null +++ b/router/api_integration_test.go @@ -0,0 +1,82 @@ +package router_test + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/router" + "github.com/aquasecurity/postee/routes" + "github.com/stretchr/testify/assert" +) + +const ( + msg = ` +{ + "action": "Login", + "adjective": "demolab.aquasec.com", + "category": "User", + "date": 1618409998039, + "description": "Roles: Administrator", + "id": 0, + "result": 1, + "source_ip": "172.18.0.9", + "time": 1618409998, + "type": "Administration", + "user": "upwork" +}` + rego = `package example.audit.html +title:="Audit event received" +result:=[{"type":"section","text":{"type":"mrkdwn","text": input.user}}] +` + want = `[{"text":{"text":"upwork","type":"mrkdwn"},"type":"section"}]` +) + +func TestAudit(t *testing.T) { + received := make(chan ([]byte)) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed ioutil.ReadAll: %s\n", err) + received <- []byte{} + return + } + + received <- body + + defer r.Body.Close() + })) + defer ts.Close() + + router.WithNewConfig("test") + + err := router.AddTemplate(&data.Template{ + Name: "audit-json-template", + Body: rego, + }) + if err != nil { + t.Logf("Error: %v", err) + return + } + router.AddOutput(&data.OutputSettings{ + Name: "test-webhook", + Type: "webhook", + Enable: true, + Url: ts.URL, + }) + + router.AddRoute(&routes.InputRoute{ + Name: "test", + Outputs: []string{"test-webhook"}, + Template: "audit-json-template", + Plugins: routes.Plugins{ + PolicyShowAll: true, + }, + }) + router.Send([]byte(msg)) + got := <-received + assert.Equal(t, string(got), want, "unexpected response") +} diff --git a/router/loads_test.go b/router/loads_test.go index b1aaa550..5dd95f45 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -145,7 +145,7 @@ func TestLoads(t *testing.T) { defer wrap.teardown() demoCtx := wrap.instance - demoCtx.ApplyFileCfg(wrap.cfgPath) + demoCtx.ApplyFileCfg(wrap.cfgPath, false) expectedOutputsCnt := 2 if len(demoCtx.outputs) != expectedOutputsCnt { @@ -183,7 +183,7 @@ func TestReload(t *testing.T) { defer wrap.teardown() demoCtx := wrap.instance - demoCtx.ApplyFileCfg(wrap.cfgPath) + demoCtx.ApplyFileCfg(wrap.cfgPath, false) expectedOutputsCnt := 2 if len(demoCtx.outputs) != expectedOutputsCnt { diff --git a/router/routehandling_test.go b/router/routehandling_test.go index 879146e3..8602c55c 100644 --- a/router/routehandling_test.go +++ b/router/routehandling_test.go @@ -257,7 +257,7 @@ func runTestRouteHandlingCase(t *testing.T, caseDesc string, cfg string, expctdI defer wrap.teardown() - err := wrap.instance.ApplyFileCfg(wrap.cfgPath) + err := wrap.instance.ApplyFileCfg(wrap.cfgPath, false) if err != nil { t.Fatalf("[%s] Unexpected error %v", caseDesc, err) } @@ -308,7 +308,7 @@ func TestInvalidRouteName(t *testing.T) { defer wrap.teardown() - err := wrap.instance.ApplyFileCfg(wrap.cfgPath) + err := wrap.instance.ApplyFileCfg(wrap.cfgPath, false) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -337,7 +337,7 @@ func TestSend(t *testing.T) { defer wrap.teardown() - err := wrap.instance.ApplyFileCfg(wrap.cfgPath) + err := wrap.instance.ApplyFileCfg(wrap.cfgPath, false) if err != nil { t.Fatalf("Unexpected error %v", err) } diff --git a/router/router.go b/router/router.go index f2927ba9..2824294e 100644 --- a/router/router.go +++ b/router/router.go @@ -41,6 +41,7 @@ type Router struct { outputs map[string]outputs.Output inputRoutes map[string]*routes.InputRoute templates map[string]data.Inpteval + synchronous bool } var ( @@ -58,46 +59,53 @@ func Instance() *Router { initCtx.Do(func() { routerCtx = &Router{ mutexScan: sync.Mutex{}, - quit: make(chan struct{}), - queue: make(chan []byte, 1000), outputs: make(map[string]outputs.Output), inputRoutes: make(map[string]*routes.InputRoute), templates: make(map[string]data.Inpteval), - stopTicker: make(chan struct{}), + synchronous: false, } }) return routerCtx } func (ctx *Router) ReloadConfig() { ctx.Terminate() - err := ctx.ApplyFileCfg(ctx.cfgfile) + err := ctx.ApplyFileCfg(ctx.cfgfile, ctx.synchronous) if err != nil { log.Printf("Unable to start router: %s", err) } } -func (ctx *Router) resetCfg() { + +func (ctx *Router) resetCfg(synchronous bool) { ctx.outputs = map[string]outputs.Output{} ctx.inputRoutes = map[string]*routes.InputRoute{} ctx.templates = map[string]data.Inpteval{} ctx.ticker = nil -} -func (ctx *Router) NewConfig() { - ctx.resetCfg() + ctx.synchronous = synchronous + + if ctx.synchronous { + ctx.quit = make(chan struct{}) + ctx.queue = make(chan []byte, 1000) + ctx.stopTicker = make(chan struct{}) + } } -func (ctx *Router) ApplyFileCfg(cfgfile string) error { +func (ctx *Router) ApplyFileCfg(cfgfile string, synchronous bool) error { log.Printf("Starting Router....") ctx.cfgfile = cfgfile - ctx.resetCfg() + ctx.resetCfg(synchronous) err := ctx.load() if err != nil { return err } - go ctx.listen() + + if !ctx.synchronous { + go ctx.listen() + } + return nil } @@ -114,7 +122,10 @@ func (ctx *Router) Terminate() { } log.Printf("Route schedulers stopped") - ctx.quit <- struct{}{} + if ctx.quit != nil { + ctx.quit <- struct{}{} + } + log.Printf("quit notified") if ctx.ticker != nil { ctx.stopTicker <- struct{}{} @@ -390,7 +401,12 @@ func (ctx *Router) HandleRoute(routeName string, in []byte) { continue } log.Printf("route %q is associated with template %q", routeName, r.Template) - go getScanService().MsgHandling(in, pl, r, tmpl, &ctx.aquaServer) + + if ctx.synchronous { + getScanService().MsgHandling(in, pl, r, tmpl, &ctx.aquaServer) + } else { + go getScanService().MsgHandling(in, pl, r, tmpl, &ctx.aquaServer) + } } } From da3dae1e5d1a3db386396fd5463d0b1d305f7040 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Thu, 19 Aug 2021 19:08:20 +0600 Subject: [PATCH 09/61] updated go version --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 62a49d76..5fae1449 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Go 1.x uses: actions/setup-go@v2 with: - go-version: ^1.13 + go-version: ^1.16 id: go - name: Check out code into the Go module directory From e47211bb16a68f348fc4781b814572b1ee1b9400 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Thu, 19 Aug 2021 20:44:04 +0600 Subject: [PATCH 10/61] fixed tests --- main.go | 2 +- router/api.go | 2 +- router/router.go | 28 ++++++++++++++++++++-------- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/main.go b/main.go index 63b20c62..bf3ed4ac 100644 --- a/main.go +++ b/main.go @@ -73,7 +73,7 @@ func main() { cfgfile = os.Getenv("POSTEE_CFG") } - err := router.Instance().ApplyFileCfg(cfgfile) + err := router.Instance().ApplyFileCfg(cfgfile, false) if err != nil { log.Printf("Can't start alert manager %v", err) return diff --git a/router/api.go b/router/api.go index 333cc835..a44ab1a4 100644 --- a/router/api.go +++ b/router/api.go @@ -38,7 +38,7 @@ func WithFileConfig(path string) error { } func WithNewConfig(name string) { //tenant name Instance().Terminate() - Instance().resetCfg(true) + Instance().initCfg(true) } func AquaServerUrl(aquaServerUrl string) { //optional diff --git a/router/router.go b/router/router.go index 2824294e..5bb0df1c 100644 --- a/router/router.go +++ b/router/router.go @@ -76,17 +76,19 @@ func (ctx *Router) ReloadConfig() { } } -func (ctx *Router) resetCfg(synchronous bool) { - ctx.outputs = map[string]outputs.Output{} - ctx.inputRoutes = map[string]*routes.InputRoute{} - ctx.templates = map[string]data.Inpteval{} - ctx.ticker = nil +func (ctx *Router) initCfg(synchronous bool) { + ctx.cleanInstance() + ctx.synchronous = synchronous - if ctx.synchronous { + if !ctx.synchronous { ctx.quit = make(chan struct{}) ctx.queue = make(chan []byte, 1000) ctx.stopTicker = make(chan struct{}) + } else { + ctx.quit = nil + ctx.queue = nil + ctx.stopTicker = nil } } @@ -95,7 +97,7 @@ func (ctx *Router) ApplyFileCfg(cfgfile string, synchronous bool) error { ctx.cfgfile = cfgfile - ctx.resetCfg(synchronous) + ctx.initCfg(synchronous) err := ctx.load() if err != nil { @@ -122,16 +124,27 @@ func (ctx *Router) Terminate() { } log.Printf("Route schedulers stopped") + log.Printf("ctx.quit %v\n", ctx.quit) + if ctx.quit != nil { ctx.quit <- struct{}{} } log.Printf("quit notified") + if ctx.ticker != nil { ctx.stopTicker <- struct{}{} log.Printf("stopTicker notified") } + ctx.cleanInstance() +} +func (ctx *Router) cleanInstance() { + ctx.outputs = map[string]outputs.Output{} + ctx.inputRoutes = map[string]*routes.InputRoute{} + ctx.templates = map[string]data.Inpteval{} + ctx.ticker = nil + ctx.quit = nil } func (ctx *Router) Send(data []byte) { @@ -259,7 +272,6 @@ func (ctx *Router) load() error { } }() } - //---------------------------------------------------- for _, r := range tenant.InputRoutes { From 26ba2db9582030de1d4ec79527401a6231f1fa04 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Mon, 23 Aug 2021 21:07:17 +0600 Subject: [PATCH 11/61] added fix to set dbpath through API --- router/api.go | 13 ++++++++----- router/api_integration_test.go | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/router/api.go b/router/api.go index a44ab1a4..518a15ae 100644 --- a/router/api.go +++ b/router/api.go @@ -5,6 +5,7 @@ import ( "os" "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/dbservice" "github.com/aquasecurity/postee/routes" ) @@ -29,15 +30,17 @@ func SetInputCallbackFunc(routeName string, callaback InputCallbackFunc) { } -func WithDefaultConfig() error { - return WithFileConfig(defaultConfigPath) +func WithDefaultConfig(dbPath string) error { + return WithFileConfig(defaultConfigPath, dbPath) } -func WithFileConfig(path string) error { +func WithFileConfig(cfgPath, dbPath string) error { Instance().Terminate() - return Instance().ApplyFileCfg(path, true) + dbservice.DbPath = dbPath + return Instance().ApplyFileCfg(cfgPath, true) } -func WithNewConfig(name string) { //tenant name +func WithNewConfig(tenantName, dbPath string) { //tenant name Instance().Terminate() + dbservice.DbPath = dbPath Instance().initCfg(true) } diff --git a/router/api_integration_test.go b/router/api_integration_test.go index 024cb871..7b37a72d 100644 --- a/router/api_integration_test.go +++ b/router/api_integration_test.go @@ -51,7 +51,7 @@ func TestAudit(t *testing.T) { })) defer ts.Close() - router.WithNewConfig("test") + router.WithNewConfig("test", "./webhook.db") err := router.AddTemplate(&data.Template{ Name: "audit-json-template", From 31c55cacb0e2a615785dea42b310ec69dd78e4d1 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Tue, 24 Aug 2021 19:19:33 +0600 Subject: [PATCH 12/61] added input callbacks & tests --- msgservice/aggregatebytime_test.go | 6 +- msgservice/aggregatescan_test.go | 4 +- msgservice/applicationscopeowner_test.go | 16 +- msgservice/getuniqueid_test.go | 69 ++-- msgservice/logs.go | 16 - msgservice/msghandling.go | 19 +- msgservice/msgservice_mocks_test.go | 9 +- msgservice/msgservice_test.go | 88 +---- router/api.go | 5 +- router/loads_test.go | 2 +- {msgservice => router}/regocriteria_test.go | 20 +- router/routehandling_test.go | 322 ++++++++---------- router/router.go | 60 +++- router/testdata/configs/invalid-output.yaml | 20 ++ router/testdata/configs/invalid-template.yaml | 20 ++ .../configs/no-associated-output.yaml | 19 ++ router/testdata/configs/no-outputs.yaml | 14 + router/testdata/configs/no-templates.yaml | 18 + router/testdata/configs/single-route.yaml | 20 ++ router/testdata/configs/two-outputs.yaml | 24 ++ router/testdata/configs/two-routes.yaml | 26 ++ .../configs/with-input-filter-empty.yaml | 21 ++ .../configs/with-input-filter-invalid.yaml | 26 ++ .../configs/with-input-filter-no-match.yaml | 21 ++ .../testdata/configs/with-input-filter.yaml | 21 ++ utils/utils.go | 13 + 26 files changed, 511 insertions(+), 388 deletions(-) delete mode 100644 msgservice/logs.go rename {msgservice => router}/regocriteria_test.go (69%) create mode 100644 router/testdata/configs/invalid-output.yaml create mode 100644 router/testdata/configs/invalid-template.yaml create mode 100644 router/testdata/configs/no-associated-output.yaml create mode 100644 router/testdata/configs/no-outputs.yaml create mode 100644 router/testdata/configs/no-templates.yaml create mode 100644 router/testdata/configs/single-route.yaml create mode 100644 router/testdata/configs/two-outputs.yaml create mode 100644 router/testdata/configs/two-routes.yaml create mode 100644 router/testdata/configs/with-input-filter-empty.yaml create mode 100644 router/testdata/configs/with-input-filter-invalid.yaml create mode 100644 router/testdata/configs/with-input-filter-no-match.yaml create mode 100644 router/testdata/configs/with-input-filter.yaml diff --git a/msgservice/aggregatebytime_test.go b/msgservice/aggregatebytime_test.go index 5497ac0d..c813f764 100644 --- a/msgservice/aggregatebytime_test.go +++ b/msgservice/aggregatebytime_test.go @@ -53,9 +53,9 @@ func TestAggregateByTimeout(t *testing.T) { srvUrl := "" srv1 := new(MsgService) - srv1.MsgHandling([]byte(mockScan1), demoEmailPlg, demoRoute, demoInptEval, &srvUrl) - srv1.MsgHandling([]byte(mockScan2), demoEmailPlg, demoRoute, demoInptEval, &srvUrl) - srv1.MsgHandling([]byte(mockScan3), demoEmailPlg, demoRoute, demoInptEval, &srvUrl) + srv1.MsgHandling(mockScan1, demoEmailPlg, demoRoute, demoInptEval, &srvUrl) + srv1.MsgHandling(mockScan2, demoEmailPlg, demoRoute, demoInptEval, &srvUrl) + srv1.MsgHandling(mockScan3, demoEmailPlg, demoRoute, demoInptEval, &srvUrl) expectedSchedulerInvctCnt := 1 diff --git a/msgservice/aggregatescan_test.go b/msgservice/aggregatescan_test.go index ca4cc5cd..fc92a735 100644 --- a/msgservice/aggregatescan_test.go +++ b/msgservice/aggregatescan_test.go @@ -49,7 +49,7 @@ func doAggregate(t *testing.T, caseDesc string, expectedSntCnt int, expectedRend emailCounts: 0, } - scans := []string{mockScan1, mockScan2, mockScan3, mockScan4} + scans := []map[string]interface{}{mockScan1, mockScan2, mockScan3, mockScan4} srvUrl := "" demoRoute := &routes.InputRoute{} @@ -67,7 +67,7 @@ func doAggregate(t *testing.T, caseDesc string, expectedSntCnt int, expectedRend for _, scan := range scans { srv := new(MsgService) - srv.MsgHandling([]byte(scan), demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + srv.MsgHandling(scan, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) } demoEmailOutput.wg.Wait() diff --git a/msgservice/applicationscopeowner_test.go b/msgservice/applicationscopeowner_test.go index ba4f96cb..f5c2fb14 100644 --- a/msgservice/applicationscopeowner_test.go +++ b/msgservice/applicationscopeowner_test.go @@ -11,13 +11,13 @@ import ( ) var ( - scnWithOwners = `{ - "image":"Demo mock image1", - "registry":"registry1", - "vulnerability_summary":{"critical":0,"high":1,"medium":3,"low":4,"negligible":5}, - "image_assurance_results":{"disallowed":true}, - "application_scope_owners": ["recipient1@aquasec.com", "recipient1@aquasec.com"] - }` + scnWithOwners = map[string]interface{}{ + "image": "Demo mock image1", + "registry": "registry1", + "vulnerability_summary": map[string]int{"critical": 0, "high": 1, "medium": 3, "low": 4, "negligible": 5}, + "image_assurance_results": map[string]interface{}{"disallowed": true}, + "application_scope_owners": []string{"recipient1@aquasec.com", "recipient1@aquasec.com"}, + } ) func TestApplicationScopeOwner(t *testing.T) { @@ -44,7 +44,7 @@ func TestApplicationScopeOwner(t *testing.T) { demoEmailOutput.wg.Add(1) srv := new(MsgService) - srv.MsgHandling([]byte(scnWithOwners), demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + srv.MsgHandling(scnWithOwners, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) demoEmailOutput.wg.Wait() diff --git a/msgservice/getuniqueid_test.go b/msgservice/getuniqueid_test.go index 3071c853..44bfd60c 100644 --- a/msgservice/getuniqueid_test.go +++ b/msgservice/getuniqueid_test.go @@ -10,76 +10,75 @@ import ( ) var ( - unique_scan1 = `{ - "image":"Demo mock image1", - "registry":"registry1", - "digest":"abc", - "vulnerability_summary":{"critical":0,"high":1,"medium":3,"low":4,"negligible":5}, - "image_assurance_results":{"disallowed":true} -}` - unique_scan2 = `{ - "image":"Demo mock image2", - "registry":"registry2", - "digest":"def", - "vulnerability_summary":{"critical":0,"high":1,"medium":3,"low":4,"negligible":5}, - "image_assurance_results":{"disallowed":true} -}` - non_unique_payload = `{ - "action": "some", + unique_scan1 = map[string]interface{}{ + "image": "Demo mock image1", + "registry": "registry1", + "digest": "abc", + "vulnerability_summary": map[string]int{"critical": 0, "high": 1, "medium": 3, "low": 4, "negligible": 5}, + "image_assurance_results": map[string]interface{}{"disallowed": true}, + } + unique_scan2 = map[string]interface{}{ + "image": "Demo mock image2", + "registry": "registry2", + "digest": "def", + "vulnerability_summary": map[string]int{"critical": 0, "high": 1, "medium": 3, "low": 4, "negligible": 5}, + "image_assurance_results": map[string]interface{}{"disallowed": true}, + } + non_unique_payload = map[string]interface{}{ + "action": "some", "adjective": "nice", - "category" : "", - "date": 123, - "id": 8, - "result": 200, + "category": "", + "date": 123, + "id": 8, + "result": 200, "source_ip": "192.168.0.1", - "time": 45, - "type": "one", - "user": "admin", - "version": "2.0.1" - -}` + "time": 45, + "type": "one", + "user": "admin", + "version": "2.0.1", + } ) func TestScanUniqueId(t *testing.T) { tests := []struct { - inputs []string + inputs []map[string]interface{} caseDesc string policyShowAll bool expctdInvc int }{ { - inputs: []string{unique_scan1, unique_scan1}, + inputs: []map[string]interface{}{unique_scan1, unique_scan1}, caseDesc: "Same scan twice with PolicyShowAll: false", policyShowAll: false, expctdInvc: 1, }, { - inputs: []string{unique_scan1, unique_scan1}, + inputs: []map[string]interface{}{unique_scan1, unique_scan1}, caseDesc: "Same scan twice with PolicyShowAll: true", policyShowAll: true, expctdInvc: 2, }, { - inputs: []string{unique_scan1, unique_scan2}, + inputs: []map[string]interface{}{unique_scan1, unique_scan2}, caseDesc: "2 unique scan with PolicyShowAll: true", policyShowAll: true, expctdInvc: 2, }, { - inputs: []string{unique_scan1, unique_scan2}, + inputs: []map[string]interface{}{unique_scan1, unique_scan2}, caseDesc: "2 unique scan with PolicyShowAll: false", policyShowAll: false, expctdInvc: 2, }, { - inputs: []string{non_unique_payload, non_unique_payload}, + inputs: []map[string]interface{}{non_unique_payload, non_unique_payload}, caseDesc: "2 non-scan inputs with PolicyShowAll: true", policyShowAll: true, expctdInvc: 2, }, { caseDesc: "2 non-scan inputs with PolicyShowAll: false", - inputs: []string{non_unique_payload, non_unique_payload}, + inputs: []map[string]interface{}{non_unique_payload, non_unique_payload}, policyShowAll: false, expctdInvc: 2, }, @@ -91,7 +90,7 @@ func TestScanUniqueId(t *testing.T) { } -func sendInputs(t *testing.T, caseDesc string, inputs []string, policyShowAll bool, expected int) { +func sendInputs(t *testing.T, caseDesc string, inputs []map[string]interface{}, policyShowAll bool, expected int) { dbPathReal := dbservice.DbPath defer func() { os.Remove(dbservice.DbPath) @@ -116,7 +115,7 @@ func sendInputs(t *testing.T, caseDesc string, inputs []string, policyShowAll bo for _, inp := range inputs { srv := new(MsgService) - srv.MsgHandling([]byte(inp), demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + srv.MsgHandling(inp, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) } demoEmailOutput.wg.Wait() diff --git a/msgservice/logs.go b/msgservice/logs.go deleted file mode 100644 index ea6de6e8..00000000 --- a/msgservice/logs.go +++ /dev/null @@ -1,16 +0,0 @@ -package msgservice - -import "log" - -func prnInputLogs(msg string, v ...interface{}) { - maxLen := 20 - for idx, e := range v { - b, ok := e.([]byte) - if ok { - if l := len(b); l > maxLen { - v[idx] = string(b[:maxLen]) - } - } - } - log.Printf(msg, v...) -} diff --git a/msgservice/msghandling.go b/msgservice/msghandling.go index c89ba539..c89ad318 100644 --- a/msgservice/msghandling.go +++ b/msgservice/msghandling.go @@ -8,7 +8,6 @@ import ( "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/dbservice" "github.com/aquasecurity/postee/outputs" - "github.com/aquasecurity/postee/regoservice" "github.com/aquasecurity/postee/routes" ) @@ -18,25 +17,13 @@ type MsgService struct { isNew bool } -func (scan *MsgService) MsgHandling(input []byte, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, AquaServer *string) { +func (scan *MsgService) MsgHandling(in map[string]interface{}, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, AquaServer *string) { if output == nil { return } - in := map[string]interface{}{} - if err := json.Unmarshal(input, &in); err != nil { - prnInputLogs("json.Unmarshal error for %q: %v", input, err) - return - } - - if ok, err := regoservice.DoesMatchRegoCriteria(in, route.Input); err != nil { - prnInputLogs("Error while evaluating rego rule %s :%v for the input %s", route.Input, err, input) - return - } else if !ok { - prnInputLogs("Input %s... doesn't match a REGO rule: %s", input, route.Input) - return - } - + //TODO marshalling message back to bytes, change after merge with https://github.com/aquasecurity/postee/pull/150 + input, _ := json.Marshal(in) if err := scan.init(input); err != nil { log.Println("ScanService.Init Error: Can't init service with data:", input, "\nError:", err) return diff --git a/msgservice/msgservice_mocks_test.go b/msgservice/msgservice_mocks_test.go index 2be598eb..3307a4c5 100644 --- a/msgservice/msgservice_mocks_test.go +++ b/msgservice/msgservice_mocks_test.go @@ -11,11 +11,10 @@ import ( ) var ( - mockScan1 = `{"image":"Demo mock image1","registry":"registry1","vulnerability_summary":{"critical":0,"high":1,"medium":3,"low":4,"negligible":5},"image_assurance_results":{"disallowed":true}}` - mockScan2 = `{"image":"Demo mock Image2","registry":"registry2","vulnerability_summary":{"critical":0,"high":0,"medium":3,"low":4,"negligible":5},"image_assurance_results":{"disallowed":false}}` - mockScan3 = `{"image":"Demo mock Image3","registry":"Registry3","vulnerability_summary":{"critical":0,"high":0,"medium":0,"low":4,"negligible":5},"image_assurance_results":{"disallowed":true}}` - mockScan4 = `{"image":"Demo mock image4","registry":"registry4","vulnerability_summary":{"critical":0,"high":0,"medium":0,"low":0,"negligible":5},"image_assurance_results":{"disallowed":true}}` - mockScan5 = `{"image":"Demo mock image5","registry":"registry5","vulnerability_summary":{"critical":1,"high":2,"medium":3,"low":4,"negligible":5},"image_assurance_results":{"disallowed":true}}` + mockScan1 = map[string]interface{}{"image": "Demo mock image1", "registry": "registry1", "vulnerability_summary": map[string]int{"critical": 0, "high": 1, "medium": 3, "low": 4, "negligible": 5}, "image_assurance_results": map[string]interface{}{"disallowed": true}} + mockScan2 = map[string]interface{}{"image": "Demo mock Image2", "registry": "registry2", "vulnerability_summary": map[string]int{"critical": 0, "high": 0, "medium": 3, "low": 4, "negligible": 5}, "image_assurance_results": map[string]interface{}{"disallowed": false}} + mockScan3 = map[string]interface{}{"image": "Demo mock Image3", "registry": "Registry3", "vulnerability_summary": map[string]int{"critical": 0, "high": 0, "medium": 0, "low": 4, "negligible": 5}, "image_assurance_results": map[string]interface{}{"disallowed": true}} + mockScan4 = map[string]interface{}{"image": "Demo mock image4", "registry": "registry4", "vulnerability_summary": map[string]int{"critical": 0, "high": 0, "medium": 0, "low": 0, "negligible": 5}, "image_assurance_results": map[string]interface{}{"disallowed": true}} ) type DemoInptEval struct { diff --git a/msgservice/msgservice_test.go b/msgservice/msgservice_test.go index 5f2c40f4..281e09cb 100644 --- a/msgservice/msgservice_test.go +++ b/msgservice/msgservice_test.go @@ -3,7 +3,6 @@ package msgservice import ( "errors" "os" - "sync" "testing" "github.com/aquasecurity/postee/dbservice" @@ -38,66 +37,6 @@ func (inptEval *FailingInptEval) BuildAggregatedContent(items []map[string]strin func (inptEval *FailingInptEval) IsAggregationSupported() bool { return inptEval.expectedAggrError != nil } - -func TestInputs(t *testing.T) { - tests := []struct { - input []byte - caseDesc string - shouldPass bool - }{ - { - input: nil, - caseDesc: "Empty input", - shouldPass: false, - }, - { - input: []byte(invalidJson), - caseDesc: "Invalid Json", - shouldPass: false, - }, - } - for _, test := range tests { - validateInputValue(t, test.caseDesc, test.input, test.shouldPass) - } - -} -func validateInputValue(t *testing.T, caseDesc string, input []byte, shouldPass bool) { - dbPathReal := dbservice.DbPath - defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal - }() - dbservice.DbPath = "test_webhooks.db" - - demoEmailOutput := &DemoEmailOutput{ - emailCounts: 0, - } - - srvUrl := "" - expected := 0 - if shouldPass { - expected = 1 - } - - demoRoute := &routes.InputRoute{} - - demoRoute.Name = "demo-route" - - demoInptEval := &DemoInptEval{} - - demoEmailOutput.wg = &sync.WaitGroup{} - demoEmailOutput.wg.Add(expected) - - srv := new(MsgService) - srv.MsgHandling([]byte(input), demoEmailOutput, demoRoute, demoInptEval, &srvUrl) - - demoEmailOutput.wg.Wait() - - if demoEmailOutput.getEmailsCount() != expected { - t.Errorf("[%s] Wrong number of Send method calls: expected %d, got %d", caseDesc, expected, demoEmailOutput.getEmailsCount()) - } - -} func TestEvalError(t *testing.T) { dbPathReal := dbservice.DbPath defer func() { @@ -122,7 +61,7 @@ func TestEvalError(t *testing.T) { } srv := new(MsgService) - srv.MsgHandling([]byte(mockScan1), demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + srv.MsgHandling(mockScan1, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) if demoEmailOutput.getEmailsCount() > 0 { t.Errorf("Output shouldn't be called when evaluation is failed") @@ -157,33 +96,10 @@ func TestAggrEvalError(t *testing.T) { for i := 0; i < 2; i++ { srv := new(MsgService) - srv.MsgHandling([]byte(mockScan1), demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + srv.MsgHandling(mockScan1, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) } if demoEmailOutput.getEmailsCount() > 0 { t.Errorf("Output shouldn't be called when evaluation is failed") } } -func TestEmptyInput(t *testing.T) { - dbPathReal := dbservice.DbPath - defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal - }() - dbservice.DbPath = "test_webhooks.db" - - srvUrl := "" - - demoRoute := &routes.InputRoute{} - - demoRoute.Name = "demo-route" - - demoInptEval := &DemoInptEval{} - - srv := new(MsgService) - srv.MsgHandling([]byte("{}"), nil, demoRoute, demoInptEval, &srvUrl) - - if demoInptEval.renderCnt != 0 { - t.Errorf("Eval() shouldn't be called if no output is passed to ResultHandling()") - } -} diff --git a/router/api.go b/router/api.go index 518a15ae..f4c47eed 100644 --- a/router/api.go +++ b/router/api.go @@ -22,12 +22,13 @@ we want to add this as we want the consumer to be able to add a code for extendi when adding a route, the callback function will be part of each route func InputCallBack(inputMessage) (bool, error) */ -type InputCallbackFunc func(InputMessage interface{}) bool +type InputCallbackFunc func(inputMessage map[string]interface{}) bool //SetInputCallbackFunc The call back func will be called as the last evaluation method of the input rego, //it will be added to the rego with && operator and the entire input evaluation will pass through only if the callback returns true -func SetInputCallbackFunc(routeName string, callaback InputCallbackFunc) { +func SetInputCallbackFunc(routeName string, callback InputCallbackFunc) { + Instance().setInputCallbackFunc(routeName, callback) } func WithDefaultConfig(dbPath string) error { diff --git a/router/loads_test.go b/router/loads_test.go index 5dd95f45..c502baeb 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -80,7 +80,7 @@ type invctn struct { routeName string } -func (ctx *ctxWrapper) MsgHandling(input []byte, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, aquaServer *string) { +func (ctx *ctxWrapper) MsgHandling(input map[string]interface{}, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, aquaServer *string) { i := invctn{ fmt.Sprintf("%T", output), fmt.Sprintf("%T", inpteval), diff --git a/msgservice/regocriteria_test.go b/router/regocriteria_test.go similarity index 69% rename from msgservice/regocriteria_test.go rename to router/regocriteria_test.go index cd273c1d..1901ea55 100644 --- a/msgservice/regocriteria_test.go +++ b/router/regocriteria_test.go @@ -1,13 +1,4 @@ -package msgservice - -import ( - "os" - "sync" - "testing" - - "github.com/aquasecurity/postee/dbservice" - "github.com/aquasecurity/postee/routes" -) +package router var ( badRego = ` @@ -18,11 +9,15 @@ var ( m == "world" } ` + mockScan1 = map[string]interface{}{"image": "Demo mock image1", "registry": "registry1", "vulnerability_summary": map[string]int{"critical": 0, "high": 1, "medium": 3, "low": 4, "negligible": 5}, "image_assurance_results": map[string]interface{}{"disallowed": true}} + mockScan2 = map[string]interface{}{"image": "Demo mock Image2", "registry": "registry2", "vulnerability_summary": map[string]int{"critical": 0, "high": 0, "medium": 3, "low": 4, "negligible": 5}, "image_assurance_results": map[string]interface{}{"disallowed": false}} ) +//TODO re-implement +/* func TestRegoCriteria(t *testing.T) { tests := []struct { - input string + input map[string]interface{} caseDesc string regoCriteria string shouldPass bool @@ -57,7 +52,7 @@ func TestRegoCriteria(t *testing.T) { } } -func validateRegoInput(t *testing.T, caseDesc string, input string, regoCriteria string, shouldPass bool) { +func validateRegoInput(t *testing.T, caseDesc string, input map[string]interface{}, regoCriteria string, shouldPass bool) { dbPathReal := dbservice.DbPath defer func() { os.Remove(dbservice.DbPath) @@ -95,3 +90,4 @@ func validateRegoInput(t *testing.T, caseDesc string, input string, regoCriteria } } +*/ diff --git a/router/routehandling_test.go b/router/routehandling_test.go index 8602c55c..c41056cc 100644 --- a/router/routehandling_test.go +++ b/router/routehandling_test.go @@ -1,194 +1,25 @@ package router import ( + "io/ioutil" + "path/filepath" "testing" "time" ) var ( - singleRoute string = ` -Name: tenant - -routes: -- name: route1 - outputs: ["my-slack"] - template: raw - plugins: - Policy-Show-All: true - -templates: -- name: raw - body: | - package postee - result:=input - -outputs: -- name: my-slack - type: slack - enable: true - url: https://hooks.slack.com/services/ABCDF/1234/TTT` - noAssociatedOutput string = ` -Name: tenant - -routes: -- name: route1 - template: raw - plugins: - Policy-Show-All: true - -templates: -- name: raw - body: | - package postee - result:=input - -outputs: -- name: my-slack - type: slack - enable: true - url: https://hooks.slack.com/services/ABCDF/1234/TTT` - twoRoutes string = ` -Name: tenant - -routes: -- name: route1 - outputs: ["my-slack"] - template: raw - plugins: - Policy-Show-All: true - -- name: route2 - outputs: ["my-slack"] - template: raw - plugins: - Policy-Show-All: true - -templates: -- name: raw - body: | - package postee - result:=input - -outputs: -- name: my-slack - type: slack - enable: true - url: https://hooks.slack.com/services/ABCDF/1234/TTT` - - twoOutputs string = ` -Name: tenant - -routes: -- name: route1 - outputs: ["my-slack", "my-slack2"] - template: raw - plugins: - Policy-Show-All: true - -templates: -- name: raw - body: | - package postee - result:=input - -outputs: -- name: my-slack - type: slack - enable: true - url: https://hooks.slack.com/services/ABCDF/1234/XXX -- name: my-slack2 - type: slack - enable: true - url: https://hooks.slack.com/services/ABCDF/1234/TTT` - noOutputs string = ` -Name: tenant - -routes: -- name: route1 - outputs: ["my-slack3"] - template: raw - plugins: - Policy-Show-All: true - -templates: -- name: raw - body: | - package postee - result:=input` - noTemplates string = ` -Name: tenant - -routes: -- name: route1 - outputs: ["my-slack", "my-slack2"] - template: raw - plugins: - Policy-Show-All: true - -outputs: -- name: my-slack - type: slack - enable: true - url: https://hooks.slack.com/services/ABCDF/1234/XXX -- name: my-slack2 - type: slack - enable: true - url: https://hooks.slack.com/services/ABCDF/1234/TTT` - invalidTemplate string = ` -Name: tenant - -routes: -- name: route1 - outputs: ["my-slack"] - template: rawx - plugins: - Policy-Show-All: true - -templates: -- name: raw - body: | - package postee - result:=input - -outputs: -- name: my-slack - type: slack - enable: true - url: https://hooks.slack.com/services/ABCDF/1234/TTT` - invalidOutput string = ` -Name: tenant - -routes: -- name: route1 - outputs: ["x-slack"] - template: raw - plugins: - Policy-Show-All: true - -templates: -- name: raw - body: | - package postee - result:=input - -outputs: -- name: my-slack - type: slack - enable: true - url: https://hooks.slack.com/services/ABCDF/1234/TTT` - payload = `{"image" : "alpine"}` ) func TestHandling(t *testing.T) { tests := []struct { caseDesc string - cfg string + cfgPath string expctdInvctns []invctn }{ { "Single Route", - singleRoute, + "single-route.yaml", []invctn{ { "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", @@ -197,7 +28,7 @@ func TestHandling(t *testing.T) { }, { "2 Routes", - twoRoutes, + "two-routes.yaml", []invctn{ { "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", @@ -209,7 +40,7 @@ func TestHandling(t *testing.T) { }, { "2 Outputs per single route", - twoOutputs, + "two-outputs.yaml", []invctn{ { "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", @@ -221,43 +52,78 @@ func TestHandling(t *testing.T) { }, { "No Outputs configured", - noOutputs, + "no-outputs.yaml", []invctn{}, }, { "No Template configured", - noTemplates, + "no-templates.yaml", []invctn{}, }, { "Invalid Output reference", - invalidOutput, + "invalid-output.yaml", []invctn{}, }, { "Invalid Template reference", - invalidTemplate, + "invalid-template.yaml", []invctn{}, }, { "No outputs associated with route", - noAssociatedOutput, + "no-associated-output.yaml", + []invctn{}, + }, + { + "Route with input filter", + "with-input-filter.yaml", + []invctn{ + { + "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", + }, + }, + }, + { + "Route with input filter - no match", + "with-input-filter-no-match.yaml", + []invctn{}, + }, + { + "Route with input filter (empty)", + "with-input-filter-empty.yaml", + []invctn{ + { + "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", + }, + }, + }, + { + "Route with input filter - invalid", + "with-input-filter-invalid.yaml", []invctn{}, }, } for _, test := range tests { - runTestRouteHandlingCase(t, test.caseDesc, test.cfg, test.expctdInvctns) + runTestRouteHandlingCase(t, test.caseDesc, test.cfgPath, test.expctdInvctns) } } -func runTestRouteHandlingCase(t *testing.T, caseDesc string, cfg string, expctdInvctns []invctn) { +func runTestRouteHandlingCase(t *testing.T, caseDesc string, cfgPath string, expctdInvctns []invctn) { actualInvctCnt := 0 t.Logf("Case: %s\n", caseDesc) wrap := ctxWrapper{} - wrap.setup(cfg) + + b, err := ioutil.ReadFile(filepath.Join("testdata/configs", cfgPath)) + if err != nil { + t.Errorf("Failed to open file %s, %s", cfgPath, err) + } + + wrap.setup(string(b)) defer wrap.teardown() - err := wrap.instance.ApplyFileCfg(wrap.cfgPath, false) + err = wrap.instance.ApplyFileCfg(wrap.cfgPath, false) + if err != nil { t.Fatalf("[%s] Unexpected error %v", caseDesc, err) } @@ -304,11 +170,17 @@ func TestInvalidRouteName(t *testing.T) { expctdInvctns := 0 actualInvctCnt := 0 wrap := ctxWrapper{} - wrap.setup(singleRoute) + + b, err := ioutil.ReadFile("testdata/configs/single-route.yaml") + if err != nil { + t.Errorf("Failed to open file %s, %s", "single-route.yaml", err) + } + + wrap.setup(string(b)) defer wrap.teardown() - err := wrap.instance.ApplyFileCfg(wrap.cfgPath, false) + err = wrap.instance.ApplyFileCfg(wrap.cfgPath, false) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -333,11 +205,17 @@ func TestSend(t *testing.T) { expctdInvctns := 1 actualInvctCnt := 0 wrap := ctxWrapper{} - wrap.setup(singleRoute) + + b, err := ioutil.ReadFile("testdata/configs/single-route.yaml") + if err != nil { + t.Errorf("Failed to open file %s, %s", "single-route.yaml", err) + } + + wrap.setup(string(b)) defer wrap.teardown() - err := wrap.instance.ApplyFileCfg(wrap.cfgPath, false) + err = wrap.instance.ApplyFileCfg(wrap.cfgPath, false) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -357,3 +235,71 @@ func TestSend(t *testing.T) { } } } + +func TestCallBack(t *testing.T) { + tests := []struct { + name string + callback InputCallbackFunc + expctdInvctns int + }{ + { + name: "negative response", + callback: func(inputMessage map[string]interface{}) bool { + return false + }, + expctdInvctns: 0, + }, + { + name: "positive response", + callback: func(inputMessage map[string]interface{}) bool { + return true + }, + expctdInvctns: 1, + }, + { + name: "no callback", + callback: nil, + expctdInvctns: 1, + }, + } + b, err := ioutil.ReadFile("testdata/configs/single-route.yaml") + + if err != nil { + t.Errorf("Failed to open file %s, %s", "single-route.yaml", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actualInvctCnt := 0 + wrap := ctxWrapper{} + + wrap.setup(string(b)) + + defer wrap.teardown() + + err = wrap.instance.ApplyFileCfg(wrap.cfgPath, false) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + if tt.callback != nil { + wrap.instance.setInputCallbackFunc("route1", tt.callback) + } + + wrap.instance.Send([]byte(payload)) + timeout := time.After(1 * time.Second) + for { + select { + case <-timeout: + return + case <-wrap.buff: + actualInvctCnt++ + if actualInvctCnt != tt.expctdInvctns { + t.Errorf("Incorrect number of invocations! expected %d, got %d \n", tt.expctdInvctns, actualInvctCnt) + return + } + } + } + }) + } +} diff --git a/router/router.go b/router/router.go index 5bb0df1c..cc7e003f 100644 --- a/router/router.go +++ b/router/router.go @@ -2,6 +2,7 @@ package router import ( "bytes" + "encoding/json" "fmt" "io/ioutil" "log" @@ -31,17 +32,18 @@ const ( ) type Router struct { - mutexScan sync.Mutex - quit chan struct{} - queue chan []byte - ticker *time.Ticker - stopTicker chan struct{} - cfgfile string - aquaServer string - outputs map[string]outputs.Output - inputRoutes map[string]*routes.InputRoute - templates map[string]data.Inpteval - synchronous bool + mutexScan sync.Mutex + quit chan struct{} + queue chan []byte + ticker *time.Ticker + stopTicker chan struct{} + cfgfile string + aquaServer string + outputs map[string]outputs.Output + inputRoutes map[string]*routes.InputRoute + templates map[string]data.Inpteval + synchronous bool + inputCallBacks map[string][]InputCallbackFunc } var ( @@ -143,6 +145,8 @@ func (ctx *Router) cleanInstance() { ctx.outputs = map[string]outputs.Output{} ctx.inputRoutes = map[string]*routes.InputRoute{} ctx.templates = map[string]data.Inpteval{} + ctx.inputCallBacks = map[string][]InputCallbackFunc{} + ctx.ticker = nil ctx.quit = nil } @@ -298,10 +302,17 @@ func (ctx *Router) load() error { } return nil } +func (ctx *Router) setInputCallbackFunc(routeName string, callback InputCallbackFunc) { + inputCallBacks := ctx.inputCallBacks[routeName] + inputCallBacks = append(inputCallBacks, callback) + + ctx.inputCallBacks[routeName] = inputCallBacks +} func (ctx *Router) addRoute(r *routes.InputRoute) { ctx.inputRoutes[r.Name] = routes.ConfigureAggrTimeout(r) } + func (ctx *Router) deleteRoute(name string) error { r, ok := ctx.inputRoutes[name] if !ok { @@ -379,7 +390,7 @@ func removeOutputFromRoute(r *routes.InputRoute, outputName string) { } type service interface { - MsgHandling(input []byte, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, aquaServer *string) + MsgHandling(input map[string]interface{}, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, aquaServer *string) } var getScanService = func() service { @@ -400,6 +411,27 @@ func (ctx *Router) HandleRoute(routeName string, in []byte) { log.Printf("route %q has no outputs", routeName) return } + inMsg := map[string]interface{}{} + if err := json.Unmarshal(in, &inMsg); err != nil { + utils.PrnInputLogs("json.Unmarshal error for %q: %v", in, err) + return + } + + if ok, err := regoservice.DoesMatchRegoCriteria(inMsg, r.Input); err != nil { + utils.PrnInputLogs("Error while evaluating rego rule %s :%v for the input %s", r.Input, err, in) + return + } else if !ok { + utils.PrnInputLogs("Input %s... doesn't match a REGO rule: %s", in, r.Input) + return + } + + inputCallbacks := ctx.inputCallBacks[routeName] + + for _, callback := range inputCallbacks { + if !callback(inMsg) { + return + } + } for _, outputName := range r.Outputs { pl, ok := ctx.outputs[outputName] if !ok { @@ -415,9 +447,9 @@ func (ctx *Router) HandleRoute(routeName string, in []byte) { log.Printf("route %q is associated with template %q", routeName, r.Template) if ctx.synchronous { - getScanService().MsgHandling(in, pl, r, tmpl, &ctx.aquaServer) + getScanService().MsgHandling(inMsg, pl, r, tmpl, &ctx.aquaServer) } else { - go getScanService().MsgHandling(in, pl, r, tmpl, &ctx.aquaServer) + go getScanService().MsgHandling(inMsg, pl, r, tmpl, &ctx.aquaServer) } } } diff --git a/router/testdata/configs/invalid-output.yaml b/router/testdata/configs/invalid-output.yaml new file mode 100644 index 00000000..f8b82eb7 --- /dev/null +++ b/router/testdata/configs/invalid-output.yaml @@ -0,0 +1,20 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["x-slack"] + template: raw + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/invalid-template.yaml b/router/testdata/configs/invalid-template.yaml new file mode 100644 index 00000000..008224ea --- /dev/null +++ b/router/testdata/configs/invalid-template.yaml @@ -0,0 +1,20 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack"] + template: rawx + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/no-associated-output.yaml b/router/testdata/configs/no-associated-output.yaml new file mode 100644 index 00000000..04780d0d --- /dev/null +++ b/router/testdata/configs/no-associated-output.yaml @@ -0,0 +1,19 @@ +Name: tenant + +routes: +- name: route1 + template: raw + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/no-outputs.yaml b/router/testdata/configs/no-outputs.yaml new file mode 100644 index 00000000..b1c7ee11 --- /dev/null +++ b/router/testdata/configs/no-outputs.yaml @@ -0,0 +1,14 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack3"] + template: raw + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input \ No newline at end of file diff --git a/router/testdata/configs/no-templates.yaml b/router/testdata/configs/no-templates.yaml new file mode 100644 index 00000000..3fbb55d2 --- /dev/null +++ b/router/testdata/configs/no-templates.yaml @@ -0,0 +1,18 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack", "my-slack2"] + template: raw + plugins: + Policy-Show-All: true + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/XXX +- name: my-slack2 + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/single-route.yaml b/router/testdata/configs/single-route.yaml new file mode 100644 index 00000000..8ca1f4ec --- /dev/null +++ b/router/testdata/configs/single-route.yaml @@ -0,0 +1,20 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack"] + template: raw + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/two-outputs.yaml b/router/testdata/configs/two-outputs.yaml new file mode 100644 index 00000000..3a076c11 --- /dev/null +++ b/router/testdata/configs/two-outputs.yaml @@ -0,0 +1,24 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack", "my-slack2"] + template: raw + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/XXX +- name: my-slack2 + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/two-routes.yaml b/router/testdata/configs/two-routes.yaml new file mode 100644 index 00000000..89d2d4eb --- /dev/null +++ b/router/testdata/configs/two-routes.yaml @@ -0,0 +1,26 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack"] + template: raw + plugins: + Policy-Show-All: true + +- name: route2 + outputs: ["my-slack"] + template: raw + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/with-input-filter-empty.yaml b/router/testdata/configs/with-input-filter-empty.yaml new file mode 100644 index 00000000..7ac73f41 --- /dev/null +++ b/router/testdata/configs/with-input-filter-empty.yaml @@ -0,0 +1,21 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack"] + template: raw + input: + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/with-input-filter-invalid.yaml b/router/testdata/configs/with-input-filter-invalid.yaml new file mode 100644 index 00000000..23f76f96 --- /dev/null +++ b/router/testdata/configs/with-input-filter-invalid.yaml @@ -0,0 +1,26 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack"] + template: raw + input: | + default input = false + + hello { + m := input.message + m == "world" } + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/with-input-filter-no-match.yaml b/router/testdata/configs/with-input-filter-no-match.yaml new file mode 100644 index 00000000..97852c28 --- /dev/null +++ b/router/testdata/configs/with-input-filter-no-match.yaml @@ -0,0 +1,21 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack"] + template: raw + input: contains(input.image, "image2") + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/router/testdata/configs/with-input-filter.yaml b/router/testdata/configs/with-input-filter.yaml new file mode 100644 index 00000000..a2ddd08d --- /dev/null +++ b/router/testdata/configs/with-input-filter.yaml @@ -0,0 +1,21 @@ +Name: tenant + +routes: +- name: route1 + outputs: ["my-slack"] + template: raw + input: contains(input.image, "alpine") + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT \ No newline at end of file diff --git a/utils/utils.go b/utils/utils.go index 10a4345f..d67ca730 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -54,3 +54,16 @@ func PathExists(name string) bool { _, err := os.Stat(name) return !os.IsNotExist(err) } + +func PrnInputLogs(msg string, v ...interface{}) { + maxLen := 20 + for idx, e := range v { + b, ok := e.([]byte) + if ok { + if l := len(b); l > maxLen { + v[idx] = string(b[:maxLen]) + } + } + } + log.Printf(msg, v...) +} From 3088306550478071fe2d875a2ef8ecdf65816de1 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Tue, 24 Aug 2021 20:28:03 +0600 Subject: [PATCH 13/61] fixed race condition --- router/routehandling_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/routehandling_test.go b/router/routehandling_test.go index c41056cc..043be0d1 100644 --- a/router/routehandling_test.go +++ b/router/routehandling_test.go @@ -286,7 +286,7 @@ func TestCallBack(t *testing.T) { wrap.instance.setInputCallbackFunc("route1", tt.callback) } - wrap.instance.Send([]byte(payload)) + wrap.instance.handle([]byte(payload)) timeout := time.After(1 * time.Second) for { select { From b75ef77a49f15d0bffed9386c090151b38e0b035 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Tue, 24 Aug 2021 20:29:03 +0600 Subject: [PATCH 14/61] removed not needed test --- router/regocriteria_test.go | 93 ------------------------------------- 1 file changed, 93 deletions(-) delete mode 100644 router/regocriteria_test.go diff --git a/router/regocriteria_test.go b/router/regocriteria_test.go deleted file mode 100644 index 1901ea55..00000000 --- a/router/regocriteria_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package router - -var ( - badRego = ` - default input = false - - hello { - m := input.message - m == "world" - } -` - mockScan1 = map[string]interface{}{"image": "Demo mock image1", "registry": "registry1", "vulnerability_summary": map[string]int{"critical": 0, "high": 1, "medium": 3, "low": 4, "negligible": 5}, "image_assurance_results": map[string]interface{}{"disallowed": true}} - mockScan2 = map[string]interface{}{"image": "Demo mock Image2", "registry": "registry2", "vulnerability_summary": map[string]int{"critical": 0, "high": 0, "medium": 3, "low": 4, "negligible": 5}, "image_assurance_results": map[string]interface{}{"disallowed": false}} -) - -//TODO re-implement -/* -func TestRegoCriteria(t *testing.T) { - tests := []struct { - input map[string]interface{} - caseDesc string - regoCriteria string - shouldPass bool - }{ - { - input: mockScan1, - caseDesc: "Empty rule should allow", - regoCriteria: "", - shouldPass: true, - }, - { - input: mockScan1, - caseDesc: "Matching rule", - regoCriteria: `contains(input.image, "image1")`, - shouldPass: true, - }, - { - input: mockScan2, - caseDesc: "Not matching rule", - regoCriteria: `contains(input.image, "image1")`, - shouldPass: false, - }, - { - input: mockScan1, - caseDesc: "Invalid rule", - regoCriteria: badRego, - shouldPass: false, - }, - } - for _, test := range tests { - validateRegoInput(t, test.caseDesc, test.input, test.regoCriteria, test.shouldPass) - } - -} -func validateRegoInput(t *testing.T, caseDesc string, input map[string]interface{}, regoCriteria string, shouldPass bool) { - dbPathReal := dbservice.DbPath - defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal - }() - dbservice.DbPath = "test_webhooks.db" - - demoEmailOutput := &DemoEmailOutput{ - emailCounts: 0, - } - - srvUrl := "" - expected := 0 - if shouldPass { - expected = 1 - } - - demoRoute := &routes.InputRoute{} - - demoRoute.Name = "demo-route" - demoRoute.Input = regoCriteria - - demoInptEval := &DemoInptEval{} - - demoEmailOutput.wg = &sync.WaitGroup{} - demoEmailOutput.wg.Add(expected) - - srv := new(MsgService) - srv.MsgHandling([]byte(input), demoEmailOutput, demoRoute, demoInptEval, &srvUrl) - - demoEmailOutput.wg.Wait() - - if demoEmailOutput.getEmailsCount() != expected { - t.Errorf("[%s] Wrong number of Send method calls: expected %d, got %d", caseDesc, expected, demoEmailOutput.getEmailsCount()) - } - -} -*/ From d4a68d7d21b0455c39b9ec0ccffcf230d63df305 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Tue, 24 Aug 2021 20:41:39 +0600 Subject: [PATCH 15/61] added misc fix to clean up input callbacks --- router/router.go | 1 + 1 file changed, 1 insertion(+) diff --git a/router/router.go b/router/router.go index cc7e003f..33ed19e7 100644 --- a/router/router.go +++ b/router/router.go @@ -320,6 +320,7 @@ func (ctx *Router) deleteRoute(name string) error { } r.StopScheduler() delete(ctx.inputRoutes, name) + delete(ctx.inputCallBacks, name) return nil } From 3256a6fe834371cce2eff78e0d410a7e386736bf Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Tue, 30 Nov 2021 13:58:10 +0600 Subject: [PATCH 16/61] fixed pointer reference error in loop --- router/loads_test.go | 2 ++ router/routehandling_test.go | 21 ++++++++++++--------- router/router.go | 4 ++-- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/router/loads_test.go b/router/loads_test.go index c502baeb..34fde398 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -78,6 +78,7 @@ type invctn struct { outputCls string templateCls string routeName string + found bool } func (ctx *ctxWrapper) MsgHandling(input map[string]interface{}, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, aquaServer *string) { @@ -85,6 +86,7 @@ func (ctx *ctxWrapper) MsgHandling(input map[string]interface{}, output outputs. fmt.Sprintf("%T", output), fmt.Sprintf("%T", inpteval), route.Name, + false, } ctx.buff <- i } diff --git a/router/routehandling_test.go b/router/routehandling_test.go index 043be0d1..046733c4 100644 --- a/router/routehandling_test.go +++ b/router/routehandling_test.go @@ -22,7 +22,7 @@ func TestHandling(t *testing.T) { "single-route.yaml", []invctn{ { - "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", + "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", false, }, }, }, @@ -31,10 +31,10 @@ func TestHandling(t *testing.T) { "two-routes.yaml", []invctn{ { - "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", + "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", false, }, { - "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route2", + "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route2", false, }, }, }, @@ -43,10 +43,10 @@ func TestHandling(t *testing.T) { "two-outputs.yaml", []invctn{ { - "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", + "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", false, }, { - "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", + "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", false, }, }, }, @@ -80,7 +80,7 @@ func TestHandling(t *testing.T) { "with-input-filter.yaml", []invctn{ { - "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", + "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", false, }, }, }, @@ -94,7 +94,7 @@ func TestHandling(t *testing.T) { "with-input-filter-empty.yaml", []invctn{ { - "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", + "*outputs.SlackOutput", "*regoservice.regoEvaluator", "route1", false, }, }, }, @@ -143,11 +143,14 @@ func runTestRouteHandlingCase(t *testing.T, caseDesc string, cfgPath string, exp return case r := <-wrap.buff: t.Logf("[%s] received invocation (%s, %s, %s)", caseDesc, r.routeName, r.outputCls, r.templateCls) - actualInvctCnt++ found := false - for _, expect := range expctdInvctns { + for i, expect := range expctdInvctns { if r == expect { + actualInvctCnt++ + t.Logf("r: %v\n", r) + t.Logf("expect: %v\n", expect) found = true + expctdInvctns[i].found = true // won't be matched anymore break } } diff --git a/router/router.go b/router/router.go index 33ed19e7..f53671c5 100644 --- a/router/router.go +++ b/router/router.go @@ -278,8 +278,8 @@ func (ctx *Router) load() error { } //---------------------------------------------------- - for _, r := range tenant.InputRoutes { - ctx.addRoute(&r) + for i := range tenant.InputRoutes { + ctx.addRoute(&tenant.InputRoutes[i]) } for _, t := range tenant.Templates { err := ctx.initTemplate(&t) From aea25072bfbe11576482a90b7cb5241c0272f998 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Tue, 30 Nov 2021 18:04:49 +0600 Subject: [PATCH 17/61] added default db path --- regoservice/eval.go | 6 +++++- router/api.go | 38 +++++++++++++++++++++------------- router/api_integration_test.go | 2 +- router/initoutputs_test.go | 12 +++++++---- 4 files changed, 38 insertions(+), 20 deletions(-) diff --git a/regoservice/eval.go b/regoservice/eval.go index f65164f4..dc4e9431 100644 --- a/regoservice/eval.go +++ b/regoservice/eval.go @@ -203,7 +203,11 @@ func buildAggregatedRego(query *rego.PreparedEvalQuery) (*rego.PreparedEvalQuery ctx := context.Background() //execute query with empty input and check if aggregation package is defined - rs, _ := query.Eval(ctx, rego.EvalInput(make(map[string]interface{}))) + rs, err := query.Eval(ctx, rego.EvalInput(make(map[string]interface{}))) + + if err != nil { + return nil, err + } if len(rs) == 0 || len(rs[0].Expressions) == 0 { return nil, errors.New("no results") //TODO error definition diff --git a/router/api.go b/router/api.go index f4c47eed..a8d52511 100644 --- a/router/api.go +++ b/router/api.go @@ -11,17 +11,9 @@ import ( const ( defaultConfigPath = "config/cfg.yaml" + defaultDbPath = "./webhooks.db" ) -/* -TODO -Is it possible to add a callback func to the route "input", this callback func, will be called when evaluating the input "rego" -and if the callback func returns "false" then the evaluation will fail and the message is not sent. - -we want to add this as we want the consumer to be able to add a code for extending the "input" evaluation. -when adding a route, the callback function will be part of each route -func InputCallBack(inputMessage) (bool, error) -*/ type InputCallbackFunc func(inputMessage map[string]interface{}) bool //SetInputCallbackFunc The call back func will be called as the last evaluation method of the input rego, @@ -31,20 +23,38 @@ func SetInputCallbackFunc(routeName string, callback InputCallbackFunc) { Instance().setInputCallbackFunc(routeName, callback) } -func WithDefaultConfig(dbPath string) error { - return WithFileConfig(defaultConfigPath, dbPath) +func WithDefaultConfig() error { + return WithFileConfig(defaultConfigPath) } -func WithFileConfig(cfgPath, dbPath string) error { +func WithFileConfig(cfgPath string) error { Instance().Terminate() - dbservice.DbPath = dbPath + dbservice.DbPath = defaultDbPath return Instance().ApplyFileCfg(cfgPath, true) } -func WithNewConfig(tenantName, dbPath string) { //tenant name + +func WithNewConfig(tenantName string) { //tenant name + Instance().Terminate() + dbservice.DbPath = defaultDbPath + Instance().initCfg(true) +} + +//initialize instance with custom db location +func WithNewConfigAndDbPath(tenantName, dbPath string) { //tenant name Instance().Terminate() dbservice.DbPath = dbPath Instance().initCfg(true) } +func WithDefaultConfigAndDbPath(dbPath string) error { + return WithFileConfigAndDbPath(defaultConfigPath, dbPath) +} + +func WithFileConfigAndDbPath(cfgPath, dbPath string) error { + Instance().Terminate() + dbservice.DbPath = dbPath + return Instance().ApplyFileCfg(cfgPath, true) +} + func AquaServerUrl(aquaServerUrl string) { //optional Instance().setAquaServerUrl(aquaServerUrl) } diff --git a/router/api_integration_test.go b/router/api_integration_test.go index 9942a47a..1b35ba91 100644 --- a/router/api_integration_test.go +++ b/router/api_integration_test.go @@ -51,7 +51,7 @@ func TestAudit(t *testing.T) { })) defer ts.Close() - router.WithNewConfig("test", "./webhook.db") + router.WithNewConfig("test") err := router.AddTemplate(&data.Template{ Name: "audit-json-template", diff --git a/router/initoutputs_test.go b/router/initoutputs_test.go index 9b39644c..2e44cc31 100644 --- a/router/initoutputs_test.go +++ b/router/initoutputs_test.go @@ -209,12 +209,16 @@ func TestBuildAndInitOtpt(t *testing.T) { }, } for _, test := range tests { - o, _ := buildAndInitOtpt(&test.outputSettings, "") //TODO handle error - if test.shouldFail && o != nil { + o, err := buildAndInitOtpt(&test.outputSettings, "") + + if !test.shouldFail && err != nil { + t.Fatalf("Unexpected error %v", err) + } + + if test.shouldFail && o != nil && err == nil { t.Fatalf("No output expected for %s test case", test.caseDesc) - } else if !test.shouldFail && o == nil { - t.Fatalf("Not expected output returned for %s test case", test.caseDesc) } + actualOutputCls := fmt.Sprintf("%T", o) if actualOutputCls != test.expectedOutputClass { t.Errorf("[%s] Incorrect output type, expected %s, got %s", test.caseDesc, test.expectedOutputClass, actualOutputCls) From 75895b17a96d5c22cb584cde65a518b8ee095749 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Tue, 30 Nov 2021 18:31:08 +0600 Subject: [PATCH 18/61] fixed TODO in Jira outputs --- outputs/jira.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/outputs/jira.go b/outputs/jira.go index 47de2955..d5324745 100644 --- a/outputs/jira.go +++ b/outputs/jira.go @@ -64,9 +64,9 @@ func (ctx *JiraAPI) CloneSettings() *data.OutputSettings { FixVersions: data.CopyStringArray(ctx.FixVersions), AffectsVersions: data.CopyStringArray(ctx.AffectsVersions), Labels: data.CopyStringArray(ctx.Labels), - //TODO Unknowns - Enable: true, - Type: "Jira", + Unknowns: cpyUnknowns(ctx.Unknowns), + Enable: true, + Type: "Jira", } } @@ -460,3 +460,11 @@ func isServerJira(rawUrl string) bool { return false } + +func cpyUnknowns(source map[string]string) map[string]string { + dst := make(map[string]string) + for k, v := range source { + dst[k] = v + } + return dst +} From 01e893dbdd446c2db4b70c19ceeaf1671d6ba239 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Tue, 30 Nov 2021 19:26:28 +0600 Subject: [PATCH 19/61] feat(postgres): added support postgres --- dbservice/{ => boltdb}/actions.go | 6 +- dbservice/boltdb/changedbpath_test.go | 18 ++ dbservice/{ => boltdb}/checker.go | 20 +- dbservice/{ => boltdb}/checker_test.go | 77 +++--- dbservice/{ => boltdb}/dbaggregator.go | 6 +- dbservice/{ => boltdb}/dbaggregator_test.go | 15 +- dbservice/{ => boltdb}/dbparam.go | 35 ++- dbservice/{ => boltdb}/dbparam_test.go | 13 +- dbservice/{ => boltdb}/dbservice_test.go | 52 ++-- dbservice/{ => boltdb}/delete.go | 2 +- dbservice/{ => boltdb}/init.go | 2 +- dbservice/{ => boltdb}/insert.go | 2 +- dbservice/{ => boltdb}/invalidinit_test.go | 31 ++- dbservice/{ => boltdb}/plgnstats.go | 10 +- dbservice/{ => boltdb}/plgnstats_test.go | 21 +- dbservice/{ => boltdb}/select.go | 2 +- dbservice/{ => boltdb}/sharedcfg.go | 16 +- dbservice/{ => boltdb}/sharedcfg_test.go | 39 +-- dbservice/changedbpath_test.go | 15 - dbservice/dbservice.go | 63 +++++ dbservice/postgresdb/actions.go | 47 ++++ dbservice/postgresdb/actions_test.go | 64 +++++ dbservice/postgresdb/checker.go | 68 +++++ dbservice/postgresdb/checker_test.go | 92 +++++++ dbservice/postgresdb/dbaggregator.go | 59 ++++ dbservice/postgresdb/dbaggregator_test.go | 114 ++++++++ dbservice/postgresdb/dbparam.go | 91 +++++++ dbservice/postgresdb/dbparam_test.go | 43 +++ dbservice/postgresdb/dbservice_test.go | 288 ++++++++++++++++++++ dbservice/postgresdb/delete.go | 21 ++ dbservice/postgresdb/init.go | 25 ++ dbservice/postgresdb/insert.go | 44 +++ dbservice/postgresdb/invalidinit_test.go | 80 ++++++ dbservice/postgresdb/plgstats.go | 33 +++ dbservice/postgresdb/plgstats_test.go | 90 ++++++ dbservice/postgresdb/sharedcfg.go | 58 ++++ dbservice/postgresdb/sharedcfg_test.go | 106 +++++++ go.mod | 3 + go.sum | 15 + main.go | 5 - msgservice/aggregatebytime_test.go | 14 +- msgservice/aggregatescan_test.go | 13 +- msgservice/applicationscopeowner_test.go | 13 +- msgservice/getuniqueid_test.go | 13 +- msgservice/msghandling.go | 6 +- msgservice/msgservice_test.go | 37 +-- msgservice/regocriteria_test.go | 14 +- router/loads_test.go | 11 +- router/router.go | 12 +- router/tenants.go | 14 +- webserver/webserver.go | 4 +- 51 files changed, 1703 insertions(+), 239 deletions(-) rename dbservice/{ => boltdb}/actions.go (80%) create mode 100644 dbservice/boltdb/changedbpath_test.go rename dbservice/{ => boltdb}/checker.go (75%) rename dbservice/{ => boltdb}/checker_test.go (64%) rename dbservice/{ => boltdb}/dbaggregator.go (91%) rename dbservice/{ => boltdb}/dbaggregator_test.go (86%) rename dbservice/{ => boltdb}/dbparam.go (57%) rename dbservice/{ => boltdb}/dbparam_test.go (85%) rename dbservice/{ => boltdb}/dbservice_test.go (76%) rename dbservice/{ => boltdb}/delete.go (94%) rename dbservice/{ => boltdb}/init.go (91%) rename dbservice/{ => boltdb}/insert.go (94%) rename dbservice/{ => boltdb}/invalidinit_test.go (73%) rename dbservice/{ => boltdb}/plgnstats.go (67%) rename dbservice/{ => boltdb}/plgnstats_test.go (69%) rename dbservice/{ => boltdb}/select.go (94%) rename dbservice/{ => boltdb}/sharedcfg.go (69%) rename dbservice/{ => boltdb}/sharedcfg_test.go (59%) delete mode 100644 dbservice/changedbpath_test.go create mode 100644 dbservice/dbservice.go create mode 100644 dbservice/postgresdb/actions.go create mode 100644 dbservice/postgresdb/actions_test.go create mode 100644 dbservice/postgresdb/checker.go create mode 100644 dbservice/postgresdb/checker_test.go create mode 100644 dbservice/postgresdb/dbaggregator.go create mode 100644 dbservice/postgresdb/dbaggregator_test.go create mode 100644 dbservice/postgresdb/dbparam.go create mode 100644 dbservice/postgresdb/dbparam_test.go create mode 100644 dbservice/postgresdb/dbservice_test.go create mode 100644 dbservice/postgresdb/delete.go create mode 100644 dbservice/postgresdb/init.go create mode 100644 dbservice/postgresdb/insert.go create mode 100644 dbservice/postgresdb/invalidinit_test.go create mode 100644 dbservice/postgresdb/plgstats.go create mode 100644 dbservice/postgresdb/plgstats_test.go create mode 100644 dbservice/postgresdb/sharedcfg.go create mode 100644 dbservice/postgresdb/sharedcfg_test.go diff --git a/dbservice/actions.go b/dbservice/boltdb/actions.go similarity index 80% rename from dbservice/actions.go rename to dbservice/boltdb/actions.go index 999f484d..8a533899 100644 --- a/dbservice/actions.go +++ b/dbservice/boltdb/actions.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "time" @@ -6,11 +6,11 @@ import ( bolt "go.etcd.io/bbolt" ) -func MayBeStoreMessage(message []byte, messageKey string, expired *time.Time) (wasStored bool, err error) { +func (boltDb *BoltDb) MayBeStoreMessage(message []byte, messageKey string, expired *time.Time) (wasStored bool, err error) { mutex.Lock() defer mutex.Unlock() - db, err := bolt.Open(DbPath, 0666, nil) + db, err := bolt.Open(boltDb.DbPath, 0666, nil) if err != nil { return false, err } diff --git a/dbservice/boltdb/changedbpath_test.go b/dbservice/boltdb/changedbpath_test.go new file mode 100644 index 00000000..d04c5bc5 --- /dev/null +++ b/dbservice/boltdb/changedbpath_test.go @@ -0,0 +1,18 @@ +package boltdb + +import ( + "testing" +) + +func TestChangeDbPath(t *testing.T) { + boltDb := NewBoltDb() + testPath := "/tmp/test.db" + storedPath := boltDb.DbPath + boltDb.ChangeDbPath(testPath) + defer func() { + boltDb.ChangeDbPath(storedPath) + }() + if boltDb.DbPath != testPath { + t.Errorf("path is not configured correctly, expected: %s, got %s", testPath, boltDb.DbPath) + } +} diff --git a/dbservice/checker.go b/dbservice/boltdb/checker.go similarity index 75% rename from dbservice/checker.go rename to dbservice/boltdb/checker.go index 66299f53..d92a8443 100644 --- a/dbservice/checker.go +++ b/dbservice/boltdb/checker.go @@ -1,23 +1,24 @@ -package dbservice +package boltdb import ( "bytes" + "fmt" "log" "time" bolt "go.etcd.io/bbolt" ) -func CheckSizeLimit() { +func (boltDb *BoltDb) CheckSizeLimit() { if DbSizeLimit == 0 { return } mutex.Lock() defer mutex.Unlock() - db, err := bolt.Open(DbPath, 0666, nil) + db, err := bolt.Open(boltDb.DbPath, 0666, nil) if err != nil { - log.Println("CheckSizeLimit: Can't open db:", DbPath) + log.Println("CheckSizeLimit: Can't open db:", boltDb.DbPath) return } defer db.Close() @@ -30,6 +31,7 @@ func CheckSizeLimit() { c := b.Cursor() size := 0 for k, v := c.First(); k != nil; k, v = c.Next() { + fmt.Println(string(v)) size += len(v) } if size > DbSizeLimit { @@ -42,18 +44,18 @@ func CheckSizeLimit() { } } -func CheckExpiredData() { +func (boltDb *BoltDb) CheckExpiredData() { mutex.Lock() defer mutex.Unlock() - db, err := bolt.Open(DbPath, 0666, nil) + db, err := bolt.Open(boltDb.DbPath, 0666, nil) if err != nil { - log.Println("CheckExpiredData: Can't open db:", DbPath) + log.Println("CheckExpiredData: Can't open db:", boltDb.DbPath) return } defer db.Close() - expired, err := getExpired(db) + expired, err := boltDb.getExpired(db) if err != nil { log.Println("Can't select expired data: ", err) return @@ -64,7 +66,7 @@ func CheckExpiredData() { } } -func getExpired(db *bolt.DB) (keys [][]byte, err error) { +func (boltDb *BoltDb) getExpired(db *bolt.DB) (keys [][]byte, err error) { keys = [][]byte{} ttlKeys := [][]byte{} diff --git a/dbservice/checker_test.go b/dbservice/boltdb/checker_test.go similarity index 64% rename from dbservice/checker_test.go rename to dbservice/boltdb/checker_test.go index 3160bd9f..baf41e10 100644 --- a/dbservice/checker_test.go +++ b/dbservice/boltdb/checker_test.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "os" @@ -9,15 +9,16 @@ import ( ) func TestExpiredDates(t *testing.T) { - dbPathReal := DbPath + boltDb := NewBoltDb() + dbPathReal := boltDb.DbPath realDueTimeBase := dueTimeBase defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(boltDb.DbPath) + boltDb.DbPath = dbPathReal dueTimeBase = realDueTimeBase }() dueTimeBase = time.Nanosecond - DbPath = "test_webhooks.db" + boltDb.DbPath = "test_webhooks.db" tests := []struct { title string delay int @@ -34,12 +35,12 @@ func TestExpiredDates(t *testing.T) { t.Log(test.title) if test.needRun { time.Sleep(time.Duration(test.delay) * time.Second) - CheckExpiredData() + boltDb.CheckExpiredData() } timeToExpire := time.Duration(test.uniqueMessageTimeoutSeconds) * time.Second expired := time.Now().UTC().Add(timeToExpire) - wasStored, err := MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, &expired) + wasStored, err := boltDb.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, &expired) if err != nil { t.Fatal("First Add AlpineImageResult Error", err) @@ -52,14 +53,15 @@ func TestExpiredDates(t *testing.T) { } func TestDbSizeLimnit(t *testing.T) { - dbPathReal := DbPath + boltDb := NewBoltDb() + dbPathReal := boltDb.DbPath realSizeLimit := DbSizeLimit defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(boltDb.DbPath) + boltDb.DbPath = dbPathReal DbSizeLimit = realSizeLimit }() - DbPath = "test_webhooks.db" + boltDb.DbPath = "test_webhooks.db" tests := []struct { title string @@ -73,16 +75,16 @@ func TestDbSizeLimnit(t *testing.T) { } DbSizeLimit = 1 - CheckSizeLimit() + boltDb.CheckSizeLimit() for _, test := range tests { t.Log(test.title) DbSizeLimit = test.limit if test.needRun { - CheckSizeLimit() + boltDb.CheckSizeLimit() } - isNew, err := MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) + isNew, err := boltDb.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) if err != nil { t.Fatal("First Add AlpineImageResult Error", err) } @@ -94,18 +96,19 @@ func TestDbSizeLimnit(t *testing.T) { } func TestWrongBuckets(t *testing.T) { + boltDb := NewBoltDb() savedDbBucketName := dbBucketName savedDbBucketExpiryDates := dbBucketExpiryDates - dbPathReal := DbPath + dbPathReal := boltDb.DbPath defer func() { dbBucketName = savedDbBucketName dbBucketExpiryDates = savedDbBucketExpiryDates - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(boltDb.DbPath) + boltDb.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" + boltDb.DbPath = "test_webhooks.db" - _, err := MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) + _, err := boltDb.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) if err != nil { t.Fatal(err) } @@ -113,32 +116,33 @@ func TestWrongBuckets(t *testing.T) { DbSizeLimit = 1 dbBucketName = "" dbBucketExpiryDates = "" - CheckSizeLimit() + boltDb.CheckSizeLimit() dbBucketName = "dbBucketName" - _, err = MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) + _, err = boltDb.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) if err == nil { t.Error("No error for empty dbBucketExpiryDates") } dbBucketExpiryDates = "dbBucketExpiryDates" dbBucketName = "" - _, err = MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) + _, err = boltDb.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) if err == nil { t.Error("No error for empty dbBucketName") } } func TestDbDelete(t *testing.T) { - dbPathReal := DbPath + boltDb := NewBoltDb() + dbPathReal := boltDb.DbPath defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(boltDb.DbPath) + boltDb.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" + boltDb.DbPath = "test_webhooks.db" - db, err := bolt.Open(DbPath, 0666, nil) + db, err := bolt.Open(boltDb.DbPath, 0666, nil) if err != nil { - t.Fatal("Can't open db:", DbPath) + t.Fatal("Can't open db:", boltDb.DbPath) return } defer db.Close() @@ -156,19 +160,20 @@ func TestDbDelete(t *testing.T) { } func TestWithoutAccessToDb(t *testing.T) { - dbPathReal := DbPath + boltDb := NewBoltDb() + dbPathReal := boltDb.DbPath defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(boltDb.DbPath) + boltDb.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" - db, err := bolt.Open(DbPath, 0220, nil) + boltDb.DbPath = "test_webhooks.db" + db, err := bolt.Open(boltDb.DbPath, 0220, nil) if err != nil { - t.Fatal("Can't open db:", DbPath) + t.Fatal("Can't open db:", boltDb.DbPath) return } db.Close() DbSizeLimit = 1 - CheckSizeLimit() - CheckExpiredData() + boltDb.CheckSizeLimit() + boltDb.CheckExpiredData() } diff --git a/dbservice/dbaggregator.go b/dbservice/boltdb/dbaggregator.go similarity index 91% rename from dbservice/dbaggregator.go rename to dbservice/boltdb/dbaggregator.go index 4db42558..82d0b764 100644 --- a/dbservice/dbaggregator.go +++ b/dbservice/boltdb/dbaggregator.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "encoding/json" @@ -6,14 +6,14 @@ import ( bolt "go.etcd.io/bbolt" ) -func AggregateScans(output string, +func (boltDb *BoltDb) AggregateScans(output string, currentScan map[string]string, scansPerTicket int, ignoreTheQuantity bool) ([]map[string]string, error) { mutex.Lock() defer mutex.Unlock() - db, err := bolt.Open(DbPath, 0666, nil) + db, err := bolt.Open(boltDb.DbPath, 0666, nil) if err != nil { return nil, err } diff --git a/dbservice/dbaggregator_test.go b/dbservice/boltdb/dbaggregator_test.go similarity index 86% rename from dbservice/dbaggregator_test.go rename to dbservice/boltdb/dbaggregator_test.go index b8f48131..eb98f489 100644 --- a/dbservice/dbaggregator_test.go +++ b/dbservice/boltdb/dbaggregator_test.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "os" @@ -6,6 +6,7 @@ import ( ) func TestAggregateScans(t *testing.T) { + db := NewBoltDb() var ( scan1 = map[string]string{"title": "t1", "description": "d1"} scan2 = map[string]string{"title": "t2", "description": "d2"} @@ -45,16 +46,16 @@ func TestAggregateScans(t *testing.T) { }, } - dbPathReal := DbPath + dbPathReal := db.DbPath defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" for i := 0; i < len(tests); i++ { test := tests[i] - aggregated, err := AggregateScans(test.output, test.currentScan, test.scansPerTicket, false) + aggregated, err := db.AggregateScans(test.output, test.currentScan, test.scansPerTicket, false) if err != nil { t.Errorf("AggregateScans Error: %v", err) continue @@ -76,7 +77,7 @@ func TestAggregateScans(t *testing.T) { } // Test of existence last scan in DB - lastScan, err := AggregateScans("jira", nil, 0, false) + lastScan, err := db.AggregateScans("jira", nil, 0, false) if err != nil { t.Fatalf("AggregateScans Error: %v", err) } diff --git a/dbservice/dbparam.go b/dbservice/boltdb/dbparam.go similarity index 57% rename from dbservice/dbparam.go rename to dbservice/boltdb/dbparam.go index f0f7df15..eae8544a 100644 --- a/dbservice/dbparam.go +++ b/dbservice/boltdb/dbparam.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "log" @@ -11,25 +11,34 @@ import ( var ( dbBucketName = "WebhookBucket" dbBucketAggregator = "WebhookAggregator" - dbBucketExpiryDates = "WebookExpiryDates" - DbBucketOutputStats = "WebhookOutputStats" - DbBucketSharedConfig = "WebhookSharedConfig" + dbBucketExpiryDates = "WebhookExpiryDates" + dbBucketOutputStats = "WebhookOutputStats" + dbBucketSharedConfig = "WebhookSharedConfig" DbSizeLimit = 0 - dueTimeBase = time.Hour * time.Duration(24) DateFmt = time.RFC3339Nano + dueTimeBase = time.Hour * time.Duration(24) - DbPath = "/server/database/webhooks.db" - mutex sync.Mutex + mutex sync.Mutex ) -func ChangeDbPath(newPath string) { +type BoltDb struct { + DbPath string +} + +func NewBoltDb() *BoltDb { + return &BoltDb{ + DbPath: "/server/database/webhooks.db", + } +} + +func (boltDb *BoltDb) ChangeDbPath(newPath string) { mutex.Lock() - DbPath = newPath + boltDb.DbPath = newPath mutex.Unlock() } -func SetNewDbPathFromEnv() { +func (boltDb *BoltDb) SetNewDbPathFromEnv() { newPath := os.Getenv("PATH_TO_DB") if newPath != "" { if _, err := os.Stat(newPath); err != nil { @@ -44,6 +53,10 @@ func SetNewDbPathFromEnv() { return } } - ChangeDbPath(newPath) + boltDb.ChangeDbPath(newPath) } } + +func (boltDb *BoltDb) SetDbSizeLimit(limit int) { + DbSizeLimit = limit +} diff --git a/dbservice/dbparam_test.go b/dbservice/boltdb/dbparam_test.go similarity index 85% rename from dbservice/dbparam_test.go rename to dbservice/boltdb/dbparam_test.go index 4c131825..a2350d0f 100644 --- a/dbservice/dbparam_test.go +++ b/dbservice/boltdb/dbparam_test.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "os" @@ -8,9 +8,10 @@ import ( ) func TestSetNewDbPathFromEnv(t *testing.T) { + db := NewBoltDb() envPathToDbOld := os.Getenv("PATH_TO_DB") defer os.Setenv("PATH_TO_DB", envPathToDbOld) - dbPathOld := DbPath + dbPathOld := db.DbPath defaultDbPath := "/server/database/webhooks.db" var tests = []struct { @@ -36,12 +37,12 @@ func TestSetNewDbPathFromEnv(t *testing.T) { } os.Chmod(baseDir, 0) } - SetNewDbPathFromEnv() + db.SetNewDbPathFromEnv() defer os.RemoveAll(baseDir) - defer ChangeDbPath(dbPathOld) + defer db.ChangeDbPath(dbPathOld) - if test.expectedDBPath != DbPath { - t.Errorf("[%s] Paths is not equals, expected: %s, got: %s", test.name, test.expectedDBPath, DbPath) + if test.expectedDBPath != db.DbPath { + t.Errorf("[%s] Paths is not equals, expected: %s, got: %s", test.name, test.expectedDBPath, db.DbPath) } }) diff --git a/dbservice/dbservice_test.go b/dbservice/boltdb/dbservice_test.go similarity index 76% rename from dbservice/dbservice_test.go rename to dbservice/boltdb/dbservice_test.go index 85cfd187..4e51b212 100644 --- a/dbservice/dbservice_test.go +++ b/dbservice/boltdb/dbservice_test.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "errors" @@ -76,23 +76,24 @@ var ( ) func TestStoreMessage(t *testing.T) { + db := NewBoltDb() var tests = []struct { input *string }{ {&AlpineImageResult}, } - dbPathReal := DbPath + dbPathReal := db.DbPath defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" for _, test := range tests { // Handling of first scan - isNew, err := MayBeStoreMessage([]byte(*test.input), AlpineImageKey, nil) + isNew, err := db.MayBeStoreMessage([]byte(*test.input), AlpineImageKey, nil) if err != nil { t.Errorf("Error: %s\n", err) } @@ -101,7 +102,7 @@ func TestStoreMessage(t *testing.T) { } // Handling of second scan with the same data - isNew, err = MayBeStoreMessage([]byte(*test.input), AlpineImageKey, nil) + isNew, err = db.MayBeStoreMessage([]byte(*test.input), AlpineImageKey, nil) if err != nil { t.Errorf("Error: %s\n", err) } @@ -112,11 +113,12 @@ func TestStoreMessage(t *testing.T) { } func TestInitError(t *testing.T) { + db := NewBoltDb() originalInit := Init - originalDbPath := DbPath + originalDbPath := db.DbPath initErr := errors.New("init error") - DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" Init = func(db *bbolt.DB, bucket string) error { return initErr @@ -124,10 +126,10 @@ func TestInitError(t *testing.T) { defer func() { Init = originalInit - os.Remove(DbPath) - DbPath = originalDbPath + os.Remove(db.DbPath) + db.DbPath = originalDbPath }() - isNew, err := MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) + isNew, err := db.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) if isNew { t.Errorf("Scan shouldn't be marked as new\n") @@ -139,11 +141,12 @@ func TestInitError(t *testing.T) { } func TestSelectError(t *testing.T) { + db := NewBoltDb() originalDbSelect := dbSelect - originalDbPath := DbPath + originalDbPath := db.DbPath selectErr := errors.New("select error") - DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" dbSelect = func(db *bbolt.DB, bucket, key string) (result []byte, err error) { return nil, selectErr @@ -151,10 +154,10 @@ func TestSelectError(t *testing.T) { defer func() { dbSelect = originalDbSelect - os.Remove(DbPath) - DbPath = originalDbPath + os.Remove(db.DbPath) + db.DbPath = originalDbPath }() - isNew, err := MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) + isNew, err := db.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) if isNew { t.Errorf("Scan shouldn't be marked as new\n") @@ -170,7 +173,7 @@ func TestInsertError(t *testing.T) { bucket string }{ {"WebhookBucket"}, - {"WebookExpiryDates"}, + {"WebhookExpiryDates"}, } for _, test := range tests { testBucketInsert(t, test.bucket) @@ -178,11 +181,12 @@ func TestInsertError(t *testing.T) { } func testBucketInsert(t *testing.T, testBucket string) { + db := NewBoltDb() originalDbInsert := dbInsert - originalDbPath := DbPath + originalDbPath := db.DbPath insertErr := errors.New("insert error") - DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" dbInsert = func(db *bbolt.DB, bucket string, key, value []byte) error { if bucket == testBucket { @@ -193,14 +197,14 @@ func testBucketInsert(t *testing.T, testBucket string) { defer func() { dbInsert = originalDbInsert - os.Remove(DbPath) - DbPath = originalDbPath + os.Remove(db.DbPath) + db.DbPath = originalDbPath }() - //expired shouldn't be null to cause insert to 'WebookExpiryDates' bucket + //expired shouldn't be null to cause insert to 'WebhookExpiryDates' bucket timeToExpire := time.Duration(1) * time.Second expired := time.Now().UTC().Add(timeToExpire) - isNew, err := MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, &expired) + isNew, err := db.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, &expired) if isNew { t.Errorf("Scan shouldn't be marked as new\n") diff --git a/dbservice/delete.go b/dbservice/boltdb/delete.go similarity index 94% rename from dbservice/delete.go rename to dbservice/boltdb/delete.go index 0156ab60..0e221685 100644 --- a/dbservice/delete.go +++ b/dbservice/boltdb/delete.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import bolt "go.etcd.io/bbolt" diff --git a/dbservice/init.go b/dbservice/boltdb/init.go similarity index 91% rename from dbservice/init.go rename to dbservice/boltdb/init.go index e1a60057..dbb87f57 100644 --- a/dbservice/init.go +++ b/dbservice/boltdb/init.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import "go.etcd.io/bbolt" diff --git a/dbservice/insert.go b/dbservice/boltdb/insert.go similarity index 94% rename from dbservice/insert.go rename to dbservice/boltdb/insert.go index 80119c15..4cedbfe1 100644 --- a/dbservice/insert.go +++ b/dbservice/boltdb/insert.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import bolt "go.etcd.io/bbolt" diff --git a/dbservice/invalidinit_test.go b/dbservice/boltdb/invalidinit_test.go similarity index 73% rename from dbservice/invalidinit_test.go rename to dbservice/boltdb/invalidinit_test.go index 93db6b10..1d1967a1 100644 --- a/dbservice/invalidinit_test.go +++ b/dbservice/boltdb/invalidinit_test.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "errors" @@ -8,6 +8,8 @@ import ( "go.etcd.io/bbolt" ) +var db = NewBoltDb() + var tests = []struct { caseDesc string errPrvdr func() error @@ -16,13 +18,13 @@ var tests = []struct { { caseDesc: "EnsureApiKey", errPrvdr: func() error { - return EnsureApiKey() + return db.EnsureApiKey() }, }, { caseDesc: "GetApiKey", errPrvdr: func() error { - _, err := GetApiKey() + _, err := db.GetApiKey() return err }, initIsNotCalled: true, @@ -30,33 +32,32 @@ var tests = []struct { { caseDesc: "RegisterPlgnInvctn", errPrvdr: func() error { - return RegisterPlgnInvctn("some-key") + return db.RegisterPlgnInvctn("some-key") }, }, { caseDesc: "MayBeStoreMessage", errPrvdr: func() error { - _, err := MayBeStoreMessage(nil, "a-b-c", nil) + _, err := db.MayBeStoreMessage(nil, "a-b-c", nil) return err }, }, { caseDesc: "AggregateScans", errPrvdr: func() error { - _, err := AggregateScans("", map[string]string{}, 1, false) + _, err := db.AggregateScans("", map[string]string{}, 1, false) return err }, }, } func TestInvalidDbPath(t *testing.T) { - - dbPathReal := DbPath + dbPathReal := db.DbPath defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - DbPath = "/tmp" + db.DbPath = "/tmp" for _, test := range tests { err := test.errPrvdr() @@ -68,13 +69,13 @@ func TestInvalidDbPath(t *testing.T) { } func TestBucketInitialization(t *testing.T) { savedInit := Init - dbPathReal := DbPath + dbPathReal := db.DbPath defer func() { - os.Remove(DbPath) + os.Remove(db.DbPath) Init = savedInit - DbPath = dbPathReal + db.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" expectedError := errors.New("weird error") Init = func(db *bbolt.DB, bucket string) error { return expectedError diff --git a/dbservice/plgnstats.go b/dbservice/boltdb/plgnstats.go similarity index 67% rename from dbservice/plgnstats.go rename to dbservice/boltdb/plgnstats.go index 74b63143..ec407d44 100644 --- a/dbservice/plgnstats.go +++ b/dbservice/boltdb/plgnstats.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "strconv" @@ -6,22 +6,22 @@ import ( bolt "go.etcd.io/bbolt" ) -func RegisterPlgnInvctn(name string) error { +func (boltDb *BoltDb) RegisterPlgnInvctn(name string) error { mutex.Lock() defer mutex.Unlock() - db, err := bolt.Open(DbPath, 0666, nil) + db, err := bolt.Open(boltDb.DbPath, 0666, nil) if err != nil { return err } defer db.Close() - err = Init(db, DbBucketOutputStats) + err = Init(db, dbBucketOutputStats) if err != nil { return err } err = db.Update(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(DbBucketOutputStats)) + bucket := tx.Bucket([]byte(dbBucketOutputStats)) var i int v := bucket.Get([]byte(name)) diff --git a/dbservice/plgnstats_test.go b/dbservice/boltdb/plgnstats_test.go similarity index 69% rename from dbservice/plgnstats_test.go rename to dbservice/boltdb/plgnstats_test.go index 7c330ac7..ae90888a 100644 --- a/dbservice/plgnstats_test.go +++ b/dbservice/boltdb/plgnstats_test.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "os" @@ -9,18 +9,19 @@ import ( ) func TestRegisterPlgnInvctn(t *testing.T) { - dbPathReal := DbPath + dbBolt := NewBoltDb() + dbPathReal := dbBolt.DbPath defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(dbBolt.DbPath) + dbBolt.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" + dbBolt.DbPath = "test_webhooks.db" expectedCnt := 3 keyToTest := "test" for i := 0; i < expectedCnt; i++ { - RegisterPlgnInvctn(keyToTest) + dbBolt.RegisterPlgnInvctn(keyToTest) } - r, err := getPlgnStats() + r, err := getPlgnStats(dbBolt) if err != nil { t.Fatal("error while getting value of API key") } @@ -30,16 +31,16 @@ func TestRegisterPlgnInvctn(t *testing.T) { } -func getPlgnStats() (r map[string]int, err error) { +func getPlgnStats(dbBolt *BoltDb) (r map[string]int, err error) { r = make(map[string]int) - db, err := bolt.Open(DbPath, 0444, nil) + db, err := bolt.Open(dbBolt.DbPath, 0444, nil) if err != nil { return nil, err } defer db.Close() err = db.View(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(DbBucketOutputStats)) + bucket := tx.Bucket([]byte(dbBucketOutputStats)) if bucket == nil { return nil //no bucket - empty stats will be returned } diff --git a/dbservice/select.go b/dbservice/boltdb/select.go similarity index 94% rename from dbservice/select.go rename to dbservice/boltdb/select.go index 3cf935e5..b6893a78 100644 --- a/dbservice/select.go +++ b/dbservice/boltdb/select.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( bolt "go.etcd.io/bbolt" diff --git a/dbservice/sharedcfg.go b/dbservice/boltdb/sharedcfg.go similarity index 69% rename from dbservice/sharedcfg.go rename to dbservice/boltdb/sharedcfg.go index 6fca7d29..dd2c1cbd 100644 --- a/dbservice/sharedcfg.go +++ b/dbservice/boltdb/sharedcfg.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "crypto/rand" @@ -13,17 +13,17 @@ const ( apiKeyName = "POSTEE_API_KEY" ) -func EnsureApiKey() error { +func (boltDb *BoltDb) EnsureApiKey() error { mutex.Lock() defer mutex.Unlock() - db, err := bolt.Open(DbPath, 0666, nil) + db, err := bolt.Open(boltDb.DbPath, 0666, nil) if err != nil { return err } defer db.Close() - err = Init(db, DbBucketOutputStats) + err = Init(db, dbBucketOutputStats) if err != nil { return err } @@ -33,19 +33,19 @@ func EnsureApiKey() error { return err } - err = dbInsert(db, DbBucketSharedConfig, []byte(apiKeyName), []byte(newApiKey)) + err = dbInsert(db, dbBucketSharedConfig, []byte(apiKeyName), []byte(newApiKey)) return err } -func GetApiKey() (string, error) { +func (boltDb *BoltDb) GetApiKey() (string, error) { var apiKey string = "" - db, err := bolt.Open(DbPath, 0444, nil) //should be enough + db, err := bolt.Open(boltDb.DbPath, 0444, nil) //should be enough if err != nil { return "", err } defer db.Close() err = db.View(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(DbBucketSharedConfig)) + bucket := tx.Bucket([]byte(dbBucketSharedConfig)) if bucket == nil { return errors.New("no bucket") //no bucket } diff --git a/dbservice/sharedcfg_test.go b/dbservice/boltdb/sharedcfg_test.go similarity index 59% rename from dbservice/sharedcfg_test.go rename to dbservice/boltdb/sharedcfg_test.go index 55b76835..5befd4f6 100644 --- a/dbservice/sharedcfg_test.go +++ b/dbservice/boltdb/sharedcfg_test.go @@ -1,4 +1,4 @@ -package dbservice +package boltdb import ( "os" @@ -6,14 +6,15 @@ import ( ) func TestApiKey(t *testing.T) { - dbPathReal := DbPath + db := NewBoltDb() + dbPathReal := db.DbPath defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" - EnsureApiKey() - key, err := GetApiKey() + db.DbPath = "test_webhooks.db" + db.EnsureApiKey() + key, err := db.GetApiKey() if err != nil { t.Fatal("error while getting value of API key") } @@ -22,13 +23,14 @@ func TestApiKey(t *testing.T) { } } func TestApiKeyWithoutInit(t *testing.T) { - dbPathReal := DbPath + db := NewBoltDb() + dbPathReal := db.DbPath defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" - key, err := GetApiKey() + db.DbPath = "test_webhooks.db" + key, err := db.GetApiKey() if err == nil { t.Fatal("Error is expected") } @@ -37,16 +39,17 @@ func TestApiKeyWithoutInit(t *testing.T) { } } func TestApiKeyRenewal(t *testing.T) { - dbPathReal := DbPath + db := NewBoltDb() + dbPathReal := db.DbPath defer func() { - os.Remove(DbPath) - DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" var keys [2]string for i := 0; i < 2; i++ { - EnsureApiKey() - key, err := GetApiKey() + db.EnsureApiKey() + key, err := db.GetApiKey() if err != nil { t.Fatal("error while getting value of API key") } diff --git a/dbservice/changedbpath_test.go b/dbservice/changedbpath_test.go deleted file mode 100644 index cf63ee7b..00000000 --- a/dbservice/changedbpath_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package dbservice - -import "testing" - -func TestChangeDbPath(t *testing.T) { - testPath := "/tmp/test.db" - storedPath := DbPath - ChangeDbPath(testPath) - defer func() { - ChangeDbPath(storedPath) - }() - if DbPath != testPath { - t.Errorf("path is not configured correctly, expected: %s, got %s", testPath, DbPath) - } -} diff --git a/dbservice/dbservice.go b/dbservice/dbservice.go new file mode 100644 index 00000000..d8c822ad --- /dev/null +++ b/dbservice/dbservice.go @@ -0,0 +1,63 @@ +package dbservice + +import ( + "os" + "time" + + "github.com/aquasecurity/postee/dbservice/boltdb" + "github.com/aquasecurity/postee/dbservice/postgresdb" +) + +var ( + Db DbProvider +) + +type DbSettings struct { + DBMaxSize int `json:"max-db-size,omitempty"` + DBRemoveOldData int `json:"delete-old-data,omitempty"` + DBTestInterval int `json:"db-verify-interval,omitempty"` + + //PostgresDb + DbName string + DbHostName string + DbPort string + DbUser string + DbPassword string + DbSslMode string + + //BoltDb + DbPath string +} +type DbProvider interface { + MayBeStoreMessage(message []byte, messageKey string, expired *time.Time) (wasStored bool, err error) + CheckSizeLimit() + CheckExpiredData() + AggregateScans(output string, currentScan map[string]string, scansPerTicket int, ignoreTheQuantity bool) ([]map[string]string, error) + RegisterPlgnInvctn(name string) error + EnsureApiKey() error + GetApiKey() (string, error) + SetDbSizeLimit(limit int) +} + +func ConfigureDb(settings *DbSettings, id string) error { + if settings.DBTestInterval == 0 { + settings.DBTestInterval = 1 + } + + if settings.DbName == "" && settings.DbHostName == "" && settings.DbUser == "" { + boltdb := boltdb.NewBoltDb() + if os.Getenv("PATH_TO_DB") != "" { + boltdb.SetNewDbPathFromEnv() + } + Db = boltdb + return nil + } else { + db, err := postgresdb.NewPostgresDb(id, settings.DbName, settings.DbHostName, settings.DbPort, settings.DbUser, settings.DbPassword, settings.DbSslMode) + if err != nil { + return err + } + Db = db + } + Db.SetDbSizeLimit(settings.DBMaxSize) + return nil +} diff --git a/dbservice/postgresdb/actions.go b/dbservice/postgresdb/actions.go new file mode 100644 index 00000000..55cc8418 --- /dev/null +++ b/dbservice/postgresdb/actions.go @@ -0,0 +1,47 @@ +package postgresdb + +import ( + "database/sql" + "fmt" + "time" +) + +func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey string, expired *time.Time) (wasStored bool, err error) { + db, err := psqlConnect(postgresDb.psqlInfo) + if err != nil { + return false, err + } + defer db.Close() + + if err = initTable(db, dbTableName); err != nil { + return false, err + } + + if err = initTable(db, dbTableExpiryDates); err != nil { + return false, err + } + + currentValue := "" + if err = db.Get(¤tValue, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "messageValue", dbTableName, "id", "messageKey"), postgresDb.id, messageKey); err != nil { + if err != sql.ErrNoRows { + return false, err + } + } + + if currentValue != "" { + return false, nil + } else { + + if err = insert(db, dbTableName, postgresDb.id, "messagekey", messageKey, "messagevalue", string(message)); err != nil { + return false, err + } + if expired != nil { + + if err = insert(db, dbTableExpiryDates, postgresDb.id, "date", expired.Format(DateFmt), "messagekey", messageKey); err != nil { + return false, err + } + } + return true, nil + } + +} diff --git a/dbservice/postgresdb/actions_test.go b/dbservice/postgresdb/actions_test.go new file mode 100644 index 00000000..a139e39e --- /dev/null +++ b/dbservice/postgresdb/actions_test.go @@ -0,0 +1,64 @@ +package postgresdb + +import ( + "log" + "testing" + + "github.com/jmoiron/sqlx" + sqlxmock "github.com/zhashkevych/go-sqlxmock" +) + +func TestStoreMessage(t *testing.T) { + currentValueStoreMessage := "" + + savedInitTable := initTable + initTable = func(db *sqlx.DB, tableName string) error { return nil } + savedInsert := insert + insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { + currentValueStoreMessage = value3 + return nil + } + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"messagevalue"}).AddRow(currentValueStoreMessage) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + defer func() { + initTable = savedInitTable + insert = savedInsert + psqlConnect = savedPsqlConnect + }() + + var tests = []struct { + input *string + }{ + {&AlpineImageResult}, + } + + for _, test := range tests { + + // Handling of first scan + isNew, err := db.MayBeStoreMessage([]byte(*test.input), AlpineImageKey, nil) + if err != nil { + t.Errorf("Error: %s\n", err) + } + if !isNew { + t.Errorf("A first scan was found!\n") + } + + // Handling of second scan with the same data + isNew, err = db.MayBeStoreMessage([]byte(*test.input), AlpineImageKey, nil) + if err != nil { + t.Errorf("Error: %s\n", err) + } + if isNew { + t.Errorf("A old scan wasn't found!\n") + } + } + +} diff --git a/dbservice/postgresdb/checker.go b/dbservice/postgresdb/checker.go new file mode 100644 index 00000000..da4bcd2d --- /dev/null +++ b/dbservice/postgresdb/checker.go @@ -0,0 +1,68 @@ +package postgresdb + +import ( + "fmt" + "log" + "time" +) + +func (postgresDb *PostgresDb) CheckSizeLimit() { + if DbSizeLimit == 0 { + return + } + + psqlInfo := postgresDb.psqlInfo + db, err := psqlConnect(psqlInfo) + if err != nil { + log.Println("CheckSizeLimit: Can't open db, psqlInfo: ", psqlInfo) + return + } + defer db.Close() + + size := 0 + if err = db.Get(&size, fmt.Sprintf("SELECT pg_total_relation_size('%s');", dbTableName)); err != nil { + log.Printf("CheckSizeLimit: Can't get db size") + return + } + if size > DbSizeLimit { + if err = deleteRowsById(db, dbTableName, postgresDb.id); err != nil { + log.Printf("CheckSizeLimit: Can't delete id's: %s from table: %s", postgresDb.id, dbTableName) + return + } + } +} + +func (postgresDb *PostgresDb) CheckExpiredData() { + psqlInfo := postgresDb.psqlInfo + db, err := psqlConnect(psqlInfo) + if err != nil { + log.Println("CheckExpiredData: Can't open db, psqlInfo: ", psqlInfo) + return + } + defer db.Close() + + var scanStructs []struct { + Date string `db:"date"` + TtlKey string `db:"messagekey"` + } + if err := db.Select(&scanStructs, fmt.Sprintf("SELECT (key AND ttlkey) FROM %s WHERE %s=$1", dbTableExpiryDates, "id"), postgresDb.id); err != nil { + log.Printf("CheckExpiredData: Can't get %s table: %s", dbTableExpiryDates, err) + return + } + + max := time.Now().UTC().Format(DateFmt) //remove expired records + for _, scanStruct := range scanStructs { + if scanStruct.Date <= max { + + if err = deleteRow(db, dbTableExpiryDates, postgresDb.id, "messagekey", scanStruct.TtlKey); err != nil { + log.Printf("CheckExpiredData: Can't delete %s from table:%s", scanStruct.TtlKey, dbTableExpiryDates) + return + } + + if err = deleteRow(db, dbTableName, postgresDb.id, "messagekey", scanStruct.TtlKey); err != nil { + log.Printf("CheckExpiredData: Can't delete %s from table:%s", scanStruct.TtlKey, dbTableName) + return + } + } + } +} diff --git a/dbservice/postgresdb/checker_test.go b/dbservice/postgresdb/checker_test.go new file mode 100644 index 00000000..fcfb61d8 --- /dev/null +++ b/dbservice/postgresdb/checker_test.go @@ -0,0 +1,92 @@ +package postgresdb + +import ( + "log" + "testing" + "time" + + "github.com/jmoiron/sqlx" + sqlxmock "github.com/zhashkevych/go-sqlxmock" +) + +func TestExpiredDates(t *testing.T) { + tests := []struct { + name string + time time.Time + wasDeleted bool + }{ + {"Time before Now", time.Now().UTC().Add(time.Duration(1) * time.Hour), false}, + {"Time after Now", time.Now().UTC().Add(time.Duration(-1) * time.Hour), true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + deleted := false + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"date", "messagekey"}).AddRow(test.time, "ttlKeyTest") + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + savedDeleteRow := deleteRow + deleteRow = func(db *sqlx.DB, table, id, columnName, value string) error { + deleted = true + return nil + } + defer func() { + psqlConnect = savedPsqlConnect + deleteRow = savedDeleteRow + }() + db.CheckExpiredData() + if deleted != test.wasDeleted { + t.Errorf("error deleted rows") + } + }) + } +} + +func TestSizeLimit(t *testing.T) { + tests := []struct { + name string + sizeLimit int + size int + wasDeleted bool + }{ + {"No size limit", 0, 10, false}, + {"Size less then limit", 5, 10, true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + deleted := false + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"size"}).AddRow(test.size) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + savedDeleteRowsById := deleteRowsById + deleteRowsById = func(db *sqlx.DB, table, id string) error { + deleted = true + return nil + } + defer func() { + psqlConnect = savedPsqlConnect + deleteRowsById = savedDeleteRowsById + }() + db.SetDbSizeLimit(test.sizeLimit) + db.CheckSizeLimit() + if deleted != test.wasDeleted { + t.Errorf("error deleted rows") + } + }) + } +} diff --git a/dbservice/postgresdb/dbaggregator.go b/dbservice/postgresdb/dbaggregator.go new file mode 100644 index 00000000..dc6ad702 --- /dev/null +++ b/dbservice/postgresdb/dbaggregator.go @@ -0,0 +1,59 @@ +package postgresdb + +import ( + "database/sql" + "encoding/json" + "fmt" +) + +func (postgresDb *PostgresDb) AggregateScans(output string, + currentScan map[string]string, + scansPerTicket int, + ignoreTheQuantity bool) ([]map[string]string, error) { + + db, err := psqlConnect(postgresDb.psqlInfo) + if err != nil { + return nil, err + } + defer db.Close() + + if err = initTable(db, dbTableAggregator); err != nil { + return nil, err + } + + aggregatedScans := make([]map[string]string, 0, scansPerTicket) + if len(currentScan) > 0 { + aggregatedScans = append(aggregatedScans, currentScan) + } + currentValue := "" + if err = db.Get(¤tValue, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "saving", dbTableAggregator, "id", "output"), postgresDb.id, output); err != nil { + if err != sql.ErrNoRows { + return nil, err + } + } + + if currentValue != "" { + var savedScans []map[string]string + err = json.Unmarshal([]byte(currentValue), &savedScans) + if err != nil { + return nil, err + } + aggregatedScans = append(aggregatedScans, savedScans...) + } + + if ignoreTheQuantity || len(aggregatedScans) < scansPerTicket { + saving, err := json.Marshal(aggregatedScans) + if err != nil { + return nil, err + } + if err = insert(db, dbTableAggregator, postgresDb.id, "output", output, "saving", string(saving)); err != nil { + + return nil, err + } + return nil, nil + } + if err = insert(db, dbTableAggregator, postgresDb.id, "output", output, "saving", ""); err != nil { + return nil, err + } + return aggregatedScans, nil +} diff --git a/dbservice/postgresdb/dbaggregator_test.go b/dbservice/postgresdb/dbaggregator_test.go new file mode 100644 index 00000000..869797d2 --- /dev/null +++ b/dbservice/postgresdb/dbaggregator_test.go @@ -0,0 +1,114 @@ +package postgresdb + +import ( + "log" + "testing" + + "github.com/jmoiron/sqlx" + sqlxmock "github.com/zhashkevych/go-sqlxmock" +) + +func TestAggregateScans(t *testing.T) { + var ( + scan1 = map[string]string{"title": "t1", "description": "d1"} + scan2 = map[string]string{"title": "t2", "description": "d2"} + scan3 = map[string]string{"title": "t3", "description": "d3"} + scan4 = map[string]string{"title": "t4", "description": "d4"} + ) + + var tests = []struct { + output string + currentScan map[string]string + scansPerTicket int + want []map[string]string + }{ + { + "jira", + scan1, + 3, + nil, + }, + { + "jira", + scan2, + 3, + nil, + }, + { + "jira", + scan3, + 3, + []map[string]string{scan3, scan2, scan1}, + }, + { + "jira", + scan4, + 3, + nil, + }, + } + + saving := "" + for i := 0; i < len(tests); i++ { + savedInitTable := initTable + initTable = func(db *sqlx.DB, tableName string) error { return nil } + savedInsert := insert + insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { + saving = value3 + return nil + } + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"saving"}).AddRow(saving) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + defer func() { + initTable = savedInitTable + insert = savedInsert + psqlConnect = savedPsqlConnect + }() + + test := tests[i] + aggregated, err := db.AggregateScans(test.output, test.currentScan, test.scansPerTicket, false) + if err != nil { + t.Errorf("AggregateScans Error: %v", err) + continue + } + + if len(aggregated) != len(test.want) { + t.Errorf("Wrong result size\nResult: %v\nWaited: %v", aggregated, test.want) + continue + } + + for i := 0; i < len(aggregated); i++ { + if aggregated[i]["title"] != test.want[i]["title"] { + t.Errorf("Wrong title\nResult: %q\nWaited: %q", aggregated[i]["title"], test.want[i]["title"]) + } + if aggregated[i]["description"] != test.want[i]["description"] { + t.Errorf("Wrong Description\nResult: %q\nWaited: %q", aggregated[i]["description"], test.want[i]["description"]) + } + } + } + + // Test of existence last scan in DB + lastScan, err := db.AggregateScans("jira", nil, 0, false) + if err != nil { + t.Fatalf("AggregateScans Error: %v", err) + } + + if len(lastScan) != 1 { + t.Fatalf("Db don't contain last scan") + } + + if lastScan[0]["title"] != scan4["title"] { + t.Errorf("Wrong title\nResult: %q\nWaited: %q", lastScan[0]["title"], scan4["title"]) + } + if lastScan[0]["description"] != scan4["description"] { + t.Errorf("Wrong Description\nResult: %q\nWaited: %q", lastScan[0]["description"], scan4["description"]) + } +} diff --git a/dbservice/postgresdb/dbparam.go b/dbservice/postgresdb/dbparam.go new file mode 100644 index 00000000..6ebea8dd --- /dev/null +++ b/dbservice/postgresdb/dbparam.go @@ -0,0 +1,91 @@ +package postgresdb + +import ( + "errors" + "log" + "strings" + "time" + + "github.com/aquasecurity/postee/utils" + "github.com/jmoiron/sqlx" +) + +var ( + dbTableName = "WebhookTable" + dbTableAggregator = "WebhookAggregator" + dbTableExpiryDates = "WebhookExpiryDates" + dbTableOutputStats = "WebhookOutputStats" + dbTableSharedConfig = "WebhookSharedConfig" + + DbSizeLimit = 0 + DateFmt = time.RFC3339Nano + dueTimeBase = time.Hour * time.Duration(24) +) + +type PostgresDb struct { + psqlInfo string + id string +} + +func NewPostgresDb(id, dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) (*PostgresDb, error) { + info, err := buildPsqlInfo(dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode) + if err != nil { + return nil, err + } + return &PostgresDb{ + psqlInfo: info, + id: id, + }, nil +} + +func buildPsqlInfo(dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) (string, error) { + psqlInfo := []string{} + + if dbHostName != "" { + dbHostName = utils.GetEnvironmentVarOrPlain(dbHostName) + psqlInfo = append(psqlInfo, "host="+dbHostName) + } else { + log.Printf("dbHostName is empty, for psqlInfo is used dbHostName=localhost") + } + if dbPort != "" { + dbPort = utils.GetEnvironmentVarOrPlain(dbPort) + psqlInfo = append(psqlInfo, "port="+dbPort) + } else { + log.Printf("dbPort is empty, for psqlInfo is used dbPort=5432") + } + if dbName != "" { + dbName = utils.GetEnvironmentVarOrPlain(dbName) + psqlInfo = append(psqlInfo, "dbname="+dbName) + } else { + return "", errors.New("can't build psqlInfo, dbName is empty") + } + if dbUser != "" { + dbUser = utils.GetEnvironmentVarOrPlain(dbUser) + psqlInfo = append(psqlInfo, "user="+dbUser) + } else { + return "", errors.New("can't build psqlInfo, dbUser is empty") + } + if dbPassword != "" { + dbPassword = utils.GetEnvironmentVarOrPlain(dbPassword) + psqlInfo = append(psqlInfo, "password="+dbPassword) + } + if dbSslMode != "" { + psqlInfo = append(psqlInfo, "sslmode="+dbSslMode) + } else { + log.Printf("dbSslMode is empty, for psqlInfo is used sslmode=disable") + psqlInfo = append(psqlInfo, "sslmode="+"disable") + } + return strings.Join(psqlInfo[:], " "), nil +} + +func (postgresDb *PostgresDb) SetDbSizeLimit(limit int) { + DbSizeLimit = limit +} + +var psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, err := sqlx.Connect("postgres", psqlInfo) + if err != nil { + return nil, err + } + return db, nil +} diff --git a/dbservice/postgresdb/dbparam_test.go b/dbservice/postgresdb/dbparam_test.go new file mode 100644 index 00000000..75977592 --- /dev/null +++ b/dbservice/postgresdb/dbparam_test.go @@ -0,0 +1,43 @@ +package postgresdb + +import ( + "testing" +) + +func TestBuildPsqlInfo(t *testing.T) { + var tests = []struct { + name string + dbName string + dbHostName string + dbPort string + dbUser string + dbPassword string + dbSslMode string + expectedPsqlInfo string + expectedError string + }{ + {"empty dbName", "", "dbHostName", "dbPort", "dbUser", "dbPassword", "dbSslMode", + "", "can't build psqlInfo, dbName is empty"}, + {"empty dbHostName", "dbName", "", "dbPort", "dbUser", "dbPassword", "dbSslMode", + "port=dbPort dbname=dbName user=dbUser password=dbPassword sslmode=dbSslMode", ""}, + {"empty dbPort", "dbName", "dbHostName", "", "dbUser", "dbPassword", "dbSslMode", + "host=dbHostName dbname=dbName user=dbUser password=dbPassword sslmode=dbSslMode", ""}, + {"empty dbUser", "dbName", "dbHostName", "dbPort", "", "dbPassword", "dbSslMode", + "", "can't build psqlInfo, dbUser is empty"}, + {"empty dbPassword", "dbName", "dbHostName", "dbPort", "dbUser", "", "dbSslMode", + "host=dbHostName port=dbPort dbname=dbName user=dbUser sslmode=dbSslMode", ""}, + {"empty dbSslMode", "dbName", "dbHostName", "dbPort", "dbUser", "dbPassword", "", + "host=dbHostName port=dbPort dbname=dbName user=dbUser password=dbPassword sslmode=disable", ""}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + psqlInfo, err := buildPsqlInfo(test.dbName, test.dbHostName, test.dbPort, test.dbUser, test.dbPassword, test.dbSslMode) + if err != nil && err.Error() != test.expectedError { + t.Errorf("Unexpected error for %s, expected %v, got %v", test.name, test.expectedError, err) + } + if test.expectedPsqlInfo != psqlInfo { + t.Errorf("error getting psqlInfo, expected:%s, got:%s", test.expectedPsqlInfo, psqlInfo) + } + }) + } +} diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go new file mode 100644 index 00000000..bebd9327 --- /dev/null +++ b/dbservice/postgresdb/dbservice_test.go @@ -0,0 +1,288 @@ +package postgresdb + +import ( + "errors" + "log" + "testing" + + "github.com/jmoiron/sqlx" + sqlxmock "github.com/zhashkevych/go-sqlxmock" +) + +var ( + AlpineImageKey = "sha256:c8bccc0af9571ec0d006a43acb5a8d08c4ce42b6cc7194dd6eb167976f501ef1-alpine:3.8-Docker Hub" + AlpineImageResult = `{ + "image": "alpine:3.8", + "registry": "Docker Hub", + "digest": "sha256:c8bccc0af9571ec0d006a43acb5a8d08c4ce42b6cc7194dd6eb167976f501ef1", + "previous_digest": "sha256:c8bccc0af9571ec0d006a43acb5a8d08c4ce42b6cc7194dd6eb167976f501ef1", + "image_assurance_results": { + "disallowed": true, + "checks_performed": [ + { + "control": "max_severity", + "policy_name": "Default", + "failed": false + }, + { + "control": "trusted_base_images", + "policy_name": "Default", + "failed": true + }, + { + "control": "max_score", + "policy_name": "Default", + "failed": false + } + ] + }, + "vulnerability_summary": { + "total": 2, + "critical": 0, + "high": 0, + "medium": 2, + "low": 0, + "negligible": 0, + "sensitive": 0, + "malware": 0 + }, + "scan_options": { + "scan_sensitive_data": true, + "scan_malware": true + }, + "resources": [ + { + "vulnerabilities": [ + { + "name": "CVE-2018-20679", + "version": "", + "fix_version": "", + "aqua_severity": "medium" + }, + { + "name": "CVE-2019-5747", + "version": "", + "fix_version": "", + "aqua_severity": "medium" + } + ], + "resource": { + "name": "busybox", + "version": "1.28.4-r3" + } + } + ] + }` + + db, _ = NewPostgresDb("id", "dbName", "", "", "user", "password", "disable") +) + +func TestInitError(t *testing.T) { + initErr := errors.New("init error") + + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectExec("CREATE").WillReturnError(initErr) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + isNew, err := db.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) + + if isNew { + t.Errorf("Scan shouldn't be marked as new\n") + } + + if err.Error() != initErr.Error() { + t.Errorf("Unexpected error: expected %s, got %s \n", initErr, err) + } +} + +func TestDeleteRow(t *testing.T) { + t.Log("happy delete row") + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + savedDeleteRow := deleteRow + defer func() { + deleteRow = savedDeleteRow + }() + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectExec("DELETE").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + + psqlDb, _ := psqlConnect(db.psqlInfo) + err := deleteRow(psqlDb, "table", "id", "column", "value") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + t.Log("bad delete row") + deleteError := errors.New("delete - error") + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + savedDeleteRow := deleteRow + defer func() { + deleteRow = savedDeleteRow + }() + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectExec("DELETE").WillReturnError(deleteError) + return db, err + } + psqlDb, _ = psqlConnect(db.psqlInfo) + err = deleteRow(psqlDb, "table", "id", "column", "value") + if deleteError != err { + t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) + } +} +func TestDeleteRowsById(t *testing.T) { + t.Log("happy delete rows by id") + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + savedDeleteRowsById := deleteRowsById + defer func() { + deleteRowsById = savedDeleteRowsById + }() + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectExec("DELETE").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + + psqlDb, _ := psqlConnect(db.psqlInfo) + err := deleteRowsById(psqlDb, "table", "id") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + t.Log("bad delete row") + deleteError := errors.New("delete - error") + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + savedDeleteRowsById := deleteRowsById + defer func() { + deleteRowsById = savedDeleteRowsById + }() + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectExec("DELETE").WillReturnError(deleteError) + return db, err + } + psqlDb, _ = psqlConnect(db.psqlInfo) + err = deleteRowsById(psqlDb, "table", "id") + if deleteError != err { + t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) + } +} + +func TestInsert(t *testing.T) { + t.Log("happy insert") + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + + psqlDb, _ := psqlConnect(db.psqlInfo) + err := insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + t.Log("happy update") + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + psqlDb, _ = psqlConnect(db.psqlInfo) + err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + t.Log("bad select") + badSelectError := errors.New("bad select") + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectQuery("SELECT").WillReturnError(badSelectError) + return db, err + } + psqlDb, _ = psqlConnect(db.psqlInfo) + err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + if err != badSelectError { + t.Errorf("Unexpected error, expected: %v, got: %v", badSelectError, err) + } + + t.Log("bad insert") + badInsertError := errors.New("bad insert") + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnError(badInsertError) + return db, err + } + psqlDb, _ = psqlConnect(db.psqlInfo) + err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + if err != badInsertError { + t.Errorf("Unexpected error, expected: %v, got: %v", badInsertError, err) + } + + t.Log("bad update") + badUpdateError := errors.New("bad update") + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) + return db, err + } + psqlDb, _ = psqlConnect(db.psqlInfo) + err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + if err != badUpdateError { + t.Errorf("Unexpected error, expected: %v, got: %v", badUpdateError, err) + } +} diff --git a/dbservice/postgresdb/delete.go b/dbservice/postgresdb/delete.go new file mode 100644 index 00000000..26a20ff3 --- /dev/null +++ b/dbservice/postgresdb/delete.go @@ -0,0 +1,21 @@ +package postgresdb + +import ( + "fmt" + + "github.com/jmoiron/sqlx" +) + +var deleteRow = func(db *sqlx.DB, table, id, columnName, value string) error { + if _, err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE (%s=$1 AND %s=$2);", table, "id", columnName), id, value); err != nil { + return err + } + return nil +} + +var deleteRowsById = func(db *sqlx.DB, table, id string) error { + if _, err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s=$1", table, "id"), id); err != nil { + return err + } + return nil +} diff --git a/dbservice/postgresdb/init.go b/dbservice/postgresdb/init.go new file mode 100644 index 00000000..44e087f2 --- /dev/null +++ b/dbservice/postgresdb/init.go @@ -0,0 +1,25 @@ +package postgresdb + +import ( + "fmt" + + "github.com/jmoiron/sqlx" +) + +var ( + tableSchemas = map[string]string{ + dbTableName: "CREATE TABLE IF NOT EXISTS %s (id text, messagekey text,messagevalue text);", + dbTableAggregator: "CREATE TABLE IF NOT EXISTS %s (id text, output text,saving text);", + dbTableExpiryDates: "CREATE TABLE IF NOT EXISTS %s (id text, date text,messageKey text);", + dbTableOutputStats: "CREATE TABLE IF NOT EXISTS %s (id text, outputname text,amount integer);", + dbTableSharedConfig: "CREATE TABLE IF NOT EXISTS %s (id text, apikeyname text,value text);", + } +) + +var initTable = func(db *sqlx.DB, tableName string) error { + _, err := db.Exec(fmt.Sprintf(tableSchemas[tableName], tableName)) + if err != nil { + return err + } + return nil +} diff --git a/dbservice/postgresdb/insert.go b/dbservice/postgresdb/insert.go new file mode 100644 index 00000000..e7e46abe --- /dev/null +++ b/dbservice/postgresdb/insert.go @@ -0,0 +1,44 @@ +package postgresdb + +import ( + "fmt" + + "github.com/jmoiron/sqlx" +) + +var insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { + var i int + if err := db.Get(&i, fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (%s=$1 AND %s=$2)", table, "id", columnName2), id, value2); err != nil { + return err + } + if i == 0 { + if _, err := db.Exec(fmt.Sprintf("INSERT INTO %s (%s, %s, %s) VALUES ($1, $2, $3)", table, "id", columnName2, columnName3), id, value2, value3); err != nil { + return err + } + } else { + if _, err := db.Exec(fmt.Sprintf("UPDATE %s SET %s=$1 WHERE (%s=$2 AND %s=$3);", table, columnName3, "id", columnName2), value3, id, value2); err != nil { + return err + } + } + return nil +} + +var insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { + var i int + err := db.Get(&i, fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (%s=$1 AND %s=$2)", dbTableOutputStats, "id", "outputName"), id, outputName) + if err != nil { + return err + } + if i == 0 { + _, err := db.Exec(fmt.Sprintf("INSERT INTO %s (%s, %s, %s) VALUES ($1, $2, $3);", dbTableOutputStats, "id", "outputName", "amount"), id, outputName, amount) + if err != nil { + return err + } + } else { + _, err = db.Exec(fmt.Sprintf("UPDATE %s SET %s=$1 WHERE (%s=$2 AND %s=$3);", dbTableOutputStats, "amount", "id", "outputName"), amount, id, outputName) + if err != nil { + return err + } + } + return nil +} diff --git a/dbservice/postgresdb/invalidinit_test.go b/dbservice/postgresdb/invalidinit_test.go new file mode 100644 index 00000000..14db5710 --- /dev/null +++ b/dbservice/postgresdb/invalidinit_test.go @@ -0,0 +1,80 @@ +package postgresdb + +import ( + "errors" + "log" + "testing" + + "github.com/jmoiron/sqlx" + sqlxmock "github.com/zhashkevych/go-sqlxmock" +) + +var tests = []struct { + caseDesc string + errPrvdr func() error + initIsNotCalled bool +}{ + { + caseDesc: "EnsureApiKey", + errPrvdr: func() error { + return db.EnsureApiKey() + }, + }, + { + caseDesc: "GetApiKey", + errPrvdr: func() error { + _, err := db.GetApiKey() + return err + }, + initIsNotCalled: true, + }, + { + caseDesc: "RegisterPlgnInvctn", + errPrvdr: func() error { + return db.RegisterPlgnInvctn("some-key") + }, + }, + { + caseDesc: "MayBeStoreMessage", + errPrvdr: func() error { + _, err := db.MayBeStoreMessage(nil, "a-b-c", nil) + return err + }, + }, + { + caseDesc: "AggregateScans", + errPrvdr: func() error { + _, err := db.AggregateScans("", map[string]string{}, 1, false) + return err + }, + }, +} + +func TestBucketInitialization(t *testing.T) { + expectedError := errors.New("weird error") + savedInitTable := initTable + savedPsqlConnect := psqlConnect + initTable = func(db *sqlx.DB, tableName string) error { return expectedError } + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, _, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + return db, err + } + defer func() { + initTable = savedInitTable + psqlConnect = savedPsqlConnect + }() + + for _, test := range tests { + if test.initIsNotCalled { + continue + } + err := test.errPrvdr() + if err != expectedError { + t.Errorf("Unexpected error for %s call, expected %v, got %v", test.caseDesc, expectedError, err) + } + + } +} diff --git a/dbservice/postgresdb/plgstats.go b/dbservice/postgresdb/plgstats.go new file mode 100644 index 00000000..a152242e --- /dev/null +++ b/dbservice/postgresdb/plgstats.go @@ -0,0 +1,33 @@ +package postgresdb + +import ( + "database/sql" + "fmt" + + _ "github.com/lib/pq" +) + +func (postgresDb *PostgresDb) RegisterPlgnInvctn(name string) error { + db, err := psqlConnect(postgresDb.psqlInfo) + if err != nil { + return err + } + defer db.Close() + + err = initTable(db, dbTableOutputStats) + if err != nil { + return err + } + amount := 0 + err = db.Get(&amount, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "amount", dbTableOutputStats, "id", "outputName"), postgresDb.id, name) + if err != nil && err != sql.ErrNoRows { + return err + } + amount += 1 + err = insertOutputStats(db, postgresDb.id, name, amount) + if err != nil { + return err + } + + return nil +} diff --git a/dbservice/postgresdb/plgstats_test.go b/dbservice/postgresdb/plgstats_test.go new file mode 100644 index 00000000..c269fd88 --- /dev/null +++ b/dbservice/postgresdb/plgstats_test.go @@ -0,0 +1,90 @@ +package postgresdb + +import ( + "database/sql" + "log" + "testing" + + "github.com/jmoiron/sqlx" + sqlxmock "github.com/zhashkevych/go-sqlxmock" +) + +func TestRegisterPlgnInvctn(t *testing.T) { + receivedKey := 0 + savedInitTable := initTable + initTable = func(db *sqlx.DB, tableName string) error { return nil } + savedInsertOutputStats := insertOutputStats + insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { + receivedKey = amount + return nil + } + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"amount"}).AddRow(receivedKey) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + defer func() { + initTable = savedInitTable + insertOutputStats = savedInsertOutputStats + psqlConnect = savedPsqlConnect + }() + + expectedCnt := 3 + keyToTest := "test" + for i := 0; i < expectedCnt; i++ { + db.RegisterPlgnInvctn(keyToTest) + } + if receivedKey != expectedCnt { + t.Errorf("Persisted count doesn't match expected. Expected %d, got %d\n", receivedKey, expectedCnt) + } +} + +func TestRegisterPlgnInvctnErrors(t *testing.T) { + var tests = []struct { + name string + errIn error + expectedErr error + }{ + {"No result rows error", sql.ErrNoRows, nil}, + {"Other errors", sql.ErrConnDone, sql.ErrConnDone}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + savedInitTable := initTable + initTable = func(db *sqlx.DB, tableName string) error { return nil } + savedInsertOutputStats := insertOutputStats + insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { return nil } + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectQuery("SELECT").WillReturnError(test.errIn) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + initTable = savedInitTable + insertOutputStats = savedInsertOutputStats + }() + err := db.RegisterPlgnInvctn("testName") + if err != test.expectedErr { + t.Errorf("Errors no contains: expected: %v, got: %v", test.expectedErr, err) + } + }) + } + + key, err := db.GetApiKey() + if err == nil { + t.Fatal("Error is expected") + } + if key != "" { + t.Fatal("Empty key is expected") + } +} diff --git a/dbservice/postgresdb/sharedcfg.go b/dbservice/postgresdb/sharedcfg.go new file mode 100644 index 00000000..7ea16c44 --- /dev/null +++ b/dbservice/postgresdb/sharedcfg.go @@ -0,0 +1,58 @@ +package postgresdb + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "io" + + _ "github.com/lib/pq" +) + +var apiKeyName = "POSTEE_API_KEY" + +func (postgresDb *PostgresDb) EnsureApiKey() error { + db, err := psqlConnect(postgresDb.psqlInfo) + if err != nil { + return err + } + defer db.Close() + err = initTable(db, dbTableSharedConfig) + if err != nil { + return err + } + + apiKey, err := generateApiKey(32) + if err != nil { + return err + } + + if err = insert(db, dbTableSharedConfig, postgresDb.id, "apikeyname", apiKeyName, "value", apiKey); err != nil { + return err + } + + return nil +} + +func (postgresDb *PostgresDb) GetApiKey() (string, error) { + db, err := psqlConnect(postgresDb.psqlInfo) + if err != nil { + return "", err + } + defer db.Close() + value := "" + err = db.Get(&value, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "value", dbTableSharedConfig, "id", "apikeyname"), postgresDb.id, apiKeyName) + if err != nil { + return "", err + } + return value, nil + +} + +func generateApiKey(length int) (string, error) { + k := make([]byte, length) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return "", err + } + return hex.EncodeToString(k), nil +} diff --git a/dbservice/postgresdb/sharedcfg_test.go b/dbservice/postgresdb/sharedcfg_test.go new file mode 100644 index 00000000..8442d2e0 --- /dev/null +++ b/dbservice/postgresdb/sharedcfg_test.go @@ -0,0 +1,106 @@ +package postgresdb + +import ( + "database/sql" + "log" + "testing" + + "github.com/jmoiron/sqlx" + sqlxmock "github.com/zhashkevych/go-sqlxmock" +) + +func TestApiKey(t *testing.T) { + savedInitTable := initTable + initTable = func(db *sqlx.DB, tableName string) error { return nil } + savedInsert := insert + insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { return nil } + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"value"}).AddRow("key") + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + defer func() { + initTable = savedInitTable + insert = savedInsert + psqlConnect = savedPsqlConnect + }() + + db.EnsureApiKey() + + key, err := db.GetApiKey() + if err != nil { + t.Fatal("error while getting value of API key") + } + if key == "" { + t.Fatal("empty key received") + } +} + +func TestApiKeyWithoutInit(t *testing.T) { + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectQuery("SELECT").WillReturnError(sql.ErrNoRows) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + key, err := db.GetApiKey() + if err == nil { + t.Fatal("Error is expected") + } + if key != "" { + t.Fatal("Empty key is expected") + } +} + +func TestApiKeyRenewal(t *testing.T) { + receivedKey := "" + savedInitTable := initTable + initTable = func(db *sqlx.DB, tableName string) error { return nil } + savedInsert := insert + insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { + receivedKey = value3 + return nil + } + savedPsqlConnect := psqlConnect + psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"value"}).AddRow(receivedKey) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + defer func() { + initTable = savedInitTable + insert = savedInsert + psqlConnect = savedPsqlConnect + }() + + var keys [2]string + for i := 0; i < 2; i++ { + db.EnsureApiKey() + key, err := db.GetApiKey() + if err != nil { + t.Fatal("error while getting value of API key") + } + if key == "" { + t.Fatal("empty key received") + } + keys[i] = key + } + if keys[0] == keys[1] { + t.Errorf("Key is not updated. (before: %s and after update: %s)", keys[0], keys[1]) + } +} diff --git a/go.mod b/go.mod index 2ad5e3d0..1a688965 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,11 @@ require ( github.com/aquasecurity/go-jira v0.0.0-20211103111421-b62ce48827be github.com/ghodss/yaml v1.0.0 github.com/gorilla/mux v1.8.0 + github.com/jmoiron/sqlx v1.3.4 + github.com/lib/pq v1.2.0 github.com/open-policy-agent/opa v0.34.2 github.com/spf13/cobra v1.2.1 github.com/stretchr/testify v1.7.0 + github.com/zhashkevych/go-sqlxmock v1.5.2-0.20201023121933-f973d0041cfc go.etcd.io/bbolt v1.3.5 ) diff --git a/go.sum b/go.sum index 3e2ad87e..d8812714 100644 --- a/go.sum +++ b/go.sum @@ -120,6 +120,9 @@ github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vb github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= @@ -231,6 +234,9 @@ github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1: github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= +github.com/jmoiron/sqlx v1.3.4 h1:wv+0IJZfL5z0uZoUjlpKgHkgaFSYD+r9CfrXjEXsO7w= +github.com/jmoiron/sqlx v1.3.4/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -242,6 +248,7 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kisielk/sqlstruct v0.0.0-20150923205031-648daed35d49/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.5 h1:9O69jUPDcsT9fEm74W92rZL9FQY7rCdaXVneq+yyzl4= github.com/klauspost/compress v1.13.5/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= @@ -254,11 +261,17 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= @@ -366,6 +379,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zhashkevych/go-sqlxmock v1.5.2-0.20201023121933-f973d0041cfc h1:z6oWvrg2brc98tlcDChukX4BKc3t0Ayz9dSBtJRYw9w= +github.com/zhashkevych/go-sqlxmock v1.5.2-0.20201023121933-f973d0041cfc/go.mod h1:kgQytrOB1XCQEsf5P1GpvvmjRkJhrORDtR/jvxKEQBw= go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs= diff --git a/main.go b/main.go index 8d7e9da4..c75d7d32 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,6 @@ import ( "runtime" "syscall" - "github.com/aquasecurity/postee/dbservice" "github.com/aquasecurity/postee/router" "github.com/aquasecurity/postee/utils" "github.com/aquasecurity/postee/webserver" @@ -74,10 +73,6 @@ func main() { cfgfile = os.Getenv("POSTEE_CFG") } - if os.Getenv("PATH_TO_DB") != "" { - dbservice.SetNewDbPathFromEnv() - } - err := router.Instance().Start(cfgfile) if err != nil { log.Printf("Can't start alert manager %v", err) diff --git a/msgservice/aggregatebytime_test.go b/msgservice/aggregatebytime_test.go index e5d37faa..5f8738cd 100644 --- a/msgservice/aggregatebytime_test.go +++ b/msgservice/aggregatebytime_test.go @@ -7,19 +7,25 @@ import ( "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/boltdb" "github.com/aquasecurity/postee/outputs" "github.com/aquasecurity/postee/routes" ) func TestAggregateByTimeout(t *testing.T) { + db = boltdb.NewBoltDb() + oldDb := dbservice.Db + dbservice.Db = db + defer func() { dbservice.Db = oldDb }() + const aggregationSeconds = 3 - dbPathReal := dbservice.DbPath + dbPathReal := db.DbPath savedRunScheduler := RunScheduler schedulerInvctCnt := 0 defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal RunScheduler = savedRunScheduler }() RunScheduler = func( @@ -36,7 +42,7 @@ func TestAggregateByTimeout(t *testing.T) { schedulerInvctCnt++ } - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" demoRoute := &routes.InputRoute{ Name: "demo-route1", diff --git a/msgservice/aggregatescan_test.go b/msgservice/aggregatescan_test.go index ea36f221..56a91a71 100644 --- a/msgservice/aggregatescan_test.go +++ b/msgservice/aggregatescan_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/boltdb" "github.com/aquasecurity/postee/routes" ) @@ -38,12 +39,16 @@ func TestAggregateIssuesPerTicket(t *testing.T) { } func doAggregate(t *testing.T, caseDesc string, expectedSntCnt int, expectedRenderCnt int, expectedAggrRenderCnt int, skipAggrSpprt bool) { - dbPathReal := dbservice.DbPath + db = boltdb.NewBoltDb() + oldDb := dbservice.Db + dbservice.Db = db + defer func() { dbservice.Db = oldDb }() + dbPathReal := db.DbPath defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" demoEmailOutput := &DemoEmailOutput{ emailCounts: 0, diff --git a/msgservice/applicationscopeowner_test.go b/msgservice/applicationscopeowner_test.go index f36f4c04..3a1b54ef 100644 --- a/msgservice/applicationscopeowner_test.go +++ b/msgservice/applicationscopeowner_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/boltdb" "github.com/aquasecurity/postee/routes" ) @@ -21,12 +22,16 @@ var ( ) func TestApplicationScopeOwner(t *testing.T) { - dbPathReal := dbservice.DbPath + db = boltdb.NewBoltDb() + oldDb := dbservice.Db + dbservice.Db = db + defer func() { dbservice.Db = oldDb }() + dbPathReal := db.DbPath defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" demoEmailOutput := &DemoEmailOutput{ emailCounts: 0, diff --git a/msgservice/getuniqueid_test.go b/msgservice/getuniqueid_test.go index f8a20231..0ca815cd 100644 --- a/msgservice/getuniqueid_test.go +++ b/msgservice/getuniqueid_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/boltdb" "github.com/aquasecurity/postee/routes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -88,12 +89,16 @@ func TestScanUniqueId(t *testing.T) { } func sendInputs(t *testing.T, caseDesc string, inputs []string, uniqueMessageProps []string, expected int) { - dbPathReal := dbservice.DbPath + db = boltdb.NewBoltDb() + oldDb := dbservice.Db + dbservice.Db = db + defer func() { dbservice.Db = oldDb }() + dbPathReal := db.DbPath defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" demoEmailOutput := &DemoEmailOutput{ emailCounts: 0, diff --git a/msgservice/msghandling.go b/msgservice/msghandling.go index db726e1e..1f163310 100644 --- a/msgservice/msghandling.go +++ b/msgservice/msghandling.go @@ -62,7 +62,7 @@ func (scan *MsgService) MsgHandling(input []byte, output outputs.Output, route * msgKey := GetMessageUniqueId(in, route.Plugins.UniqueMessageProps) expired := calculateExpired(route.Plugins.UniqueMessageTimeoutSeconds) - wasStored, err := dbservice.MayBeStoreMessage(input, msgKey, expired) + wasStored, err := dbservice.Db.MayBeStoreMessage(input, msgKey, expired) if err != nil { log.Printf("Error while storing input: %v", err) return @@ -123,7 +123,7 @@ func send(otpt outputs.Output, cnt map[string]string) { } }() - err := dbservice.RegisterPlgnInvctn(otpt.GetName()) + err := dbservice.Db.RegisterPlgnInvctn(otpt.GetName()) if err != nil { log.Printf("Error while building aggregated content: %v", err) return @@ -140,7 +140,7 @@ func calculateExpired(UniqueMessageTimeoutSeconds int) *time.Time { } var AggregateScanAndGetQueue = func(outputName string, currentContent map[string]string, counts int, ignoreLength bool) []map[string]string { - aggregatedScans, err := dbservice.AggregateScans(outputName, currentContent, counts, ignoreLength) + aggregatedScans, err := dbservice.Db.AggregateScans(outputName, currentContent, counts, ignoreLength) if err != nil { log.Printf("AggregateScans Error: %v", err) return aggregatedScans diff --git a/msgservice/msgservice_test.go b/msgservice/msgservice_test.go index 8695de07..680aae69 100644 --- a/msgservice/msgservice_test.go +++ b/msgservice/msgservice_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/boltdb" "github.com/aquasecurity/postee/routes" ) @@ -14,6 +15,7 @@ var ( invalidJson = `{ image : "My Image" }` + db = boltdb.NewBoltDb() ) type FailingInptEval struct { @@ -62,12 +64,12 @@ func TestInputs(t *testing.T) { } func validateInputValue(t *testing.T, caseDesc string, input []byte, shouldPass bool) { - dbPathReal := dbservice.DbPath + dbPathReal := db.DbPath defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" demoEmailOutput := &DemoEmailOutput{ emailCounts: 0, @@ -99,12 +101,12 @@ func validateInputValue(t *testing.T, caseDesc string, input []byte, shouldPass } func TestEvalError(t *testing.T) { - dbPathReal := dbservice.DbPath + dbPathReal := db.DbPath defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" demoEmailOutput := &DemoEmailOutput{ emailCounts: 0, @@ -130,12 +132,15 @@ func TestEvalError(t *testing.T) { } func TestAggrEvalError(t *testing.T) { - dbPathReal := dbservice.DbPath + oldDb := dbservice.Db + dbservice.Db = db + defer func() { dbservice.Db = oldDb }() + dbPathReal := db.DbPath defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" demoEmailOutput := &DemoEmailOutput{ emailCounts: 0, @@ -164,12 +169,12 @@ func TestAggrEvalError(t *testing.T) { } } func TestEmptyInput(t *testing.T) { - dbPathReal := dbservice.DbPath + dbPathReal := db.DbPath defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" srvUrl := "" diff --git a/msgservice/regocriteria_test.go b/msgservice/regocriteria_test.go index c1a322d9..d5acd996 100644 --- a/msgservice/regocriteria_test.go +++ b/msgservice/regocriteria_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/boltdb" "github.com/aquasecurity/postee/routes" ) @@ -93,6 +94,11 @@ func TestRegoCriteria(t *testing.T) { } func validateRegoInput(t *testing.T, caseDesc string, input string, regoCriteria string, regoFilePath string, shouldPass bool) { + db := boltdb.NewBoltDb() + oldDb := dbservice.Db + dbservice.Db = db + defer func() { dbservice.Db = oldDb }() + regoFile, err := os.Create("regoFile.rego") if err != nil { t.Error("Can't create regoFile.rego file") @@ -101,12 +107,12 @@ func validateRegoInput(t *testing.T, caseDesc string, input string, regoCriteria defer os.Remove("regoFile.rego") defer regoFile.Close() - dbPathReal := dbservice.DbPath + dbPathReal := db.DbPath defer func() { - os.Remove(dbservice.DbPath) - dbservice.DbPath = dbPathReal + os.Remove(db.DbPath) + db.DbPath = dbPathReal }() - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" demoEmailOutput := &DemoEmailOutput{ emailCounts: 0, diff --git a/router/loads_test.go b/router/loads_test.go index 2a892d5d..630a0bec 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/aquasecurity/postee/data" - "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/boltdb" "github.com/aquasecurity/postee/msgservice" "github.com/aquasecurity/postee/outputs" "github.com/aquasecurity/postee/routes" @@ -62,6 +62,7 @@ outputs: password: admin tls-verify: false project-key: kcv` + db = boltdb.NewBoltDb() ) type ctxWrapper struct { @@ -90,13 +91,13 @@ func (ctx *ctxWrapper) MsgHandling(input []byte, output outputs.Output, route *r } func (ctxWrapper *ctxWrapper) setup(cfg string) { - ctxWrapper.savedDBPath = dbservice.DbPath + ctxWrapper.savedDBPath = db.DbPath ctxWrapper.savedBaseForTicker = baseForTicker ctxWrapper.cfgPath = "cfg_test.yaml" ctxWrapper.savedGetService = getScanService ctxWrapper.buff = make(chan invctn) - dbservice.DbPath = "test_webhooks.db" + db.DbPath = "test_webhooks.db" baseForTicker = time.Microsecond ctxWrapper.defaultRegoFolder = "rego-templates" ctxWrapper.commonRegoFolder = ctxWrapper.defaultRegoFolder + "/common" @@ -125,11 +126,11 @@ func (ctxWrapper *ctxWrapper) teardown() { baseForTicker = ctxWrapper.savedBaseForTicker os.Remove(ctxWrapper.cfgPath) - os.Remove(dbservice.DbPath) + os.Remove(db.DbPath) os.Remove(ctxWrapper.commonRegoFolder) os.Remove(ctxWrapper.defaultRegoFolder) - dbservice.ChangeDbPath(ctxWrapper.savedDBPath) + db.ChangeDbPath(ctxWrapper.savedDBPath) getScanService = ctxWrapper.savedGetService close(ctxWrapper.buff) } diff --git a/router/router.go b/router/router.go index 390d2546..1400471d 100644 --- a/router/router.go +++ b/router/router.go @@ -198,19 +198,17 @@ func (ctx *Router) load() error { ctx.aquaServer = fmt.Sprintf("%s%s#/images/", tenant.AquaServer, slash) } - dbservice.DbSizeLimit = tenant.DBMaxSize - if tenant.DBTestInterval == 0 { - tenant.DBTestInterval = 1 - } - ctx.ticker = time.NewTicker(baseForTicker * time.Duration(tenant.DBTestInterval)) + dbservice.ConfigureDb(&tenant.DbSettings, tenant.Name) + + ctx.ticker = time.NewTicker(baseForTicker * time.Duration(tenant.DbSettings.DBTestInterval)) go func() { for { select { case <-ctx.stopTicker: return case <-ctx.ticker.C: - dbservice.CheckSizeLimit() - dbservice.CheckExpiredData() + dbservice.Db.CheckSizeLimit() + dbservice.Db.CheckExpiredData() } } }() diff --git a/router/tenants.go b/router/tenants.go index 1c2881ca..ad653662 100644 --- a/router/tenants.go +++ b/router/tenants.go @@ -1,15 +1,15 @@ package router import ( + "github.com/aquasecurity/postee/dbservice" "github.com/aquasecurity/postee/routes" ) type TenantSettings struct { - AquaServer string `json:"aqua-server,omitempty"` - DBMaxSize int `json:"max-db-size,omitempty"` - DBRemoveOldData int `json:"delete-old-data,omitempty"` - DBTestInterval int `json:"db-verify-interval,omitempty"` - Outputs []OutputSettings `json:"outputs"` - InputRoutes []routes.InputRoute `json:"routes"` - Templates []Template `json:"templates"` + Name string `json:"name,omitempty"` + AquaServer string `json:"aqua-server,omitempty"` + DbSettings dbservice.DbSettings `json:"dbsettings"` + Outputs []OutputSettings `json:"outputs"` + InputRoutes []routes.InputRoute `json:"routes"` + Templates []Template `json:"templates"` } diff --git a/webserver/webserver.go b/webserver/webserver.go index b9d69f0a..447feba4 100644 --- a/webserver/webserver.go +++ b/webserver/webserver.go @@ -34,7 +34,7 @@ func Instance() *WebServer { } func (ctx *WebServer) withApiKey(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - correctKey, err := dbservice.GetApiKey() + correctKey, err := dbservice.Db.GetApiKey() if err != nil || correctKey == "" { log.Printf("reload API key is either empty or there is an error: %s \n", err) @@ -69,7 +69,7 @@ func (ctx *WebServer) Start(host, tlshost string) { if os.Getenv("AQUAALERT_KEY_PEM") != "" { keyPem = os.Getenv("AQUAALERT_KEY_PEM") } - dbservice.EnsureApiKey() + dbservice.Db.EnsureApiKey() ctx.router.HandleFunc("/", ctx.sessionHandler(ctx.scanHandler)).Methods("POST") ctx.router.HandleFunc("/tenant/{route}", ctx.sessionHandler(ctx.tenantHandler)).Methods("POST") From c0099b0f97e4407b7facface57de830d871137fb Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Wed, 1 Dec 2021 11:56:49 +0600 Subject: [PATCH 20/61] feat(postgresDb): add test connect psql, change cfg.yaml --- cfg.yaml | 17 +++++++++++++++-- dbservice/dbservice.go | 17 ++++++++++------- dbservice/postgresdb/dbparam.go | 5 +++++ router/router.go | 4 +++- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/cfg.yaml b/cfg.yaml index 8f04cbda..e2c2214f 100644 --- a/cfg.yaml +++ b/cfg.yaml @@ -3,8 +3,21 @@ name: tenant # The tenant name aqua-server: # URL of Aqua Server for links. E.g. https://myserver.aquasec.com -max-db-size: 1000 # Max size of DB in MB. if empty then unlimited -db-verify-interval: 1 # How often to check the DB size. By default, Postee checks every 1 hour + +dbsettings: + max-db-size: 1000 # Max size of DB in MB. if empty then unlimited + db-verify-interval: 1 # How often to check the DB size. By default, Postee checks every 1 hour + + #PostgresDb settings. + dbname: database #Database PostgreSQL name. Must be filled for PostgreSQL. + dbhostname: #Database PostgreSQL hostname. Default dbhostname='localhost' + dbport: #Database PostgreSQL port. Default dbhostname=5432 + dbuser: user #Database PostgreSQL user. Must be filled for PostgreSQL. + dbpassword: user-password #Database PostgreSQL user password. + dbsslmode: #Database PostgreSQL sslmode. Default dbsslmode='disable' + + #BoltDb setting. + # dbpath: #Database BoltDb path. Default dbpath='/server/database/webhooks.db' # Routes are used to define how to handle an incoming message routes: diff --git a/dbservice/dbservice.go b/dbservice/dbservice.go index d8c822ad..14f12f05 100644 --- a/dbservice/dbservice.go +++ b/dbservice/dbservice.go @@ -18,15 +18,15 @@ type DbSettings struct { DBTestInterval int `json:"db-verify-interval,omitempty"` //PostgresDb - DbName string - DbHostName string - DbPort string - DbUser string - DbPassword string - DbSslMode string + DbName string `json:"dbname,omitempty"` + DbHostName string `json:"dbhostname,omitempty"` + DbPort string `json:"dbport,omitempty"` + DbUser string `json:"dbuser,omitempty"` + DbPassword string `json:"dbpassword,omitempty"` + DbSslMode string `json:"dbsslmode,omitempty"` //BoltDb - DbPath string + DbPath string `json:"dbpath,omitempty"` } type DbProvider interface { MayBeStoreMessage(message []byte, messageKey string, expired *time.Time) (wasStored bool, err error) @@ -56,6 +56,9 @@ func ConfigureDb(settings *DbSettings, id string) error { if err != nil { return err } + if err = db.TestConnect(); err != nil { + return err + } Db = db } Db.SetDbSizeLimit(settings.DBMaxSize) diff --git a/dbservice/postgresdb/dbparam.go b/dbservice/postgresdb/dbparam.go index 6ebea8dd..d6fc69a5 100644 --- a/dbservice/postgresdb/dbparam.go +++ b/dbservice/postgresdb/dbparam.go @@ -82,6 +82,11 @@ func (postgresDb *PostgresDb) SetDbSizeLimit(limit int) { DbSizeLimit = limit } +func (postgresDb *PostgresDb) TestConnect() error { + _, err := psqlConnect(postgresDb.psqlInfo) + return errors.New("Error postgresDb test connect: " + err.Error()) +} + var psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { db, err := sqlx.Connect("postgres", psqlInfo) if err != nil { diff --git a/router/router.go b/router/router.go index 1400471d..21d1718c 100644 --- a/router/router.go +++ b/router/router.go @@ -198,7 +198,9 @@ func (ctx *Router) load() error { ctx.aquaServer = fmt.Sprintf("%s%s#/images/", tenant.AquaServer, slash) } - dbservice.ConfigureDb(&tenant.DbSettings, tenant.Name) + if err = dbservice.ConfigureDb(&tenant.DbSettings, tenant.Name); err != nil { + return err + } ctx.ticker = time.NewTicker(baseForTicker * time.Duration(tenant.DbSettings.DBTestInterval)) go func() { From f267246941374d544a79589d9b537ebd1c1fff1c Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Wed, 1 Dec 2021 17:00:18 +0600 Subject: [PATCH 21/61] feat(psql) changed configurate psql settings are now taken from env 'POSTGRES_URL' now,tests for configurate psql added --- README.md | 2 +- cfg.yaml | 17 +--- dbservice/boltdb/dbparam.go | 12 ++- dbservice/boltdb/dbparam_test.go | 8 +- dbservice/dbservice.go | 50 +++++------- dbservice/dbservice_test.go | 94 +++++++++++++++++++++++ dbservice/postgresdb/actions.go | 8 +- dbservice/postgresdb/actions_test.go | 6 +- dbservice/postgresdb/checker.go | 22 +++--- dbservice/postgresdb/checker_test.go | 4 +- dbservice/postgresdb/dbaggregator.go | 8 +- dbservice/postgresdb/dbaggregator_test.go | 2 +- dbservice/postgresdb/dbparam.go | 71 ++++------------- dbservice/postgresdb/dbparam_test.go | 43 ----------- dbservice/postgresdb/dbservice_test.go | 77 ++++++++++++------- dbservice/postgresdb/invalidinit_test.go | 2 +- dbservice/postgresdb/plgstats.go | 6 +- dbservice/postgresdb/plgstats_test.go | 4 +- dbservice/postgresdb/sharedcfg.go | 8 +- dbservice/postgresdb/sharedcfg_test.go | 6 +- router/router.go | 4 +- router/tenants.go | 15 ++-- 22 files changed, 238 insertions(+), 231 deletions(-) create mode 100644 dbservice/dbservice_test.go delete mode 100644 dbservice/postgresdb/dbparam_test.go diff --git a/README.md b/README.md index 1b2c3f29..325b9cad 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,7 @@ The Postee container uses BoltDB to store information about previously scanned i This is used to prevent resending messages that were already sent before. The size of the database can grow over time. Every image that is saved in the database uses 20K of storage. -Postee supports ‘PATH_TO_DB’ environment variable to change the database directory. To use, set the ‘PATH_TO_DB’ environment variable to point to the database file, for example: PATH_TO_DB="./database/webhook.db". By default, the directory for the database file is “/server/database/webhook.db”. +Postee supports ‘PATH_TO_BOLTDB’ environment variable to change the bolt database directory. To use, set the ‘PATH_TO_BOLTDB’ environment variable to point to the bolt database file, for example: PATH_TO_BOLTDB="./database/webhook.db". By default, the directory for the bolt database file is “/server/database/webhook.db”. If you would like to persist the database file between restarts of the Postee container, then you should use a persistent storage option to mount the "/server/database" directory of the container. diff --git a/cfg.yaml b/cfg.yaml index e2c2214f..8f04cbda 100644 --- a/cfg.yaml +++ b/cfg.yaml @@ -3,21 +3,8 @@ name: tenant # The tenant name aqua-server: # URL of Aqua Server for links. E.g. https://myserver.aquasec.com - -dbsettings: - max-db-size: 1000 # Max size of DB in MB. if empty then unlimited - db-verify-interval: 1 # How often to check the DB size. By default, Postee checks every 1 hour - - #PostgresDb settings. - dbname: database #Database PostgreSQL name. Must be filled for PostgreSQL. - dbhostname: #Database PostgreSQL hostname. Default dbhostname='localhost' - dbport: #Database PostgreSQL port. Default dbhostname=5432 - dbuser: user #Database PostgreSQL user. Must be filled for PostgreSQL. - dbpassword: user-password #Database PostgreSQL user password. - dbsslmode: #Database PostgreSQL sslmode. Default dbsslmode='disable' - - #BoltDb setting. - # dbpath: #Database BoltDb path. Default dbpath='/server/database/webhooks.db' +max-db-size: 1000 # Max size of DB in MB. if empty then unlimited +db-verify-interval: 1 # How often to check the DB size. By default, Postee checks every 1 hour # Routes are used to define how to handle an incoming message routes: diff --git a/dbservice/boltdb/dbparam.go b/dbservice/boltdb/dbparam.go index eae8544a..1fbf249a 100644 --- a/dbservice/boltdb/dbparam.go +++ b/dbservice/boltdb/dbparam.go @@ -1,7 +1,6 @@ package boltdb import ( - "log" "os" "path/filepath" "sync" @@ -38,23 +37,22 @@ func (boltDb *BoltDb) ChangeDbPath(newPath string) { mutex.Unlock() } -func (boltDb *BoltDb) SetNewDbPathFromEnv() { - newPath := os.Getenv("PATH_TO_DB") +func (boltDb *BoltDb) SetNewDbPathFromEnv() error { + newPath := os.Getenv("PATH_TO_BOLTDB") if newPath != "" { if _, err := os.Stat(newPath); err != nil { if os.IsNotExist(err) { err = os.MkdirAll(filepath.Dir(newPath), os.ModePerm) if err != nil { - log.Printf("Can't create DateBase directory: %v, the default path is used", err) - return + return err } } else { - log.Printf("Can't check DateBase directory: %v, the default path is used", err) - return + return err } } boltDb.ChangeDbPath(newPath) } + return nil } func (boltDb *BoltDb) SetDbSizeLimit(limit int) { diff --git a/dbservice/boltdb/dbparam_test.go b/dbservice/boltdb/dbparam_test.go index a2350d0f..2ad6f037 100644 --- a/dbservice/boltdb/dbparam_test.go +++ b/dbservice/boltdb/dbparam_test.go @@ -9,8 +9,8 @@ import ( func TestSetNewDbPathFromEnv(t *testing.T) { db := NewBoltDb() - envPathToDbOld := os.Getenv("PATH_TO_DB") - defer os.Setenv("PATH_TO_DB", envPathToDbOld) + envPathToDbOld := os.Getenv("PATH_TO_BOLTDB") + defer os.Setenv("PATH_TO_BOLTDB", envPathToDbOld) dbPathOld := db.DbPath defaultDbPath := "/server/database/webhooks.db" @@ -20,7 +20,7 @@ func TestSetNewDbPathFromEnv(t *testing.T) { changePermission bool expectedDBPath string }{ - {"Empty PATH_TO_DB", "", false, defaultDbPath}, + {"Empty PATH_TO_BOLTDB", "", false, defaultDbPath}, {"Permission denied to create directory(default DbPath is used)", "/database/database.db", false, defaultDbPath}, {"New DbPath", "./base/base.db", false, "./base/base.db"}, {"Permission denied to check directory(default DbPath is used)", "webhook/database/webhooks.db", true, defaultDbPath}, @@ -28,7 +28,7 @@ func TestSetNewDbPathFromEnv(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - os.Setenv("PATH_TO_DB", test.envPathToDb) + os.Setenv("PATH_TO_BOLTDB", test.envPathToDb) baseDir := strings.Split(filepath.Dir(test.envPathToDb), "/")[0] if test.changePermission { err := os.Mkdir(baseDir, os.ModeDir) diff --git a/dbservice/dbservice.go b/dbservice/dbservice.go index 14f12f05..c3631d2d 100644 --- a/dbservice/dbservice.go +++ b/dbservice/dbservice.go @@ -1,6 +1,7 @@ package dbservice import ( + "errors" "os" "time" @@ -12,22 +13,6 @@ var ( Db DbProvider ) -type DbSettings struct { - DBMaxSize int `json:"max-db-size,omitempty"` - DBRemoveOldData int `json:"delete-old-data,omitempty"` - DBTestInterval int `json:"db-verify-interval,omitempty"` - - //PostgresDb - DbName string `json:"dbname,omitempty"` - DbHostName string `json:"dbhostname,omitempty"` - DbPort string `json:"dbport,omitempty"` - DbUser string `json:"dbuser,omitempty"` - DbPassword string `json:"dbpassword,omitempty"` - DbSslMode string `json:"dbsslmode,omitempty"` - - //BoltDb - DbPath string `json:"dbpath,omitempty"` -} type DbProvider interface { MayBeStoreMessage(message []byte, messageKey string, expired *time.Time) (wasStored bool, err error) CheckSizeLimit() @@ -39,28 +24,29 @@ type DbProvider interface { SetDbSizeLimit(limit int) } -func ConfigureDb(settings *DbSettings, id string) error { - if settings.DBTestInterval == 0 { - settings.DBTestInterval = 1 +func ConfigurateDb(id string, dBTestInterval *int, dbMaxSize int) error { + if *dBTestInterval == 0 { + *dBTestInterval = 1 } - if settings.DbName == "" && settings.DbHostName == "" && settings.DbUser == "" { - boltdb := boltdb.NewBoltDb() - if os.Getenv("PATH_TO_DB") != "" { - boltdb.SetNewDbPathFromEnv() + if os.Getenv("POSTGRES_URL") != "" { + if id == "" { + return errors.New("error configurate postgresDb: 'id' is empty") } - Db = boltdb - return nil - } else { - db, err := postgresdb.NewPostgresDb(id, settings.DbName, settings.DbHostName, settings.DbPort, settings.DbUser, settings.DbPassword, settings.DbSslMode) - if err != nil { + postgresDb := postgresdb.NewPostgresDb(id, os.Getenv("POSTGRES_URL")) + if err := postgresdb.TestConnect(postgresDb.ConnectUrl); err != nil { return err } - if err = db.TestConnect(); err != nil { - return err + Db = postgresDb + } else { + boltdb := boltdb.NewBoltDb() + if os.Getenv("PATH_TO_BOLTDB") != "" { + if err := boltdb.SetNewDbPathFromEnv(); err != nil { + return err + } } - Db = db + Db = boltdb } - Db.SetDbSizeLimit(settings.DBMaxSize) + Db.SetDbSizeLimit(dbMaxSize) return nil } diff --git a/dbservice/dbservice_test.go b/dbservice/dbservice_test.go new file mode 100644 index 00000000..898b91b5 --- /dev/null +++ b/dbservice/dbservice_test.go @@ -0,0 +1,94 @@ +package dbservice + +import ( + "errors" + "os" + "reflect" + "testing" + + "github.com/aquasecurity/postee/dbservice/postgresdb" +) + +func TestConfigurateBoltDbPath(t *testing.T) { + tests := []struct { + name string + dbPath string + expectedPath string + }{ + {"happy configuration BoltDB with dbPath", "database/webhooks.db", "database/webhooks.db"}, + {"happy configuration BoltDB with empty dbPath", "", "/server/database/webhooks.db"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + oldPathEnv := os.Getenv("PATH_TO_BOLTDB") + defer func() { + os.Setenv("PATH_TO_BOLTDB", oldPathEnv) + }() + os.Setenv("PATH_TO_BOLTDB", test.dbPath) + + testInterval := 2 + if err := ConfigurateDb("id", &testInterval, 1); err != nil { + t.Errorf("Unexpected error: %v", err) + } + if testInterval != 2 { + t.Error("test interval error, expected: 2, got: ", testInterval) + } + if test.expectedPath != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("DbPath").Interface() { + t.Errorf("paths do not match, expected: %s, got: %s", test.expectedPath, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("DbPath").Interface()) + } + + }) + } +} + +func TestConfiguratePostgresDbUrlAndId(t *testing.T) { + tests := []struct { + name string + url string + id string + expectedError error + }{ + {"happy configuration", "postgresql://user:secret@localhost", "test-id", nil}, + {"bad id", "postgresql://user:secret@localhost", "", errors.New("error configurate postgresDb: 'id' is empty")}, + {"bad url", "badUrl", "test-id", errors.New("badUrl error")}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testConnectSaved := postgresdb.TestConnect + postgresdb.TestConnect = func(connectUrl string) error { + if connectUrl == "badUrl" { + return errors.New("badUrl error") + } + return nil + } + defer func() { + postgresdb.TestConnect = testConnectSaved + }() + oldUrlEnv := os.Getenv("POSTGRES_URL") + defer func() { + os.Setenv("POSTGRES_URL", oldUrlEnv) + }() + os.Setenv("POSTGRES_URL", test.url) + + testInterval := 0 + err := ConfigurateDb(test.id, &testInterval, 1) + if err != nil { + if err.Error() != test.expectedError.Error() { + t.Errorf("Unexpected error, expected: %s, got: %s", test.expectedError, err) + } + } else { + if testInterval != 1 { + t.Error("test interval error, expected: 1, got: ", testInterval) + } + if test.url != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface() { + t.Errorf("url's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface()) + } + if test.id != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("Id").Interface() { + t.Errorf("id's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("Id").Interface()) + } + } + }) + } +} diff --git a/dbservice/postgresdb/actions.go b/dbservice/postgresdb/actions.go index 55cc8418..a9d7b1df 100644 --- a/dbservice/postgresdb/actions.go +++ b/dbservice/postgresdb/actions.go @@ -7,7 +7,7 @@ import ( ) func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey string, expired *time.Time) (wasStored bool, err error) { - db, err := psqlConnect(postgresDb.psqlInfo) + db, err := psqlConnect(postgresDb.ConnectUrl) if err != nil { return false, err } @@ -22,7 +22,7 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin } currentValue := "" - if err = db.Get(¤tValue, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "messageValue", dbTableName, "id", "messageKey"), postgresDb.id, messageKey); err != nil { + if err = db.Get(¤tValue, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "messageValue", dbTableName, "id", "messageKey"), postgresDb.Id, messageKey); err != nil { if err != sql.ErrNoRows { return false, err } @@ -32,12 +32,12 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin return false, nil } else { - if err = insert(db, dbTableName, postgresDb.id, "messagekey", messageKey, "messagevalue", string(message)); err != nil { + if err = insert(db, dbTableName, postgresDb.Id, "messagekey", messageKey, "messagevalue", string(message)); err != nil { return false, err } if expired != nil { - if err = insert(db, dbTableExpiryDates, postgresDb.id, "date", expired.Format(DateFmt), "messagekey", messageKey); err != nil { + if err = insert(db, dbTableExpiryDates, postgresDb.Id, "date", expired.Format(DateFmt), "messagekey", messageKey); err != nil { return false, err } } diff --git a/dbservice/postgresdb/actions_test.go b/dbservice/postgresdb/actions_test.go index a139e39e..06f73f36 100644 --- a/dbservice/postgresdb/actions_test.go +++ b/dbservice/postgresdb/actions_test.go @@ -3,6 +3,7 @@ package postgresdb import ( "log" "testing" + "time" "github.com/jmoiron/sqlx" sqlxmock "github.com/zhashkevych/go-sqlxmock" @@ -10,6 +11,7 @@ import ( func TestStoreMessage(t *testing.T) { currentValueStoreMessage := "" + time := time.Now() savedInitTable := initTable initTable = func(db *sqlx.DB, tableName string) error { return nil } @@ -19,7 +21,7 @@ func TestStoreMessage(t *testing.T) { return nil } savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -43,7 +45,7 @@ func TestStoreMessage(t *testing.T) { for _, test := range tests { // Handling of first scan - isNew, err := db.MayBeStoreMessage([]byte(*test.input), AlpineImageKey, nil) + isNew, err := db.MayBeStoreMessage([]byte(*test.input), AlpineImageKey, &time) if err != nil { t.Errorf("Error: %s\n", err) } diff --git a/dbservice/postgresdb/checker.go b/dbservice/postgresdb/checker.go index da4bcd2d..94b2e65f 100644 --- a/dbservice/postgresdb/checker.go +++ b/dbservice/postgresdb/checker.go @@ -11,10 +11,10 @@ func (postgresDb *PostgresDb) CheckSizeLimit() { return } - psqlInfo := postgresDb.psqlInfo - db, err := psqlConnect(psqlInfo) + connectUrl := postgresDb.ConnectUrl + db, err := psqlConnect(connectUrl) if err != nil { - log.Println("CheckSizeLimit: Can't open db, psqlInfo: ", psqlInfo) + log.Println("CheckSizeLimit: Can't open db, connectUrl: ", connectUrl) return } defer db.Close() @@ -25,18 +25,18 @@ func (postgresDb *PostgresDb) CheckSizeLimit() { return } if size > DbSizeLimit { - if err = deleteRowsById(db, dbTableName, postgresDb.id); err != nil { - log.Printf("CheckSizeLimit: Can't delete id's: %s from table: %s", postgresDb.id, dbTableName) + if err = deleteRowsById(db, dbTableName, postgresDb.Id); err != nil { + log.Printf("CheckSizeLimit: Can't delete id's: %s from table: %s", postgresDb.Id, dbTableName) return } } } func (postgresDb *PostgresDb) CheckExpiredData() { - psqlInfo := postgresDb.psqlInfo - db, err := psqlConnect(psqlInfo) + connectUrl := postgresDb.ConnectUrl + db, err := psqlConnect(connectUrl) if err != nil { - log.Println("CheckExpiredData: Can't open db, psqlInfo: ", psqlInfo) + log.Println("CheckExpiredData: Can't open db, connectUrl: ", connectUrl) return } defer db.Close() @@ -45,7 +45,7 @@ func (postgresDb *PostgresDb) CheckExpiredData() { Date string `db:"date"` TtlKey string `db:"messagekey"` } - if err := db.Select(&scanStructs, fmt.Sprintf("SELECT (key AND ttlkey) FROM %s WHERE %s=$1", dbTableExpiryDates, "id"), postgresDb.id); err != nil { + if err := db.Select(&scanStructs, fmt.Sprintf("SELECT (key AND ttlkey) FROM %s WHERE %s=$1", dbTableExpiryDates, "id"), postgresDb.Id); err != nil { log.Printf("CheckExpiredData: Can't get %s table: %s", dbTableExpiryDates, err) return } @@ -54,12 +54,12 @@ func (postgresDb *PostgresDb) CheckExpiredData() { for _, scanStruct := range scanStructs { if scanStruct.Date <= max { - if err = deleteRow(db, dbTableExpiryDates, postgresDb.id, "messagekey", scanStruct.TtlKey); err != nil { + if err = deleteRow(db, dbTableExpiryDates, postgresDb.Id, "messagekey", scanStruct.TtlKey); err != nil { log.Printf("CheckExpiredData: Can't delete %s from table:%s", scanStruct.TtlKey, dbTableExpiryDates) return } - if err = deleteRow(db, dbTableName, postgresDb.id, "messagekey", scanStruct.TtlKey); err != nil { + if err = deleteRow(db, dbTableName, postgresDb.Id, "messagekey", scanStruct.TtlKey); err != nil { log.Printf("CheckExpiredData: Can't delete %s from table:%s", scanStruct.TtlKey, dbTableName) return } diff --git a/dbservice/postgresdb/checker_test.go b/dbservice/postgresdb/checker_test.go index fcfb61d8..a39ed301 100644 --- a/dbservice/postgresdb/checker_test.go +++ b/dbservice/postgresdb/checker_test.go @@ -23,7 +23,7 @@ func TestExpiredDates(t *testing.T) { t.Run(test.name, func(t *testing.T) { deleted := false savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -64,7 +64,7 @@ func TestSizeLimit(t *testing.T) { t.Run(test.name, func(t *testing.T) { deleted := false savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) diff --git a/dbservice/postgresdb/dbaggregator.go b/dbservice/postgresdb/dbaggregator.go index dc6ad702..25b5b3af 100644 --- a/dbservice/postgresdb/dbaggregator.go +++ b/dbservice/postgresdb/dbaggregator.go @@ -11,7 +11,7 @@ func (postgresDb *PostgresDb) AggregateScans(output string, scansPerTicket int, ignoreTheQuantity bool) ([]map[string]string, error) { - db, err := psqlConnect(postgresDb.psqlInfo) + db, err := psqlConnect(postgresDb.ConnectUrl) if err != nil { return nil, err } @@ -26,7 +26,7 @@ func (postgresDb *PostgresDb) AggregateScans(output string, aggregatedScans = append(aggregatedScans, currentScan) } currentValue := "" - if err = db.Get(¤tValue, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "saving", dbTableAggregator, "id", "output"), postgresDb.id, output); err != nil { + if err = db.Get(¤tValue, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "saving", dbTableAggregator, "id", "output"), postgresDb.Id, output); err != nil { if err != sql.ErrNoRows { return nil, err } @@ -46,13 +46,13 @@ func (postgresDb *PostgresDb) AggregateScans(output string, if err != nil { return nil, err } - if err = insert(db, dbTableAggregator, postgresDb.id, "output", output, "saving", string(saving)); err != nil { + if err = insert(db, dbTableAggregator, postgresDb.Id, "output", output, "saving", string(saving)); err != nil { return nil, err } return nil, nil } - if err = insert(db, dbTableAggregator, postgresDb.id, "output", output, "saving", ""); err != nil { + if err = insert(db, dbTableAggregator, postgresDb.Id, "output", output, "saving", ""); err != nil { return nil, err } return aggregatedScans, nil diff --git a/dbservice/postgresdb/dbaggregator_test.go b/dbservice/postgresdb/dbaggregator_test.go index 869797d2..730d7f14 100644 --- a/dbservice/postgresdb/dbaggregator_test.go +++ b/dbservice/postgresdb/dbaggregator_test.go @@ -58,7 +58,7 @@ func TestAggregateScans(t *testing.T) { return nil } savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) diff --git a/dbservice/postgresdb/dbparam.go b/dbservice/postgresdb/dbparam.go index d6fc69a5..7babc6ef 100644 --- a/dbservice/postgresdb/dbparam.go +++ b/dbservice/postgresdb/dbparam.go @@ -2,11 +2,8 @@ package postgresdb import ( "errors" - "log" - "strings" "time" - "github.com/aquasecurity/postee/utils" "github.com/jmoiron/sqlx" ) @@ -23,72 +20,32 @@ var ( ) type PostgresDb struct { - psqlInfo string - id string + ConnectUrl string + Id string } -func NewPostgresDb(id, dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) (*PostgresDb, error) { - info, err := buildPsqlInfo(dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode) - if err != nil { - return nil, err - } +func NewPostgresDb(id, connectUrl string) *PostgresDb { return &PostgresDb{ - psqlInfo: info, - id: id, - }, nil -} - -func buildPsqlInfo(dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) (string, error) { - psqlInfo := []string{} - - if dbHostName != "" { - dbHostName = utils.GetEnvironmentVarOrPlain(dbHostName) - psqlInfo = append(psqlInfo, "host="+dbHostName) - } else { - log.Printf("dbHostName is empty, for psqlInfo is used dbHostName=localhost") - } - if dbPort != "" { - dbPort = utils.GetEnvironmentVarOrPlain(dbPort) - psqlInfo = append(psqlInfo, "port="+dbPort) - } else { - log.Printf("dbPort is empty, for psqlInfo is used dbPort=5432") + ConnectUrl: connectUrl, + Id: id, } - if dbName != "" { - dbName = utils.GetEnvironmentVarOrPlain(dbName) - psqlInfo = append(psqlInfo, "dbname="+dbName) - } else { - return "", errors.New("can't build psqlInfo, dbName is empty") - } - if dbUser != "" { - dbUser = utils.GetEnvironmentVarOrPlain(dbUser) - psqlInfo = append(psqlInfo, "user="+dbUser) - } else { - return "", errors.New("can't build psqlInfo, dbUser is empty") - } - if dbPassword != "" { - dbPassword = utils.GetEnvironmentVarOrPlain(dbPassword) - psqlInfo = append(psqlInfo, "password="+dbPassword) - } - if dbSslMode != "" { - psqlInfo = append(psqlInfo, "sslmode="+dbSslMode) - } else { - log.Printf("dbSslMode is empty, for psqlInfo is used sslmode=disable") - psqlInfo = append(psqlInfo, "sslmode="+"disable") - } - return strings.Join(psqlInfo[:], " "), nil } func (postgresDb *PostgresDb) SetDbSizeLimit(limit int) { DbSizeLimit = limit } -func (postgresDb *PostgresDb) TestConnect() error { - _, err := psqlConnect(postgresDb.psqlInfo) - return errors.New("Error postgresDb test connect: " + err.Error()) +var TestConnect = func(connectUrl string) error { + db, err := psqlConnect(connectUrl) + if err != nil { + return errors.New("Error postgresDb test connect: " + err.Error()) + } + defer db.Close() + return nil } -var psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { - db, err := sqlx.Connect("postgres", psqlInfo) +var psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, err := sqlx.Connect("postgres", connectUrl) if err != nil { return nil, err } diff --git a/dbservice/postgresdb/dbparam_test.go b/dbservice/postgresdb/dbparam_test.go deleted file mode 100644 index 75977592..00000000 --- a/dbservice/postgresdb/dbparam_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package postgresdb - -import ( - "testing" -) - -func TestBuildPsqlInfo(t *testing.T) { - var tests = []struct { - name string - dbName string - dbHostName string - dbPort string - dbUser string - dbPassword string - dbSslMode string - expectedPsqlInfo string - expectedError string - }{ - {"empty dbName", "", "dbHostName", "dbPort", "dbUser", "dbPassword", "dbSslMode", - "", "can't build psqlInfo, dbName is empty"}, - {"empty dbHostName", "dbName", "", "dbPort", "dbUser", "dbPassword", "dbSslMode", - "port=dbPort dbname=dbName user=dbUser password=dbPassword sslmode=dbSslMode", ""}, - {"empty dbPort", "dbName", "dbHostName", "", "dbUser", "dbPassword", "dbSslMode", - "host=dbHostName dbname=dbName user=dbUser password=dbPassword sslmode=dbSslMode", ""}, - {"empty dbUser", "dbName", "dbHostName", "dbPort", "", "dbPassword", "dbSslMode", - "", "can't build psqlInfo, dbUser is empty"}, - {"empty dbPassword", "dbName", "dbHostName", "dbPort", "dbUser", "", "dbSslMode", - "host=dbHostName port=dbPort dbname=dbName user=dbUser sslmode=dbSslMode", ""}, - {"empty dbSslMode", "dbName", "dbHostName", "dbPort", "dbUser", "dbPassword", "", - "host=dbHostName port=dbPort dbname=dbName user=dbUser password=dbPassword sslmode=disable", ""}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - psqlInfo, err := buildPsqlInfo(test.dbName, test.dbHostName, test.dbPort, test.dbUser, test.dbPassword, test.dbSslMode) - if err != nil && err.Error() != test.expectedError { - t.Errorf("Unexpected error for %s, expected %v, got %v", test.name, test.expectedError, err) - } - if test.expectedPsqlInfo != psqlInfo { - t.Errorf("error getting psqlInfo, expected:%s, got:%s", test.expectedPsqlInfo, psqlInfo) - } - }) - } -} diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index bebd9327..49f7939c 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -74,14 +74,14 @@ var ( ] }` - db, _ = NewPostgresDb("id", "dbName", "", "", "user", "password", "disable") + db = NewPostgresDb("id", "postgresql://user:secret@localhost/dbname?sslmode=disable") ) func TestInitError(t *testing.T) { initErr := errors.New("init error") savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -106,7 +106,7 @@ func TestInitError(t *testing.T) { func TestDeleteRow(t *testing.T) { t.Log("happy delete row") savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { savedDeleteRow := deleteRow defer func() { deleteRow = savedDeleteRow @@ -122,7 +122,7 @@ func TestDeleteRow(t *testing.T) { psqlConnect = savedPsqlConnect }() - psqlDb, _ := psqlConnect(db.psqlInfo) + psqlDb, _ := psqlConnect(db.ConnectUrl) err := deleteRow(psqlDb, "table", "id", "column", "value") if err != nil { t.Errorf("Unexpected error: %v", err) @@ -130,7 +130,7 @@ func TestDeleteRow(t *testing.T) { t.Log("bad delete row") deleteError := errors.New("delete - error") - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { savedDeleteRow := deleteRow defer func() { deleteRow = savedDeleteRow @@ -142,7 +142,7 @@ func TestDeleteRow(t *testing.T) { mock.ExpectExec("DELETE").WillReturnError(deleteError) return db, err } - psqlDb, _ = psqlConnect(db.psqlInfo) + psqlDb, _ = psqlConnect(db.ConnectUrl) err = deleteRow(psqlDb, "table", "id", "column", "value") if deleteError != err { t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) @@ -151,7 +151,7 @@ func TestDeleteRow(t *testing.T) { func TestDeleteRowsById(t *testing.T) { t.Log("happy delete rows by id") savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { savedDeleteRowsById := deleteRowsById defer func() { deleteRowsById = savedDeleteRowsById @@ -167,7 +167,7 @@ func TestDeleteRowsById(t *testing.T) { psqlConnect = savedPsqlConnect }() - psqlDb, _ := psqlConnect(db.psqlInfo) + psqlDb, _ := psqlConnect(db.ConnectUrl) err := deleteRowsById(psqlDb, "table", "id") if err != nil { t.Errorf("Unexpected error: %v", err) @@ -175,7 +175,7 @@ func TestDeleteRowsById(t *testing.T) { t.Log("bad delete row") deleteError := errors.New("delete - error") - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { savedDeleteRowsById := deleteRowsById defer func() { deleteRowsById = savedDeleteRowsById @@ -187,17 +187,17 @@ func TestDeleteRowsById(t *testing.T) { mock.ExpectExec("DELETE").WillReturnError(deleteError) return db, err } - psqlDb, _ = psqlConnect(db.psqlInfo) + psqlDb, _ = psqlConnect(db.ConnectUrl) err = deleteRowsById(psqlDb, "table", "id") if deleteError != err { t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) } } -func TestInsert(t *testing.T) { +func TestInsertAndInsertOutputStats(t *testing.T) { t.Log("happy insert") savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -211,14 +211,19 @@ func TestInsert(t *testing.T) { psqlConnect = savedPsqlConnect }() - psqlDb, _ := psqlConnect(db.psqlInfo) + psqlDb, _ := psqlConnect(db.ConnectUrl) err := insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") if err != nil { - t.Errorf("Unexpected error: %v", err) + t.Errorf("Unexpected error in 'insert': %v", err) + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertOutputStats(psqlDb, "id", "outputName", 1) + if err != nil { + t.Errorf("Unexpected error in 'insertOutputStats': %v", err) } t.Log("happy update") - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -228,15 +233,20 @@ func TestInsert(t *testing.T) { mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) return db, err } - psqlDb, _ = psqlConnect(db.psqlInfo) + psqlDb, _ = psqlConnect(db.ConnectUrl) err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") if err != nil { - t.Errorf("Unexpected error: %v", err) + t.Errorf("Unexpected error in 'insert': %v", err) + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertOutputStats(psqlDb, "id", "outputName", 1) + if err != nil { + t.Errorf("Unexpected error in 'insertOutputStats': %v", err) } t.Log("bad select") badSelectError := errors.New("bad select") - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -244,15 +254,20 @@ func TestInsert(t *testing.T) { mock.ExpectQuery("SELECT").WillReturnError(badSelectError) return db, err } - psqlDb, _ = psqlConnect(db.psqlInfo) + psqlDb, _ = psqlConnect(db.ConnectUrl) err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") if err != badSelectError { - t.Errorf("Unexpected error, expected: %v, got: %v", badSelectError, err) + t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", badSelectError, err) + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertOutputStats(psqlDb, "id", "outputName", 1) + if err != badSelectError { + t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badSelectError, err) } t.Log("bad insert") badInsertError := errors.New("bad insert") - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -262,15 +277,20 @@ func TestInsert(t *testing.T) { mock.ExpectExec("INSERT").WillReturnError(badInsertError) return db, err } - psqlDb, _ = psqlConnect(db.psqlInfo) + psqlDb, _ = psqlConnect(db.ConnectUrl) err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") if err != badInsertError { - t.Errorf("Unexpected error, expected: %v, got: %v", badInsertError, err) + t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", badInsertError, err) + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertOutputStats(psqlDb, "id", "outputName", 1) + if err != badInsertError { + t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badInsertError, err) } t.Log("bad update") badUpdateError := errors.New("bad update") - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -280,9 +300,14 @@ func TestInsert(t *testing.T) { mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) return db, err } - psqlDb, _ = psqlConnect(db.psqlInfo) + psqlDb, _ = psqlConnect(db.ConnectUrl) err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") if err != badUpdateError { - t.Errorf("Unexpected error, expected: %v, got: %v", badUpdateError, err) + t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", badUpdateError, err) + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertOutputStats(psqlDb, "id", "outputName", 1) + if err != badUpdateError { + t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badUpdateError, err) } } diff --git a/dbservice/postgresdb/invalidinit_test.go b/dbservice/postgresdb/invalidinit_test.go index 14db5710..8018b421 100644 --- a/dbservice/postgresdb/invalidinit_test.go +++ b/dbservice/postgresdb/invalidinit_test.go @@ -55,7 +55,7 @@ func TestBucketInitialization(t *testing.T) { savedInitTable := initTable savedPsqlConnect := psqlConnect initTable = func(db *sqlx.DB, tableName string) error { return expectedError } - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, _, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) diff --git a/dbservice/postgresdb/plgstats.go b/dbservice/postgresdb/plgstats.go index a152242e..e109abf2 100644 --- a/dbservice/postgresdb/plgstats.go +++ b/dbservice/postgresdb/plgstats.go @@ -8,7 +8,7 @@ import ( ) func (postgresDb *PostgresDb) RegisterPlgnInvctn(name string) error { - db, err := psqlConnect(postgresDb.psqlInfo) + db, err := psqlConnect(postgresDb.ConnectUrl) if err != nil { return err } @@ -19,12 +19,12 @@ func (postgresDb *PostgresDb) RegisterPlgnInvctn(name string) error { return err } amount := 0 - err = db.Get(&amount, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "amount", dbTableOutputStats, "id", "outputName"), postgresDb.id, name) + err = db.Get(&amount, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "amount", dbTableOutputStats, "id", "outputName"), postgresDb.Id, name) if err != nil && err != sql.ErrNoRows { return err } amount += 1 - err = insertOutputStats(db, postgresDb.id, name, amount) + err = insertOutputStats(db, postgresDb.Id, name, amount) if err != nil { return err } diff --git a/dbservice/postgresdb/plgstats_test.go b/dbservice/postgresdb/plgstats_test.go index c269fd88..b419bc78 100644 --- a/dbservice/postgresdb/plgstats_test.go +++ b/dbservice/postgresdb/plgstats_test.go @@ -19,7 +19,7 @@ func TestRegisterPlgnInvctn(t *testing.T) { return nil } savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -60,7 +60,7 @@ func TestRegisterPlgnInvctnErrors(t *testing.T) { savedInsertOutputStats := insertOutputStats insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { return nil } savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) diff --git a/dbservice/postgresdb/sharedcfg.go b/dbservice/postgresdb/sharedcfg.go index 7ea16c44..2e5941f0 100644 --- a/dbservice/postgresdb/sharedcfg.go +++ b/dbservice/postgresdb/sharedcfg.go @@ -12,7 +12,7 @@ import ( var apiKeyName = "POSTEE_API_KEY" func (postgresDb *PostgresDb) EnsureApiKey() error { - db, err := psqlConnect(postgresDb.psqlInfo) + db, err := psqlConnect(postgresDb.ConnectUrl) if err != nil { return err } @@ -27,7 +27,7 @@ func (postgresDb *PostgresDb) EnsureApiKey() error { return err } - if err = insert(db, dbTableSharedConfig, postgresDb.id, "apikeyname", apiKeyName, "value", apiKey); err != nil { + if err = insert(db, dbTableSharedConfig, postgresDb.Id, "apikeyname", apiKeyName, "value", apiKey); err != nil { return err } @@ -35,13 +35,13 @@ func (postgresDb *PostgresDb) EnsureApiKey() error { } func (postgresDb *PostgresDb) GetApiKey() (string, error) { - db, err := psqlConnect(postgresDb.psqlInfo) + db, err := psqlConnect(postgresDb.ConnectUrl) if err != nil { return "", err } defer db.Close() value := "" - err = db.Get(&value, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "value", dbTableSharedConfig, "id", "apikeyname"), postgresDb.id, apiKeyName) + err = db.Get(&value, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "value", dbTableSharedConfig, "id", "apikeyname"), postgresDb.Id, apiKeyName) if err != nil { return "", err } diff --git a/dbservice/postgresdb/sharedcfg_test.go b/dbservice/postgresdb/sharedcfg_test.go index 8442d2e0..9136f6ec 100644 --- a/dbservice/postgresdb/sharedcfg_test.go +++ b/dbservice/postgresdb/sharedcfg_test.go @@ -15,7 +15,7 @@ func TestApiKey(t *testing.T) { savedInsert := insert insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { return nil } savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -43,7 +43,7 @@ func TestApiKey(t *testing.T) { func TestApiKeyWithoutInit(t *testing.T) { savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) @@ -73,7 +73,7 @@ func TestApiKeyRenewal(t *testing.T) { return nil } savedPsqlConnect := psqlConnect - psqlConnect = func(psqlInfo string) (*sqlx.DB, error) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) diff --git a/router/router.go b/router/router.go index 21d1718c..f72f7f57 100644 --- a/router/router.go +++ b/router/router.go @@ -198,11 +198,11 @@ func (ctx *Router) load() error { ctx.aquaServer = fmt.Sprintf("%s%s#/images/", tenant.AquaServer, slash) } - if err = dbservice.ConfigureDb(&tenant.DbSettings, tenant.Name); err != nil { + if err = dbservice.ConfigurateDb(tenant.Name, &tenant.DBTestInterval, tenant.DBMaxSize); err != nil { return err } - ctx.ticker = time.NewTicker(baseForTicker * time.Duration(tenant.DbSettings.DBTestInterval)) + ctx.ticker = time.NewTicker(baseForTicker * time.Duration(tenant.DBTestInterval)) go func() { for { select { diff --git a/router/tenants.go b/router/tenants.go index ad653662..85de94bd 100644 --- a/router/tenants.go +++ b/router/tenants.go @@ -1,15 +1,16 @@ package router import ( - "github.com/aquasecurity/postee/dbservice" "github.com/aquasecurity/postee/routes" ) type TenantSettings struct { - Name string `json:"name,omitempty"` - AquaServer string `json:"aqua-server,omitempty"` - DbSettings dbservice.DbSettings `json:"dbsettings"` - Outputs []OutputSettings `json:"outputs"` - InputRoutes []routes.InputRoute `json:"routes"` - Templates []Template `json:"templates"` + Name string `json:"name,omitempty"` + AquaServer string `json:"aqua-server,omitempty"` + DBMaxSize int `json:"max-db-size,omitempty"` + DBRemoveOldData int `json:"delete-old-data,omitempty"` + DBTestInterval int `json:"db-verify-interval,omitempty"` + Outputs []OutputSettings `json:"outputs"` + InputRoutes []routes.InputRoute `json:"routes"` + Templates []Template `json:"templates"` } From 0a36fb0ae4c0106b2f79e5745d68d4ec6c17c6b9 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Fri, 3 Dec 2021 13:07:20 +0600 Subject: [PATCH 22/61] combined Webhooktable and WebhookExpiryDates --- dbservice/postgresdb/actions.go | 15 +- dbservice/postgresdb/actions_test.go | 8 +- dbservice/postgresdb/checker.go | 25 +-- dbservice/postgresdb/checker_test.go | 2 +- dbservice/postgresdb/dbparam.go | 1 - dbservice/postgresdb/dbparam_test.go | 23 ++ dbservice/postgresdb/dbservice_test.go | 258 +++++++++++++++++++++-- dbservice/postgresdb/init.go | 3 +- dbservice/postgresdb/insert.go | 30 ++- dbservice/postgresdb/invalidinit_test.go | 2 +- 10 files changed, 306 insertions(+), 61 deletions(-) create mode 100644 dbservice/postgresdb/dbparam_test.go diff --git a/dbservice/postgresdb/actions.go b/dbservice/postgresdb/actions.go index a9d7b1df..d69d366f 100644 --- a/dbservice/postgresdb/actions.go +++ b/dbservice/postgresdb/actions.go @@ -17,10 +17,6 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin return false, err } - if err = initTable(db, dbTableExpiryDates); err != nil { - return false, err - } - currentValue := "" if err = db.Get(¤tValue, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "messageValue", dbTableName, "id", "messageKey"), postgresDb.Id, messageKey); err != nil { if err != sql.ErrNoRows { @@ -31,13 +27,12 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin if currentValue != "" { return false, nil } else { - - if err = insert(db, dbTableName, postgresDb.Id, "messagekey", messageKey, "messagevalue", string(message)); err != nil { - return false, err - } if expired != nil { - - if err = insert(db, dbTableExpiryDates, postgresDb.Id, "date", expired.Format(DateFmt), "messagekey", messageKey); err != nil { + if err = insertInTableName(db, postgresDb.Id, expired.Format(DateFmt), messageKey, string(message)); err != nil { + return false, err + } + } else { + if err = insertInTableName(db, postgresDb.Id, "", messageKey, string(message)); err != nil { return false, err } } diff --git a/dbservice/postgresdb/actions_test.go b/dbservice/postgresdb/actions_test.go index 06f73f36..cc834f1c 100644 --- a/dbservice/postgresdb/actions_test.go +++ b/dbservice/postgresdb/actions_test.go @@ -15,9 +15,9 @@ func TestStoreMessage(t *testing.T) { savedInitTable := initTable initTable = func(db *sqlx.DB, tableName string) error { return nil } - savedInsert := insert - insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { - currentValueStoreMessage = value3 + savedinsertInTableName := insertInTableName + insertInTableName = func(db *sqlx.DB, id, date, messageKey, messageValue string) error { + currentValueStoreMessage = messageValue return nil } savedPsqlConnect := psqlConnect @@ -32,7 +32,7 @@ func TestStoreMessage(t *testing.T) { } defer func() { initTable = savedInitTable - insert = savedInsert + insertInTableName = savedinsertInTableName psqlConnect = savedPsqlConnect }() diff --git a/dbservice/postgresdb/checker.go b/dbservice/postgresdb/checker.go index 94b2e65f..86e8ab76 100644 --- a/dbservice/postgresdb/checker.go +++ b/dbservice/postgresdb/checker.go @@ -36,31 +36,22 @@ func (postgresDb *PostgresDb) CheckExpiredData() { connectUrl := postgresDb.ConnectUrl db, err := psqlConnect(connectUrl) if err != nil { - log.Println("CheckExpiredData: Can't open db, connectUrl: ", connectUrl) + log.Printf("CheckExpiredData: Can't open postgresDb: %v", err) return } defer db.Close() - var scanStructs []struct { - Date string `db:"date"` - TtlKey string `db:"messagekey"` - } - if err := db.Select(&scanStructs, fmt.Sprintf("SELECT (key AND ttlkey) FROM %s WHERE %s=$1", dbTableExpiryDates, "id"), postgresDb.Id); err != nil { - log.Printf("CheckExpiredData: Can't get %s table: %s", dbTableExpiryDates, err) + dates := []string{} + if err := db.Select(&dates, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 and %s != '')", "date", dbTableName, "id", "date"), postgresDb.Id); err != nil { + log.Printf("CheckExpiredData: Can't get dates from table: %s, err: %v", dbTableName, err) return } max := time.Now().UTC().Format(DateFmt) //remove expired records - for _, scanStruct := range scanStructs { - if scanStruct.Date <= max { - - if err = deleteRow(db, dbTableExpiryDates, postgresDb.Id, "messagekey", scanStruct.TtlKey); err != nil { - log.Printf("CheckExpiredData: Can't delete %s from table:%s", scanStruct.TtlKey, dbTableExpiryDates) - return - } - - if err = deleteRow(db, dbTableName, postgresDb.Id, "messagekey", scanStruct.TtlKey); err != nil { - log.Printf("CheckExpiredData: Can't delete %s from table:%s", scanStruct.TtlKey, dbTableName) + for _, date := range dates { + if date <= max { + if err = deleteRow(db, dbTableName, postgresDb.Id, "date", date); err != nil { + log.Printf("CheckExpiredData: Can't delete %s from table:%s", date, dbTableName) return } } diff --git a/dbservice/postgresdb/checker_test.go b/dbservice/postgresdb/checker_test.go index a39ed301..f27c4814 100644 --- a/dbservice/postgresdb/checker_test.go +++ b/dbservice/postgresdb/checker_test.go @@ -28,7 +28,7 @@ func TestExpiredDates(t *testing.T) { if err != nil { log.Println("failed to open sqlmock database:", err) } - rows := sqlxmock.NewRows([]string{"date", "messagekey"}).AddRow(test.time, "ttlKeyTest") + rows := sqlxmock.NewRows([]string{"date"}).AddRow(test.time) mock.ExpectQuery("SELECT").WillReturnRows(rows) return db, err } diff --git a/dbservice/postgresdb/dbparam.go b/dbservice/postgresdb/dbparam.go index 7babc6ef..8dbbbb76 100644 --- a/dbservice/postgresdb/dbparam.go +++ b/dbservice/postgresdb/dbparam.go @@ -10,7 +10,6 @@ import ( var ( dbTableName = "WebhookTable" dbTableAggregator = "WebhookAggregator" - dbTableExpiryDates = "WebhookExpiryDates" dbTableOutputStats = "WebhookOutputStats" dbTableSharedConfig = "WebhookSharedConfig" diff --git a/dbservice/postgresdb/dbparam_test.go b/dbservice/postgresdb/dbparam_test.go new file mode 100644 index 00000000..efa04496 --- /dev/null +++ b/dbservice/postgresdb/dbparam_test.go @@ -0,0 +1,23 @@ +package postgresdb + +import ( + "errors" + "testing" + + "github.com/jmoiron/sqlx" +) + +func TestConnectFunc(t *testing.T) { + expectedError := "Error postgresDb test connect: connect error" + savedpsqlConnect := psqlConnect + defer func() { + psqlConnect = savedpsqlConnect + }() + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + return nil, errors.New("connect error") + } + err := TestConnect("url") + if err.Error() != expectedError { + t.Errorf("error text connect, expectedError: %v, got: %v", expectedError, err) + } +} diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index 49f7939c..0dfca0bb 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -194,7 +194,7 @@ func TestDeleteRowsById(t *testing.T) { } } -func TestInsertAndInsertOutputStats(t *testing.T) { +func TestInsert(t *testing.T) { t.Log("happy insert") savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { @@ -216,11 +216,6 @@ func TestInsertAndInsertOutputStats(t *testing.T) { if err != nil { t.Errorf("Unexpected error in 'insert': %v", err) } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "id", "outputName", 1) - if err != nil { - t.Errorf("Unexpected error in 'insertOutputStats': %v", err) - } t.Log("happy update") psqlConnect = func(connectUrl string) (*sqlx.DB, error) { @@ -238,31 +233,38 @@ func TestInsertAndInsertOutputStats(t *testing.T) { if err != nil { t.Errorf("Unexpected error in 'insert': %v", err) } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "id", "outputName", 1) - if err != nil { - t.Errorf("Unexpected error in 'insertOutputStats': %v", err) - } - t.Log("bad select") - badSelectError := errors.New("bad select") + t.Log("select error") + selectError := errors.New("select error") psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) } - mock.ExpectQuery("SELECT").WillReturnError(badSelectError) + mock.ExpectQuery("SELECT").WillReturnError(selectError) return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") - if err != badSelectError { - t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", badSelectError, err) + if err != selectError { + t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", selectError, err) + } + + t.Log("select 2 rows") + select2RowsError := "error insert in postgresDb. Table:table where id=id, column2=value2, have 2 rows" + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "id", "outputName", 1) - if err != badSelectError { - t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badSelectError, err) + err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + if err.Error() != select2RowsError { + t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", select2RowsError, err) } t.Log("bad insert") @@ -282,6 +284,111 @@ func TestInsertAndInsertOutputStats(t *testing.T) { if err != badInsertError { t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", badInsertError, err) } + + t.Log("bad update") + badUpdateError := errors.New("bad update") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + if err != badUpdateError { + t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", badUpdateError, err) + } +} + +func TestInsertOutputStats(t *testing.T) { + t.Log("happy insert") + savedPsqlConnect := psqlConnect + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + + psqlDb, _ := psqlConnect(db.ConnectUrl) + err := insertOutputStats(psqlDb, "id", "outputName", 1) + if err != nil { + t.Errorf("Unexpected error in 'insertOutputStats': %v", err) + } + + t.Log("happy update") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertOutputStats(psqlDb, "id", "outputName", 1) + if err != nil { + t.Errorf("Unexpected error in 'insertOutputStats': %v", err) + } + + t.Log("select error") + selectError := errors.New("select error") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectQuery("SELECT").WillReturnError(selectError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertOutputStats(psqlDb, "id", "outputName", 1) + if err != selectError { + t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", selectError, err) + } + + t.Log("select 2 rows") + select2RowsError := "error insert in postgresDb. Table:WebhookOutputStats where id=id, outputName=outputName, have 2 rows" + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertOutputStats(psqlDb, "id", "outputName", 1) + if err.Error() != select2RowsError { + t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", select2RowsError, err) + } + + t.Log("bad insert") + badInsertError := errors.New("bad insert") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnError(badInsertError) + return db, err + } psqlDb, _ = psqlConnect(db.ConnectUrl) err = insertOutputStats(psqlDb, "id", "outputName", 1) if err != badInsertError { @@ -301,13 +408,118 @@ func TestInsertAndInsertOutputStats(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + err = insertOutputStats(psqlDb, "id", "outputName", 1) if err != badUpdateError { - t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", badUpdateError, err) + t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badUpdateError, err) + } +} + +func TestInsertInTableName(t *testing.T) { + t.Log("happy insert") + savedPsqlConnect := psqlConnect + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + + psqlDb, _ := psqlConnect(db.ConnectUrl) + err := insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + if err != nil { + t.Errorf("Unexpected error in 'insertInTableName': %v", err) + } + + t.Log("happy update") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "id", "outputName", 1) + err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + if err != nil { + t.Errorf("Unexpected error in 'insertInTableName': %v", err) + } + + t.Log("select error") + selectError := errors.New("select error") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectQuery("SELECT").WillReturnError(selectError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + if err != selectError { + t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", selectError, err) + } + + t.Log("select 2 rows") + select2RowsError := "error insert in postgresDb. Table:WebhookTable where id=id, messageKey=messageKey, have 2 rows" + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + if err.Error() != select2RowsError { + t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", select2RowsError, err) + } + + t.Log("bad insert") + badInsertError := errors.New("bad insert") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnError(badInsertError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + if err != badInsertError { + t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", badInsertError, err) + } + + t.Log("bad update") + badUpdateError := errors.New("bad update") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") if err != badUpdateError { - t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badUpdateError, err) + t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", badUpdateError, err) } } diff --git a/dbservice/postgresdb/init.go b/dbservice/postgresdb/init.go index 44e087f2..5fc0c401 100644 --- a/dbservice/postgresdb/init.go +++ b/dbservice/postgresdb/init.go @@ -8,9 +8,8 @@ import ( var ( tableSchemas = map[string]string{ - dbTableName: "CREATE TABLE IF NOT EXISTS %s (id text, messagekey text,messagevalue text);", + dbTableName: "CREATE TABLE IF NOT EXISTS %s (id text, date text, messagekey text,messagevalue text);", dbTableAggregator: "CREATE TABLE IF NOT EXISTS %s (id text, output text,saving text);", - dbTableExpiryDates: "CREATE TABLE IF NOT EXISTS %s (id text, date text,messageKey text);", dbTableOutputStats: "CREATE TABLE IF NOT EXISTS %s (id text, outputname text,amount integer);", dbTableSharedConfig: "CREATE TABLE IF NOT EXISTS %s (id text, apikeyname text,value text);", } diff --git a/dbservice/postgresdb/insert.go b/dbservice/postgresdb/insert.go index e7e46abe..f1113100 100644 --- a/dbservice/postgresdb/insert.go +++ b/dbservice/postgresdb/insert.go @@ -15,10 +15,33 @@ var insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, valu if _, err := db.Exec(fmt.Sprintf("INSERT INTO %s (%s, %s, %s) VALUES ($1, $2, $3)", table, "id", columnName2, columnName3), id, value2, value3); err != nil { return err } - } else { + } else if i == 1 { if _, err := db.Exec(fmt.Sprintf("UPDATE %s SET %s=$1 WHERE (%s=$2 AND %s=$3);", table, columnName3, "id", columnName2), value3, id, value2); err != nil { return err } + } else { + return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, %s=%s, have %d rows", table, id, columnName2, value2, i) + } + return nil +} + +var insertInTableName = func(db *sqlx.DB, id, date, messageKey, messageValue string) error { + var i int + if err := db.Get(&i, fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (%s=$1 AND %s=$2)", dbTableName, "id", "messageKey"), id, messageKey); err != nil { + return err + } + if i == 0 { + if _, err := db.Exec(fmt.Sprintf("INSERT INTO %s (%s, %s, %s, %s) VALUES ($1, $2, $3, $4)", dbTableName, "id", "date", "messagekey", "messagevalue"), + id, date, messageKey, messageValue); err != nil { + return err + } + } else if i == 1 { + if _, err := db.Exec(fmt.Sprintf("UPDATE %s SET %s=$1, %s=$2 WHERE (%s=$3 AND %s=$4);", dbTableName, "date", "messagevalue", "id", "messagekey"), + date, messageValue, id, messageKey); err != nil { + return err + } + } else { + return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, messageKey=%s, have %d rows", dbTableName, id, messageKey, i) } return nil } @@ -34,11 +57,14 @@ var insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) err if err != nil { return err } - } else { + } else if i == 1 { _, err = db.Exec(fmt.Sprintf("UPDATE %s SET %s=$1 WHERE (%s=$2 AND %s=$3);", dbTableOutputStats, "amount", "id", "outputName"), amount, id, outputName) if err != nil { return err } + } else { + return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, outputName=%s, have %d rows", dbTableOutputStats, id, outputName, i) } + return nil } diff --git a/dbservice/postgresdb/invalidinit_test.go b/dbservice/postgresdb/invalidinit_test.go index 8018b421..24bf8d9e 100644 --- a/dbservice/postgresdb/invalidinit_test.go +++ b/dbservice/postgresdb/invalidinit_test.go @@ -50,7 +50,7 @@ var tests = []struct { }, } -func TestBucketInitialization(t *testing.T) { +func TestTableInitialization(t *testing.T) { expectedError := errors.New("weird error") savedInitTable := initTable savedPsqlConnect := psqlConnect From bc3573a159ed1c19e618d7141f8222bc024e28c5 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Fri, 3 Dec 2021 18:01:29 +0600 Subject: [PATCH 23/61] Refactor: code review notes are corrected --- README.md | 2 +- dbservice/boltdb/checker.go | 2 - dbservice/dbservice.go | 10 +-- dbservice/dbservice_test.go | 14 ++-- dbservice/postgresdb/actions.go | 7 +- dbservice/postgresdb/actions_test.go | 11 ++- dbservice/postgresdb/connect.go | 23 ++++++ .../{dbparam_test.go => connect_test.go} | 12 ++- dbservice/postgresdb/dbaggregator.go | 4 - dbservice/postgresdb/dbaggregator_test.go | 3 - dbservice/postgresdb/dbparam.go | 20 ----- dbservice/postgresdb/dbservice_test.go | 25 +++--- dbservice/postgresdb/init.go | 32 +++++--- dbservice/postgresdb/invalidinit_test.go | 80 ------------------- dbservice/postgresdb/plgstats.go | 4 - dbservice/postgresdb/plgstats_test.go | 6 -- dbservice/postgresdb/sharedcfg.go | 4 - dbservice/postgresdb/sharedcfg_test.go | 6 -- msgservice/applicationscopeowner_test.go | 2 +- router/router.go | 2 +- 20 files changed, 92 insertions(+), 177 deletions(-) create mode 100644 dbservice/postgresdb/connect.go rename dbservice/postgresdb/{dbparam_test.go => connect_test.go} (54%) delete mode 100644 dbservice/postgresdb/invalidinit_test.go diff --git a/README.md b/README.md index 325b9cad..9b7bf72e 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,7 @@ The Postee container uses BoltDB to store information about previously scanned i This is used to prevent resending messages that were already sent before. The size of the database can grow over time. Every image that is saved in the database uses 20K of storage. -Postee supports ‘PATH_TO_BOLTDB’ environment variable to change the bolt database directory. To use, set the ‘PATH_TO_BOLTDB’ environment variable to point to the bolt database file, for example: PATH_TO_BOLTDB="./database/webhook.db". By default, the directory for the bolt database file is “/server/database/webhook.db”. +Postee supports ‘PATH_TO_DB’ environment variable to change the bolt database directory. To use, set the ‘PATH_TO_DB’ environment variable to point to the bolt database file, for example: PATH_TO_DB="./database/webhook.db". By default, the directory for the bolt database file is “/server/database/webhook.db”. If you would like to persist the database file between restarts of the Postee container, then you should use a persistent storage option to mount the "/server/database" directory of the container. diff --git a/dbservice/boltdb/checker.go b/dbservice/boltdb/checker.go index d92a8443..6bb0c4fc 100644 --- a/dbservice/boltdb/checker.go +++ b/dbservice/boltdb/checker.go @@ -2,7 +2,6 @@ package boltdb import ( "bytes" - "fmt" "log" "time" @@ -31,7 +30,6 @@ func (boltDb *BoltDb) CheckSizeLimit() { c := b.Cursor() size := 0 for k, v := c.First(); k != nil; k, v = c.Next() { - fmt.Println(string(v)) size += len(v) } if size > DbSizeLimit { diff --git a/dbservice/dbservice.go b/dbservice/dbservice.go index c3631d2d..9adb9a76 100644 --- a/dbservice/dbservice.go +++ b/dbservice/dbservice.go @@ -24,17 +24,17 @@ type DbProvider interface { SetDbSizeLimit(limit int) } -func ConfigurateDb(id string, dBTestInterval *int, dbMaxSize int) error { +func ConfigureDb(tenantId string, dBTestInterval *int, dbMaxSize int) error { if *dBTestInterval == 0 { *dBTestInterval = 1 } if os.Getenv("POSTGRES_URL") != "" { - if id == "" { - return errors.New("error configurate postgresDb: 'id' is empty") + if tenantId == "" { + return errors.New("error configurate postgresDb: 'tenantId' is empty") } - postgresDb := postgresdb.NewPostgresDb(id, os.Getenv("POSTGRES_URL")) - if err := postgresdb.TestConnect(postgresDb.ConnectUrl); err != nil { + postgresDb := postgresdb.NewPostgresDb(tenantId, os.Getenv("POSTGRES_URL")) + if err := postgresdb.InitPostgresDb(postgresDb.ConnectUrl); err != nil { return err } Db = postgresDb diff --git a/dbservice/dbservice_test.go b/dbservice/dbservice_test.go index 898b91b5..e9526e78 100644 --- a/dbservice/dbservice_test.go +++ b/dbservice/dbservice_test.go @@ -28,7 +28,7 @@ func TestConfigurateBoltDbPath(t *testing.T) { os.Setenv("PATH_TO_BOLTDB", test.dbPath) testInterval := 2 - if err := ConfigurateDb("id", &testInterval, 1); err != nil { + if err := ConfigureDb("id", &testInterval, 1); err != nil { t.Errorf("Unexpected error: %v", err) } if testInterval != 2 { @@ -50,30 +50,28 @@ func TestConfiguratePostgresDbUrlAndId(t *testing.T) { expectedError error }{ {"happy configuration", "postgresql://user:secret@localhost", "test-id", nil}, - {"bad id", "postgresql://user:secret@localhost", "", errors.New("error configurate postgresDb: 'id' is empty")}, + {"bad id", "postgresql://user:secret@localhost", "", errors.New("error configurate postgresDb: 'tenantId' is empty")}, {"bad url", "badUrl", "test-id", errors.New("badUrl error")}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - testConnectSaved := postgresdb.TestConnect - postgresdb.TestConnect = func(connectUrl string) error { + initPostgresDbSaved := postgresdb.InitPostgresDb + postgresdb.InitPostgresDb = func(connectUrl string) error { if connectUrl == "badUrl" { return errors.New("badUrl error") } return nil } - defer func() { - postgresdb.TestConnect = testConnectSaved - }() oldUrlEnv := os.Getenv("POSTGRES_URL") defer func() { + postgresdb.InitPostgresDb = initPostgresDbSaved os.Setenv("POSTGRES_URL", oldUrlEnv) }() os.Setenv("POSTGRES_URL", test.url) testInterval := 0 - err := ConfigurateDb(test.id, &testInterval, 1) + err := ConfigureDb(test.id, &testInterval, 1) if err != nil { if err.Error() != test.expectedError.Error() { t.Errorf("Unexpected error, expected: %s, got: %s", test.expectedError, err) diff --git a/dbservice/postgresdb/actions.go b/dbservice/postgresdb/actions.go index d69d366f..932dea97 100644 --- a/dbservice/postgresdb/actions.go +++ b/dbservice/postgresdb/actions.go @@ -13,12 +13,9 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin } defer db.Close() - if err = initTable(db, dbTableName); err != nil { - return false, err - } - currentValue := "" - if err = db.Get(¤tValue, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "messageValue", dbTableName, "id", "messageKey"), postgresDb.Id, messageKey); err != nil { + sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "messageValue", dbTableName, "id", "messageKey") + if err = db.Get(¤tValue, sqlQuery, postgresDb.Id, messageKey); err != nil { if err != sql.ErrNoRows { return false, err } diff --git a/dbservice/postgresdb/actions_test.go b/dbservice/postgresdb/actions_test.go index cc834f1c..d44fada9 100644 --- a/dbservice/postgresdb/actions_test.go +++ b/dbservice/postgresdb/actions_test.go @@ -11,10 +11,7 @@ import ( func TestStoreMessage(t *testing.T) { currentValueStoreMessage := "" - time := time.Now() - savedInitTable := initTable - initTable = func(db *sqlx.DB, tableName string) error { return nil } savedinsertInTableName := insertInTableName insertInTableName = func(db *sqlx.DB, id, date, messageKey, messageValue string) error { currentValueStoreMessage = messageValue @@ -31,21 +28,22 @@ func TestStoreMessage(t *testing.T) { return db, err } defer func() { - initTable = savedInitTable insertInTableName = savedinsertInTableName psqlConnect = savedPsqlConnect }() var tests = []struct { input *string + t *time.Time }{ - {&AlpineImageResult}, + {&AlpineImageResult, nil}, + {&AlpineImageResult, &time.Time{}}, } for _, test := range tests { // Handling of first scan - isNew, err := db.MayBeStoreMessage([]byte(*test.input), AlpineImageKey, &time) + isNew, err := db.MayBeStoreMessage([]byte(*test.input), AlpineImageKey, test.t) if err != nil { t.Errorf("Error: %s\n", err) } @@ -61,6 +59,7 @@ func TestStoreMessage(t *testing.T) { if isNew { t.Errorf("A old scan wasn't found!\n") } + currentValueStoreMessage = "" } } diff --git a/dbservice/postgresdb/connect.go b/dbservice/postgresdb/connect.go new file mode 100644 index 00000000..dbd54ef7 --- /dev/null +++ b/dbservice/postgresdb/connect.go @@ -0,0 +1,23 @@ +package postgresdb + +import ( + "errors" + + "github.com/jmoiron/sqlx" +) + +var psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, err := sqlx.Connect("postgres", connectUrl) + if err != nil { + return nil, err + } + return db, nil +} + +var testConnect = func(connectUrl string) (*sqlx.DB, error) { + db, err := psqlConnect(connectUrl) + if err != nil { + return nil, errors.New("Error postgresDb test connect: " + err.Error()) + } + return db, nil +} diff --git a/dbservice/postgresdb/dbparam_test.go b/dbservice/postgresdb/connect_test.go similarity index 54% rename from dbservice/postgresdb/dbparam_test.go rename to dbservice/postgresdb/connect_test.go index efa04496..7004d801 100644 --- a/dbservice/postgresdb/dbparam_test.go +++ b/dbservice/postgresdb/connect_test.go @@ -7,7 +7,7 @@ import ( "github.com/jmoiron/sqlx" ) -func TestConnectFunc(t *testing.T) { +func TestConnectFuncError(t *testing.T) { expectedError := "Error postgresDb test connect: connect error" savedpsqlConnect := psqlConnect defer func() { @@ -16,8 +16,16 @@ func TestConnectFunc(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { return nil, errors.New("connect error") } - err := TestConnect("url") + _, err := testConnect("url") if err.Error() != expectedError { t.Errorf("error text connect, expectedError: %v, got: %v", expectedError, err) } } + +func TestPsqlConnectError(t *testing.T) { + expectedError := `missing "=" after "test_trivy_psql_connect_dbName" in connection info string"` + _, err := psqlConnect("test_trivy_psql_connect_dbName") + if err.Error() != expectedError { + t.Errorf("Unexpected error, expected: '%v', got: '%v'", expectedError, err) + } +} diff --git a/dbservice/postgresdb/dbaggregator.go b/dbservice/postgresdb/dbaggregator.go index 25b5b3af..7de7f6ba 100644 --- a/dbservice/postgresdb/dbaggregator.go +++ b/dbservice/postgresdb/dbaggregator.go @@ -17,10 +17,6 @@ func (postgresDb *PostgresDb) AggregateScans(output string, } defer db.Close() - if err = initTable(db, dbTableAggregator); err != nil { - return nil, err - } - aggregatedScans := make([]map[string]string, 0, scansPerTicket) if len(currentScan) > 0 { aggregatedScans = append(aggregatedScans, currentScan) diff --git a/dbservice/postgresdb/dbaggregator_test.go b/dbservice/postgresdb/dbaggregator_test.go index 730d7f14..d0838084 100644 --- a/dbservice/postgresdb/dbaggregator_test.go +++ b/dbservice/postgresdb/dbaggregator_test.go @@ -50,8 +50,6 @@ func TestAggregateScans(t *testing.T) { saving := "" for i := 0; i < len(tests); i++ { - savedInitTable := initTable - initTable = func(db *sqlx.DB, tableName string) error { return nil } savedInsert := insert insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { saving = value3 @@ -68,7 +66,6 @@ func TestAggregateScans(t *testing.T) { return db, err } defer func() { - initTable = savedInitTable insert = savedInsert psqlConnect = savedPsqlConnect }() diff --git a/dbservice/postgresdb/dbparam.go b/dbservice/postgresdb/dbparam.go index 8dbbbb76..0a2b1ed5 100644 --- a/dbservice/postgresdb/dbparam.go +++ b/dbservice/postgresdb/dbparam.go @@ -1,10 +1,7 @@ package postgresdb import ( - "errors" "time" - - "github.com/jmoiron/sqlx" ) var ( @@ -33,20 +30,3 @@ func NewPostgresDb(id, connectUrl string) *PostgresDb { func (postgresDb *PostgresDb) SetDbSizeLimit(limit int) { DbSizeLimit = limit } - -var TestConnect = func(connectUrl string) error { - db, err := psqlConnect(connectUrl) - if err != nil { - return errors.New("Error postgresDb test connect: " + err.Error()) - } - defer db.Close() - return nil -} - -var psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, err := sqlx.Connect("postgres", connectUrl) - if err != nil { - return nil, err - } - return db, nil -} diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index 0dfca0bb..83895376 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -78,7 +78,8 @@ var ( ) func TestInitError(t *testing.T) { - initErr := errors.New("init error") + initTablesErr := errors.New("init tables error") + testConnectErr := errors.New("test connect error") savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { @@ -86,20 +87,24 @@ func TestInitError(t *testing.T) { if err != nil { log.Println("failed to open sqlmock database:", err) } - mock.ExpectExec("CREATE").WillReturnError(initErr) + mock.ExpectExec("CREATE").WillReturnError(initTablesErr) return db, err } - defer func() { - psqlConnect = savedPsqlConnect - }() - isNew, err := db.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) - if isNew { - t.Errorf("Scan shouldn't be marked as new\n") + err := InitPostgresDb("connectUrl") + if err.Error() != initTablesErr.Error() { + t.Errorf("Unexpected error: expected %s, got %s \n", initTablesErr, err) } - if err.Error() != initErr.Error() { - t.Errorf("Unexpected error: expected %s, got %s \n", initErr, err) + savedTestConnect := testConnect + testConnect = func(connectUrl string) (*sqlx.DB, error) { return nil, testConnectErr } + defer func() { + psqlConnect = savedPsqlConnect + testConnect = savedTestConnect + }() + err = InitPostgresDb("ConnectUrl") + if err.Error() != testConnectErr.Error() { + t.Errorf("Unexpected error: expected %s, got %s \n", testConnectErr, err) } } diff --git a/dbservice/postgresdb/init.go b/dbservice/postgresdb/init.go index 5fc0c401..076d50cb 100644 --- a/dbservice/postgresdb/init.go +++ b/dbservice/postgresdb/init.go @@ -1,22 +1,36 @@ package postgresdb import ( - "fmt" - "github.com/jmoiron/sqlx" ) var ( - tableSchemas = map[string]string{ - dbTableName: "CREATE TABLE IF NOT EXISTS %s (id text, date text, messagekey text,messagevalue text);", - dbTableAggregator: "CREATE TABLE IF NOT EXISTS %s (id text, output text,saving text);", - dbTableOutputStats: "CREATE TABLE IF NOT EXISTS %s (id text, outputname text,amount integer);", - dbTableSharedConfig: "CREATE TABLE IF NOT EXISTS %s (id text, apikeyname text,value text);", + tableSchemas = []string{ + "CREATE TABLE IF NOT EXISTS webhooktable (id varchar(32), date varchar(32), messagekey varchar(256),messagevalue text);", + "CREATE TABLE IF NOT EXISTS webhookaggregator (id varchar(32), output varchar(32), saving text);", + "CREATE TABLE IF NOT EXISTS webhookoutputstats (id varchar(32), outputname varchar(32), amount integer);", + "CREATE TABLE IF NOT EXISTS webhooksharedconfig (id varchar(32), apikeyname varchar(14),value varchar(64));", } ) -var initTable = func(db *sqlx.DB, tableName string) error { - _, err := db.Exec(fmt.Sprintf(tableSchemas[tableName], tableName)) +var initAllTables = func(db *sqlx.DB) error { + for _, schema := range tableSchemas { + _, err := db.Exec(schema) + if err != nil { + return err + } + } + return nil +} + +var InitPostgresDb = func(connectUrl string) error { + db, err := testConnect(connectUrl) + if err != nil { + return err + } + defer db.Close() + + err = initAllTables(db) if err != nil { return err } diff --git a/dbservice/postgresdb/invalidinit_test.go b/dbservice/postgresdb/invalidinit_test.go deleted file mode 100644 index 24bf8d9e..00000000 --- a/dbservice/postgresdb/invalidinit_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package postgresdb - -import ( - "errors" - "log" - "testing" - - "github.com/jmoiron/sqlx" - sqlxmock "github.com/zhashkevych/go-sqlxmock" -) - -var tests = []struct { - caseDesc string - errPrvdr func() error - initIsNotCalled bool -}{ - { - caseDesc: "EnsureApiKey", - errPrvdr: func() error { - return db.EnsureApiKey() - }, - }, - { - caseDesc: "GetApiKey", - errPrvdr: func() error { - _, err := db.GetApiKey() - return err - }, - initIsNotCalled: true, - }, - { - caseDesc: "RegisterPlgnInvctn", - errPrvdr: func() error { - return db.RegisterPlgnInvctn("some-key") - }, - }, - { - caseDesc: "MayBeStoreMessage", - errPrvdr: func() error { - _, err := db.MayBeStoreMessage(nil, "a-b-c", nil) - return err - }, - }, - { - caseDesc: "AggregateScans", - errPrvdr: func() error { - _, err := db.AggregateScans("", map[string]string{}, 1, false) - return err - }, - }, -} - -func TestTableInitialization(t *testing.T) { - expectedError := errors.New("weird error") - savedInitTable := initTable - savedPsqlConnect := psqlConnect - initTable = func(db *sqlx.DB, tableName string) error { return expectedError } - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, _, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - return db, err - } - defer func() { - initTable = savedInitTable - psqlConnect = savedPsqlConnect - }() - - for _, test := range tests { - if test.initIsNotCalled { - continue - } - err := test.errPrvdr() - if err != expectedError { - t.Errorf("Unexpected error for %s call, expected %v, got %v", test.caseDesc, expectedError, err) - } - - } -} diff --git a/dbservice/postgresdb/plgstats.go b/dbservice/postgresdb/plgstats.go index e109abf2..d27470e6 100644 --- a/dbservice/postgresdb/plgstats.go +++ b/dbservice/postgresdb/plgstats.go @@ -14,10 +14,6 @@ func (postgresDb *PostgresDb) RegisterPlgnInvctn(name string) error { } defer db.Close() - err = initTable(db, dbTableOutputStats) - if err != nil { - return err - } amount := 0 err = db.Get(&amount, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "amount", dbTableOutputStats, "id", "outputName"), postgresDb.Id, name) if err != nil && err != sql.ErrNoRows { diff --git a/dbservice/postgresdb/plgstats_test.go b/dbservice/postgresdb/plgstats_test.go index b419bc78..d522af21 100644 --- a/dbservice/postgresdb/plgstats_test.go +++ b/dbservice/postgresdb/plgstats_test.go @@ -11,8 +11,6 @@ import ( func TestRegisterPlgnInvctn(t *testing.T) { receivedKey := 0 - savedInitTable := initTable - initTable = func(db *sqlx.DB, tableName string) error { return nil } savedInsertOutputStats := insertOutputStats insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { receivedKey = amount @@ -29,7 +27,6 @@ func TestRegisterPlgnInvctn(t *testing.T) { return db, err } defer func() { - initTable = savedInitTable insertOutputStats = savedInsertOutputStats psqlConnect = savedPsqlConnect }() @@ -55,8 +52,6 @@ func TestRegisterPlgnInvctnErrors(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - savedInitTable := initTable - initTable = func(db *sqlx.DB, tableName string) error { return nil } savedInsertOutputStats := insertOutputStats insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { return nil } savedPsqlConnect := psqlConnect @@ -70,7 +65,6 @@ func TestRegisterPlgnInvctnErrors(t *testing.T) { } defer func() { psqlConnect = savedPsqlConnect - initTable = savedInitTable insertOutputStats = savedInsertOutputStats }() err := db.RegisterPlgnInvctn("testName") diff --git a/dbservice/postgresdb/sharedcfg.go b/dbservice/postgresdb/sharedcfg.go index 2e5941f0..5ef11dbd 100644 --- a/dbservice/postgresdb/sharedcfg.go +++ b/dbservice/postgresdb/sharedcfg.go @@ -17,10 +17,6 @@ func (postgresDb *PostgresDb) EnsureApiKey() error { return err } defer db.Close() - err = initTable(db, dbTableSharedConfig) - if err != nil { - return err - } apiKey, err := generateApiKey(32) if err != nil { diff --git a/dbservice/postgresdb/sharedcfg_test.go b/dbservice/postgresdb/sharedcfg_test.go index 9136f6ec..343b54a0 100644 --- a/dbservice/postgresdb/sharedcfg_test.go +++ b/dbservice/postgresdb/sharedcfg_test.go @@ -10,8 +10,6 @@ import ( ) func TestApiKey(t *testing.T) { - savedInitTable := initTable - initTable = func(db *sqlx.DB, tableName string) error { return nil } savedInsert := insert insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { return nil } savedPsqlConnect := psqlConnect @@ -25,7 +23,6 @@ func TestApiKey(t *testing.T) { return db, err } defer func() { - initTable = savedInitTable insert = savedInsert psqlConnect = savedPsqlConnect }() @@ -65,8 +62,6 @@ func TestApiKeyWithoutInit(t *testing.T) { func TestApiKeyRenewal(t *testing.T) { receivedKey := "" - savedInitTable := initTable - initTable = func(db *sqlx.DB, tableName string) error { return nil } savedInsert := insert insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { receivedKey = value3 @@ -83,7 +78,6 @@ func TestApiKeyRenewal(t *testing.T) { return db, err } defer func() { - initTable = savedInitTable insert = savedInsert psqlConnect = savedPsqlConnect }() diff --git a/msgservice/applicationscopeowner_test.go b/msgservice/applicationscopeowner_test.go index 3a1b54ef..4ac02840 100644 --- a/msgservice/applicationscopeowner_test.go +++ b/msgservice/applicationscopeowner_test.go @@ -25,11 +25,11 @@ func TestApplicationScopeOwner(t *testing.T) { db = boltdb.NewBoltDb() oldDb := dbservice.Db dbservice.Db = db - defer func() { dbservice.Db = oldDb }() dbPathReal := db.DbPath defer func() { os.Remove(db.DbPath) db.DbPath = dbPathReal + dbservice.Db = oldDb }() db.DbPath = "test_webhooks.db" diff --git a/router/router.go b/router/router.go index f72f7f57..37e50f82 100644 --- a/router/router.go +++ b/router/router.go @@ -198,7 +198,7 @@ func (ctx *Router) load() error { ctx.aquaServer = fmt.Sprintf("%s%s#/images/", tenant.AquaServer, slash) } - if err = dbservice.ConfigurateDb(tenant.Name, &tenant.DBTestInterval, tenant.DBMaxSize); err != nil { + if err = dbservice.ConfigureDb(tenant.Name, &tenant.DBTestInterval, tenant.DBMaxSize); err != nil { return err } From 04f8f21bcab86dc8f1bc7765549b3f1539171e61 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Mon, 6 Dec 2021 17:00:21 +0600 Subject: [PATCH 24/61] Refactor: code review notes are corrected --- dbservice/boltdb/actions.go | 11 +- dbservice/boltdb/{dbparam.go => boltdb.go} | 18 +-- dbservice/boltdb/checker.go | 17 ++- dbservice/boltdb/checker_test.go | 37 ++--- dbservice/boltdb/dbaggregator.go | 9 +- dbservice/boltdb/dbparam_test.go | 11 +- dbservice/boltdb/plgnstats.go | 5 +- dbservice/boltdb/plgnstats_test.go | 3 +- dbservice/boltdb/sharedcfg.go | 21 +-- dbservice/dbparam/dbparam.go | 28 ++++ dbservice/dbservice.go | 19 ++- dbservice/dbservice_test.go | 45 ++++-- dbservice/postgresdb/actions.go | 8 +- dbservice/postgresdb/actions_test.go | 6 +- dbservice/postgresdb/checker.go | 29 ++-- dbservice/postgresdb/checker_test.go | 27 ++-- dbservice/postgresdb/dbaggregator.go | 13 +- dbservice/postgresdb/dbaggregator_test.go | 12 +- dbservice/postgresdb/dbparam.go | 32 ---- dbservice/postgresdb/dbservice_test.go | 165 +++++++++++++++++---- dbservice/postgresdb/delete.go | 9 +- dbservice/postgresdb/init.go | 11 +- dbservice/postgresdb/insert.go | 63 ++++++-- dbservice/postgresdb/plgstats.go | 4 +- dbservice/postgresdb/postgresdb.go | 13 ++ dbservice/postgresdb/sharedcfg.go | 19 +-- dbservice/postgresdb/sharedcfg_test.go | 14 +- router/router.go | 2 +- 28 files changed, 395 insertions(+), 256 deletions(-) rename dbservice/boltdb/{dbparam.go => boltdb.go} (55%) create mode 100644 dbservice/dbparam/dbparam.go delete mode 100644 dbservice/postgresdb/dbparam.go create mode 100644 dbservice/postgresdb/postgresdb.go diff --git a/dbservice/boltdb/actions.go b/dbservice/boltdb/actions.go index 8a533899..4bb70a19 100644 --- a/dbservice/boltdb/actions.go +++ b/dbservice/boltdb/actions.go @@ -3,6 +3,7 @@ package boltdb import ( "time" + "github.com/aquasecurity/postee/dbservice/dbparam" bolt "go.etcd.io/bbolt" ) @@ -16,14 +17,14 @@ func (boltDb *BoltDb) MayBeStoreMessage(message []byte, messageKey string, expir } defer db.Close() - if err = Init(db, dbBucketName); err != nil { + if err = Init(db, dbparam.DbBucketName); err != nil { return false, err } - if err = Init(db, dbBucketExpiryDates); err != nil { + if err = Init(db, dbparam.DbBucketExpiryDates); err != nil { return false, err } - currentValue, err := dbSelect(db, dbBucketName, messageKey) + currentValue, err := dbSelect(db, dbparam.DbBucketName, messageKey) if err != nil { return false, err } @@ -32,12 +33,12 @@ func (boltDb *BoltDb) MayBeStoreMessage(message []byte, messageKey string, expir return false, nil } else { bMessageKey := []byte(messageKey) - err = dbInsert(db, dbBucketName, bMessageKey, message) + err = dbInsert(db, dbparam.DbBucketName, bMessageKey, message) if err != nil { return false, err } if expired != nil { - err = dbInsert(db, dbBucketExpiryDates, []byte(expired.Format(DateFmt)), bMessageKey) + err = dbInsert(db, dbparam.DbBucketExpiryDates, []byte(expired.Format(dbparam.DateFmt)), bMessageKey) if err != nil { return false, err } diff --git a/dbservice/boltdb/dbparam.go b/dbservice/boltdb/boltdb.go similarity index 55% rename from dbservice/boltdb/dbparam.go rename to dbservice/boltdb/boltdb.go index 1fbf249a..cea58c95 100644 --- a/dbservice/boltdb/dbparam.go +++ b/dbservice/boltdb/boltdb.go @@ -4,20 +4,9 @@ import ( "os" "path/filepath" "sync" - "time" ) var ( - dbBucketName = "WebhookBucket" - dbBucketAggregator = "WebhookAggregator" - dbBucketExpiryDates = "WebhookExpiryDates" - dbBucketOutputStats = "WebhookOutputStats" - dbBucketSharedConfig = "WebhookSharedConfig" - - DbSizeLimit = 0 - DateFmt = time.RFC3339Nano - dueTimeBase = time.Hour * time.Duration(24) - mutex sync.Mutex ) @@ -37,8 +26,7 @@ func (boltDb *BoltDb) ChangeDbPath(newPath string) { mutex.Unlock() } -func (boltDb *BoltDb) SetNewDbPathFromEnv() error { - newPath := os.Getenv("PATH_TO_BOLTDB") +func (boltDb *BoltDb) SetNewDbPath(newPath string) error { if newPath != "" { if _, err := os.Stat(newPath); err != nil { if os.IsNotExist(err) { @@ -54,7 +42,3 @@ func (boltDb *BoltDb) SetNewDbPathFromEnv() error { } return nil } - -func (boltDb *BoltDb) SetDbSizeLimit(limit int) { - DbSizeLimit = limit -} diff --git a/dbservice/boltdb/checker.go b/dbservice/boltdb/checker.go index 6bb0c4fc..1cfea8bd 100644 --- a/dbservice/boltdb/checker.go +++ b/dbservice/boltdb/checker.go @@ -5,11 +5,12 @@ import ( "log" "time" + "github.com/aquasecurity/postee/dbservice/dbparam" bolt "go.etcd.io/bbolt" ) func (boltDb *BoltDb) CheckSizeLimit() { - if DbSizeLimit == 0 { + if dbparam.DbSizeLimit == 0 { return } mutex.Lock() @@ -23,7 +24,7 @@ func (boltDb *BoltDb) CheckSizeLimit() { defer db.Close() if err := db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte(dbBucketName)) + b := tx.Bucket([]byte(dbparam.DbBucketName)) if b == nil { return nil } @@ -32,8 +33,8 @@ func (boltDb *BoltDb) CheckSizeLimit() { for k, v := c.First(); k != nil; k, v = c.Next() { size += len(v) } - if size > DbSizeLimit { - return tx.DeleteBucket([]byte(dbBucketName)) + if size > dbparam.DbSizeLimit { + return tx.DeleteBucket([]byte(dbparam.DbBucketName)) } return nil }); err != nil { @@ -59,7 +60,7 @@ func (boltDb *BoltDb) CheckExpiredData() { return } - if err := dbDelete(db, dbBucketName, expired); err != nil { + if err := dbDelete(db, dbparam.DbBucketName, expired); err != nil { log.Println("Can't remove expired data: ", err) } } @@ -69,13 +70,13 @@ func (boltDb *BoltDb) getExpired(db *bolt.DB) (keys [][]byte, err error) { ttlKeys := [][]byte{} if err = db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte(dbBucketExpiryDates)) + b := tx.Bucket([]byte(dbparam.DbBucketExpiryDates)) if b == nil { return nil } c := b.Cursor() - max := []byte(time.Now().UTC().Format(DateFmt)) //remove expired records + max := []byte(time.Now().UTC().Format(dbparam.DateFmt)) //remove expired records for k, v := c.First(); k != nil && bytes.Compare(k, max) <= 0; k, v = c.Next() { keys = append(keys, v) ttlKeys = append(ttlKeys, k) @@ -85,7 +86,7 @@ func (boltDb *BoltDb) getExpired(db *bolt.DB) (keys [][]byte, err error) { return nil, err } - if err = dbDelete(db, dbBucketExpiryDates, ttlKeys); err != nil { + if err = dbDelete(db, dbparam.DbBucketExpiryDates, ttlKeys); err != nil { return nil, err } diff --git a/dbservice/boltdb/checker_test.go b/dbservice/boltdb/checker_test.go index baf41e10..d53e6457 100644 --- a/dbservice/boltdb/checker_test.go +++ b/dbservice/boltdb/checker_test.go @@ -5,19 +5,20 @@ import ( "testing" "time" + "github.com/aquasecurity/postee/dbservice/dbparam" bolt "go.etcd.io/bbolt" ) func TestExpiredDates(t *testing.T) { boltDb := NewBoltDb() dbPathReal := boltDb.DbPath - realDueTimeBase := dueTimeBase + realDueTimeBase := dbparam.DueTimeBase defer func() { os.Remove(boltDb.DbPath) boltDb.DbPath = dbPathReal - dueTimeBase = realDueTimeBase + dbparam.DueTimeBase = realDueTimeBase }() - dueTimeBase = time.Nanosecond + dbparam.DueTimeBase = time.Nanosecond boltDb.DbPath = "test_webhooks.db" tests := []struct { title string @@ -55,11 +56,11 @@ func TestExpiredDates(t *testing.T) { func TestDbSizeLimnit(t *testing.T) { boltDb := NewBoltDb() dbPathReal := boltDb.DbPath - realSizeLimit := DbSizeLimit + realSizeLimit := dbparam.DbSizeLimit defer func() { os.Remove(boltDb.DbPath) boltDb.DbPath = dbPathReal - DbSizeLimit = realSizeLimit + dbparam.DbSizeLimit = realSizeLimit }() boltDb.DbPath = "test_webhooks.db" @@ -74,12 +75,12 @@ func TestDbSizeLimnit(t *testing.T) { {"Third scan", 1, true, true}, } - DbSizeLimit = 1 + dbparam.DbSizeLimit = 1 boltDb.CheckSizeLimit() for _, test := range tests { t.Log(test.title) - DbSizeLimit = test.limit + dbparam.DbSizeLimit = test.limit if test.needRun { boltDb.CheckSizeLimit() } @@ -97,12 +98,12 @@ func TestDbSizeLimnit(t *testing.T) { func TestWrongBuckets(t *testing.T) { boltDb := NewBoltDb() - savedDbBucketName := dbBucketName - savedDbBucketExpiryDates := dbBucketExpiryDates + savedDbBucketName := dbparam.DbBucketName + savedDbBucketExpiryDates := dbparam.DbBucketExpiryDates dbPathReal := boltDb.DbPath defer func() { - dbBucketName = savedDbBucketName - dbBucketExpiryDates = savedDbBucketExpiryDates + dbparam.DbBucketName = savedDbBucketName + dbparam.DbBucketExpiryDates = savedDbBucketExpiryDates os.Remove(boltDb.DbPath) boltDb.DbPath = dbPathReal }() @@ -113,18 +114,18 @@ func TestWrongBuckets(t *testing.T) { t.Fatal(err) } - DbSizeLimit = 1 - dbBucketName = "" - dbBucketExpiryDates = "" + dbparam.DbSizeLimit = 1 + dbparam.DbBucketName = "" + dbparam.DbBucketExpiryDates = "" boltDb.CheckSizeLimit() - dbBucketName = "dbBucketName" + dbparam.DbBucketName = "dbBucketName" _, err = boltDb.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) if err == nil { t.Error("No error for empty dbBucketExpiryDates") } - dbBucketExpiryDates = "dbBucketExpiryDates" - dbBucketName = "" + dbparam.DbBucketExpiryDates = "dbBucketExpiryDates" + dbparam.DbBucketName = "" _, err = boltDb.MayBeStoreMessage([]byte(AlpineImageResult), AlpineImageKey, nil) if err == nil { t.Error("No error for empty dbBucketName") @@ -173,7 +174,7 @@ func TestWithoutAccessToDb(t *testing.T) { return } db.Close() - DbSizeLimit = 1 + dbparam.DbSizeLimit = 1 boltDb.CheckSizeLimit() boltDb.CheckExpiredData() } diff --git a/dbservice/boltdb/dbaggregator.go b/dbservice/boltdb/dbaggregator.go index 82d0b764..ad61eac7 100644 --- a/dbservice/boltdb/dbaggregator.go +++ b/dbservice/boltdb/dbaggregator.go @@ -3,6 +3,7 @@ package boltdb import ( "encoding/json" + "github.com/aquasecurity/postee/dbservice/dbparam" bolt "go.etcd.io/bbolt" ) @@ -19,7 +20,7 @@ func (boltDb *BoltDb) AggregateScans(output string, } defer db.Close() - err = Init(db, dbBucketAggregator) + err = Init(db, dbparam.DbBucketAggregator) if err != nil { return nil, err } @@ -28,7 +29,7 @@ func (boltDb *BoltDb) AggregateScans(output string, if len(currentScan) > 0 { aggregatedScans = append(aggregatedScans, currentScan) } - currentValue, err := dbSelect(db, dbBucketAggregator, output) + currentValue, err := dbSelect(db, dbparam.DbBucketAggregator, output) if err != nil { return nil, err } @@ -48,12 +49,12 @@ func (boltDb *BoltDb) AggregateScans(output string, return nil, err } - err = dbInsert(db, dbBucketAggregator, []byte(output), saving) + err = dbInsert(db, dbparam.DbBucketAggregator, []byte(output), saving) if err != nil { return nil, err } return nil, nil } - dbInsert(db, dbBucketAggregator, []byte(output), nil) + dbInsert(db, dbparam.DbBucketAggregator, []byte(output), nil) return aggregatedScans, nil } diff --git a/dbservice/boltdb/dbparam_test.go b/dbservice/boltdb/dbparam_test.go index 2ad6f037..867db148 100644 --- a/dbservice/boltdb/dbparam_test.go +++ b/dbservice/boltdb/dbparam_test.go @@ -9,18 +9,16 @@ import ( func TestSetNewDbPathFromEnv(t *testing.T) { db := NewBoltDb() - envPathToDbOld := os.Getenv("PATH_TO_BOLTDB") - defer os.Setenv("PATH_TO_BOLTDB", envPathToDbOld) dbPathOld := db.DbPath defaultDbPath := "/server/database/webhooks.db" var tests = []struct { name string - envPathToDb string + pathToDb string changePermission bool expectedDBPath string }{ - {"Empty PATH_TO_BOLTDB", "", false, defaultDbPath}, + {"Empty pathToDb", "", false, defaultDbPath}, {"Permission denied to create directory(default DbPath is used)", "/database/database.db", false, defaultDbPath}, {"New DbPath", "./base/base.db", false, "./base/base.db"}, {"Permission denied to check directory(default DbPath is used)", "webhook/database/webhooks.db", true, defaultDbPath}, @@ -28,8 +26,7 @@ func TestSetNewDbPathFromEnv(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - os.Setenv("PATH_TO_BOLTDB", test.envPathToDb) - baseDir := strings.Split(filepath.Dir(test.envPathToDb), "/")[0] + baseDir := strings.Split(filepath.Dir(test.pathToDb), "/")[0] if test.changePermission { err := os.Mkdir(baseDir, os.ModeDir) if err != nil { @@ -37,7 +34,7 @@ func TestSetNewDbPathFromEnv(t *testing.T) { } os.Chmod(baseDir, 0) } - db.SetNewDbPathFromEnv() + db.SetNewDbPath(test.pathToDb) defer os.RemoveAll(baseDir) defer db.ChangeDbPath(dbPathOld) diff --git a/dbservice/boltdb/plgnstats.go b/dbservice/boltdb/plgnstats.go index ec407d44..2b43fea8 100644 --- a/dbservice/boltdb/plgnstats.go +++ b/dbservice/boltdb/plgnstats.go @@ -3,6 +3,7 @@ package boltdb import ( "strconv" + "github.com/aquasecurity/postee/dbservice/dbparam" bolt "go.etcd.io/bbolt" ) @@ -15,13 +16,13 @@ func (boltDb *BoltDb) RegisterPlgnInvctn(name string) error { return err } defer db.Close() - err = Init(db, dbBucketOutputStats) + err = Init(db, dbparam.DbBucketOutputStats) if err != nil { return err } err = db.Update(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(dbBucketOutputStats)) + bucket := tx.Bucket([]byte(dbparam.DbBucketOutputStats)) var i int v := bucket.Get([]byte(name)) diff --git a/dbservice/boltdb/plgnstats_test.go b/dbservice/boltdb/plgnstats_test.go index ae90888a..41d85f5f 100644 --- a/dbservice/boltdb/plgnstats_test.go +++ b/dbservice/boltdb/plgnstats_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + "github.com/aquasecurity/postee/dbservice/dbparam" bolt "go.etcd.io/bbolt" ) @@ -40,7 +41,7 @@ func getPlgnStats(dbBolt *BoltDb) (r map[string]int, err error) { } defer db.Close() err = db.View(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(dbBucketOutputStats)) + bucket := tx.Bucket([]byte(dbparam.DbBucketOutputStats)) if bucket == nil { return nil //no bucket - empty stats will be returned } diff --git a/dbservice/boltdb/sharedcfg.go b/dbservice/boltdb/sharedcfg.go index dd2c1cbd..fc3ee383 100644 --- a/dbservice/boltdb/sharedcfg.go +++ b/dbservice/boltdb/sharedcfg.go @@ -1,11 +1,9 @@ package boltdb import ( - "crypto/rand" - "encoding/hex" "errors" - "io" + "github.com/aquasecurity/postee/dbservice/dbparam" bolt "go.etcd.io/bbolt" ) @@ -23,17 +21,17 @@ func (boltDb *BoltDb) EnsureApiKey() error { } defer db.Close() - err = Init(db, dbBucketOutputStats) + err = Init(db, dbparam.DbBucketOutputStats) if err != nil { return err } - newApiKey, err := generateApiKey(32) + newApiKey, err := dbparam.GenerateApiKey(32) if err != nil { return err } - err = dbInsert(db, dbBucketSharedConfig, []byte(apiKeyName), []byte(newApiKey)) + err = dbInsert(db, dbparam.DbBucketSharedConfig, []byte(apiKeyName), []byte(newApiKey)) return err } @@ -45,7 +43,7 @@ func (boltDb *BoltDb) GetApiKey() (string, error) { } defer db.Close() err = db.View(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(dbBucketSharedConfig)) + bucket := tx.Bucket([]byte(dbparam.DbBucketSharedConfig)) if bucket == nil { return errors.New("no bucket") //no bucket } @@ -58,14 +56,5 @@ func (boltDb *BoltDb) GetApiKey() (string, error) { if err != nil { return "", err } - return apiKey, nil - -} -func generateApiKey(length int) (string, error) { - k := make([]byte, length) - if _, err := io.ReadFull(rand.Reader, k); err != nil { - return "", err - } - return hex.EncodeToString(k), nil } diff --git a/dbservice/dbparam/dbparam.go b/dbservice/dbparam/dbparam.go new file mode 100644 index 00000000..2ff87429 --- /dev/null +++ b/dbservice/dbparam/dbparam.go @@ -0,0 +1,28 @@ +package dbparam + +import ( + "crypto/rand" + "encoding/hex" + "io" + "time" +) + +var ( + DbBucketName = "WebhookBucket" + DbBucketAggregator = "WebhookAggregator" + DbBucketExpiryDates = "WebhookExpiryDates" + DbBucketOutputStats = "WebhookOutputStats" + DbBucketSharedConfig = "WebhookSharedConfig" + + DbSizeLimit = 0 + DateFmt = time.RFC3339Nano + DueTimeBase = time.Hour * time.Duration(24) +) + +func GenerateApiKey(length int) (string, error) { + k := make([]byte, length) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return "", err + } + return hex.EncodeToString(k), nil +} diff --git a/dbservice/dbservice.go b/dbservice/dbservice.go index 9adb9a76..6b0c3c7d 100644 --- a/dbservice/dbservice.go +++ b/dbservice/dbservice.go @@ -2,11 +2,12 @@ package dbservice import ( "errors" - "os" "time" "github.com/aquasecurity/postee/dbservice/boltdb" + "github.com/aquasecurity/postee/dbservice/dbparam" "github.com/aquasecurity/postee/dbservice/postgresdb" + "github.com/aquasecurity/postee/utils" ) var ( @@ -21,32 +22,34 @@ type DbProvider interface { RegisterPlgnInvctn(name string) error EnsureApiKey() error GetApiKey() (string, error) - SetDbSizeLimit(limit int) } -func ConfigureDb(tenantId string, dBTestInterval *int, dbMaxSize int) error { +func ConfigureDb(pathToDb, postgresUrl, tenantId string, dBTestInterval *int, dbMaxSize int) error { if *dBTestInterval == 0 { *dBTestInterval = 1 } - if os.Getenv("POSTGRES_URL") != "" { + postgresUrl = utils.GetEnvironmentVarOrPlain(postgresUrl) + pathToDb = utils.GetEnvironmentVarOrPlain(pathToDb) + + if postgresUrl != "" { if tenantId == "" { return errors.New("error configurate postgresDb: 'tenantId' is empty") } - postgresDb := postgresdb.NewPostgresDb(tenantId, os.Getenv("POSTGRES_URL")) + postgresDb := postgresdb.NewPostgresDb(tenantId, postgresUrl) if err := postgresdb.InitPostgresDb(postgresDb.ConnectUrl); err != nil { return err } Db = postgresDb } else { boltdb := boltdb.NewBoltDb() - if os.Getenv("PATH_TO_BOLTDB") != "" { - if err := boltdb.SetNewDbPathFromEnv(); err != nil { + if pathToDb != "" { + if err := boltdb.SetNewDbPath(pathToDb); err != nil { return err } } Db = boltdb } - Db.SetDbSizeLimit(dbMaxSize) + dbparam.DbSizeLimit = dbMaxSize return nil } diff --git a/dbservice/dbservice_test.go b/dbservice/dbservice_test.go index e9526e78..4d673a2e 100644 --- a/dbservice/dbservice_test.go +++ b/dbservice/dbservice_test.go @@ -4,31 +4,36 @@ import ( "errors" "os" "reflect" + "strings" "testing" "github.com/aquasecurity/postee/dbservice/postgresdb" ) -func TestConfigurateBoltDbPath(t *testing.T) { +func TestConfigurateBoltDbPathUsedEnv(t *testing.T) { tests := []struct { name string dbPath string + dbPathInEnv string expectedPath string }{ - {"happy configuration BoltDB with dbPath", "database/webhooks.db", "database/webhooks.db"}, - {"happy configuration BoltDB with empty dbPath", "", "/server/database/webhooks.db"}, + {"happy configuration BoltDB with dbPath", "database/webhooks.db", "", "database/webhooks.db"}, + {"happy configuration BoltDB with env", "$PATH_TO_DB", "database/envPath/webhooks.db", "database/envPath/webhooks.db"}, + {"happy configuration BoltDB with empty dbPath", "", "", "/server/database/webhooks.db"}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - oldPathEnv := os.Getenv("PATH_TO_BOLTDB") - defer func() { - os.Setenv("PATH_TO_BOLTDB", oldPathEnv) - }() - os.Setenv("PATH_TO_BOLTDB", test.dbPath) + if test.dbPathInEnv != "" { + oldPathEnv := os.Getenv("PATH_TO_DB") + defer func() { + os.Setenv("PATH_TO_DB", oldPathEnv) + }() + os.Setenv("PATH_TO_DB", test.dbPathInEnv) + } testInterval := 2 - if err := ConfigureDb("id", &testInterval, 1); err != nil { + if err := ConfigureDb(test.dbPath, "", "", &testInterval, 1); err != nil { t.Errorf("Unexpected error: %v", err) } if testInterval != 2 { @@ -46,12 +51,14 @@ func TestConfiguratePostgresDbUrlAndId(t *testing.T) { tests := []struct { name string url string + urlInEnv string id string expectedError error }{ - {"happy configuration", "postgresql://user:secret@localhost", "test-id", nil}, - {"bad id", "postgresql://user:secret@localhost", "", errors.New("error configurate postgresDb: 'tenantId' is empty")}, - {"bad url", "badUrl", "test-id", errors.New("badUrl error")}, + {"happy configuration postgres with url", "postgresql://user:secret@localhost", "", "test-id", nil}, + {"happy configuration postgres with env", "$POSTGRES_URL", "postgresql://user:secret@localhost", "test-id", nil}, + {"bad id", "postgresql://user:secret@localhost", "", "", errors.New("error configurate postgresDb: 'tenantId' is empty")}, + {"bad url", "badUrl", "", "test-id", errors.New("badUrl error")}, } for _, test := range tests { @@ -64,14 +71,14 @@ func TestConfiguratePostgresDbUrlAndId(t *testing.T) { return nil } oldUrlEnv := os.Getenv("POSTGRES_URL") + os.Setenv("POSTGRES_URL", test.urlInEnv) defer func() { postgresdb.InitPostgresDb = initPostgresDbSaved os.Setenv("POSTGRES_URL", oldUrlEnv) }() - os.Setenv("POSTGRES_URL", test.url) testInterval := 0 - err := ConfigureDb(test.id, &testInterval, 1) + err := ConfigureDb("", test.url, test.id, &testInterval, 1) if err != nil { if err.Error() != test.expectedError.Error() { t.Errorf("Unexpected error, expected: %s, got: %s", test.expectedError, err) @@ -80,8 +87,14 @@ func TestConfiguratePostgresDbUrlAndId(t *testing.T) { if testInterval != 1 { t.Error("test interval error, expected: 1, got: ", testInterval) } - if test.url != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface() { - t.Errorf("url's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface()) + if strings.HasPrefix(test.url, "$") { + if test.urlInEnv != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface() { + t.Errorf("url's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface()) + } + } else { + if test.url != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface() { + t.Errorf("url's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface()) + } } if test.id != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("Id").Interface() { t.Errorf("id's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("Id").Interface()) diff --git a/dbservice/postgresdb/actions.go b/dbservice/postgresdb/actions.go index 932dea97..6029791c 100644 --- a/dbservice/postgresdb/actions.go +++ b/dbservice/postgresdb/actions.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "time" + + "github.com/aquasecurity/postee/dbservice/dbparam" ) func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey string, expired *time.Time) (wasStored bool, err error) { @@ -14,7 +16,7 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin defer db.Close() currentValue := "" - sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "messageValue", dbTableName, "id", "messageKey") + sqlQuery := fmt.Sprintf("SELECT messageValue FROM %s WHERE (id=$1 AND messageKey=$2)", dbparam.DbBucketName) if err = db.Get(¤tValue, sqlQuery, postgresDb.Id, messageKey); err != nil { if err != sql.ErrNoRows { return false, err @@ -25,11 +27,11 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin return false, nil } else { if expired != nil { - if err = insertInTableName(db, postgresDb.Id, expired.Format(DateFmt), messageKey, string(message)); err != nil { + if err = insertInTableName(db, postgresDb.Id, messageKey, message, expired); err != nil { return false, err } } else { - if err = insertInTableName(db, postgresDb.Id, "", messageKey, string(message)); err != nil { + if err = insertInTableName(db, postgresDb.Id, messageKey, message, nil); err != nil { return false, err } } diff --git a/dbservice/postgresdb/actions_test.go b/dbservice/postgresdb/actions_test.go index d44fada9..265dfe3e 100644 --- a/dbservice/postgresdb/actions_test.go +++ b/dbservice/postgresdb/actions_test.go @@ -10,10 +10,10 @@ import ( ) func TestStoreMessage(t *testing.T) { - currentValueStoreMessage := "" + currentValueStoreMessage := []byte{} savedinsertInTableName := insertInTableName - insertInTableName = func(db *sqlx.DB, id, date, messageKey, messageValue string) error { + insertInTableName = func(db *sqlx.DB, id, messageKey string, messageValue []byte, date *time.Time) error { currentValueStoreMessage = messageValue return nil } @@ -59,7 +59,7 @@ func TestStoreMessage(t *testing.T) { if isNew { t.Errorf("A old scan wasn't found!\n") } - currentValueStoreMessage = "" + currentValueStoreMessage = []byte{} } } diff --git a/dbservice/postgresdb/checker.go b/dbservice/postgresdb/checker.go index 86e8ab76..2a961a03 100644 --- a/dbservice/postgresdb/checker.go +++ b/dbservice/postgresdb/checker.go @@ -4,10 +4,12 @@ import ( "fmt" "log" "time" + + "github.com/aquasecurity/postee/dbservice/dbparam" ) func (postgresDb *PostgresDb) CheckSizeLimit() { - if DbSizeLimit == 0 { + if dbparam.DbSizeLimit == 0 { return } @@ -20,13 +22,13 @@ func (postgresDb *PostgresDb) CheckSizeLimit() { defer db.Close() size := 0 - if err = db.Get(&size, fmt.Sprintf("SELECT pg_total_relation_size('%s');", dbTableName)); err != nil { + if err = db.Get(&size, fmt.Sprintf("SELECT pg_total_relation_size('%s');", dbparam.DbBucketName)); err != nil { log.Printf("CheckSizeLimit: Can't get db size") return } - if size > DbSizeLimit { - if err = deleteRowsById(db, dbTableName, postgresDb.Id); err != nil { - log.Printf("CheckSizeLimit: Can't delete id's: %s from table: %s", postgresDb.Id, dbTableName) + if size > dbparam.DbSizeLimit { + if err = deleteRowsById(db, dbparam.DbBucketName, postgresDb.Id); err != nil { + log.Printf("CheckSizeLimit: Can't delete id's: %s from table: %s", postgresDb.Id, dbparam.DbBucketName) return } } @@ -41,19 +43,8 @@ func (postgresDb *PostgresDb) CheckExpiredData() { } defer db.Close() - dates := []string{} - if err := db.Select(&dates, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 and %s != '')", "date", dbTableName, "id", "date"), postgresDb.Id); err != nil { - log.Printf("CheckExpiredData: Can't get dates from table: %s, err: %v", dbTableName, err) - return - } - - max := time.Now().UTC().Format(DateFmt) //remove expired records - for _, date := range dates { - if date <= max { - if err = deleteRow(db, dbTableName, postgresDb.Id, "date", date); err != nil { - log.Printf("CheckExpiredData: Can't delete %s from table:%s", date, dbTableName) - return - } - } + max := time.Now().UTC() //remove expired records + if err = deleteRowsByIdAndTime(db, dbparam.DbBucketName, postgresDb.Id, max); err != nil { + log.Printf("CheckExpiredData: Can't delete dates from table:%s, err: %v", dbparam.DbBucketName, err) } } diff --git a/dbservice/postgresdb/checker_test.go b/dbservice/postgresdb/checker_test.go index f27c4814..722fbccf 100644 --- a/dbservice/postgresdb/checker_test.go +++ b/dbservice/postgresdb/checker_test.go @@ -5,18 +5,19 @@ import ( "testing" "time" + "github.com/aquasecurity/postee/dbservice/dbparam" "github.com/jmoiron/sqlx" sqlxmock "github.com/zhashkevych/go-sqlxmock" ) func TestExpiredDates(t *testing.T) { tests := []struct { - name string - time time.Time - wasDeleted bool + name string + deleteError bool + wasDeleted bool }{ - {"Time before Now", time.Now().UTC().Add(time.Duration(1) * time.Hour), false}, - {"Time after Now", time.Now().UTC().Add(time.Duration(-1) * time.Hour), true}, + {"happy delete rows", false, true}, + {"bad delete rows", true, false}, } for _, test := range tests { @@ -24,22 +25,22 @@ func TestExpiredDates(t *testing.T) { deleted := false savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() + db, _, err := sqlxmock.Newx() if err != nil { log.Println("failed to open sqlmock database:", err) } - rows := sqlxmock.NewRows([]string{"date"}).AddRow(test.time) - mock.ExpectQuery("SELECT").WillReturnRows(rows) return db, err } - savedDeleteRow := deleteRow - deleteRow = func(db *sqlx.DB, table, id, columnName, value string) error { - deleted = true + savedDeleteRow := deleteRowsByIdAndTime + deleteRowsByIdAndTime = func(db *sqlx.DB, table, id string, t time.Time) error { + if !test.deleteError { + deleted = true + } return nil } defer func() { psqlConnect = savedPsqlConnect - deleteRow = savedDeleteRow + deleteRowsByIdAndTime = savedDeleteRow }() db.CheckExpiredData() if deleted != test.wasDeleted { @@ -82,7 +83,7 @@ func TestSizeLimit(t *testing.T) { psqlConnect = savedPsqlConnect deleteRowsById = savedDeleteRowsById }() - db.SetDbSizeLimit(test.sizeLimit) + dbparam.DbSizeLimit = test.sizeLimit db.CheckSizeLimit() if deleted != test.wasDeleted { t.Errorf("error deleted rows") diff --git a/dbservice/postgresdb/dbaggregator.go b/dbservice/postgresdb/dbaggregator.go index 7de7f6ba..06243d1e 100644 --- a/dbservice/postgresdb/dbaggregator.go +++ b/dbservice/postgresdb/dbaggregator.go @@ -4,6 +4,8 @@ import ( "database/sql" "encoding/json" "fmt" + + "github.com/aquasecurity/postee/dbservice/dbparam" ) func (postgresDb *PostgresDb) AggregateScans(output string, @@ -21,14 +23,15 @@ func (postgresDb *PostgresDb) AggregateScans(output string, if len(currentScan) > 0 { aggregatedScans = append(aggregatedScans, currentScan) } - currentValue := "" - if err = db.Get(¤tValue, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "saving", dbTableAggregator, "id", "output"), postgresDb.Id, output); err != nil { + currentValue := []byte{} + sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (id=$1 AND %s=$2)", "saving", dbparam.DbBucketAggregator, "output") + if err = db.Get(¤tValue, sqlQuery, postgresDb.Id, output); err != nil { if err != sql.ErrNoRows { return nil, err } } - if currentValue != "" { + if len(currentValue) > 0 { var savedScans []map[string]string err = json.Unmarshal([]byte(currentValue), &savedScans) if err != nil { @@ -42,13 +45,13 @@ func (postgresDb *PostgresDb) AggregateScans(output string, if err != nil { return nil, err } - if err = insert(db, dbTableAggregator, postgresDb.Id, "output", output, "saving", string(saving)); err != nil { + if err = insertInTableAggregator(db, postgresDb.Id, output, saving); err != nil { return nil, err } return nil, nil } - if err = insert(db, dbTableAggregator, postgresDb.Id, "output", output, "saving", ""); err != nil { + if err = insertInTableAggregator(db, postgresDb.Id, output, nil); err != nil { return nil, err } return aggregatedScans, nil diff --git a/dbservice/postgresdb/dbaggregator_test.go b/dbservice/postgresdb/dbaggregator_test.go index d0838084..e68a6906 100644 --- a/dbservice/postgresdb/dbaggregator_test.go +++ b/dbservice/postgresdb/dbaggregator_test.go @@ -48,11 +48,11 @@ func TestAggregateScans(t *testing.T) { }, } - saving := "" + savingTest := []byte{} for i := 0; i < len(tests); i++ { - savedInsert := insert - insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { - saving = value3 + savedInsertInTableAggregator := insertInTableAggregator + insertInTableAggregator = func(db *sqlx.DB, id, output string, saving []byte) error { + savingTest = saving return nil } savedPsqlConnect := psqlConnect @@ -61,12 +61,12 @@ func TestAggregateScans(t *testing.T) { if err != nil { log.Println("failed to open sqlmock database:", err) } - rows := sqlxmock.NewRows([]string{"saving"}).AddRow(saving) + rows := sqlxmock.NewRows([]string{"saving"}).AddRow(savingTest) mock.ExpectQuery("SELECT").WillReturnRows(rows) return db, err } defer func() { - insert = savedInsert + insertInTableAggregator = savedInsertInTableAggregator psqlConnect = savedPsqlConnect }() diff --git a/dbservice/postgresdb/dbparam.go b/dbservice/postgresdb/dbparam.go deleted file mode 100644 index 0a2b1ed5..00000000 --- a/dbservice/postgresdb/dbparam.go +++ /dev/null @@ -1,32 +0,0 @@ -package postgresdb - -import ( - "time" -) - -var ( - dbTableName = "WebhookTable" - dbTableAggregator = "WebhookAggregator" - dbTableOutputStats = "WebhookOutputStats" - dbTableSharedConfig = "WebhookSharedConfig" - - DbSizeLimit = 0 - DateFmt = time.RFC3339Nano - dueTimeBase = time.Hour * time.Duration(24) -) - -type PostgresDb struct { - ConnectUrl string - Id string -} - -func NewPostgresDb(id, connectUrl string) *PostgresDb { - return &PostgresDb{ - ConnectUrl: connectUrl, - Id: id, - } -} - -func (postgresDb *PostgresDb) SetDbSizeLimit(limit int) { - DbSizeLimit = limit -} diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index 83895376..b8a8c41b 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -4,6 +4,7 @@ import ( "errors" "log" "testing" + "time" "github.com/jmoiron/sqlx" sqlxmock "github.com/zhashkevych/go-sqlxmock" @@ -108,13 +109,13 @@ func TestInitError(t *testing.T) { } } -func TestDeleteRow(t *testing.T) { +func TestDeleteRowsByIdAndTime(t *testing.T) { t.Log("happy delete row") savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRow := deleteRow + savedDeleteRow := deleteRowsByIdAndTime defer func() { - deleteRow = savedDeleteRow + deleteRowsByIdAndTime = savedDeleteRow }() db, mock, err := sqlxmock.Newx() if err != nil { @@ -128,7 +129,7 @@ func TestDeleteRow(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := deleteRow(psqlDb, "table", "id", "column", "value") + err := deleteRowsByIdAndTime(psqlDb, "table", "id", time.Now()) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -136,9 +137,9 @@ func TestDeleteRow(t *testing.T) { t.Log("bad delete row") deleteError := errors.New("delete - error") psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRow := deleteRow + savedDeleteRow := deleteRowsByIdAndTime defer func() { - deleteRow = savedDeleteRow + deleteRowsByIdAndTime = savedDeleteRow }() db, mock, err := sqlxmock.Newx() if err != nil { @@ -148,7 +149,7 @@ func TestDeleteRow(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = deleteRow(psqlDb, "table", "id", "column", "value") + err = deleteRowsByIdAndTime(psqlDb, "table", "id", time.Now()) if deleteError != err { t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) } @@ -199,7 +200,7 @@ func TestDeleteRowsById(t *testing.T) { } } -func TestInsert(t *testing.T) { +func TestInsertInTableSharedConfig(t *testing.T) { t.Log("happy insert") savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { @@ -217,9 +218,9 @@ func TestInsert(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + err := insertInTableSharedConfig(psqlDb, "id", "value2", "value3") if err != nil { - t.Errorf("Unexpected error in 'insert': %v", err) + t.Errorf("Unexpected error in 'insertInTableSharedConfig': %v", err) } t.Log("happy update") @@ -234,11 +235,121 @@ func TestInsert(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + if err != nil { + t.Errorf("Unexpected error in 'insertInTableSharedConfig': %v", err) + } + + t.Log("select error") + selectError := errors.New("select error") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectQuery("SELECT").WillReturnError(selectError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + if err != selectError { + t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", selectError, err) + } + + t.Log("select 2 rows") + select2RowsError := "error insert in postgresDb. Table:WebhookSharedConfig where id=id, apikeyname=value2, have 2 rows" + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + if err.Error() != select2RowsError { + t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", select2RowsError, err) + } + + t.Log("bad insert") + badInsertError := errors.New("bad insert") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnError(badInsertError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + if err != badInsertError { + t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", badInsertError, err) + } + + t.Log("bad update") + badUpdateError := errors.New("bad update") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + if err != badUpdateError { + t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", badUpdateError, err) + } +} + +func TestInsertInTableAggregator(t *testing.T) { + t.Log("happy insert") + savedPsqlConnect := psqlConnect + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + + psqlDb, _ := psqlConnect(db.ConnectUrl) + err := insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) if err != nil { t.Errorf("Unexpected error in 'insert': %v", err) } + t.Log("happy update") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) + if err != nil { + t.Errorf("Unexpected error in 'insertInTableAggregator': %v", err) + } + t.Log("select error") selectError := errors.New("select error") psqlConnect = func(connectUrl string) (*sqlx.DB, error) { @@ -250,13 +361,13 @@ func TestInsert(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) if err != selectError { - t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", selectError, err) + t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", selectError, err) } t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:table where id=id, column2=value2, have 2 rows" + select2RowsError := "error insert in postgresDb. Table:WebhookAggregator where id=id, output=value2, have 2 rows" psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { @@ -267,9 +378,9 @@ func TestInsert(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) if err.Error() != select2RowsError { - t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", select2RowsError, err) + t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", select2RowsError, err) } t.Log("bad insert") @@ -285,9 +396,9 @@ func TestInsert(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) if err != badInsertError { - t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", badInsertError, err) + t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", badInsertError, err) } t.Log("bad update") @@ -303,9 +414,9 @@ func TestInsert(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insert(psqlDb, "table", "id", "column2", "value2", "column3", "value3") + err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) if err != badUpdateError { - t.Errorf("Unexpected error in 'insert', expected: %v, got: %v", badUpdateError, err) + t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", badUpdateError, err) } } @@ -437,7 +548,7 @@ func TestInsertInTableName(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + err := insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) if err != nil { t.Errorf("Unexpected error in 'insertInTableName': %v", err) } @@ -454,7 +565,7 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) if err != nil { t.Errorf("Unexpected error in 'insertInTableName': %v", err) } @@ -470,13 +581,13 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) if err != selectError { t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", selectError, err) } t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookTable where id=id, messageKey=messageKey, have 2 rows" + select2RowsError := "error insert in postgresDb. Table:WebhookBucket where id=id, messageKey=messageKey, have 2 rows" psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { @@ -487,7 +598,7 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) if err.Error() != select2RowsError { t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", select2RowsError, err) } @@ -505,7 +616,7 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) if err != badInsertError { t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", badInsertError, err) } @@ -523,7 +634,7 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "date", "messageKey", "messageValue") + err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) if err != badUpdateError { t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", badUpdateError, err) } diff --git a/dbservice/postgresdb/delete.go b/dbservice/postgresdb/delete.go index 26a20ff3..f89ab053 100644 --- a/dbservice/postgresdb/delete.go +++ b/dbservice/postgresdb/delete.go @@ -2,19 +2,22 @@ package postgresdb import ( "fmt" + "time" "github.com/jmoiron/sqlx" ) -var deleteRow = func(db *sqlx.DB, table, id, columnName, value string) error { - if _, err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE (%s=$1 AND %s=$2);", table, "id", columnName), id, value); err != nil { +var deleteRowsByIdAndTime = func(db *sqlx.DB, table, id string, t time.Time) error { + sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE (id=$1 and date < $2)", table) + if _, err := db.Exec(sqlQuery, id, t); err != nil { return err } return nil } var deleteRowsById = func(db *sqlx.DB, table, id string) error { - if _, err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s=$1", table, "id"), id); err != nil { + sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE id=$1", table) + if _, err := db.Exec(sqlQuery, id); err != nil { return err } return nil diff --git a/dbservice/postgresdb/init.go b/dbservice/postgresdb/init.go index 076d50cb..d062e15e 100644 --- a/dbservice/postgresdb/init.go +++ b/dbservice/postgresdb/init.go @@ -1,15 +1,18 @@ package postgresdb import ( + "fmt" + + "github.com/aquasecurity/postee/dbservice/dbparam" "github.com/jmoiron/sqlx" ) var ( tableSchemas = []string{ - "CREATE TABLE IF NOT EXISTS webhooktable (id varchar(32), date varchar(32), messagekey varchar(256),messagevalue text);", - "CREATE TABLE IF NOT EXISTS webhookaggregator (id varchar(32), output varchar(32), saving text);", - "CREATE TABLE IF NOT EXISTS webhookoutputstats (id varchar(32), outputname varchar(32), amount integer);", - "CREATE TABLE IF NOT EXISTS webhooksharedconfig (id varchar(32), apikeyname varchar(14),value varchar(64));", + fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id varchar(32), date timestamp, messagekey varchar(256), messagevalue bytea);", dbparam.DbBucketName), + fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id varchar(32), output varchar(32), saving bytea);", dbparam.DbBucketAggregator), + fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id varchar(32), outputname varchar(32), amount integer);", dbparam.DbBucketOutputStats), + fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id varchar(32), apikeyname varchar(14),value varchar(64));", dbparam.DbBucketSharedConfig), } ) diff --git a/dbservice/postgresdb/insert.go b/dbservice/postgresdb/insert.go index f1113100..e04ee463 100644 --- a/dbservice/postgresdb/insert.go +++ b/dbservice/postgresdb/insert.go @@ -2,68 +2,99 @@ package postgresdb import ( "fmt" + "time" + "github.com/aquasecurity/postee/dbservice/dbparam" "github.com/jmoiron/sqlx" ) -var insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { +var insertInTableSharedConfig = func(db *sqlx.DB, id, apikeyname, value string) error { var i int - if err := db.Get(&i, fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (%s=$1 AND %s=$2)", table, "id", columnName2), id, value2); err != nil { + sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (id=$1 AND apikeyname=$2)", dbparam.DbBucketSharedConfig) + if err := db.Get(&i, sqlQuery, id, apikeyname); err != nil { return err } if i == 0 { - if _, err := db.Exec(fmt.Sprintf("INSERT INTO %s (%s, %s, %s) VALUES ($1, $2, $3)", table, "id", columnName2, columnName3), id, value2, value3); err != nil { + sqlQuery = fmt.Sprintf("INSERT INTO %s (id, apikeyname, value) VALUES ($1, $2, $3)", dbparam.DbBucketSharedConfig) + if _, err := db.Exec(sqlQuery, id, apikeyname, value); err != nil { return err } } else if i == 1 { - if _, err := db.Exec(fmt.Sprintf("UPDATE %s SET %s=$1 WHERE (%s=$2 AND %s=$3);", table, columnName3, "id", columnName2), value3, id, value2); err != nil { + sqlQuery = fmt.Sprintf("UPDATE %s SET value=$1 WHERE (id=$2 AND apikeyname=$3);", dbparam.DbBucketSharedConfig) + if _, err := db.Exec(sqlQuery, value, id, apikeyname); err != nil { return err } } else { - return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, %s=%s, have %d rows", table, id, columnName2, value2, i) + return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, apikeyname=%s, have %d rows", dbparam.DbBucketSharedConfig, id, apikeyname, i) } return nil } -var insertInTableName = func(db *sqlx.DB, id, date, messageKey, messageValue string) error { +var insertInTableAggregator = func(db *sqlx.DB, id, output string, saving []byte) error { var i int - if err := db.Get(&i, fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (%s=$1 AND %s=$2)", dbTableName, "id", "messageKey"), id, messageKey); err != nil { + sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (id=$1 AND output=$2)", dbparam.DbBucketAggregator) + if err := db.Get(&i, sqlQuery, id, output); err != nil { return err } if i == 0 { - if _, err := db.Exec(fmt.Sprintf("INSERT INTO %s (%s, %s, %s, %s) VALUES ($1, $2, $3, $4)", dbTableName, "id", "date", "messagekey", "messagevalue"), - id, date, messageKey, messageValue); err != nil { + sqlQuery = fmt.Sprintf("INSERT INTO %s (id, output, saving) VALUES ($1, $2, $3)", dbparam.DbBucketAggregator) + if _, err := db.Exec(sqlQuery, id, output, saving); err != nil { return err } } else if i == 1 { - if _, err := db.Exec(fmt.Sprintf("UPDATE %s SET %s=$1, %s=$2 WHERE (%s=$3 AND %s=$4);", dbTableName, "date", "messagevalue", "id", "messagekey"), - date, messageValue, id, messageKey); err != nil { + sqlQuery = fmt.Sprintf("UPDATE %s SET saving=$1 WHERE (id=$2 AND output=$3);", dbparam.DbBucketAggregator) + if _, err := db.Exec(sqlQuery, saving, id, output); err != nil { return err } } else { - return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, messageKey=%s, have %d rows", dbTableName, id, messageKey, i) + return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, output=%s, have %d rows", dbparam.DbBucketAggregator, id, output, i) + } + return nil +} + +var insertInTableName = func(db *sqlx.DB, id, messageKey string, messageValue []byte, date *time.Time) error { + var i int + sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (id=$1 AND %s=$2)", dbparam.DbBucketName, "messageKey") + if err := db.Get(&i, sqlQuery, id, messageKey); err != nil { + return err + } + if i == 0 { + sqlQuery = fmt.Sprintf("INSERT INTO %s (id, %s, %s, %s) VALUES ($1, $2, $3, $4)", dbparam.DbBucketName, "date", "messagekey", "messagevalue") + if _, err := db.Exec(sqlQuery, id, date, messageKey, messageValue); err != nil { + return err + } + } else if i == 1 { + sqlQuery = fmt.Sprintf("UPDATE %s SET %s=$1, %s=$2 WHERE (id=$3 AND %s=$4);", dbparam.DbBucketName, "date", "messagevalue", "messagekey") + if _, err := db.Exec(sqlQuery, date, messageValue, id, messageKey); err != nil { + return err + } + } else { + return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, messageKey=%s, have %d rows", dbparam.DbBucketName, id, messageKey, i) } return nil } var insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { var i int - err := db.Get(&i, fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (%s=$1 AND %s=$2)", dbTableOutputStats, "id", "outputName"), id, outputName) + sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (id=$1 AND %s=$2)", dbparam.DbBucketOutputStats, "outputName") + err := db.Get(&i, sqlQuery, id, outputName) if err != nil { return err } if i == 0 { - _, err := db.Exec(fmt.Sprintf("INSERT INTO %s (%s, %s, %s) VALUES ($1, $2, $3);", dbTableOutputStats, "id", "outputName", "amount"), id, outputName, amount) + sqlQuery = fmt.Sprintf("INSERT INTO %s (id, %s, %s) VALUES ($1, $2, $3);", dbparam.DbBucketOutputStats, "outputName", "amount") + _, err := db.Exec(sqlQuery, id, outputName, amount) if err != nil { return err } } else if i == 1 { - _, err = db.Exec(fmt.Sprintf("UPDATE %s SET %s=$1 WHERE (%s=$2 AND %s=$3);", dbTableOutputStats, "amount", "id", "outputName"), amount, id, outputName) + sqlQuery = fmt.Sprintf("UPDATE %s SET %s=$1 WHERE (id=$2 AND %s=$3);", dbparam.DbBucketOutputStats, "amount", "outputName") + _, err = db.Exec(sqlQuery, amount, id, outputName) if err != nil { return err } } else { - return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, outputName=%s, have %d rows", dbTableOutputStats, id, outputName, i) + return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, outputName=%s, have %d rows", dbparam.DbBucketOutputStats, id, outputName, i) } return nil diff --git a/dbservice/postgresdb/plgstats.go b/dbservice/postgresdb/plgstats.go index d27470e6..283b4a80 100644 --- a/dbservice/postgresdb/plgstats.go +++ b/dbservice/postgresdb/plgstats.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" + "github.com/aquasecurity/postee/dbservice/dbparam" _ "github.com/lib/pq" ) @@ -15,7 +16,8 @@ func (postgresDb *PostgresDb) RegisterPlgnInvctn(name string) error { defer db.Close() amount := 0 - err = db.Get(&amount, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "amount", dbTableOutputStats, "id", "outputName"), postgresDb.Id, name) + sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (id=$1 AND %s=$2)", "amount", dbparam.DbBucketOutputStats, "outputName") + err = db.Get(&amount, sqlQuery, postgresDb.Id, name) if err != nil && err != sql.ErrNoRows { return err } diff --git a/dbservice/postgresdb/postgresdb.go b/dbservice/postgresdb/postgresdb.go new file mode 100644 index 00000000..d77c5466 --- /dev/null +++ b/dbservice/postgresdb/postgresdb.go @@ -0,0 +1,13 @@ +package postgresdb + +type PostgresDb struct { + ConnectUrl string + Id string +} + +func NewPostgresDb(id, connectUrl string) *PostgresDb { + return &PostgresDb{ + ConnectUrl: connectUrl, + Id: id, + } +} diff --git a/dbservice/postgresdb/sharedcfg.go b/dbservice/postgresdb/sharedcfg.go index 5ef11dbd..38ea810d 100644 --- a/dbservice/postgresdb/sharedcfg.go +++ b/dbservice/postgresdb/sharedcfg.go @@ -1,11 +1,9 @@ package postgresdb import ( - "crypto/rand" - "encoding/hex" "fmt" - "io" + "github.com/aquasecurity/postee/dbservice/dbparam" _ "github.com/lib/pq" ) @@ -18,12 +16,12 @@ func (postgresDb *PostgresDb) EnsureApiKey() error { } defer db.Close() - apiKey, err := generateApiKey(32) + apiKey, err := dbparam.GenerateApiKey(32) if err != nil { return err } - if err = insert(db, dbTableSharedConfig, postgresDb.Id, "apikeyname", apiKeyName, "value", apiKey); err != nil { + if err = insertInTableSharedConfig(db, postgresDb.Id, apiKeyName, apiKey); err != nil { return err } @@ -37,18 +35,11 @@ func (postgresDb *PostgresDb) GetApiKey() (string, error) { } defer db.Close() value := "" - err = db.Get(&value, fmt.Sprintf("SELECT %s FROM %s WHERE (%s=$1 AND %s=$2)", "value", dbTableSharedConfig, "id", "apikeyname"), postgresDb.Id, apiKeyName) + sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (id=$1 AND %s=$2)", "value", dbparam.DbBucketSharedConfig, "apikeyname") + err = db.Get(&value, sqlQuery, postgresDb.Id, apiKeyName) if err != nil { return "", err } return value, nil } - -func generateApiKey(length int) (string, error) { - k := make([]byte, length) - if _, err := io.ReadFull(rand.Reader, k); err != nil { - return "", err - } - return hex.EncodeToString(k), nil -} diff --git a/dbservice/postgresdb/sharedcfg_test.go b/dbservice/postgresdb/sharedcfg_test.go index 343b54a0..4317a199 100644 --- a/dbservice/postgresdb/sharedcfg_test.go +++ b/dbservice/postgresdb/sharedcfg_test.go @@ -10,8 +10,8 @@ import ( ) func TestApiKey(t *testing.T) { - savedInsert := insert - insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { return nil } + savedInsertInTableSharedConfig := insertInTableSharedConfig + insertInTableSharedConfig = func(db *sqlx.DB, id, apikeyname, value string) error { return nil } savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() @@ -23,7 +23,7 @@ func TestApiKey(t *testing.T) { return db, err } defer func() { - insert = savedInsert + insertInTableSharedConfig = savedInsertInTableSharedConfig psqlConnect = savedPsqlConnect }() @@ -62,9 +62,9 @@ func TestApiKeyWithoutInit(t *testing.T) { func TestApiKeyRenewal(t *testing.T) { receivedKey := "" - savedInsert := insert - insert = func(db *sqlx.DB, table, id, columnName2, value2, columnName3, value3 string) error { - receivedKey = value3 + savedInsertInTableSharedConfig := insertInTableSharedConfig + insertInTableSharedConfig = func(db *sqlx.DB, id, apikeyname, value string) error { + receivedKey = value return nil } savedPsqlConnect := psqlConnect @@ -78,7 +78,7 @@ func TestApiKeyRenewal(t *testing.T) { return db, err } defer func() { - insert = savedInsert + insertInTableSharedConfig = savedInsertInTableSharedConfig psqlConnect = savedPsqlConnect }() diff --git a/router/router.go b/router/router.go index 37e50f82..8bc52ace 100644 --- a/router/router.go +++ b/router/router.go @@ -198,7 +198,7 @@ func (ctx *Router) load() error { ctx.aquaServer = fmt.Sprintf("%s%s#/images/", tenant.AquaServer, slash) } - if err = dbservice.ConfigureDb(tenant.Name, &tenant.DBTestInterval, tenant.DBMaxSize); err != nil { + if err = dbservice.ConfigureDb("$PATH_TO_DB", "$POSTGRES_URL", tenant.Name, &tenant.DBTestInterval, tenant.DBMaxSize); err != nil { return err } From 90c2311cc2fb7d0a27687eedce805c8335f10027 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Mon, 6 Dec 2021 17:56:08 +0600 Subject: [PATCH 25/61] =?UTF-8?q?Refactor=20:=20added=20=D0=B2=D1=83=D0=B4?= =?UTF-8?q?=D1=83=D0=B5=D1=83=D0=B2=20temp=20dir=20in=20configureDb=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dbservice/dbservice_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbservice/dbservice_test.go b/dbservice/dbservice_test.go index 4d673a2e..c89546a6 100644 --- a/dbservice/dbservice_test.go +++ b/dbservice/dbservice_test.go @@ -42,8 +42,8 @@ func TestConfigurateBoltDbPathUsedEnv(t *testing.T) { if test.expectedPath != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("DbPath").Interface() { t.Errorf("paths do not match, expected: %s, got: %s", test.expectedPath, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("DbPath").Interface()) } - }) + defer os.RemoveAll("database/") } } From ccfec3617a897af14f5e587b048de96306308d34 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Mon, 6 Dec 2021 18:15:41 +0600 Subject: [PATCH 26/61] Docs(psql): added psql env info to readme --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 9b7bf72e..147e33bd 100644 --- a/README.md +++ b/README.md @@ -336,15 +336,18 @@ See [Postee UI](PosteeUI.md) for details how to setup the Postee UI. ## Misc ### Data Persistency -The Postee container uses BoltDB to store information about previously scanned images. +The Postee container uses BoltDB or PostgreSQL to store information about previously scanned images. This is used to prevent resending messages that were already sent before. -The size of the database can grow over time. Every image that is saved in the database uses 20K of storage. - +The size of the database can grow over time. Every image that is saved in the Bolt database uses 20K of storage. +The default Postee Database is BoltDb. + Postee supports ‘PATH_TO_DB’ environment variable to change the bolt database directory. To use, set the ‘PATH_TO_DB’ environment variable to point to the bolt database file, for example: PATH_TO_DB="./database/webhook.db". By default, the directory for the bolt database file is “/server/database/webhook.db”. If you would like to persist the database file between restarts of the Postee container, then you should use a persistent storage option to mount the "/server/database" directory of the container. The "deploy/kubernetes" directory in this project contains an example deployment that includes a basic Host Persistency. + +To use PostgreSQL set the 'POSTGRES_URL' environment variable to your [connection URI](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING). If you would like to connect 2 or more Postee to 1 PostgreSQL to use unique tenant name in postee config file. ### Using environment variables in Postee Configuration File Postee supports use of environment variables for *Output* fields: **User**, **Password** and **Token**. Add preffix `$` to the environment variable name in the configuration file, for example: From f9473c30932e2279c9c267e954e3184a737b44db Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Tue, 7 Dec 2021 11:23:40 +0600 Subject: [PATCH 27/61] refactor: change parametrs for deleteRowsByIdAndTime --- dbservice/postgresdb/checker.go | 2 +- dbservice/postgresdb/checker_test.go | 2 +- dbservice/postgresdb/dbservice_test.go | 4 ++-- dbservice/postgresdb/delete.go | 13 +++++++++++-- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/dbservice/postgresdb/checker.go b/dbservice/postgresdb/checker.go index 2a961a03..50d2a8f6 100644 --- a/dbservice/postgresdb/checker.go +++ b/dbservice/postgresdb/checker.go @@ -44,7 +44,7 @@ func (postgresDb *PostgresDb) CheckExpiredData() { defer db.Close() max := time.Now().UTC() //remove expired records - if err = deleteRowsByIdAndTime(db, dbparam.DbBucketName, postgresDb.Id, max); err != nil { + if err = deleteRowsByIdAndTime(db, postgresDb.Id, max); err != nil { log.Printf("CheckExpiredData: Can't delete dates from table:%s, err: %v", dbparam.DbBucketName, err) } } diff --git a/dbservice/postgresdb/checker_test.go b/dbservice/postgresdb/checker_test.go index 722fbccf..2e040f7d 100644 --- a/dbservice/postgresdb/checker_test.go +++ b/dbservice/postgresdb/checker_test.go @@ -32,7 +32,7 @@ func TestExpiredDates(t *testing.T) { return db, err } savedDeleteRow := deleteRowsByIdAndTime - deleteRowsByIdAndTime = func(db *sqlx.DB, table, id string, t time.Time) error { + deleteRowsByIdAndTime = func(db *sqlx.DB, id string, t time.Time) error { if !test.deleteError { deleted = true } diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index b8a8c41b..086c1b50 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -129,7 +129,7 @@ func TestDeleteRowsByIdAndTime(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := deleteRowsByIdAndTime(psqlDb, "table", "id", time.Now()) + err := deleteRowsByIdAndTime(psqlDb, "id", time.Now()) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -149,7 +149,7 @@ func TestDeleteRowsByIdAndTime(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = deleteRowsByIdAndTime(psqlDb, "table", "id", time.Now()) + err = deleteRowsByIdAndTime(psqlDb, "id", time.Now()) if deleteError != err { t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) } diff --git a/dbservice/postgresdb/delete.go b/dbservice/postgresdb/delete.go index f89ab053..794a718b 100644 --- a/dbservice/postgresdb/delete.go +++ b/dbservice/postgresdb/delete.go @@ -4,11 +4,12 @@ import ( "fmt" "time" + "github.com/aquasecurity/postee/dbservice/dbparam" "github.com/jmoiron/sqlx" ) -var deleteRowsByIdAndTime = func(db *sqlx.DB, table, id string, t time.Time) error { - sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE (id=$1 and date < $2)", table) +var deleteRowsByIdAndTime = func(db *sqlx.DB, id string, t time.Time) error { + sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE (id=$1 AND date < $2)", dbparam.DbBucketName) if _, err := db.Exec(sqlQuery, id, t); err != nil { return err } @@ -22,3 +23,11 @@ var deleteRowsById = func(db *sqlx.DB, table, id string) error { } return nil } + +// var deleteRowsByIdAndOutput = func(db *sqlx.DB, id, output string) error { +// sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE (id=$1 AND output=$2)", dbparam.DbBucketAggregator) +// if _, err := db.Exec(sqlQuery, id, output); err != nil { +// return err +// } +// return nil +// } From ec36e3bc951d13fba5f13d77acd82d78588ba918 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Tue, 7 Dec 2021 12:30:33 +0600 Subject: [PATCH 28/61] Refactor: code review notes are corrected --- dbservice/dbservice.go | 8 +-- dbservice/dbservice_test.go | 18 ++--- dbservice/postgresdb/actions.go | 15 ++-- dbservice/postgresdb/actions_test.go | 3 +- dbservice/postgresdb/checker.go | 6 +- dbservice/postgresdb/checker_test.go | 12 ++-- dbservice/postgresdb/dbaggregator.go | 8 +-- dbservice/postgresdb/dbaggregator_test.go | 2 +- dbservice/postgresdb/dbservice_test.go | 88 +++++++++++------------ dbservice/postgresdb/delete.go | 20 ++---- dbservice/postgresdb/init.go | 8 +-- dbservice/postgresdb/insert.go | 65 +++++++++-------- dbservice/postgresdb/plgstats.go | 7 +- dbservice/postgresdb/plgstats_test.go | 4 +- dbservice/postgresdb/postgresdb.go | 6 +- dbservice/postgresdb/sharedcfg.go | 8 +-- dbservice/postgresdb/sharedcfg_test.go | 4 +- 17 files changed, 131 insertions(+), 151 deletions(-) diff --git a/dbservice/dbservice.go b/dbservice/dbservice.go index 6b0c3c7d..6ef67be2 100644 --- a/dbservice/dbservice.go +++ b/dbservice/dbservice.go @@ -24,7 +24,7 @@ type DbProvider interface { GetApiKey() (string, error) } -func ConfigureDb(pathToDb, postgresUrl, tenantId string, dBTestInterval *int, dbMaxSize int) error { +func ConfigureDb(pathToDb, postgresUrl, tenantName string, dBTestInterval *int, dbMaxSize int) error { if *dBTestInterval == 0 { *dBTestInterval = 1 } @@ -33,10 +33,10 @@ func ConfigureDb(pathToDb, postgresUrl, tenantId string, dBTestInterval *int, db pathToDb = utils.GetEnvironmentVarOrPlain(pathToDb) if postgresUrl != "" { - if tenantId == "" { - return errors.New("error configurate postgresDb: 'tenantId' is empty") + if tenantName == "" { + return errors.New("error configurate postgresDb: 'tenantName' is empty") } - postgresDb := postgresdb.NewPostgresDb(tenantId, postgresUrl) + postgresDb := postgresdb.NewPostgresDb(tenantName, postgresUrl) if err := postgresdb.InitPostgresDb(postgresDb.ConnectUrl); err != nil { return err } diff --git a/dbservice/dbservice_test.go b/dbservice/dbservice_test.go index c89546a6..5a2df17d 100644 --- a/dbservice/dbservice_test.go +++ b/dbservice/dbservice_test.go @@ -47,18 +47,18 @@ func TestConfigurateBoltDbPathUsedEnv(t *testing.T) { } } -func TestConfiguratePostgresDbUrlAndId(t *testing.T) { +func TestConfiguratePostgresDbUrlAndTenantName(t *testing.T) { tests := []struct { name string url string urlInEnv string - id string + tenantName string expectedError error }{ - {"happy configuration postgres with url", "postgresql://user:secret@localhost", "", "test-id", nil}, - {"happy configuration postgres with env", "$POSTGRES_URL", "postgresql://user:secret@localhost", "test-id", nil}, - {"bad id", "postgresql://user:secret@localhost", "", "", errors.New("error configurate postgresDb: 'tenantId' is empty")}, - {"bad url", "badUrl", "", "test-id", errors.New("badUrl error")}, + {"happy configuration postgres with url", "postgresql://user:secret@localhost", "", "test-tenantName", nil}, + {"happy configuration postgres with env", "$POSTGRES_URL", "postgresql://user:secret@localhost", "test-tenantName", nil}, + {"bad tenantName", "postgresql://user:secret@localhost", "", "", errors.New("error configurate postgresDb: 'tenantName' is empty")}, + {"bad url", "badUrl", "", "test-tenantName", errors.New("badUrl error")}, } for _, test := range tests { @@ -78,7 +78,7 @@ func TestConfiguratePostgresDbUrlAndId(t *testing.T) { }() testInterval := 0 - err := ConfigureDb("", test.url, test.id, &testInterval, 1) + err := ConfigureDb("", test.url, test.tenantName, &testInterval, 1) if err != nil { if err.Error() != test.expectedError.Error() { t.Errorf("Unexpected error, expected: %s, got: %s", test.expectedError, err) @@ -96,8 +96,8 @@ func TestConfiguratePostgresDbUrlAndId(t *testing.T) { t.Errorf("url's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface()) } } - if test.id != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("Id").Interface() { - t.Errorf("id's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("Id").Interface()) + if test.tenantName != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("TenantName").Interface() { + t.Errorf("tenantName's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("TenantName").Interface()) } } }) diff --git a/dbservice/postgresdb/actions.go b/dbservice/postgresdb/actions.go index 6029791c..7ae23614 100644 --- a/dbservice/postgresdb/actions.go +++ b/dbservice/postgresdb/actions.go @@ -16,8 +16,8 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin defer db.Close() currentValue := "" - sqlQuery := fmt.Sprintf("SELECT messageValue FROM %s WHERE (id=$1 AND messageKey=$2)", dbparam.DbBucketName) - if err = db.Get(¤tValue, sqlQuery, postgresDb.Id, messageKey); err != nil { + sqlQuery := fmt.Sprintf("SELECT messageValue FROM %s WHERE (tenantName=$1 AND messageKey=$2)", dbparam.DbBucketName) + if err = db.Get(¤tValue, sqlQuery, postgresDb.TenantName, messageKey); err != nil { if err != sql.ErrNoRows { return false, err } @@ -26,16 +26,9 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin if currentValue != "" { return false, nil } else { - if expired != nil { - if err = insertInTableName(db, postgresDb.Id, messageKey, message, expired); err != nil { - return false, err - } - } else { - if err = insertInTableName(db, postgresDb.Id, messageKey, message, nil); err != nil { - return false, err - } + if err = insertInTableName(db, postgresDb.TenantName, messageKey, message, expired); err != nil { + return false, err } return true, nil } - } diff --git a/dbservice/postgresdb/actions_test.go b/dbservice/postgresdb/actions_test.go index 265dfe3e..485f370b 100644 --- a/dbservice/postgresdb/actions_test.go +++ b/dbservice/postgresdb/actions_test.go @@ -13,7 +13,7 @@ func TestStoreMessage(t *testing.T) { currentValueStoreMessage := []byte{} savedinsertInTableName := insertInTableName - insertInTableName = func(db *sqlx.DB, id, messageKey string, messageValue []byte, date *time.Time) error { + insertInTableName = func(db *sqlx.DB, tenantName, messageKey string, messageValue []byte, date *time.Time) error { currentValueStoreMessage = messageValue return nil } @@ -61,5 +61,4 @@ func TestStoreMessage(t *testing.T) { } currentValueStoreMessage = []byte{} } - } diff --git a/dbservice/postgresdb/checker.go b/dbservice/postgresdb/checker.go index 50d2a8f6..049df045 100644 --- a/dbservice/postgresdb/checker.go +++ b/dbservice/postgresdb/checker.go @@ -27,8 +27,8 @@ func (postgresDb *PostgresDb) CheckSizeLimit() { return } if size > dbparam.DbSizeLimit { - if err = deleteRowsById(db, dbparam.DbBucketName, postgresDb.Id); err != nil { - log.Printf("CheckSizeLimit: Can't delete id's: %s from table: %s", postgresDb.Id, dbparam.DbBucketName) + if err = deleteRowsByTenantName(db, dbparam.DbBucketName, postgresDb.TenantName); err != nil { + log.Printf("CheckSizeLimit: Can't delete tenantName's: %s from table: %s", postgresDb.TenantName, dbparam.DbBucketName) return } } @@ -44,7 +44,7 @@ func (postgresDb *PostgresDb) CheckExpiredData() { defer db.Close() max := time.Now().UTC() //remove expired records - if err = deleteRowsByIdAndTime(db, postgresDb.Id, max); err != nil { + if err = deleteRowsByTenantNameAndTime(db, postgresDb.TenantName, max); err != nil { log.Printf("CheckExpiredData: Can't delete dates from table:%s, err: %v", dbparam.DbBucketName, err) } } diff --git a/dbservice/postgresdb/checker_test.go b/dbservice/postgresdb/checker_test.go index 2e040f7d..dc86c844 100644 --- a/dbservice/postgresdb/checker_test.go +++ b/dbservice/postgresdb/checker_test.go @@ -31,8 +31,8 @@ func TestExpiredDates(t *testing.T) { } return db, err } - savedDeleteRow := deleteRowsByIdAndTime - deleteRowsByIdAndTime = func(db *sqlx.DB, id string, t time.Time) error { + savedDeleteRow := deleteRowsByTenantNameAndTime + deleteRowsByTenantNameAndTime = func(db *sqlx.DB, tenantName string, t time.Time) error { if !test.deleteError { deleted = true } @@ -40,7 +40,7 @@ func TestExpiredDates(t *testing.T) { } defer func() { psqlConnect = savedPsqlConnect - deleteRowsByIdAndTime = savedDeleteRow + deleteRowsByTenantNameAndTime = savedDeleteRow }() db.CheckExpiredData() if deleted != test.wasDeleted { @@ -74,14 +74,14 @@ func TestSizeLimit(t *testing.T) { mock.ExpectQuery("SELECT").WillReturnRows(rows) return db, err } - savedDeleteRowsById := deleteRowsById - deleteRowsById = func(db *sqlx.DB, table, id string) error { + savedDeleteRowsByTenantName := deleteRowsByTenantName + deleteRowsByTenantName = func(db *sqlx.DB, table, tenantName string) error { deleted = true return nil } defer func() { psqlConnect = savedPsqlConnect - deleteRowsById = savedDeleteRowsById + deleteRowsByTenantName = savedDeleteRowsByTenantName }() dbparam.DbSizeLimit = test.sizeLimit db.CheckSizeLimit() diff --git a/dbservice/postgresdb/dbaggregator.go b/dbservice/postgresdb/dbaggregator.go index 06243d1e..5117f28e 100644 --- a/dbservice/postgresdb/dbaggregator.go +++ b/dbservice/postgresdb/dbaggregator.go @@ -24,8 +24,8 @@ func (postgresDb *PostgresDb) AggregateScans(output string, aggregatedScans = append(aggregatedScans, currentScan) } currentValue := []byte{} - sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (id=$1 AND %s=$2)", "saving", dbparam.DbBucketAggregator, "output") - if err = db.Get(¤tValue, sqlQuery, postgresDb.Id, output); err != nil { + sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (tenantName=$1 AND %s=$2)", "saving", dbparam.DbBucketAggregator, "output") + if err = db.Get(¤tValue, sqlQuery, postgresDb.TenantName, output); err != nil { if err != sql.ErrNoRows { return nil, err } @@ -45,13 +45,13 @@ func (postgresDb *PostgresDb) AggregateScans(output string, if err != nil { return nil, err } - if err = insertInTableAggregator(db, postgresDb.Id, output, saving); err != nil { + if err = insertInTableAggregator(db, postgresDb.TenantName, output, saving); err != nil { return nil, err } return nil, nil } - if err = insertInTableAggregator(db, postgresDb.Id, output, nil); err != nil { + if err = insertInTableAggregator(db, postgresDb.TenantName, output, nil); err != nil { return nil, err } return aggregatedScans, nil diff --git a/dbservice/postgresdb/dbaggregator_test.go b/dbservice/postgresdb/dbaggregator_test.go index e68a6906..cf1fca30 100644 --- a/dbservice/postgresdb/dbaggregator_test.go +++ b/dbservice/postgresdb/dbaggregator_test.go @@ -51,7 +51,7 @@ func TestAggregateScans(t *testing.T) { savingTest := []byte{} for i := 0; i < len(tests); i++ { savedInsertInTableAggregator := insertInTableAggregator - insertInTableAggregator = func(db *sqlx.DB, id, output string, saving []byte) error { + insertInTableAggregator = func(db *sqlx.DB, tenantName, output string, saving []byte) error { savingTest = saving return nil } diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index 086c1b50..c1abaede 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -75,7 +75,7 @@ var ( ] }` - db = NewPostgresDb("id", "postgresql://user:secret@localhost/dbname?sslmode=disable") + db = NewPostgresDb("tenantName", "postgresql://user:secret@localhost/dbname?sslmode=disable") ) func TestInitError(t *testing.T) { @@ -109,13 +109,13 @@ func TestInitError(t *testing.T) { } } -func TestDeleteRowsByIdAndTime(t *testing.T) { +func TestDeleteRowsByTenantNameAndTime(t *testing.T) { t.Log("happy delete row") savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRow := deleteRowsByIdAndTime + savedDeleteRow := deleteRowsByTenantNameAndTime defer func() { - deleteRowsByIdAndTime = savedDeleteRow + deleteRowsByTenantNameAndTime = savedDeleteRow }() db, mock, err := sqlxmock.Newx() if err != nil { @@ -129,7 +129,7 @@ func TestDeleteRowsByIdAndTime(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := deleteRowsByIdAndTime(psqlDb, "id", time.Now()) + err := deleteRowsByTenantNameAndTime(psqlDb, "tenantName", time.Now()) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -137,9 +137,9 @@ func TestDeleteRowsByIdAndTime(t *testing.T) { t.Log("bad delete row") deleteError := errors.New("delete - error") psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRow := deleteRowsByIdAndTime + savedDeleteRow := deleteRowsByTenantNameAndTime defer func() { - deleteRowsByIdAndTime = savedDeleteRow + deleteRowsByTenantNameAndTime = savedDeleteRow }() db, mock, err := sqlxmock.Newx() if err != nil { @@ -149,18 +149,18 @@ func TestDeleteRowsByIdAndTime(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = deleteRowsByIdAndTime(psqlDb, "id", time.Now()) + err = deleteRowsByTenantNameAndTime(psqlDb, "tenantName", time.Now()) if deleteError != err { t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) } } -func TestDeleteRowsById(t *testing.T) { - t.Log("happy delete rows by id") +func TestDeleteRowsByTenantName(t *testing.T) { + t.Log("happy delete rows by tenantName") savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRowsById := deleteRowsById + savedDeleteRowsByTenantName := deleteRowsByTenantName defer func() { - deleteRowsById = savedDeleteRowsById + deleteRowsByTenantName = savedDeleteRowsByTenantName }() db, mock, err := sqlxmock.Newx() if err != nil { @@ -174,7 +174,7 @@ func TestDeleteRowsById(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := deleteRowsById(psqlDb, "table", "id") + err := deleteRowsByTenantName(psqlDb, "table", "tenantName") if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -182,9 +182,9 @@ func TestDeleteRowsById(t *testing.T) { t.Log("bad delete row") deleteError := errors.New("delete - error") psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRowsById := deleteRowsById + savedDeleteRowsByTenantName := deleteRowsByTenantName defer func() { - deleteRowsById = savedDeleteRowsById + deleteRowsByTenantName = savedDeleteRowsByTenantName }() db, mock, err := sqlxmock.Newx() if err != nil { @@ -194,7 +194,7 @@ func TestDeleteRowsById(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = deleteRowsById(psqlDb, "table", "id") + err = deleteRowsByTenantName(psqlDb, "table", "tenantName") if deleteError != err { t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) } @@ -218,7 +218,7 @@ func TestInsertInTableSharedConfig(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + err := insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") if err != nil { t.Errorf("Unexpected error in 'insertInTableSharedConfig': %v", err) } @@ -235,7 +235,7 @@ func TestInsertInTableSharedConfig(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") if err != nil { t.Errorf("Unexpected error in 'insertInTableSharedConfig': %v", err) } @@ -251,13 +251,13 @@ func TestInsertInTableSharedConfig(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") if err != selectError { t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", selectError, err) } t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookSharedConfig where id=id, apikeyname=value2, have 2 rows" + select2RowsError := "error insert in postgresDb. Table:WebhookSharedConfig where tenantName=tenantName, apikeyname=value2, have 2 rows" psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { @@ -268,7 +268,7 @@ func TestInsertInTableSharedConfig(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") if err.Error() != select2RowsError { t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", select2RowsError, err) } @@ -286,7 +286,7 @@ func TestInsertInTableSharedConfig(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") if err != badInsertError { t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", badInsertError, err) } @@ -304,7 +304,7 @@ func TestInsertInTableSharedConfig(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "id", "value2", "value3") + err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") if err != badUpdateError { t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", badUpdateError, err) } @@ -328,7 +328,7 @@ func TestInsertInTableAggregator(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) + err := insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) if err != nil { t.Errorf("Unexpected error in 'insert': %v", err) } @@ -345,7 +345,7 @@ func TestInsertInTableAggregator(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) + err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) if err != nil { t.Errorf("Unexpected error in 'insertInTableAggregator': %v", err) } @@ -361,13 +361,13 @@ func TestInsertInTableAggregator(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) + err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) if err != selectError { t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", selectError, err) } t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookAggregator where id=id, output=value2, have 2 rows" + select2RowsError := "error insert in postgresDb. Table:WebhookAggregator where tenantName=tenantName, output=value2, have 2 rows" psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { @@ -378,7 +378,7 @@ func TestInsertInTableAggregator(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) + err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) if err.Error() != select2RowsError { t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", select2RowsError, err) } @@ -396,7 +396,7 @@ func TestInsertInTableAggregator(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) + err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) if err != badInsertError { t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", badInsertError, err) } @@ -414,7 +414,7 @@ func TestInsertInTableAggregator(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "id", "value2", []byte("value3")) + err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) if err != badUpdateError { t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", badUpdateError, err) } @@ -438,7 +438,7 @@ func TestInsertOutputStats(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertOutputStats(psqlDb, "id", "outputName", 1) + err := insertOutputStats(psqlDb, "tenantName", "outputName", 1) if err != nil { t.Errorf("Unexpected error in 'insertOutputStats': %v", err) } @@ -455,7 +455,7 @@ func TestInsertOutputStats(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "id", "outputName", 1) + err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) if err != nil { t.Errorf("Unexpected error in 'insertOutputStats': %v", err) } @@ -471,13 +471,13 @@ func TestInsertOutputStats(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "id", "outputName", 1) + err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) if err != selectError { t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", selectError, err) } t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookOutputStats where id=id, outputName=outputName, have 2 rows" + select2RowsError := "error insert in postgresDb. Table:WebhookOutputStats where tenantName=tenantName, outputName=outputName, have 2 rows" psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { @@ -488,7 +488,7 @@ func TestInsertOutputStats(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "id", "outputName", 1) + err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) if err.Error() != select2RowsError { t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", select2RowsError, err) } @@ -506,7 +506,7 @@ func TestInsertOutputStats(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "id", "outputName", 1) + err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) if err != badInsertError { t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badInsertError, err) } @@ -524,7 +524,7 @@ func TestInsertOutputStats(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "id", "outputName", 1) + err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) if err != badUpdateError { t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badUpdateError, err) } @@ -548,7 +548,7 @@ func TestInsertInTableName(t *testing.T) { }() psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) + err := insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) if err != nil { t.Errorf("Unexpected error in 'insertInTableName': %v", err) } @@ -565,7 +565,7 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) + err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) if err != nil { t.Errorf("Unexpected error in 'insertInTableName': %v", err) } @@ -581,13 +581,13 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) + err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) if err != selectError { t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", selectError, err) } t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookBucket where id=id, messageKey=messageKey, have 2 rows" + select2RowsError := "error insert in postgresDb. Table:WebhookBucket where tenantName=tenantName, messageKey=messageKey, have 2 rows" psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { @@ -598,7 +598,7 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) + err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) if err.Error() != select2RowsError { t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", select2RowsError, err) } @@ -616,7 +616,7 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) + err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) if err != badInsertError { t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", badInsertError, err) } @@ -634,7 +634,7 @@ func TestInsertInTableName(t *testing.T) { return db, err } psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "id", "messageKey", []byte("messageValue"), nil) + err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) if err != badUpdateError { t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", badUpdateError, err) } diff --git a/dbservice/postgresdb/delete.go b/dbservice/postgresdb/delete.go index 794a718b..cd45c309 100644 --- a/dbservice/postgresdb/delete.go +++ b/dbservice/postgresdb/delete.go @@ -8,26 +8,18 @@ import ( "github.com/jmoiron/sqlx" ) -var deleteRowsByIdAndTime = func(db *sqlx.DB, id string, t time.Time) error { - sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE (id=$1 AND date < $2)", dbparam.DbBucketName) - if _, err := db.Exec(sqlQuery, id, t); err != nil { +var deleteRowsByTenantNameAndTime = func(db *sqlx.DB, tenantName string, t time.Time) error { + sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE (tenantName=$1 AND date < $2)", dbparam.DbBucketName) + if _, err := db.Exec(sqlQuery, tenantName, t); err != nil { return err } return nil } -var deleteRowsById = func(db *sqlx.DB, table, id string) error { - sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE id=$1", table) - if _, err := db.Exec(sqlQuery, id); err != nil { +var deleteRowsByTenantName = func(db *sqlx.DB, table, tenantName string) error { + sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE tenantName=$1", table) + if _, err := db.Exec(sqlQuery, tenantName); err != nil { return err } return nil } - -// var deleteRowsByIdAndOutput = func(db *sqlx.DB, id, output string) error { -// sqlQuery := fmt.Sprintf("DELETE FROM %s WHERE (id=$1 AND output=$2)", dbparam.DbBucketAggregator) -// if _, err := db.Exec(sqlQuery, id, output); err != nil { -// return err -// } -// return nil -// } diff --git a/dbservice/postgresdb/init.go b/dbservice/postgresdb/init.go index d062e15e..cae93800 100644 --- a/dbservice/postgresdb/init.go +++ b/dbservice/postgresdb/init.go @@ -9,10 +9,10 @@ import ( var ( tableSchemas = []string{ - fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id varchar(32), date timestamp, messagekey varchar(256), messagevalue bytea);", dbparam.DbBucketName), - fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id varchar(32), output varchar(32), saving bytea);", dbparam.DbBucketAggregator), - fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id varchar(32), outputname varchar(32), amount integer);", dbparam.DbBucketOutputStats), - fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id varchar(32), apikeyname varchar(14),value varchar(64));", dbparam.DbBucketSharedConfig), + fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (tenantName varchar(32), date timestamp, messagekey varchar(256), messagevalue bytea);", dbparam.DbBucketName), + fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (tenantName varchar(32), output varchar(32), saving bytea);", dbparam.DbBucketAggregator), + fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (tenantName varchar(32), outputname varchar(32), amount integer);", dbparam.DbBucketOutputStats), + fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (tenantName varchar(32), apikeyname varchar(14),value varchar(64));", dbparam.DbBucketSharedConfig), } ) diff --git a/dbservice/postgresdb/insert.go b/dbservice/postgresdb/insert.go index e04ee463..57882a55 100644 --- a/dbservice/postgresdb/insert.go +++ b/dbservice/postgresdb/insert.go @@ -8,94 +8,93 @@ import ( "github.com/jmoiron/sqlx" ) -var insertInTableSharedConfig = func(db *sqlx.DB, id, apikeyname, value string) error { +var insertInTableSharedConfig = func(db *sqlx.DB, tenantName, apikeyname, value string) error { var i int - sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (id=$1 AND apikeyname=$2)", dbparam.DbBucketSharedConfig) - if err := db.Get(&i, sqlQuery, id, apikeyname); err != nil { + sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (tenantName=$1 AND apikeyname=$2)", dbparam.DbBucketSharedConfig) + if err := db.Get(&i, sqlQuery, tenantName, apikeyname); err != nil { return err } if i == 0 { - sqlQuery = fmt.Sprintf("INSERT INTO %s (id, apikeyname, value) VALUES ($1, $2, $3)", dbparam.DbBucketSharedConfig) - if _, err := db.Exec(sqlQuery, id, apikeyname, value); err != nil { + sqlQuery = fmt.Sprintf("INSERT INTO %s (tenantName, apikeyname, value) VALUES ($1, $2, $3)", dbparam.DbBucketSharedConfig) + if _, err := db.Exec(sqlQuery, tenantName, apikeyname, value); err != nil { return err } } else if i == 1 { - sqlQuery = fmt.Sprintf("UPDATE %s SET value=$1 WHERE (id=$2 AND apikeyname=$3);", dbparam.DbBucketSharedConfig) - if _, err := db.Exec(sqlQuery, value, id, apikeyname); err != nil { + sqlQuery = fmt.Sprintf("UPDATE %s SET value=$1 WHERE (tenantName=$2 AND apikeyname=$3);", dbparam.DbBucketSharedConfig) + if _, err := db.Exec(sqlQuery, value, tenantName, apikeyname); err != nil { return err } } else { - return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, apikeyname=%s, have %d rows", dbparam.DbBucketSharedConfig, id, apikeyname, i) + return fmt.Errorf("error insert in postgresDb. Table:%s where tenantName=%s, apikeyname=%s, have %d rows", dbparam.DbBucketSharedConfig, tenantName, apikeyname, i) } return nil } -var insertInTableAggregator = func(db *sqlx.DB, id, output string, saving []byte) error { +var insertInTableAggregator = func(db *sqlx.DB, tenantName, output string, saving []byte) error { var i int - sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (id=$1 AND output=$2)", dbparam.DbBucketAggregator) - if err := db.Get(&i, sqlQuery, id, output); err != nil { + sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (tenantName=$1 AND output=$2)", dbparam.DbBucketAggregator) + if err := db.Get(&i, sqlQuery, tenantName, output); err != nil { return err } if i == 0 { - sqlQuery = fmt.Sprintf("INSERT INTO %s (id, output, saving) VALUES ($1, $2, $3)", dbparam.DbBucketAggregator) - if _, err := db.Exec(sqlQuery, id, output, saving); err != nil { + sqlQuery = fmt.Sprintf("INSERT INTO %s (tenantName, output, saving) VALUES ($1, $2, $3)", dbparam.DbBucketAggregator) + if _, err := db.Exec(sqlQuery, tenantName, output, saving); err != nil { return err } } else if i == 1 { - sqlQuery = fmt.Sprintf("UPDATE %s SET saving=$1 WHERE (id=$2 AND output=$3);", dbparam.DbBucketAggregator) - if _, err := db.Exec(sqlQuery, saving, id, output); err != nil { + sqlQuery = fmt.Sprintf("UPDATE %s SET saving=$1 WHERE (tenantName=$2 AND output=$3);", dbparam.DbBucketAggregator) + if _, err := db.Exec(sqlQuery, saving, tenantName, output); err != nil { return err } } else { - return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, output=%s, have %d rows", dbparam.DbBucketAggregator, id, output, i) + return fmt.Errorf("error insert in postgresDb. Table:%s where tenantName=%s, output=%s, have %d rows", dbparam.DbBucketAggregator, tenantName, output, i) } return nil } -var insertInTableName = func(db *sqlx.DB, id, messageKey string, messageValue []byte, date *time.Time) error { +var insertInTableName = func(db *sqlx.DB, tenantName, messageKey string, messageValue []byte, date *time.Time) error { var i int - sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (id=$1 AND %s=$2)", dbparam.DbBucketName, "messageKey") - if err := db.Get(&i, sqlQuery, id, messageKey); err != nil { + sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (tenantName=$1 AND %s=$2)", dbparam.DbBucketName, "messageKey") + if err := db.Get(&i, sqlQuery, tenantName, messageKey); err != nil { return err } if i == 0 { - sqlQuery = fmt.Sprintf("INSERT INTO %s (id, %s, %s, %s) VALUES ($1, $2, $3, $4)", dbparam.DbBucketName, "date", "messagekey", "messagevalue") - if _, err := db.Exec(sqlQuery, id, date, messageKey, messageValue); err != nil { + sqlQuery = fmt.Sprintf("INSERT INTO %s (tenantName, %s, %s, %s) VALUES ($1, $2, $3, $4)", dbparam.DbBucketName, "date", "messagekey", "messagevalue") + if _, err := db.Exec(sqlQuery, tenantName, date, messageKey, messageValue); err != nil { return err } } else if i == 1 { - sqlQuery = fmt.Sprintf("UPDATE %s SET %s=$1, %s=$2 WHERE (id=$3 AND %s=$4);", dbparam.DbBucketName, "date", "messagevalue", "messagekey") - if _, err := db.Exec(sqlQuery, date, messageValue, id, messageKey); err != nil { + sqlQuery = fmt.Sprintf("UPDATE %s SET %s=$1, %s=$2 WHERE (tenantName=$3 AND %s=$4);", dbparam.DbBucketName, "date", "messagevalue", "messagekey") + if _, err := db.Exec(sqlQuery, date, messageValue, tenantName, messageKey); err != nil { return err } } else { - return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, messageKey=%s, have %d rows", dbparam.DbBucketName, id, messageKey, i) + return fmt.Errorf("error insert in postgresDb. Table:%s where tenantName=%s, messageKey=%s, have %d rows", dbparam.DbBucketName, tenantName, messageKey, i) } return nil } -var insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { +var insertOutputStats = func(db *sqlx.DB, tenantName, outputName string, amount int) error { var i int - sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (id=$1 AND %s=$2)", dbparam.DbBucketOutputStats, "outputName") - err := db.Get(&i, sqlQuery, id, outputName) + sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE (tenantName=$1 AND %s=$2)", dbparam.DbBucketOutputStats, "outputName") + err := db.Get(&i, sqlQuery, tenantName, outputName) if err != nil { return err } if i == 0 { - sqlQuery = fmt.Sprintf("INSERT INTO %s (id, %s, %s) VALUES ($1, $2, $3);", dbparam.DbBucketOutputStats, "outputName", "amount") - _, err := db.Exec(sqlQuery, id, outputName, amount) + sqlQuery = fmt.Sprintf("INSERT INTO %s (tenantName, %s, %s) VALUES ($1, $2, $3);", dbparam.DbBucketOutputStats, "outputName", "amount") + _, err := db.Exec(sqlQuery, tenantName, outputName, amount) if err != nil { return err } } else if i == 1 { - sqlQuery = fmt.Sprintf("UPDATE %s SET %s=$1 WHERE (id=$2 AND %s=$3);", dbparam.DbBucketOutputStats, "amount", "outputName") - _, err = db.Exec(sqlQuery, amount, id, outputName) + sqlQuery = fmt.Sprintf("UPDATE %s SET %s=$1 WHERE (tenantName=$2 AND %s=$3);", dbparam.DbBucketOutputStats, "amount", "outputName") + _, err = db.Exec(sqlQuery, amount, tenantName, outputName) if err != nil { return err } } else { - return fmt.Errorf("error insert in postgresDb. Table:%s where id=%s, outputName=%s, have %d rows", dbparam.DbBucketOutputStats, id, outputName, i) + return fmt.Errorf("error insert in postgresDb. Table:%s where tenantName=%s, outputName=%s, have %d rows", dbparam.DbBucketOutputStats, tenantName, outputName, i) } - return nil } diff --git a/dbservice/postgresdb/plgstats.go b/dbservice/postgresdb/plgstats.go index 283b4a80..fc29229d 100644 --- a/dbservice/postgresdb/plgstats.go +++ b/dbservice/postgresdb/plgstats.go @@ -16,16 +16,15 @@ func (postgresDb *PostgresDb) RegisterPlgnInvctn(name string) error { defer db.Close() amount := 0 - sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (id=$1 AND %s=$2)", "amount", dbparam.DbBucketOutputStats, "outputName") - err = db.Get(&amount, sqlQuery, postgresDb.Id, name) + sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (tenantName=$1 AND %s=$2)", "amount", dbparam.DbBucketOutputStats, "outputName") + err = db.Get(&amount, sqlQuery, postgresDb.TenantName, name) if err != nil && err != sql.ErrNoRows { return err } amount += 1 - err = insertOutputStats(db, postgresDb.Id, name, amount) + err = insertOutputStats(db, postgresDb.TenantName, name, amount) if err != nil { return err } - return nil } diff --git a/dbservice/postgresdb/plgstats_test.go b/dbservice/postgresdb/plgstats_test.go index d522af21..3c17962d 100644 --- a/dbservice/postgresdb/plgstats_test.go +++ b/dbservice/postgresdb/plgstats_test.go @@ -12,7 +12,7 @@ import ( func TestRegisterPlgnInvctn(t *testing.T) { receivedKey := 0 savedInsertOutputStats := insertOutputStats - insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { + insertOutputStats = func(db *sqlx.DB, tenantName, outputName string, amount int) error { receivedKey = amount return nil } @@ -53,7 +53,7 @@ func TestRegisterPlgnInvctnErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { savedInsertOutputStats := insertOutputStats - insertOutputStats = func(db *sqlx.DB, id, outputName string, amount int) error { return nil } + insertOutputStats = func(db *sqlx.DB, tenantName, outputName string, amount int) error { return nil } savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() diff --git a/dbservice/postgresdb/postgresdb.go b/dbservice/postgresdb/postgresdb.go index d77c5466..1a8490fd 100644 --- a/dbservice/postgresdb/postgresdb.go +++ b/dbservice/postgresdb/postgresdb.go @@ -2,12 +2,12 @@ package postgresdb type PostgresDb struct { ConnectUrl string - Id string + TenantName string } -func NewPostgresDb(id, connectUrl string) *PostgresDb { +func NewPostgresDb(tenantName, connectUrl string) *PostgresDb { return &PostgresDb{ ConnectUrl: connectUrl, - Id: id, + TenantName: tenantName, } } diff --git a/dbservice/postgresdb/sharedcfg.go b/dbservice/postgresdb/sharedcfg.go index 38ea810d..33e14741 100644 --- a/dbservice/postgresdb/sharedcfg.go +++ b/dbservice/postgresdb/sharedcfg.go @@ -21,10 +21,9 @@ func (postgresDb *PostgresDb) EnsureApiKey() error { return err } - if err = insertInTableSharedConfig(db, postgresDb.Id, apiKeyName, apiKey); err != nil { + if err = insertInTableSharedConfig(db, postgresDb.TenantName, apiKeyName, apiKey); err != nil { return err } - return nil } @@ -35,11 +34,10 @@ func (postgresDb *PostgresDb) GetApiKey() (string, error) { } defer db.Close() value := "" - sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (id=$1 AND %s=$2)", "value", dbparam.DbBucketSharedConfig, "apikeyname") - err = db.Get(&value, sqlQuery, postgresDb.Id, apiKeyName) + sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (tenantName=$1 AND %s=$2)", "value", dbparam.DbBucketSharedConfig, "apikeyname") + err = db.Get(&value, sqlQuery, postgresDb.TenantName, apiKeyName) if err != nil { return "", err } return value, nil - } diff --git a/dbservice/postgresdb/sharedcfg_test.go b/dbservice/postgresdb/sharedcfg_test.go index 4317a199..45246f08 100644 --- a/dbservice/postgresdb/sharedcfg_test.go +++ b/dbservice/postgresdb/sharedcfg_test.go @@ -11,7 +11,7 @@ import ( func TestApiKey(t *testing.T) { savedInsertInTableSharedConfig := insertInTableSharedConfig - insertInTableSharedConfig = func(db *sqlx.DB, id, apikeyname, value string) error { return nil } + insertInTableSharedConfig = func(db *sqlx.DB, tenantName, apikeyname, value string) error { return nil } savedPsqlConnect := psqlConnect psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() @@ -63,7 +63,7 @@ func TestApiKeyWithoutInit(t *testing.T) { func TestApiKeyRenewal(t *testing.T) { receivedKey := "" savedInsertInTableSharedConfig := insertInTableSharedConfig - insertInTableSharedConfig = func(db *sqlx.DB, id, apikeyname, value string) error { + insertInTableSharedConfig = func(db *sqlx.DB, tenantName, apikeyname, value string) error { receivedKey = value return nil } From 38d1929fd11679452ee1cf59fcdadba579db4f32 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Tue, 7 Dec 2021 21:13:14 +0600 Subject: [PATCH 29/61] clean up of db connection config --- dbservice/dbservice.go | 13 ++----------- router/api.go | 9 +++++---- router/router.go | 19 +++++++++++++++---- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/dbservice/dbservice.go b/dbservice/dbservice.go index 6ef67be2..5d2562d9 100644 --- a/dbservice/dbservice.go +++ b/dbservice/dbservice.go @@ -5,9 +5,7 @@ import ( "time" "github.com/aquasecurity/postee/dbservice/boltdb" - "github.com/aquasecurity/postee/dbservice/dbparam" "github.com/aquasecurity/postee/dbservice/postgresdb" - "github.com/aquasecurity/postee/utils" ) var ( @@ -24,17 +22,11 @@ type DbProvider interface { GetApiKey() (string, error) } -func ConfigureDb(pathToDb, postgresUrl, tenantName string, dBTestInterval *int, dbMaxSize int) error { - if *dBTestInterval == 0 { - *dBTestInterval = 1 - } - - postgresUrl = utils.GetEnvironmentVarOrPlain(postgresUrl) - pathToDb = utils.GetEnvironmentVarOrPlain(pathToDb) +func ConfigureDb(pathToDb, postgresUrl, tenantName string) error { if postgresUrl != "" { if tenantName == "" { - return errors.New("error configurate postgresDb: 'tenantName' is empty") + return errors.New("error configuring postgres: 'tenantName' is empty") } postgresDb := postgresdb.NewPostgresDb(tenantName, postgresUrl) if err := postgresdb.InitPostgresDb(postgresDb.ConnectUrl); err != nil { @@ -50,6 +42,5 @@ func ConfigureDb(pathToDb, postgresUrl, tenantName string, dBTestInterval *int, } Db = boltdb } - dbparam.DbSizeLimit = dbMaxSize return nil } diff --git a/router/api.go b/router/api.go index e59d236b..045008b0 100644 --- a/router/api.go +++ b/router/api.go @@ -5,6 +5,7 @@ import ( "os" "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/dbservice" "github.com/aquasecurity/postee/routes" ) @@ -27,20 +28,20 @@ func WithDefaultConfig() error { } func WithFileConfig(cfgPath string) error { Instance().Terminate() - // dbservice.DbPath = defaultDbPath + dbservice.ConfigureDb(defaultDbPath, "", "") return Instance().ApplyFileCfg(cfgPath, true) } func WithNewConfig(tenantName string) { //tenant name Instance().Terminate() - // dbservice.DbPath = defaultDbPath + dbservice.ConfigureDb(defaultDbPath, "", "") Instance().initCfg(true) } //initialize instance with custom db location func WithNewConfigAndDbPath(tenantName, dbPath string) { //tenant name Instance().Terminate() - // dbservice.DbPath = dbPath + dbservice.ConfigureDb(defaultDbPath, "", "") Instance().initCfg(true) } @@ -50,7 +51,7 @@ func WithDefaultConfigAndDbPath(dbPath string) error { func WithFileConfigAndDbPath(cfgPath, dbPath string) error { Instance().Terminate() - // dbservice.DbPath = dbPath + dbservice.ConfigureDb(dbPath, "", "") return Instance().ApplyFileCfg(cfgPath, true) } diff --git a/router/router.go b/router/router.go index 4e80d181..6c82b2b6 100644 --- a/router/router.go +++ b/router/router.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "log" "net/http" + "os" "path" "strings" "sync" @@ -14,6 +15,7 @@ import ( "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/dbparam" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/msgservice" "github.com/aquasecurity/postee/outputs" @@ -253,14 +255,23 @@ func (ctx *Router) load() error { } ctx.setAquaServerUrl(tenant.AquaServer) - //---------------------------------------------------- - // TODO there should be some other way of doing that - if err = dbservice.ConfigureDb("$PATH_TO_DB", "$POSTGRES_URL", tenant.Name, &tenant.DBTestInterval, tenant.DBMaxSize); err != nil { + postgresUrl := os.Getenv("$POSTGRES_URL") + pathToDb := os.Getenv("$PATH_TO_DB") + + if err = dbservice.ConfigureDb(pathToDb, postgresUrl, tenant.Name); err != nil { return err } - ctx.ticker = time.NewTicker(baseForTicker * time.Duration(tenant.DBTestInterval)) + dbparam.DbSizeLimit = tenant.DBMaxSize + + actualDbTestInterval := tenant.DBTestInterval + + if tenant.DBTestInterval == 0 { + actualDbTestInterval = 1 + } + + ctx.ticker = time.NewTicker(baseForTicker * time.Duration(actualDbTestInterval)) go func() { for { select { From ac3f8360fbadf4999c9f425cb646259ea0f47b02 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Wed, 8 Dec 2021 20:40:08 +0600 Subject: [PATCH 30/61] exposed API to use postgres --- dbservice/dbservice_test.go | 43 +++++++------------------------ router/api.go | 38 ++++++++++++++++++++++++++++ router/api_integration_test.go | 46 ++++++++++++++++++++++++++++++++++ router/api_test.go | 42 +++++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 34 deletions(-) diff --git a/dbservice/dbservice_test.go b/dbservice/dbservice_test.go index 5a2df17d..638fe6a5 100644 --- a/dbservice/dbservice_test.go +++ b/dbservice/dbservice_test.go @@ -4,7 +4,6 @@ import ( "errors" "os" "reflect" - "strings" "testing" "github.com/aquasecurity/postee/dbservice/postgresdb" @@ -14,26 +13,17 @@ func TestConfigurateBoltDbPathUsedEnv(t *testing.T) { tests := []struct { name string dbPath string - dbPathInEnv string expectedPath string }{ - {"happy configuration BoltDB with dbPath", "database/webhooks.db", "", "database/webhooks.db"}, - {"happy configuration BoltDB with env", "$PATH_TO_DB", "database/envPath/webhooks.db", "database/envPath/webhooks.db"}, - {"happy configuration BoltDB with empty dbPath", "", "", "/server/database/webhooks.db"}, + {"happy configuration BoltDB with dbPath", "database/webhooks.db", "database/webhooks.db"}, + {"happy configuration BoltDB with empty dbPath", "", "/server/database/webhooks.db"}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if test.dbPathInEnv != "" { - oldPathEnv := os.Getenv("PATH_TO_DB") - defer func() { - os.Setenv("PATH_TO_DB", oldPathEnv) - }() - os.Setenv("PATH_TO_DB", test.dbPathInEnv) - } testInterval := 2 - if err := ConfigureDb(test.dbPath, "", "", &testInterval, 1); err != nil { + if err := ConfigureDb(test.dbPath, "", ""); err != nil { t.Errorf("Unexpected error: %v", err) } if testInterval != 2 { @@ -51,14 +41,12 @@ func TestConfiguratePostgresDbUrlAndTenantName(t *testing.T) { tests := []struct { name string url string - urlInEnv string tenantName string expectedError error }{ - {"happy configuration postgres with url", "postgresql://user:secret@localhost", "", "test-tenantName", nil}, - {"happy configuration postgres with env", "$POSTGRES_URL", "postgresql://user:secret@localhost", "test-tenantName", nil}, - {"bad tenantName", "postgresql://user:secret@localhost", "", "", errors.New("error configurate postgresDb: 'tenantName' is empty")}, - {"bad url", "badUrl", "", "test-tenantName", errors.New("badUrl error")}, + {"happy configuration postgres with url", "postgresql://user:secret@localhost", "test-tenantName", nil}, + {"bad tenantName", "postgresql://user:secret@localhost", "", errors.New("error configuring postgres: 'tenantName' is empty")}, + {"bad url", "badUrl", "test-tenantName", errors.New("badUrl error")}, } for _, test := range tests { @@ -70,31 +58,18 @@ func TestConfiguratePostgresDbUrlAndTenantName(t *testing.T) { } return nil } - oldUrlEnv := os.Getenv("POSTGRES_URL") - os.Setenv("POSTGRES_URL", test.urlInEnv) defer func() { postgresdb.InitPostgresDb = initPostgresDbSaved - os.Setenv("POSTGRES_URL", oldUrlEnv) }() - testInterval := 0 - err := ConfigureDb("", test.url, test.tenantName, &testInterval, 1) + err := ConfigureDb("", test.url, test.tenantName) if err != nil { if err.Error() != test.expectedError.Error() { t.Errorf("Unexpected error, expected: %s, got: %s", test.expectedError, err) } } else { - if testInterval != 1 { - t.Error("test interval error, expected: 1, got: ", testInterval) - } - if strings.HasPrefix(test.url, "$") { - if test.urlInEnv != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface() { - t.Errorf("url's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface()) - } - } else { - if test.url != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface() { - t.Errorf("url's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface()) - } + if test.url != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface() { + t.Errorf("url's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("ConnectUrl").Interface()) } if test.tenantName != reflect.Indirect(reflect.ValueOf(Db)).FieldByName("TenantName").Interface() { t.Errorf("tenantName's do not match, expected: %s, got: %s", test.url, reflect.Indirect(reflect.ValueOf(Db)).FieldByName("TenantName").Interface()) diff --git a/router/api.go b/router/api.go index 045008b0..5d36750a 100644 --- a/router/api.go +++ b/router/api.go @@ -2,6 +2,8 @@ package router import ( "bytes" + "fmt" + "net/url" "os" "github.com/aquasecurity/postee/data" @@ -55,6 +57,19 @@ func WithFileConfigAndDbPath(cfgPath, dbPath string) error { return Instance().ApplyFileCfg(cfgPath, true) } +func WithPostgresParams(tenantName, dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) { + postgresUrl := buildPostgresUrl(dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode) + Instance().Terminate() + dbservice.ConfigureDb("", postgresUrl, tenantName) + Instance().initCfg(true) +} + +func WithPostgresUrl(tenantName, postgresUrl string) { + Instance().Terminate() + dbservice.ConfigureDb("", postgresUrl, tenantName) + Instance().initCfg(true) +} + func AquaServerUrl(aquaServerUrl string) { //optional Instance().setAquaServerUrl(aquaServerUrl) } @@ -161,3 +176,26 @@ func Send(b []byte) { //Instance().Send(b) Instance().handle(bytes.ReplaceAll(b, []byte{'`'}, []byte{'\''})) } + +func buildPostgresUrl(dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) string { + hostname := dbHostName + + if dbPort != "" { + hostname += fmt.Sprintf(":%s", dbPort) + } + + rawQuery := "" + + if dbSslMode != "" { + rawQuery = fmt.Sprintf("sslmode=%s", dbSslMode) + } + + url := url.URL{ + Scheme: "postgres", + Host: hostname, + Path: dbName, + User: url.UserPassword(dbUser, dbPassword), + RawQuery: rawQuery, + } + return url.String() +} diff --git a/router/api_integration_test.go b/router/api_integration_test.go index 1b35ba91..77bae892 100644 --- a/router/api_integration_test.go +++ b/router/api_integration_test.go @@ -77,3 +77,49 @@ func TestAudit(t *testing.T) { got := <-received assert.Equal(t, string(got), want, "unexpected response") } + +/* +TODO figure out how to run integration test with Postgres DB +func TestAuditWithPostgres(t *testing.T) { + received := make(chan ([]byte)) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed ioutil.ReadAll: %s\n", err) + received <- []byte{} + return + } + + received <- body + + defer r.Body.Close() + })) + defer ts.Close() + + router.WithPostgresParams("my-postee", "posteedb", "localhost", "", "postee", "postee123", "") + + err := router.AddTemplate(&data.Template{ + Name: "audit-json-template", + Body: rego, + }) + if err != nil { + t.Logf("Error: %v", err) + return + } + router.AddOutput(&data.OutputSettings{ + Name: "test-webhook", + Type: "webhook", + Enable: true, + Url: ts.URL, + }) + + router.AddRoute(&routes.InputRoute{ + Name: "test", + Outputs: []string{"test-webhook"}, + Template: "audit-json-template", + }) + router.Send([]byte(msg)) + got := <-received + assert.Equal(t, string(got), want, "unexpected response") +}*/ diff --git a/router/api_test.go b/router/api_test.go index 99614129..a24441bb 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -72,3 +72,45 @@ func TestListOutput(t *testing.T) { //TODO templates //TODO routes + +func TestBuildPosgresUrl(t *testing.T) { + tests := []struct { + caseDesc string + username string + password string + port string + dbName string + dbHostName string + dbSslMode string + expectedUrl string + }{ + { + "all parameters specified", + "admin", + "admin", + "5433", + "postee", + "localhost", + "prefer", + "postgres://admin:admin@localhost:5433/postee?sslmode=prefer", + }, + { + "minimal parameters", + "admin", + "admin", + "", + "postee", + "localhost", + "", + "postgres://admin:admin@localhost/postee", + }, + } + + for _, test := range tests { + t.Run(test.caseDesc, func(t *testing.T) { + url := buildPostgresUrl(test.dbName, test.dbHostName, test.port, test.username, test.password, test.dbSslMode) + assert.Equal(t, test.expectedUrl, url) + }) + } + +} From de297df4c945ea195837843904777c3ae1f6e9de Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Mon, 13 Dec 2021 13:47:42 +0600 Subject: [PATCH 31/61] feat:added copy of cfgFile in psql --- dbservice/dbparam/dbparam.go | 11 +- dbservice/postgresdb/cfgcachesource.go | 35 +++++ dbservice/postgresdb/cfgcachesource_test.go | 39 +++++ dbservice/postgresdb/dbservice_test.go | 110 +++++++++++++ dbservice/postgresdb/init.go | 1 + dbservice/postgresdb/insert.go | 25 +++ router/api.go | 6 +- router/api_test.go | 44 ++++++ router/router.go | 164 ++++++++++++++++---- routes/routes.go | 14 +- 10 files changed, 408 insertions(+), 41 deletions(-) create mode 100644 dbservice/postgresdb/cfgcachesource.go create mode 100644 dbservice/postgresdb/cfgcachesource_test.go diff --git a/dbservice/dbparam/dbparam.go b/dbservice/dbparam/dbparam.go index 2ff87429..0cd02044 100644 --- a/dbservice/dbparam/dbparam.go +++ b/dbservice/dbparam/dbparam.go @@ -8,11 +8,12 @@ import ( ) var ( - DbBucketName = "WebhookBucket" - DbBucketAggregator = "WebhookAggregator" - DbBucketExpiryDates = "WebhookExpiryDates" - DbBucketOutputStats = "WebhookOutputStats" - DbBucketSharedConfig = "WebhookSharedConfig" + DbBucketName = "WebhookBucket" + DbBucketAggregator = "WebhookAggregator" + DbBucketExpiryDates = "WebhookExpiryDates" + DbBucketOutputStats = "WebhookOutputStats" + DbBucketSharedConfig = "WebhookSharedConfig" + DbTableCfgCacheSource = "WebhookCfgCacheSource" DbSizeLimit = 0 DateFmt = time.RFC3339Nano diff --git a/dbservice/postgresdb/cfgcachesource.go b/dbservice/postgresdb/cfgcachesource.go new file mode 100644 index 00000000..8ab32bcf --- /dev/null +++ b/dbservice/postgresdb/cfgcachesource.go @@ -0,0 +1,35 @@ +package postgresdb + +import ( + "fmt" + + "github.com/aquasecurity/postee/dbservice/dbparam" +) + +var UpdateCfgCacheSource = func(postgresDb *PostgresDb, cfgfile string) error { + connectUrl := postgresDb.ConnectUrl + db, err := psqlConnect(connectUrl) + if err != nil { + return err + } + defer db.Close() + if err := insertCfgCacheSource(db, postgresDb.TenantName, cfgfile); err != nil { + return err + } + return nil +} + +var GetCfgCacheSource = func(postgresDb *PostgresDb) (string, error) { + connectUrl := postgresDb.ConnectUrl + db, err := psqlConnect(connectUrl) + if err != nil { + return "", err + } + defer db.Close() + cfgFile := "" + sqlQuery := fmt.Sprintf("SELECT configfile FROM %s WHERE tenantName=$1", dbparam.DbTableCfgCacheSource) + if err = db.Get(&cfgFile, sqlQuery, postgresDb.TenantName); err != nil { + return "", err + } + return cfgFile, nil +} diff --git a/dbservice/postgresdb/cfgcachesource_test.go b/dbservice/postgresdb/cfgcachesource_test.go new file mode 100644 index 00000000..91298f0b --- /dev/null +++ b/dbservice/postgresdb/cfgcachesource_test.go @@ -0,0 +1,39 @@ +package postgresdb + +import ( + "log" + "testing" + + "github.com/jmoiron/sqlx" + sqlxmock "github.com/zhashkevych/go-sqlxmock" +) + +func TestUpdateCfgCacheSource(t *testing.T) { + cfgFile := `{"name": "tenant", "aqua-server": "https://myserver.aquasec.com"}` + savedPsqlConnect := psqlConnect + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"cfgFile"}).AddRow(cfgFile) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + savedInsertCfgCacheSource := insertCfgCacheSource + insertCfgCacheSource = func(db *sqlx.DB, tenantName, cfgFile string) error { return nil } + defer func() { + psqlConnect = savedPsqlConnect + insertCfgCacheSource = savedInsertCfgCacheSource + }() + + UpdateCfgCacheSource(db, "cfgFile") + + cfg, err := GetCfgCacheSource(db) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if cfgFile != cfg { + t.Errorf("CfgFiles not equals, expected: %s, got: %s", cfgFile, cfg) + } +} diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index c1abaede..73559a94 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -639,3 +639,113 @@ func TestInsertInTableName(t *testing.T) { t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", badUpdateError, err) } } + +func TestInsertCfgCacheSource(t *testing.T) { + t.Log("happy insert") + savedPsqlConnect := psqlConnect + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + + psqlDb, _ := psqlConnect(db.ConnectUrl) + err := insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") + if err != nil { + t.Errorf("Unexpected error in 'insertCfgCacheSource': %v", err) + } + + t.Log("happy update") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") + if err != nil { + t.Errorf("Unexpected error in 'insertCfgCacheSource': %v", err) + } + + t.Log("select error") + selectError := errors.New("select error") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectQuery("SELECT").WillReturnError(selectError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") + if err != selectError { + t.Errorf("Unexpected error in 'insertCfgCacheSource', expected: %v, got: %v", selectError, err) + } + + t.Log("select 2 rows") + select2RowsError := "error insert in postgresDb. Table:WebhookCfgCacheSource where tenantName=tenantName, have 2 rows" + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") + if err.Error() != select2RowsError { + t.Errorf("Unexpected error in 'insertCfgCacheSource', expected: %v, got: %v", select2RowsError, err) + } + + t.Log("bad insert") + badInsertError := errors.New("bad insert") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("INSERT").WillReturnError(badInsertError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") + if err != badInsertError { + t.Errorf("Unexpected error in 'insertCfgCacheSource', expected: %v, got: %v", badInsertError, err) + } + + t.Log("bad update") + badUpdateError := errors.New("bad update") + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) + return db, err + } + psqlDb, _ = psqlConnect(db.ConnectUrl) + err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") + if err != badUpdateError { + t.Errorf("Unexpected error in 'insertCfgCacheSource', expected: %v, got: %v", badUpdateError, err) + } +} diff --git a/dbservice/postgresdb/init.go b/dbservice/postgresdb/init.go index cae93800..cc374050 100644 --- a/dbservice/postgresdb/init.go +++ b/dbservice/postgresdb/init.go @@ -13,6 +13,7 @@ var ( fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (tenantName varchar(32), output varchar(32), saving bytea);", dbparam.DbBucketAggregator), fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (tenantName varchar(32), outputname varchar(32), amount integer);", dbparam.DbBucketOutputStats), fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (tenantName varchar(32), apikeyname varchar(14),value varchar(64));", dbparam.DbBucketSharedConfig), + fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (tenantName varchar(32), configfile text);", dbparam.DbTableCfgCacheSource), } ) diff --git a/dbservice/postgresdb/insert.go b/dbservice/postgresdb/insert.go index 57882a55..ace1d8be 100644 --- a/dbservice/postgresdb/insert.go +++ b/dbservice/postgresdb/insert.go @@ -98,3 +98,28 @@ var insertOutputStats = func(db *sqlx.DB, tenantName, outputName string, amount } return nil } + +var insertCfgCacheSource = func(db *sqlx.DB, tenantName, cfgFile string) error { + var i int + sqlQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE tenantName=$1", dbparam.DbTableCfgCacheSource) + err := db.Get(&i, sqlQuery, tenantName) + if err != nil { + return err + } + if i == 0 { + sqlQuery = fmt.Sprintf("INSERT INTO %s (tenantName, configfile) VALUES ($1, $2);", dbparam.DbTableCfgCacheSource) + _, err := db.Exec(sqlQuery, tenantName, cfgFile) + if err != nil { + return err + } + } else if i == 1 { + sqlQuery = fmt.Sprintf("UPDATE %s SET configfile=$1 WHERE tenantName=$2;", dbparam.DbTableCfgCacheSource) + _, err = db.Exec(sqlQuery, cfgFile, tenantName) + if err != nil { + return err + } + } else { + return fmt.Errorf("error insert in postgresDb. Table:%s where tenantName=%s, have %d rows", dbparam.DbTableCfgCacheSource, tenantName, i) + } + return nil +} diff --git a/router/api.go b/router/api.go index 5d36750a..2cb80cd3 100644 --- a/router/api.go +++ b/router/api.go @@ -62,12 +62,14 @@ func WithPostgresParams(tenantName, dbName, dbHostName, dbPort, dbUser, dbPasswo Instance().Terminate() dbservice.ConfigureDb("", postgresUrl, tenantName) Instance().initCfg(true) + Instance().load(true) } func WithPostgresUrl(tenantName, postgresUrl string) { Instance().Terminate() dbservice.ConfigureDb("", postgresUrl, tenantName) Instance().initCfg(true) + Instance().load(true) } func AquaServerUrl(aquaServerUrl string) { //optional @@ -125,7 +127,7 @@ func UpdateRoute(route *routes.InputRoute) error { //-------------------Templates------------------- func AddTemplate(template *data.Template) error { - return Instance().initTemplate(template) + return Instance().addTemplate(template) } //helper method @@ -150,7 +152,7 @@ func UpdateTemplate(template *data.Template) error { return err } - return Instance().initTemplate(template) + return Instance().addTemplate(template) } func DeleteTemplate(name string) error { diff --git a/router/api_test.go b/router/api_test.go index a24441bb..21c69587 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/postgresdb" "github.com/aquasecurity/postee/outputs" "github.com/stretchr/testify/assert" ) @@ -114,3 +116,45 @@ func TestBuildPosgresUrl(t *testing.T) { } } + +func TestSaveLoadCfgInPostgres(t *testing.T) { + savedCfgInPsql := "" + expectedCfgJson := `{"name":"tenantName","aqua-server":"https://myserver.aquasec.com","outputs":null,"routes":null,"templates":null}` + router := Router{ + databaseCfgCacheSource: &data.TenantSettings{ + Name: "tenantName", + AquaServer: "https://myserver.aquasec.com", + }, + } + dbservice.Db = postgresdb.NewPostgresDb("tenantName", "connectUrl") + savedUpdateCfgCacheSource := postgresdb.UpdateCfgCacheSource + postgresdb.UpdateCfgCacheSource = func(postgresDb *postgresdb.PostgresDb, cfgfile string) error { + savedCfgInPsql = cfgfile + return nil + } + savedGetCfgCacheSource := postgresdb.GetCfgCacheSource + postgresdb.GetCfgCacheSource = func(postgresDb *postgresdb.PostgresDb) (string, error) { + return savedCfgInPsql, nil + } + defer func() { + postgresdb.UpdateCfgCacheSource = savedUpdateCfgCacheSource + postgresdb.GetCfgCacheSource = savedGetCfgCacheSource + }() + + if err := router.saveCfgCacheSourceInPostgres(); err != nil { + t.Errorf("Unexpected error: %v", err) + } + if expectedCfgJson != savedCfgInPsql { + t.Errorf("cfg marshal error, expected: %s, got: %s", expectedCfgJson, savedCfgInPsql) + } + tenant, err := router.loadCfgCacheSourceFromPostgres() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if router.databaseCfgCacheSource.Name != tenant.Name { + t.Errorf("names are not equals, expected: %s, got: %s", router.databaseCfgCacheSource.Name, tenant.Name) + } + if router.databaseCfgCacheSource.AquaServer != tenant.AquaServer { + t.Errorf("AquaServers are not equals, expected: %s, got: %s", router.databaseCfgCacheSource.AquaServer, tenant.AquaServer) + } +} diff --git a/router/router.go b/router/router.go index 6c82b2b6..8328d742 100644 --- a/router/router.go +++ b/router/router.go @@ -16,6 +16,7 @@ import ( "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/dbservice" "github.com/aquasecurity/postee/dbservice/dbparam" + "github.com/aquasecurity/postee/dbservice/postgresdb" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/msgservice" "github.com/aquasecurity/postee/outputs" @@ -34,18 +35,19 @@ const ( ) type Router struct { - mutexScan sync.Mutex - quit chan struct{} - queue chan []byte - ticker *time.Ticker - stopTicker chan struct{} - cfgfile string - aquaServer string - outputs map[string]outputs.Output - inputRoutes map[string]*routes.InputRoute - templates map[string]data.Inpteval - synchronous bool - inputCallBacks map[string][]InputCallbackFunc + mutexScan sync.Mutex + quit chan struct{} + queue chan []byte + ticker *time.Ticker + stopTicker chan struct{} + cfgfile string + aquaServer string + outputs map[string]outputs.Output + inputRoutes map[string]*routes.InputRoute + templates map[string]data.Inpteval + synchronous bool + inputCallBacks map[string][]InputCallbackFunc + databaseCfgCacheSource *data.TenantSettings } var ( @@ -61,11 +63,12 @@ var ( func Instance() *Router { initCtx.Do(func() { routerCtx = &Router{ - mutexScan: sync.Mutex{}, - outputs: make(map[string]outputs.Output), - inputRoutes: make(map[string]*routes.InputRoute), - templates: make(map[string]data.Inpteval), - synchronous: false, + mutexScan: sync.Mutex{}, + outputs: make(map[string]outputs.Output), + inputRoutes: make(map[string]*routes.InputRoute), + templates: make(map[string]data.Inpteval), + synchronous: false, + databaseCfgCacheSource: &data.TenantSettings{}, } }) return routerCtx @@ -102,7 +105,7 @@ func (ctx *Router) ApplyFileCfg(cfgfile string, synchronous bool) error { ctx.initCfg(synchronous) - err := ctx.load() + err := ctx.load(false) if err != nil { return err } @@ -155,12 +158,25 @@ func (ctx *Router) cleanInstance() { func (ctx *Router) Send(data []byte) { ctx.queue <- data } + +func (ctx *Router) addTemplate(template *data.Template) error { + if err := ctx.initTemplate(template); err != nil { + return err + } + + ctx.databaseCfgCacheSource.Templates = append(ctx.databaseCfgCacheSource.Templates, *template) + if err := ctx.saveCfgCacheSourceInPostgres(); err != nil { + return err + } + return nil +} + func (ctx *Router) deleteTemplate(name string, removeFromRoutes bool) error { - _, ok := ctx.outputs[name] + _, ok := ctx.templates[name] if !ok { return xerrors.Errorf("template %s is not found", name) } - delete(ctx.outputs, name) + delete(ctx.templates, name) if removeFromRoutes { for _, route := range ctx.inputRoutes { @@ -170,9 +186,23 @@ func (ctx *Router) deleteTemplate(name string, removeFromRoutes bool) error { } } + removeTemplateFromCfgCacheSource(ctx.databaseCfgCacheSource, name) + if err := ctx.saveCfgCacheSourceInPostgres(); err != nil { + return err + } return nil } +func removeTemplateFromCfgCacheSource(outputs *data.TenantSettings, templateName string) { + filtered := make([]data.Template, 0) + for _, template := range outputs.Templates { + if template.Name != templateName { + filtered = append(filtered, template) + } + } + outputs.Templates = filtered +} + func (ctx *Router) initTemplate(template *data.Template) error { log.Printf("Configuring template %s \n", template.Name) @@ -241,23 +271,36 @@ func (ctx *Router) setAquaServerUrl(url string) { } ctx.aquaServer = fmt.Sprintf("%s%s#/images/", url, slash) } - + ctx.databaseCfgCacheSource.AquaServer = url + if err := ctx.saveCfgCacheSourceInPostgres(); err != nil { + log.Printf("Can't save cfgSource Source: %v", err) + } } -func (ctx *Router) load() error { +func (ctx *Router) load(loadCfgFromPostgres bool) error { ctx.mutexScan.Lock() defer ctx.mutexScan.Unlock() log.Printf("Loading alerts configuration file %s ....\n", ctx.cfgfile) - tenant, err := Parsev2cfg(ctx.cfgfile) - if err != nil { - return err + var tenant = &data.TenantSettings{} + var err error + if loadCfgFromPostgres { + tenant, err = ctx.loadCfgCacheSourceFromPostgres() + if err != nil { + return err + } + } else { + tenant, err = Parsev2cfg(ctx.cfgfile) + if err != nil { + return err + } + ctx.databaseCfgCacheSource = tenant } ctx.setAquaServerUrl(tenant.AquaServer) - postgresUrl := os.Getenv("$POSTGRES_URL") - pathToDb := os.Getenv("$PATH_TO_DB") + postgresUrl := os.Getenv("POSTGRES_URL") + pathToDb := os.Getenv("PATH_TO_DB") if err = dbservice.ConfigureDb(pathToDb, postgresUrl, tenant.Name); err != nil { return err @@ -318,6 +361,10 @@ func (ctx *Router) setInputCallbackFunc(routeName string, callback InputCallback func (ctx *Router) addRoute(r *routes.InputRoute) { ctx.inputRoutes[r.Name] = routes.ConfigureTimeouts(r) + ctx.databaseCfgCacheSource.InputRoutes = append(ctx.databaseCfgCacheSource.InputRoutes, *r) + if err := ctx.saveCfgCacheSourceInPostgres(); err != nil { + log.Printf("Can't save cfgSource Source: %v", err) + } } func (ctx *Router) deleteRoute(name string) error { @@ -329,6 +376,11 @@ func (ctx *Router) deleteRoute(name string) error { delete(ctx.inputRoutes, name) delete(ctx.inputCallBacks, name) + removeRouteFromCfgCacheSource(ctx.databaseCfgCacheSource, name) + if err := ctx.saveCfgCacheSourceInPostgres(); err != nil { + return err + } + return nil } @@ -353,6 +405,16 @@ func (ctx *Router) listRoutes() []routes.InputRoute { return list } +func removeRouteFromCfgCacheSource(outputs *data.TenantSettings, routeName string) { + filtered := make([]routes.InputRoute, 0) + for _, route := range outputs.InputRoutes { + if route.Name != routeName { + filtered = append(filtered, route) + } + } + outputs.InputRoutes = filtered +} + func (ctx *Router) addOutput(settings *data.OutputSettings) error { if settings.Enable { plg, err := buildAndInitOtpt(settings, ctx.aquaServer) @@ -364,6 +426,11 @@ func (ctx *Router) addOutput(settings *data.OutputSettings) error { ctx.outputs[settings.Name] = plg } + + ctx.databaseCfgCacheSource.Outputs = append(ctx.databaseCfgCacheSource.Outputs, *settings) + if err := ctx.saveCfgCacheSourceInPostgres(); err != nil { + return err + } return nil } func (ctx *Router) deleteOutput(outputName string, removeFromRoutes bool) error { @@ -379,6 +446,10 @@ func (ctx *Router) deleteOutput(outputName string, removeFromRoutes bool) error removeOutputFromRoute(route, outputName) } } + removeOutputFromCfgCacheSource(ctx.databaseCfgCacheSource, outputName) + if err := ctx.saveCfgCacheSourceInPostgres(); err != nil { + return err + } return nil } @@ -399,6 +470,45 @@ func removeOutputFromRoute(r *routes.InputRoute, outputName string) { r.Outputs = filtered } +func removeOutputFromCfgCacheSource(outputs *data.TenantSettings, outputName string) { + filtered := make([]data.OutputSettings, 0) + for _, output := range outputs.Outputs { + if output.Name != outputName { + filtered = append(filtered, output) + } + } + outputs.Outputs = filtered +} + +func (ctx *Router) saveCfgCacheSourceInPostgres() error { + cfg := ctx.databaseCfgCacheSource + if postgresDb, ok := dbservice.Db.(*postgresdb.PostgresDb); ok { + cfgFile, err := json.Marshal(cfg) + if err != nil { + return err + } + if err = postgresdb.UpdateCfgCacheSource(postgresDb, string(cfgFile)); err != nil { + return err + } + } + return nil +} + +func (ctx *Router) loadCfgCacheSourceFromPostgres() (*data.TenantSettings, error) { + cfg := &data.TenantSettings{} + if postgresDb, ok := dbservice.Db.(*postgresdb.PostgresDb); ok { + cfgFile, err := postgresdb.GetCfgCacheSource(postgresDb) + if err != nil { + return cfg, err + } + err = json.Unmarshal([]byte(cfgFile), &cfg) + if err != nil { + return cfg, err + } + } + return cfg, nil +} + type service interface { MsgHandling(input map[string]interface{}, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, aquaServer *string) } diff --git a/routes/routes.go b/routes/routes.go index 513fbad0..ee991287 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -1,13 +1,13 @@ package routes type InputRoute struct { - Name string `json:"name"` - Input string `json:"input"` - InputFiles []string `json:"input-files"` - Outputs []string `json:"outputs"` - Plugins Plugins `json:"plugins"` - Template string `json:"template"` - Scheduling chan struct{} + Name string `json:"name"` + Input string `json:"input"` + InputFiles []string `json:"input-files"` + Outputs []string `json:"outputs"` + Plugins Plugins `json:"plugins"` + Template string `json:"template"` + Scheduling chan struct{} `json:"-"` } type Plugins struct { From f40b5b70f3ba5ba362f252714e9198ad60a31d5a Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Mon, 13 Dec 2021 17:39:00 +0600 Subject: [PATCH 32/61] test: changed tests for insert into psql --- dbservice/postgresdb/dbservice_test.go | 785 ++++++------------------- 1 file changed, 164 insertions(+), 621 deletions(-) diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index 73559a94..5794da2a 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -110,642 +110,185 @@ func TestInitError(t *testing.T) { } func TestDeleteRowsByTenantNameAndTime(t *testing.T) { - t.Log("happy delete row") - savedPsqlConnect := psqlConnect - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRow := deleteRowsByTenantNameAndTime - defer func() { - deleteRowsByTenantNameAndTime = savedDeleteRow - }() - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - mock.ExpectExec("DELETE").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - defer func() { - psqlConnect = savedPsqlConnect - }() - - psqlDb, _ := psqlConnect(db.ConnectUrl) - err := deleteRowsByTenantNameAndTime(psqlDb, "tenantName", time.Now()) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - t.Log("bad delete row") - deleteError := errors.New("delete - error") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRow := deleteRowsByTenantNameAndTime - defer func() { - deleteRowsByTenantNameAndTime = savedDeleteRow - }() - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - mock.ExpectExec("DELETE").WillReturnError(deleteError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = deleteRowsByTenantNameAndTime(psqlDb, "tenantName", time.Now()) - if deleteError != err { - t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) + tests := []struct { + name string + wasError bool + expectedError error + }{ + {"happy delete rows by tenantName", false, nil}, + {"bad delete row by tenantName", true, errors.New("delete rows error")}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + savedDeleteRowsByTenantName := deleteRowsByTenantName + defer func() { + deleteRowsByTenantName = savedDeleteRowsByTenantName + }() + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + if test.wasError { + mock.ExpectExec("DELETE").WillReturnError(test.expectedError) + } else { + mock.ExpectExec("DELETE").WillReturnResult(sqlxmock.NewResult(1, 1)) + } + return db, err + } + psqlDb, _ := psqlConnect(db.ConnectUrl) + err := deleteRowsByTenantNameAndTime(psqlDb, "tenantName", time.Now()) + if test.expectedError != err { + t.Errorf("Unexpected error, expected: %v, got: %v", test.expectedError, err) + } + }) } } func TestDeleteRowsByTenantName(t *testing.T) { - t.Log("happy delete rows by tenantName") - savedPsqlConnect := psqlConnect - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRowsByTenantName := deleteRowsByTenantName - defer func() { - deleteRowsByTenantName = savedDeleteRowsByTenantName - }() - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - mock.ExpectExec("DELETE").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - defer func() { - psqlConnect = savedPsqlConnect - }() - - psqlDb, _ := psqlConnect(db.ConnectUrl) - err := deleteRowsByTenantName(psqlDb, "table", "tenantName") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - t.Log("bad delete row") - deleteError := errors.New("delete - error") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - savedDeleteRowsByTenantName := deleteRowsByTenantName - defer func() { - deleteRowsByTenantName = savedDeleteRowsByTenantName - }() - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - mock.ExpectExec("DELETE").WillReturnError(deleteError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = deleteRowsByTenantName(psqlDb, "table", "tenantName") - if deleteError != err { - t.Errorf("Unexpected error, expected: %v, got: %v", deleteError, err) - } -} - -func TestInsertInTableSharedConfig(t *testing.T) { - t.Log("happy insert") - savedPsqlConnect := psqlConnect - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - defer func() { - psqlConnect = savedPsqlConnect - }() - - psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") - if err != nil { - t.Errorf("Unexpected error in 'insertInTableSharedConfig': %v", err) - } - - t.Log("happy update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") - if err != nil { - t.Errorf("Unexpected error in 'insertInTableSharedConfig': %v", err) - } - - t.Log("select error") - selectError := errors.New("select error") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - mock.ExpectQuery("SELECT").WillReturnError(selectError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") - if err != selectError { - t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", selectError, err) - } - - t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookSharedConfig where tenantName=tenantName, apikeyname=value2, have 2 rows" - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") - if err.Error() != select2RowsError { - t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", select2RowsError, err) - } - - t.Log("bad insert") - badInsertError := errors.New("bad insert") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) + tests := []struct { + name string + wasError bool + expectedError error + }{ + {"happy delete rows by tenantName", false, nil}, + {"bad delete row by tenantName", true, errors.New("delete rows error")}, + } + for _, test := range tests { + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + savedDeleteRowsByTenantName := deleteRowsByTenantName + defer func() { + deleteRowsByTenantName = savedDeleteRowsByTenantName + }() + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + if test.wasError { + mock.ExpectExec("DELETE").WillReturnError(test.expectedError) + } else { + mock.ExpectExec("DELETE").WillReturnResult(sqlxmock.NewResult(1, 1)) + } + return db, err } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnError(badInsertError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") - if err != badInsertError { - t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", badInsertError, err) - } - - t.Log("bad update") - badUpdateError := errors.New("bad update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) + psqlDb, _ := psqlConnect(db.ConnectUrl) + err := deleteRowsByTenantName(psqlDb, "table", "tenantName") + if test.expectedError != err { + t.Errorf("Unexpected error, expected: %v, got: %v", test.expectedError, err) } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableSharedConfig(psqlDb, "tenantName", "value2", "value3") - if err != badUpdateError { - t.Errorf("Unexpected error in 'insertInTableSharedConfig', expected: %v, got: %v", badUpdateError, err) } } -func TestInsertInTableAggregator(t *testing.T) { - t.Log("happy insert") - savedPsqlConnect := psqlConnect - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - defer func() { - psqlConnect = savedPsqlConnect - }() - - psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) - if err != nil { - t.Errorf("Unexpected error in 'insert': %v", err) - } - - t.Log("happy update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) - if err != nil { - t.Errorf("Unexpected error in 'insertInTableAggregator': %v", err) - } - - t.Log("select error") - selectError := errors.New("select error") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - mock.ExpectQuery("SELECT").WillReturnError(selectError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) - if err != selectError { - t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", selectError, err) - } - - t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookAggregator where tenantName=tenantName, output=value2, have 2 rows" - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) - if err.Error() != select2RowsError { - t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", select2RowsError, err) - } - - t.Log("bad insert") - badInsertError := errors.New("bad insert") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnError(badInsertError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) - if err != badInsertError { - t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", badInsertError, err) - } - - t.Log("bad update") - badUpdateError := errors.New("bad update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableAggregator(psqlDb, "tenantName", "value2", []byte("value3")) - if err != badUpdateError { - t.Errorf("Unexpected error in 'insertInTableAggregator', expected: %v, got: %v", badUpdateError, err) - } +var insertFuncs = []string{ + "insertInTableSharedConfig", + "insertInTableAggregator", + "insertInTableName", + "insertOutputStats", + "insertCfgCacheSource", } -func TestInsertOutputStats(t *testing.T) { - t.Log("happy insert") - savedPsqlConnect := psqlConnect - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - defer func() { - psqlConnect = savedPsqlConnect - }() - - psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertOutputStats(psqlDb, "tenantName", "outputName", 1) - if err != nil { - t.Errorf("Unexpected error in 'insertOutputStats': %v", err) - } - - t.Log("happy update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) - if err != nil { - t.Errorf("Unexpected error in 'insertOutputStats': %v", err) - } - - t.Log("select error") - selectError := errors.New("select error") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - mock.ExpectQuery("SELECT").WillReturnError(selectError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) - if err != selectError { - t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", selectError, err) - } - - t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookOutputStats where tenantName=tenantName, outputName=outputName, have 2 rows" - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) - if err.Error() != select2RowsError { - t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", select2RowsError, err) - } - - t.Log("bad insert") - badInsertError := errors.New("bad insert") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnError(badInsertError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) - if err != badInsertError { - t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badInsertError, err) - } - - t.Log("bad update") - badUpdateError := errors.New("bad update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) +func TestInsert(t *testing.T) { + tests := []struct { + name string + wasQuaryError bool + quaryRows int + exec string + wasExecError bool + expectedError error + }{ + {" happy insert", false, 0, "INSERT", false, nil}, + {" happy update", false, 1, "UPDATE", false, nil}, + {" select error", true, 0, "INSERT", false, errors.New("select error")}, + //{" select 2 rows", false, 2, "INSERT", false, errors.New("error insert in postgresDb. Table:WebhookCfgCacheSource where tenantName=tenantName, have 2 rows")}, + {" bad insert", false, 0, "INSERT", true, errors.New("bad insert error")}, + {" bad update", false, 1, "UPDATE", true, errors.New("bad update error")}, + } + for _, insertFunc := range insertFuncs { + for _, test := range tests { + t.Run(insertFunc+test.name, func(t *testing.T) { + savedPsqlConnect := psqlConnect + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + if test.wasQuaryError { + mock.ExpectQuery("SELECT").WillReturnError(test.expectedError) + } else { + rows := sqlxmock.NewRows([]string{"count"}).AddRow(test.quaryRows) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + } + if test.wasExecError { + mock.ExpectExec(test.exec).WillReturnError(test.expectedError) + } else { + mock.ExpectExec(test.exec).WillReturnResult(sqlxmock.NewResult(1, 1)) + } + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + + psqlDb, _ := psqlConnect(db.ConnectUrl) + err := runInsertFunc(psqlDb, insertFunc) + if err != nil { + if test.expectedError == nil || err.Error() != test.expectedError.Error() { + t.Errorf("Unexpected error in %s, expected: %v, got: %v", insertFunc, test.expectedError, err) + } + } + }) } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertOutputStats(psqlDb, "tenantName", "outputName", 1) - if err != badUpdateError { - t.Errorf("Unexpected error in 'insertOutputStats', expected: %v, got: %v", badUpdateError, err) } } -func TestInsertInTableName(t *testing.T) { - t.Log("happy insert") - savedPsqlConnect := psqlConnect - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - defer func() { - psqlConnect = savedPsqlConnect - }() - - psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) - if err != nil { - t.Errorf("Unexpected error in 'insertInTableName': %v", err) - } - - t.Log("happy update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) - if err != nil { - t.Errorf("Unexpected error in 'insertInTableName': %v", err) - } - - t.Log("select error") - selectError := errors.New("select error") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - mock.ExpectQuery("SELECT").WillReturnError(selectError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) - if err != selectError { - t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", selectError, err) - } - - t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookBucket where tenantName=tenantName, messageKey=messageKey, have 2 rows" - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) - if err.Error() != select2RowsError { - t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", select2RowsError, err) - } - - t.Log("bad insert") - badInsertError := errors.New("bad insert") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnError(badInsertError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) - if err != badInsertError { - t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", badInsertError, err) - } - - t.Log("bad update") - badUpdateError := errors.New("bad update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertInTableName(psqlDb, "tenantName", "messageKey", []byte("messageValue"), nil) - if err != badUpdateError { - t.Errorf("Unexpected error in 'insertInTableName', expected: %v, got: %v", badUpdateError, err) +func TestInsertErrorSelect2Rows(t *testing.T) { + tests := []struct { + f string + expectedError string + }{ + {"insertInTableSharedConfig", "error insert in postgresDb. Table:WebhookSharedConfig where tenantName=tenantName, apikeyname=apiKeyName, have 2 rows"}, + {"insertInTableAggregator", "error insert in postgresDb. Table:WebhookAggregator where tenantName=tenantName, output=output, have 2 rows"}, + {"insertInTableName", "error insert in postgresDb. Table:WebhookBucket where tenantName=tenantName, messageKey=messagekey, have 2 rows"}, + {"insertOutputStats", "error insert in postgresDb. Table:WebhookOutputStats where tenantName=tenantName, outputName=outputName, have 2 rows"}, + {"insertCfgCacheSource", "error insert in postgresDb. Table:WebhookCfgCacheSource where tenantName=tenantName, have 2 rows"}, + } + for _, test := range tests { + t.Run(test.f, func(t *testing.T) { + savedPsqlConnect := psqlConnect + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) + mock.ExpectQuery("SELECT").WillReturnRows(rows) + return db, err + } + defer func() { + psqlConnect = savedPsqlConnect + }() + psqlDb, _ := psqlConnect(db.ConnectUrl) + err := runInsertFunc(psqlDb, test.f) + if err == nil { + t.Errorf("no error, expected: %s", test.expectedError) + } else if err.Error() != test.expectedError { + t.Errorf("unexpected error, expected: %s got: %v", test.expectedError, err) + } + }) } } -func TestInsertCfgCacheSource(t *testing.T) { - t.Log("happy insert") - savedPsqlConnect := psqlConnect - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - defer func() { - psqlConnect = savedPsqlConnect - }() - - psqlDb, _ := psqlConnect(db.ConnectUrl) - err := insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") - if err != nil { - t.Errorf("Unexpected error in 'insertCfgCacheSource': %v", err) - } - - t.Log("happy update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnResult(sqlxmock.NewResult(1, 1)) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") - if err != nil { - t.Errorf("Unexpected error in 'insertCfgCacheSource': %v", err) - } - - t.Log("select error") - selectError := errors.New("select error") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - mock.ExpectQuery("SELECT").WillReturnError(selectError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") - if err != selectError { - t.Errorf("Unexpected error in 'insertCfgCacheSource', expected: %v, got: %v", selectError, err) - } - - t.Log("select 2 rows") - select2RowsError := "error insert in postgresDb. Table:WebhookCfgCacheSource where tenantName=tenantName, have 2 rows" - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") - if err.Error() != select2RowsError { - t.Errorf("Unexpected error in 'insertCfgCacheSource', expected: %v, got: %v", select2RowsError, err) - } - - t.Log("bad insert") - badInsertError := errors.New("bad insert") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(0) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("INSERT").WillReturnError(badInsertError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") - if err != badInsertError { - t.Errorf("Unexpected error in 'insertCfgCacheSource', expected: %v, got: %v", badInsertError, err) - } - - t.Log("bad update") - badUpdateError := errors.New("bad update") - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) - } - rows := sqlxmock.NewRows([]string{"count"}).AddRow(1) - mock.ExpectQuery("SELECT").WillReturnRows(rows) - mock.ExpectExec("UPDATE").WillReturnError(badUpdateError) - return db, err - } - psqlDb, _ = psqlConnect(db.ConnectUrl) - err = insertCfgCacheSource(psqlDb, "tenantName", "cfgFile") - if err != badUpdateError { - t.Errorf("Unexpected error in 'insertCfgCacheSource', expected: %v, got: %v", badUpdateError, err) - } +func runInsertFunc(db *sqlx.DB, funcName string) error { + switch funcName { + case "insertInTableSharedConfig": + return insertInTableSharedConfig(db, "tenantName", "apiKeyName", "value") + case "insertInTableAggregator": + return insertInTableAggregator(db, "tenantName", "output", []byte("saving")) + case "insertInTableName": + return insertInTableName(db, "tenantName", "messagekey", []byte("messageValue"), &time.Time{}) + case "insertOutputStats": + return insertOutputStats(db, "tenantName", "outputName", 1) + case "insertCfgCacheSource": + return insertCfgCacheSource(db, "tenantName", "cfgfile") + } + return nil } From 47b581bbc2067065abd0119a1bc1c9684c5c87c3 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Tue, 14 Dec 2021 19:37:00 +0600 Subject: [PATCH 33/61] test: added tests for api --- router/api.go | 5 +- router/api_test.go | 357 ++++++++++++++++++++++++++++++++++++++++++++- router/router.go | 8 +- 3 files changed, 363 insertions(+), 7 deletions(-) diff --git a/router/api.go b/router/api.go index 5d36750a..96c0ce28 100644 --- a/router/api.go +++ b/router/api.go @@ -88,7 +88,10 @@ func AddOutput(output *data.OutputSettings) error { return Instance().addOutput(output) } func UpdateOutput(output *data.OutputSettings) error { - Instance().deleteOutput(output.Name, false) + err := Instance().deleteOutput(output.Name, false) + if err != nil { + return err + } return Instance().addOutput(output) } func ListOutputs() []data.OutputSettings { diff --git a/router/api_test.go b/router/api_test.go index a24441bb..13f2327f 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -1,11 +1,19 @@ package router import ( + "errors" "fmt" + "io/ioutil" + "os" + "path/filepath" "testing" "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/dbservice/boltdb" + "github.com/aquasecurity/postee/dbservice/postgresdb" "github.com/aquasecurity/postee/outputs" + "github.com/aquasecurity/postee/routes" "github.com/stretchr/testify/assert" ) @@ -22,7 +30,19 @@ var outputSettings = &data.OutputSettings{ Enable: true, } +var inputRoute = &routes.InputRoute{ + Name: "my-route", + Outputs: []string{"my-slack"}, + Template: "legacy-slack", +} + +var template = &data.Template{ + Name: "legacy", + LegacyScanRenderer: "html", +} + func TestAddOutput(t *testing.T) { + Instance().cleanInstance() AddOutput(outputSettings) assert.Equal(t, 1, len(Instance().outputs), "one output expected") assert.Contains(t, Instance().outputs, "my-slack") @@ -32,15 +52,21 @@ func TestAddOutput(t *testing.T) { } func TestDeleteOutput(t *testing.T) { + Instance().cleanInstance() AddOutput(outputSettings) assert.Equal(t, 1, len(Instance().outputs), "one output expected") + AddRoute(&routes.InputRoute{Name: "my-route", Outputs: []string{"my-slack", "my-jira"}}) + assert.Equal(t, 2, len(Instance().inputRoutes["my-route"].Outputs), "two output expected") DeleteOutput("my-slack") assert.Equal(t, 0, len(Instance().outputs), "no outputs expected") + assert.Equal(t, 1, len(Instance().inputRoutes["my-route"].Outputs), "one output expected") } func TestEditOutput(t *testing.T) { + Instance().cleanInstance() modifiedUrl := "https://hooks.slack.com/services/TAAAA/XXX/" + expectedError := errors.New("output badName is not found") AddOutput(outputSettings) assert.Equal(t, 1, len(Instance().outputs), "one output expected") @@ -53,8 +79,13 @@ func TestEditOutput(t *testing.T) { assert.Equal(t, 1, len(Instance().outputs), "one output expected") assert.Equal(t, modifiedUrl, Instance().outputs["my-slack"].(*outputs.SlackOutput).Url, "url is updated") + err := UpdateOutput(&data.OutputSettings{Name: "badName"}) + if err != nil && errors.Is(err, expectedError) { + t.Errorf("unexpected error, expected: %v, got: %v", expectedError, err) + } } func TestListOutput(t *testing.T) { + Instance().cleanInstance() AddOutput(outputSettings) assert.Equal(t, 1, len(Instance().outputs), "one output expected") @@ -70,8 +101,330 @@ func TestListOutput(t *testing.T) { } -//TODO templates -//TODO routes +func TestAddRoute(t *testing.T) { + Instance().cleanInstance() + AddRoute(inputRoute) + assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") + assert.Contains(t, Instance().inputRoutes, "my-route") + assert.Equal(t, "my-route", Instance().inputRoutes["my-route"].Name, "check name failed") + assert.Equal(t, "*routes.InputRoute", fmt.Sprintf("%T", Instance().inputRoutes["my-route"]), "check name failed") +} + +func TestDeleteRoute(t *testing.T) { + AddRoute(inputRoute) + assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") + + DeleteRoute("my-route") + assert.Equal(t, 0, len(Instance().inputRoutes), "no routes expected") +} + +func TestEditRoute(t *testing.T) { + Instance().cleanInstance() + modifiedTemplate := "vuls-slack" + expectedError := errors.New("output badName is not found") + AddRoute(inputRoute) + assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") + + savedTempalate := *Instance().inputRoutes["my-route"] + r := Instance().inputRoutes["my-route"] + r.Template = modifiedTemplate + defer func() { + *Instance().inputRoutes["my-route"] = savedTempalate + }() + + UpdateRoute(r) + + assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") + assert.Equal(t, modifiedTemplate, Instance().inputRoutes["my-route"].Template, "template is updated") + + err := UpdateRoute(&routes.InputRoute{Name: "badName"}) + if err != nil && errors.Is(err, expectedError) { + t.Errorf("unexpected error, expected: %v, got: %v", expectedError, err) + } +} + +func TestListRoute(t *testing.T) { + Instance().cleanInstance() + AddRoute(inputRoute) + assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") + + routes := ListRoutes() + + assert.Equal(t, 1, len(routes), "one route expected") + + r := routes[0] + + assert.Equal(t, "my-route", r.Name, "check name failed") + assert.Equal(t, "my-slack", r.Outputs[0], "check output failed") + assert.Equal(t, "legacy-slack", r.Template, "check template failed") +} + +func TestAddTemplate(t *testing.T) { + Instance().cleanInstance() + AddTemplate(template) + assert.Equal(t, 1, len(Instance().templates), "one template expected") + assert.Contains(t, Instance().templates, "legacy") + assert.Equal(t, "*formatting.legacyScnEvaluator", fmt.Sprintf("%T", Instance().templates["legacy"]), "check name failed") +} + +func TestAddTemplateFromFile(t *testing.T) { + Instance().cleanInstance() + regoString := `package postee + default hello = false + +hello { + m := input.message + m == "world" +}` + err := ioutil.WriteFile("./testFile", []byte(regoString), 0644) + if err != nil { + t.Errorf("error write file: %v", err) + } + defer os.Remove("./testFile") + err = AddRegoTemplateFromFile("rego-template", "testFile") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + assert.Equal(t, 1, len(Instance().templates), "one template expected") + assert.Contains(t, Instance().templates, "rego-template") + assert.Equal(t, "*regoservice.regoEvaluator", fmt.Sprintf("%T", Instance().templates["rego-template"]), "check evaluator failed") +} + +func TestDeleteTemplate(t *testing.T) { + Instance().cleanInstance() + AddTemplate(template) + assert.Equal(t, 1, len(Instance().templates), "one template expected") + AddRoute(&routes.InputRoute{Name: "my-route", Template: "legacy"}) + assert.Equal(t, "legacy", Instance().inputRoutes["my-route"].Template, "one template expected") + + DeleteTemplate("legacy") + assert.Equal(t, 0, len(Instance().templates), "no templates expected") + assert.Equal(t, "", Instance().inputRoutes["my-route"].Template, "no template expected") +} + +func TestEditTemplate(t *testing.T) { + Instance().cleanInstance() + expectedError := errors.New("template badName is not found") + AddTemplate(template) + assert.Equal(t, 1, len(Instance().templates), "one template expected") + assert.Equal(t, "*formatting.legacyScnEvaluator", fmt.Sprintf("%T", Instance().templates["legacy"]), "legacyScnEvaluator expected") + + templ := template + + templ.LegacyScanRenderer = "" + templ.Body = `package postee` + + err := UpdateTemplate(templ) + if err != nil { + t.Errorf("unexpected errpr: %v", err) + } + + assert.Equal(t, 1, len(Instance().templates), "one template expected") + assert.Equal(t, "*regoservice.regoEvaluator", fmt.Sprintf("%T", Instance().templates["legacy"]), "ScanRenderer is updated") + + err = UpdateTemplate(&data.Template{Name: "badName"}) + if err != nil && errors.Is(err, expectedError) { + t.Errorf("unexpected error, expected: %v, got: %v", expectedError, err) + } +} + +func TestListTemplate(t *testing.T) { + Instance().cleanInstance() + AddTemplate(template) + assert.Equal(t, 1, len(Instance().templates), "one route expected") + + templates := ListTemplates() + + assert.Equal(t, 1, len(templates), "one route expected") + + templ := templates[0] + + assert.Equal(t, "legacy", templ, "check name failed") +} + +func TestNewConfig(t *testing.T) { + Instance().cleanInstance() + + fmt.Println(Instance()) + + tests := []struct { + funcName string + tenantName string + dbPath string + expectedDbPath string + }{ + {}, + } + WithNewConfig("tenantName") + //WithNewConfigAndDbPath() + for _, test := range tests { + t.Run("test "+test.funcName, func(t *testing.T) { + savedPathToDb := os.Getenv("PATH_TO_DB") + os.Setenv("PATH_TO_DB", test.dbPath) + defer os.Setenv("PATH_TO_DB", savedPathToDb) + runFunc(test.funcName, "", test.dbPath, test.tenantName) + + assert.Equal(t, 0, len(Instance().templates), "one template expected") + assert.Equal(t, 0, len(Instance().outputs), "one output expected") + assert.Equal(t, 0, len(Instance().inputRoutes), "one route expected") + }) + + } + +} + +func TestConfigFuncs(t *testing.T) { + tests := []struct { + funcName string + cfgPath string + tenantName string + clearCfg bool + templateName string + outputName string + routeName string + dbPath string + psqlUrl string + }{ + {"WithDefaultConfig", "", "", false, "raw", "my-slack", "route1", "/server/database/webhooks.db", ""}, + {"WithFileConfig", "test/cfg.yaml", "", false, "raw", "my-slack", "route1", "/server/database/webhooks.db", ""}, + {"WithDefaultConfigAndDbPath", "", "", false, "raw", "my-slack", "route1", "database/webhooks.db", ""}, + {"WithFileConfigAndDbPath", "test/cfg.yaml", "", false, "raw", "my-slack", "route1", "database/webhooks.db", ""}, + {"WithNewConfig", "", "", true, "", "", "", "./webhooks.db", ""}, + {"WithNewConfigAndDbPath", "test/cfg.yaml", "", true, "", "", "", "./webhooks.db", ""}, + } + for _, test := range tests { + t.Run("test "+test.funcName, func(t *testing.T) { + Instance().cleanInstance() + + savedPathToDb := os.Getenv("PATH_TO_DB") + savedPostgresUrl := os.Getenv("POSTGRES_URL") + os.Setenv("PATH_TO_DB", test.dbPath) + os.Setenv("POSTGRES_URL", test.psqlUrl) + defer func() { + os.Setenv("PATH_TO_DB", savedPathToDb) + os.Setenv("POSTGRES_URL", savedPostgresUrl) + }() + + err := runFunc(test.funcName, test.cfgPath, test.dbPath, test.tenantName) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if test.clearCfg { + assert.Equal(t, 0, len(Instance().templates), "no template expected") + assert.Equal(t, 0, len(Instance().outputs), "no output expected") + assert.Equal(t, 0, len(Instance().inputRoutes), "no route expected") + } else { + assert.Equal(t, 1, len(Instance().templates), "one template expected") + assert.Contains(t, Instance().templates, test.templateName) + + assert.Equal(t, 1, len(Instance().outputs), "one output expected") + assert.Contains(t, Instance().outputs, test.outputName) + assert.Equal(t, test.outputName, Instance().outputs[test.outputName].GetName(), "check name failed") + + assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") + assert.Contains(t, Instance().inputRoutes, test.routeName) + assert.Contains(t, Instance().inputRoutes[test.routeName].Outputs, test.outputName) + assert.Equal(t, test.templateName, Instance().inputRoutes[test.routeName].Template, "one template expected") + } + if postgresDb, ok := dbservice.Db.(*postgresdb.PostgresDb); ok { + assert.Equal(t, test.psqlUrl, postgresDb.ConnectUrl, "url configured") + assert.Equal(t, test.tenantName, postgresDb.TenantName, "tenantName configured") + } + + if boltDb, ok := dbservice.Db.(*boltdb.BoltDb); ok { + assert.Equal(t, test.dbPath, boltDb.DbPath, "dbPath configured") + } + }) + } +} + +var cfg = `Name: tenant + +routes: +- name: route1 + outputs: ["my-slack"] + template: raw + plugins: + Policy-Show-All: true + +templates: +- name: raw + body: | + package postee + result:=input + +outputs: +- name: my-slack + type: slack + enable: true + url: https://hooks.slack.com/services/ABCDF/1234/TTT` + +func runFunc(funcName, cfgPath, dbPath, tenantName string) error { + switch funcName { + case "WithFileConfig": + createTestCfg(cfgPath) + WithFileConfig(cfgPath) + defer func() { + os.Remove(defaultDbPath) + os.RemoveAll(filepath.Dir(cfgPath)) + }() + return nil + case "WithDefaultConfig": + createTestCfg(defaultConfigPath) + WithDefaultConfig() + defer func() { + os.Remove(defaultDbPath) + os.RemoveAll(filepath.Dir(defaultConfigPath)) + }() + return nil + case "WithNewConfig": + WithNewConfig(tenantName) + os.Remove(defaultDbPath) + return nil + case "WithNewConfigAndDbPath": + WithNewConfigAndDbPath(tenantName, dbPath) + os.Remove(defaultDbPath) + return nil + case "WithFileConfigAndDbPath": + createTestCfg(cfgPath) + WithFileConfigAndDbPath(cfgPath, dbPath) + defer func() { + os.RemoveAll(filepath.Dir(dbPath)) + os.RemoveAll(filepath.Dir(cfgPath)) + }() + return nil + case "WithDefaultConfigAndDbPath": + createTestCfg(defaultConfigPath) + WithDefaultConfigAndDbPath(dbPath) + defer func() { + os.RemoveAll(filepath.Dir(dbPath)) + os.RemoveAll(filepath.Dir(defaultConfigPath)) + }() + return nil + } + + return errors.New("don't have func: " + funcName) +} + +func createTestCfg(cfgPath string) error { + _, err := os.Stat(filepath.Dir(cfgPath)) + if err != nil { + if os.IsNotExist(err) { + err := os.Mkdir(filepath.Dir(cfgPath), os.ModePerm) + if err != nil { + return err + } + } else { + return err + } + } + + err = ioutil.WriteFile(cfgPath, []byte(cfg), 0644) + if err != nil { + return err + } + return nil +} func TestBuildPosgresUrl(t *testing.T) { tests := []struct { diff --git a/router/router.go b/router/router.go index 6c82b2b6..8756d090 100644 --- a/router/router.go +++ b/router/router.go @@ -156,11 +156,11 @@ func (ctx *Router) Send(data []byte) { ctx.queue <- data } func (ctx *Router) deleteTemplate(name string, removeFromRoutes bool) error { - _, ok := ctx.outputs[name] + _, ok := ctx.templates[name] if !ok { return xerrors.Errorf("template %s is not found", name) } - delete(ctx.outputs, name) + delete(ctx.templates, name) if removeFromRoutes { for _, route := range ctx.inputRoutes { @@ -256,8 +256,8 @@ func (ctx *Router) load() error { ctx.setAquaServerUrl(tenant.AquaServer) - postgresUrl := os.Getenv("$POSTGRES_URL") - pathToDb := os.Getenv("$PATH_TO_DB") + postgresUrl := os.Getenv("POSTGRES_URL") + pathToDb := os.Getenv("PATH_TO_DB") if err = dbservice.ConfigureDb(pathToDb, postgresUrl, tenant.Name); err != nil { return err From e04c940909fb445489c35670641e2900872be092 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Wed, 15 Dec 2021 11:12:37 +0600 Subject: [PATCH 34/61] test: added tests for configure psql --- router/api_test.go | 85 ++++++++++++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/router/api_test.go b/router/api_test.go index 13f2327f..c139f611 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -43,6 +43,8 @@ var template = &data.Template{ func TestAddOutput(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() + AddOutput(outputSettings) assert.Equal(t, 1, len(Instance().outputs), "one output expected") assert.Contains(t, Instance().outputs, "my-slack") @@ -53,6 +55,8 @@ func TestAddOutput(t *testing.T) { func TestDeleteOutput(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() + AddOutput(outputSettings) assert.Equal(t, 1, len(Instance().outputs), "one output expected") AddRoute(&routes.InputRoute{Name: "my-route", Outputs: []string{"my-slack", "my-jira"}}) @@ -65,8 +69,10 @@ func TestDeleteOutput(t *testing.T) { } func TestEditOutput(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() modifiedUrl := "https://hooks.slack.com/services/TAAAA/XXX/" expectedError := errors.New("output badName is not found") + AddOutput(outputSettings) assert.Equal(t, 1, len(Instance().outputs), "one output expected") @@ -86,6 +92,8 @@ func TestEditOutput(t *testing.T) { } func TestListOutput(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() + AddOutput(outputSettings) assert.Equal(t, 1, len(Instance().outputs), "one output expected") @@ -98,11 +106,12 @@ func TestListOutput(t *testing.T) { assert.Equal(t, "my-slack", r.Name, "check name failed") assert.Equal(t, "slack", r.Type, "check type failed") assert.True(t, r.Enable, "output must be enabled") - } func TestAddRoute(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() + AddRoute(inputRoute) assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") assert.Contains(t, Instance().inputRoutes, "my-route") @@ -111,6 +120,9 @@ func TestAddRoute(t *testing.T) { } func TestDeleteRoute(t *testing.T) { + Instance().cleanInstance() + defer Instance().cleanInstance() + AddRoute(inputRoute) assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") @@ -120,8 +132,10 @@ func TestDeleteRoute(t *testing.T) { func TestEditRoute(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() modifiedTemplate := "vuls-slack" expectedError := errors.New("output badName is not found") + AddRoute(inputRoute) assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") @@ -145,6 +159,8 @@ func TestEditRoute(t *testing.T) { func TestListRoute(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() + AddRoute(inputRoute) assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") @@ -161,6 +177,8 @@ func TestListRoute(t *testing.T) { func TestAddTemplate(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() + AddTemplate(template) assert.Equal(t, 1, len(Instance().templates), "one template expected") assert.Contains(t, Instance().templates, "legacy") @@ -169,6 +187,7 @@ func TestAddTemplate(t *testing.T) { func TestAddTemplateFromFile(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() regoString := `package postee default hello = false @@ -192,6 +211,8 @@ hello { func TestDeleteTemplate(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() + AddTemplate(template) assert.Equal(t, 1, len(Instance().templates), "one template expected") AddRoute(&routes.InputRoute{Name: "my-route", Template: "legacy"}) @@ -204,7 +225,9 @@ func TestDeleteTemplate(t *testing.T) { func TestEditTemplate(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() expectedError := errors.New("template badName is not found") + AddTemplate(template) assert.Equal(t, 1, len(Instance().templates), "one template expected") assert.Equal(t, "*formatting.legacyScnEvaluator", fmt.Sprintf("%T", Instance().templates["legacy"]), "legacyScnEvaluator expected") @@ -230,6 +253,8 @@ func TestEditTemplate(t *testing.T) { func TestListTemplate(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() + AddTemplate(template) assert.Equal(t, 1, len(Instance().templates), "one route expected") @@ -242,38 +267,21 @@ func TestListTemplate(t *testing.T) { assert.Equal(t, "legacy", templ, "check name failed") } -func TestNewConfig(t *testing.T) { +func TestSetInputCallbackFunc(t *testing.T) { Instance().cleanInstance() + defer Instance().cleanInstance() - fmt.Println(Instance()) - - tests := []struct { - funcName string - tenantName string - dbPath string - expectedDbPath string - }{ - {}, - } - WithNewConfig("tenantName") - //WithNewConfigAndDbPath() - for _, test := range tests { - t.Run("test "+test.funcName, func(t *testing.T) { - savedPathToDb := os.Getenv("PATH_TO_DB") - os.Setenv("PATH_TO_DB", test.dbPath) - defer os.Setenv("PATH_TO_DB", savedPathToDb) - runFunc(test.funcName, "", test.dbPath, test.tenantName) - - assert.Equal(t, 0, len(Instance().templates), "one template expected") - assert.Equal(t, 0, len(Instance().outputs), "one output expected") - assert.Equal(t, 0, len(Instance().inputRoutes), "one route expected") - }) + inputCallbackFunc := InputCallbackFunc(func(inputMessage map[string]interface{}) bool { return false }) - } + AddRoute(inputRoute) + assert.Equal(t, 0, len(Instance().inputCallBacks), "no inputCallBack expected") + SetInputCallbackFunc("my-route", inputCallbackFunc) + assert.Equal(t, 1, len(Instance().inputCallBacks), "one inputCallBack expected") } func TestConfigFuncs(t *testing.T) { + Instance().cleanInstance() tests := []struct { funcName string cfgPath string @@ -291,11 +299,12 @@ func TestConfigFuncs(t *testing.T) { {"WithFileConfigAndDbPath", "test/cfg.yaml", "", false, "raw", "my-slack", "route1", "database/webhooks.db", ""}, {"WithNewConfig", "", "", true, "", "", "", "./webhooks.db", ""}, {"WithNewConfigAndDbPath", "test/cfg.yaml", "", true, "", "", "", "./webhooks.db", ""}, + {"WithPostgresParams", "", "ParamsName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, + {"WithPostgresUrl", "", "ParamsName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, } for _, test := range tests { t.Run("test "+test.funcName, func(t *testing.T) { - Instance().cleanInstance() - + defer Instance().cleanInstance() savedPathToDb := os.Getenv("PATH_TO_DB") savedPostgresUrl := os.Getenv("POSTGRES_URL") os.Setenv("PATH_TO_DB", test.dbPath) @@ -305,7 +314,7 @@ func TestConfigFuncs(t *testing.T) { os.Setenv("POSTGRES_URL", savedPostgresUrl) }() - err := runFunc(test.funcName, test.cfgPath, test.dbPath, test.tenantName) + err := runFunc(test.funcName, test.cfgPath, test.dbPath, test.tenantName, test.psqlUrl) if err != nil { t.Errorf("unexpected error: %v", err) } @@ -359,7 +368,7 @@ outputs: enable: true url: https://hooks.slack.com/services/ABCDF/1234/TTT` -func runFunc(funcName, cfgPath, dbPath, tenantName string) error { +func runFunc(funcName, cfgPath, dbPath, tenantName, psqlUrl string) error { switch funcName { case "WithFileConfig": createTestCfg(cfgPath) @@ -401,6 +410,22 @@ func runFunc(funcName, cfgPath, dbPath, tenantName string) error { os.RemoveAll(filepath.Dir(defaultConfigPath)) }() return nil + case "WithPostgresParams": + savedInitPostgresDb := postgresdb.InitPostgresDb + postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } + defer func() { + postgresdb.InitPostgresDb = savedInitPostgresDb + }() + WithPostgresParams(tenantName, "ParamsDbName", "ParamsDbHostName", "ParamsPort", "ParamsUser", "ParamsPassword", "ParamsSslMode") + return nil + case "WithPostgresUrl": + savedInitPostgresDb := postgresdb.InitPostgresDb + postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } + defer func() { + postgresdb.InitPostgresDb = savedInitPostgresDb + }() + WithPostgresUrl(tenantName, psqlUrl) + return nil } return errors.New("don't have func: " + funcName) From aca7bf35993655d3c61449731ebdbbd3a3fa8627 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Wed, 15 Dec 2021 11:22:45 +0600 Subject: [PATCH 35/61] Refactor: code review notes are corrected --- dbservice/postgresdb/dbservice_test.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index 5794da2a..78c47eb3 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -189,8 +189,8 @@ var insertFuncs = []string{ func TestInsert(t *testing.T) { tests := []struct { name string - wasQuaryError bool - quaryRows int + wasQueryError bool + queryRows int exec string wasExecError bool expectedError error @@ -198,7 +198,6 @@ func TestInsert(t *testing.T) { {" happy insert", false, 0, "INSERT", false, nil}, {" happy update", false, 1, "UPDATE", false, nil}, {" select error", true, 0, "INSERT", false, errors.New("select error")}, - //{" select 2 rows", false, 2, "INSERT", false, errors.New("error insert in postgresDb. Table:WebhookCfgCacheSource where tenantName=tenantName, have 2 rows")}, {" bad insert", false, 0, "INSERT", true, errors.New("bad insert error")}, {" bad update", false, 1, "UPDATE", true, errors.New("bad update error")}, } @@ -211,10 +210,10 @@ func TestInsert(t *testing.T) { if err != nil { log.Println("failed to open sqlmock database:", err) } - if test.wasQuaryError { + if test.wasQueryError { mock.ExpectQuery("SELECT").WillReturnError(test.expectedError) } else { - rows := sqlxmock.NewRows([]string{"count"}).AddRow(test.quaryRows) + rows := sqlxmock.NewRows([]string{"count"}).AddRow(test.queryRows) mock.ExpectQuery("SELECT").WillReturnRows(rows) } if test.wasExecError { @@ -231,7 +230,7 @@ func TestInsert(t *testing.T) { psqlDb, _ := psqlConnect(db.ConnectUrl) err := runInsertFunc(psqlDb, insertFunc) if err != nil { - if test.expectedError == nil || err.Error() != test.expectedError.Error() { + if !errors.Is(err, test.expectedError) { t.Errorf("Unexpected error in %s, expected: %v, got: %v", insertFunc, test.expectedError, err) } } From 164c3f747fbc1188fd1b22394cf57951e46b48d3 Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Wed, 15 Dec 2021 16:41:39 +0600 Subject: [PATCH 36/61] added make target for golang-lint --- Makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Makefile b/Makefile index a392da90..c66a3898 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,9 @@ fmt : test : go test -race -coverprofile=coverage.txt -covermode=atomic ./router ./msgservice ./dbservice ./formatting ./data ./regoservice ./routes +lint : + golangci-lint run + cover : go test ./msgservice ./dbservice ./router ./formatting ./data ./regoservice ./routes -v -coverprofile=cover.out go tool cover -html=cover.out From c3e2a86c8a1abc8ff23676323db810ed294011a4 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Thu, 16 Dec 2021 18:32:12 +0600 Subject: [PATCH 37/61] Refactor: code review notes are corrected --- dbservice/boltdb/checker_test.go | 25 +- dbservice/boltdb/dbaggregator.go | 5 +- dbservice/boltdb/dbparam_test.go | 8 +- dbservice/boltdb/plgnstats_test.go | 5 +- dbservice/boltdb/sharedcfg_test.go | 8 +- dbservice/postgresdb/plgstats_test.go | 4 +- dbservice/postgresdb/sharedcfg_test.go | 8 +- main.go | 12 +- msgservice/msgservice_test.go | 3 - msgservice/regocriteria_test.go | 4 +- msgservice/scheduler_test.go | 4 +- regoservice/aggregation_test.go | 12 +- router/api_integration_test.go | 5 +- router/api_test.go | 315 +++++++++++++++++-------- utils/cert.go | 8 +- utils/utils.go | 2 +- webserver/webserver.go | 12 +- 17 files changed, 307 insertions(+), 133 deletions(-) diff --git a/dbservice/boltdb/checker_test.go b/dbservice/boltdb/checker_test.go index d53e6457..a23cb33d 100644 --- a/dbservice/boltdb/checker_test.go +++ b/dbservice/boltdb/checker_test.go @@ -152,19 +152,34 @@ func TestDbDelete(t *testing.T) { value := []byte("value") bucket := "b" - dbInsert(db, bucket, key, value) - dbDelete(db, bucket, [][]byte{key}) - dbDelete(db, bucket, [][]byte{key}) + err = dbInsert(db, bucket, key, value) + if err != nil { + t.Errorf("Can't insert in db: %v", err) + } + err = dbDelete(db, bucket, [][]byte{key}) + if err != nil { + t.Errorf("Can't delete from db: %v", err) + } + err = dbDelete(db, bucket, [][]byte{key}) + if err != nil { + t.Errorf("Can't delete from db: %v", err) + } bucket = "" - dbInsert(db, bucket, key, value) + err = dbInsert(db, bucket, key, value) + if err != nil { + t.Errorf("Can't insert in db: %v", err) + } } func TestWithoutAccessToDb(t *testing.T) { boltDb := NewBoltDb() dbPathReal := boltDb.DbPath defer func() { - os.Remove(boltDb.DbPath) + err := os.Remove(boltDb.DbPath) + if err != nil { + t.Errorf("Can't remove db: %v", err) + } boltDb.DbPath = dbPathReal }() boltDb.DbPath = "test_webhooks.db" diff --git a/dbservice/boltdb/dbaggregator.go b/dbservice/boltdb/dbaggregator.go index ad61eac7..315dc026 100644 --- a/dbservice/boltdb/dbaggregator.go +++ b/dbservice/boltdb/dbaggregator.go @@ -55,6 +55,9 @@ func (boltDb *BoltDb) AggregateScans(output string, } return nil, nil } - dbInsert(db, dbparam.DbBucketAggregator, []byte(output), nil) + err = dbInsert(db, dbparam.DbBucketAggregator, []byte(output), nil) + if err != nil { + return nil, err + } return aggregatedScans, nil } diff --git a/dbservice/boltdb/dbparam_test.go b/dbservice/boltdb/dbparam_test.go index 867db148..6b57a866 100644 --- a/dbservice/boltdb/dbparam_test.go +++ b/dbservice/boltdb/dbparam_test.go @@ -32,9 +32,13 @@ func TestSetNewDbPathFromEnv(t *testing.T) { if err != nil { t.Errorf("Can't create dir: %s", baseDir) } - os.Chmod(baseDir, 0) + if err := os.Chmod(baseDir, 0); err != nil { + t.Errorf("Can't change permission: %v", err) + } + } + if err := db.SetNewDbPath(test.pathToDb); err != nil { + t.Errorf("Can't set new dbPath: %v", err) } - db.SetNewDbPath(test.pathToDb) defer os.RemoveAll(baseDir) defer db.ChangeDbPath(dbPathOld) diff --git a/dbservice/boltdb/plgnstats_test.go b/dbservice/boltdb/plgnstats_test.go index 41d85f5f..c856c2ee 100644 --- a/dbservice/boltdb/plgnstats_test.go +++ b/dbservice/boltdb/plgnstats_test.go @@ -20,7 +20,10 @@ func TestRegisterPlgnInvctn(t *testing.T) { expectedCnt := 3 keyToTest := "test" for i := 0; i < expectedCnt; i++ { - dbBolt.RegisterPlgnInvctn(keyToTest) + err := dbBolt.RegisterPlgnInvctn(keyToTest) + if err != nil { + t.Errorf("Unexpected RegisterPlgnInvctn error: %v", err) + } } r, err := getPlgnStats(dbBolt) if err != nil { diff --git a/dbservice/boltdb/sharedcfg_test.go b/dbservice/boltdb/sharedcfg_test.go index 5befd4f6..49a712f0 100644 --- a/dbservice/boltdb/sharedcfg_test.go +++ b/dbservice/boltdb/sharedcfg_test.go @@ -13,7 +13,9 @@ func TestApiKey(t *testing.T) { db.DbPath = dbPathReal }() db.DbPath = "test_webhooks.db" - db.EnsureApiKey() + if err := db.EnsureApiKey(); err != nil { + t.Errorf("Unexpected EnsureApiKey error: %v", err) + } key, err := db.GetApiKey() if err != nil { t.Fatal("error while getting value of API key") @@ -48,7 +50,9 @@ func TestApiKeyRenewal(t *testing.T) { db.DbPath = "test_webhooks.db" var keys [2]string for i := 0; i < 2; i++ { - db.EnsureApiKey() + if err := db.EnsureApiKey(); err != nil { + t.Errorf("Unexpected error EnsureApiKey: %v", err) + } key, err := db.GetApiKey() if err != nil { t.Fatal("error while getting value of API key") diff --git a/dbservice/postgresdb/plgstats_test.go b/dbservice/postgresdb/plgstats_test.go index 3c17962d..b9a0a709 100644 --- a/dbservice/postgresdb/plgstats_test.go +++ b/dbservice/postgresdb/plgstats_test.go @@ -34,7 +34,9 @@ func TestRegisterPlgnInvctn(t *testing.T) { expectedCnt := 3 keyToTest := "test" for i := 0; i < expectedCnt; i++ { - db.RegisterPlgnInvctn(keyToTest) + if err := db.RegisterPlgnInvctn(keyToTest); err != nil { + t.Errorf("unexpected error: %v", err) + } } if receivedKey != expectedCnt { t.Errorf("Persisted count doesn't match expected. Expected %d, got %d\n", receivedKey, expectedCnt) diff --git a/dbservice/postgresdb/sharedcfg_test.go b/dbservice/postgresdb/sharedcfg_test.go index 45246f08..74f79cdc 100644 --- a/dbservice/postgresdb/sharedcfg_test.go +++ b/dbservice/postgresdb/sharedcfg_test.go @@ -27,7 +27,9 @@ func TestApiKey(t *testing.T) { psqlConnect = savedPsqlConnect }() - db.EnsureApiKey() + if err := db.EnsureApiKey(); err != nil { + t.Errorf("Unexpected EnsureApiKey error: %v", err) + } key, err := db.GetApiKey() if err != nil { @@ -84,7 +86,9 @@ func TestApiKeyRenewal(t *testing.T) { var keys [2]string for i := 0; i < 2; i++ { - db.EnsureApiKey() + if err := db.EnsureApiKey(); err != nil { + t.Errorf("Unexpected EnsureApiKey error: %v", err) + } key, err := db.GetApiKey() if err != nil { t.Fatal("error while getting value of API key") diff --git a/main.go b/main.go index e52d49db..09158318 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "log" "os" "os/signal" @@ -33,8 +32,8 @@ var ( var rootCmd = &cobra.Command{ Use: "webhooksrv", - Short: fmt.Sprintf("Aqua Container Security Webhook server\n"), - Long: fmt.Sprintf("Aqua Container Security Webhook server\n"), + Short: "Aqua Container Security Webhook server\n", + Long: "Aqua Container Security Webhook server\n", } func init() { @@ -87,7 +86,12 @@ func main() { Daemonize() } - rootCmd.Execute() + err := rootCmd.Execute() + if err != nil { + log.Printf("Execute error %v", err) + return + } + } func Daemonize() { diff --git a/msgservice/msgservice_test.go b/msgservice/msgservice_test.go index 8ac4ecad..b883b3b9 100644 --- a/msgservice/msgservice_test.go +++ b/msgservice/msgservice_test.go @@ -11,9 +11,6 @@ import ( ) var ( - invalidJson = `{ - image : "My Image" - }` db = boltdb.NewBoltDb() ) diff --git a/msgservice/regocriteria_test.go b/msgservice/regocriteria_test.go index 8c10cf2f..2179baac 100644 --- a/msgservice/regocriteria_test.go +++ b/msgservice/regocriteria_test.go @@ -103,7 +103,9 @@ func validateRegoInput(t *testing.T, caseDesc string, input map[string]interface if err != nil { t.Error("Can't create regoFile.rego file") } - regoFile.WriteString(regoCriteria) + if _, err := regoFile.WriteString(regoCriteria); err != nil { + t.Errorf("Can't write string: %v", err) + } defer os.Remove("regoFile.rego") defer regoFile.Close() diff --git a/msgservice/scheduler_test.go b/msgservice/scheduler_test.go index c8f053df..a5b66d9c 100644 --- a/msgservice/scheduler_test.go +++ b/msgservice/scheduler_test.go @@ -16,7 +16,9 @@ func TestSheduler(t *testing.T) { demoRoute.Plugins.AggregateTimeoutSeconds = 3 demoSend := func(plg outputs.Output, cnt map[string]string) { - plg.Send(cnt) + if err := plg.Send(cnt); err != nil { + t.Errorf("Unexpected send error: %v", err) + } } demoAggregate := func(outputName string, currentContent map[string]string, counts int, ignoreLength bool) []map[string]string { return []map[string]string{ diff --git a/regoservice/aggregation_test.go b/regoservice/aggregation_test.go index bdc76fee..f826be0c 100644 --- a/regoservice/aggregation_test.go +++ b/regoservice/aggregation_test.go @@ -73,9 +73,15 @@ func aggregateBuildinRego(t *testing.T, caseDesc string, regoRule *string, aggre commonRegoFilename := "common.rego" buildinRegoTemplates = []string{commonRegoFilename, testRego, aggrRego} //common part goes in single bundle - ioutil.WriteFile(commonRegoFilename, []byte(commonRego), 0644) - ioutil.WriteFile(testRego, []byte(*regoRule), 0644) - ioutil.WriteFile(aggrRego, []byte(*aggregationRegoRule), 0644) + if err := ioutil.WriteFile(commonRegoFilename, []byte(commonRego), 0644); err != nil { + t.Errorf("Can't write file: %v", err) + } + if err := ioutil.WriteFile(testRego, []byte(*regoRule), 0644); err != nil { + t.Errorf("Can't write file: %v", err) + } + if err := ioutil.WriteFile(aggrRego, []byte(*aggregationRegoRule), 0644); err != nil { + t.Errorf("Can't write file: %v", err) + } defer func() { buildinRegoTemplates = buildinRegoTemplatesSaved diff --git a/router/api_integration_test.go b/router/api_integration_test.go index 77bae892..3be44e2d 100644 --- a/router/api_integration_test.go +++ b/router/api_integration_test.go @@ -61,12 +61,15 @@ func TestAudit(t *testing.T) { t.Logf("Error: %v", err) return } - router.AddOutput(&data.OutputSettings{ + err = router.AddOutput(&data.OutputSettings{ Name: "test-webhook", Type: "webhook", Enable: true, Url: ts.URL, }) + if err != nil { + return + } router.AddRoute(&routes.InputRoute{ Name: "test", diff --git a/router/api_test.go b/router/api_test.go index c139f611..5d3f933d 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -30,22 +30,50 @@ var outputSettings = &data.OutputSettings{ Enable: true, } +var outputSettingsTeams = &data.OutputSettings{ + Type: "teams", + Name: "ms-teams", + Url: "https://outlook.office.com/webhook/", + Enable: true, +} + var inputRoute = &routes.InputRoute{ Name: "my-route", Outputs: []string{"my-slack"}, Template: "legacy-slack", } +var inputRouteJira = &routes.InputRoute{ + Name: "my-jira", + Outputs: []string{"my-jira"}, + Template: "legacy-jira", +} + +var inputRouteHtml = &routes.InputRoute{ + Name: "my-html", + Outputs: []string{"my-html"}, + Template: "legacy", +} + var template = &data.Template{ Name: "legacy", LegacyScanRenderer: "html", } +var templateSlack = &data.Template{ + Name: "legacy-slack", + LegacyScanRenderer: "slack", +} + func TestAddOutput(t *testing.T) { - Instance().cleanInstance() + if len(Instance().outputs) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() - AddOutput(outputSettings) + if err := AddOutput(outputSettings); err != nil { + t.Errorf("Can't add output: %v", err) + } assert.Equal(t, 1, len(Instance().outputs), "one output expected") assert.Contains(t, Instance().outputs, "my-slack") assert.Equal(t, "my-slack", Instance().outputs["my-slack"].GetName(), "check name failed") @@ -54,33 +82,52 @@ func TestAddOutput(t *testing.T) { } func TestDeleteOutput(t *testing.T) { - Instance().cleanInstance() + if len(Instance().inputRoutes) > 0 || len(Instance().outputs) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() - AddOutput(outputSettings) - assert.Equal(t, 1, len(Instance().outputs), "one output expected") - AddRoute(&routes.InputRoute{Name: "my-route", Outputs: []string{"my-slack", "my-jira"}}) + if err := AddOutput(outputSettings); err != nil { + t.Errorf("Can't add output: %v", err) + } + if err := AddOutput(outputSettingsTeams); err != nil { + t.Errorf("Can't add output: %v", err) + } + assert.Equal(t, 2, len(Instance().outputs), "two output expected") + + AddRoute(&routes.InputRoute{Name: "my-route", Outputs: []string{"my-slack", "ms-teams"}}) assert.Equal(t, 2, len(Instance().inputRoutes["my-route"].Outputs), "two output expected") - DeleteOutput("my-slack") - assert.Equal(t, 0, len(Instance().outputs), "no outputs expected") - assert.Equal(t, 1, len(Instance().inputRoutes["my-route"].Outputs), "one output expected") + if err := DeleteOutput("my-slack"); err != nil { + t.Errorf("Can't delte output: %v", err) + } + assert.Equal(t, 1, len(Instance().outputs), "one outputs expected") + assert.NotContains(t, Instance().outputs, "my-slack") + + assert.Equal(t, 1, len(Instance().inputRoutes["my-route"].Outputs), "one output in inputRoute expected") + assert.NotContains(t, Instance().inputRoutes["my-route"].Outputs, "my-slack") } func TestEditOutput(t *testing.T) { - Instance().cleanInstance() + if len(Instance().outputs) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() modifiedUrl := "https://hooks.slack.com/services/TAAAA/XXX/" expectedError := errors.New("output badName is not found") - AddOutput(outputSettings) + if err := AddOutput(outputSettings); err != nil { + t.Errorf("Can't add output: %v", err) + } assert.Equal(t, 1, len(Instance().outputs), "one output expected") s := Instance().outputs["my-slack"].CloneSettings() s.Url = modifiedUrl - UpdateOutput(s) + if err := UpdateOutput(s); err != nil { + t.Errorf("Can't update output: %v", err) + } assert.Equal(t, 1, len(Instance().outputs), "one output expected") assert.Equal(t, modifiedUrl, Instance().outputs["my-slack"].(*outputs.SlackOutput).Url, "url is updated") @@ -91,7 +138,9 @@ func TestEditOutput(t *testing.T) { } } func TestListOutput(t *testing.T) { - Instance().cleanInstance() + if len(Instance().outputs) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() AddOutput(outputSettings) @@ -109,7 +158,9 @@ func TestListOutput(t *testing.T) { } func TestAddRoute(t *testing.T) { - Instance().cleanInstance() + if len(Instance().inputRoutes) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() AddRoute(inputRoute) @@ -120,18 +171,25 @@ func TestAddRoute(t *testing.T) { } func TestDeleteRoute(t *testing.T) { - Instance().cleanInstance() + if len(Instance().inputRoutes) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() AddRoute(inputRoute) - assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") + AddRoute(inputRouteJira) + AddRoute(inputRouteHtml) + assert.Equal(t, 3, len(Instance().inputRoutes), "three route expected") DeleteRoute("my-route") - assert.Equal(t, 0, len(Instance().inputRoutes), "no routes expected") + assert.Equal(t, 2, len(Instance().inputRoutes), "two routes expected") + assert.NotContains(t, Instance().inputRoutes, "my-route") } func TestEditRoute(t *testing.T) { - Instance().cleanInstance() + if len(Instance().inputRoutes) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() modifiedTemplate := "vuls-slack" expectedError := errors.New("output badName is not found") @@ -158,13 +216,18 @@ func TestEditRoute(t *testing.T) { } func TestListRoute(t *testing.T) { - Instance().cleanInstance() + if len(Instance().inputRoutes) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() + routes := ListRoutes() + assert.Equal(t, 0, len(routes), "no route expected") + AddRoute(inputRoute) assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") - routes := ListRoutes() + routes = ListRoutes() assert.Equal(t, 1, len(routes), "one route expected") @@ -176,7 +239,9 @@ func TestListRoute(t *testing.T) { } func TestAddTemplate(t *testing.T) { - Instance().cleanInstance() + if len(Instance().templates) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() AddTemplate(template) @@ -186,7 +251,9 @@ func TestAddTemplate(t *testing.T) { } func TestAddTemplateFromFile(t *testing.T) { - Instance().cleanInstance() + if len(Instance().templates) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() regoString := `package postee default hello = false @@ -210,21 +277,27 @@ hello { } func TestDeleteTemplate(t *testing.T) { - Instance().cleanInstance() + if len(Instance().inputRoutes) > 0 || len(Instance().templates) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() AddTemplate(template) - assert.Equal(t, 1, len(Instance().templates), "one template expected") + AddTemplate(templateSlack) + assert.Equal(t, 2, len(Instance().templates), "two template expected") AddRoute(&routes.InputRoute{Name: "my-route", Template: "legacy"}) assert.Equal(t, "legacy", Instance().inputRoutes["my-route"].Template, "one template expected") DeleteTemplate("legacy") - assert.Equal(t, 0, len(Instance().templates), "no templates expected") + assert.Equal(t, 1, len(Instance().templates), "one templates expected") + assert.NotContains(t, Instance().templates, "legacy") assert.Equal(t, "", Instance().inputRoutes["my-route"].Template, "no template expected") } func TestEditTemplate(t *testing.T) { - Instance().cleanInstance() + if len(Instance().templates) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() expectedError := errors.New("template badName is not found") @@ -252,7 +325,9 @@ func TestEditTemplate(t *testing.T) { } func TestListTemplate(t *testing.T) { - Instance().cleanInstance() + if len(Instance().templates) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() AddTemplate(template) @@ -268,7 +343,9 @@ func TestListTemplate(t *testing.T) { } func TestSetInputCallbackFunc(t *testing.T) { - Instance().cleanInstance() + if len(Instance().inputCallBacks) > 0 { + Instance().cleanInstance() + } defer Instance().cleanInstance() inputCallbackFunc := InputCallbackFunc(func(inputMessage map[string]interface{}) bool { return false }) @@ -281,10 +358,12 @@ func TestSetInputCallbackFunc(t *testing.T) { } func TestConfigFuncs(t *testing.T) { - Instance().cleanInstance() + if len(Instance().inputRoutes) > 0 || len(Instance().outputs) > 0 || len(Instance().templates) > 0 { + Instance().cleanInstance() + } tests := []struct { funcName string - cfgPath string + f func() error tenantName string clearCfg bool templateName string @@ -293,18 +372,22 @@ func TestConfigFuncs(t *testing.T) { dbPath string psqlUrl string }{ - {"WithDefaultConfig", "", "", false, "raw", "my-slack", "route1", "/server/database/webhooks.db", ""}, - {"WithFileConfig", "test/cfg.yaml", "", false, "raw", "my-slack", "route1", "/server/database/webhooks.db", ""}, - {"WithDefaultConfigAndDbPath", "", "", false, "raw", "my-slack", "route1", "database/webhooks.db", ""}, - {"WithFileConfigAndDbPath", "test/cfg.yaml", "", false, "raw", "my-slack", "route1", "database/webhooks.db", ""}, - {"WithNewConfig", "", "", true, "", "", "", "./webhooks.db", ""}, - {"WithNewConfigAndDbPath", "test/cfg.yaml", "", true, "", "", "", "./webhooks.db", ""}, - {"WithPostgresParams", "", "ParamsName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, - {"WithPostgresUrl", "", "ParamsName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, + {"WithDefaultConfig", withDefaultConfigTest, "", false, "raw", "my-slack", "route1", "/server/database/webhooks.db", ""}, + {"WithFileConfig", withFileConfigTest, "", false, "raw", "my-slack", "route1", "/server/database/webhooks.db", ""}, + {"WithDefaultConfigAndDbPath", withDefaultConfigAndDbPathTest, "", false, "raw", "my-slack", "route1", "database/webhooks.db", ""}, + {"WithFileConfigAndDbPath", withFileConfigAndDbPathTest, "", false, "raw", "my-slack", "route1", "database/webhooks.db", ""}, + {"WithNewConfig", withNewConfigTest, "", true, "", "", "", "./webhooks.db", ""}, + {"WithNewConfigAndDbPath", withNewConfigAndDbPathTest, "", true, "", "", "", "./webhooks.db", ""}, + {"WithPostgresParams", withPostgresParamsTest, "ParamsTenantName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, + {"WithPostgresUrl", withPostgresUrlTest, "tenantName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, } for _, test := range tests { t.Run("test "+test.funcName, func(t *testing.T) { - defer Instance().cleanInstance() + defer func() { + Instance().cleanInstance() + dbservice.Db = nil + }() + savedPathToDb := os.Getenv("PATH_TO_DB") savedPostgresUrl := os.Getenv("POSTGRES_URL") os.Setenv("PATH_TO_DB", test.dbPath) @@ -314,7 +397,7 @@ func TestConfigFuncs(t *testing.T) { os.Setenv("POSTGRES_URL", savedPostgresUrl) }() - err := runFunc(test.funcName, test.cfgPath, test.dbPath, test.tenantName, test.psqlUrl) + err := test.f() if err != nil { t.Errorf("unexpected error: %v", err) } @@ -347,7 +430,12 @@ func TestConfigFuncs(t *testing.T) { } } -var cfg = `Name: tenant +var ( + cfgPath = "test/cfg.yaml" + tenantName = "tenantName" + dbPath = "test/cfg.yaml" + + cfg = `Name: tenant routes: - name: route1 @@ -367,68 +455,95 @@ outputs: type: slack enable: true url: https://hooks.slack.com/services/ABCDF/1234/TTT` +) -func runFunc(funcName, cfgPath, dbPath, tenantName, psqlUrl string) error { - switch funcName { - case "WithFileConfig": - createTestCfg(cfgPath) - WithFileConfig(cfgPath) - defer func() { - os.Remove(defaultDbPath) - os.RemoveAll(filepath.Dir(cfgPath)) - }() - return nil - case "WithDefaultConfig": - createTestCfg(defaultConfigPath) - WithDefaultConfig() - defer func() { - os.Remove(defaultDbPath) - os.RemoveAll(filepath.Dir(defaultConfigPath)) - }() - return nil - case "WithNewConfig": - WithNewConfig(tenantName) +var withDefaultConfigTest = func() error { + if err := createTestCfg(defaultConfigPath); err != nil { + return err + } + if err := WithDefaultConfig(); err != nil { + return err + } + defer func() { os.Remove(defaultDbPath) - return nil - case "WithNewConfigAndDbPath": - WithNewConfigAndDbPath(tenantName, dbPath) + os.RemoveAll(filepath.Dir(defaultConfigPath)) + }() + return nil +} + +var withFileConfigTest = func() error { + if err := createTestCfg(cfgPath); err != nil { + return err + } + if err := WithFileConfig(cfgPath); err != nil { + return err + } + defer func() { os.Remove(defaultDbPath) - return nil - case "WithFileConfigAndDbPath": - createTestCfg(cfgPath) - WithFileConfigAndDbPath(cfgPath, dbPath) - defer func() { - os.RemoveAll(filepath.Dir(dbPath)) - os.RemoveAll(filepath.Dir(cfgPath)) - }() - return nil - case "WithDefaultConfigAndDbPath": - createTestCfg(defaultConfigPath) - WithDefaultConfigAndDbPath(dbPath) - defer func() { - os.RemoveAll(filepath.Dir(dbPath)) - os.RemoveAll(filepath.Dir(defaultConfigPath)) - }() - return nil - case "WithPostgresParams": - savedInitPostgresDb := postgresdb.InitPostgresDb - postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } - defer func() { - postgresdb.InitPostgresDb = savedInitPostgresDb - }() - WithPostgresParams(tenantName, "ParamsDbName", "ParamsDbHostName", "ParamsPort", "ParamsUser", "ParamsPassword", "ParamsSslMode") - return nil - case "WithPostgresUrl": - savedInitPostgresDb := postgresdb.InitPostgresDb - postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } - defer func() { - postgresdb.InitPostgresDb = savedInitPostgresDb - }() - WithPostgresUrl(tenantName, psqlUrl) - return nil - } - - return errors.New("don't have func: " + funcName) + os.RemoveAll(filepath.Dir(cfgPath)) + }() + return nil +} + +var withNewConfigTest = func() error { + WithNewConfig(tenantName) + os.Remove(defaultDbPath) + return nil +} + +var withNewConfigAndDbPathTest = func() error { + WithNewConfigAndDbPath(tenantName, dbPath) + os.Remove(dbPath) + return nil +} + +var withFileConfigAndDbPathTest = func() error { + if err := createTestCfg(cfgPath); err != nil { + return err + } + if err := WithFileConfigAndDbPath(cfgPath, dbPath); err != nil { + return err + } + defer func() { + os.RemoveAll(filepath.Dir(dbPath)) + os.RemoveAll(filepath.Dir(cfgPath)) + }() + return nil +} + +var withDefaultConfigAndDbPathTest = func() error { + if err := createTestCfg(defaultConfigPath); err != nil { + return err + } + if err := WithDefaultConfigAndDbPath(dbPath); err != nil { + return err + } + defer func() { + os.RemoveAll(filepath.Dir(dbPath)) + os.RemoveAll(filepath.Dir(defaultConfigPath)) + }() + return nil +} + +var withPostgresParamsTest = func() error { + savedInitPostgresDb := postgresdb.InitPostgresDb + postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } + defer func() { + postgresdb.InitPostgresDb = savedInitPostgresDb + }() + WithPostgresParams("ParamsTenantName", "ParamsDbName", "ParamsDbHostName", "ParamsPort", "ParamsUser", "ParamsPassword", "ParamsSslMode") + return nil +} + +var withPostgresUrlTest = func() error { + savedInitPostgresDb := postgresdb.InitPostgresDb + postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } + defer func() { + postgresdb.InitPostgresDb = savedInitPostgresDb + }() + psqlUrl := "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode" + WithPostgresUrl(tenantName, psqlUrl) + return nil } func createTestCfg(cfgPath string) error { diff --git a/utils/cert.go b/utils/cert.go index d1953723..909841c6 100644 --- a/utils/cert.go +++ b/utils/cert.go @@ -94,7 +94,9 @@ func generateCertificate(hosts []string, keyFile string, certFile string) error return err } - pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + return err + } certOut.Close() keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) @@ -102,7 +104,9 @@ func generateCertificate(hosts []string, keyFile string, certFile string) error return err } - pem.Encode(keyOut, pemBlockForKey(priv)) + if err := pem.Encode(keyOut, pemBlockForKey(priv)); err != nil { + return err + } keyOut.Close() return nil } diff --git a/utils/utils.go b/utils/utils.go index d67ca730..3e50e47a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -30,7 +30,7 @@ func InitDebug() { } func Debug(format string, v ...interface{}) { - if dbg != false { + if !dbg { log.Printf(format, v...) } } diff --git a/webserver/webserver.go b/webserver/webserver.go index 447feba4..1b422cd6 100644 --- a/webserver/webserver.go +++ b/webserver/webserver.go @@ -58,7 +58,7 @@ func (ctx *WebServer) Start(host, tlshost string) { certPem := filepath.Join(rootDir, "cert.pem") keyPem := filepath.Join(rootDir, "key.pem") - if ok := utils.PathExists(keyPem); ok != true { + if ok := utils.PathExists(keyPem); !ok { utils.GenerateCertificate(keyPem, certPem) } @@ -122,12 +122,18 @@ func (ctx *WebServer) writeResponse(w http.ResponseWriter, httpStatus int, v int w.WriteHeader(httpStatus) if v != nil { result, _ := json.Marshal(v) - w.Write(result) + _, err := w.Write(result) + if err != nil { + log.Printf("Can't write response: %v", err) + } } } func (ctx *WebServer) writeResponseError(w http.ResponseWriter, httpError int, err error) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(httpError) - json.NewEncoder(w).Encode(err) + err = json.NewEncoder(w).Encode(err) + if err != nil { + log.Printf("Can't json Encode error: %v", err) + } } From 401146648119cd62b0da753e9d14972c24b9691e Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Fri, 17 Dec 2021 11:54:05 +0600 Subject: [PATCH 38/61] Refactor: code review notes are corrected --- dbservice/boltdb/checker_test.go | 41 +++++++++++++++++++++++++++--- dbservice/boltdb/dbparam_test.go | 14 +++++----- dbservice/boltdb/dbservice_test.go | 6 ++--- outputs/jira.go | 2 +- router/api_test.go | 5 ++-- utils/utils.go | 2 +- 6 files changed, 52 insertions(+), 18 deletions(-) diff --git a/dbservice/boltdb/checker_test.go b/dbservice/boltdb/checker_test.go index a23cb33d..c1333081 100644 --- a/dbservice/boltdb/checker_test.go +++ b/dbservice/boltdb/checker_test.go @@ -1,6 +1,7 @@ package boltdb import ( + "bytes" "os" "testing" "time" @@ -152,24 +153,56 @@ func TestDbDelete(t *testing.T) { value := []byte("value") bucket := "b" + key1 := []byte("key1") + value1 := []byte("value1") + bucket1 := "b1" + err = dbInsert(db, bucket, key, value) if err != nil { t.Errorf("Can't insert in db: %v", err) } - err = dbDelete(db, bucket, [][]byte{key}) + err = dbInsert(db, bucket1, key1, value1) + if err != nil { + t.Errorf("Can't insert in db: %v", err) + } + + selectValue, err := dbSelect(db, bucket, string(key)) if err != nil { t.Errorf("Can't delete from db: %v", err) } + if !bytes.Equal(value, selectValue) { + t.Errorf("bad insert/select, expected: %s, got: %s", value, selectValue) + } + + selectValue1, err := dbSelect(db, bucket1, string(key1)) + if err != nil { + t.Errorf("Can't delete from db: %v", err) + } + if !bytes.Equal(value1, selectValue1) { + t.Errorf("bad insert/select, expected: %s, got: %s", value1, selectValue1) + } + err = dbDelete(db, bucket, [][]byte{key}) if err != nil { t.Errorf("Can't delete from db: %v", err) } - bucket = "" - err = dbInsert(db, bucket, key, value) + selectValueAfterDel, err := dbSelect(db, bucket, string(key)) if err != nil { - t.Errorf("Can't insert in db: %v", err) + t.Errorf("Can't delete from db: %v", err) + } + if len(selectValueAfterDel) > 0 { + t.Errorf("bad delete/select, selectValueAfterDel= %s", selectValueAfterDel) } + + selectValue1AfterDel, err := dbSelect(db, bucket1, string(key1)) + if err != nil { + t.Errorf("Can't delete from db: %v", err) + } + if !bytes.Equal(value1, selectValue1AfterDel) { + t.Errorf("bad insert/select, expected: %s, got: %s", value1, selectValue1AfterDel) + } + } func TestWithoutAccessToDb(t *testing.T) { diff --git a/dbservice/boltdb/dbparam_test.go b/dbservice/boltdb/dbparam_test.go index 92e8236d..eff07252 100644 --- a/dbservice/boltdb/dbparam_test.go +++ b/dbservice/boltdb/dbparam_test.go @@ -1,6 +1,7 @@ package boltdb import ( + "errors" "os" "path/filepath" "strings" @@ -17,11 +18,12 @@ func TestSetNewDbPathFromEnv(t *testing.T) { pathToDb string changePermission bool expectedDBPath string + expectedErr error }{ - {"Empty pathToDb", "", false, defaultDbPath}, - {"Permission denied to create directory(default DbPath is used)", "/database/database.db", false, defaultDbPath}, - {"New DbPath", "./base/base.db", false, "./base/base.db"}, - {"Permission denied to check directory(default DbPath is used)", "webhook/database/webhooks.db", true, defaultDbPath}, + {"Empty pathToDb", "", false, defaultDbPath, nil}, + {"Permission denied to create directory(default DbPath is used)", "/database/database.db", false, defaultDbPath, errors.New("mkdir /database: permission denied")}, + {"New DbPath", "./base/base.db", false, "./base/base.db", nil}, + {"Permission denied to check directory(default DbPath is used)", "webhook/database/webhooks.db", true, defaultDbPath, errors.New("stat webhook/database/webhooks.db: permission denied")}, } for _, test := range tests { @@ -36,8 +38,8 @@ func TestSetNewDbPathFromEnv(t *testing.T) { t.Errorf("Can't change the mode dir in %s: %s", baseDir, err) } } - if err := db.SetNewDbPath(test.pathToDb); err != nil { - t.Errorf("Can't set new dbPath: %v", err) + if err := db.SetNewDbPath(test.pathToDb); err != nil && errors.Is(err, test.expectedErr) { + t.Errorf("unexpected error setNewDbPath, expected: %v, got: %v", test.expectedErr, err) } defer os.RemoveAll(baseDir) defer db.ChangeDbPath(dbPathOld) diff --git a/dbservice/boltdb/dbservice_test.go b/dbservice/boltdb/dbservice_test.go index 4e51b212..fab3ed08 100644 --- a/dbservice/boltdb/dbservice_test.go +++ b/dbservice/boltdb/dbservice_test.go @@ -135,7 +135,7 @@ func TestInitError(t *testing.T) { t.Errorf("Scan shouldn't be marked as new\n") } - if err != initErr { + if !errors.Is(err, initErr) { t.Errorf("Unexpected error: expected %s, got %s \n", initErr, err) } @@ -163,7 +163,7 @@ func TestSelectError(t *testing.T) { t.Errorf("Scan shouldn't be marked as new\n") } - if err != selectErr { + if !errors.Is(err, selectErr) { t.Errorf("Unexpected error: expected %s, got %s \n", selectErr, err) } @@ -210,7 +210,7 @@ func testBucketInsert(t *testing.T, testBucket string) { t.Errorf("Scan shouldn't be marked as new\n") } - if err != insertErr { + if !errors.Is(err, insertErr) { t.Errorf("Unexpected error: expected %s, got %s \n", insertErr, err) } } diff --git a/outputs/jira.go b/outputs/jira.go index 0fc1fcf2..03607a56 100644 --- a/outputs/jira.go +++ b/outputs/jira.go @@ -17,7 +17,7 @@ import ( "os" "strings" - "github.com/aquasecurity/go-jira" + jira "github.com/aquasecurity/go-jira" ) type JiraAPI struct { diff --git a/router/api_test.go b/router/api_test.go index d2ed18d8..62f25a98 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -10,9 +10,9 @@ import ( "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/dbservice" - "github.com/aquasecurity/postee/dbservice/boltdb" + "github.com/aquasecurity/postee/dbservice/boltdb" //nolint - used to get db type in TestConfigFuncs "github.com/aquasecurity/postee/dbservice/postgresdb" - "github.com/aquasecurity/postee/outputs" + "github.com/aquasecurity/postee/outputs" //nolint - used to get Output type in TestEditOutput "github.com/aquasecurity/postee/routes" "github.com/stretchr/testify/assert" ) @@ -422,7 +422,6 @@ func TestConfigFuncs(t *testing.T) { assert.Equal(t, test.psqlUrl, postgresDb.ConnectUrl, "url configured") assert.Equal(t, test.tenantName, postgresDb.TenantName, "tenantName configured") } - if boltDb, ok := dbservice.Db.(*boltdb.BoltDb); ok { assert.Equal(t, test.dbPath, boltDb.DbPath, "dbPath configured") } diff --git a/utils/utils.go b/utils/utils.go index 3e50e47a..1d5dcb33 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -30,7 +30,7 @@ func InitDebug() { } func Debug(format string, v ...interface{}) { - if !dbg { + if dbg { log.Printf(format, v...) } } From b3877b154a38290942a728ab510402274a1798c2 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Fri, 17 Dec 2021 16:21:31 +0600 Subject: [PATCH 39/61] refactor: changed func load --- router/api.go | 20 ++++++----- router/loads_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++++ router/router.go | 70 +++++++++++++++++++++++-------------- 3 files changed, 138 insertions(+), 34 deletions(-) diff --git a/router/api.go b/router/api.go index 2cb80cd3..0e447d9c 100644 --- a/router/api.go +++ b/router/api.go @@ -30,7 +30,8 @@ func WithDefaultConfig() error { } func WithFileConfig(cfgPath string) error { Instance().Terminate() - dbservice.ConfigureDb(defaultDbPath, "", "") + os.Setenv("POSTGRES_URL", "") + os.Setenv("PATH_TO_DB", defaultDbPath) return Instance().ApplyFileCfg(cfgPath, true) } @@ -43,7 +44,7 @@ func WithNewConfig(tenantName string) { //tenant name //initialize instance with custom db location func WithNewConfigAndDbPath(tenantName, dbPath string) { //tenant name Instance().Terminate() - dbservice.ConfigureDb(defaultDbPath, "", "") + dbservice.ConfigureDb(dbPath, "", "") Instance().initCfg(true) } @@ -53,23 +54,24 @@ func WithDefaultConfigAndDbPath(dbPath string) error { func WithFileConfigAndDbPath(cfgPath, dbPath string) error { Instance().Terminate() - dbservice.ConfigureDb(dbPath, "", "") + os.Setenv("POSTGRES_URL", "") + os.Setenv("PATH_TO_DB", dbPath) return Instance().ApplyFileCfg(cfgPath, true) } func WithPostgresParams(tenantName, dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) { postgresUrl := buildPostgresUrl(dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode) Instance().Terminate() - dbservice.ConfigureDb("", postgresUrl, tenantName) - Instance().initCfg(true) - Instance().load(true) + os.Setenv("POSTGRES_URL", postgresUrl) + os.Setenv("PATH_TO_DB", "") + Instance().ApplyPostgresCfg(tenantName, true) } func WithPostgresUrl(tenantName, postgresUrl string) { Instance().Terminate() - dbservice.ConfigureDb("", postgresUrl, tenantName) - Instance().initCfg(true) - Instance().load(true) + os.Setenv("POSTGRES_URL", postgresUrl) + os.Setenv("PATH_TO_DB", "") + Instance().ApplyPostgresCfg(tenantName, true) } func AquaServerUrl(aquaServerUrl string) { //optional diff --git a/router/loads_test.go b/router/loads_test.go index 6b4b598e..23d66985 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -1,6 +1,7 @@ package router import ( + "encoding/json" "fmt" "io/ioutil" "log" @@ -9,7 +10,9 @@ import ( "time" "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/dbservice" "github.com/aquasecurity/postee/dbservice/boltdb" + "github.com/aquasecurity/postee/dbservice/postgresdb" "github.com/aquasecurity/postee/msgservice" "github.com/aquasecurity/postee/outputs" "github.com/aquasecurity/postee/routes" @@ -223,3 +226,82 @@ func TestServiceGetters(t *testing.T) { t.Error("getScanService() doesn't return an instance of scanservice.ScanService") } } + +func TestApplyPostgresCfg(t *testing.T) { + testTenantSerrings := data.TenantSettings{ + Name: "TenantName", + AquaServer: "https://demolab.aquasec.com", + DBMaxSize: 13, + DBRemoveOldData: 7, + InputRoutes: []routes.InputRoute{ + { + Name: "route", + Outputs: []string{"slack", "teams"}, + Template: "legacy", + }, + }, + Outputs: []data.OutputSettings{ + { + Name: "slack", + Type: "slack", + Url: "https://hooks.slack.com/services/TAAAA/BBB/", + Enable: true, + }, + { + Name: "teams", + Type: "teams", + Url: "https://outlook.office.com/webhook/", + Enable: true, + }, + }, + Templates: []data.Template{ + { + Name: "legacy", + LegacyScanRenderer: "html", + }, + }, + } + wrap := ctxWrapper{} + wrap.init() + demoCtx := wrap.instance + + savedDb := dbservice.Db + dbservice.Db = &postgresdb.PostgresDb{} + + savedPostgresUrl := os.Getenv("POSTGRES_URL") + os.Setenv("POSTGRES_URL", "postgres://User:Password@DbHostName:Port/DbName?sslmode=SslMode") + + savedGetCfgCacheSource := postgresdb.GetCfgCacheSource + postgresdb.GetCfgCacheSource = func(postgresDb *postgresdb.PostgresDb) (string, error) { + cfg, _ := json.Marshal(testTenantSerrings) + return string(cfg), nil + } + + savedUpdateCfgCacheSource := postgresdb.UpdateCfgCacheSource + postgresdb.UpdateCfgCacheSource = func(postgresDb *postgresdb.PostgresDb, cfgfile string) error { return nil } + + savedInitPostgresDb := postgresdb.InitPostgresDb + postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } + defer func() { + wrap.teardown() + dbservice.Db = savedDb + os.Setenv("POSTGRES_URL", savedPostgresUrl) + postgresdb.GetCfgCacheSource = savedGetCfgCacheSource + postgresdb.InitPostgresDb = savedInitPostgresDb + postgresdb.UpdateCfgCacheSource = savedUpdateCfgCacheSource + }() + + err := demoCtx.ApplyPostgresCfg("tenantName", false) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expectedOutputsCnt := 2 + if len(demoCtx.outputs) != expectedOutputsCnt { + t.Errorf("There are stopped outputs\nWaited: %d\nResult: %d", expectedOutputsCnt, len(Instance().outputs)) + } + + if testTenantSerrings.Outputs[0].Name != Instance().databaseCfgCacheSource.Outputs[0].Name { + t.Errorf("Output names are not equals, expected: %s, got: %s", testTenantSerrings.Outputs[0].Name, Instance().databaseCfgCacheSource.Outputs[0].Name) + } +} diff --git a/router/router.go b/router/router.go index db9aca70..d4dfee6e 100644 --- a/router/router.go +++ b/router/router.go @@ -103,9 +103,50 @@ func (ctx *Router) ApplyFileCfg(cfgfile string, synchronous bool) error { ctx.cfgfile = cfgfile + tenant, err := Parsev2cfg(ctx.cfgfile) + if err != nil { + return err + } + + postgresUrl := os.Getenv("POSTGRES_URL") + pathToDb := os.Getenv("PATH_TO_DB") + + if err = dbservice.ConfigureDb(pathToDb, postgresUrl, tenant.Name); err != nil { + return err + } + + ctx.initCfg(synchronous) + + err = ctx.initTenantSettings(tenant) + if err != nil { + return err + } + + if !ctx.synchronous { + go ctx.listen() + } + + return nil +} + +func (ctx *Router) ApplyPostgresCfg(tenantName string, synchronous bool) error { + log.Printf("Starting Router....") + + postgresUrl := os.Getenv("POSTGRES_URL") + pathToDb := os.Getenv("PATH_TO_DB") + + if err := dbservice.ConfigureDb(pathToDb, postgresUrl, tenantName); err != nil { + return err + } + ctx.initCfg(synchronous) - err := ctx.load(false) + tenant, err := ctx.loadCfgCacheSourceFromPostgres() + if err != nil { + return err + } + + err = ctx.initTenantSettings(tenant) if err != nil { return err } @@ -280,35 +321,13 @@ func (ctx *Router) setAquaServerUrl(url string) { } } -func (ctx *Router) load(loadCfgFromPostgres bool) error { +func (ctx *Router) initTenantSettings(tenant *data.TenantSettings) error { ctx.mutexScan.Lock() defer ctx.mutexScan.Unlock() log.Printf("Loading alerts configuration file %s ....\n", ctx.cfgfile) - var tenant = &data.TenantSettings{} - var err error - if loadCfgFromPostgres { - tenant, err = ctx.loadCfgCacheSourceFromPostgres() - if err != nil { - return err - } - } else { - tenant, err = Parsev2cfg(ctx.cfgfile) - if err != nil { - return err - } - ctx.databaseCfgCacheSource = tenant - } - ctx.setAquaServerUrl(tenant.AquaServer) - postgresUrl := os.Getenv("POSTGRES_URL") - pathToDb := os.Getenv("PATH_TO_DB") - - if err = dbservice.ConfigureDb(pathToDb, postgresUrl, tenant.Name); err != nil { - return err - } - dbparam.DbSizeLimit = tenant.DBMaxSize actualDbTestInterval := tenant.DBTestInterval @@ -344,7 +363,7 @@ func (ctx *Router) load(loadCfgFromPostgres bool) error { for _, settings := range tenant.Outputs { utils.Debug("%#v\n", anonymizeSettings(&settings)) - err = ctx.addOutput(&settings) + err := ctx.addOutput(&settings) if err != nil { log.Printf("Can not initialize output %s: %v \n", settings.Name, err) @@ -353,6 +372,7 @@ func (ctx *Router) load(loadCfgFromPostgres bool) error { } } + ctx.databaseCfgCacheSource = tenant return nil } func (ctx *Router) setInputCallbackFunc(routeName string, callback InputCallbackFunc) { From b726ea1b796636ba1a335375590ccbdd376502aa Mon Sep 17 00:00:00 2001 From: Andrey Levchenko Date: Tue, 21 Dec 2021 13:36:40 +0600 Subject: [PATCH 40/61] misc api cleanup --- main.go | 5 +++- router/api.go | 22 +++++++----------- router/loads_test.go | 10 ++++---- router/routehandling_test.go | 8 +++---- router/router.go | 44 ++++++++++++++++-------------------- 5 files changed, 39 insertions(+), 50 deletions(-) diff --git a/main.go b/main.go index ed3523d4..c5df1696 100644 --- a/main.go +++ b/main.go @@ -73,7 +73,10 @@ func main() { cfgfile = os.Getenv("POSTEE_CFG") } - err := router.Instance().ApplyFileCfg(cfgfile, false) + postgresUrl := os.Getenv("POSTGRES_URL") + pathToDb := os.Getenv("PATH_TO_DB") + + err := router.Instance().ApplyFileCfg(cfgfile, postgresUrl, pathToDb, false) if err != nil { log.Printf("Can't start alert manager %v", err) diff --git a/router/api.go b/router/api.go index 0e447d9c..707bde89 100644 --- a/router/api.go +++ b/router/api.go @@ -30,22 +30,22 @@ func WithDefaultConfig() error { } func WithFileConfig(cfgPath string) error { Instance().Terminate() - os.Setenv("POSTGRES_URL", "") - os.Setenv("PATH_TO_DB", defaultDbPath) - return Instance().ApplyFileCfg(cfgPath, true) + return Instance().ApplyFileCfg(cfgPath, "", defaultDbPath, true) } func WithNewConfig(tenantName string) { //tenant name Instance().Terminate() dbservice.ConfigureDb(defaultDbPath, "", "") - Instance().initCfg(true) + Instance().cleanInstance() + Instance().cleanChannels(true) } //initialize instance with custom db location func WithNewConfigAndDbPath(tenantName, dbPath string) { //tenant name Instance().Terminate() dbservice.ConfigureDb(dbPath, "", "") - Instance().initCfg(true) + Instance().cleanInstance() + Instance().cleanChannels(true) } func WithDefaultConfigAndDbPath(dbPath string) error { @@ -54,24 +54,18 @@ func WithDefaultConfigAndDbPath(dbPath string) error { func WithFileConfigAndDbPath(cfgPath, dbPath string) error { Instance().Terminate() - os.Setenv("POSTGRES_URL", "") - os.Setenv("PATH_TO_DB", dbPath) - return Instance().ApplyFileCfg(cfgPath, true) + return Instance().ApplyFileCfg(cfgPath, "", defaultDbPath, true) } func WithPostgresParams(tenantName, dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) { postgresUrl := buildPostgresUrl(dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode) Instance().Terminate() - os.Setenv("POSTGRES_URL", postgresUrl) - os.Setenv("PATH_TO_DB", "") - Instance().ApplyPostgresCfg(tenantName, true) + Instance().ApplyPostgresCfg(tenantName, postgresUrl, true) } func WithPostgresUrl(tenantName, postgresUrl string) { Instance().Terminate() - os.Setenv("POSTGRES_URL", postgresUrl) - os.Setenv("PATH_TO_DB", "") - Instance().ApplyPostgresCfg(tenantName, true) + Instance().ApplyPostgresCfg(tenantName, postgresUrl, true) } func AquaServerUrl(aquaServerUrl string) { //optional diff --git a/router/loads_test.go b/router/loads_test.go index 23d66985..5120ce42 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -151,7 +151,7 @@ func TestLoads(t *testing.T) { defer wrap.teardown() demoCtx := wrap.instance - err := demoCtx.ApplyFileCfg(wrap.cfgPath, false) + err := demoCtx.ApplyFileCfg(wrap.cfgPath, "", "", false) if err != nil { t.Fatal(err) } @@ -193,7 +193,7 @@ func TestReload(t *testing.T) { demoCtx := wrap.instance - errStart := demoCtx.ApplyFileCfg(wrap.cfgPath, false) + errStart := demoCtx.ApplyFileCfg(wrap.cfgPath, "", "", false) if errStart != nil { t.Fatal(errStart) } @@ -268,8 +268,7 @@ func TestApplyPostgresCfg(t *testing.T) { savedDb := dbservice.Db dbservice.Db = &postgresdb.PostgresDb{} - savedPostgresUrl := os.Getenv("POSTGRES_URL") - os.Setenv("POSTGRES_URL", "postgres://User:Password@DbHostName:Port/DbName?sslmode=SslMode") + postgresUrl := "postgres://User:Password@DbHostName:Port/DbName?sslmode=SslMode" savedGetCfgCacheSource := postgresdb.GetCfgCacheSource postgresdb.GetCfgCacheSource = func(postgresDb *postgresdb.PostgresDb) (string, error) { @@ -285,13 +284,12 @@ func TestApplyPostgresCfg(t *testing.T) { defer func() { wrap.teardown() dbservice.Db = savedDb - os.Setenv("POSTGRES_URL", savedPostgresUrl) postgresdb.GetCfgCacheSource = savedGetCfgCacheSource postgresdb.InitPostgresDb = savedInitPostgresDb postgresdb.UpdateCfgCacheSource = savedUpdateCfgCacheSource }() - err := demoCtx.ApplyPostgresCfg("tenantName", false) + err := demoCtx.ApplyPostgresCfg("tenantName", postgresUrl, false) if err != nil { t.Errorf("Unexpected error: %v", err) } diff --git a/router/routehandling_test.go b/router/routehandling_test.go index 046733c4..9d8bce99 100644 --- a/router/routehandling_test.go +++ b/router/routehandling_test.go @@ -122,7 +122,7 @@ func runTestRouteHandlingCase(t *testing.T, caseDesc string, cfgPath string, exp defer wrap.teardown() - err = wrap.instance.ApplyFileCfg(wrap.cfgPath, false) + err = wrap.instance.ApplyFileCfg(wrap.cfgPath, "", "", false) if err != nil { t.Fatalf("[%s] Unexpected error %v", caseDesc, err) @@ -183,7 +183,7 @@ func TestInvalidRouteName(t *testing.T) { defer wrap.teardown() - err = wrap.instance.ApplyFileCfg(wrap.cfgPath, false) + err = wrap.instance.ApplyFileCfg(wrap.cfgPath, "", "", false) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -218,7 +218,7 @@ func TestSend(t *testing.T) { defer wrap.teardown() - err = wrap.instance.ApplyFileCfg(wrap.cfgPath, false) + err = wrap.instance.ApplyFileCfg(wrap.cfgPath, "", "", false) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -280,7 +280,7 @@ func TestCallBack(t *testing.T) { defer wrap.teardown() - err = wrap.instance.ApplyFileCfg(wrap.cfgPath, false) + err = wrap.instance.ApplyFileCfg(wrap.cfgPath, "", "", false) if err != nil { t.Fatalf("Unexpected error %v", err) } diff --git a/router/router.go b/router/router.go index d4dfee6e..755ce52a 100644 --- a/router/router.go +++ b/router/router.go @@ -7,7 +7,6 @@ import ( "io/ioutil" "log" "net/http" - "os" "path" "strings" "sync" @@ -75,16 +74,15 @@ func Instance() *Router { } func (ctx *Router) ReloadConfig() { ctx.Terminate() - err := ctx.ApplyFileCfg(ctx.cfgfile, ctx.synchronous) + + err := ctx.ApplyFileCfg(ctx.cfgfile, "", ctx.cfgfile, ctx.synchronous) if err != nil { log.Printf("Unable to start router: %s", err) } } -func (ctx *Router) initCfg(synchronous bool) { - ctx.cleanInstance() - +func (ctx *Router) cleanChannels(synchronous bool) { ctx.synchronous = synchronous if !ctx.synchronous { @@ -98,7 +96,7 @@ func (ctx *Router) initCfg(synchronous bool) { } } -func (ctx *Router) ApplyFileCfg(cfgfile string, synchronous bool) error { +func (ctx *Router) ApplyFileCfg(cfgfile, postgresUrl, pathToDb string, synchronous bool) error { log.Printf("Starting Router....") ctx.cfgfile = cfgfile @@ -108,45 +106,40 @@ func (ctx *Router) ApplyFileCfg(cfgfile string, synchronous bool) error { return err } - postgresUrl := os.Getenv("POSTGRES_URL") - pathToDb := os.Getenv("PATH_TO_DB") - if err = dbservice.ConfigureDb(pathToDb, postgresUrl, tenant.Name); err != nil { return err } - ctx.initCfg(synchronous) - - err = ctx.initTenantSettings(tenant) + err = ctx.applyTenantCfg(tenant, synchronous) if err != nil { return err } - - if !ctx.synchronous { - go ctx.listen() - } - return nil } -func (ctx *Router) ApplyPostgresCfg(tenantName string, synchronous bool) error { +func (ctx *Router) ApplyPostgresCfg(tenantName, postgresUrl string, synchronous bool) error { log.Printf("Starting Router....") - postgresUrl := os.Getenv("POSTGRES_URL") - pathToDb := os.Getenv("PATH_TO_DB") - - if err := dbservice.ConfigureDb(pathToDb, postgresUrl, tenantName); err != nil { + if err := dbservice.ConfigureDb("", postgresUrl, tenantName); err != nil { return err } - ctx.initCfg(synchronous) - tenant, err := ctx.loadCfgCacheSourceFromPostgres() if err != nil { return err } + err = ctx.applyTenantCfg(tenant, synchronous) + if err != nil { + return err + } + return nil - err = ctx.initTenantSettings(tenant) +} +func (ctx *Router) applyTenantCfg(tenant *data.TenantSettings, synchronous bool) error { + ctx.cleanInstance() + ctx.cleanChannels(synchronous) + + err := ctx.initTenantSettings(tenant) if err != nil { return err } @@ -156,6 +149,7 @@ func (ctx *Router) ApplyPostgresCfg(tenantName string, synchronous bool) error { } return nil + } func (ctx *Router) Terminate() { From 77e762c4345532c6208f72f6b1757367924fcee4 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Tue, 21 Dec 2021 16:07:06 +0600 Subject: [PATCH 41/61] test: fixed api tests --- router/api.go | 2 +- router/api_integration_test.go | 2 ++ router/api_test.go | 29 ++++++++++------------------- router/router.go | 8 +++++++- 4 files changed, 20 insertions(+), 21 deletions(-) diff --git a/router/api.go b/router/api.go index d48631a5..73b901be 100644 --- a/router/api.go +++ b/router/api.go @@ -54,7 +54,7 @@ func WithDefaultConfigAndDbPath(dbPath string) error { func WithFileConfigAndDbPath(cfgPath, dbPath string) error { Instance().Terminate() - return Instance().ApplyFileCfg(cfgPath, "", defaultDbPath, true) + return Instance().ApplyFileCfg(cfgPath, "", dbPath, true) } func WithPostgresParams(tenantName, dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) { diff --git a/router/api_integration_test.go b/router/api_integration_test.go index 3be44e2d..d2f5d858 100644 --- a/router/api_integration_test.go +++ b/router/api_integration_test.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" "github.com/aquasecurity/postee/data" @@ -77,6 +78,7 @@ func TestAudit(t *testing.T) { Template: "audit-json-template", }) router.Send([]byte(msg)) + defer os.Remove("webhooks.db") got := <-received assert.Equal(t, string(got), want, "unexpected response") } diff --git a/router/api_test.go b/router/api_test.go index 62f25a98..b3d4869f 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -325,7 +325,7 @@ func TestEditTemplate(t *testing.T) { } func TestListTemplate(t *testing.T) { - if len(Instance().templates) > 0 { + if len(Instance().templates) > 0 || Instance().templates == nil { Instance().cleanInstance() } defer Instance().cleanInstance() @@ -343,7 +343,7 @@ func TestListTemplate(t *testing.T) { } func TestSetInputCallbackFunc(t *testing.T) { - if len(Instance().inputCallBacks) > 0 { + if len(Instance().inputCallBacks) > 0 || Instance().inputCallBacks == nil { Instance().cleanInstance() } defer Instance().cleanInstance() @@ -372,12 +372,12 @@ func TestConfigFuncs(t *testing.T) { dbPath string psqlUrl string }{ - {"WithDefaultConfig", withDefaultConfigTest, "", false, "raw", "my-slack", "route1", "/server/database/webhooks.db", ""}, - {"WithFileConfig", withFileConfigTest, "", false, "raw", "my-slack", "route1", "/server/database/webhooks.db", ""}, - {"WithDefaultConfigAndDbPath", withDefaultConfigAndDbPathTest, "", false, "raw", "my-slack", "route1", "database/webhooks.db", ""}, - {"WithFileConfigAndDbPath", withFileConfigAndDbPathTest, "", false, "raw", "my-slack", "route1", "database/webhooks.db", ""}, + {"WithDefaultConfig", withDefaultConfigTest, "", false, "raw", "my-slack", "route1", "./webhooks.db", ""}, + {"WithFileConfig", withFileConfigTest, "", false, "raw", "my-slack", "route1", "./webhooks.db", ""}, + {"WithDefaultConfigAndDbPath", withDefaultConfigAndDbPathTest, "", false, "raw", "my-slack", "route1", "test/webhooks.db", ""}, + {"WithFileConfigAndDbPath", withFileConfigAndDbPathTest, "", false, "raw", "my-slack", "route1", "test/webhooks.db", ""}, {"WithNewConfig", withNewConfigTest, "", true, "", "", "", "./webhooks.db", ""}, - {"WithNewConfigAndDbPath", withNewConfigAndDbPathTest, "", true, "", "", "", "./webhooks.db", ""}, + {"WithNewConfigAndDbPath", withNewConfigAndDbPathTest, "", true, "", "", "", "test/webhooks.db", ""}, {"WithPostgresParams", withPostgresParamsTest, "ParamsTenantName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, {"WithPostgresUrl", withPostgresUrlTest, "tenantName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, } @@ -388,15 +388,6 @@ func TestConfigFuncs(t *testing.T) { dbservice.Db = nil }() - savedPathToDb := os.Getenv("PATH_TO_DB") - savedPostgresUrl := os.Getenv("POSTGRES_URL") - os.Setenv("PATH_TO_DB", test.dbPath) - os.Setenv("POSTGRES_URL", test.psqlUrl) - defer func() { - os.Setenv("PATH_TO_DB", savedPathToDb) - os.Setenv("POSTGRES_URL", savedPostgresUrl) - }() - err := test.f() if err != nil { t.Errorf("unexpected error: %v", err) @@ -432,7 +423,7 @@ func TestConfigFuncs(t *testing.T) { var ( cfgPath = "test/cfg.yaml" tenantName = "tenantName" - dbPath = "test/cfg.yaml" + dbPath = "test/webhooks.db" cfg = `Name: tenant @@ -486,13 +477,13 @@ var withFileConfigTest = func() error { var withNewConfigTest = func() error { WithNewConfig(tenantName) - os.Remove(defaultDbPath) + defer os.RemoveAll(filepath.Dir(dbPath)) return nil } var withNewConfigAndDbPathTest = func() error { WithNewConfigAndDbPath(tenantName, dbPath) - os.Remove(dbPath) + defer os.RemoveAll(filepath.Dir(dbPath)) return nil } diff --git a/router/router.go b/router/router.go index 755ce52a..9e3d9c3a 100644 --- a/router/router.go +++ b/router/router.go @@ -75,7 +75,13 @@ func Instance() *Router { func (ctx *Router) ReloadConfig() { ctx.Terminate() - err := ctx.ApplyFileCfg(ctx.cfgfile, "", ctx.cfgfile, ctx.synchronous) + tenant, err := Parsev2cfg(ctx.cfgfile) + if err != nil { + log.Printf("Failed to parse cfg file %s", err) + return + } + + err = ctx.applyTenantCfg(tenant, ctx.synchronous) if err != nil { log.Printf("Unable to start router: %s", err) From 46716637ee4904e713210da289b184f4382545da Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Tue, 21 Dec 2021 16:09:07 +0600 Subject: [PATCH 42/61] test: fixed api tests --- router/api_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/api_test.go b/router/api_test.go index b3d4869f..d3979284 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -325,7 +325,7 @@ func TestEditTemplate(t *testing.T) { } func TestListTemplate(t *testing.T) { - if len(Instance().templates) > 0 || Instance().templates == nil { + if len(Instance().templates) > 0 { Instance().cleanInstance() } defer Instance().cleanInstance() From 2e2c32c50ee9fffc402f33453c4178d7a71f66e9 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Fri, 24 Dec 2021 14:52:44 +0600 Subject: [PATCH 43/61] feat(logger): added default logger --- dbservice/boltdb/checker.go | 12 +-- dbservice/postgresdb/actions_test.go | 3 +- dbservice/postgresdb/cfgcachesource_test.go | 3 +- dbservice/postgresdb/checker.go | 12 +-- dbservice/postgresdb/checker_test.go | 5 +- dbservice/postgresdb/dbaggregator_test.go | 3 +- dbservice/postgresdb/dbservice_test.go | 11 ++- dbservice/postgresdb/plgstats_test.go | 5 +- dbservice/postgresdb/sharedcfg_test.go | 7 +- formatting/slackmrkdwnprovider.go | 9 ++- log/logger.go | 27 +++++++ log/stdoutlogger/stdoutlogger.go | 84 +++++++++++++++++++++ main.go | 9 ++- msgservice/msghandling.go | 22 +++--- msgservice/scheduler.go | 10 +-- outputs/email.go | 22 +++--- outputs/jira.go | 46 +++++------ outputs/plugin.go | 4 +- outputs/servicenow.go | 16 ++-- outputs/slack.go | 16 ++-- outputs/splunk.go | 20 ++--- outputs/teams.go | 14 ++-- outputs/webhook.go | 15 ++-- regoservice/eval.go | 6 +- regoservice/jsonformat.go | 6 +- router/parsecfg.go | 6 +- router/router.go | 56 +++++++------- routes/aggrtimeout.go | 7 +- slack/sendtoslack.go | 5 +- utils/utils.go | 8 +- webserver/tenant.go | 6 +- webserver/webserver.go | 30 ++++---- 32 files changed, 306 insertions(+), 199 deletions(-) create mode 100644 log/logger.go create mode 100644 log/stdoutlogger/stdoutlogger.go diff --git a/dbservice/boltdb/checker.go b/dbservice/boltdb/checker.go index 1cfea8bd..6d185aa4 100644 --- a/dbservice/boltdb/checker.go +++ b/dbservice/boltdb/checker.go @@ -2,10 +2,10 @@ package boltdb import ( "bytes" - "log" "time" "github.com/aquasecurity/postee/dbservice/dbparam" + "github.com/aquasecurity/postee/log" bolt "go.etcd.io/bbolt" ) @@ -18,7 +18,7 @@ func (boltDb *BoltDb) CheckSizeLimit() { db, err := bolt.Open(boltDb.DbPath, 0666, nil) if err != nil { - log.Println("CheckSizeLimit: Can't open db:", boltDb.DbPath) + log.Logger.Errorf("CheckSizeLimit: Can't open db: %s", boltDb.DbPath) return } defer db.Close() @@ -38,7 +38,7 @@ func (boltDb *BoltDb) CheckSizeLimit() { } return nil }); err != nil { - log.Println("Error a check of db size:", err) + log.Logger.Errorf("Error a check of db size: %v", err) return } } @@ -49,19 +49,19 @@ func (boltDb *BoltDb) CheckExpiredData() { db, err := bolt.Open(boltDb.DbPath, 0666, nil) if err != nil { - log.Println("CheckExpiredData: Can't open db:", boltDb.DbPath) + log.Logger.Errorf("CheckExpiredData: Can't open db: %s", boltDb.DbPath) return } defer db.Close() expired, err := boltDb.getExpired(db) if err != nil { - log.Println("Can't select expired data: ", err) + log.Logger.Errorf("Can't select expired data: %v", err) return } if err := dbDelete(db, dbparam.DbBucketName, expired); err != nil { - log.Println("Can't remove expired data: ", err) + log.Logger.Errorf("Can't remove expired data: %v", err) } } diff --git a/dbservice/postgresdb/actions_test.go b/dbservice/postgresdb/actions_test.go index 485f370b..b7d350a7 100644 --- a/dbservice/postgresdb/actions_test.go +++ b/dbservice/postgresdb/actions_test.go @@ -1,7 +1,6 @@ package postgresdb import ( - "log" "testing" "time" @@ -21,7 +20,7 @@ func TestStoreMessage(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } rows := sqlxmock.NewRows([]string{"messagevalue"}).AddRow(currentValueStoreMessage) mock.ExpectQuery("SELECT").WillReturnRows(rows) diff --git a/dbservice/postgresdb/cfgcachesource_test.go b/dbservice/postgresdb/cfgcachesource_test.go index 91298f0b..24e2d063 100644 --- a/dbservice/postgresdb/cfgcachesource_test.go +++ b/dbservice/postgresdb/cfgcachesource_test.go @@ -1,7 +1,6 @@ package postgresdb import ( - "log" "testing" "github.com/jmoiron/sqlx" @@ -14,7 +13,7 @@ func TestUpdateCfgCacheSource(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } rows := sqlxmock.NewRows([]string{"cfgFile"}).AddRow(cfgFile) mock.ExpectQuery("SELECT").WillReturnRows(rows) diff --git a/dbservice/postgresdb/checker.go b/dbservice/postgresdb/checker.go index 049df045..bff01a75 100644 --- a/dbservice/postgresdb/checker.go +++ b/dbservice/postgresdb/checker.go @@ -2,10 +2,10 @@ package postgresdb import ( "fmt" - "log" "time" "github.com/aquasecurity/postee/dbservice/dbparam" + "github.com/aquasecurity/postee/log" ) func (postgresDb *PostgresDb) CheckSizeLimit() { @@ -16,19 +16,19 @@ func (postgresDb *PostgresDb) CheckSizeLimit() { connectUrl := postgresDb.ConnectUrl db, err := psqlConnect(connectUrl) if err != nil { - log.Println("CheckSizeLimit: Can't open db, connectUrl: ", connectUrl) + log.Logger.Errorf("CheckSizeLimit: Can't open db, connectUrl: %s", connectUrl) return } defer db.Close() size := 0 if err = db.Get(&size, fmt.Sprintf("SELECT pg_total_relation_size('%s');", dbparam.DbBucketName)); err != nil { - log.Printf("CheckSizeLimit: Can't get db size") + log.Logger.Error("CheckSizeLimit: Can't get db size") return } if size > dbparam.DbSizeLimit { if err = deleteRowsByTenantName(db, dbparam.DbBucketName, postgresDb.TenantName); err != nil { - log.Printf("CheckSizeLimit: Can't delete tenantName's: %s from table: %s", postgresDb.TenantName, dbparam.DbBucketName) + log.Logger.Errorf("CheckSizeLimit: Can't delete tenantName's: %s from table: %s", postgresDb.TenantName, dbparam.DbBucketName) return } } @@ -38,13 +38,13 @@ func (postgresDb *PostgresDb) CheckExpiredData() { connectUrl := postgresDb.ConnectUrl db, err := psqlConnect(connectUrl) if err != nil { - log.Printf("CheckExpiredData: Can't open postgresDb: %v", err) + log.Logger.Errorf("CheckExpiredData: Can't open postgresDb: %v", err) return } defer db.Close() max := time.Now().UTC() //remove expired records if err = deleteRowsByTenantNameAndTime(db, postgresDb.TenantName, max); err != nil { - log.Printf("CheckExpiredData: Can't delete dates from table:%s, err: %v", dbparam.DbBucketName, err) + log.Logger.Errorf("CheckExpiredData: Can't delete dates from table:%s, err: %v", dbparam.DbBucketName, err) } } diff --git a/dbservice/postgresdb/checker_test.go b/dbservice/postgresdb/checker_test.go index dc86c844..f8baa5b4 100644 --- a/dbservice/postgresdb/checker_test.go +++ b/dbservice/postgresdb/checker_test.go @@ -1,7 +1,6 @@ package postgresdb import ( - "log" "testing" "time" @@ -27,7 +26,7 @@ func TestExpiredDates(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, _, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } return db, err } @@ -68,7 +67,7 @@ func TestSizeLimit(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } rows := sqlxmock.NewRows([]string{"size"}).AddRow(test.size) mock.ExpectQuery("SELECT").WillReturnRows(rows) diff --git a/dbservice/postgresdb/dbaggregator_test.go b/dbservice/postgresdb/dbaggregator_test.go index cf1fca30..d72e8917 100644 --- a/dbservice/postgresdb/dbaggregator_test.go +++ b/dbservice/postgresdb/dbaggregator_test.go @@ -1,7 +1,6 @@ package postgresdb import ( - "log" "testing" "github.com/jmoiron/sqlx" @@ -59,7 +58,7 @@ func TestAggregateScans(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } rows := sqlxmock.NewRows([]string{"saving"}).AddRow(savingTest) mock.ExpectQuery("SELECT").WillReturnRows(rows) diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index 78c47eb3..e53d4585 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -2,7 +2,6 @@ package postgresdb import ( "errors" - "log" "testing" "time" @@ -86,7 +85,7 @@ func TestInitError(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } mock.ExpectExec("CREATE").WillReturnError(initTablesErr) return db, err @@ -127,7 +126,7 @@ func TestDeleteRowsByTenantNameAndTime(t *testing.T) { }() db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } if test.wasError { mock.ExpectExec("DELETE").WillReturnError(test.expectedError) @@ -161,7 +160,7 @@ func TestDeleteRowsByTenantName(t *testing.T) { }() db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } if test.wasError { mock.ExpectExec("DELETE").WillReturnError(test.expectedError) @@ -208,7 +207,7 @@ func TestInsert(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } if test.wasQueryError { mock.ExpectQuery("SELECT").WillReturnError(test.expectedError) @@ -256,7 +255,7 @@ func TestInsertErrorSelect2Rows(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } rows := sqlxmock.NewRows([]string{"count"}).AddRow(2) mock.ExpectQuery("SELECT").WillReturnRows(rows) diff --git a/dbservice/postgresdb/plgstats_test.go b/dbservice/postgresdb/plgstats_test.go index b9a0a709..3f86c048 100644 --- a/dbservice/postgresdb/plgstats_test.go +++ b/dbservice/postgresdb/plgstats_test.go @@ -2,7 +2,6 @@ package postgresdb import ( "database/sql" - "log" "testing" "github.com/jmoiron/sqlx" @@ -20,7 +19,7 @@ func TestRegisterPlgnInvctn(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } rows := sqlxmock.NewRows([]string{"amount"}).AddRow(receivedKey) mock.ExpectQuery("SELECT").WillReturnRows(rows) @@ -60,7 +59,7 @@ func TestRegisterPlgnInvctnErrors(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } mock.ExpectQuery("SELECT").WillReturnError(test.errIn) return db, err diff --git a/dbservice/postgresdb/sharedcfg_test.go b/dbservice/postgresdb/sharedcfg_test.go index 74f79cdc..802c26ed 100644 --- a/dbservice/postgresdb/sharedcfg_test.go +++ b/dbservice/postgresdb/sharedcfg_test.go @@ -2,7 +2,6 @@ package postgresdb import ( "database/sql" - "log" "testing" "github.com/jmoiron/sqlx" @@ -16,7 +15,7 @@ func TestApiKey(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } rows := sqlxmock.NewRows([]string{"value"}).AddRow("key") mock.ExpectQuery("SELECT").WillReturnRows(rows) @@ -45,7 +44,7 @@ func TestApiKeyWithoutInit(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } mock.ExpectQuery("SELECT").WillReturnError(sql.ErrNoRows) return db, err @@ -73,7 +72,7 @@ func TestApiKeyRenewal(t *testing.T) { psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, mock, err := sqlxmock.Newx() if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Errorf("failed to open sqlmock database: %v", err) } rows := sqlxmock.NewRows([]string{"value"}).AddRow(receivedKey) mock.ExpectQuery("SELECT").WillReturnRows(rows) diff --git a/formatting/slackmrkdwnprovider.go b/formatting/slackmrkdwnprovider.go index 7ebf1371..29a26cc3 100644 --- a/formatting/slackmrkdwnprovider.go +++ b/formatting/slackmrkdwnprovider.go @@ -4,8 +4,9 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/aquasecurity/postee/data" - "log" + "github.com/aquasecurity/postee/log" ) func getMrkdwnText(text string) string { @@ -18,7 +19,7 @@ func getMrkdwnText(text string) string { } result, err := json.Marshal(block) if err != nil { - log.Printf("SlackMrkdwnProvider Error: %v", err) + log.Logger.Errorf("SlackMrkdwnProvider Error: %v", err) return "" } result = append(result, ',') @@ -76,7 +77,7 @@ func (mrkdwn *SlackMrkdwnProvider) Table(rows [][]string) string { if fields.Fields != nil { block, err := json.Marshal(fields) if err != nil { - log.Printf("SlackMrkdwnProvider Error: %v", err) + log.Logger.Errorf("SlackMrkdwnProvider Error: %v", err) return "" } builder.Write(block) @@ -127,7 +128,7 @@ func (mrkdwn *SlackMrkdwnProvider) Table(rows [][]string) string { } result, err := json.Marshal(fields) if err != nil { - log.Printf("SlackMrkdwnProvider Error: %v", err) + log.Logger.Errorf("SlackMrkdwnProvider Error: %v", err) return "" } builder.Write(result) diff --git a/log/logger.go b/log/logger.go new file mode 100644 index 00000000..cf14df83 --- /dev/null +++ b/log/logger.go @@ -0,0 +1,27 @@ +package log + +import "github.com/aquasecurity/postee/log/stdoutlogger" + +var Logger LoggerType = stdoutlogger.NewLogger() + +type LoggerType interface { + Info(args ...interface{}) + Error(args ...interface{}) + Infof(template string, args ...interface{}) + Errorf(template string, args ...interface{}) + Warn(args ...interface{}) + Warnf(template string, args ...interface{}) + Debug(args ...interface{}) + DebugF(template string, args ...interface{}) + Fatal(args ...interface{}) + Fatalf(template string, args ...interface{}) +} + +func InitDefaultLogger() { + logType := stdoutlogger.NewLogger() + Logger = logType +} + +func SetLogger(loggerType LoggerType) { + Logger = loggerType +} diff --git a/log/stdoutlogger/stdoutlogger.go b/log/stdoutlogger/stdoutlogger.go new file mode 100644 index 00000000..f916c529 --- /dev/null +++ b/log/stdoutlogger/stdoutlogger.go @@ -0,0 +1,84 @@ +package stdoutlogger + +import ( + "fmt" + "log" + "os" +) + +const ( + colorReset = "\033[0m" + colorRed = "\033[31m" + colorBlue = "\033[34m" + colorYellow = "\033[33m" + colorPurple = "\033[35m" + infoLevel = colorBlue + " [INFO] " + colorReset + warnLevel = colorYellow + " [WARN] " + colorReset + errorLevel = colorRed + " [ERROR] " + colorReset + debugLevel = colorPurple + " [DEBUG] " + colorReset +) + +type StdOutLogger struct { + logger log.Logger +} + +func NewLogger() StdOutLogger { + logger := log.New(os.Stdout, "", log.Ldate|log.Ltime) + return StdOutLogger{logger: *logger} +} + +func (stdOutLogger StdOutLogger) Info(args ...interface{}) { + stdOutLogger.logger.Print(infoLevel + getMessage("", args)) +} + +func (stdOutLogger StdOutLogger) Error(args ...interface{}) { + stdOutLogger.logger.Print(errorLevel + getMessage("", args)) +} + +func (stdOutLogger StdOutLogger) Warn(args ...interface{}) { + stdOutLogger.logger.Print(warnLevel + getMessage("", args)) +} +func (stdOutLogger StdOutLogger) Debug(args ...interface{}) { + stdOutLogger.logger.Print(debugLevel + getMessage("", args)) +} + +func (stdOutLogger StdOutLogger) Infof(template string, args ...interface{}) { + stdOutLogger.logger.Print(infoLevel + getMessage(template, args)) +} + +func (stdOutLogger StdOutLogger) Errorf(template string, args ...interface{}) { + stdOutLogger.logger.Print(errorLevel + getMessage(template, args)) +} + +func (stdOutLogger StdOutLogger) Warnf(template string, args ...interface{}) { + stdOutLogger.logger.Print(warnLevel + getMessage(template, args)) +} + +func (stdOutLogger StdOutLogger) DebugF(template string, args ...interface{}) { + stdOutLogger.logger.Print(debugLevel + getMessage(template, args)) +} + +func (stdOutLogger StdOutLogger) Fatal(args ...interface{}) { + stdOutLogger.logger.Fatal(args...) +} + +func (stdOutLogger StdOutLogger) Fatalf(template string, args ...interface{}) { + stdOutLogger.logger.Fatalf(template, args...) +} + +func getMessage(template string, fmtArgs []interface{}) string { + if len(fmtArgs) == 0 { + return template + } + + if template != "" { + return fmt.Sprintf(template, fmtArgs...) + } + + if len(fmtArgs) == 1 { + if str, ok := fmtArgs[0].(string); ok { + return str + } + } + return fmt.Sprint(fmtArgs...) +} diff --git a/main.go b/main.go index 6240e9b7..7d24f66e 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "log" "os" "os/signal" "runtime" @@ -11,6 +10,8 @@ import ( "github.com/aquasecurity/postee/utils" "github.com/aquasecurity/postee/webserver" "github.com/spf13/cobra" + + "github.com/aquasecurity/postee/log" ) const ( @@ -78,7 +79,7 @@ func main() { err := router.Instance().ApplyFileCfg(cfgfile, postgresUrl, pathToDb, false) if err != nil { - log.Printf("Can't start alert manager %v", err) + log.Logger.Errorf("Can't start alert manager %v", err) return } @@ -91,7 +92,7 @@ func main() { } err := rootCmd.Execute() if err != nil { - log.Printf("Can't start command %v", err) + log.Logger.Errorf("Can't start command %v", err) return } } @@ -103,7 +104,7 @@ func Daemonize() { go func() { sig := <-sigs - log.Println(sig) + log.Logger.Info(sig) done <- true }() diff --git a/msgservice/msghandling.go b/msgservice/msghandling.go index 9a95ad77..161e96c3 100644 --- a/msgservice/msghandling.go +++ b/msgservice/msghandling.go @@ -2,12 +2,12 @@ package msgservice import ( "encoding/json" - "log" "strings" "time" "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/log" "github.com/aquasecurity/postee/outputs" "github.com/aquasecurity/postee/regoservice" "github.com/aquasecurity/postee/routes" @@ -60,11 +60,11 @@ func (scan *MsgService) MsgHandling(in map[string]interface{}, output outputs.Ou wasStored, err := dbservice.Db.MayBeStoreMessage(input, msgKey, expired) if err != nil { - log.Printf("Error while storing input: %v", err) + log.Logger.Errorf("Error while storing input: %v", err) return } if !wasStored { - log.Printf("The same message was received before: %s", msgKey) + log.Logger.Infof("The same message was received before: %s", msgKey) return } @@ -78,7 +78,7 @@ func (scan *MsgService) MsgHandling(in map[string]interface{}, output outputs.Ou content, err := inpteval.Eval(in, *AquaServer) if err != nil { - log.Printf("Error while evaluating input: %v", err) + log.Logger.Errorf("Error while evaluating input: %v", err) return } @@ -91,7 +91,7 @@ func (scan *MsgService) MsgHandling(in map[string]interface{}, output outputs.Ou if len(aggregated) > 0 { content, err = inpteval.BuildAggregatedContent(aggregated) if err != nil { - log.Printf("Error while building aggregated content: %v", err) + log.Logger.Errorf("Error while building aggregated content: %v", err) return } send(output, content) @@ -100,10 +100,10 @@ func (scan *MsgService) MsgHandling(in map[string]interface{}, output outputs.Ou AggregateScanAndGetQueue(route.Name, content, 0, true) if !route.IsSchedulerRun() { //TODO route shouldn't have any associated logic - log.Printf("about to schedule %s\n", route.Name) + log.Logger.Infof("about to schedule %s\n", route.Name) RunScheduler(route, send, AggregateScanAndGetQueue, inpteval, &route.Name, output) } else { - log.Printf("%s is already scheduled\n", route.Name) + log.Logger.Infof("%s is already scheduled\n", route.Name) } } else { send(output, content) @@ -115,13 +115,13 @@ func send(otpt outputs.Output, cnt map[string]string) { go func() { err := otpt.Send(cnt) if err != nil { - log.Printf("Error while sending event: %v", err) + log.Logger.Errorf("Error while sending event: %v", err) } }() err := dbservice.Db.RegisterPlgnInvctn(otpt.GetName()) if err != nil { - log.Printf("Error while building aggregated content: %v", err) + log.Logger.Errorf("Error while building aggregated content: %v", err) return } @@ -138,11 +138,11 @@ func calculateExpired(UniqueMessageTimeoutSeconds int) *time.Time { var AggregateScanAndGetQueue = func(outputName string, currentContent map[string]string, counts int, ignoreLength bool) []map[string]string { aggregatedScans, err := dbservice.Db.AggregateScans(outputName, currentContent, counts, ignoreLength) if err != nil { - log.Printf("AggregateScans Error: %v", err) + log.Logger.Errorf("AggregateScans Error: %v", err) return aggregatedScans } if len(currentContent) != 0 && len(aggregatedScans) == 0 { - log.Printf("New scan was added to the queue of %q without sending.", outputName) + log.Logger.Infof("New scan was added to the queue of %q without sending.", outputName) return nil } return aggregatedScans diff --git a/msgservice/scheduler.go b/msgservice/scheduler.go index 89618fdf..6546f781 100644 --- a/msgservice/scheduler.go +++ b/msgservice/scheduler.go @@ -1,10 +1,10 @@ package msgservice import ( - "log" "time" "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/log" "github.com/aquasecurity/postee/outputs" "github.com/aquasecurity/postee/routes" ) @@ -20,7 +20,7 @@ var RunScheduler = func( name *string, output outputs.Output, ) { - log.Printf("Scheduler is activated for route %q. Period: %d sec", route.Name, route.Plugins.AggregateTimeoutSeconds) + log.Logger.Infof("Scheduler is activated for route %q. Period: %d sec", route.Name, route.Plugins.AggregateTimeoutSeconds) ticker := getTicker(route.Plugins.AggregateTimeoutSeconds) route.StartScheduler() @@ -30,15 +30,15 @@ var RunScheduler = func( select { case <-done: currentTicker.Stop() - log.Printf("Scheduler for %q was stopped", route.Name) + log.Logger.Infof("Scheduler for %q was stopped", route.Name) return case <-currentTicker.C: - log.Printf("Scheduler triggered for %q", route.Name) + log.Logger.Infof("Scheduler triggered for %q", route.Name) queue := fnAggregate(route.Name, nil, 0, false) if len(queue) > 0 { aggregated, err := inpteval.BuildAggregatedContent(queue) if err != nil { - log.Printf("Unable to build aggregated contents %v\n", err) + log.Logger.Errorf("Unable to build aggregated contents %v\n", err) } fnSend(output, aggregated) } diff --git a/outputs/email.go b/outputs/email.go index 98ecdb7e..e15e1703 100644 --- a/outputs/email.go +++ b/outputs/email.go @@ -3,7 +3,6 @@ package outputs import ( "errors" "fmt" - "log" "net" "net/smtp" "strconv" @@ -12,6 +11,7 @@ import ( "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" + "github.com/aquasecurity/postee/log" ) var ( @@ -49,7 +49,7 @@ func (email *EmailOutput) CloneSettings() *data.OutputSettings { } func (email *EmailOutput) Init() error { - log.Printf("Starting Email output %q...", email.Name) + log.Logger.Infof("Starting Email output %q...", email.Name) if email.Sender == "" { email.Sender = email.User } @@ -57,7 +57,7 @@ func (email *EmailOutput) Init() error { } func (email *EmailOutput) Terminate() error { - log.Printf("Email output terminated\n") + log.Logger.Infof("Email output terminated\n") return nil } @@ -88,11 +88,11 @@ func (email *EmailOutput) Send(content map[string]string) error { auth := smtp.PlainAuth("", email.User, email.Password, email.Host) err := smtp.SendMail(email.Host+":"+strconv.Itoa(email.Port), auth, email.Sender, recipients, []byte(msg)) if err != nil { - log.Println("SendMail Error:", err) - log.Printf("From: %q, to %v via %q", email.Sender, email.Recipients, email.Host) + log.Logger.Error("SendMail Error:", err) + log.Logger.Errorf("From: %q, to %v via %q", email.Sender, email.Recipients, email.Host) return err } - log.Println("Email was sent successfully!") + log.Logger.Infof("Email was sent successfully!") return nil } @@ -100,13 +100,13 @@ func sendViaMxServers(from, subj, msg string, recipients []string) { for _, rcpt := range recipients { at := strings.LastIndex(rcpt, "@") if at < 0 { - log.Printf("%q isn't email", rcpt) + log.Logger.Errorf("%q isn't email", rcpt) continue } host := rcpt[at+1:] mxs, err := net.LookupMX(host) if err != nil { - log.Print(err) + log.Logger.Error(err) continue } for _, mx := range mxs { @@ -117,11 +117,11 @@ func sendViaMxServers(from, subj, msg string, recipients []string) { "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", rcpt, from, subj, msg) if err := smtp.SendMail(mx.Host+":25", nil, from, []string{rcpt}, []byte(message)); err != nil { - log.Printf("SendMail error to %q via %q", rcpt, mx.Host) - log.Print(err) + log.Logger.Errorf("SendMail error to %q via %q", rcpt, mx.Host) + log.Logger.Error(err) continue } - log.Printf("The message to %q was sent successful via %q!", rcpt, mx.Host) + log.Logger.Infof("The message to %q was sent successful via %q!", rcpt, mx.Host) break } } diff --git a/outputs/jira.go b/outputs/jira.go index 03607a56..c613e5b4 100644 --- a/outputs/jira.go +++ b/outputs/jira.go @@ -5,12 +5,12 @@ import ( "errors" "fmt" "io/ioutil" - "log" "strconv" "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" + "github.com/aquasecurity/postee/log" "net/http" "net/url" @@ -73,13 +73,13 @@ func (ctx *JiraAPI) CloneSettings() *data.OutputSettings { func (ctx *JiraAPI) fetchBoardId(boardName string) { client, err := ctx.createClient() if err != nil { - log.Printf("unable to create Jira client: %s, please check your credentials.", err) + log.Logger.Errorf("unable to create Jira client: %s, please check your credentials.", err) return } boardlist, _, err := client.Board.GetAllBoards(&jira.BoardListOptions{ProjectKeyOrID: ctx.ProjectKey}) if err != nil { - log.Printf("failed to get boards from Jira API GetAllBoards with ProjectID %s. %s", ctx.ProjectKey, err) + log.Logger.Errorf("failed to get boards from Jira API GetAllBoards with ProjectID %s. %s", ctx.ProjectKey, err) return } var matches int @@ -92,36 +92,36 @@ func (ctx *JiraAPI) fetchBoardId(boardName string) { } if matches > 1 { - log.Printf("found more than one boards with name %q, working with board id %d", boardName, ctx.boardId) + log.Logger.Infof("found more than one boards with name %q, working with board id %d", boardName, ctx.boardId) } else if matches == 0 { - log.Printf("no boards found with name %s when getting all boards for User", boardName) + log.Logger.Infof("no boards found with name %s when getting all boards for User", boardName) return } else { - log.Printf("using board ID %d with Name %q", ctx.boardId, boardName) + log.Logger.Infof("using board ID %d with Name %q", ctx.boardId, boardName) } } func (ctx *JiraAPI) fetchSprintId(client jira.Client) { sprints, _, err := client.Board.GetAllSprintsWithOptions(ctx.boardId, &jira.GetAllSprintsOptions{State: "active"}) if err != nil { - log.Printf("failed to get active sprint for board ID %d from Jira API. %s", ctx.boardId, err) + log.Logger.Errorf("failed to get active sprint for board ID %d from Jira API. %s", ctx.boardId, err) return } if len(sprints.Values) > 1 { ctx.SprintId = len(sprints.Values) - 1 - log.Printf("Found more than one active sprint, using sprint id %d as the active sprint", ctx.SprintId) + log.Logger.Infof("Found more than one active sprint, using sprint id %d as the active sprint", ctx.SprintId) } else if len(sprints.Values) == 1 { if sprints.Values[0].ID != ctx.SprintId { ctx.SprintId = sprints.Values[0].ID - log.Printf("using sprint id %d as the active sprint", ctx.SprintId) + log.Logger.Infof("using sprint id %d as the active sprint", ctx.SprintId) } } else { - log.Printf("no active sprints exist in board ID %d Name %s", ctx.boardId, ctx.ProjectKey) + log.Logger.Infof("no active sprints exist in board ID %d Name %s", ctx.boardId, ctx.ProjectKey) } } func (ctx *JiraAPI) Terminate() error { - log.Printf("Jira output terminated\n") + log.Logger.Infof("Jira output terminated\n") return nil } @@ -131,7 +131,7 @@ func (ctx *JiraAPI) Init() error { } ctx.fetchBoardId(ctx.BoardName) - log.Printf("Starting Jira output %q....", ctx.Name) + log.Logger.Infof("Starting Jira output %q....", ctx.Name) if len(ctx.Password) == 0 { ctx.Password = os.Getenv("JIRA_PASSWORD") } @@ -148,7 +148,7 @@ func (ctx *JiraAPI) buildTransportClient() (*http.Client, error) { return nil, errors.New("Jira Cloud can't work with PAT") } if ctx.Password != "" { - log.Printf("Found both Password and PAT, using PAT to authenticate.") + log.Logger.Warn("Found both Password and PAT, using PAT to authenticate.") } tp := jira.BearerTokenAuthTransport{ Token: ctx.Token, @@ -188,7 +188,7 @@ func (ctx *JiraAPI) createClient() (*jira.Client, error) { func (ctx *JiraAPI) Send(content map[string]string) error { client, err := ctx.createClient() if err != nil { - log.Printf("unable to create Jira client: %s", err) + log.Logger.Errorf("unable to create Jira client: %s", err) return err } @@ -234,7 +234,7 @@ func (ctx *JiraAPI) Send(content map[string]string) error { fieldsConfig[k] = v } if len(ctx.Unknowns) > 0 { - log.Printf("added %d custom fields to issue.", len(ctx.Unknowns)) + log.Logger.Infof("added %d custom fields to issue.", len(ctx.Unknowns)) } type Version struct { @@ -244,7 +244,7 @@ func (ctx *JiraAPI) Send(content map[string]string) error { issue, err := InitIssue(client, metaProject, metaIssueType, fieldsConfig, isServerJira(ctx.Url)) if err != nil { - log.Printf("Failed to init issue: %s\n", err) + log.Logger.Errorf("Failed to init issue: %s\n", err) return err } @@ -268,15 +268,15 @@ func (ctx *JiraAPI) Send(content map[string]string) error { }) } issue.Fields.Unknowns["versions"] = affectsVersions - log.Printf("added %d affected versions into Versions field", len(ctx.AffectsVersions)) + log.Logger.Infof("added %d affected versions into Versions field", len(ctx.AffectsVersions)) } i, err := ctx.openIssue(client, issue) if err != nil { - log.Printf("Failed to open jira issue, %s\n", err) + log.Logger.Errorf("Failed to open jira issue, %s\n", err) return err } - log.Printf("Created new jira issue %s", i.ID) + log.Logger.Infof("Created new jira issue %s", i.ID) return nil } @@ -411,15 +411,15 @@ func InitIssue(c *jira.Client, metaProject *jira.MetaProject, metaIssuetype *jir } if err != nil { - log.Printf("Get Jira User info error: %v", err) + log.Logger.Errorf("Get Jira User info error: %v", err) continue } if resp.StatusCode != http.StatusOK { - log.Printf("http response failed: %q", resp.Status) + log.Logger.Errorf("http response failed: %q", resp.Status) continue } if len(users) == 0 { - log.Printf("There is no user for %q", value) + log.Logger.Errorf("There is no user for %q", value) continue } issueFields.Unknowns[jiraKey] = users[0] @@ -446,7 +446,7 @@ func findUserOnJiraServer(c *jira.Client, email string) ([]jira.User, *jira.Resp resp, err := c.Do(req, &users) if err != nil { - log.Printf("%v", err) + log.Logger.Errorf("%v", err) return nil, resp, err } return users, resp, nil diff --git a/outputs/plugin.go b/outputs/plugin.go index b5c8d2d4..de2e3710 100644 --- a/outputs/plugin.go +++ b/outputs/plugin.go @@ -2,11 +2,11 @@ package outputs import ( "fmt" - "log" "strings" "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/layout" + "github.com/aquasecurity/postee/log" ) const ( @@ -28,7 +28,7 @@ func getHandledRecipients(recipients []string, content *map[string]string, outpu if r == ApplicationScopeOwner { owners, err := getAppScopeOwners(content) if err != nil { - log.Printf("get application scope owners error for %q: %v", outputName, err) + log.Logger.Errorf("get application scope owners error for %q: %v", outputName, err) continue } result = append(result, owners...) diff --git a/outputs/servicenow.go b/outputs/servicenow.go index 81a21e6a..2226170c 100644 --- a/outputs/servicenow.go +++ b/outputs/servicenow.go @@ -2,11 +2,11 @@ package outputs import ( "encoding/json" - "log" "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" + "github.com/aquasecurity/postee/log" servicenow "github.com/aquasecurity/postee/servicenow" ) @@ -36,34 +36,34 @@ func (sn *ServiceNowOutput) CloneSettings() *data.OutputSettings { } func (sn *ServiceNowOutput) Init() error { - log.Printf("Starting ServiceNow output %q....", sn.Name) - log.Printf("Your ServiceNow Table is %q on '%s.%s'", sn.Table, sn.Instance, servicenow.BaseServer) + log.Logger.Infof("Starting ServiceNow output %q....", sn.Name) + log.Logger.Infof("Your ServiceNow Table is %q on '%s.%s'", sn.Table, sn.Instance, servicenow.BaseServer) sn.layoutProvider = new(formatting.HtmlProvider) return nil } func (sn *ServiceNowOutput) Send(content map[string]string) error { - log.Printf("Sending via ServiceNow %q", sn.Name) + log.Logger.Infof("Sending via ServiceNow %q", sn.Name) d := &servicenow.ServiceNowData{ ShortDescription: content["title"], WorkNotes: "[code]" + content["description"] + "[/code]", } body, err := json.Marshal(d) if err != nil { - log.Println("ServiceNow Error:", err) + log.Logger.Error("ServiceNow Error:", err) return err } err = servicenow.InsertRecordToTable(sn.User, sn.Password, sn.Instance, sn.Table, body) if err != nil { - log.Println("ServiceNow Error:", err) + log.Logger.Error("ServiceNow Error:", err) return err } - log.Printf("Sending via ServiceNow %q was successful!", sn.Name) + log.Logger.Infof("Sending via ServiceNow %q was successful!", sn.Name) return nil } func (sn *ServiceNowOutput) Terminate() error { - log.Printf("ServiceNow output %q terminated", sn.Name) + log.Logger.Infof("ServiceNow output %q terminated", sn.Name) return nil } diff --git a/outputs/slack.go b/outputs/slack.go index b745c62d..d2423a41 100644 --- a/outputs/slack.go +++ b/outputs/slack.go @@ -3,12 +3,12 @@ package outputs import ( "bytes" "encoding/json" - "log" "strings" "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" + "github.com/aquasecurity/postee/log" slackAPI "github.com/aquasecurity/postee/slack" ) @@ -39,7 +39,7 @@ func (slack *SlackOutput) CloneSettings() *data.OutputSettings { func (slack *SlackOutput) Init() error { slack.slackLayout = new(formatting.SlackMrkdwnProvider) - log.Printf("Starting Slack output %q....", slack.Name) + log.Logger.Infof("Starting Slack output %q....", slack.Name) return nil } @@ -63,7 +63,7 @@ func buildSlackBlock(title string, data []byte) []byte { } func (slack *SlackOutput) Send(input map[string]string) error { - log.Printf("Sending via Slack %q", slack.Name) + log.Logger.Infof("Sending via Slack %q", slack.Name) title := clearSlackText(slack.slackLayout.TitleH2(input["title"])) var body string if strings.HasSuffix(input["description"], ",") { @@ -78,7 +78,7 @@ func (slack *SlackOutput) Send(input map[string]string) error { rawBlock := make([]data.SlackBlock, 0) err := json.Unmarshal([]byte(body), &rawBlock) if err != nil { - log.Printf("Unmarshal slack sending error: %v", err) + log.Logger.Errorf("Unmarshal slack sending error: %v", err) return err } @@ -89,7 +89,7 @@ func (slack *SlackOutput) Send(input map[string]string) error { if err := slackAPI.SendToUrl(slack.Url, buildSlackBlock(title, []byte(message))); err != nil { return err } - log.Printf("Sending via Slack %q was successful!", slack.Name) + log.Logger.Infof("Sending via Slack %q was successful!", slack.Name) } else { for n := 0; n < length; { d := length - n @@ -99,10 +99,10 @@ func (slack *SlackOutput) Send(input map[string]string) error { cutData, _ := json.Marshal(rawBlock[n : n+d]) cutData = cutData[1 : len(cutData)-1] if err := slackAPI.SendToUrl(slack.Url, buildSlackBlock(title, cutData)); err != nil { - log.Printf("Sending to %q was finished with error: %v", slack.Name, err) + log.Logger.Errorf("Sending to %q was finished with error: %v", slack.Name, err) return err } else { - log.Printf("Sending [%d/%d part] to %q was successful!", + log.Logger.Infof("Sending [%d/%d part] to %q was successful!", int(n/49)+1, int(length/49)+1, slack.Name) } @@ -113,7 +113,7 @@ func (slack *SlackOutput) Send(input map[string]string) error { } func (slack *SlackOutput) Terminate() error { - log.Printf("Slack output %q terminated", slack.Name) + log.Logger.Infof("Slack output %q terminated", slack.Name) return nil } diff --git a/outputs/splunk.go b/outputs/splunk.go index 667534d1..d8b32ced 100644 --- a/outputs/splunk.go +++ b/outputs/splunk.go @@ -6,13 +6,13 @@ import ( "errors" "fmt" "io/ioutil" - "log" "net/http" "strings" "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" + "github.com/aquasecurity/postee/log" ) const defaultSizeLimit = 10000 @@ -42,18 +42,18 @@ func (splunk *SplunkOutput) CloneSettings() *data.OutputSettings { func (splunk *SplunkOutput) Init() error { splunk.splunkLayout = new(formatting.HtmlProvider) - log.Printf("Starting Splunk output %q....", splunk.Name) + log.Logger.Infof("Starting Splunk output %q....", splunk.Name) return nil } func (splunk *SplunkOutput) Send(d map[string]string) error { - log.Printf("Sending a message to %q", splunk.Name) + log.Logger.Infof("Sending a message to %q", splunk.Name) if splunk.EventLimit == 0 { splunk.EventLimit = defaultSizeLimit } if splunk.EventLimit < defaultSizeLimit { - log.Printf("[WARNING] %q has a short limit %d (default %d)", + log.Logger.Warnf("%q has a short limit %d (default %d)", splunk.Name, splunk.EventLimit, defaultSizeLimit) } @@ -64,7 +64,7 @@ func (splunk *SplunkOutput) Send(d map[string]string) error { scanInfo := new(data.ScanImageInfo) err := json.Unmarshal([]byte(d["src"]), scanInfo) if err != nil { - log.Printf("sending to %q error: %v", splunk.Name, err) + log.Logger.Errorf("sending to %q error: %v", splunk.Name, err) return err } @@ -76,7 +76,7 @@ func (splunk *SplunkOutput) Send(d map[string]string) error { for { fields, err = json.Marshal(scanInfo) if err != nil { - log.Printf("sending to %q error: %v", splunk.Name, err) + log.Logger.Errorf("sending to %q error: %v", splunk.Name, err) return err } if len(fields) < splunk.EventLimit-constLimit { @@ -95,7 +95,7 @@ func (splunk *SplunkOutput) Send(d map[string]string) error { default: msg := fmt.Sprintf("Scan result for %q is large for %q , its size if %d (limit %d)", scanInfo.Image, splunk.Name, len(fields), splunk.EventLimit) - log.Print(msg) + log.Logger.Infof(msg) return errors.New(msg) } } @@ -119,15 +119,15 @@ func (splunk *SplunkOutput) Send(d map[string]string) error { if resp.StatusCode != http.StatusOK { defer resp.Body.Close() b, _ := ioutil.ReadAll(resp.Body) - log.Printf("Splunk sending error: failed response status %q. Body: %q", resp.Status, string(b)) + log.Logger.Errorf("Splunk sending error: failed response status %q. Body: %q", resp.Status, string(b)) return errors.New("failed response status for Splunk sending") } - log.Printf("Sending a message to %q was successful!", splunk.Name) + log.Logger.Infof("Sending a message to %q was successful!", splunk.Name) return nil } func (splunk *SplunkOutput) Terminate() error { - log.Printf("Splunk output %q terminated", splunk.Name) + log.Logger.Infof("Splunk output %q terminated", splunk.Name) return nil } diff --git a/outputs/teams.go b/outputs/teams.go index c9fa2acb..1c04b408 100644 --- a/outputs/teams.go +++ b/outputs/teams.go @@ -2,11 +2,11 @@ package outputs import ( "encoding/json" - "log" "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" + "github.com/aquasecurity/postee/log" "github.com/aquasecurity/postee/utils" msteams "github.com/aquasecurity/postee/teams" @@ -37,13 +37,13 @@ func (teams *TeamsOutput) CloneSettings() *data.OutputSettings { } func (teams *TeamsOutput) Init() error { - log.Printf("Starting MS Teams output %q....", teams.Name) + log.Logger.Infof("Starting MS Teams output %q....", teams.Name) teams.teamsLayout = new(formatting.HtmlProvider) return nil } func (teams *TeamsOutput) Send(input map[string]string) error { - log.Printf("Sending to MS Teams via %q...", teams.Name) + log.Logger.Infof("Sending to MS Teams via %q...", teams.Name) utils.Debug("Title for %q: %q\n", teams.Name, input["title"]) utils.Debug("Url(s) for %q: %q\n", teams.Name, input["url"]) utils.Debug("Webhook for %q: %q\n", teams.Name, teams.Webhook) @@ -62,23 +62,23 @@ func (teams *TeamsOutput) Send(input map[string]string) error { escaped, err := escapeJSON(body) if err != nil { - log.Printf("Error while escaping payload: %v", err) + log.Logger.Errorf("Error while escaping payload: %v", err) return err } err = msteams.CreateMessageByWebhook(teams.Webhook, teams.teamsLayout.TitleH2(input["title"])+escaped) if err != nil { - log.Printf("TeamsOutput Send Error: %v", err) + log.Logger.Errorf("TeamsOutput Send Error: %v", err) return err } - log.Printf("Sending to MS Teams via %q was successful!", teams.Name) + log.Logger.Infof("Sending to MS Teams via %q was successful!", teams.Name) return nil } func (teams *TeamsOutput) Terminate() error { - log.Printf("MS Teams output %q terminated", teams.Name) + log.Logger.Infof("MS Teams output %q terminated", teams.Name) return nil } diff --git a/outputs/webhook.go b/outputs/webhook.go index 3e2fdcb3..bbc97309 100644 --- a/outputs/webhook.go +++ b/outputs/webhook.go @@ -3,13 +3,13 @@ package outputs import ( "fmt" "io/ioutil" - "log" "net/http" "strings" "github.com/aquasecurity/postee/data" "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" + "github.com/aquasecurity/postee/log" ) type WebhookOutput struct { @@ -31,37 +31,36 @@ func (webhook *WebhookOutput) CloneSettings() *data.OutputSettings { } func (webhook *WebhookOutput) Init() error { - log.Printf("Starting Webhook output %q, for sending to %q", + log.Logger.Infof("Starting Webhook output %q, for sending to %q", webhook.Name, webhook.Url) return nil } func (webhook *WebhookOutput) Send(content map[string]string) error { - log.Printf("Sending webhook to %q", webhook.Url) + log.Logger.Infof("Sending webhook to %q", webhook.Url) data := content["description"] //it's not supposed to work with legacy renderer resp, err := http.Post(webhook.Url, "application/json", strings.NewReader(data)) if err != nil { - log.Printf("Sending webhook Error: %v", err) + log.Logger.Errorf("Sending webhook Error: %v", err) return err } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - log.Printf("Sending %q Error: %v", webhook.Name, err) + log.Logger.Errorf("Sending %q Error: %v", webhook.Name, err) return err } if resp.StatusCode != http.StatusOK { msg := "sending webhook wrong status: %q. Body: %s" - log.Printf(msg, resp.StatusCode, body) return fmt.Errorf(msg, resp.StatusCode, body) } - log.Printf("Sending Webhook to %q was successful!", webhook.Name) + log.Logger.Infof("Sending Webhook to %q was successful!", webhook.Name) return nil } func (webhook *WebhookOutput) Terminate() error { - log.Printf("Webhook output %q terminated.", webhook.Name) + log.Logger.Infof("Webhook output %q terminated.", webhook.Name) return nil } diff --git a/regoservice/eval.go b/regoservice/eval.go index dc4e9431..b3b6ebbe 100644 --- a/regoservice/eval.go +++ b/regoservice/eval.go @@ -5,10 +5,10 @@ import ( "encoding/json" "errors" "fmt" - "log" "os" "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/log" "github.com/open-policy-agent/opa/rego" ) @@ -78,7 +78,7 @@ func (regoEvaluator *regoEvaluator) Eval(in map[string]interface{}, serverUrl st func getFirstElement(context map[string]interface{}, key string) interface{} { for _, v := range context { - log.Printf("checking: %s ...\n", key) + log.Logger.Infof("checking: %s ...\n", key) childCtx, ok := v.(map[string]interface{}) if !ok { return nil @@ -229,7 +229,7 @@ func buildAggregatedRego(query *rego.PreparedEvalQuery) (*rego.PreparedEvalQuery } } else { //it's ok skip aggregation package - no aggregation features will be available - log.Printf("No aggregation package configured!!!") + log.Logger.Infof("No aggregation package configured!!!") } return aggrQuery, nil } diff --git a/regoservice/jsonformat.go b/regoservice/jsonformat.go index bc7f4018..f6f897e7 100644 --- a/regoservice/jsonformat.go +++ b/regoservice/jsonformat.go @@ -2,8 +2,8 @@ package regoservice import ( "encoding/json" - "log" + "github.com/aquasecurity/postee/log" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/types" @@ -21,13 +21,13 @@ func jsonFmtFunc() func(r *rego.Rego) { err := ast.As(a.Value, &obj) if err != nil { //Rego doesn't show errors - log.Printf("Can't convert OPA object: %v\n", err) + log.Logger.Errorf("Can't convert OPA object: %v\n", err) return nil, err } b, err := json.MarshalIndent(obj, "", " ") if err != nil { //Rego doesn't show errors - log.Printf("Error while json format: %v\n", err) + log.Logger.Errorf("Error while json format: %v\n", err) return nil, err } return ast.StringTerm(string(b)), nil diff --git a/router/parsecfg.go b/router/parsecfg.go index a3a6954a..4eaa4671 100644 --- a/router/parsecfg.go +++ b/router/parsecfg.go @@ -3,9 +3,9 @@ package router import ( "bytes" "io/ioutil" - "log" "github.com/aquasecurity/postee/data" + "github.com/aquasecurity/postee/log" "github.com/ghodss/yaml" ) @@ -28,7 +28,6 @@ const ( func Parsev2cfg(cfgpath string) (*data.TenantSettings, error) { b, err := ioutil.ReadFile(cfgpath) if err != nil { - log.Printf("Failed to open file %s, %s", cfgpath, err) return nil, err } @@ -38,7 +37,6 @@ func Parsev2cfg(cfgpath string) (*data.TenantSettings, error) { err = yaml.Unmarshal(b, tenant) if err != nil { - log.Printf("Failed yaml.Unmarshal, %s", err) return nil, err } @@ -47,6 +45,6 @@ func Parsev2cfg(cfgpath string) (*data.TenantSettings, error) { } func checkV1Cfg(data []byte, cfgpath string) { if bytes.Index(data, []byte(v1Marker)) > -1 { - log.Printf(v1Warning, cfgpath) + log.Logger.Warnf(v1Warning, cfgpath) } } diff --git a/router/router.go b/router/router.go index 9e3d9c3a..db197a34 100644 --- a/router/router.go +++ b/router/router.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io/ioutil" - "log" "net/http" "path" "strings" @@ -17,6 +16,7 @@ import ( "github.com/aquasecurity/postee/dbservice/dbparam" "github.com/aquasecurity/postee/dbservice/postgresdb" "github.com/aquasecurity/postee/formatting" + "github.com/aquasecurity/postee/log" "github.com/aquasecurity/postee/msgservice" "github.com/aquasecurity/postee/outputs" "github.com/aquasecurity/postee/regoservice" @@ -77,14 +77,14 @@ func (ctx *Router) ReloadConfig() { tenant, err := Parsev2cfg(ctx.cfgfile) if err != nil { - log.Printf("Failed to parse cfg file %s", err) + log.Logger.Errorf("Failed to parse cfg file %s", err) return } err = ctx.applyTenantCfg(tenant, ctx.synchronous) if err != nil { - log.Printf("Unable to start router: %s", err) + log.Logger.Errorf("Unable to start router: %s", err) } } @@ -103,7 +103,7 @@ func (ctx *Router) cleanChannels(synchronous bool) { } func (ctx *Router) ApplyFileCfg(cfgfile, postgresUrl, pathToDb string, synchronous bool) error { - log.Printf("Starting Router....") + log.Logger.Info("Starting Router....") ctx.cfgfile = cfgfile @@ -124,7 +124,7 @@ func (ctx *Router) ApplyFileCfg(cfgfile, postgresUrl, pathToDb string, synchrono } func (ctx *Router) ApplyPostgresCfg(tenantName, postgresUrl string, synchronous bool) error { - log.Printf("Starting Router....") + log.Logger.Info("Starting Router....") if err := dbservice.ConfigureDb("", postgresUrl, tenantName); err != nil { return err @@ -159,32 +159,32 @@ func (ctx *Router) applyTenantCfg(tenant *data.TenantSettings, synchronous bool) } func (ctx *Router) Terminate() { - log.Printf("Terminating Router....") + log.Logger.Info("Terminating Router....") for _, pl := range ctx.outputs { err := pl.Terminate() if err != nil { - log.Printf("failed to terminate output: %v", err) + log.Logger.Errorf("failed to terminate output: %v", err) } } - log.Printf("Outputs terminated") + log.Logger.Info("Outputs terminated") for _, route := range ctx.inputRoutes { route.StopScheduler() } - log.Printf("Route schedulers stopped") + log.Logger.Info("Route schedulers stopped") - log.Printf("ctx.quit %v\n", ctx.quit) + log.Logger.Infof("ctx.quit %v\n", ctx.quit) if ctx.quit != nil { ctx.quit <- struct{}{} } - log.Printf("quit notified") + log.Logger.Info("quit notified") if ctx.ticker != nil { ctx.stopTicker <- struct{}{} - log.Printf("stopTicker notified") + log.Logger.Info("stopTicker notified") } ctx.cleanInstance() @@ -248,7 +248,7 @@ func removeTemplateFromCfgCacheSource(outputs *data.TenantSettings, templateName } func (ctx *Router) initTemplate(template *data.Template) error { - log.Printf("Configuring template %s \n", template.Name) + log.Logger.Infof("Configuring template %s \n", template.Name) if template.LegacyScanRenderer != "" { inpteval, err := formatting.BuildLegacyScnEvaluator(template.LegacyScanRenderer) @@ -256,7 +256,7 @@ func (ctx *Router) initTemplate(template *data.Template) error { return err } ctx.templates[template.Name] = inpteval - log.Printf("Configured with legacy renderer %s \n", template.LegacyScanRenderer) + log.Logger.Infof("Configured with legacy renderer %s \n", template.LegacyScanRenderer) } if template.RegoPackage != "" { @@ -265,10 +265,10 @@ func (ctx *Router) initTemplate(template *data.Template) error { return err } ctx.templates[template.Name] = inpteval - log.Printf("Configured with Rego package %s\n", template.RegoPackage) + log.Logger.Infof("Configured with Rego package %s\n", template.RegoPackage) } if template.Url != "" { - log.Printf("Configured with url: %s\n", template.Url) + log.Logger.Infof("Configured with url: %s\n", template.Url) r, err := http.NewRequest("GET", template.Url, nil) if err != nil { @@ -317,14 +317,14 @@ func (ctx *Router) setAquaServerUrl(url string) { } ctx.databaseCfgCacheSource.AquaServer = url if err := ctx.saveCfgCacheSourceInPostgres(); err != nil { - log.Printf("Can't save cfgSource Source: %v", err) + log.Logger.Errorf("Can't save cfgSource Source: %v", err) } } func (ctx *Router) initTenantSettings(tenant *data.TenantSettings) error { ctx.mutexScan.Lock() defer ctx.mutexScan.Unlock() - log.Printf("Loading alerts configuration file %s ....\n", ctx.cfgfile) + log.Logger.Infof("Loading alerts configuration file %s ....\n", ctx.cfgfile) ctx.setAquaServerUrl(tenant.AquaServer) @@ -356,7 +356,7 @@ func (ctx *Router) initTenantSettings(tenant *data.TenantSettings) error { for _, t := range tenant.Templates { err := ctx.initTemplate(&t) if err != nil { - log.Printf("Can not initialize template %s: %v \n", t.Name, err) + log.Logger.Errorf("Can not initialize template %s: %v \n", t.Name, err) } } @@ -366,9 +366,9 @@ func (ctx *Router) initTenantSettings(tenant *data.TenantSettings) error { err := ctx.addOutput(&settings) if err != nil { - log.Printf("Can not initialize output %s: %v \n", settings.Name, err) + log.Logger.Errorf("Can not initialize output %s: %v \n", settings.Name, err) } else { - log.Printf("Output %s is configured", settings.Name) + log.Logger.Infof("Output %s is configured", settings.Name) } } @@ -386,7 +386,7 @@ func (ctx *Router) addRoute(r *routes.InputRoute) { ctx.inputRoutes[r.Name] = routes.ConfigureTimeouts(r) ctx.databaseCfgCacheSource.InputRoutes = append(ctx.databaseCfgCacheSource.InputRoutes, *r) if err := ctx.saveCfgCacheSourceInPostgres(); err != nil { - log.Printf("Can't save cfgSource Source: %v", err) + log.Logger.Errorf("Can't save cfgSource Source: %v", err) } } @@ -547,11 +547,11 @@ var getHttpClient = func() *http.Client { func (ctx *Router) HandleRoute(routeName string, in []byte) { r, ok := ctx.inputRoutes[routeName] if !ok || r == nil { - log.Printf("There isn't route %q", routeName) + log.Logger.Errorf("There isn't route %q", routeName) return } if len(r.Outputs) == 0 { - log.Printf("route %q has no outputs", routeName) + log.Logger.Errorf("route %q has no outputs", routeName) return } inMsg := map[string]interface{}{} @@ -578,16 +578,16 @@ func (ctx *Router) HandleRoute(routeName string, in []byte) { for _, outputName := range r.Outputs { pl, ok := ctx.outputs[outputName] if !ok { - log.Printf("route %q contains an output %q, which doesn't enable now.", routeName, outputName) + log.Logger.Errorf("route %q contains an output %q, which doesn't enable now.", routeName, outputName) continue } tmpl, ok := ctx.templates[r.Template] if !ok { - log.Printf("route %q contains reference to undefined or misconfigured template %q.", + log.Logger.Errorf("route %q contains reference to undefined or misconfigured template %q.", routeName, r.Template) continue } - log.Printf("route %q is associated with template %q", routeName, r.Template) + log.Logger.Infof("route %q is associated with template %q", routeName, r.Template) if ctx.synchronous { getScanService().MsgHandling(inMsg, pl, r, tmpl, &ctx.aquaServer) @@ -647,7 +647,7 @@ func buildAndInitOtpt(settings *data.OutputSettings, aquaServerUrl string) (outp } err := plg.Init() if err != nil { - log.Printf("failed to Init : %v", err) + log.Logger.Errorf("failed to Init : %v", err) } return plg, nil diff --git a/routes/aggrtimeout.go b/routes/aggrtimeout.go index 593ba649..70c27ebc 100644 --- a/routes/aggrtimeout.go +++ b/routes/aggrtimeout.go @@ -1,9 +1,10 @@ package routes import ( - "log" "strconv" "strings" + + "github.com/aquasecurity/postee/log" ) func parseTimeouts(v string) (int, error) { @@ -41,7 +42,7 @@ func parseTimeouts(v string) (int, error) { func ConfigureTimeouts(route *InputRoute) *InputRoute { aggregateTimeoutSeconds, err := parseTimeouts(route.Plugins.AggregateMessageTimeout) if err != nil { - log.Printf("%q settings: Can't convert 'aggregate-message-timeout'(%q) to seconds.", + log.Logger.Errorf("%q settings: Can't convert 'aggregate-message-timeout'(%q) to seconds.", route.Name, route.Plugins.AggregateMessageTimeout) } @@ -49,7 +50,7 @@ func ConfigureTimeouts(route *InputRoute) *InputRoute { uniqueMessageTimeoutSeconds, err := parseTimeouts(route.Plugins.UniqueMessageTimeout) if err != nil { - log.Printf("%q settings: Can't convert 'unique-message-timeout'(%q) to seconds.", + log.Logger.Errorf("%q settings: Can't convert 'unique-message-timeout'(%q) to seconds.", route.Name, route.Plugins.UniqueMessageTimeout) } diff --git a/slack/sendtoslack.go b/slack/sendtoslack.go index e770c705..2c8e13d0 100644 --- a/slack/sendtoslack.go +++ b/slack/sendtoslack.go @@ -4,15 +4,16 @@ import ( "bytes" "fmt" "io/ioutil" - "log" "net/http" + + "github.com/aquasecurity/postee/log" ) func SendToUrl(url string, data []byte) error { r := bytes.NewReader(data) resp, err := http.Post(url, "application/json", r) if err != nil { - log.Printf("Slack API error: %v", err) + log.Logger.Errorf("Slack API error: %v", err) return err } if resp.StatusCode != http.StatusOK { diff --git a/utils/utils.go b/utils/utils.go index 1d5dcb33..de3c3954 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,10 +2,12 @@ package utils import ( "errors" - "log" + "os" "path/filepath" "strings" + + "github.com/aquasecurity/postee/log" ) var ( @@ -31,7 +33,7 @@ func InitDebug() { func Debug(format string, v ...interface{}) { if dbg { - log.Printf(format, v...) + log.Logger.DebugF(format, v...) } } @@ -65,5 +67,5 @@ func PrnInputLogs(msg string, v ...interface{}) { } } } - log.Printf(msg, v...) + log.Logger.Errorf(msg, v...) } diff --git a/webserver/tenant.go b/webserver/tenant.go index 42958594..ae92bc0d 100644 --- a/webserver/tenant.go +++ b/webserver/tenant.go @@ -2,9 +2,9 @@ package webserver import ( "io/ioutil" - "log" "net/http" + "github.com/aquasecurity/postee/log" "github.com/aquasecurity/postee/router" "github.com/aquasecurity/postee/utils" "github.com/gorilla/mux" @@ -13,14 +13,14 @@ import ( func (ctx *WebServer) tenantHandler(w http.ResponseWriter, r *http.Request) { route, ok := mux.Vars(r)["route"] if !ok || len(route) == 0 { - log.Printf("Failed route: %q", route) + log.Logger.Errorf("Failed route: %q", route) ctx.writeResponse(w, http.StatusBadRequest, "failed route") return } body, err := ioutil.ReadAll(r.Body) if err != nil { - log.Printf("Failed ioutil.ReadAll: %s", err) + log.Logger.Errorf("Failed ioutil.ReadAll: %s", err) ctx.writeResponseError(w, http.StatusInternalServerError, err) return } diff --git a/webserver/webserver.go b/webserver/webserver.go index d7d14164..5ac8d184 100644 --- a/webserver/webserver.go +++ b/webserver/webserver.go @@ -3,13 +3,13 @@ package webserver import ( "encoding/json" "io/ioutil" - "log" "net/http" "os" "path/filepath" "sync" "github.com/aquasecurity/postee/dbservice" + "github.com/aquasecurity/postee/log" "github.com/aquasecurity/postee/router" "github.com/aquasecurity/postee/utils" "github.com/gorilla/mux" @@ -37,12 +37,12 @@ func (ctx *WebServer) withApiKey(next http.HandlerFunc) http.HandlerFunc { correctKey, err := dbservice.Db.GetApiKey() if err != nil || correctKey == "" { - log.Printf("reload API key is either empty or there is an error: %s \n", err) + log.Logger.Errorf("reload API key is either empty or there is an error: %s \n", err) http.Error(w, "Unauthorized", http.StatusUnauthorized) } if key := r.URL.Query().Get("key"); key != correctKey { - log.Printf("reload API received an incorrect key %q", key) + log.Logger.Errorf("reload API received an incorrect key %q", key) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -52,16 +52,16 @@ func (ctx *WebServer) withApiKey(next http.HandlerFunc) http.HandlerFunc { } func (ctx *WebServer) Start(host, tlshost string) { - log.Printf("Starting WebServer....") + log.Logger.Info("Starting WebServer....") rootDir, _ := utils.GetRootDir() certPem := filepath.Join(rootDir, "cert.pem") keyPem := filepath.Join(rootDir, "key.pem") - if ok := utils.PathExists(keyPem); ok != true { + if ok := utils.PathExists(keyPem); !ok { err := utils.GenerateCertificate(keyPem, certPem) if err != nil { - log.Printf("GenerateCertificate error: %v \n", err) + log.Logger.Errorf("GenerateCertificate error: %v \n", err) } } @@ -74,7 +74,7 @@ func (ctx *WebServer) Start(host, tlshost string) { } err := dbservice.Db.EnsureApiKey() if err != nil { - log.Printf("EnsureApiKey error: %v \n", err) + log.Logger.Errorf("EnsureApiKey error: %v \n", err) } ctx.router.HandleFunc("/", ctx.sessionHandler(ctx.scanHandler)).Methods("POST") @@ -85,17 +85,17 @@ func (ctx *WebServer) Start(host, tlshost string) { ctx.router.HandleFunc("/reload", ctx.withApiKey(ctx.reload)).Methods("GET") go func() { - log.Printf("Listening for HTTP on %s ", host) - log.Fatal(http.ListenAndServe(host, ctx.router)) + log.Logger.Infof("Listening for HTTP on %s ", host) + log.Logger.Fatal(http.ListenAndServe(host, ctx.router)) }() go func() { - log.Printf("Listening for HTTPS on %s", tlshost) - log.Fatal(http.ListenAndServeTLS(tlshost, certPem, keyPem, ctx.router)) + log.Logger.Infof("Listening for HTTPS on %s", tlshost) + log.Logger.Fatal(http.ListenAndServeTLS(tlshost, certPem, keyPem, ctx.router)) }() } func (ctx *WebServer) Terminate() { - log.Printf("Terminating WebServer....") + log.Logger.Info("Terminating WebServer....") close(ctx.quit) } @@ -108,7 +108,7 @@ func (ctx *WebServer) sessionHandler(f func(http.ResponseWriter, *http.Request)) func (ctx *WebServer) scanHandler(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - log.Printf("Failed ioutil.ReadAll: %s\n", err) + log.Logger.Errorf("Failed ioutil.ReadAll: %s\n", err) ctx.writeResponseError(w, http.StatusInternalServerError, err) return } @@ -130,7 +130,7 @@ func (ctx *WebServer) writeResponse(w http.ResponseWriter, httpStatus int, v int result, _ := json.Marshal(v) _, err := w.Write(result) if err != nil { - log.Printf("Write error: %s \n", err) + log.Logger.Errorf("Write error: %s \n", err) } } } @@ -140,6 +140,6 @@ func (ctx *WebServer) writeResponseError(w http.ResponseWriter, httpError int, e w.WriteHeader(httpError) errEncode := json.NewEncoder(w).Encode(err) if errEncode != nil { - log.Printf("Encode error: %s \n", errEncode) + log.Logger.Errorf("Encode error: %s \n", errEncode) } } From 2b73ecc5013b6d37523c43a5d9f57e455a9acea7 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Fri, 24 Dec 2021 16:47:43 +0600 Subject: [PATCH 44/61] fix: fixed golangci-lint test errors --- dbservice/dbservice.go | 6 ++++-- dbservice/dbservice_test.go | 6 +++--- dbservice/postgresdb/dbservice_test.go | 4 ++-- router/api.go | 28 ++++++++++++++++++-------- router/api_integration_test.go | 4 +++- router/api_test.go | 8 ++++++-- router/router.go | 4 +++- 7 files changed, 41 insertions(+), 19 deletions(-) diff --git a/dbservice/dbservice.go b/dbservice/dbservice.go index 5d2562d9..318d87ae 100644 --- a/dbservice/dbservice.go +++ b/dbservice/dbservice.go @@ -1,7 +1,7 @@ package dbservice import ( - "errors" + "fmt" "time" "github.com/aquasecurity/postee/dbservice/boltdb" @@ -10,6 +10,8 @@ import ( var ( Db DbProvider + + errConfigPsqlEmptyTenantName = fmt.Errorf("error configuring postgres: 'tenantName' is empty") ) type DbProvider interface { @@ -26,7 +28,7 @@ func ConfigureDb(pathToDb, postgresUrl, tenantName string) error { if postgresUrl != "" { if tenantName == "" { - return errors.New("error configuring postgres: 'tenantName' is empty") + return errConfigPsqlEmptyTenantName } postgresDb := postgresdb.NewPostgresDb(tenantName, postgresUrl) if err := postgresdb.InitPostgresDb(postgresDb.ConnectUrl); err != nil { diff --git a/dbservice/dbservice_test.go b/dbservice/dbservice_test.go index 638fe6a5..b86f9e19 100644 --- a/dbservice/dbservice_test.go +++ b/dbservice/dbservice_test.go @@ -45,7 +45,7 @@ func TestConfiguratePostgresDbUrlAndTenantName(t *testing.T) { expectedError error }{ {"happy configuration postgres with url", "postgresql://user:secret@localhost", "test-tenantName", nil}, - {"bad tenantName", "postgresql://user:secret@localhost", "", errors.New("error configuring postgres: 'tenantName' is empty")}, + {"bad tenantName", "postgresql://user:secret@localhost", "", errConfigPsqlEmptyTenantName}, {"bad url", "badUrl", "test-tenantName", errors.New("badUrl error")}, } @@ -54,7 +54,7 @@ func TestConfiguratePostgresDbUrlAndTenantName(t *testing.T) { initPostgresDbSaved := postgresdb.InitPostgresDb postgresdb.InitPostgresDb = func(connectUrl string) error { if connectUrl == "badUrl" { - return errors.New("badUrl error") + return test.expectedError } return nil } @@ -64,7 +64,7 @@ func TestConfiguratePostgresDbUrlAndTenantName(t *testing.T) { err := ConfigureDb("", test.url, test.tenantName) if err != nil { - if err.Error() != test.expectedError.Error() { + if !errors.Is(err, test.expectedError) { t.Errorf("Unexpected error, expected: %s, got: %s", test.expectedError, err) } } else { diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index 78c47eb3..b0c079e0 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -93,7 +93,7 @@ func TestInitError(t *testing.T) { } err := InitPostgresDb("connectUrl") - if err.Error() != initTablesErr.Error() { + if !errors.Is(err, initTablesErr) { t.Errorf("Unexpected error: expected %s, got %s \n", initTablesErr, err) } @@ -104,7 +104,7 @@ func TestInitError(t *testing.T) { testConnect = savedTestConnect }() err = InitPostgresDb("ConnectUrl") - if err.Error() != testConnectErr.Error() { + if !errors.Is(err, testConnectErr) { t.Errorf("Unexpected error: expected %s, got %s \n", testConnectErr, err) } } diff --git a/router/api.go b/router/api.go index 73b901be..61582489 100644 --- a/router/api.go +++ b/router/api.go @@ -33,19 +33,25 @@ func WithFileConfig(cfgPath string) error { return Instance().ApplyFileCfg(cfgPath, "", defaultDbPath, true) } -func WithNewConfig(tenantName string) { //tenant name +func WithNewConfig(tenantName string) error { //tenant name Instance().Terminate() - dbservice.ConfigureDb(defaultDbPath, "", "") + if err := dbservice.ConfigureDb(defaultDbPath, "", ""); err != nil { + return err + } Instance().cleanInstance() Instance().cleanChannels(true) + return nil } //initialize instance with custom db location -func WithNewConfigAndDbPath(tenantName, dbPath string) { //tenant name +func WithNewConfigAndDbPath(tenantName, dbPath string) error { //tenant name Instance().Terminate() - dbservice.ConfigureDb(dbPath, "", "") + if err := dbservice.ConfigureDb(dbPath, "", ""); err != nil { + return err + } Instance().cleanInstance() Instance().cleanChannels(true) + return nil } func WithDefaultConfigAndDbPath(dbPath string) error { @@ -57,15 +63,21 @@ func WithFileConfigAndDbPath(cfgPath, dbPath string) error { return Instance().ApplyFileCfg(cfgPath, "", dbPath, true) } -func WithPostgresParams(tenantName, dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) { +func WithPostgresParams(tenantName, dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode string) error { postgresUrl := buildPostgresUrl(dbName, dbHostName, dbPort, dbUser, dbPassword, dbSslMode) Instance().Terminate() - Instance().ApplyPostgresCfg(tenantName, postgresUrl, true) + if err := Instance().ApplyPostgresCfg(tenantName, postgresUrl, true); err != nil { + return err + } + return nil } -func WithPostgresUrl(tenantName, postgresUrl string) { +func WithPostgresUrl(tenantName, postgresUrl string) error { Instance().Terminate() - Instance().ApplyPostgresCfg(tenantName, postgresUrl, true) + if err := Instance().ApplyPostgresCfg(tenantName, postgresUrl, true); err != nil { + return nil + } + return nil } func AquaServerUrl(aquaServerUrl string) { //optional diff --git a/router/api_integration_test.go b/router/api_integration_test.go index d2f5d858..638e2809 100644 --- a/router/api_integration_test.go +++ b/router/api_integration_test.go @@ -52,7 +52,9 @@ func TestAudit(t *testing.T) { })) defer ts.Close() - router.WithNewConfig("test") + if err := router.WithNewConfig("test"); err != nil { + t.Errorf("Unexpected WithNewConfig error: %v", err) + } err := router.AddTemplate(&data.Template{ Name: "audit-json-template", diff --git a/router/api_test.go b/router/api_test.go index d3979284..833c107e 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -476,13 +476,17 @@ var withFileConfigTest = func() error { } var withNewConfigTest = func() error { - WithNewConfig(tenantName) + if err := WithNewConfig(tenantName); err != nil { + return err + } defer os.RemoveAll(filepath.Dir(dbPath)) return nil } var withNewConfigAndDbPathTest = func() error { - WithNewConfigAndDbPath(tenantName, dbPath) + if err := WithNewConfigAndDbPath(tenantName, dbPath); err != nil { + return err + } defer os.RemoveAll(filepath.Dir(dbPath)) return nil } diff --git a/router/router.go b/router/router.go index 9e3d9c3a..5de5458a 100644 --- a/router/router.go +++ b/router/router.go @@ -461,7 +461,9 @@ func (ctx *Router) deleteOutput(outputName string, removeFromRoutes bool) error if !ok { return xerrors.Errorf("output %s is not found", outputName) } - output.Terminate() + if err := output.Terminate(); err != nil { + return err + } delete(ctx.outputs, outputName) if removeFromRoutes { From 983832ad2ab3112d05f38e5efd450d379b1cf65a Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Fri, 24 Dec 2021 19:12:30 +0600 Subject: [PATCH 45/61] fix: fixed golangci-lint test errors --- dbservice/boltdb/dbparam_test.go | 13 ++-- dbservice/postgresdb/actions.go | 3 +- dbservice/postgresdb/cfgcachesource_test.go | 4 +- dbservice/postgresdb/dbaggregator.go | 3 +- dbservice/postgresdb/dbservice_test.go | 4 +- dbservice/postgresdb/plgstats.go | 3 +- dbservice/postgresdb/plgstats_test.go | 3 +- router/api_test.go | 75 +++++++++++++++------ 8 files changed, 73 insertions(+), 35 deletions(-) diff --git a/dbservice/boltdb/dbparam_test.go b/dbservice/boltdb/dbparam_test.go index eff07252..add8a146 100644 --- a/dbservice/boltdb/dbparam_test.go +++ b/dbservice/boltdb/dbparam_test.go @@ -1,7 +1,6 @@ package boltdb import ( - "errors" "os" "path/filepath" "strings" @@ -18,12 +17,12 @@ func TestSetNewDbPathFromEnv(t *testing.T) { pathToDb string changePermission bool expectedDBPath string - expectedErr error + expectedErr string }{ - {"Empty pathToDb", "", false, defaultDbPath, nil}, - {"Permission denied to create directory(default DbPath is used)", "/database/database.db", false, defaultDbPath, errors.New("mkdir /database: permission denied")}, - {"New DbPath", "./base/base.db", false, "./base/base.db", nil}, - {"Permission denied to check directory(default DbPath is used)", "webhook/database/webhooks.db", true, defaultDbPath, errors.New("stat webhook/database/webhooks.db: permission denied")}, + {"Empty pathToDb", "", false, defaultDbPath, ""}, + {"Permission denied to create directory(default DbPath is used)", "/database/database.db", false, defaultDbPath, "mkdir /database: permission denied"}, + {"New DbPath", "./base/base.db", false, "./base/base.db", ""}, + {"Permission denied to check directory(default DbPath is used)", "webhook/database/webhooks.db", true, defaultDbPath, "stat webhook/database/webhooks.db: permission denied"}, } for _, test := range tests { @@ -38,7 +37,7 @@ func TestSetNewDbPathFromEnv(t *testing.T) { t.Errorf("Can't change the mode dir in %s: %s", baseDir, err) } } - if err := db.SetNewDbPath(test.pathToDb); err != nil && errors.Is(err, test.expectedErr) { + if err := db.SetNewDbPath(test.pathToDb); err != nil && err.Error() != test.expectedErr { t.Errorf("unexpected error setNewDbPath, expected: %v, got: %v", test.expectedErr, err) } defer os.RemoveAll(baseDir) diff --git a/dbservice/postgresdb/actions.go b/dbservice/postgresdb/actions.go index 7ae23614..3e5e0641 100644 --- a/dbservice/postgresdb/actions.go +++ b/dbservice/postgresdb/actions.go @@ -2,6 +2,7 @@ package postgresdb import ( "database/sql" + "errors" "fmt" "time" @@ -18,7 +19,7 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin currentValue := "" sqlQuery := fmt.Sprintf("SELECT messageValue FROM %s WHERE (tenantName=$1 AND messageKey=$2)", dbparam.DbBucketName) if err = db.Get(¤tValue, sqlQuery, postgresDb.TenantName, messageKey); err != nil { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { return false, err } } diff --git a/dbservice/postgresdb/cfgcachesource_test.go b/dbservice/postgresdb/cfgcachesource_test.go index 91298f0b..22ee3df2 100644 --- a/dbservice/postgresdb/cfgcachesource_test.go +++ b/dbservice/postgresdb/cfgcachesource_test.go @@ -27,7 +27,9 @@ func TestUpdateCfgCacheSource(t *testing.T) { insertCfgCacheSource = savedInsertCfgCacheSource }() - UpdateCfgCacheSource(db, "cfgFile") + if err := UpdateCfgCacheSource(db, "cfgFile"); err != nil { + t.Errorf("Unexpected error: %v", err) + } cfg, err := GetCfgCacheSource(db) if err != nil { diff --git a/dbservice/postgresdb/dbaggregator.go b/dbservice/postgresdb/dbaggregator.go index 5117f28e..cd96e4fd 100644 --- a/dbservice/postgresdb/dbaggregator.go +++ b/dbservice/postgresdb/dbaggregator.go @@ -3,6 +3,7 @@ package postgresdb import ( "database/sql" "encoding/json" + "errors" "fmt" "github.com/aquasecurity/postee/dbservice/dbparam" @@ -26,7 +27,7 @@ func (postgresDb *PostgresDb) AggregateScans(output string, currentValue := []byte{} sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (tenantName=$1 AND %s=$2)", "saving", dbparam.DbBucketAggregator, "output") if err = db.Get(¤tValue, sqlQuery, postgresDb.TenantName, output); err != nil { - if err != sql.ErrNoRows { + if !errors.Is(err, sql.ErrNoRows) { return nil, err } } diff --git a/dbservice/postgresdb/dbservice_test.go b/dbservice/postgresdb/dbservice_test.go index b0c079e0..7e5ebc89 100644 --- a/dbservice/postgresdb/dbservice_test.go +++ b/dbservice/postgresdb/dbservice_test.go @@ -138,7 +138,7 @@ func TestDeleteRowsByTenantNameAndTime(t *testing.T) { } psqlDb, _ := psqlConnect(db.ConnectUrl) err := deleteRowsByTenantNameAndTime(psqlDb, "tenantName", time.Now()) - if test.expectedError != err { + if !errors.Is(test.expectedError, err) { t.Errorf("Unexpected error, expected: %v, got: %v", test.expectedError, err) } }) @@ -172,7 +172,7 @@ func TestDeleteRowsByTenantName(t *testing.T) { } psqlDb, _ := psqlConnect(db.ConnectUrl) err := deleteRowsByTenantName(psqlDb, "table", "tenantName") - if test.expectedError != err { + if !errors.Is(test.expectedError, err) { t.Errorf("Unexpected error, expected: %v, got: %v", test.expectedError, err) } } diff --git a/dbservice/postgresdb/plgstats.go b/dbservice/postgresdb/plgstats.go index fc29229d..867a6d31 100644 --- a/dbservice/postgresdb/plgstats.go +++ b/dbservice/postgresdb/plgstats.go @@ -2,6 +2,7 @@ package postgresdb import ( "database/sql" + "errors" "fmt" "github.com/aquasecurity/postee/dbservice/dbparam" @@ -18,7 +19,7 @@ func (postgresDb *PostgresDb) RegisterPlgnInvctn(name string) error { amount := 0 sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (tenantName=$1 AND %s=$2)", "amount", dbparam.DbBucketOutputStats, "outputName") err = db.Get(&amount, sqlQuery, postgresDb.TenantName, name) - if err != nil && err != sql.ErrNoRows { + if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } amount += 1 diff --git a/dbservice/postgresdb/plgstats_test.go b/dbservice/postgresdb/plgstats_test.go index b9a0a709..d990d719 100644 --- a/dbservice/postgresdb/plgstats_test.go +++ b/dbservice/postgresdb/plgstats_test.go @@ -2,6 +2,7 @@ package postgresdb import ( "database/sql" + "errors" "log" "testing" @@ -70,7 +71,7 @@ func TestRegisterPlgnInvctnErrors(t *testing.T) { insertOutputStats = savedInsertOutputStats }() err := db.RegisterPlgnInvctn("testName") - if err != test.expectedErr { + if !errors.Is(err, test.expectedErr) { t.Errorf("Errors no contains: expected: %v, got: %v", test.expectedErr, err) } }) diff --git a/router/api_test.go b/router/api_test.go index 833c107e..f24ad7d9 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -1,7 +1,7 @@ package router import ( - "errors" + "encoding/json" "fmt" "io/ioutil" "os" @@ -114,7 +114,7 @@ func TestEditOutput(t *testing.T) { } defer Instance().cleanInstance() modifiedUrl := "https://hooks.slack.com/services/TAAAA/XXX/" - expectedError := errors.New("output badName is not found") + expectedError := "output badName is not found" if err := AddOutput(outputSettings); err != nil { t.Errorf("Can't add output: %v", err) @@ -133,7 +133,7 @@ func TestEditOutput(t *testing.T) { assert.Equal(t, modifiedUrl, Instance().outputs["my-slack"].(*outputs.SlackOutput).Url, "url is updated") err := UpdateOutput(&data.OutputSettings{Name: "badName"}) - if err != nil && errors.Is(err, expectedError) { + if err != nil && err.Error() != expectedError { t.Errorf("unexpected error, expected: %v, got: %v", expectedError, err) } } @@ -143,7 +143,9 @@ func TestListOutput(t *testing.T) { } defer Instance().cleanInstance() - AddOutput(outputSettings) + if err := AddOutput(outputSettings); err != nil { + t.Errorf("Unexpected AddOutput error: %v", err) + } assert.Equal(t, 1, len(Instance().outputs), "one output expected") outputs := ListOutputs() @@ -181,7 +183,9 @@ func TestDeleteRoute(t *testing.T) { AddRoute(inputRouteHtml) assert.Equal(t, 3, len(Instance().inputRoutes), "three route expected") - DeleteRoute("my-route") + if err := DeleteRoute("my-route"); err != nil { + t.Errorf("Unexpected DeleteRoute error: %v", err) + } assert.Equal(t, 2, len(Instance().inputRoutes), "two routes expected") assert.NotContains(t, Instance().inputRoutes, "my-route") } @@ -192,7 +196,7 @@ func TestEditRoute(t *testing.T) { } defer Instance().cleanInstance() modifiedTemplate := "vuls-slack" - expectedError := errors.New("output badName is not found") + expectedError := "output badName is not found" AddRoute(inputRoute) assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") @@ -204,13 +208,15 @@ func TestEditRoute(t *testing.T) { *Instance().inputRoutes["my-route"] = savedTempalate }() - UpdateRoute(r) + if err := UpdateRoute(r); err != nil { + t.Errorf("Unexpected AddTemplate error: %v", err) + } assert.Equal(t, 1, len(Instance().inputRoutes), "one route expected") assert.Equal(t, modifiedTemplate, Instance().inputRoutes["my-route"].Template, "template is updated") err := UpdateRoute(&routes.InputRoute{Name: "badName"}) - if err != nil && errors.Is(err, expectedError) { + if err != nil && err.Error() != expectedError { t.Errorf("unexpected error, expected: %v, got: %v", expectedError, err) } } @@ -244,7 +250,9 @@ func TestAddTemplate(t *testing.T) { } defer Instance().cleanInstance() - AddTemplate(template) + if err := AddTemplate(template); err != nil { + t.Errorf("Unexpected AddTemplate error: %v", err) + } assert.Equal(t, 1, len(Instance().templates), "one template expected") assert.Contains(t, Instance().templates, "legacy") assert.Equal(t, "*formatting.legacyScnEvaluator", fmt.Sprintf("%T", Instance().templates["legacy"]), "check name failed") @@ -282,13 +290,20 @@ func TestDeleteTemplate(t *testing.T) { } defer Instance().cleanInstance() - AddTemplate(template) - AddTemplate(templateSlack) + if err := AddTemplate(template); err != nil { + t.Errorf("Unexpected AddTemplate error: %v", err) + } + + if err := AddTemplate(templateSlack); err != nil { + t.Errorf("Unexpected AddTemplate error: %v", err) + } assert.Equal(t, 2, len(Instance().templates), "two template expected") AddRoute(&routes.InputRoute{Name: "my-route", Template: "legacy"}) assert.Equal(t, "legacy", Instance().inputRoutes["my-route"].Template, "one template expected") - DeleteTemplate("legacy") + if err := DeleteTemplate("legacy"); err != nil { + t.Errorf("Unexpected DeleteTemplate error: %v", err) + } assert.Equal(t, 1, len(Instance().templates), "one templates expected") assert.NotContains(t, Instance().templates, "legacy") assert.Equal(t, "", Instance().inputRoutes["my-route"].Template, "no template expected") @@ -299,9 +314,11 @@ func TestEditTemplate(t *testing.T) { Instance().cleanInstance() } defer Instance().cleanInstance() - expectedError := errors.New("template badName is not found") + expectedError := "template badName is not found" - AddTemplate(template) + if err := AddTemplate(template); err != nil { + t.Errorf("Unexpected AddTemplate error: %v", err) + } assert.Equal(t, 1, len(Instance().templates), "one template expected") assert.Equal(t, "*formatting.legacyScnEvaluator", fmt.Sprintf("%T", Instance().templates["legacy"]), "legacyScnEvaluator expected") @@ -319,7 +336,7 @@ func TestEditTemplate(t *testing.T) { assert.Equal(t, "*regoservice.regoEvaluator", fmt.Sprintf("%T", Instance().templates["legacy"]), "ScanRenderer is updated") err = UpdateTemplate(&data.Template{Name: "badName"}) - if err != nil && errors.Is(err, expectedError) { + if err != nil && err.Error() != expectedError { t.Errorf("unexpected error, expected: %v, got: %v", expectedError, err) } } @@ -330,7 +347,9 @@ func TestListTemplate(t *testing.T) { } defer Instance().cleanInstance() - AddTemplate(template) + if err := AddTemplate(template); err != nil { + t.Errorf("Unexpected AddTemplate error: %v", err) + } assert.Equal(t, 1, len(Instance().templates), "one route expected") templates := ListTemplates() @@ -378,8 +397,8 @@ func TestConfigFuncs(t *testing.T) { {"WithFileConfigAndDbPath", withFileConfigAndDbPathTest, "", false, "raw", "my-slack", "route1", "test/webhooks.db", ""}, {"WithNewConfig", withNewConfigTest, "", true, "", "", "", "./webhooks.db", ""}, {"WithNewConfigAndDbPath", withNewConfigAndDbPathTest, "", true, "", "", "", "test/webhooks.db", ""}, - {"WithPostgresParams", withPostgresParamsTest, "ParamsTenantName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, - {"WithPostgresUrl", withPostgresUrlTest, "tenantName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode"}, + {"WithPostgresParams", withPostgresParamsTest, "ParamsTenantName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:543/ParamsDbName?sslmode=ParamsSslMode"}, + {"WithPostgresUrl", withPostgresUrlTest, "tenantName", true, "", "", "", "", "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:543/ParamsDbName?sslmode=ParamsSslMode"}, } for _, test := range tests { t.Run("test "+test.funcName, func(t *testing.T) { @@ -522,21 +541,35 @@ var withDefaultConfigAndDbPathTest = func() error { var withPostgresParamsTest = func() error { savedInitPostgresDb := postgresdb.InitPostgresDb postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } + savedGetCfgCacheSource := postgresdb.GetCfgCacheSource + postgresdb.GetCfgCacheSource = func(postgresDb *postgresdb.PostgresDb) (string, error) { + j, _ := json.Marshal(outputSettings) + return string(j), nil + } defer func() { postgresdb.InitPostgresDb = savedInitPostgresDb + postgresdb.GetCfgCacheSource = savedGetCfgCacheSource }() - WithPostgresParams("ParamsTenantName", "ParamsDbName", "ParamsDbHostName", "ParamsPort", "ParamsUser", "ParamsPassword", "ParamsSslMode") + err := WithPostgresParams("ParamsTenantName", "ParamsDbName", "ParamsDbHostName", "543", "ParamsUser", "ParamsPassword", "ParamsSslMode") + if err != nil { + return err + } return nil } var withPostgresUrlTest = func() error { savedInitPostgresDb := postgresdb.InitPostgresDb postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } + savedGetCfgCacheSource := postgresdb.GetCfgCacheSource + postgresdb.GetCfgCacheSource = func(postgresDb *postgresdb.PostgresDb) (string, error) { return "", nil } defer func() { postgresdb.InitPostgresDb = savedInitPostgresDb + postgresdb.GetCfgCacheSource = savedGetCfgCacheSource }() - psqlUrl := "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:ParamsPort/ParamsDbName?sslmode=ParamsSslMode" - WithPostgresUrl(tenantName, psqlUrl) + psqlUrl := "postgres://ParamsUser:ParamsPassword@ParamsDbHostName:543/ParamsDbName?sslmode=ParamsSslMode" + if err := WithPostgresUrl(tenantName, psqlUrl); err != nil { + return err + } return nil } From 15f4b9d786752c582e85722b142e9d1552ec4ce2 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Tue, 28 Dec 2021 12:12:24 +0600 Subject: [PATCH 46/61] feat: added custom logger --- log/logger.go | 9 ++++----- log/stdoutlogger/stdoutlogger.go | 13 +++++++------ main.go | 4 ++-- utils/utils.go | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/log/logger.go b/log/logger.go index cf14df83..5205e87f 100644 --- a/log/logger.go +++ b/log/logger.go @@ -2,7 +2,7 @@ package log import "github.com/aquasecurity/postee/log/stdoutlogger" -var Logger LoggerType = stdoutlogger.NewLogger() +var Logger LoggerType = initDefaultLogger() type LoggerType interface { Info(args ...interface{}) @@ -12,14 +12,13 @@ type LoggerType interface { Warn(args ...interface{}) Warnf(template string, args ...interface{}) Debug(args ...interface{}) - DebugF(template string, args ...interface{}) + Debugf(template string, args ...interface{}) Fatal(args ...interface{}) Fatalf(template string, args ...interface{}) } -func InitDefaultLogger() { - logType := stdoutlogger.NewLogger() - Logger = logType +func initDefaultLogger() LoggerType { + return stdoutlogger.NewLogger() } func SetLogger(loggerType LoggerType) { diff --git a/log/stdoutlogger/stdoutlogger.go b/log/stdoutlogger/stdoutlogger.go index f916c529..ee04d61c 100644 --- a/log/stdoutlogger/stdoutlogger.go +++ b/log/stdoutlogger/stdoutlogger.go @@ -16,6 +16,7 @@ const ( warnLevel = colorYellow + " [WARN] " + colorReset errorLevel = colorRed + " [ERROR] " + colorReset debugLevel = colorPurple + " [DEBUG] " + colorReset + fatalLevel = colorRed + " [FATAL] " + colorReset ) type StdOutLogger struct { @@ -42,6 +43,10 @@ func (stdOutLogger StdOutLogger) Debug(args ...interface{}) { stdOutLogger.logger.Print(debugLevel + getMessage("", args)) } +func (stdOutLogger StdOutLogger) Fatal(args ...interface{}) { + stdOutLogger.logger.Fatal(fatalLevel + getMessage("", args)) +} + func (stdOutLogger StdOutLogger) Infof(template string, args ...interface{}) { stdOutLogger.logger.Print(infoLevel + getMessage(template, args)) } @@ -54,16 +59,12 @@ func (stdOutLogger StdOutLogger) Warnf(template string, args ...interface{}) { stdOutLogger.logger.Print(warnLevel + getMessage(template, args)) } -func (stdOutLogger StdOutLogger) DebugF(template string, args ...interface{}) { +func (stdOutLogger StdOutLogger) Debugf(template string, args ...interface{}) { stdOutLogger.logger.Print(debugLevel + getMessage(template, args)) } -func (stdOutLogger StdOutLogger) Fatal(args ...interface{}) { - stdOutLogger.logger.Fatal(args...) -} - func (stdOutLogger StdOutLogger) Fatalf(template string, args ...interface{}) { - stdOutLogger.logger.Fatalf(template, args...) + stdOutLogger.logger.Fatal(fatalLevel + getMessage(template, args)) } func getMessage(template string, fmtArgs []interface{}) string { diff --git a/main.go b/main.go index 7d24f66e..334d3cf5 100644 --- a/main.go +++ b/main.go @@ -79,7 +79,7 @@ func main() { err := router.Instance().ApplyFileCfg(cfgfile, postgresUrl, pathToDb, false) if err != nil { - log.Logger.Errorf("Can't start alert manager %v", err) + log.Logger.Fatalf("Can't start alert manager: %v", err) return } @@ -92,7 +92,7 @@ func main() { } err := rootCmd.Execute() if err != nil { - log.Logger.Errorf("Can't start command %v", err) + log.Logger.Fatalf("Can't start command: %v", err) return } } diff --git a/utils/utils.go b/utils/utils.go index de3c3954..df037bbb 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -33,7 +33,7 @@ func InitDebug() { func Debug(format string, v ...interface{}) { if dbg { - log.Logger.DebugF(format, v...) + log.Logger.Debugf(format, v...) } } From ff2528fceb1145b128ad41319f52185c3457a8d9 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Wed, 29 Dec 2021 16:18:40 +0600 Subject: [PATCH 47/61] fix: fixed load psql cfg with empty table --- dbservice/postgresdb/cfgcachesource.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/dbservice/postgresdb/cfgcachesource.go b/dbservice/postgresdb/cfgcachesource.go index 8ab32bcf..a790e2e7 100644 --- a/dbservice/postgresdb/cfgcachesource.go +++ b/dbservice/postgresdb/cfgcachesource.go @@ -1,11 +1,16 @@ package postgresdb import ( + "database/sql" + "errors" "fmt" + "log" "github.com/aquasecurity/postee/dbservice/dbparam" ) +const emptyCfg = `{"outputs":[],"routes":[],"templates":[]}` + var UpdateCfgCacheSource = func(postgresDb *PostgresDb, cfgfile string) error { connectUrl := postgresDb.ConnectUrl db, err := psqlConnect(connectUrl) @@ -29,7 +34,11 @@ var GetCfgCacheSource = func(postgresDb *PostgresDb) (string, error) { cfgFile := "" sqlQuery := fmt.Sprintf("SELECT configfile FROM %s WHERE tenantName=$1", dbparam.DbTableCfgCacheSource) if err = db.Get(&cfgFile, sqlQuery, postgresDb.TenantName); err != nil { - return "", err + if errors.Is(err, sql.ErrNoRows) { + log.Printf("WARNING: %s doesn't include tenantName: %s, empty cfg is used", dbparam.DbTableCfgCacheSource, postgresDb.TenantName) + return emptyCfg, nil + } + return "", fmt.Errorf("error getting cfg cache source: %v", err) } return cfgFile, nil } From cc9e31dc86c5224fbbba73c9c14ffbed3d3d7943 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Wed, 29 Dec 2021 16:31:32 +0600 Subject: [PATCH 48/61] fix: fixed load psql cfg with empty table --- dbservice/postgresdb/cfgcachesource.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dbservice/postgresdb/cfgcachesource.go b/dbservice/postgresdb/cfgcachesource.go index a790e2e7..93f11f24 100644 --- a/dbservice/postgresdb/cfgcachesource.go +++ b/dbservice/postgresdb/cfgcachesource.go @@ -4,12 +4,11 @@ import ( "database/sql" "errors" "fmt" - "log" "github.com/aquasecurity/postee/dbservice/dbparam" ) -const emptyCfg = `{"outputs":[],"routes":[],"templates":[]}` +const emptyCfg = `{}` var UpdateCfgCacheSource = func(postgresDb *PostgresDb, cfgfile string) error { connectUrl := postgresDb.ConnectUrl @@ -35,7 +34,6 @@ var GetCfgCacheSource = func(postgresDb *PostgresDb) (string, error) { sqlQuery := fmt.Sprintf("SELECT configfile FROM %s WHERE tenantName=$1", dbparam.DbTableCfgCacheSource) if err = db.Get(&cfgFile, sqlQuery, postgresDb.TenantName); err != nil { if errors.Is(err, sql.ErrNoRows) { - log.Printf("WARNING: %s doesn't include tenantName: %s, empty cfg is used", dbparam.DbTableCfgCacheSource, postgresDb.TenantName) return emptyCfg, nil } return "", fmt.Errorf("error getting cfg cache source: %v", err) From 9e40f3d60e5fc7ddcd9604372cb696f87a4685fa Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Wed, 29 Dec 2021 16:51:43 +0600 Subject: [PATCH 49/61] test: added test for errors to getCfgCacheSource --- dbservice/postgresdb/cfgcachesource_test.go | 43 +++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/dbservice/postgresdb/cfgcachesource_test.go b/dbservice/postgresdb/cfgcachesource_test.go index 22ee3df2..b6526c69 100644 --- a/dbservice/postgresdb/cfgcachesource_test.go +++ b/dbservice/postgresdb/cfgcachesource_test.go @@ -1,6 +1,8 @@ package postgresdb import ( + "database/sql" + "errors" "log" "testing" @@ -39,3 +41,44 @@ func TestUpdateCfgCacheSource(t *testing.T) { t.Errorf("CfgFiles not equals, expected: %s, got: %s", cfgFile, cfg) } } + +func TestGetCfgCacheSourceErrors(t *testing.T) { + tests := []struct { + name string + err error + expectedCfg string + expectedErr string + }{ + {"Norows error", sql.ErrNoRows, "{}", ""}, + {"select error", errors.New("select error"), "", "error getting cfg cache source: select error"}, + } + for _, test := range tests { + savedPsqlConnect := psqlConnect + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectQuery("SELECT").WillReturnError(test.err) + return db, err + } + savedInsertCfgCacheSource := insertCfgCacheSource + insertCfgCacheSource = func(db *sqlx.DB, tenantName, cfgFile string) error { return nil } + defer func() { + psqlConnect = savedPsqlConnect + insertCfgCacheSource = savedInsertCfgCacheSource + }() + + cfg, err := GetCfgCacheSource(db) + if test.expectedErr != "" || err != nil { + if err.Error() != test.expectedErr { + t.Errorf("Unexpected err, expected: %v, got: %v", test.expectedErr, err) + } + } + + if cfg != test.expectedCfg { + t.Errorf("Bad cfg, expected: %s, got: %s", test.expectedCfg, cfg) + } + } + +} From 0d663fc8fe23aa535c1a0b6bd3a3cd86730187f5 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Wed, 29 Dec 2021 16:54:43 +0600 Subject: [PATCH 50/61] test: added test for errors to getCfgCacheSource --- dbservice/postgresdb/cfgcachesource_test.go | 47 +++++++++++---------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/dbservice/postgresdb/cfgcachesource_test.go b/dbservice/postgresdb/cfgcachesource_test.go index b6526c69..935f39c8 100644 --- a/dbservice/postgresdb/cfgcachesource_test.go +++ b/dbservice/postgresdb/cfgcachesource_test.go @@ -53,32 +53,35 @@ func TestGetCfgCacheSourceErrors(t *testing.T) { {"select error", errors.New("select error"), "", "error getting cfg cache source: select error"}, } for _, test := range tests { - savedPsqlConnect := psqlConnect - psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, mock, err := sqlxmock.Newx() - if err != nil { - log.Println("failed to open sqlmock database:", err) + t.Run(test.name, func(t *testing.T) { + savedPsqlConnect := psqlConnect + psqlConnect = func(connectUrl string) (*sqlx.DB, error) { + db, mock, err := sqlxmock.Newx() + if err != nil { + log.Println("failed to open sqlmock database:", err) + } + mock.ExpectQuery("SELECT").WillReturnError(test.err) + return db, err } - mock.ExpectQuery("SELECT").WillReturnError(test.err) - return db, err - } - savedInsertCfgCacheSource := insertCfgCacheSource - insertCfgCacheSource = func(db *sqlx.DB, tenantName, cfgFile string) error { return nil } - defer func() { - psqlConnect = savedPsqlConnect - insertCfgCacheSource = savedInsertCfgCacheSource - }() + savedInsertCfgCacheSource := insertCfgCacheSource + insertCfgCacheSource = func(db *sqlx.DB, tenantName, cfgFile string) error { return nil } + defer func() { + psqlConnect = savedPsqlConnect + insertCfgCacheSource = savedInsertCfgCacheSource + }() - cfg, err := GetCfgCacheSource(db) - if test.expectedErr != "" || err != nil { - if err.Error() != test.expectedErr { - t.Errorf("Unexpected err, expected: %v, got: %v", test.expectedErr, err) + cfg, err := GetCfgCacheSource(db) + if test.expectedErr != "" || err != nil { + if err.Error() != test.expectedErr { + t.Errorf("Unexpected err, expected: %v, got: %v", test.expectedErr, err) + } } - } - if cfg != test.expectedCfg { - t.Errorf("Bad cfg, expected: %s, got: %s", test.expectedCfg, cfg) - } + if cfg != test.expectedCfg { + t.Errorf("Bad cfg, expected: %s, got: %s", test.expectedCfg, cfg) + } + }) + } } From 57a0e797b424484f55df9d2ec92326a81bd276cb Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Thu, 30 Dec 2021 11:23:18 +0600 Subject: [PATCH 51/61] fix: fixed lint error --- dbservice/postgresdb/cfgcachesource.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbservice/postgresdb/cfgcachesource.go b/dbservice/postgresdb/cfgcachesource.go index 93f11f24..f9056ad4 100644 --- a/dbservice/postgresdb/cfgcachesource.go +++ b/dbservice/postgresdb/cfgcachesource.go @@ -36,7 +36,7 @@ var GetCfgCacheSource = func(postgresDb *PostgresDb) (string, error) { if errors.Is(err, sql.ErrNoRows) { return emptyCfg, nil } - return "", fmt.Errorf("error getting cfg cache source: %v", err) + return "", fmt.Errorf("error getting cfg cache source: %w", err) } return cfgFile, nil } From 17864ab7604e5a7e99c03030266658a89357a70e Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Thu, 30 Dec 2021 13:25:14 +0600 Subject: [PATCH 52/61] Refactor: code review notes are corrected --- dbservice/boltdb/checker.go | 2 +- outputs/email.go | 8 ++++---- outputs/slack.go | 2 +- regoservice/eval.go | 2 +- router/router.go | 8 ++++---- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/dbservice/boltdb/checker.go b/dbservice/boltdb/checker.go index 6d185aa4..c9d1f856 100644 --- a/dbservice/boltdb/checker.go +++ b/dbservice/boltdb/checker.go @@ -38,7 +38,7 @@ func (boltDb *BoltDb) CheckSizeLimit() { } return nil }); err != nil { - log.Logger.Errorf("Error a check of db size: %v", err) + log.Logger.Errorf("Unable to delete bucket: %v", err) return } } diff --git a/outputs/email.go b/outputs/email.go index e15e1703..d2e7624f 100644 --- a/outputs/email.go +++ b/outputs/email.go @@ -88,11 +88,11 @@ func (email *EmailOutput) Send(content map[string]string) error { auth := smtp.PlainAuth("", email.User, email.Password, email.Host) err := smtp.SendMail(email.Host+":"+strconv.Itoa(email.Port), auth, email.Sender, recipients, []byte(msg)) if err != nil { - log.Logger.Error("SendMail Error:", err) + log.Logger.Error("Placeholder is missed: ", err) log.Logger.Errorf("From: %q, to %v via %q", email.Sender, email.Recipients, email.Host) return err } - log.Logger.Infof("Email was sent successfully!") + log.Logger.Debug("Email was sent successfully!") return nil } @@ -100,7 +100,7 @@ func sendViaMxServers(from, subj, msg string, recipients []string) { for _, rcpt := range recipients { at := strings.LastIndex(rcpt, "@") if at < 0 { - log.Logger.Errorf("%q isn't email", rcpt) + log.Logger.Errorf("%q isn't valid email", rcpt) continue } host := rcpt[at+1:] @@ -121,7 +121,7 @@ func sendViaMxServers(from, subj, msg string, recipients []string) { log.Logger.Error(err) continue } - log.Logger.Infof("The message to %q was sent successful via %q!", rcpt, mx.Host) + log.Logger.Debugf("The message to %q was sent successful via %q!", rcpt, mx.Host) break } } diff --git a/outputs/slack.go b/outputs/slack.go index d2423a41..0e9e0316 100644 --- a/outputs/slack.go +++ b/outputs/slack.go @@ -78,7 +78,7 @@ func (slack *SlackOutput) Send(input map[string]string) error { rawBlock := make([]data.SlackBlock, 0) err := json.Unmarshal([]byte(body), &rawBlock) if err != nil { - log.Logger.Errorf("Unmarshal slack sending error: %v", err) + log.Logger.Errorf("Unable to parse json: %v", err) return err } diff --git a/regoservice/eval.go b/regoservice/eval.go index b3b6ebbe..a2966f98 100644 --- a/regoservice/eval.go +++ b/regoservice/eval.go @@ -78,7 +78,7 @@ func (regoEvaluator *regoEvaluator) Eval(in map[string]interface{}, serverUrl st func getFirstElement(context map[string]interface{}, key string) interface{} { for _, v := range context { - log.Logger.Infof("checking: %s ...\n", key) + log.Logger.Debugf("checking: %s ...\n", key) childCtx, ok := v.(map[string]interface{}) if !ok { return nil diff --git a/router/router.go b/router/router.go index 8ddb8c76..a59411d7 100644 --- a/router/router.go +++ b/router/router.go @@ -180,11 +180,11 @@ func (ctx *Router) Terminate() { ctx.quit <- struct{}{} } - log.Logger.Info("quit notified") + log.Logger.Debug("quit notified") if ctx.ticker != nil { ctx.stopTicker <- struct{}{} - log.Logger.Info("stopTicker notified") + log.Logger.Debug("stopTicker notified") } ctx.cleanInstance() @@ -580,12 +580,12 @@ func (ctx *Router) HandleRoute(routeName string, in []byte) { for _, outputName := range r.Outputs { pl, ok := ctx.outputs[outputName] if !ok { - log.Logger.Errorf("route %q contains an output %q, which doesn't enable now.", routeName, outputName) + log.Logger.Errorf("Route %q contains reference to not enabled output %q.", routeName, outputName) continue } tmpl, ok := ctx.templates[r.Template] if !ok { - log.Logger.Errorf("route %q contains reference to undefined or misconfigured template %q.", + log.Logger.Errorf("Route %q contains reference to undefined or misconfigured template %q.", routeName, r.Template) continue } From dc501c63d8aeeb4177e4a1c6dce93d0e767c9388 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Thu, 30 Dec 2021 14:21:57 +0600 Subject: [PATCH 53/61] Refactor: code review notes are corrected --- log/logger.go | 28 +++++++++-- log/stdoutlogger/stdoutlogger.go | 85 -------------------------------- log/zaplogger/zaplogger.go | 63 +++++++++++++++++++++++ msgservice/msghandling.go | 4 +- msgservice/scheduler.go | 2 +- outputs/email.go | 2 +- outputs/jira.go | 8 +-- router/router.go | 20 ++++---- teams/teams_requests.go | 7 +-- webserver/webserver.go | 12 ++--- 10 files changed, 115 insertions(+), 116 deletions(-) delete mode 100644 log/stdoutlogger/stdoutlogger.go create mode 100644 log/zaplogger/zaplogger.go diff --git a/log/logger.go b/log/logger.go index 5205e87f..2d8777a3 100644 --- a/log/logger.go +++ b/log/logger.go @@ -1,6 +1,11 @@ package log -import "github.com/aquasecurity/postee/log/stdoutlogger" +import ( + "log" + "os" + + "github.com/aquasecurity/postee/log/zaplogger" +) var Logger LoggerType = initDefaultLogger() @@ -18,9 +23,24 @@ type LoggerType interface { } func initDefaultLogger() LoggerType { - return stdoutlogger.NewLogger() + debug := false + disable := false + + if os.Getenv("POSTEE_DEBUG") != "" { + debug = true + } + + if os.Getenv("POSTEE_QUIET") != "" { + disable = true + } + + logger, err := zaplogger.NewLogger(debug, disable) + if err != nil { + log.Fatalf("failed to initialize a logger: %v", err) + } + return logger } -func SetLogger(loggerType LoggerType) { - Logger = loggerType +func SetLogger(logger LoggerType) { + Logger = logger } diff --git a/log/stdoutlogger/stdoutlogger.go b/log/stdoutlogger/stdoutlogger.go deleted file mode 100644 index ee04d61c..00000000 --- a/log/stdoutlogger/stdoutlogger.go +++ /dev/null @@ -1,85 +0,0 @@ -package stdoutlogger - -import ( - "fmt" - "log" - "os" -) - -const ( - colorReset = "\033[0m" - colorRed = "\033[31m" - colorBlue = "\033[34m" - colorYellow = "\033[33m" - colorPurple = "\033[35m" - infoLevel = colorBlue + " [INFO] " + colorReset - warnLevel = colorYellow + " [WARN] " + colorReset - errorLevel = colorRed + " [ERROR] " + colorReset - debugLevel = colorPurple + " [DEBUG] " + colorReset - fatalLevel = colorRed + " [FATAL] " + colorReset -) - -type StdOutLogger struct { - logger log.Logger -} - -func NewLogger() StdOutLogger { - logger := log.New(os.Stdout, "", log.Ldate|log.Ltime) - return StdOutLogger{logger: *logger} -} - -func (stdOutLogger StdOutLogger) Info(args ...interface{}) { - stdOutLogger.logger.Print(infoLevel + getMessage("", args)) -} - -func (stdOutLogger StdOutLogger) Error(args ...interface{}) { - stdOutLogger.logger.Print(errorLevel + getMessage("", args)) -} - -func (stdOutLogger StdOutLogger) Warn(args ...interface{}) { - stdOutLogger.logger.Print(warnLevel + getMessage("", args)) -} -func (stdOutLogger StdOutLogger) Debug(args ...interface{}) { - stdOutLogger.logger.Print(debugLevel + getMessage("", args)) -} - -func (stdOutLogger StdOutLogger) Fatal(args ...interface{}) { - stdOutLogger.logger.Fatal(fatalLevel + getMessage("", args)) -} - -func (stdOutLogger StdOutLogger) Infof(template string, args ...interface{}) { - stdOutLogger.logger.Print(infoLevel + getMessage(template, args)) -} - -func (stdOutLogger StdOutLogger) Errorf(template string, args ...interface{}) { - stdOutLogger.logger.Print(errorLevel + getMessage(template, args)) -} - -func (stdOutLogger StdOutLogger) Warnf(template string, args ...interface{}) { - stdOutLogger.logger.Print(warnLevel + getMessage(template, args)) -} - -func (stdOutLogger StdOutLogger) Debugf(template string, args ...interface{}) { - stdOutLogger.logger.Print(debugLevel + getMessage(template, args)) -} - -func (stdOutLogger StdOutLogger) Fatalf(template string, args ...interface{}) { - stdOutLogger.logger.Fatal(fatalLevel + getMessage(template, args)) -} - -func getMessage(template string, fmtArgs []interface{}) string { - if len(fmtArgs) == 0 { - return template - } - - if template != "" { - return fmt.Sprintf(template, fmtArgs...) - } - - if len(fmtArgs) == 1 { - if str, ok := fmtArgs[0].(string); ok { - return str - } - } - return fmt.Sprint(fmtArgs...) -} diff --git a/log/zaplogger/zaplogger.go b/log/zaplogger/zaplogger.go new file mode 100644 index 00000000..feff48a9 --- /dev/null +++ b/log/zaplogger/zaplogger.go @@ -0,0 +1,63 @@ +package zaplogger + +import ( + "os" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func NewLogger(debug, disable bool) (*zap.SugaredLogger, error) { + // First, define our level-handling logic. + errorPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= zapcore.ErrorLevel + }) + logPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + if debug { + return lvl < zapcore.ErrorLevel + } + // Not enable debug level + return zapcore.DebugLevel < lvl && lvl < zapcore.ErrorLevel + }) + + encoderConfig := zapcore.EncoderConfig{ + TimeKey: "Time", + LevelKey: "Level", + NameKey: "Name", + CallerKey: "Caller", + MessageKey: "Msg", + StacktraceKey: "St", + EncodeLevel: zapcore.CapitalColorLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.StringDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + consoleEncoder := zapcore.NewConsoleEncoder(encoderConfig) + + // High-priority output should also go to standard error, and low-priority + // output should also go to standard out. + consoleLogs := zapcore.Lock(os.Stdout) + consoleErrors := zapcore.Lock(os.Stderr) + if disable { + devNull, err := os.Create(os.DevNull) + if err != nil { + return nil, err + } + // Discard low-priority output + consoleLogs = zapcore.Lock(devNull) + } + + core := zapcore.NewTee( + zapcore.NewCore(consoleEncoder, consoleErrors, errorPriority), + zapcore.NewCore(consoleEncoder, consoleLogs, logPriority), + ) + + opts := []zap.Option{zap.ErrorOutput(zapcore.Lock(os.Stderr))} + if debug { + opts = append(opts, zap.Development()) + } + logger := zap.New(core, opts...) + + return logger.Sugar(), nil +} diff --git a/msgservice/msghandling.go b/msgservice/msghandling.go index 161e96c3..00672ed8 100644 --- a/msgservice/msghandling.go +++ b/msgservice/msghandling.go @@ -100,10 +100,10 @@ func (scan *MsgService) MsgHandling(in map[string]interface{}, output outputs.Ou AggregateScanAndGetQueue(route.Name, content, 0, true) if !route.IsSchedulerRun() { //TODO route shouldn't have any associated logic - log.Logger.Infof("about to schedule %s\n", route.Name) + log.Logger.Infof("about to schedule %s", route.Name) RunScheduler(route, send, AggregateScanAndGetQueue, inpteval, &route.Name, output) } else { - log.Logger.Infof("%s is already scheduled\n", route.Name) + log.Logger.Infof("%s is already scheduled", route.Name) } } else { send(output, content) diff --git a/msgservice/scheduler.go b/msgservice/scheduler.go index 6546f781..bb27da78 100644 --- a/msgservice/scheduler.go +++ b/msgservice/scheduler.go @@ -38,7 +38,7 @@ var RunScheduler = func( if len(queue) > 0 { aggregated, err := inpteval.BuildAggregatedContent(queue) if err != nil { - log.Logger.Errorf("Unable to build aggregated contents %v\n", err) + log.Logger.Errorf("Unable to build aggregated contents %v", err) } fnSend(output, aggregated) } diff --git a/outputs/email.go b/outputs/email.go index d2e7624f..2ea5c816 100644 --- a/outputs/email.go +++ b/outputs/email.go @@ -57,7 +57,7 @@ func (email *EmailOutput) Init() error { } func (email *EmailOutput) Terminate() error { - log.Logger.Infof("Email output terminated\n") + log.Logger.Infof("Email output terminated") return nil } diff --git a/outputs/jira.go b/outputs/jira.go index c613e5b4..b76ee4cd 100644 --- a/outputs/jira.go +++ b/outputs/jira.go @@ -121,7 +121,7 @@ func (ctx *JiraAPI) fetchSprintId(client jira.Client) { } func (ctx *JiraAPI) Terminate() error { - log.Logger.Infof("Jira output terminated\n") + log.Logger.Infof("Jira output terminated") return nil } @@ -244,7 +244,7 @@ func (ctx *JiraAPI) Send(content map[string]string) error { issue, err := InitIssue(client, metaProject, metaIssueType, fieldsConfig, isServerJira(ctx.Url)) if err != nil { - log.Logger.Errorf("Failed to init issue: %s\n", err) + log.Logger.Errorf("Failed to init issue: %s", err) return err } @@ -273,7 +273,7 @@ func (ctx *JiraAPI) Send(content map[string]string) error { i, err := ctx.openIssue(client, issue) if err != nil { - log.Logger.Errorf("Failed to open jira issue, %s\n", err) + log.Logger.Errorf("Failed to open jira issue, %s", err) return err } log.Logger.Infof("Created new jira issue %s", i.ID) @@ -367,7 +367,7 @@ func InitIssue(c *jira.Client, metaProject *jira.MetaProject, metaIssuetype *jir case "number": val, err := strconv.Atoi(value) if err != nil { - fmt.Printf("Failed convert value(string) to int: %s\n", err) + fmt.Printf("Failed convert value(string) to int: %s", err) } issueFields.Unknowns[jiraKey] = val diff --git a/router/router.go b/router/router.go index a59411d7..fcccd4f2 100644 --- a/router/router.go +++ b/router/router.go @@ -174,7 +174,7 @@ func (ctx *Router) Terminate() { } log.Logger.Info("Route schedulers stopped") - log.Logger.Infof("ctx.quit %v\n", ctx.quit) + log.Logger.Infof("ctx.quit %v", ctx.quit) if ctx.quit != nil { ctx.quit <- struct{}{} @@ -248,7 +248,7 @@ func removeTemplateFromCfgCacheSource(outputs *data.TenantSettings, templateName } func (ctx *Router) initTemplate(template *data.Template) error { - log.Logger.Infof("Configuring template %s \n", template.Name) + log.Logger.Infof("Configuring template %s", template.Name) if template.LegacyScanRenderer != "" { inpteval, err := formatting.BuildLegacyScnEvaluator(template.LegacyScanRenderer) @@ -256,7 +256,7 @@ func (ctx *Router) initTemplate(template *data.Template) error { return err } ctx.templates[template.Name] = inpteval - log.Logger.Infof("Configured with legacy renderer %s \n", template.LegacyScanRenderer) + log.Logger.Infof("Configured with legacy renderer %s", template.LegacyScanRenderer) } if template.RegoPackage != "" { @@ -265,10 +265,10 @@ func (ctx *Router) initTemplate(template *data.Template) error { return err } ctx.templates[template.Name] = inpteval - log.Logger.Infof("Configured with Rego package %s\n", template.RegoPackage) + log.Logger.Infof("Configured with Rego package %s", template.RegoPackage) } if template.Url != "" { - log.Logger.Infof("Configured with url: %s\n", template.Url) + log.Logger.Infof("Configured with url: %s", template.Url) r, err := http.NewRequest("GET", template.Url, nil) if err != nil { @@ -324,7 +324,7 @@ func (ctx *Router) setAquaServerUrl(url string) { func (ctx *Router) initTenantSettings(tenant *data.TenantSettings) error { ctx.mutexScan.Lock() defer ctx.mutexScan.Unlock() - log.Logger.Infof("Loading alerts configuration file %s ....\n", ctx.cfgfile) + log.Logger.Infof("Loading alerts configuration file %s ....", ctx.cfgfile) ctx.setAquaServerUrl(tenant.AquaServer) @@ -356,17 +356,17 @@ func (ctx *Router) initTenantSettings(tenant *data.TenantSettings) error { for _, t := range tenant.Templates { err := ctx.initTemplate(&t) if err != nil { - log.Logger.Errorf("Can not initialize template %s: %v \n", t.Name, err) + log.Logger.Errorf("Can not initialize template %s: %v", t.Name, err) } } for _, settings := range tenant.Outputs { - utils.Debug("%#v\n", anonymizeSettings(&settings)) + utils.Debug("%#v", anonymizeSettings(&settings)) err := ctx.addOutput(&settings) if err != nil { - log.Logger.Errorf("Can not initialize output %s: %v \n", settings.Name, err) + log.Logger.Errorf("Can not initialize output %s: %v", settings.Name, err) } else { log.Logger.Infof("Output %s is configured", settings.Name) } @@ -623,7 +623,7 @@ func buildAndInitOtpt(settings *data.OutputSettings, aquaServerUrl string) (outp } } - utils.Debug("Starting Output %q: %q\n", settings.Type, settings.Name) + utils.Debug("Starting Output %q: %q", settings.Type, settings.Name) var plg outputs.Output diff --git a/teams/teams_requests.go b/teams/teams_requests.go index 54f21bb5..c9d7aa5a 100644 --- a/teams/teams_requests.go +++ b/teams/teams_requests.go @@ -3,16 +3,17 @@ package teams_api import ( "bytes" "fmt" - "github.com/aquasecurity/postee/utils" "io/ioutil" "net/http" + + "github.com/aquasecurity/postee/utils" ) func CreateMessageByWebhook(webhook, content string) error { var message bytes.Buffer fmt.Fprintf(&message, "{\"text\":\"%s\"}", content) - utils.Debug("Data for sending to %q: %q\n", webhook, message.String()) + utils.Debug("Data for sending to %q: %q", webhook, message.String()) r := bytes.NewReader(message.Bytes()) client := http.DefaultClient reg, err := http.NewRequest("POST", webhook, r) @@ -32,7 +33,7 @@ func CreateMessageByWebhook(webhook, content string) error { if message[0] != '1' { return fmt.Errorf("Teams Body Error: %q", string(message)) } - utils.Debug("Response body: %q\n", message) + utils.Debug("Response body: %q", message) } return nil } diff --git a/webserver/webserver.go b/webserver/webserver.go index 5ac8d184..46e71bb4 100644 --- a/webserver/webserver.go +++ b/webserver/webserver.go @@ -37,7 +37,7 @@ func (ctx *WebServer) withApiKey(next http.HandlerFunc) http.HandlerFunc { correctKey, err := dbservice.Db.GetApiKey() if err != nil || correctKey == "" { - log.Logger.Errorf("reload API key is either empty or there is an error: %s \n", err) + log.Logger.Errorf("reload API key is either empty or there is an error: %s", err) http.Error(w, "Unauthorized", http.StatusUnauthorized) } @@ -61,7 +61,7 @@ func (ctx *WebServer) Start(host, tlshost string) { if ok := utils.PathExists(keyPem); !ok { err := utils.GenerateCertificate(keyPem, certPem) if err != nil { - log.Logger.Errorf("GenerateCertificate error: %v \n", err) + log.Logger.Errorf("GenerateCertificate error: %v", err) } } @@ -74,7 +74,7 @@ func (ctx *WebServer) Start(host, tlshost string) { } err := dbservice.Db.EnsureApiKey() if err != nil { - log.Logger.Errorf("EnsureApiKey error: %v \n", err) + log.Logger.Errorf("EnsureApiKey error: %v", err) } ctx.router.HandleFunc("/", ctx.sessionHandler(ctx.scanHandler)).Methods("POST") @@ -108,7 +108,7 @@ func (ctx *WebServer) sessionHandler(f func(http.ResponseWriter, *http.Request)) func (ctx *WebServer) scanHandler(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - log.Logger.Errorf("Failed ioutil.ReadAll: %s\n", err) + log.Logger.Errorf("Failed ioutil.ReadAll: %s", err) ctx.writeResponseError(w, http.StatusInternalServerError, err) return } @@ -130,7 +130,7 @@ func (ctx *WebServer) writeResponse(w http.ResponseWriter, httpStatus int, v int result, _ := json.Marshal(v) _, err := w.Write(result) if err != nil { - log.Logger.Errorf("Write error: %s \n", err) + log.Logger.Errorf("Write error: %s", err) } } } @@ -140,6 +140,6 @@ func (ctx *WebServer) writeResponseError(w http.ResponseWriter, httpError int, e w.WriteHeader(httpError) errEncode := json.NewEncoder(w).Encode(err) if errEncode != nil { - log.Logger.Errorf("Encode error: %s \n", errEncode) + log.Logger.Errorf("Encode error: %s", errEncode) } } From 31032c27f7980bee3510fb1424452c038517f277 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Thu, 30 Dec 2021 14:55:53 +0600 Subject: [PATCH 54/61] refactor: removed func debug --- log/logger.go | 2 +- main.go | 2 -- outputs/teams.go | 15 +++++++-------- router/router.go | 4 ++-- teams/teams_requests.go | 6 +++--- utils/utils.go | 15 --------------- webserver/tenant.go | 3 +-- webserver/webserver.go | 2 +- 8 files changed, 15 insertions(+), 34 deletions(-) diff --git a/log/logger.go b/log/logger.go index 2d8777a3..be29a1f7 100644 --- a/log/logger.go +++ b/log/logger.go @@ -26,7 +26,7 @@ func initDefaultLogger() LoggerType { debug := false disable := false - if os.Getenv("POSTEE_DEBUG") != "" { + if os.Getenv("POSTEE_DEBUG") != "" || os.Getenv("AQUAALERT_DEBUG") != "" { debug = true } diff --git a/main.go b/main.go index 334d3cf5..4f558e0f 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "syscall" "github.com/aquasecurity/postee/router" - "github.com/aquasecurity/postee/utils" "github.com/aquasecurity/postee/webserver" "github.com/spf13/cobra" @@ -45,7 +44,6 @@ func init() { func main() { runtime.GOMAXPROCS(runtime.NumCPU()) - utils.InitDebug() rootCmd.Run = func(cmd *cobra.Command, args []string) { diff --git a/outputs/teams.go b/outputs/teams.go index 1c04b408..8b3b668a 100644 --- a/outputs/teams.go +++ b/outputs/teams.go @@ -7,7 +7,6 @@ import ( "github.com/aquasecurity/postee/formatting" "github.com/aquasecurity/postee/layout" "github.com/aquasecurity/postee/log" - "github.com/aquasecurity/postee/utils" msteams "github.com/aquasecurity/postee/teams" ) @@ -44,21 +43,21 @@ func (teams *TeamsOutput) Init() error { func (teams *TeamsOutput) Send(input map[string]string) error { log.Logger.Infof("Sending to MS Teams via %q...", teams.Name) - utils.Debug("Title for %q: %q\n", teams.Name, input["title"]) - utils.Debug("Url(s) for %q: %q\n", teams.Name, input["url"]) - utils.Debug("Webhook for %q: %q\n", teams.Name, teams.Webhook) - utils.Debug("Length of Description for %q: %d/%d\n", + log.Logger.Debugf("Title for %q: %q", teams.Name, input["title"]) + log.Logger.Debugf("Url(s) for %q: %q", teams.Name, input["url"]) + log.Logger.Debugf("Webhook for %q: %q", teams.Name, teams.Webhook) + log.Logger.Debugf("Length of Description for %q: %d/%d", teams.Name, len(input["description"]), teamsSizeLimit) var body string if len(input["description"]) > teamsSizeLimit { - utils.Debug("MS Team output will send SHORT message\n") + log.Logger.Debugf("MS Team output will send SHORT message") body = buildShortMessage(teams.AquaServer, input["url"], teams.teamsLayout) } else { - utils.Debug("MS Team output will send LONG message\n") + log.Logger.Debugf("MS Team output will send LONG message") body = input["description"] } - utils.Debug("Message is: %q\n", body) + log.Logger.Debugf("Message is: %q", body) escaped, err := escapeJSON(body) if err != nil { diff --git a/router/router.go b/router/router.go index fcccd4f2..61f6a0c7 100644 --- a/router/router.go +++ b/router/router.go @@ -361,7 +361,7 @@ func (ctx *Router) initTenantSettings(tenant *data.TenantSettings) error { } for _, settings := range tenant.Outputs { - utils.Debug("%#v", anonymizeSettings(&settings)) + log.Logger.Debugf("%#v", anonymizeSettings(&settings)) err := ctx.addOutput(&settings) @@ -623,7 +623,7 @@ func buildAndInitOtpt(settings *data.OutputSettings, aquaServerUrl string) (outp } } - utils.Debug("Starting Output %q: %q", settings.Type, settings.Name) + log.Logger.Debugf("Starting Output %q: %q", settings.Type, settings.Name) var plg outputs.Output diff --git a/teams/teams_requests.go b/teams/teams_requests.go index c9d7aa5a..7a700ec7 100644 --- a/teams/teams_requests.go +++ b/teams/teams_requests.go @@ -6,14 +6,14 @@ import ( "io/ioutil" "net/http" - "github.com/aquasecurity/postee/utils" + "github.com/aquasecurity/postee/log" ) func CreateMessageByWebhook(webhook, content string) error { var message bytes.Buffer fmt.Fprintf(&message, "{\"text\":\"%s\"}", content) - utils.Debug("Data for sending to %q: %q", webhook, message.String()) + log.Logger.Debugf("Data for sending to %q: %q", webhook, message.String()) r := bytes.NewReader(message.Bytes()) client := http.DefaultClient reg, err := http.NewRequest("POST", webhook, r) @@ -33,7 +33,7 @@ func CreateMessageByWebhook(webhook, content string) error { if message[0] != '1' { return fmt.Errorf("Teams Body Error: %q", string(message)) } - utils.Debug("Response body: %q", message) + log.Logger.Debugf("Response body: %q", message) } return nil } diff --git a/utils/utils.go b/utils/utils.go index df037bbb..0b3ee158 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -22,21 +22,6 @@ func GetEnvironmentVarOrPlain(value string) string { return value } -func InitDebug() { - if os.Getenv("AQUAALERT_DEBUG") != "" { - dbg = true - } - if os.Getenv("POSTEE_DEBUG") != "" { - dbg = true - } -} - -func Debug(format string, v ...interface{}) { - if dbg { - log.Logger.Debugf(format, v...) - } -} - func GetEnv(name string) (string, error) { value := os.Getenv(name) if len(value) > 0 { diff --git a/webserver/tenant.go b/webserver/tenant.go index ae92bc0d..c8c003d2 100644 --- a/webserver/tenant.go +++ b/webserver/tenant.go @@ -6,7 +6,6 @@ import ( "github.com/aquasecurity/postee/log" "github.com/aquasecurity/postee/router" - "github.com/aquasecurity/postee/utils" "github.com/gorilla/mux" ) @@ -26,7 +25,7 @@ func (ctx *WebServer) tenantHandler(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - utils.Debug("%s\n\n", string(body)) + log.Logger.Debugf("%s\n\n", string(body)) router.Instance().HandleRoute(route, body) ctx.writeResponse(w, http.StatusOK, "") } diff --git a/webserver/webserver.go b/webserver/webserver.go index 46e71bb4..44caf9de 100644 --- a/webserver/webserver.go +++ b/webserver/webserver.go @@ -114,7 +114,7 @@ func (ctx *WebServer) scanHandler(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - utils.Debug("%s\n\n", string(body)) + log.Logger.Debugf("%s\n\n", string(body)) router.Instance().Send(body) ctx.writeResponse(w, http.StatusOK, "") } From 58bfd4a8d88a7634e025c5d13799f65e1df80ab7 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Thu, 30 Dec 2021 14:58:10 +0600 Subject: [PATCH 55/61] refactor: removed func debug --- utils/utils.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/utils/utils.go b/utils/utils.go index 0b3ee158..ccaa2e3d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -10,10 +10,6 @@ import ( "github.com/aquasecurity/postee/log" ) -var ( - dbg = false -) - func GetEnvironmentVarOrPlain(value string) string { const VarPrefix = "$" if strings.HasPrefix(value, VarPrefix) { From ca2bad4773373f7979a29e045b22ab90614cb928 Mon Sep 17 00:00:00 2001 From: DmitriyLewen Date: Wed, 12 Jan 2022 16:31:18 +0600 Subject: [PATCH 56/61] chore(deps): added zap dependency --- go.mod | 1 + go.sum | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/go.mod b/go.mod index a9c026bc..953c7c07 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/stretchr/testify v1.7.0 github.com/zhashkevych/go-sqlxmock v1.5.2-0.20201023121933-f973d0041cfc go.etcd.io/bbolt v1.3.6 + go.uber.org/zap v1.19.1 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 ) diff --git a/go.sum b/go.sum index f4500fba..6ad9ba6f 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,7 @@ github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hC github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= @@ -399,10 +400,15 @@ go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q= +go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= +go.uber.org/zap v1.19.1 h1:ue41HOKd1vGURxrmeKIgELGb3jPW9DMUDGtsinblHwI= +go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -641,6 +647,7 @@ golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From 50f31c98a1ee61a14390f73554963734480a0b13 Mon Sep 17 00:00:00 2001 From: DmitriyLewen <91113035+DmitriyLewen@users.noreply.github.com> Date: Thu, 27 Jan 2022 03:38:02 +0600 Subject: [PATCH 57/61] fix: fixed test errors (#228) --- router/api.go | 2 -- router/api_test.go | 3 +-- router/router.go | 30 ++++++++++++++++-------------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/router/api.go b/router/api.go index 61582489..012d43e6 100644 --- a/router/api.go +++ b/router/api.go @@ -38,7 +38,6 @@ func WithNewConfig(tenantName string) error { //tenant name if err := dbservice.ConfigureDb(defaultDbPath, "", ""); err != nil { return err } - Instance().cleanInstance() Instance().cleanChannels(true) return nil } @@ -49,7 +48,6 @@ func WithNewConfigAndDbPath(tenantName, dbPath string) error { //tenant name if err := dbservice.ConfigureDb(dbPath, "", ""); err != nil { return err } - Instance().cleanInstance() Instance().cleanChannels(true) return nil } diff --git a/router/api_test.go b/router/api_test.go index f24ad7d9..423c1291 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -99,7 +99,7 @@ func TestDeleteOutput(t *testing.T) { assert.Equal(t, 2, len(Instance().inputRoutes["my-route"].Outputs), "two output expected") if err := DeleteOutput("my-slack"); err != nil { - t.Errorf("Can't delte output: %v", err) + t.Errorf("Can't delete output: %v", err) } assert.Equal(t, 1, len(Instance().outputs), "one outputs expected") assert.NotContains(t, Instance().outputs, "my-slack") @@ -403,7 +403,6 @@ func TestConfigFuncs(t *testing.T) { for _, test := range tests { t.Run("test "+test.funcName, func(t *testing.T) { defer func() { - Instance().cleanInstance() dbservice.Db = nil }() diff --git a/router/router.go b/router/router.go index 61f6a0c7..7f9d3d08 100644 --- a/router/router.go +++ b/router/router.go @@ -145,7 +145,7 @@ func (ctx *Router) applyTenantCfg(tenant *data.TenantSettings, synchronous bool) ctx.cleanInstance() ctx.cleanChannels(synchronous) - err := ctx.initTenantSettings(tenant) + err := ctx.initTenantSettings(tenant, synchronous) if err != nil { return err } @@ -182,7 +182,7 @@ func (ctx *Router) Terminate() { log.Logger.Debug("quit notified") - if ctx.ticker != nil { + if ctx.ticker != nil && ctx.stopTicker != nil { ctx.stopTicker <- struct{}{} log.Logger.Debug("stopTicker notified") } @@ -321,7 +321,7 @@ func (ctx *Router) setAquaServerUrl(url string) { } } -func (ctx *Router) initTenantSettings(tenant *data.TenantSettings) error { +func (ctx *Router) initTenantSettings(tenant *data.TenantSettings, synchronous bool) error { ctx.mutexScan.Lock() defer ctx.mutexScan.Unlock() log.Logger.Infof("Loading alerts configuration file %s ....", ctx.cfgfile) @@ -336,18 +336,20 @@ func (ctx *Router) initTenantSettings(tenant *data.TenantSettings) error { actualDbTestInterval = 1 } - ctx.ticker = time.NewTicker(baseForTicker * time.Duration(actualDbTestInterval)) - go func() { - for { - select { - case <-ctx.stopTicker: - return - case <-ctx.ticker.C: - dbservice.Db.CheckSizeLimit() - dbservice.Db.CheckExpiredData() + if !synchronous { + ctx.ticker = time.NewTicker(baseForTicker * time.Duration(actualDbTestInterval)) + go func() { + for { + select { + case <-ctx.stopTicker: + return + case <-ctx.ticker.C: + dbservice.Db.CheckSizeLimit() + dbservice.Db.CheckExpiredData() + } } - } - }() + }() + } for i, r := range tenant.InputRoutes { ctx.inputRoutes[r.Name] = routes.ConfigureTimeouts(&tenant.InputRoutes[i]) From b6d63003d34a60fc332a74f5e404281672eea6f7 Mon Sep 17 00:00:00 2001 From: Elad Zada Date: Sun, 30 Jan 2022 13:39:03 +0200 Subject: [PATCH 58/61] reuse postgres db conncetion instance --- dbservice/postgresdb/actions.go | 1 - dbservice/postgresdb/cfgcachesource.go | 4 ++-- dbservice/postgresdb/checker.go | 2 -- dbservice/postgresdb/connect.go | 31 ++++++++++++++++++++++---- dbservice/postgresdb/dbaggregator.go | 1 - dbservice/postgresdb/init.go | 1 - dbservice/postgresdb/plgstats.go | 1 - dbservice/postgresdb/sharedcfg.go | 3 +-- 8 files changed, 30 insertions(+), 14 deletions(-) diff --git a/dbservice/postgresdb/actions.go b/dbservice/postgresdb/actions.go index 3e5e0641..a7f634c8 100644 --- a/dbservice/postgresdb/actions.go +++ b/dbservice/postgresdb/actions.go @@ -14,7 +14,6 @@ func (postgresDb *PostgresDb) MayBeStoreMessage(message []byte, messageKey strin if err != nil { return false, err } - defer db.Close() currentValue := "" sqlQuery := fmt.Sprintf("SELECT messageValue FROM %s WHERE (tenantName=$1 AND messageKey=$2)", dbparam.DbBucketName) diff --git a/dbservice/postgresdb/cfgcachesource.go b/dbservice/postgresdb/cfgcachesource.go index f9056ad4..244c58b2 100644 --- a/dbservice/postgresdb/cfgcachesource.go +++ b/dbservice/postgresdb/cfgcachesource.go @@ -16,7 +16,7 @@ var UpdateCfgCacheSource = func(postgresDb *PostgresDb, cfgfile string) error { if err != nil { return err } - defer db.Close() + if err := insertCfgCacheSource(db, postgresDb.TenantName, cfgfile); err != nil { return err } @@ -29,7 +29,7 @@ var GetCfgCacheSource = func(postgresDb *PostgresDb) (string, error) { if err != nil { return "", err } - defer db.Close() + cfgFile := "" sqlQuery := fmt.Sprintf("SELECT configfile FROM %s WHERE tenantName=$1", dbparam.DbTableCfgCacheSource) if err = db.Get(&cfgFile, sqlQuery, postgresDb.TenantName); err != nil { diff --git a/dbservice/postgresdb/checker.go b/dbservice/postgresdb/checker.go index bff01a75..3767ae15 100644 --- a/dbservice/postgresdb/checker.go +++ b/dbservice/postgresdb/checker.go @@ -19,7 +19,6 @@ func (postgresDb *PostgresDb) CheckSizeLimit() { log.Logger.Errorf("CheckSizeLimit: Can't open db, connectUrl: %s", connectUrl) return } - defer db.Close() size := 0 if err = db.Get(&size, fmt.Sprintf("SELECT pg_total_relation_size('%s');", dbparam.DbBucketName)); err != nil { @@ -41,7 +40,6 @@ func (postgresDb *PostgresDb) CheckExpiredData() { log.Logger.Errorf("CheckExpiredData: Can't open postgresDb: %v", err) return } - defer db.Close() max := time.Now().UTC() //remove expired records if err = deleteRowsByTenantNameAndTime(db, postgresDb.TenantName, max); err != nil { diff --git a/dbservice/postgresdb/connect.go b/dbservice/postgresdb/connect.go index dbd54ef7..e6971e7a 100644 --- a/dbservice/postgresdb/connect.go +++ b/dbservice/postgresdb/connect.go @@ -2,16 +2,39 @@ package postgresdb import ( "errors" + "time" + "github.com/aquasecurity/postee/log" "github.com/jmoiron/sqlx" ) +const ( + CONN_RETRIES = 10 +) + +var dbConn *sqlx.DB + var psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - db, err := sqlx.Connect("postgres", connectUrl) - if err != nil { - return nil, err + if dbConn == nil { + retries := CONN_RETRIES + db, err := sqlx.Connect("postgres", connectUrl) + for err != nil { + if log.Logger != nil { + log.Logger.Errorf("failed to connect to postgres db (%d): %s", retries, err.Error()) + } + + if retries > 1 { + retries-- + time.Sleep(5 * time.Second) + db, err = sqlx.Connect("postgres", connectUrl) + continue + } + return nil, err + } + dbConn = db } - return db, nil + + return dbConn, nil } var testConnect = func(connectUrl string) (*sqlx.DB, error) { diff --git a/dbservice/postgresdb/dbaggregator.go b/dbservice/postgresdb/dbaggregator.go index cd96e4fd..26f32d00 100644 --- a/dbservice/postgresdb/dbaggregator.go +++ b/dbservice/postgresdb/dbaggregator.go @@ -18,7 +18,6 @@ func (postgresDb *PostgresDb) AggregateScans(output string, if err != nil { return nil, err } - defer db.Close() aggregatedScans := make([]map[string]string, 0, scansPerTicket) if len(currentScan) > 0 { diff --git a/dbservice/postgresdb/init.go b/dbservice/postgresdb/init.go index cc374050..0025307f 100644 --- a/dbservice/postgresdb/init.go +++ b/dbservice/postgresdb/init.go @@ -32,7 +32,6 @@ var InitPostgresDb = func(connectUrl string) error { if err != nil { return err } - defer db.Close() err = initAllTables(db) if err != nil { diff --git a/dbservice/postgresdb/plgstats.go b/dbservice/postgresdb/plgstats.go index 867a6d31..56fcecfd 100644 --- a/dbservice/postgresdb/plgstats.go +++ b/dbservice/postgresdb/plgstats.go @@ -14,7 +14,6 @@ func (postgresDb *PostgresDb) RegisterPlgnInvctn(name string) error { if err != nil { return err } - defer db.Close() amount := 0 sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (tenantName=$1 AND %s=$2)", "amount", dbparam.DbBucketOutputStats, "outputName") diff --git a/dbservice/postgresdb/sharedcfg.go b/dbservice/postgresdb/sharedcfg.go index 33e14741..8acf447b 100644 --- a/dbservice/postgresdb/sharedcfg.go +++ b/dbservice/postgresdb/sharedcfg.go @@ -14,7 +14,6 @@ func (postgresDb *PostgresDb) EnsureApiKey() error { if err != nil { return err } - defer db.Close() apiKey, err := dbparam.GenerateApiKey(32) if err != nil { @@ -32,7 +31,7 @@ func (postgresDb *PostgresDb) GetApiKey() (string, error) { if err != nil { return "", err } - defer db.Close() + value := "" sqlQuery := fmt.Sprintf("SELECT %s FROM %s WHERE (tenantName=$1 AND %s=$2)", "value", dbparam.DbBucketSharedConfig, "apikeyname") err = db.Get(&value, sqlQuery, postgresDb.TenantName, apiKeyName) From 17027ccefaea0e2d8f35278fbb1b84e9c57ac4da Mon Sep 17 00:00:00 2001 From: Elad Zada Date: Sun, 30 Jan 2022 19:07:27 +0200 Subject: [PATCH 59/61] adjust unit test for Postgres singleton and retrying functionality --- dbservice/boltdb/boltdb.go | 5 +++++ dbservice/dbservice.go | 1 + dbservice/postgresdb/close.go | 10 ++++++++++ dbservice/postgresdb/connect.go | 16 +++++++++------- dbservice/postgresdb/connect_test.go | 8 -------- router/api_test.go | 8 ++++++++ router/inittemplate_test.go | 18 +++++++++++------- router/loads_test.go | 2 +- 8 files changed, 45 insertions(+), 23 deletions(-) create mode 100644 dbservice/postgresdb/close.go diff --git a/dbservice/boltdb/boltdb.go b/dbservice/boltdb/boltdb.go index cea58c95..5cde909a 100644 --- a/dbservice/boltdb/boltdb.go +++ b/dbservice/boltdb/boltdb.go @@ -42,3 +42,8 @@ func (boltDb *BoltDb) SetNewDbPath(newPath string) error { } return nil } + +// unimplemented +func (boltDb *BoltDb) Close() error { + return nil +} diff --git a/dbservice/dbservice.go b/dbservice/dbservice.go index 318d87ae..3c54ff39 100644 --- a/dbservice/dbservice.go +++ b/dbservice/dbservice.go @@ -22,6 +22,7 @@ type DbProvider interface { RegisterPlgnInvctn(name string) error EnsureApiKey() error GetApiKey() (string, error) + Close() error } func ConfigureDb(pathToDb, postgresUrl, tenantName string) error { diff --git a/dbservice/postgresdb/close.go b/dbservice/postgresdb/close.go new file mode 100644 index 00000000..182b90d2 --- /dev/null +++ b/dbservice/postgresdb/close.go @@ -0,0 +1,10 @@ +package postgresdb + +func (postgresDb *PostgresDb) Close() error { + db, err := psqlConnect(postgresDb.ConnectUrl) + if err != nil { + return err + } + + return db.Close() +} diff --git a/dbservice/postgresdb/connect.go b/dbservice/postgresdb/connect.go index e6971e7a..e9ad3845 100644 --- a/dbservice/postgresdb/connect.go +++ b/dbservice/postgresdb/connect.go @@ -2,6 +2,7 @@ package postgresdb import ( "errors" + "sync" "time" "github.com/aquasecurity/postee/log" @@ -12,16 +13,17 @@ const ( CONN_RETRIES = 10 ) -var dbConn *sqlx.DB +var ( + once sync.Once + dbConn *sqlx.DB +) var psqlConnect = func(connectUrl string) (*sqlx.DB, error) { - if dbConn == nil { + once.Do(func() { retries := CONN_RETRIES db, err := sqlx.Connect("postgres", connectUrl) for err != nil { - if log.Logger != nil { - log.Logger.Errorf("failed to connect to postgres db (%d): %s", retries, err.Error()) - } + log.Logger.Errorf("failed to connect to postgres db (%d): %s", retries, err.Error()) if retries > 1 { retries-- @@ -29,10 +31,10 @@ var psqlConnect = func(connectUrl string) (*sqlx.DB, error) { db, err = sqlx.Connect("postgres", connectUrl) continue } - return nil, err + log.Logger.Fatal(err) } dbConn = db - } + }) return dbConn, nil } diff --git a/dbservice/postgresdb/connect_test.go b/dbservice/postgresdb/connect_test.go index 7004d801..fa9b970c 100644 --- a/dbservice/postgresdb/connect_test.go +++ b/dbservice/postgresdb/connect_test.go @@ -21,11 +21,3 @@ func TestConnectFuncError(t *testing.T) { t.Errorf("error text connect, expectedError: %v, got: %v", expectedError, err) } } - -func TestPsqlConnectError(t *testing.T) { - expectedError := `missing "=" after "test_trivy_psql_connect_dbName" in connection info string"` - _, err := psqlConnect("test_trivy_psql_connect_dbName") - if err.Error() != expectedError { - t.Errorf("Unexpected error, expected: '%v', got: '%v'", expectedError, err) - } -} diff --git a/router/api_test.go b/router/api_test.go index 423c1291..82c07a5b 100644 --- a/router/api_test.go +++ b/router/api_test.go @@ -545,7 +545,12 @@ var withPostgresParamsTest = func() error { j, _ := json.Marshal(outputSettings) return string(j), nil } + + savedUpdateCfgCache := postgresdb.UpdateCfgCacheSource + postgresdb.UpdateCfgCacheSource = func(postgresDb *postgresdb.PostgresDb, cfgfile string) error { return nil } + defer func() { + postgresdb.UpdateCfgCacheSource = savedUpdateCfgCache postgresdb.InitPostgresDb = savedInitPostgresDb postgresdb.GetCfgCacheSource = savedGetCfgCacheSource }() @@ -561,7 +566,10 @@ var withPostgresUrlTest = func() error { postgresdb.InitPostgresDb = func(connectUrl string) error { return nil } savedGetCfgCacheSource := postgresdb.GetCfgCacheSource postgresdb.GetCfgCacheSource = func(postgresDb *postgresdb.PostgresDb) (string, error) { return "", nil } + savedUpdateCfgCache := postgresdb.UpdateCfgCacheSource + postgresdb.UpdateCfgCacheSource = func(postgresDb *postgresdb.PostgresDb, cfgfile string) error { return nil } defer func() { + postgresdb.UpdateCfgCacheSource = savedUpdateCfgCache postgresdb.InitPostgresDb = savedInitPostgresDb postgresdb.GetCfgCacheSource = savedGetCfgCacheSource }() diff --git a/router/inittemplate_test.go b/router/inittemplate_test.go index c39eb512..b0fe48ed 100644 --- a/router/inittemplate_test.go +++ b/router/inittemplate_test.go @@ -22,16 +22,20 @@ func TestInitTemplate(t *testing.T) { defaultRegoFolder := "rego-templates" commonRegoFolder := defaultRegoFolder + "/common" testRego := defaultRegoFolder + "/rego1.rego" - err := os.Mkdir(defaultRegoFolder, 0777) - if err != nil { - t.Fatalf("Can't create rego folder: %v", err) + if _, err := os.Stat(defaultRegoFolder); os.IsNotExist(err) { + err = os.Mkdir(defaultRegoFolder, 0777) + if err != nil { + t.Fatalf("Can't create rego folder: %v", err) + } } - err = os.Mkdir(commonRegoFolder, 0777) - if err != nil { - t.Fatalf("Can't create rego folder: %v", err) + if _, err := os.Stat(commonRegoFolder); os.IsNotExist(err) { + err = os.Mkdir(commonRegoFolder, 0777) + if err != nil { + t.Fatalf("Can't create rego folder: %v", err) + } } - err = ioutil.WriteFile(testRego, []byte(regoRule), 0644) + err := ioutil.WriteFile(testRego, []byte(regoRule), 0644) if err != nil { t.Fatalf("Can't write rego: %v", err) diff --git a/router/loads_test.go b/router/loads_test.go index 5120ce42..45b507b2 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -289,7 +289,7 @@ func TestApplyPostgresCfg(t *testing.T) { postgresdb.UpdateCfgCacheSource = savedUpdateCfgCacheSource }() - err := demoCtx.ApplyPostgresCfg("tenantName", postgresUrl, false) + err := demoCtx.ApplyPostgresCfg("tenantName", postgresUrl, true) if err != nil { t.Errorf("Unexpected error: %v", err) } From e1fae4c5f08916d485fe21802fa0cea6341017f8 Mon Sep 17 00:00:00 2001 From: Tom Weiss Date: Sun, 30 Jan 2022 19:49:42 +0200 Subject: [PATCH 60/61] Router | remove redundant rego match check --- router/router.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/router/router.go b/router/router.go index 7f9d3d08..d25acd14 100644 --- a/router/router.go +++ b/router/router.go @@ -564,14 +564,6 @@ func (ctx *Router) HandleRoute(routeName string, in []byte) { return } - if ok, err := regoservice.DoesMatchRegoCriteria(inMsg, r.InputFiles, r.Input); err != nil { - utils.PrnInputLogs("Error while evaluating rego rule %s :%v for the input %s", r.Input, err, in) - return - } else if !ok { - utils.PrnInputLogs("Input %s... doesn't match a REGO rule: %s", in, r.Input) - return - } - inputCallbacks := ctx.inputCallBacks[routeName] for _, callback := range inputCallbacks { @@ -579,6 +571,7 @@ func (ctx *Router) HandleRoute(routeName string, in []byte) { return } } + for _, outputName := range r.Outputs { pl, ok := ctx.outputs[outputName] if !ok { From a788cf74ee4031ea09acc0d1a99ca1daae8b284d Mon Sep 17 00:00:00 2001 From: Tom Weiss Date: Mon, 31 Jan 2022 15:11:19 +0200 Subject: [PATCH 61/61] Evaluate rego rule once for every route This commit contains code that moves the rego evaluation code to the route handling process. At its current implementation, the output handler checks for rego match - this check is redundant since the output does not effect the matching process. --- msgservice/aggregatebytime_test.go | 12 ++++++-- msgservice/aggregatescan_test.go | 4 ++- msgservice/applicationscopeowner_test.go | 4 ++- msgservice/getuniqueid_test.go | 4 ++- msgservice/msghandling.go | 37 ++++++++++++++---------- msgservice/msgservice_test.go | 35 ++++++++++++++++++++-- msgservice/regocriteria_test.go | 4 ++- router/loads_test.go | 8 +++++ router/routehandling_test.go | 6 ++-- router/router.go | 8 ++++- 10 files changed, 92 insertions(+), 30 deletions(-) diff --git a/msgservice/aggregatebytime_test.go b/msgservice/aggregatebytime_test.go index e811d083..023db839 100644 --- a/msgservice/aggregatebytime_test.go +++ b/msgservice/aggregatebytime_test.go @@ -58,9 +58,15 @@ func TestAggregateByTimeout(t *testing.T) { srvUrl := "" srv1 := new(MsgService) - srv1.MsgHandling(mockScan1, demoEmailPlg, demoRoute, demoInptEval, &srvUrl) - srv1.MsgHandling(mockScan2, demoEmailPlg, demoRoute, demoInptEval, &srvUrl) - srv1.MsgHandling(mockScan3, demoEmailPlg, demoRoute, demoInptEval, &srvUrl) + if srv1.EvaluateRegoRule(demoRoute, mockScan1) { + srv1.MsgHandling(mockScan1, demoEmailPlg, demoRoute, demoInptEval, &srvUrl) + } + if srv1.EvaluateRegoRule(demoRoute, mockScan2) { + srv1.MsgHandling(mockScan2, demoEmailPlg, demoRoute, demoInptEval, &srvUrl) + } + if srv1.EvaluateRegoRule(demoRoute, mockScan3) { + srv1.MsgHandling(mockScan3, demoEmailPlg, demoRoute, demoInptEval, &srvUrl) + } expectedSchedulerInvctCnt := 1 diff --git a/msgservice/aggregatescan_test.go b/msgservice/aggregatescan_test.go index d016e54a..7db37d38 100644 --- a/msgservice/aggregatescan_test.go +++ b/msgservice/aggregatescan_test.go @@ -71,7 +71,9 @@ func doAggregate(t *testing.T, caseDesc string, expectedSntCnt int, expectedRend for _, scan := range scans { srv := new(MsgService) - srv.MsgHandling(scan, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + if srv.EvaluateRegoRule(demoRoute, scan) { + srv.MsgHandling(scan, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + } } demoEmailOutput.wg.Wait() diff --git a/msgservice/applicationscopeowner_test.go b/msgservice/applicationscopeowner_test.go index a8ff3fed..fd96910c 100644 --- a/msgservice/applicationscopeowner_test.go +++ b/msgservice/applicationscopeowner_test.go @@ -48,7 +48,9 @@ func TestApplicationScopeOwner(t *testing.T) { demoEmailOutput.wg.Add(1) srv := new(MsgService) - srv.MsgHandling(scnWithOwners, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + if srv.EvaluateRegoRule(demoRoute, scnWithOwners) { + srv.MsgHandling(scnWithOwners, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + } demoEmailOutput.wg.Wait() diff --git a/msgservice/getuniqueid_test.go b/msgservice/getuniqueid_test.go index cfc6df35..8f586439 100644 --- a/msgservice/getuniqueid_test.go +++ b/msgservice/getuniqueid_test.go @@ -115,7 +115,9 @@ func sendInputs(t *testing.T, caseDesc string, inputs []map[string]interface{}, for _, inp := range inputs { srv := new(MsgService) - srv.MsgHandling(inp, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + if srv.EvaluateRegoRule(demoRoute, inp) { + srv.MsgHandling(inp, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + } } demoEmailOutput.wg.Wait() diff --git a/msgservice/msghandling.go b/msgservice/msghandling.go index 00672ed8..bae81abf 100644 --- a/msgservice/msghandling.go +++ b/msgservice/msghandling.go @@ -25,22 +25,6 @@ func (scan *MsgService) MsgHandling(in map[string]interface{}, output outputs.Ou //TODO marshalling message back to bytes, change after merge with https://github.com/aquasecurity/postee/pull/150 input, _ := json.Marshal(in) - if ok, err := regoservice.DoesMatchRegoCriteria(in, route.InputFiles, route.Input); err != nil { - if !regoservice.IsUsedRegoFiles(route.InputFiles) { - utils.PrnInputLogs("Error while evaluating rego rule %s :%v for the input %s", route.Input, err, input) - } else { - utils.PrnInputLogs("Error while evaluating rego rule for input files :%v for the input %s", err, input) - } - return - } else if !ok { - if !regoservice.IsUsedRegoFiles(route.InputFiles) { - utils.PrnInputLogs("Input %s... doesn't match a REGO rule: %s", input, route.Input) - } else { - utils.PrnInputLogs("Input %s... doesn't match a REGO input files rule", input) - } - return - } - //TODO move logic below somewhere close to Jira output implementation owners := "" applicationScopeOwnersObj, ok := in["application_scope_owners"] @@ -111,6 +95,27 @@ func (scan *MsgService) MsgHandling(in map[string]interface{}, output outputs.Ou } } +// EvaluateRegoRule returns true in case the given input ([]byte) matches the input of the given route +func (scan *MsgService) EvaluateRegoRule(r *routes.InputRoute, input map[string]interface{}) bool { + if ok, err := regoservice.DoesMatchRegoCriteria(input, r.InputFiles, r.Input); err != nil { + if !regoservice.IsUsedRegoFiles(r.InputFiles) { + utils.PrnInputLogs("Error while evaluating rego rule %s :%v for the input %s", r.Input, err, input) + } else { + utils.PrnInputLogs("Error while evaluating rego rule for input files :%v for the input %s", err, input) + } + return false + } else if !ok { + if !regoservice.IsUsedRegoFiles(r.InputFiles) { + utils.PrnInputLogs("Input %s... doesn't match a REGO rule: %s", input, r.Input) + } else { + utils.PrnInputLogs("Input %s... doesn't match a REGO input files rule", input) + } + return false + } + + return true +} + func send(otpt outputs.Output, cnt map[string]string) { go func() { err := otpt.Send(cnt) diff --git a/msgservice/msgservice_test.go b/msgservice/msgservice_test.go index b883b3b9..be358984 100644 --- a/msgservice/msgservice_test.go +++ b/msgservice/msgservice_test.go @@ -36,6 +36,7 @@ func (inptEval *FailingInptEval) BuildAggregatedContent(items []map[string]strin func (inptEval *FailingInptEval) IsAggregationSupported() bool { return inptEval.expectedAggrError != nil } + func TestEvalError(t *testing.T) { dbPathReal := db.DbPath defer func() { @@ -60,7 +61,9 @@ func TestEvalError(t *testing.T) { } srv := new(MsgService) - srv.MsgHandling(mockScan1, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + if srv.EvaluateRegoRule(demoRoute, mockScan1) { + srv.MsgHandling(mockScan1, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + } if demoEmailOutput.getEmailsCount() > 0 { t.Errorf("Output shouldn't be called when evaluation is failed") @@ -97,10 +100,38 @@ func TestAggrEvalError(t *testing.T) { for i := 0; i < 2; i++ { srv := new(MsgService) - srv.MsgHandling(mockScan1, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + if srv.EvaluateRegoRule(demoRoute, mockScan1) { + srv.MsgHandling(mockScan1, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + } } if demoEmailOutput.getEmailsCount() > 0 { t.Errorf("Output shouldn't be called when evaluation is failed") } } + +func TestEmptyInput(t *testing.T) { + dbPathReal := db.DbPath + defer func() { + os.Remove(db.DbPath) + db.DbPath = dbPathReal + }() + db.DbPath = "test_webhooks.db" + + srvUrl := "" + + demoRoute := &routes.InputRoute{} + + demoRoute.Name = "demo-route" + + demoInptEval := &DemoInptEval{} + + srv := new(MsgService) + if srv.EvaluateRegoRule(demoRoute, map[string]interface{}{}) { + srv.MsgHandling(map[string]interface{}{}, nil, demoRoute, demoInptEval, &srvUrl) + } + + if demoInptEval.renderCnt != 0 { + t.Errorf("Eval() shouldn't be called if no output is passed to ResultHandling()") + } +} diff --git a/msgservice/regocriteria_test.go b/msgservice/regocriteria_test.go index fc86ba3f..cbba4d49 100644 --- a/msgservice/regocriteria_test.go +++ b/msgservice/regocriteria_test.go @@ -139,7 +139,9 @@ func validateRegoInput(t *testing.T, caseDesc string, input map[string]interface demoEmailOutput.wg.Add(expected) srv := new(MsgService) - srv.MsgHandling(input, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + if srv.EvaluateRegoRule(demoRoute, input) { + srv.MsgHandling(input, demoEmailOutput, demoRoute, demoInptEval, &srvUrl) + } demoEmailOutput.wg.Wait() diff --git a/router/loads_test.go b/router/loads_test.go index 45b507b2..2ad2a7e3 100644 --- a/router/loads_test.go +++ b/router/loads_test.go @@ -78,6 +78,7 @@ type ctxWrapper struct { commonRegoFolder string buff chan invctn } + type invctn struct { outputCls string templateCls string @@ -144,6 +145,13 @@ func (ctxWrapper *ctxWrapper) teardown() { close(ctxWrapper.buff) } +func (ctx *ctxWrapper) EvaluateRegoRule(r *routes.InputRoute, _ map[string]interface{}) bool { + if r.Name == "fail_evaluation" { + return false + } + return true +} + func TestLoads(t *testing.T) { wrap := ctxWrapper{} wrap.setup(cfgData) diff --git a/router/routehandling_test.go b/router/routehandling_test.go index 9d8bce99..114b7ccf 100644 --- a/router/routehandling_test.go +++ b/router/routehandling_test.go @@ -7,9 +7,7 @@ import ( "time" ) -var ( - payload = `{"image" : "alpine"}` -) +var payload = `{"image" : "alpine"}` func TestHandling(t *testing.T) { tests := []struct { @@ -202,8 +200,8 @@ func TestInvalidRouteName(t *testing.T) { } } } - } + func TestSend(t *testing.T) { expctdInvctns := 1 actualInvctCnt := 0 diff --git a/router/router.go b/router/router.go index d25acd14..fa181aaf 100644 --- a/router/router.go +++ b/router/router.go @@ -538,6 +538,7 @@ func (ctx *Router) loadCfgCacheSourceFromPostgres() (*data.TenantSettings, error type service interface { MsgHandling(input map[string]interface{}, output outputs.Output, route *routes.InputRoute, inpteval data.Inpteval, aquaServer *string) + EvaluateRegoRule(input *routes.InputRoute, in map[string]interface{}) bool } var getScanService = func() service { @@ -571,7 +572,12 @@ func (ctx *Router) HandleRoute(routeName string, in []byte) { return } } - + + if !getScanService().EvaluateRegoRule(r, inMsg) { + log.Logger.Infof("Rego match was not found for route %s", routeName) + return + } + for _, outputName := range r.Outputs { pl, ok := ctx.outputs[outputName] if !ok {