diff --git a/annotation/key.go b/annotation/key.go index 0ba82292..34f077d7 100644 --- a/annotation/key.go +++ b/annotation/key.go @@ -710,3 +710,42 @@ func (rk *RecvAnnotationKey) copy() Key { func (rk *RecvAnnotationKey) String() string { return fmt.Sprintf("Receiver of Method %s", rk.FuncDecl.Name()) } + +// FuncVarRetAnnotationKey allows the Lookup of a function variable's return in the Annotation Map. +// This key is used when the function being called is a variable (e.g., `s.f()` where f is a function field) +// rather than a declared function. +type FuncVarRetAnnotationKey struct { + // Location uniquely identifies the call site + Location token.Position + // RetNum is the index of the return value + RetNum int +} + +// Lookup looks this key up in the passed map, returning a Val. +// Function variable returns are not annotated, so we return the optimistic default. +func (fk *FuncVarRetAnnotationKey) Lookup(_ Map) (Val, bool) { + // Function variables can't be annotated, so return "might be nil" by default + // This allows the guard mechanism to work properly + return nonAnnotatedDefault, false +} + +// Object returns nil since function variable returns don't have a types.Object. +func (fk *FuncVarRetAnnotationKey) Object() types.Object { + return nil +} + +func (fk *FuncVarRetAnnotationKey) equals(other Key) bool { + if other, ok := other.(*FuncVarRetAnnotationKey); ok { + return fk.Location == other.Location && fk.RetNum == other.RetNum + } + return false +} + +func (fk *FuncVarRetAnnotationKey) copy() Key { + copyKey := *fk + return ©Key +} + +func (fk *FuncVarRetAnnotationKey) String() string { + return fmt.Sprintf("FuncVar Result %d at %s", fk.RetNum, fk.Location.String()) +} diff --git a/annotation/key_test.go b/annotation/key_test.go index 737378a7..7962b2d7 100644 --- a/annotation/key_test.go +++ b/annotation/key_test.go @@ -36,6 +36,7 @@ var initStructsKey = []any{ &EscapeFieldAnnotationKey{}, &ParamFieldAnnotationKey{}, &LocalVarAnnotationKey{}, + &FuncVarRetAnnotationKey{}, } // TestKeyEqualsSuite runs the test suite for the `equals` method of all the structs that implement diff --git a/annotation/produce_trigger.go b/annotation/produce_trigger.go index 753e6615..e3e0acb7 100644 --- a/annotation/produce_trigger.go +++ b/annotation/produce_trigger.go @@ -775,8 +775,10 @@ func (f *FuncReturn) Prestring() Prestring { return FuncReturnPrestring{key.RetNum, key.FuncDecl.Name(), ""} case *CallSiteRetAnnotationKey: return FuncReturnPrestring{key.RetNum, key.FuncDecl.Name(), key.Location.String()} + case *FuncVarRetAnnotationKey: + return FuncReturnPrestring{key.RetNum, "func variable", key.Location.String()} default: - panic(fmt.Sprintf("Expected RetAnnotationKey or CallSiteRetAnnotationKey but got: %T", key)) + panic(fmt.Sprintf("Expected RetAnnotationKey, CallSiteRetAnnotationKey, or FuncVarRetAnnotationKey but got: %T", key)) } } diff --git a/assertion/function/assertiontree/parse_expr_producer.go b/assertion/function/assertiontree/parse_expr_producer.go index 3a7db8a5..acda7a51 100644 --- a/assertion/function/assertiontree/parse_expr_producer.go +++ b/assertion/function/assertiontree/parse_expr_producer.go @@ -287,9 +287,12 @@ func (r *RootAssertionNode) ParseExprAsProducer(expr ast.Expr, doNotTrack bool) } // Check if the method is a function value, e.g., `f := func() {}` and then `f()`. - // TODO: this is a temporary fix to suppress false positives caused by function values. - // Remove this once we have have implemented the function value support. if r.isVariable(fun) { + if producers := r.getFuncVarReturnProducers(r.Pass().TypesInfo.TypeOf(fun), expr); producers != nil { + return nil, producers + } + // For non-error/ok-returning function variables, suppress false positives + // TODO: this is a temporary fix. Remove once function value support is complete. return nil, []producer.ParsedProducer{producer.ShallowParsedProducer{Producer: &annotation.ProduceTrigger{ Annotation: &annotation.TrustedFuncNonnil{ProduceTriggerNever: &annotation.ProduceTriggerNever{}}, Expr: expr, @@ -337,10 +340,13 @@ func (r *RootAssertionNode) ParseExprAsProducer(expr ast.Expr, doNotTrack bool) return nil, r.getFuncReturnProducers(fun, expr) case *ast.SelectorExpr: // method call - // Check if the method is a function value, e.g., `f := func() {}` and then `f()`. - // TODO: this is a temporary fix to handle the case of function values. - // Remove this once we have have implemented the function value support. + // Check if the method is a function value, e.g., `s.f()` where `f` is a function type field. if r.isVariable(fun.Sel) { + if producers := r.getFuncVarReturnProducers(r.Pass().TypesInfo.TypeOf(fun.Sel), expr); producers != nil { + return nil, producers + } + // For non-error/ok-returning function variables, suppress false positives + // TODO: this is a temporary fix. Remove once function value support is complete. return nil, []producer.ParsedProducer{producer.ShallowParsedProducer{Producer: &annotation.ProduceTrigger{ Annotation: &annotation.TrustedFuncNonnil{ProduceTriggerNever: &annotation.ProduceTriggerNever{}}, Expr: expr, @@ -582,7 +588,52 @@ func (r *RootAssertionNode) getFuncReturnProducers(ident *ast.Ident, expr *ast.C return producers } -// parseStructCreateExprAsProducer parses composite expressions used to initialize a struct e.g. A{f1: v1, f2: v2} +// getFuncVarReturnProducers returns producers for function variable calls (e.g., `f()` where f is a variable). +// Returns nil if the function should use the default TrustedFuncNonnil workaround. +func (r *RootAssertionNode) getFuncVarReturnProducers(funType types.Type, expr *ast.CallExpr) []producer.ParsedProducer { + sig := typeshelper.GetFuncSignature(funType) + if sig == nil { + return nil + } + + isErrReturning := typeshelper.FuncIsErrReturning(sig) + isOkReturning := typeshelper.FuncIsOkReturning(sig) + + // For non-rich-check-effect functions, use default workaround + if !isErrReturning && !isOkReturning { + return nil + } + + // Generate FuncReturn producers that integrate with rich check effects + numResults := sig.Results().Len() + producers := make([]producer.ParsedProducer, numResults) + callLocation := r.Pass().Fset.Position(expr.Pos()) + + for i := 0; i < numResults; i++ { + resultType := sig.Results().At(i).Type() + var shallowAnnotation annotation.ProducingAnnotationTrigger + + if typeshelper.TypeBarsNilness(resultType) { + shallowAnnotation = &annotation.TrustedFuncNonnil{ProduceTriggerNever: &annotation.ProduceTriggerNever{}} + } else { + shallowAnnotation = &annotation.FuncReturn{ + TriggerIfNilable: &annotation.TriggerIfNilable{ + Ann: &annotation.FuncVarRetAnnotationKey{Location: callLocation, RetNum: i}, + NeedsGuard: (isErrReturning || isOkReturning) && i != numResults-1, + }, + IsFromRichCheckEffectFunc: isErrReturning || isOkReturning, + } + } + + producers[i] = producer.ShallowParsedProducer{Producer: &annotation.ProduceTrigger{ + Annotation: shallowAnnotation, + Expr: expr, + }} + } + return producers +} + +// parseStructCreateExprAsProducer parsed composite expressions used to initialize a struct e.g. A{f1: v1, f2: v2} func (r *RootAssertionNode) parseStructCreateExprAsProducer(expr ast.Expr, fieldInitializations []ast.Expr) producer.ParsedProducer { exprType := r.Pass().TypesInfo.TypeOf(expr) diff --git a/inference/primitive.go b/inference/primitive.go index 98a8aef5..50ded1b8 100644 --- a/inference/primitive.go +++ b/inference/primitive.go @@ -191,7 +191,22 @@ func (p *primitivizer) fullTrigger(trigger annotation.FullTrigger) primitiveFull // site returns the primitive version of the annotation site. func (p *primitivizer) site(key annotation.Key, isDeep bool) primitiveSite { - objPath, err := p.objPathEncoder.For(key.Object()) + obj := key.Object() + + // Handle keys with nil Object (e.g., FuncVarRetAnnotationKey for function variable returns). + // These are local-only annotations that don't participate in cross-package inference. + if obj == nil { + return primitiveSite{ + PkgPath: p.pass.Pkg.Path(), + Repr: key.String(), + IsDeep: isDeep, + Exported: false, + ObjectPath: "", + Position: token.Position{}, + } + } + + objPath, err := p.objPathEncoder.For(obj) if err != nil { // An error will occur when trying to get object path for unexported objects, in which case // we simply assign an empty object path. @@ -199,13 +214,13 @@ func (p *primitivizer) site(key annotation.Key, isDeep bool) primitiveSite { } pkgRepr := "" - if pkg := key.Object().Pkg(); pkg != nil { + if pkg := obj.Pkg(); pkg != nil { pkgRepr = pkg.Path() } var position token.Position // For upstream objects, we need to look up the local position cache for correct positions. - if key.Object().Pkg() != p.pass.Pkg { + if obj.Pkg() != p.pass.Pkg { // Correct upstream information may not always be in the cache: we may not even have it // since we skipped analysis for standard and 3rd party libraries. if p, ok := p.upstreamObjPositions[pkgRepr+"."+string(objPath)]; ok { @@ -217,14 +232,14 @@ func (p *primitivizer) site(key annotation.Key, isDeep bool) primitiveSite { // their Object.Pos() and retrieve the position information. However, we must trim the possible // build-system sandbox prefix from the filenames for cross-package references. if !position.IsValid() { - position = p.toPosition(key.Object().Pos()) + position = p.toPosition(obj.Pos()) } return primitiveSite{ PkgPath: pkgRepr, Repr: key.String(), IsDeep: isDeep, - Exported: key.Object().Exported(), + Exported: obj.Exported(), ObjectPath: objPath, Position: position, } diff --git a/nilaway b/nilaway new file mode 100755 index 00000000..3afcf808 Binary files /dev/null and b/nilaway differ diff --git a/testdata/src/go.uber.org/trustedfunc/issue277_function_pointer.go b/testdata/src/go.uber.org/trustedfunc/issue277_function_pointer.go new file mode 100644 index 00000000..26f11bff --- /dev/null +++ b/testdata/src/go.uber.org/trustedfunc/issue277_function_pointer.go @@ -0,0 +1,73 @@ +package trustedfunc + +// Test for issue #277: function pointers as rich check effect functions + +type S2 struct { + f func() (*int, error) +} + +// This is the false positive from the issue - should NOT report a warning +func testFunctionPointerErrorFP(s *S2) { + v, err := s.f() + if err != nil { + return + } + _ = *v // no error - err was checked +} + +// This SHOULD report a warning because we don't check err +func testFunctionPointerErrorNoCheck(s *S2) { + v, _ := s.f() + _ = *v // want "deref" +} + +// Another checked example - should NOT report +func testFunctionPointerErrorChecked(s *S2) { + v, err := s.f() + if err != nil { + return + } + _ = *v // no error - err was checked +} + +// Test with local function variable - no check +func testLocalFunctionVarNoCheck() { + var f func() (*int, error) + v, _ := f() + _ = *v // want "deref" +} + +// Test with local function variable - checked +func testLocalFunctionVarChecked() { + var f func() (*int, error) + v, err := f() + if err != nil { + return + } + _ = *v // no error - err was checked +} + +// Test with parameter function variable - no check +func testParamFunctionVarNoCheck(f func() (*int, error)) { + v, _ := f() + _ = *v // want "deref" +} + +// Test with parameter function variable - checked +func testParamFunctionVarChecked(f func() (*int, error)) { + v, err := f() + if err != nil { + return + } + _ = *v // no error - err was checked +} + +// Test from the issue description - this is the false positive case +// Should NOT report a warning +func testme(s *S2) { + v, err := s.f() + if err != nil { + return + } + _ = *v // no error - err was checked +}