Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions annotation/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,42 @@ func (rk *RecvAnnotationKey) copy() Key {
func (rk *RecvAnnotationKey) String() string {
return fmt.Sprintf("Receiver of Method %s", rk.FuncDecl.Name())
}

// FuncVarRetAnnotationKey allows the Lookup of a function variable's return in the Annotation Map.
// This key is used when the function being called is a variable (e.g., `s.f()` where f is a function field)
// rather than a declared function.
type FuncVarRetAnnotationKey struct {
// Location uniquely identifies the call site
Location token.Position
// RetNum is the index of the return value
RetNum int
}

// Lookup looks this key up in the passed map, returning a Val.
// Function variable returns are not annotated, so we return the optimistic default.
func (fk *FuncVarRetAnnotationKey) Lookup(_ Map) (Val, bool) {
// Function variables can't be annotated, so return "might be nil" by default
// This allows the guard mechanism to work properly
return nonAnnotatedDefault, false
}

// Object returns nil since function variable returns don't have a types.Object.
func (fk *FuncVarRetAnnotationKey) Object() types.Object {
return nil
}

func (fk *FuncVarRetAnnotationKey) equals(other Key) bool {
if other, ok := other.(*FuncVarRetAnnotationKey); ok {
return fk.Location == other.Location && fk.RetNum == other.RetNum
}
return false
}

func (fk *FuncVarRetAnnotationKey) copy() Key {
copyKey := *fk
return &copyKey
}

func (fk *FuncVarRetAnnotationKey) String() string {
return fmt.Sprintf("FuncVar Result %d at %s", fk.RetNum, fk.Location.String())
}
1 change: 1 addition & 0 deletions annotation/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var initStructsKey = []any{
&EscapeFieldAnnotationKey{},
&ParamFieldAnnotationKey{},
&LocalVarAnnotationKey{},
&FuncVarRetAnnotationKey{},
}

// TestKeyEqualsSuite runs the test suite for the `equals` method of all the structs that implement
Expand Down
4 changes: 3 additions & 1 deletion annotation/produce_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,10 @@ func (f *FuncReturn) Prestring() Prestring {
return FuncReturnPrestring{key.RetNum, key.FuncDecl.Name(), ""}
case *CallSiteRetAnnotationKey:
return FuncReturnPrestring{key.RetNum, key.FuncDecl.Name(), key.Location.String()}
case *FuncVarRetAnnotationKey:
return FuncReturnPrestring{key.RetNum, "func variable", key.Location.String()}
default:
panic(fmt.Sprintf("Expected RetAnnotationKey or CallSiteRetAnnotationKey but got: %T", key))
panic(fmt.Sprintf("Expected RetAnnotationKey, CallSiteRetAnnotationKey, or FuncVarRetAnnotationKey but got: %T", key))
}
}

Expand Down
63 changes: 57 additions & 6 deletions assertion/function/assertiontree/parse_expr_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,12 @@ func (r *RootAssertionNode) ParseExprAsProducer(expr ast.Expr, doNotTrack bool)
}

