diff --git a/walk/filesystem.go b/walk/filesystem.go index dc7b301c..8915d030 100644 --- a/walk/filesystem.go +++ b/walk/filesystem.go @@ -2,6 +2,7 @@ package walk import ( "context" + "errors" "fmt" "io" "io/fs" @@ -15,6 +16,10 @@ import ( "golang.org/x/sync/errgroup" ) +// errWalkClosed is used internally to abort filepath.Walk when Close() is +// called while process() is still producing. +var errWalkClosed = errors.New("filesystem reader closed") + // FilesystemReader traverses and reads files from a specified root directory and its subdirectories. type FilesystemReader struct { log *log.Logger @@ -22,7 +27,8 @@ type FilesystemReader struct { path string batchSize int - eg *errgroup.Group + eg *errgroup.Group + done chan struct{} stats *stats.Stats filesCh chan *File @@ -69,13 +75,19 @@ func (f *FilesystemReader) process() error { Info: info, } - f.filesCh <- &file + select { + case f.filesCh <- &file: + case <-f.done: + return errWalkClosed + } f.log.Debugf("file queued %s", file.RelPath) return nil }) - if err != nil { + if errors.Is(err, errWalkClosed) { + return nil + } else if err != nil { return fmt.Errorf("failed to walk path %s: %w", path, err) } @@ -118,6 +130,9 @@ LOOP: // Close waits for all filesystem processing to complete. func (f *FilesystemReader) Close() error { + // Unblock process() in case the caller stopped draining Read() before EOF. + close(f.done) + err := f.eg.Wait() if err != nil { return fmt.Errorf("failed to wait for processing to complete: %w", err) @@ -143,7 +158,8 @@ func NewFilesystemReader( path: path, batchSize: batchSize, - eg: &eg, + eg: &eg, + done: make(chan struct{}), stats: statz, filesCh: make(chan *File, batchSize*runtime.NumCPU()), diff --git a/walk/filesystem_test.go b/walk/filesystem_test.go index 45d7d17b..1960563c 100644 --- a/walk/filesystem_test.go +++ b/walk/filesystem_test.go @@ -3,7 +3,11 @@ package walk_test import ( "context" "errors" + "fmt" "io" + "os" + "path/filepath" + "runtime" "testing" "time" @@ -101,3 +105,32 @@ func TestFilesystemReader(t *testing.T) { as.Equal(0, statz.Value(stats.Formatted)) as.Equal(0, statz.Value(stats.Changed)) } + +// TestFilesystemReaderCloseUnblocks verifies that Close() returns even when +// the caller stops draining Read() before EOF. +func TestFilesystemReaderCloseUnblocks(t *testing.T) { + as := require.New(t) + + tempDir := t.TempDir() + + // Enough files to overflow the reader's internal channel so process() + // would block on send if Close() did not unblock it. + n := walk.BatchSize*runtime.GOMAXPROCS(0) + 100 + for i := range n { + as.NoError(os.WriteFile(filepath.Join(tempDir, fmt.Sprintf("f%05d", i)), nil, 0o600)) + } + + statz := stats.New() + reader := walk.NewFilesystemReader(tempDir, "", &statz, walk.BatchSize) + + done := make(chan error, 1) + + go func() { done <- reader.Close() }() + + select { + case err := <-done: + as.NoError(err) + case <-time.After(5 * time.Second): + t.Fatal("Close() did not return; process() is deadlocked") + } +} diff --git a/walk/git.go b/walk/git.go index 607ba524..89825674 100644 --- a/walk/git.go +++ b/walk/git.go @@ -8,8 +8,10 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "strconv" "strings" + "sync" "github.com/charmbracelet/log" "github.com/numtide/treefmt/v2/git" @@ -17,6 +19,11 @@ import ( "golang.org/x/sync/errgroup" ) +type gitEntry struct { + relative string + gitlink bool // mode 160000, i.e. a submodule +} + type GitReader struct { root string path string @@ -25,101 +32,30 @@ type GitReader struct { stats *stats.Stats eg *errgroup.Group - scanner *bufio.Scanner + filesCh chan *File + cancel context.CancelFunc } func (g *GitReader) Read(ctx context.Context, files []*File) (n int, err error) { - // ensure we record how many files we traversed defer func() { g.stats.Add(stats.Traversed, n) }() - nextFile := func() (string, error) { - for line := g.scanner.Text(); len(line) > 0; line = g.scanner.Text() { - lineSplit := strings.Split(line, "\t") - - var stage, file string - // Untracked files just show as ``, while tracked files show as ` ` - if len(lineSplit) == 1 { - stage, file = "", lineSplit[0] - } else { - stage, file = lineSplit[0], lineSplit[1] - } - - // 160000 is the mode for submodules, skip them because they are separate projects that may have their own - // formatting rules - if strings.HasPrefix(stage, "160000") { - g.scanner.Scan() - - continue - } - - if file[0] != '"' { - return file, nil - } - - unquoted, err := strconv.Unquote(file) - if err != nil { - return "", fmt.Errorf("failed to unquote file %s: %w", file, err) - } - - return unquoted, nil - } - - return "", io.EOF - } - LOOP: - for n < len(files) { select { - // exit early if the context was cancelled case <-ctx.Done(): return n, ctx.Err() //nolint:wrapcheck - default: - // read the next file - if g.scanner.Scan() { - entry, err := nextFile() - if err != nil { - return n, err - } - - path := filepath.Join(g.root, g.path, entry) - - g.log.Debugf("processing file: %s", path) - - info, err := os.Lstat(path) - - switch { - case os.IsNotExist(err): - // the underlying file might have been removed - g.log.Warnf( - "Path %s is in the worktree but appears to have been removed from the filesystem", path, - ) - - continue - case err != nil: - return n, fmt.Errorf("failed to stat %s: %w", path, err) - case info.Mode()&os.ModeSymlink == os.ModeSymlink: - // we skip reporting symlinks stored in Git, they should - // point to local files which we would list anyway. - continue - } - - files[n] = &File{ - Path: path, - RelPath: filepath.Join(g.path, entry), - Info: info, - } - - n++ - } else { - // nothing more to read + case file, ok := <-g.filesCh: + if !ok { err = io.EOF break LOOP } + + files[n] = file + n++ } } @@ -127,6 +63,10 @@ LOOP: } func (g *GitReader) Close() error { + // Unblock any producer goroutines (and kill the git children) in case the + // caller stopped draining Read() before EOF. + g.cancel() + err := g.eg.Wait() if err != nil { return fmt.Errorf("failed to wait for git command to complete: %w", err) @@ -135,12 +75,131 @@ func (g *GitReader) Close() error { return nil } +func (g *GitReader) stat(ctx context.Context, entry gitEntry) { + if entry.gitlink { + // submodules are separate projects with their own formatting rules + return + } + + path := filepath.Join(g.root, entry.relative) + + g.log.Debugf("processing file: %s", path) + + info, err := os.Lstat(path) + + switch { + case os.IsNotExist(err): + g.log.Warnf( + "Path %s is in the worktree but appears to have been removed from the filesystem", path, + ) + + return + case err != nil: + g.log.Errorf("failed to stat %s: %v", path, err) + + return + case info.Mode()&os.ModeSymlink == os.ModeSymlink: + // symlinks point at files we list anyway + return + } + + select { + case g.filesCh <- &File{Path: path, RelPath: entry.relative, Info: info}: + case <-ctx.Done(): + } +} + +func lsFiles(ctx context.Context, dir string, staged bool, prefix string, out chan<- gitEntry, args ...string) error { + //nolint:gosec // args are fixed flag sets assembled in NewGitReader, not user input. + cmd := exec.CommandContext(ctx, "git", append([]string{"ls-files"}, args...)...) + cmd.Dir = dir + + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + if ctx.Err() != nil { + return nil //nolint:nilerr // reader was closed; cancellation is not a failure + } + + return fmt.Errorf("failed to start git ls-files: %w", err) + } + + scanErr := scanLsFiles(ctx, stdout, staged, prefix, out) + + // Always reap the child. If scanning aborted early git may be blocked on a + // full pipe, so kill it first to guarantee Wait returns. + if scanErr != nil { + _ = cmd.Process.Kill() + } + + waitErr := cmd.Wait() + + if ctx.Err() != nil { + return nil //nolint:nilerr // reader was closed; the kill signal is not a failure + } + + if scanErr != nil { + return scanErr + } + + if waitErr != nil { + return fmt.Errorf("git ls-files failed: %w", waitErr) + } + + return nil +} + +func scanLsFiles(ctx context.Context, r io.Reader, staged bool, prefix string, out chan<- gitEntry) error { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + + var gitlink bool + + path := line + if staged { + // \t + if mode, file, ok := strings.Cut(line, "\t"); ok { + gitlink = strings.HasPrefix(mode, "160000") + path = file + } + } + + if path == "" { + continue + } + + if path[0] == '"' { + unquoted, err := strconv.Unquote(path) + if err != nil { + return fmt.Errorf("failed to unquote file %s: %w", path, err) + } + + path = unquoted + } + + select { + case out <- gitEntry{relative: filepath.Join(prefix, path), gitlink: gitlink}: + case <-ctx.Done(): + return nil + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to read git ls-files output: %w", err) + } + + return nil +} + func NewGitReader( root string, path string, statz *stats.Stats, ) (*GitReader, error) { - // check if the root is a git repository isGit, err := git.IsInsideWorktree(root) if err != nil { return nil, fmt.Errorf("failed to check if %s is a git repository: %w", root, err) @@ -150,34 +209,72 @@ func NewGitReader( return nil, fmt.Errorf("%s is not a git repository", root) } - // create an errgroup for executing in the background - eg := &errgroup.Group{} - - // create a pipe to capture the command output - r, w := io.Pipe() + dir := filepath.Join(root, path) - // create a command which will execute from the specified sub path within root - cmd := exec.CommandContext( - context.Background(), - "git", "ls-files", "--cached", "--others", "--exclude-standard", "--stage", - ) - cmd.Dir = filepath.Join(root, path) - cmd.Stdout = w + ctx, cancel := context.WithCancel(context.Background()) - // execute the command in the background - eg.Go(func() error { - return w.CloseWithError(cmd.Run()) - }) - - // create a new scanner for reading the output - scanner := bufio.NewScanner(r) - - return &GitReader{ - eg: eg, + g := &GitReader{ root: root, path: path, stats: statz, - scanner: scanner, log: log.WithPrefix("walk | git"), - }, nil + eg: &errgroup.Group{}, + filesCh: make(chan *File, BatchSize*runtime.NumCPU()), + cancel: cancel, + } + + entries := make(chan gitEntry, BatchSize) + + // `--cached` and `--others` are queried separately because git buffers all + // output until the untracked scan finishes when both are combined; the + // index-only query streams immediately so formatters start without waiting. + var producers sync.WaitGroup + + producers.Add(2) + + g.eg.Go(func() error { + defer producers.Done() + + return lsFiles(ctx, dir, true, path, entries, "--cached", "--stage") + }) + + g.eg.Go(func() error { + defer producers.Done() + + return lsFiles(ctx, dir, false, path, entries, "--others", "--exclude-standard") + }) + + g.eg.Go(func() error { + producers.Wait() + close(entries) + + return nil + }) + + var workers sync.WaitGroup + + statWorkers := runtime.GOMAXPROCS(0) + + workers.Add(statWorkers) + + for range statWorkers { + g.eg.Go(func() error { + defer workers.Done() + + for e := range entries { + g.stat(ctx, e) + } + + return nil + }) + } + + g.eg.Go(func() error { + workers.Wait() + close(g.filesCh) + + return nil + }) + + return g, nil } diff --git a/walk/git_test.go b/walk/git_test.go index ce4e3d29..bda45af8 100644 --- a/walk/git_test.go +++ b/walk/git_test.go @@ -3,8 +3,12 @@ package walk_test import ( "context" "errors" + "fmt" "io" + "os" "os/exec" + "path/filepath" + "runtime" "testing" "time" @@ -110,3 +114,42 @@ func TestGitReader(t *testing.T) { as.Equal(0, statz.Value(stats.Formatted)) as.Equal(0, statz.Value(stats.Changed)) } + +// TestGitReaderCloseUnblocks verifies that Close() returns even when the +// caller stops draining Read() before EOF. +func TestGitReaderCloseUnblocks(t *testing.T) { + as := require.New(t) + + tempDir := t.TempDir() + + cmd := exec.CommandContext(t.Context(), "git", "init") + cmd.Dir = tempDir + as.NoError(cmd.Run()) + + // Enough files to overflow the reader's internal channel so producers + // would block if Close() did not unblock them. + n := walk.BatchSize*runtime.GOMAXPROCS(0) + 100 + for i := range n { + as.NoError(os.WriteFile(filepath.Join(tempDir, fmt.Sprintf("f%05d", i)), nil, 0o600)) + } + + cmd = exec.CommandContext(t.Context(), "git", "add", ".") + cmd.Dir = tempDir + as.NoError(cmd.Run()) + + statz := stats.New() + + reader, err := walk.NewGitReader(tempDir, "", &statz) + as.NoError(err) + + done := make(chan error, 1) + + go func() { done <- reader.Close() }() + + select { + case err := <-done: + as.NoError(err) + case <-time.After(5 * time.Second): + t.Fatal("Close() did not return; producers are deadlocked") + } +} diff --git a/walk/jujutsu.go b/walk/jujutsu.go index 4ee7a195..79f6e522 100644 --- a/walk/jujutsu.go +++ b/walk/jujutsu.go @@ -23,6 +23,7 @@ type JujutsuReader struct { stats *stats.Stats eg *errgroup.Group + cancel context.CancelFunc scanner *bufio.Scanner } @@ -93,6 +94,9 @@ LOOP: } func (j *JujutsuReader) Close() error { + // Unblock the jj process in case the caller stopped draining Read() before EOF. + j.cancel() + err := j.eg.Wait() if err != nil { return fmt.Errorf("failed to wait for jujutsu command to complete: %w", err) @@ -119,6 +123,8 @@ func NewJujutsuReader( // create an errgroup for async list task eg := &errgroup.Group{} + ctx, cancel := context.WithCancel(context.Background()) + // create a pipe to capture the command output r, w := io.Pipe() @@ -134,13 +140,19 @@ func NewJujutsuReader( } // create the jj command - cmd := exec.CommandContext(context.Background(), "jj", args...) + cmd := exec.CommandContext(ctx, "jj", args...) cmd.Dir = root cmd.Stdout = w // execute the command in the background eg.Go(func() error { - return w.CloseWithError(cmd.Run()) + err := cmd.Run() + if ctx.Err() != nil { + // reader was closed; the kill signal is not a failure + err = nil + } + + return w.CloseWithError(err) }) // create a new scanner for reading the output @@ -148,6 +160,7 @@ func NewJujutsuReader( return &JujutsuReader{ eg: eg, + cancel: cancel, root: root, path: path, stats: statz,