Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 205 additions & 17 deletions annotated.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ func (pt paramTagsAnnotation) apply(ann *annotated) error {

// build builds and returns a constructor after applying a ParamTags annotation
func (pt paramTagsAnnotation) build(ann *annotated) (any, error) {
if ann.paramTagsBuilt {
return ann.Target, nil
}
ann.paramTagsBuilt = true
paramTypes, remap := pt.parameters(ann)
resultTypes, _ := ann.currentResultTypes()

Expand All @@ -267,6 +271,7 @@ func (pt paramTagsAnnotation) parameters(ann *annotated) (
types []reflect.Type,
remap func([]reflect.Value) []reflect.Value,
) {
tags := ann.ParamTags
ft := reflect.TypeOf(ann.Target)
types = make([]reflect.Type, ft.NumIn())
for i := 0; i < ft.NumIn(); i++ {
Expand All @@ -275,7 +280,7 @@ func (pt paramTagsAnnotation) parameters(ann *annotated) (

// No parameter annotations. Return the original types
// and an identity function.
if len(pt.tags) == 0 {
if len(tags) == 0 {
return types, func(args []reflect.Value) []reflect.Value {
return args
}
Expand All @@ -295,8 +300,8 @@ func (pt paramTagsAnnotation) parameters(ann *annotated) (
Type: origField.Type,
Tag: origField.Tag,
}
if i-1 < len(pt.tags) {
field.Tag = reflect.StructTag(pt.tags[i-1])
if i-1 < len(tags) {
field.Tag = reflect.StructTag(tags[i-1])
}

inFields = append(inFields, field)
Expand All @@ -318,8 +323,8 @@ func (pt paramTagsAnnotation) parameters(ann *annotated) (
Name: fmt.Sprintf("Field%d", i),
Type: t,
}
if i < len(pt.tags) {
field.Tag = reflect.StructTag(pt.tags[i])
if i < len(tags) {
field.Tag = reflect.StructTag(tags[i])
}

inFields = append(inFields, field)
Expand Down Expand Up @@ -353,6 +358,76 @@ func ParamTags(tags ...string) Annotation {
return paramTagsAnnotation{tags}
}

type genericParamTagAnnotation struct {
typ reflect.Type
tag string
nth int
}

var _ Annotation = genericParamTagAnnotation{}

func (pt genericParamTagAnnotation) apply(ann *annotated) error {
if err := verifyAnnotateTag(pt.tag); err != nil {
return err
}
if pt.nth < 1 {
return fmt.Errorf("fx.ParamTag[%v]: nth must be >= 1, got %d", pt.typ, pt.nth)
}

targetTyp := reflect.TypeOf(ann.Target)
if targetTyp == nil || targetTyp.Kind() != reflect.Func {
return fmt.Errorf("fx.ParamTag[%v]: must annotate a function, got %T", pt.typ, ann.Target)
}

idx, total := nthOfType(targetTyp, pt.typ, pt.nth, "in")
if total == 0 {
return fmt.Errorf("fx.ParamTag[%v]: %s has no %v parameters",
pt.typ, fxreflect.FuncName(ann.Target), pt.typ)
}
if idx < 0 {
return fmt.Errorf("fx.ParamTag[%v]: %s has only %d %v parameters, requested #%d",
pt.typ, fxreflect.FuncName(ann.Target), total, pt.typ, pt.nth)
}

ensureLen(&ann.ParamTags, targetTyp.NumIn())

if ann.ParamTags[idx] != "" {
return fmt.Errorf("fx.ParamTag[%v]: parameter %d of %s already has tag %q",
pt.typ, idx, fxreflect.FuncName(ann.Target), ann.ParamTags[idx])
}

ann.ParamTags[idx] = pt.tag
return nil
}

func (pt genericParamTagAnnotation) build(ann *annotated) (any, error) {
return paramTagsAnnotation{}.build(ann) // delegate
}

// ParamTag is an Annotation that annotates a function parameter of type T.
//
// By default, it annotates the first such parameter. If there are multiple parameters of type T,
// nth selects the 1-based occurrence to annotate.
//
// For example, this tags the second [*sql.DB] parameter:
//
// fx.Annotate(constructor, fx.ParamTag[*sql.DB](`name:"ro"`, 2))
//
// Multiple ParamTag annotations may be applied to the same function as long as
// they target different parameters.
func ParamTag[T any](tag string, nth ...int) Annotation {
finalNth := 1
if len(nth) > 0 {
finalNth = nth[0]
}

return genericParamTagAnnotation{
typ: reflect.TypeFor[T](),
tag: tag,
nth: finalNth,
}
}

type resultTagsAnnotation struct {
tags []string
}
Expand Down Expand Up @@ -391,6 +466,10 @@ func (rt resultTagsAnnotation) apply(ann *annotated) error {

// build builds and returns a constructor after applying a ResultTags annotation
func (rt resultTagsAnnotation) build(ann *annotated) (any, error) {
if ann.resultTagsBuilt {
return ann.Target, nil
}
ann.resultTagsBuilt = true
paramTypes := ann.currentParamTypes()
resultTypes, remapResults := rt.results(ann)
origFn := reflect.ValueOf(ann.Target)
Expand All @@ -409,6 +488,7 @@ func (rt resultTagsAnnotation) results(ann *annotated) (
types []reflect.Type,
remap func([]reflect.Value) []reflect.Value,
) {
tags := ann.ResultTags
types, hasError := ann.currentResultTypes()

if hasError {
Expand All @@ -417,7 +497,7 @@ func (rt resultTagsAnnotation) results(ann *annotated) (

// No result annotations. Return the original types
// and an identity function.
if len(rt.tags) == 0 {
if len(tags) == 0 {
return types, func(results []reflect.Value) []reflect.Value {
return results
}
Expand All @@ -441,8 +521,8 @@ func (rt resultTagsAnnotation) results(ann *annotated) (
Name: fmt.Sprintf("Field%d", i),
Type: t,
}
if i < len(rt.tags) {
field.Tag = reflect.StructTag(rt.tags[i])
if i < len(tags) {
field.Tag = reflect.StructTag(tags[i])
}
newOut.Offsets = append(newOut.Offsets, len(newOut.Fields))
newOut.Fields = append(newOut.Fields, field)
Expand All @@ -452,7 +532,7 @@ func (rt resultTagsAnnotation) results(ann *annotated) (
// apply the tags to the existing type
taggedFields := make([]reflect.StructField, t.NumField())
taggedFields[0] = _outAnnotationField
for j, tag := range rt.tags {
for j, tag := range tags {
if j+1 < t.NumField() {
field := t.Field(j + 1)
taggedFields[j+1] = reflect.StructField{
Expand Down Expand Up @@ -543,6 +623,76 @@ func ResultTags(tags ...string) Annotation {
return resultTagsAnnotation{tags}
}

type genericResultTagAnnotation struct {
typ reflect.Type
tag string
nth int
}

var _ Annotation = genericResultTagAnnotation{}

func (rt genericResultTagAnnotation) apply(ann *annotated) error {
if err := verifyAnnotateTag(rt.tag); err != nil {
return err
}
if rt.nth < 1 {
return fmt.Errorf("fx.ResultTag[%v]: nth must be >= 1, got %d", rt.typ, rt.nth)
}

targetTyp := reflect.TypeOf(ann.Target)
if targetTyp == nil || targetTyp.Kind() != reflect.Func {
return fmt.Errorf("fx.ResultTag[%v]: must annotate a function, got %T", rt.typ, ann.Target)
}

idx, total := nthOfType(targetTyp, rt.typ, rt.nth, "out")
if total == 0 {
return fmt.Errorf("fx.ResultTag[%v]: %s has no %v results",
rt.typ, fxreflect.FuncName(ann.Target), rt.typ)
}
if idx < 0 {
return fmt.Errorf("fx.ResultTag[%v]: %s has only %d %v results, requested #%d",
rt.typ, fxreflect.FuncName(ann.Target), total, rt.typ, rt.nth)
}

ensureLen(&ann.ResultTags, targetTyp.NumOut())

if ann.ResultTags[idx] != "" {
return fmt.Errorf("fx.ResultTag[%v]: return value %d of %s already has tag %q",
rt.typ, idx, fxreflect.FuncName(ann.Target), ann.ResultTags[idx])
}

ann.ResultTags[idx] = rt.tag
return nil
}

func (rt genericResultTagAnnotation) build(ann *annotated) (any, error) {
return resultTagsAnnotation{}.build(ann) // delegate
}

// ResultTag is an Annotation that annotates a function return value of type T.
//
// By default, it annotates the first such return value. If there are multiple return
// values of type T, nth selects the 1-based occurrence to annotate.
//
// For example, this tags the second [*sql.DB] return value:
//
// fx.Annotate(constructor, fx.ResultTag[*sql.DB](`name:"ro"`, 2))
//
// Multiple ResultTag annotations may be applied to the same function as long as
// they target different return values.
func ResultTag[T any](tag string, nth ...int) Annotation {
finalNth := 1
if len(nth) > 0 {
finalNth = nth[0]
}

return genericResultTagAnnotation{
typ: reflect.TypeFor[T](),
tag: tag,
nth: finalNth,
}
}

type outStructInfo struct {
Fields []reflect.StructField // fields of the struct
Offsets []int // Offsets[i] is the index of result i in Fields
Expand Down Expand Up @@ -1210,6 +1360,42 @@ func isIn(t reflect.Type) bool {
dig.IsIn(reflect.New(t).Elem().Interface()))
}

// find the index of the n-th occurrence (nth >= 1) of typ in the
// function funcTyp parameters or results. Also returns the total number of occurrences.
func nthOfType(funcTyp, typ reflect.Type, nth int, op string) (idx, seen int) {
count := funcTyp.NumIn()
get := funcTyp.In
if op == "out" {
count = funcTyp.NumOut()
get = funcTyp.Out
}

idx = -1
for i := 0; i < count; i++ {
if get(i) != typ {
continue
}

seen++
if seen == nth {
idx = i
}
}

return
}

// grows the slice and fills it with zero values
func ensureLen[T any](s *[]T, targetLen int) {
if len(*s) >= targetLen {
return
}

res := make([]T, targetLen)
copy(res, *s)
*s = res
}

var _ Annotation = (*asAnnotation)(nil)

// As is an Annotation that annotates the result of a function (i.e. a
Expand Down Expand Up @@ -1613,14 +1799,16 @@ func (fr *fromAnnotation) parameters(ann *annotated) (
}

type annotated struct {
Target any
Annotations []Annotation
ParamTags []string
ResultTags []string
As [][]asType
From []reflect.Type
FuncPtr uintptr
Hooks []*lifecycleHookAnnotation
Target any
Annotations []Annotation
ParamTags []string
ResultTags []string
As [][]asType
From []reflect.Type
FuncPtr uintptr
Hooks []*lifecycleHookAnnotation
paramTagsBuilt bool
resultTagsBuilt bool
// container is used to build private scopes for lifecycle hook functions
// added via fx.OnStart and fx.OnStop annotations.
container *dig.Container
Expand Down
Loading