diff --git a/annotated.go b/annotated.go index 4a732f1e3..e4ba858ce 100644 --- a/annotated.go +++ b/annotated.go @@ -248,6 +248,10 @@ func (pt paramTagsAnnotation) apply(ann *annotated) error { // build builds and returns a constructor after applying a ParamTags annotation func (pt paramTagsAnnotation) build(ann *annotated) (any, error) { + if ann.paramTagsBuilt { + return ann.Target, nil + } + ann.paramTagsBuilt = true paramTypes, remap := pt.parameters(ann) resultTypes, _ := ann.currentResultTypes() @@ -267,6 +271,7 @@ func (pt paramTagsAnnotation) parameters(ann *annotated) ( types []reflect.Type, remap func([]reflect.Value) []reflect.Value, ) { + tags := ann.ParamTags ft := reflect.TypeOf(ann.Target) types = make([]reflect.Type, ft.NumIn()) for i := 0; i < ft.NumIn(); i++ { @@ -275,7 +280,7 @@ func (pt paramTagsAnnotation) parameters(ann *annotated) ( // No parameter annotations. Return the original types // and an identity function. - if len(pt.tags) == 0 { + if len(tags) == 0 { return types, func(args []reflect.Value) []reflect.Value { return args } @@ -295,8 +300,8 @@ func (pt paramTagsAnnotation) parameters(ann *annotated) ( Type: origField.Type, Tag: origField.Tag, } - if i-1 < len(pt.tags) { - field.Tag = reflect.StructTag(pt.tags[i-1]) + if i-1 < len(tags) { + field.Tag = reflect.StructTag(tags[i-1]) } inFields = append(inFields, field) @@ -318,8 +323,8 @@ func (pt paramTagsAnnotation) parameters(ann *annotated) ( Name: fmt.Sprintf("Field%d", i), Type: t, } - if i < len(pt.tags) { - field.Tag = reflect.StructTag(pt.tags[i]) + if i < len(tags) { + field.Tag = reflect.StructTag(tags[i]) } inFields = append(inFields, field) @@ -353,6 +358,76 @@ func ParamTags(tags ...string) Annotation { return paramTagsAnnotation{tags} } +type genericParamTagAnnotation struct { + typ reflect.Type + tag string + nth int +} + +var _ Annotation = genericParamTagAnnotation{} + +func (pt genericParamTagAnnotation) apply(ann *annotated) error { + if err := verifyAnnotateTag(pt.tag); err != nil { + return err + } + if pt.nth < 1 { + return fmt.Errorf("fx.ParamTag[%v]: nth must be >= 1, got %d", pt.typ, pt.nth) + } + + targetTyp := reflect.TypeOf(ann.Target) + if targetTyp == nil || targetTyp.Kind() != reflect.Func { + return fmt.Errorf("fx.ParamTag[%v]: must annotate a function, got %T", pt.typ, ann.Target) + } + + idx, total := nthOfType(targetTyp, pt.typ, pt.nth, "in") + if total == 0 { + return fmt.Errorf("fx.ParamTag[%v]: %s has no %v parameters", + pt.typ, fxreflect.FuncName(ann.Target), pt.typ) + } + if idx < 0 { + return fmt.Errorf("fx.ParamTag[%v]: %s has only %d %v parameters, requested #%d", + pt.typ, fxreflect.FuncName(ann.Target), total, pt.typ, pt.nth) + } + + ensureLen(&ann.ParamTags, targetTyp.NumIn()) + + if ann.ParamTags[idx] != "" { + return fmt.Errorf("fx.ParamTag[%v]: parameter %d of %s already has tag %q", + pt.typ, idx, fxreflect.FuncName(ann.Target), ann.ParamTags[idx]) + } + + ann.ParamTags[idx] = pt.tag + return nil +} + +func (pt genericParamTagAnnotation) build(ann *annotated) (any, error) { + return paramTagsAnnotation{}.build(ann) // delegate +} + +// ParamTag is an Annotation that annotates a function parameter of type T. +// +// By default, it annotates the first such parameter. If there are multiple parameters of type T, +// nth selects the 1-based occurrence to annotate. +// +// For example, this tags the second [*sql.DB] parameter: +// +// fx.Annotate(constructor, fx.ParamTag[*sql.DB](`name:"ro"`, 2)) +// +// Multiple ParamTag annotations may be applied to the same function as long as +// they target different parameters. +func ParamTag[T any](tag string, nth ...int) Annotation { + finalNth := 1 + if len(nth) > 0 { + finalNth = nth[0] + } + + return genericParamTagAnnotation{ + typ: reflect.TypeFor[T](), + tag: tag, + nth: finalNth, + } +} + type resultTagsAnnotation struct { tags []string } @@ -391,6 +466,10 @@ func (rt resultTagsAnnotation) apply(ann *annotated) error { // build builds and returns a constructor after applying a ResultTags annotation func (rt resultTagsAnnotation) build(ann *annotated) (any, error) { + if ann.resultTagsBuilt { + return ann.Target, nil + } + ann.resultTagsBuilt = true paramTypes := ann.currentParamTypes() resultTypes, remapResults := rt.results(ann) origFn := reflect.ValueOf(ann.Target) @@ -409,6 +488,7 @@ func (rt resultTagsAnnotation) results(ann *annotated) ( types []reflect.Type, remap func([]reflect.Value) []reflect.Value, ) { + tags := ann.ResultTags types, hasError := ann.currentResultTypes() if hasError { @@ -417,7 +497,7 @@ func (rt resultTagsAnnotation) results(ann *annotated) ( // No result annotations. Return the original types // and an identity function. - if len(rt.tags) == 0 { + if len(tags) == 0 { return types, func(results []reflect.Value) []reflect.Value { return results } @@ -441,8 +521,8 @@ func (rt resultTagsAnnotation) results(ann *annotated) ( Name: fmt.Sprintf("Field%d", i), Type: t, } - if i < len(rt.tags) { - field.Tag = reflect.StructTag(rt.tags[i]) + if i < len(tags) { + field.Tag = reflect.StructTag(tags[i]) } newOut.Offsets = append(newOut.Offsets, len(newOut.Fields)) newOut.Fields = append(newOut.Fields, field) @@ -452,7 +532,7 @@ func (rt resultTagsAnnotation) results(ann *annotated) ( // apply the tags to the existing type taggedFields := make([]reflect.StructField, t.NumField()) taggedFields[0] = _outAnnotationField - for j, tag := range rt.tags { + for j, tag := range tags { if j+1 < t.NumField() { field := t.Field(j + 1) taggedFields[j+1] = reflect.StructField{ @@ -543,6 +623,76 @@ func ResultTags(tags ...string) Annotation { return resultTagsAnnotation{tags} } +type genericResultTagAnnotation struct { + typ reflect.Type + tag string + nth int +} + +var _ Annotation = genericResultTagAnnotation{} + +func (rt genericResultTagAnnotation) apply(ann *annotated) error { + if err := verifyAnnotateTag(rt.tag); err != nil { + return err + } + if rt.nth < 1 { + return fmt.Errorf("fx.ResultTag[%v]: nth must be >= 1, got %d", rt.typ, rt.nth) + } + + targetTyp := reflect.TypeOf(ann.Target) + if targetTyp == nil || targetTyp.Kind() != reflect.Func { + return fmt.Errorf("fx.ResultTag[%v]: must annotate a function, got %T", rt.typ, ann.Target) + } + + idx, total := nthOfType(targetTyp, rt.typ, rt.nth, "out") + if total == 0 { + return fmt.Errorf("fx.ResultTag[%v]: %s has no %v results", + rt.typ, fxreflect.FuncName(ann.Target), rt.typ) + } + if idx < 0 { + return fmt.Errorf("fx.ResultTag[%v]: %s has only %d %v results, requested #%d", + rt.typ, fxreflect.FuncName(ann.Target), total, rt.typ, rt.nth) + } + + ensureLen(&ann.ResultTags, targetTyp.NumOut()) + + if ann.ResultTags[idx] != "" { + return fmt.Errorf("fx.ResultTag[%v]: return value %d of %s already has tag %q", + rt.typ, idx, fxreflect.FuncName(ann.Target), ann.ResultTags[idx]) + } + + ann.ResultTags[idx] = rt.tag + return nil +} + +func (rt genericResultTagAnnotation) build(ann *annotated) (any, error) { + return resultTagsAnnotation{}.build(ann) // delegate +} + +// ResultTag is an Annotation that annotates a function return value of type T. +// +// By default, it annotates the first such return value. If there are multiple return +// values of type T, nth selects the 1-based occurrence to annotate. +// +// For example, this tags the second [*sql.DB] return value: +// +// fx.Annotate(constructor, fx.ResultTag[*sql.DB](`name:"ro"`, 2)) +// +// Multiple ResultTag annotations may be applied to the same function as long as +// they target different return values. +func ResultTag[T any](tag string, nth ...int) Annotation { + finalNth := 1 + if len(nth) > 0 { + finalNth = nth[0] + } + + return genericResultTagAnnotation{ + typ: reflect.TypeFor[T](), + tag: tag, + nth: finalNth, + } +} + type outStructInfo struct { Fields []reflect.StructField // fields of the struct Offsets []int // Offsets[i] is the index of result i in Fields @@ -1210,6 +1360,42 @@ func isIn(t reflect.Type) bool { dig.IsIn(reflect.New(t).Elem().Interface())) } +// find the index of the n-th occurrence (nth >= 1) of typ in the +// function funcTyp parameters or results. Also returns the total number of occurrences. +func nthOfType(funcTyp, typ reflect.Type, nth int, op string) (idx, seen int) { + count := funcTyp.NumIn() + get := funcTyp.In + if op == "out" { + count = funcTyp.NumOut() + get = funcTyp.Out + } + + idx = -1 + for i := 0; i < count; i++ { + if get(i) != typ { + continue + } + + seen++ + if seen == nth { + idx = i + } + } + + return +} + +// grows the slice and fills it with zero values +func ensureLen[T any](s *[]T, targetLen int) { + if len(*s) >= targetLen { + return + } + + res := make([]T, targetLen) + copy(res, *s) + *s = res +} + var _ Annotation = (*asAnnotation)(nil) // As is an Annotation that annotates the result of a function (i.e. a @@ -1613,14 +1799,16 @@ func (fr *fromAnnotation) parameters(ann *annotated) ( } type annotated struct { - Target any - Annotations []Annotation - ParamTags []string - ResultTags []string - As [][]asType - From []reflect.Type - FuncPtr uintptr - Hooks []*lifecycleHookAnnotation + Target any + Annotations []Annotation + ParamTags []string + ResultTags []string + As [][]asType + From []reflect.Type + FuncPtr uintptr + Hooks []*lifecycleHookAnnotation + paramTagsBuilt bool + resultTagsBuilt bool // container is used to build private scopes for lifecycle hook functions // added via fx.OnStart and fx.OnStop annotations. container *dig.Container diff --git a/annotated_test.go b/annotated_test.go index e10fb120c..4a8a3bf30 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -2534,3 +2534,369 @@ func TestHookAnnotationFunctionFlexibility(t *testing.T) { }) } } + +func TestParamTag(t *testing.T) { + t.Parallel() + + type DB struct { + name string + } + type Redis struct { + name string + } + type Service struct { + db *DB + dbRO *DB + redis *Redis + redisRO *Redis + } + + newDB := func(name string) func() *DB { + return func() *DB { + return &DB{name: name} + } + } + newRedis := func(name string) func() *Redis { + return func() *Redis { + return &Redis{name: name} + } + } + newService := func(db *DB, dbRO *DB, redis *Redis, redisRO *Redis) *Service { + return &Service{db: db, dbRO: dbRO, redis: redis, redisRO: redisRO} + } + + t.Run("tags first parameter of a type", func(t *testing.T) { + t.Parallel() + + var service *Service + app := fxtest.New(t, + fx.Provide( + newDB("db1"), + fx.Annotate(newRedis("redis1"), fx.ResultTags(`name:"primary"`)), + newRedis("redis2"), + + fx.Annotate(newService, + fx.ParamTag[*Redis](`name:"primary"`), + ), + ), + fx.Populate(&service), + ) + defer app.RequireStart().RequireStop() + + require.NoError(t, app.Err()) + require.NotNil(t, service) + + assert.Equal(t, "db1", service.db.name) + assert.Equal(t, "db1", service.dbRO.name) + assert.Equal(t, "redis1", service.redis.name) + assert.Equal(t, "redis2", service.redisRO.name) + }) + + t.Run("tags multiple parameters of the same type", func(t *testing.T) { + t.Parallel() + + var service *Service + app := fxtest.New(t, + fx.Provide( + fx.Annotate(newDB("db1"), fx.ResultTags(`name:"primary"`)), + fx.Annotate(newDB("db2"), fx.ResultTags(`name:"ro"`)), + newRedis("r1"), + fx.Annotate(newRedis("r2"), fx.ResultTags(`name:"ro"`)), + + fx.Annotate(newService, + fx.ParamTag[*DB](`name:"primary"`), + fx.ParamTag[*DB](`name:"ro"`, 2), + fx.ParamTag[*Redis](`name:"ro"`, 2), + ), + ), + fx.Populate(&service), + ) + defer app.RequireStart().RequireStop() + + require.NoError(t, app.Err()) + require.NotNil(t, service) + + assert.Equal(t, "db1", service.db.name) + assert.Equal(t, "db2", service.dbRO.name) + assert.Equal(t, "r1", service.redis.name) + assert.Equal(t, "r2", service.redisRO.name) + }) + + t.Run("adds tags after ParamTags", func(t *testing.T) { + t.Parallel() + + var service *Service + app := fxtest.New(t, + fx.Provide( + fx.Annotate(newDB("db1"), fx.ResultTags(`name:"primary"`)), + fx.Annotate(newDB("db2"), fx.ResultTags(`name:"ro"`)), + newRedis("r1"), + fx.Annotate(newRedis("r2"), fx.ResultTags(`name:"ro"`)), + + fx.Annotate(newService, + fx.ParamTags(``, `name:"ro"`), + fx.ParamTag[*DB](`name:"primary"`), + fx.ParamTag[*Redis](`name:"ro"`, 2), + ), + ), + fx.Populate(&service), + ) + defer app.RequireStart().RequireStop() + + require.NoError(t, app.Err()) + require.NotNil(t, service) + + assert.Equal(t, "db1", service.db.name) + assert.Equal(t, "db2", service.dbRO.name) + assert.Equal(t, "r1", service.redis.name) + assert.Equal(t, "r2", service.redisRO.name) + }) + + t.Run("errors if type is missing", func(t *testing.T) { + t.Parallel() + + err := NewForTest(t, + fx.Provide( + newDB("db1"), + newRedis("r1"), + fx.Annotate(newService, + fx.ParamTag[io.Reader](`name:"db1"`), + ), + ), + ).Err() + + require.Error(t, err) + assert.Contains(t, err.Error(), "has no io.Reader parameters") + }) + + t.Run("errors if nth exceeds count", func(t *testing.T) { + t.Parallel() + + err := NewForTest(t, + fx.Provide( + newDB("db1"), + newRedis("r1"), + fx.Annotate(newService, + fx.ParamTag[*DB](`name:"ro"`, 15), + ), + ), + ).Err() + + require.Error(t, err) + assert.Contains(t, err.Error(), "has only 2 *fx_test.DB parameters") + assert.Contains(t, err.Error(), "requested #15") + }) + + t.Run("errors if nth is less than 1", func(t *testing.T) { + t.Parallel() + + err := NewForTest(t, + fx.Provide( + newDB("db1"), + newRedis("r1"), + fx.Annotate(newService, + fx.ParamTag[*DB](`name:"ro"`, 0), + ), + ), + ).Err() + + require.Error(t, err) + assert.Contains(t, err.Error(), "nth must be >= 1, got 0") + }) + + t.Run("errors if parameter already has tag", func(t *testing.T) { + t.Parallel() + + err := NewForTest(t, + fx.Provide( + newDB("db1"), + newRedis("r1"), + fx.Annotate(newService, + fx.ParamTag[*DB](`name:"ro"`, 2), + fx.ParamTag[*DB](`optional:"true"`, 2), + ), + ), + ).Err() + + require.Error(t, err) + assert.Contains(t, err.Error(), "already has tag") + }) +} + +func TestResultTag(t *testing.T) { + t.Parallel() + + type DB struct { + name string + } + type Redis struct { + name string + } + + constructor := func() (*DB, *DB, *Redis, *Redis, error) { + return &DB{name: "db1"}, &DB{name: "db2"}, &Redis{name: "redis1"}, &Redis{name: "redis2"}, nil + } + + t.Run("tags first result of a type", func(t *testing.T) { + t.Parallel() + + var db *DB + var dbRO *DB + var redis *Redis + var redisRO *Redis + + err := fxtest.New(t, + fx.Provide(fx.Annotate(constructor, + fx.ResultTag[*DB](`name:"primary"`), + fx.ResultTag[*Redis](`name:"primary"`), + )), + + fx.Populate(fx.Annotate(&db, fx.ParamTags(`name:"primary"`))), + fx.Populate(&dbRO), + fx.Populate(fx.Annotate(&redis, fx.ParamTags(`name:"primary"`))), + fx.Populate(&redisRO), + ).Err() + + require.NoError(t, err) + require.Equal(t, "db1", db.name) + require.Equal(t, "db2", dbRO.name) + require.Equal(t, "redis1", redis.name) + require.Equal(t, "redis2", redisRO.name) + }) + + t.Run("tags multiple results of the same type", func(t *testing.T) { + t.Parallel() + + var db *DB + var dbRO *DB + var redis *Redis + var redisRO *Redis + + err := fxtest.New(t, + fx.Provide(fx.Annotate(constructor, + fx.ResultTag[*DB](`name:"primary"`), + fx.ResultTag[*DB](`name:"ro"`, 2), + fx.ResultTag[*Redis](`name:"primary"`), + fx.ResultTag[*Redis](`name:"ro"`, 2), + )), + + fx.Populate(fx.Annotate(&db, fx.ParamTags(`name:"primary"`))), + fx.Populate(fx.Annotate(&dbRO, fx.ParamTags(`name:"ro"`))), + fx.Populate(fx.Annotate(&redis, fx.ParamTags(`name:"primary"`))), + fx.Populate(fx.Annotate(&redisRO, fx.ParamTags(`name:"ro"`))), + ).Err() + + require.NoError(t, err) + require.Equal(t, "db1", db.name) + require.Equal(t, "db2", dbRO.name) + require.Equal(t, "redis1", redis.name) + require.Equal(t, "redis2", redisRO.name) + }) + + t.Run("adds tags after ResultTags", func(t *testing.T) { + t.Parallel() + + var db *DB + var dbRO *DB + var redis *Redis + var redisRO *Redis + + err := fxtest.New(t, + fx.Provide(fx.Annotate(constructor, + fx.ResultTags(``, `name:"ro"`), + fx.ResultTag[*DB](`name:"primary"`), + fx.ResultTag[*Redis](`name:"primary"`), + fx.ResultTag[*Redis](`name:"ro"`, 2), + )), + + fx.Populate(fx.Annotate(&db, fx.ParamTags(`name:"primary"`))), + fx.Populate(fx.Annotate(&dbRO, fx.ParamTags(`name:"ro"`))), + fx.Populate(fx.Annotate(&redis, fx.ParamTags(`name:"primary"`))), + fx.Populate(fx.Annotate(&redisRO, fx.ParamTags(`name:"ro"`))), + ).Err() + + require.NoError(t, err) + require.Equal(t, "db1", db.name) + require.Equal(t, "db2", dbRO.name) + require.Equal(t, "redis1", redis.name) + require.Equal(t, "redis2", redisRO.name) + }) + + t.Run("errors if type is missing", func(t *testing.T) { + t.Parallel() + + err := NewForTest(t, + fx.Provide( + fx.Annotate(constructor, + fx.ResultTag[io.Reader](`name:"primary"`), + ), + ), + ).Err() + + require.Error(t, err) + assert.Contains(t, err.Error(), "has no io.Reader results") + }) + + t.Run("tagging of terminal error is ignored", func(t *testing.T) { + t.Parallel() + + err := NewForTest(t, + fx.Provide( + fx.Annotate(constructor, + fx.ResultTag[*DB](`name:"primary"`), + fx.ResultTag[*Redis](`name:"primary"`), + fx.ResultTag[error](`name:"failure"`), + ), + ), + ).Err() + + require.NoError(t, err) + }) + + t.Run("errors if nth exceeds count", func(t *testing.T) { + t.Parallel() + + err := NewForTest(t, + fx.Provide( + fx.Annotate(constructor, + fx.ResultTag[*DB](`name:"ro"`, 15), + ), + ), + ).Err() + + require.Error(t, err) + assert.Contains(t, err.Error(), "has only 2 *fx_test.DB results") + assert.Contains(t, err.Error(), "requested #15") + }) + + t.Run("errors if nth is less than 1", func(t *testing.T) { + t.Parallel() + + err := NewForTest(t, + fx.Provide( + fx.Annotate(constructor, + fx.ResultTag[*DB](`name:"ro"`, 0), + ), + ), + ).Err() + + require.Error(t, err) + assert.Contains(t, err.Error(), "nth must be >= 1, got 0") + }) + + t.Run("errors if result already has tag", func(t *testing.T) { + t.Parallel() + + err := NewForTest(t, + fx.Provide( + fx.Annotate(constructor, + fx.ResultTag[*DB](`name:"ro"`, 2), + fx.ResultTag[*DB](`group:"databases"`, 2), + ), + ), + ).Err() + + require.Error(t, err) + assert.Contains(t, err.Error(), "already has tag") + }) +}