From e166b5a8c75f9ad392ac57278a255a31231074e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Tue, 8 Apr 2025 13:43:04 +0200 Subject: [PATCH 1/5] Encapsulate flags in a struct --- mockgen/mockgen.go | 186 ++++++++++++++++++++++++++-------------- mockgen/mockgen_test.go | 2 +- 2 files changed, 122 insertions(+), 66 deletions(-) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 79cce84b..84058468 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -86,47 +86,21 @@ func main() { return } - var pkg *model.Package - var err error - var packageName string - if *modelGob != "" { - pkg, err = gobMode(*modelGob) - } else if *source != "" { - pkg, err = sourceMode(*source) - } else { - if flag.NArg() != 2 { - usage() - log.Fatal("Expected exactly two arguments") - } - packageName = flag.Arg(0) - interfaces := strings.Split(flag.Arg(1), ",") - if packageName == "." { - dir, err := os.Getwd() - if err != nil { - log.Fatalf("Get current directory failed: %v", err) - } - packageName, err = packageNameOfDir(dir) - if err != nil { - log.Fatalf("Parse package name failed: %v", err) - } - } - parser := packageModeParser{} - pkg, err = parser.parsePackage(packageName, interfaces) - } + target, err := prepareTarget() if err != nil { log.Fatalf("Loading input failed: %v", err) } if *debugParser { - pkg.Print(os.Stdout) + target.pkg.Print(os.Stdout) return } - outputPackageName := *packageOut + outputPackageName := target.packageOut if outputPackageName == "" { // pkg.Name in package mode is the base name of the import path, // which might have characters that are illegal to have in package names. - outputPackageName = "mock_" + sanitize(pkg.Name) + outputPackageName = "mock_" + sanitize(target.pkg.Name) } // outputPackagePath represents the fully qualified name of the package of @@ -135,9 +109,9 @@ func main() { // package (i.e. if there is a type called X then we want to print "X" not // "package.X" since "package" is this package). This can happen if the mock // is output into an already existing package. - outputPackagePath := *selfPackage - if outputPackagePath == "" && *destination != "" { - dstPath, err := filepath.Abs(filepath.Dir(*destination)) + outputPackagePath := target.selfPackage + if outputPackagePath == "" && target.destination != "" { + dstPath, err := filepath.Abs(filepath.Dir(target.destination)) if err == nil { pkgPath, err := parsePackageImport(dstPath) if err == nil { @@ -151,44 +125,43 @@ func main() { } g := &generator{ - buildConstraint: *buildConstraint, + buildConstraint: target.buildConstraint, } - if *source != "" { - g.filename = *source + if target.source != "" { + g.filename = target.source } else { - g.srcPackage = packageName - g.srcInterfaces = flag.Arg(1) + g.srcPackage = target.packageName + g.srcInterfaces = target.interfaces } - g.destination = *destination + g.destination = target.destination - if *mockNames != "" { - g.mockNames = parseMockNames(*mockNames) + if target.mockNames != "" { + g.mockNames = parseMockNames(target.mockNames) } - if *copyrightFile != "" { - header, err := os.ReadFile(*copyrightFile) + if target.copyrightFile != "" { + header, err := os.ReadFile(target.copyrightFile) if err != nil { log.Fatalf("Failed reading copyright file: %v", err) } - g.copyrightHeader = string(header) } - if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil { + if err := g.Generate(target, outputPackageName, outputPackagePath); err != nil { log.Fatalf("Failed generating mock: %v", err) } output := g.Output() dst := os.Stdout - if len(*destination) > 0 { - if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil { + if len(target.destination) > 0 { + if err := os.MkdirAll(filepath.Dir(target.destination), os.ModePerm); err != nil { log.Fatalf("Unable to create directory: %v", err) } - existing, err := os.ReadFile(*destination) + existing, err := os.ReadFile(target.destination) if err != nil && !errors.Is(err, os.ErrNotExist) { log.Fatalf("Failed reading pre-exiting destination file: %v", err) } if len(existing) == len(output) && bytes.Equal(existing, output) { return } - f, err := os.Create(*destination) + f, err := os.Create(target.destination) if err != nil { log.Fatalf("Failed opening destination file: %v", err) } @@ -200,6 +173,65 @@ func main() { } } +func prepareTarget() (*genTarget, error) { + target := genTarget{ + destination: *destination, + mockNames: *mockNames, + packageOut: *packageOut, + selfPackage: *selfPackage, + writeCmdComment: *writeCmdComment, + writePkgComment: *writePkgComment, + writeSourceComment: *writeSourceComment, + writeGenerateDirective: *writeGenerateDirective, + copyrightFile: *copyrightFile, + buildConstraint: *buildConstraint, + typed: *typed, + } + if *modelGob != "" { + pkg, err := gobMode(*modelGob) + if err != nil { + return nil, err + } + target.pkg = pkg + return &target, nil + } else if *source != "" { + pkg, err := sourceMode(*source) + if err != nil { + return nil, err + } + target.pkg = pkg + target.source = *source + target.imports = *imports + return &target, nil + } else { + if flag.NArg() != 2 { + usage() + log.Fatal("Expected exactly two arguments") + } + packageName := flag.Arg(0) + interfaces := strings.Split(flag.Arg(1), ",") + if packageName == "." { + dir, err := os.Getwd() + if err != nil { + log.Fatalf("Get current directory failed: %v", err) + } + packageName, err = packageNameOfDir(dir) + if err != nil { + log.Fatalf("Parse package name failed: %v", err) + } + } + parser := packageModeParser{} + pkg, err := parser.parsePackage(packageName, interfaces) + if err != nil { + return nil, err + } + target.pkg = pkg + target.packageName = packageName + target.interfaces = flag.Arg(1) + return &target, nil + } +} + func parseMockNames(names string) map[string]string { mocksMap := make(map[string]string) for _, kv := range strings.Split(names, ",") { @@ -253,6 +285,30 @@ Example: ` +type genTarget struct { + pkg *model.Package + + // source mode only + source string + imports string + + packageName string + interfaces string + + // flags + destination string + mockNames string + packageOut string + selfPackage string + writeCmdComment bool + writePkgComment bool + writeSourceComment bool + writeGenerateDirective bool + copyrightFile string + buildConstraint string + typed bool +} + type generator struct { buf bytes.Buffer indent string @@ -303,8 +359,8 @@ func sanitize(s string) string { return t } -func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error { - if outputPkgName != pkg.Name && *selfPackage == "" { +func (g *generator) Generate(target *genTarget, outputPkgName string, outputPackagePath string) error { + if outputPkgName != target.pkg.Name && target.selfPackage == "" { // reset outputPackagePath if it's not passed in through -self_package outputPackagePath = "" } @@ -324,14 +380,14 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac } g.p("// Code generated by MockGen. DO NOT EDIT.") - if *writeSourceComment { + if target.writeSourceComment { if g.filename != "" { g.p("// Source: %v", g.filename) } else { g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) } } - if *writeCmdComment { + if target.writeCmdComment { g.p("//") g.p("// Generated by this command:") g.p("//") @@ -345,12 +401,12 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac } // Get all required imports, and generate unique names for them all. - im := pkg.Imports() + im := target.pkg.Imports() im[gomockImportPath] = true // Only import reflect if it's used. We only use reflect in mocked methods // so only import if any of the mocked interfaces have methods. - for _, intf := range pkg.Interfaces { + for _, intf := range target.pkg.Interfaces { if len(intf.Methods) > 0 { im["reflect"] = true break @@ -369,8 +425,8 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac packagesName := createPackageMap(sortedPaths) definedImports := make(map[string]string, len(im)) - if *imports != "" { - for _, kv := range strings.Split(*imports, ",") { + if target.imports != "" { + for _, kv := range strings.Split(target.imports, ",") { eq := strings.Index(kv, "=") if k, v := kv[:eq], kv[eq+1:]; k != "." { definedImports[v] = k @@ -404,7 +460,7 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac } // Avoid importing package if source pkg == output pkg - if pth == pkg.PkgPath && outputPackagePath == pkg.PkgPath { + if pth == target.pkg.PkgPath && outputPackagePath == target.pkg.PkgPath { continue } @@ -418,7 +474,7 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac // That is, “generated by” should not be a package comment. g.p("") - if *writePkgComment { + if target.writePkgComment { g.p("// Package %v is a generated GoMock package.", outputPkgName) } g.p("package %v", outputPkgName) @@ -431,18 +487,18 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac } g.p("%v %q", pkgName, pkgPath) } - for _, pkgPath := range pkg.DotImports { + for _, pkgPath := range target.pkg.DotImports { g.p(". %q", pkgPath) } g.out() g.p(")") - if *writeGenerateDirective { + if target.writeGenerateDirective { g.p("//go:generate %v", strings.Join(os.Args, " ")) } - for _, intf := range pkg.Interfaces { - if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { + for _, intf := range target.pkg.Interfaces { + if err := g.GenerateMockInterface(intf, outputPackagePath, target.typed); err != nil { return err } } @@ -484,7 +540,7 @@ func (g *generator) formattedTypeParams(it *model.Interface, pkgOverride string) return long.String(), short.String() } -func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { +func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string, typed bool) error { mockType := g.mockName(intf.Name) longTp, shortTp := g.formattedTypeParams(intf, outputPackagePath) @@ -525,7 +581,7 @@ func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePa g.out() g.p("}") - g.GenerateMockMethods(mockType, intf, outputPackagePath, longTp, shortTp, *typed) + g.GenerateMockMethods(mockType, intf, outputPackagePath, longTp, shortTp, typed) return nil } diff --git a/mockgen/mockgen_test.go b/mockgen/mockgen_test.go index 6b171272..7406e70b 100644 --- a/mockgen/mockgen_test.go +++ b/mockgen/mockgen_test.go @@ -245,7 +245,7 @@ func TestGenerateMockInterface_Helper(t *testing.T) { intf.AddMethod(m) } - if err := g.GenerateMockInterface(intf, "somepackage"); err != nil { + if err := g.GenerateMockInterface(intf, "somepackage", false); err != nil { t.Fatal(err) } From 78c795809c250dd9f4aa879cd27eae06899d633a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Tue, 8 Apr 2025 14:06:28 +0200 Subject: [PATCH 2/5] Prepare for generating multiple targets --- mockgen/mockgen.go | 130 ++++++++++++++++++++++++--------------------- 1 file changed, 69 insertions(+), 61 deletions(-) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 84058468..3580be19 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -86,16 +86,83 @@ func main() { return } - target, err := prepareTarget() + targets, err := prepareTargets() if err != nil { log.Fatalf("Loading input failed: %v", err) } if *debugParser { - target.pkg.Print(os.Stdout) + for i := range targets { + targets[i].pkg.Print(os.Stdout) + } return } + for i := range targets { + generateTarget(&targets[i]) + } +} + +func prepareTargets() ([]genTarget, error) { + target := genTarget{ + destination: *destination, + mockNames: *mockNames, + packageOut: *packageOut, + selfPackage: *selfPackage, + writeCmdComment: *writeCmdComment, + writePkgComment: *writePkgComment, + writeSourceComment: *writeSourceComment, + writeGenerateDirective: *writeGenerateDirective, + copyrightFile: *copyrightFile, + buildConstraint: *buildConstraint, + typed: *typed, + } + if *modelGob != "" { + pkg, err := gobMode(*modelGob) + if err != nil { + return nil, err + } + target.pkg = pkg + return []genTarget{target}, nil + } else if *source != "" { + pkg, err := sourceMode(*source) + if err != nil { + return nil, err + } + target.pkg = pkg + target.source = *source + target.imports = *imports + return []genTarget{target}, nil + } else { + if flag.NArg() != 2 { + usage() + log.Fatal("Expected exactly two arguments") + } + packageName := flag.Arg(0) + interfaces := strings.Split(flag.Arg(1), ",") + if packageName == "." { + dir, err := os.Getwd() + if err != nil { + log.Fatalf("Get current directory failed: %v", err) + } + packageName, err = packageNameOfDir(dir) + if err != nil { + log.Fatalf("Parse package name failed: %v", err) + } + } + parser := packageModeParser{} + pkg, err := parser.parsePackage(packageName, interfaces) + if err != nil { + return nil, err + } + target.pkg = pkg + target.packageName = packageName + target.interfaces = flag.Arg(1) + return []genTarget{target}, nil + } +} + +func generateTarget(target *genTarget) { outputPackageName := target.packageOut if outputPackageName == "" { // pkg.Name in package mode is the base name of the import path, @@ -173,65 +240,6 @@ func main() { } } -func prepareTarget() (*genTarget, error) { - target := genTarget{ - destination: *destination, - mockNames: *mockNames, - packageOut: *packageOut, - selfPackage: *selfPackage, - writeCmdComment: *writeCmdComment, - writePkgComment: *writePkgComment, - writeSourceComment: *writeSourceComment, - writeGenerateDirective: *writeGenerateDirective, - copyrightFile: *copyrightFile, - buildConstraint: *buildConstraint, - typed: *typed, - } - if *modelGob != "" { - pkg, err := gobMode(*modelGob) - if err != nil { - return nil, err - } - target.pkg = pkg - return &target, nil - } else if *source != "" { - pkg, err := sourceMode(*source) - if err != nil { - return nil, err - } - target.pkg = pkg - target.source = *source - target.imports = *imports - return &target, nil - } else { - if flag.NArg() != 2 { - usage() - log.Fatal("Expected exactly two arguments") - } - packageName := flag.Arg(0) - interfaces := strings.Split(flag.Arg(1), ",") - if packageName == "." { - dir, err := os.Getwd() - if err != nil { - log.Fatalf("Get current directory failed: %v", err) - } - packageName, err = packageNameOfDir(dir) - if err != nil { - log.Fatalf("Parse package name failed: %v", err) - } - } - parser := packageModeParser{} - pkg, err := parser.parsePackage(packageName, interfaces) - if err != nil { - return nil, err - } - target.pkg = pkg - target.packageName = packageName - target.interfaces = flag.Arg(1) - return &target, nil - } -} - func parseMockNames(names string) map[string]string { mocksMap := make(map[string]string) for _, kv := range strings.Split(names, ",") { From 755ff08c450b2654782cba2dbb35e6266e19a517 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Tue, 8 Apr 2025 15:06:20 +0200 Subject: [PATCH 3/5] Prepare for parsing multiple packages --- mockgen/package_mode.go | 152 +++++++++++++++++++++++----------------- 1 file changed, 86 insertions(+), 66 deletions(-) diff --git a/mockgen/package_mode.go b/mockgen/package_mode.go index acbe487f..2aa67209 100644 --- a/mockgen/package_mode.go +++ b/mockgen/package_mode.go @@ -17,8 +17,6 @@ var ( ) type packageModeParser struct { - pkgName string - // Mapping from underlying types to aliases used within the package source. // // We prefer to use aliases used in the source rather than underlying type names @@ -35,25 +33,54 @@ type aliasReplacement struct { } func (p *packageModeParser) parsePackage(packageName string, ifaces []string) (*model.Package, error) { - p.pkgName = packageName + parsed, err := p.parsePackages([]parseTarget{{name: packageName, ifaces: ifaces}}) + if err != nil { + return nil, err + } + return parsed[0], nil +} - pkg, err := p.loadPackage(packageName) +type parseTarget struct { + name string + ifaces []string +} + +func (p *packageModeParser) parsePackages(targets []parseTarget) ([]*model.Package, error) { + packageNames := make([]string, len(targets)) + for i := range targets { + packageNames[i] = targets[i].name + } + + pkgs, err := loadPackages(packageNames) if err != nil { return nil, fmt.Errorf("load package: %w", err) } - p.buildAliasReplacements(pkg) + p.buildAliasReplacements(pkgs) - interfaces, err := p.extractInterfacesFromPackage(pkg, ifaces) - if err != nil { - return nil, fmt.Errorf("extract interfaces from package: %w", err) + pkgByPath := make(map[string]*packages.Package, len(pkgs)) + for _, pkg := range pkgs { + pkgByPath[pkg.PkgPath] = pkg } - return &model.Package{ - Name: pkg.Types.Name(), - PkgPath: packageName, - Interfaces: interfaces, - }, nil + parsed := make([]*model.Package, len(targets)) + for i := range targets { + pkg, ok := pkgByPath[targets[i].name] + if !ok { + return nil, fmt.Errorf("package not found: %s", targets[i].name) + } + interfaces, err := p.extractInterfacesFromPackage(pkg, targets[i].ifaces) + if err != nil { + return nil, fmt.Errorf("extract interfaces from package: %w", err) + } + + parsed[i] = &model.Package{ + Name: pkg.Types.Name(), + PkgPath: pkg.PkgPath, + Interfaces: interfaces, + } + } + return parsed, nil } // buildAliasReplacements finds and records any references to aliases @@ -65,7 +92,7 @@ func (p *packageModeParser) parsePackage(packageName string, ifaces []string) (* // the latest one to be inspected will be the one used for mapping. // This is fine, since all aliases and their underlying types are interchangeable // from a type-checking standpoint. -func (p *packageModeParser) buildAliasReplacements(pkg *packages.Package) { +func (p *packageModeParser) buildAliasReplacements(pkgs []*packages.Package) { p.aliasReplacements = make(map[types.Type]aliasReplacement) // checkIdent checks if the given identifier exists @@ -96,51 +123,51 @@ func (p *packageModeParser) buildAliasReplacements(pkg *packages.Package) { pkg: pkg.Path(), } return false - } - for _, f := range pkg.Syntax { - fileScope, ok := pkg.TypesInfo.Scopes[f] - if !ok { - continue - } - ast.Inspect(f, func(node ast.Node) bool { - - // Simple identifiers: check if it is an alias - // from the source package. - if ident, ok := node.(*ast.Ident); ok { - return checkIdent(pkg.Types, ident.String()) - } - - // Selector expressions: check if it is an alias - // from the package represented by the qualifier. - selExpr, ok := node.(*ast.SelectorExpr) - if !ok { - return true - } - - x, sel := selExpr.X, selExpr.Sel - xident, ok := x.(*ast.Ident) - if !ok { - return true - } - - xObj := fileScope.Lookup(xident.String()) - pkgName, ok := xObj.(*types.PkgName) + for _, pkg := range pkgs { + for _, f := range pkg.Syntax { + fileScope, ok := pkg.TypesInfo.Scopes[f] if !ok { - return true + continue } - - xPkg := pkgName.Imported() - if xPkg == nil { - return true - } - return checkIdent(xPkg, sel.String()) - }) + ast.Inspect(f, func(node ast.Node) bool { + // Simple identifiers: check if it is an alias + // from the source package. + if ident, ok := node.(*ast.Ident); ok { + return checkIdent(pkg.Types, ident.String()) + } + + // Selector expressions: check if it is an alias + // from the package represented by the qualifier. + selExpr, ok := node.(*ast.SelectorExpr) + if !ok { + return true + } + + x, sel := selExpr.X, selExpr.Sel + xident, ok := x.(*ast.Ident) + if !ok { + return true + } + + xObj := fileScope.Lookup(xident.String()) + pkgName, ok := xObj.(*types.PkgName) + if !ok { + return true + } + + xPkg := pkgName.Imported() + if xPkg == nil { + return true + } + return checkIdent(xPkg, sel.String()) + }) + } } } -func (p *packageModeParser) loadPackage(packageName string) (*packages.Package, error) { +func loadPackages(packageNames []string) ([]*packages.Package, error) { var buildFlagsSet []string if *buildFlags != "" { buildFlagsSet = strings.Split(*buildFlags, " ") @@ -150,25 +177,18 @@ func (p *packageModeParser) loadPackage(packageName string) (*packages.Package, Mode: packages.NeedDeps | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedEmbedFiles | packages.LoadSyntax, BuildFlags: buildFlagsSet, } - pkgs, err := packages.Load(cfg, packageName) + pkgs, err := packages.Load(cfg, packageNames...) if err != nil { return nil, fmt.Errorf("load packages: %w", err) } - if len(pkgs) != 1 { - return nil, fmt.Errorf("packages length must be 1: %d", len(pkgs)) - } - - if len(pkgs[0].Errors) > 0 { - errs := make([]error, len(pkgs[0].Errors)) - for i, err := range pkgs[0].Errors { - errs[i] = err + var errs []error + for _, pkg := range pkgs { + for _, err := range pkg.Errors { + errs = append(errs, err) } - - return nil, errors.Join(errs...) } - - return pkgs[0], nil + return pkgs, errors.Join(errs...) } func (p *packageModeParser) extractInterfacesFromPackage(pkg *packages.Package, ifaces []string) ([]*model.Interface, error) { @@ -244,7 +264,7 @@ func (p *packageModeParser) parseInterface(obj types.Object) (*model.Interface, return &model.Interface{Name: obj.Name(), Methods: methods, TypeParams: typeParams}, nil } -func (o *packageModeParser) isConstraint(t *types.Interface) bool { +func (p *packageModeParser) isConstraint(t *types.Interface) bool { for i := range t.NumEmbeddeds() { embed := t.EmbeddedType(i) if _, ok := embed.Underlying().(*types.Interface); !ok { From b1da60f6ce8d8148eb00be235c5f49900456e80c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Tue, 8 Apr 2025 17:15:31 +0200 Subject: [PATCH 4/5] Add batch mode --- go.mod | 2 +- mockgen/mockgen.go | 120 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 113 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index f1ddbebb..1ab45dbe 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/stretchr/testify v1.9.0 golang.org/x/mod v0.18.0 golang.org/x/tools v0.22.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -13,5 +14,4 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/yuin/goldmark v1.4.13 // indirect golang.org/x/sync v0.7.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 3580be19..f1c0009a 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -39,6 +39,7 @@ import ( "golang.org/x/mod/modfile" toolsimports "golang.org/x/tools/imports" + "gopkg.in/yaml.v3" "go.uber.org/mock/mockgen/model" ) @@ -53,23 +54,32 @@ var ( date = "unknown" ) +const ( + defaultWriteCmdComment = true + defaultWritePkgComment = true + defaultWriteSourceComment = true + defaultWriteGenerateDirective = false + defaultTyped = false +) + var ( source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.") destination = flag.String("destination", "", "Output file; defaults to stdout.") mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.") packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.") selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.") - writeCmdComment = flag.Bool("write_command_comment", true, "Writes the command used as a comment if true.") - writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.") - writeSourceComment = flag.Bool("write_source_comment", true, "Writes original file (source mode) or interface names (package mode) comment if true.") - writeGenerateDirective = flag.Bool("write_generate_directive", false, "Add //go:generate directive to regenerate the mock") + writeCmdComment = flag.Bool("write_command_comment", defaultWriteCmdComment, "Writes the command used as a comment if true.") + writePkgComment = flag.Bool("write_package_comment", defaultWritePkgComment, "Writes package documentation comment (godoc) if true.") + writeSourceComment = flag.Bool("write_source_comment", defaultWriteSourceComment, "Writes original file (source mode) or interface names (package mode) comment if true.") + writeGenerateDirective = flag.Bool("write_generate_directive", defaultWriteGenerateDirective, "Add //go:generate directive to regenerate the mock") copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header") buildConstraint = flag.String("build_constraint", "", "If non-empty, added as //go:build ") - typed = flag.Bool("typed", false, "Generate Type-safe 'Return', 'Do', 'DoAndReturn' function") + typed = flag.Bool("typed", defaultTyped, "Generate Type-safe 'Return', 'Do', 'DoAndReturn' function") imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.") auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") excludeInterfaces = flag.String("exclude_interfaces", "", "(source mode) Comma-separated names of interfaces to be excluded") modelGob = flag.String("model_gob", "", "Skip package/source loading entirely and use the gob encoded model.Package at the given path") + batch = flag.String("batch", "", "YAML file with mockgen configuration for multiple packages. If used, all other flags are ignored.") debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") showVersion = flag.Bool("version", false, "Print version.") @@ -133,6 +143,56 @@ func prepareTargets() ([]genTarget, error) { target.source = *source target.imports = *imports return []genTarget{target}, nil + } else if *batch != "" { + f, err := os.ReadFile(*batch) + if err != nil { + log.Fatalf("Failed reading batch file: %v", err) + } + var b Batch + if err := yaml.Unmarshal(f, &b); err != nil { + log.Fatalf("Failed parsing batch file: %v", err) + } + + parseTargets := make([]parseTarget, len(b.Targets)) + for i := range b.Targets { + target := b.Targets[i].Target + packageName, ifaces, found := strings.Cut(target, " ") + if !found { + log.Fatalf("Invalid target, must be a package name followed by comma-separated interface names: %s", target) + } + + parseTargets[i] = parseTarget{ + name: packageName, + ifaces: strings.Split(ifaces, ","), + } + } + + parser := packageModeParser{} + pkgs, err := parser.parsePackages(parseTargets) + if err != nil { + return nil, err + } + + targets := make([]genTarget, len(b.Targets)) + for i := range b.Targets { + targets[i] = genTarget{ + pkg: pkgs[i], + packageName: parseTargets[i].name, + interfaces: parseTargets[i].ifaces, + destination: b.Targets[i].Destination, + mockNames: b.Targets[i].MockNames, + packageOut: b.Targets[i].PackageOut, + selfPackage: b.Targets[i].SelfPackage, + writeCmdComment: overrideBool(defaultWriteCmdComment, b.Generator.WriteCmdComment, b.Targets[i].WriteCmdComment), + writePkgComment: overrideBool(defaultWritePkgComment, b.Generator.WritePkgComment, b.Targets[i].WritePkgComment), + writeSourceComment: overrideBool(defaultWriteSourceComment, b.Generator.WriteSourceComment, b.Targets[i].WriteSourceComment), + writeGenerateDirective: overrideBool(defaultWriteGenerateDirective, b.Generator.WriteGenerateDirective, b.Targets[i].WriteGenerateDirective), + copyrightFile: overrideString("", b.Generator.CopyrightFile, b.Targets[i].CopyrightFile), + buildConstraint: overrideString("", b.Generator.BuildConstraint, b.Targets[i].BuildConstraint), + typed: overrideBool(defaultTyped, b.Generator.Typed, b.Targets[i].Typed), + } + } + return targets, nil } else { if flag.NArg() != 2 { usage() @@ -157,11 +217,31 @@ func prepareTargets() ([]genTarget, error) { } target.pkg = pkg target.packageName = packageName - target.interfaces = flag.Arg(1) + target.interfaces = interfaces return []genTarget{target}, nil } } +func overrideBool(deflt bool, global, pkg *bool) bool { + if pkg != nil { + return *pkg + } + if global != nil { + return *global + } + return deflt +} + +func overrideString(deflt, global, pkg string) string { + if pkg != "" { + return pkg + } + if global != "" { + return global + } + return deflt +} + func generateTarget(target *genTarget) { outputPackageName := target.packageOut if outputPackageName == "" { @@ -198,7 +278,7 @@ func generateTarget(target *genTarget) { g.filename = target.source } else { g.srcPackage = target.packageName - g.srcInterfaces = target.interfaces + g.srcInterfaces = strings.Join(target.interfaces, ",") } g.destination = target.destination @@ -293,6 +373,30 @@ Example: ` +type Batch struct { + Generator Flags `yaml:"generator"` + Targets []Target `yaml:"targets"` +} + +type Target struct { + Target string `yaml:"target"` + Destination string `yaml:"destination"` + MockNames string `yaml:"mock_names"` + PackageOut string `yaml:"package"` + SelfPackage string `yaml:"self_package"` + Flags +} + +type Flags struct { + WriteCmdComment *bool `yaml:"write_command_comment"` + WritePkgComment *bool `yaml:"write_package_comment"` + WriteSourceComment *bool `yaml:"write_source_comment"` + WriteGenerateDirective *bool `yaml:"write_generate_directive"` + CopyrightFile string `yaml:"copyright_file"` + BuildConstraint string `yaml:"build_constraint"` + Typed *bool `yaml:"typed"` +} + type genTarget struct { pkg *model.Package @@ -301,7 +405,7 @@ type genTarget struct { imports string packageName string - interfaces string + interfaces []string // flags destination string From 5812d4ad5754ea864cb21470ade5670cc0ff0c35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Thu, 10 Apr 2025 12:26:56 +0200 Subject: [PATCH 5/5] Document batch mode --- README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/README.md b/README.md index b449eb2f..a869ffbd 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,32 @@ mockgen database/sql/driver Conn,Driver mockgen . Conn,Driver ``` +### Batch mode + +Batch mode works similarly to package mode, but allows generating mocks for multiple packages at once. +This is especially useful in large codebases that call mockgen repeatedly, +as it allows mockgen to parse the code only once, making the generation process shorter. + +To use batch mode you need to prepare a YAML file. +The file provides a list of package and interface names pairs, called generation "targets". +It also allows specifying flags for all packages (in the `generator` section), or for individual packages. +Field names are identical with commandline flags used in package mode. + +The file structure is as follows: + +```yaml +generator: + # package mode flags to use for all packages: + typed: true + copyright_file: copyright.txt +targets: + - target: database/sql/driver Conn,Driver + - target: github.com/example/example/mypackage Client,Server + # package-specific flags: + destination: example/example/mypackage/mocks/generated.go + write_generate_directive: true +``` + ### Flags The `mockgen` command is used to generate source code for a mock