// Check if the method is a function value, e.g., `f := func() {}` and then `f()`.
// TODO: this is a temporary fix to suppress false positives caused by function values.
// Remove this once we have have implemented the function value support.
if r.isVariable(fun) {
if producers := r.getFuncVarReturnProducers(r.Pass().TypesInfo.TypeOf(fun), expr); producers != nil {
return nil, producers
}
// For non-error/ok-returning function variables, suppress false positives
// TODO: this is a temporary fix. Remove once function value support is complete.
return nil, []producer.ParsedProducer{producer.ShallowParsedProducer{Producer: &annotation.ProduceTrigger{
Annotation: &annotation.TrustedFuncNonnil{ProduceTriggerNever: &annotation.ProduceTriggerNever{}},
Expr: expr,
Expand Down Expand Up @@ -337,10 +340,13 @@ func (r *RootAssertionNode) ParseExprAsProducer(expr ast.Expr, doNotTrack bool)
return nil, r.getFuncReturnProducers(fun, expr)

case *ast.SelectorExpr: // method call
// Check if the method is a function value, e.g., `f := func() {}` and then `f()`.
// TODO: this is a temporary fix to handle the case of function values.
// Remove this once we have have implemented the function value support.
// Check if the method is a function value, e.g., `s.f()` where `f` is a function type field.
if r.isVariable(fun.Sel) {
if producers := r.getFuncVarReturnProducers(r.Pass().TypesInfo.TypeOf(fun.Sel), expr); producers != nil {
return nil, producers
}
// For non-error/ok-returning function variables, suppress false positives
// TODO: this is a temporary fix. Remove once function value support is complete.
return nil, []producer.ParsedProducer{producer.ShallowParsedProducer{Producer: &annotation.ProduceTrigger{
Annotation: &annotation.TrustedFuncNonnil{ProduceTriggerNever: &annotation.ProduceTriggerNever{}},
Expr: expr,
Expand Down Expand Up @@ -582,7 +588,52 @@ func (r *RootAssertionNode) getFuncReturnProducers(ident *ast.Ident, expr *ast.C
return producers
}

// parseStructCreateExprAsProducer parses composite expressions used to initialize a struct e.g. A{f1: v1, f2: v2}
// getFuncVarReturnProducers returns producers for function variable calls (e.g., `f()` where f is a variable).
// Returns nil if the function should use the default TrustedFuncNonnil workaround.
func (r *RootAssertionNode) getFuncVarReturnProducers(funType types.Type, expr *ast.CallExpr) []producer.ParsedProducer {
sig := typeshelper.GetFuncSignature(funType)
if sig == nil {
return nil
}

isErrReturning := typeshelper.FuncIsErrReturning(sig)
isOkReturning := typeshelper.FuncIsOkReturning(sig)

// For non-rich-check-effect functions, use default workaround
if !isErrReturning && !isOkReturning {
return nil
}

// Generate FuncReturn producers that integrate with rich check effects
numResults := sig.Results().Len()
producers := make([]producer.ParsedProducer, numResults)
callLocation := r.Pass().Fset.Position(expr.Pos())

for i := 0; i < numResults; i++ {
resultType := sig.Results().At(i).Type()
var shallowAnnotation annotation.ProducingAnnotationTrigger

if typeshelper.TypeBarsNilness(resultType) {
shallowAnnotation = &annotation.TrustedFuncNonnil{ProduceTriggerNever: &annotation.ProduceTriggerNever{}}
} else {
shallowAnnotation = &annotation.FuncReturn{
TriggerIfNilable: &annotation.TriggerIfNilable{
Ann: &annotation.FuncVarRetAnnotationKey{Location: callLocation, RetNum: i},
NeedsGuard: (isErrReturning || isOkReturning) && i != numResults-1,
},
IsFromRichCheckEffectFunc: isErrReturning || isOkReturning,
}
}

producers[i] = producer.ShallowParsedProducer{Producer: &annotation.ProduceTrigger{
Annotation: shallowAnnotation,
Expr: expr,
}}
}
return producers
}

// parseStructCreateExprAsProducer parsed composite expressions used to initialize a struct e.g. A{f1: v1, f2: v2}
func (r *RootAssertionNode) parseStructCreateExprAsProducer(expr ast.Expr, fieldInitializations []ast.Expr) producer.ParsedProducer {
exprType := r.Pass().TypesInfo.TypeOf(expr)

Expand Down
25 changes: 20 additions & 5 deletions inference/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,36 @@ func (p *primitivizer) fullTrigger(trigger annotation.FullTrigger) primitiveFull

// site returns the primitive version of the annotation site.
func (p *primitivizer) site(key annotation.Key, isDeep bool) primitiveSite {
objPath, err := p.objPathEncoder.For(key.Object())
obj := key.Object()

// Handle keys with nil Object (e.g., FuncVarRetAnnotationKey for function variable returns).
// These are local-only annotations that don't participate in cross-package inference.
if obj == nil {
return primitiveSite{
PkgPath: p.pass.Pkg.Path(),
Repr: key.String(),
IsDeep: isDeep,
Exported: false,
ObjectPath: "",
Position: token.Position{},
}
}

objPath, err := p.objPathEncoder.For(obj)
if err != nil {
// An error will occur when trying to get object path for unexported objects, in which case
// we simply assign an empty object path.
objPath = ""
}

pkgRepr := ""
if pkg := key.Object().Pkg(); pkg != nil {
if pkg := obj.Pkg(); pkg != nil {
pkgRepr = pkg.Path()
}

var position token.Position
// For upstream objects, we need to look up the local position cache for correct positions.
if key.Object().Pkg() != p.pass.Pkg {
if obj.Pkg() != p.pass.Pkg {
// Correct upstream information may not always be in the cache: we may not even have it
// since we skipped analysis for standard and 3rd party libraries.
if p, ok := p.upstreamObjPositions[pkgRepr+"."+string(objPath)]; ok {
Expand All @@ -217,14 +232,14 @@ func (p *primitivizer) site(key annotation.Key, isDeep bool) primitiveSite {
// their Object.Pos() and retrieve the position information. However, we must trim the possible
// build-system sandbox prefix from the filenames for cross-package references.
if !position.IsValid() {
position = p.toPosition(key.Object().Pos())
position = p.toPosition(obj.Pos())
}

return primitiveSite{
PkgPath: pkgRepr,
Repr: key.String(),
IsDeep: isDeep,
Exported: key.Object().Exported(),
Exported: obj.Exported(),
ObjectPath: objPath,
Position: position,
}
Expand Down
Binary file added nilaway
Binary file not shown.
73 changes: 73 additions & 0 deletions testdata/src/go.uber.org/trustedfunc/issue277_function_pointer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package trustedfunc

// Test for issue #277: function pointers as rich check effect functions

type S2 struct {
f func() (*int, error)
}

// This is the false positive from the issue - should NOT report a warning
func testFunctionPointerErrorFP(s *S2) {
v, err := s.f()
if err != nil {
return
}
_ = *v // no error - err was checked
}

// This SHOULD report a warning because we don't check err
func testFunctionPointerErrorNoCheck(s *S2) {
v, _ := s.f()
_ = *v // want "deref"
}

// Another checked example - should NOT report
func testFunctionPointerErrorChecked(s *S2) {
v, err := s.f()
if err != nil {
return
}
_ = *v // no error - err was checked
}

// Test with local function variable - no check
func testLocalFunctionVarNoCheck() {
var f func() (*int, error)
v, _ := f()
_ = *v // want "deref"
}

// Test with local function variable - checked
func testLocalFunctionVarChecked() {
var f func() (*int, error)
v, err := f()
if err != nil {
return
}
_ = *v // no error - err was checked
}

// Test with parameter function variable - no check
func testParamFunctionVarNoCheck(f func() (*int, error)) {
v, _ := f()
_ = *v // want "deref"
}

// Test with parameter function variable - checked
func testParamFunctionVarChecked(f func() (*int, error)) {
v, err := f()
if err != nil {
return
}
_ = *v // no error - err was checked
}

// Test from the issue description - this is the false positive case
// Should NOT report a warning
func testme(s *S2) {
v, err := s.f()
if err != nil {
return
}
_ = *v // no error - err was checked
}