diff --git a/assertion/global/analyzer.go b/assertion/global/analyzer.go index 692365fd..a6b4e84c 100644 --- a/assertion/global/analyzer.go +++ b/assertion/global/analyzer.go @@ -50,6 +50,7 @@ func run(p *analysis.Pass) ([]annotation.FullTrigger, error) { if !conf.IsFileInScope(file) { continue } + initFuncDecls := getInitFuncDecls(file) for _, decl := range file.Decls { genDecl, ok := decl.(*ast.GenDecl) @@ -57,10 +58,57 @@ func run(p *analysis.Pass) ([]annotation.FullTrigger, error) { continue } for _, spec := range genDecl.Specs { - fullTriggers = append(fullTriggers, analyzeValueSpec(pass, spec.(*ast.ValueSpec))...) + fullTriggers = append(fullTriggers, analyzeValueSpec(pass, spec.(*ast.ValueSpec), initFuncDecls)...) } } } return fullTriggers, nil } + +// getInitFuncDecls searches for the init function declarations and all related functions in the given *ast.File. +// It returns a slice of *ast.FuncDecl representing the init functions and all functions called directly or indirectly from them. +// The function handles multiple init functions if present, and avoids infinite recursion in case of cyclic function calls. +// If the file is nil, it returns nil. +func getInitFuncDecls(file *ast.File) []*ast.FuncDecl { + if file == nil { + return nil + } + funcDecls := make(map[string]*ast.FuncDecl) + for _, decl := range file.Decls { + if funcDecl, ok := decl.(*ast.FuncDecl); ok { + funcDecls[funcDecl.Name.Name] = funcDecl + } + } + + var initFuncDecls []*ast.FuncDecl + // visitedFuncs tracks processed functions to prevent infinite recursion and duplicate processing + visitedFuncs := make(map[string]struct{}) + var findRelatedFuncs func(*ast.FuncDecl) + findRelatedFuncs = func(funcDecl *ast.FuncDecl) { + if _, visited := visitedFuncs[funcDecl.Name.Name]; visited { + return + } + initFuncDecls = append(initFuncDecls, funcDecl) + visitedFuncs[funcDecl.Name.Name] = struct{}{} + ast.Inspect(funcDecl.Body, func(n ast.Node) bool { + if callExpr, ok := n.(*ast.CallExpr); ok { + if ident, ok := callExpr.Fun.(*ast.Ident); ok { + if funcDecl, exists := funcDecls[ident.Name]; exists { + findRelatedFuncs(funcDecl) + } + } + } + return true + }) + } + + for _, decl := range file.Decls { + if funcDecl, ok := decl.(*ast.FuncDecl); ok && funcDecl.Name.Name == "init" { + findRelatedFuncs(funcDecl) + // Reset visitedFuncs for each init function to ensure all related functions are processed + visitedFuncs = make(map[string]struct{}) + } + } + return initFuncDecls +} diff --git a/assertion/global/globalvarinit.go b/assertion/global/globalvarinit.go index 20dfc68c..75683e3c 100644 --- a/assertion/global/globalvarinit.go +++ b/assertion/global/globalvarinit.go @@ -26,10 +26,10 @@ import ( ) // analyzeValueSpec returns full triggers corresponding to the declaration -func analyzeValueSpec(pass *analysishelper.EnhancedPass, spec *ast.ValueSpec) []annotation.FullTrigger { +func analyzeValueSpec(pass *analysishelper.EnhancedPass, spec *ast.ValueSpec, initFuncDecls []*ast.FuncDecl) []annotation.FullTrigger { var fullTriggers []annotation.FullTrigger - consumers := getGlobalConsumers(pass, spec) + consumers := getGlobalConsumers(pass, spec, initFuncDecls) for i, ident := range spec.Names { if consumers[i] == nil { @@ -65,12 +65,12 @@ func analyzeValueSpec(pass *analysishelper.EnhancedPass, spec *ast.ValueSpec) [] } // Returns a list of consumers corresponding to a global level variable declaration -func getGlobalConsumers(pass *analysishelper.EnhancedPass, valspec *ast.ValueSpec) []*annotation.ConsumeTrigger { +func getGlobalConsumers(pass *analysishelper.EnhancedPass, valspec *ast.ValueSpec, initFuncDecls []*ast.FuncDecl) []*annotation.ConsumeTrigger { consumers := make([]*annotation.ConsumeTrigger, len(valspec.Names)) for i, name := range valspec.Names { // Types that are not nilable are eliminated here - if !asthelper.IsEmptyExpr(name) && !typeshelper.TypeBarsNilness(pass.TypesInfo.TypeOf(name)) { + if !asthelper.IsEmptyExpr(name) && !typeshelper.TypeBarsNilness(pass.TypesInfo.TypeOf(name)) && !hasGlobalVarAssignInInitFunc(valspec, initFuncDecls) { v := pass.TypesInfo.ObjectOf(name).(*types.Var) consumers[i] = &annotation.ConsumeTrigger{ Annotation: &annotation.GlobalVarAssign{ @@ -87,6 +87,40 @@ func getGlobalConsumers(pass *analysishelper.EnhancedPass, valspec *ast.ValueSpe return consumers } +// Checks if all the global variables represented by spec are assigned values within the init function. +// It returns true if all variables are assigned, false otherwise. +// If initFuncDecl is nil, it returns false. +func hasGlobalVarAssignInInitFunc(spec *ast.ValueSpec, initFuncDecls []*ast.FuncDecl) bool { + if len(initFuncDecls) == 0 { + return false + } + assignedVars := make(map[string]bool) + for _, name := range spec.Names { + assignedVars[name.Name] = false + } + for _, initFuncDecl := range initFuncDecls { + ast.Inspect(initFuncDecl.Body, func(node ast.Node) bool { + if assign, ok := node.(*ast.AssignStmt); ok { + for _, lhs := range assign.Lhs { + if ident, ok := lhs.(*ast.Ident); ok { + if _, exists := assignedVars[ident.Name]; exists { + assignedVars[ident.Name] = true + } + } + } + } + return true + }) + } + + for _, assigned := range assignedVars { + if !assigned { + return false + } + } + return true +} + // Returns a producer in the cases: 1) func call 2) literal nil 3) another global var 4) struct field/method. // In all other cases, it returns nil. func getGlobalProducer(pass *analysishelper.EnhancedPass, valspec *ast.ValueSpec, lid int, rid int) *annotation.ProduceTrigger { diff --git a/testdata/src/go.uber.org/globalvars/globalvarinit.go b/testdata/src/go.uber.org/globalvars/globalvarinit.go index ee2aab21..b3b64e9a 100644 --- a/testdata/src/go.uber.org/globalvars/globalvarinit.go +++ b/testdata/src/go.uber.org/globalvars/globalvarinit.go @@ -25,6 +25,35 @@ var x = 3 // This should throw an error since it is not initialized var noInit *int //want "assigned into global variable" +var _init *int +var _initMult1, _initMult2 *int + +func init() { + _init = new(int) + _initMult1 = new(int) + _initMult2 = new(int) +} + +var _init2 *int + +func init() { + _init2 = new(int) +} + +var _init3, _init4 *int + +func init() { + init_next() +} + +func init_next() { + _init3 = new(int) + init_next_next() +} + +func init_next_next() { + _init4 = new(int) +} // nilable(nilableVar) var nilableVar *int