diff --git a/cmd/restic/cleanup.go b/cmd/restic/cleanup.go index 5a6cf79e11d..90ea93b9235 100644 --- a/cmd/restic/cleanup.go +++ b/cmd/restic/cleanup.go @@ -1,89 +1,41 @@ package main import ( + "context" "os" "os/signal" - "sync" "syscall" "github.com/restic/restic/internal/debug" ) -var cleanupHandlers struct { - sync.Mutex - list []func(code int) (int, error) - done bool - ch chan os.Signal -} - -func init() { - cleanupHandlers.ch = make(chan os.Signal, 1) - go CleanupHandler(cleanupHandlers.ch) - signal.Notify(cleanupHandlers.ch, syscall.SIGINT, syscall.SIGTERM) -} +func createGlobalContext() context.Context { + ctx, cancel := context.WithCancel(context.Background()) -// AddCleanupHandler adds the function f to the list of cleanup handlers so -// that it is executed when all the cleanup handlers are run, e.g. when SIGINT -// is received. -func AddCleanupHandler(f func(code int) (int, error)) { - cleanupHandlers.Lock() - defer cleanupHandlers.Unlock() + ch := make(chan os.Signal, 1) + go cleanupHandler(ch, cancel) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) - // reset the done flag for integration tests - cleanupHandlers.done = false - - cleanupHandlers.list = append(cleanupHandlers.list, f) + return ctx } -// RunCleanupHandlers runs all registered cleanup handlers -func RunCleanupHandlers(code int) int { - cleanupHandlers.Lock() - defer cleanupHandlers.Unlock() - - if cleanupHandlers.done { - return code - } - cleanupHandlers.done = true +// cleanupHandler handles the SIGINT and SIGTERM signals. +func cleanupHandler(c <-chan os.Signal, cancel context.CancelFunc) { + s := <-c + debug.Log("signal %v received, cleaning up", s) + Warnf("%ssignal %v received, cleaning up\n", clearLine(0), s) - for _, f := range cleanupHandlers.list { - var err error - code, err = f(code) - if err != nil { - Warnf("error in cleanup handler: %v\n", err) - } + if val, _ := os.LookupEnv("RESTIC_DEBUG_STACKTRACE_SIGINT"); val != "" { + _, _ = os.Stderr.WriteString("\n--- STACKTRACE START ---\n\n") + _, _ = os.Stderr.WriteString(debug.DumpStacktrace()) + _, _ = os.Stderr.WriteString("\n--- STACKTRACE END ---\n") } - cleanupHandlers.list = nil - return code -} - -// CleanupHandler handles the SIGINT and SIGTERM signals. -func CleanupHandler(c <-chan os.Signal) { - for s := range c { - debug.Log("signal %v received, cleaning up", s) - Warnf("%ssignal %v received, cleaning up\n", clearLine(0), s) - - if val, _ := os.LookupEnv("RESTIC_DEBUG_STACKTRACE_SIGINT"); val != "" { - _, _ = os.Stderr.WriteString("\n--- STACKTRACE START ---\n\n") - _, _ = os.Stderr.WriteString(debug.DumpStacktrace()) - _, _ = os.Stderr.WriteString("\n--- STACKTRACE END ---\n") - } - code := 0 - - if s == syscall.SIGINT || s == syscall.SIGTERM { - code = 130 - } else { - code = 1 - } - - Exit(code) - } + cancel() } -// Exit runs the cleanup handlers and then terminates the process with the -// given exit code. +// Exit terminates the process with the given exit code. func Exit(code int) { - code = RunCleanupHandlers(code) debug.Log("exiting with status code %d", code) os.Exit(code) } diff --git a/cmd/restic/cmd_check.go b/cmd/restic/cmd_check.go index 7bea641ae8e..c44edae7e79 100644 --- a/cmd/restic/cmd_check.go +++ b/cmd/restic/cmd_check.go @@ -199,10 +199,7 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args } cleanup := prepareCheckCache(opts, &gopts) - AddCleanupHandler(func(code int) (int, error) { - cleanup() - return code, nil - }) + defer cleanup() if !gopts.NoLock { Verbosef("create exclusive lock for repository\n") @@ -222,6 +219,9 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args Verbosef("load indexes\n") bar := newIndexProgress(gopts.Quiet, gopts.JSON) hints, errs := chkr.LoadIndex(ctx, bar) + if ctx.Err() != nil { + return ctx.Err() + } errorsFound := false suggestIndexRebuild := false @@ -283,6 +283,9 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args if orphanedPacks > 0 { Verbosef("%d additional files were found in the repo, which likely contain duplicate data.\nThis is non-critical, you can run `restic prune` to correct this.\n", orphanedPacks) } + if ctx.Err() != nil { + return ctx.Err() + } Verbosef("check snapshots, trees and blobs\n") errChan = make(chan error) @@ -316,9 +319,16 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args // Must happen after `errChan` is read from in the above loop to avoid // deadlocking in the case of errors. wg.Wait() + if ctx.Err() != nil { + return ctx.Err() + } if opts.CheckUnused { - for _, id := range chkr.UnusedBlobs(ctx) { + unused, err := chkr.UnusedBlobs(ctx) + if err != nil { + return err + } + for _, id := range unused { Verbosef("unused blob %v\n", id) errorsFound = true } @@ -395,10 +405,13 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args doReadData(packs) } + if ctx.Err() != nil { + return ctx.Err() + } + if errorsFound { return errors.Fatal("repository contains errors") } - Verbosef("no errors were found\n") return nil diff --git a/cmd/restic/cmd_copy.go b/cmd/restic/cmd_copy.go index 410134e41ec..ad6c58a2526 100644 --- a/cmd/restic/cmd_copy.go +++ b/cmd/restic/cmd_copy.go @@ -53,7 +53,7 @@ func init() { } func runCopy(ctx context.Context, opts CopyOptions, gopts GlobalOptions, args []string) error { - secondaryGopts, isFromRepo, err := fillSecondaryGlobalOpts(opts.secondaryRepoOptions, gopts, "destination") + secondaryGopts, isFromRepo, err := fillSecondaryGlobalOpts(ctx, opts.secondaryRepoOptions, gopts, "destination") if err != nil { return err } @@ -103,6 +103,9 @@ func runCopy(ctx context.Context, opts CopyOptions, gopts GlobalOptions, args [] // also consider identical snapshot copies dstSnapshotByOriginal[*sn.ID()] = append(dstSnapshotByOriginal[*sn.ID()], sn) } + if ctx.Err() != nil { + return ctx.Err() + } // remember already processed trees across all snapshots visitedTrees := restic.NewIDSet() @@ -147,7 +150,7 @@ func runCopy(ctx context.Context, opts CopyOptions, gopts GlobalOptions, args [] } Verbosef("snapshot %s saved\n", newID.Str()) } - return nil + return ctx.Err() } func similarSnapshots(sna *restic.Snapshot, snb *restic.Snapshot) bool { diff --git a/cmd/restic/cmd_find.go b/cmd/restic/cmd_find.go index e29fe30dc60..81df0ab9882 100644 --- a/cmd/restic/cmd_find.go +++ b/cmd/restic/cmd_find.go @@ -439,7 +439,10 @@ func (f *Finder) packsToBlobs(ctx context.Context, packs []string) error { if err != errAllPacksFound { // try to resolve unknown pack ids from the index - packIDs = f.indexPacksToBlobs(ctx, packIDs) + packIDs, err = f.indexPacksToBlobs(ctx, packIDs) + if err != nil { + return err + } } if len(packIDs) > 0 { @@ -456,13 +459,13 @@ func (f *Finder) packsToBlobs(ctx context.Context, packs []string) error { return nil } -func (f *Finder) indexPacksToBlobs(ctx context.Context, packIDs map[string]struct{}) map[string]struct{} { +func (f *Finder) indexPacksToBlobs(ctx context.Context, packIDs map[string]struct{}) (map[string]struct{}, error) { wctx, cancel := context.WithCancel(ctx) defer cancel() // remember which packs were found in the index indexPackIDs := make(map[string]struct{}) - f.repo.Index().Each(wctx, func(pb restic.PackedBlob) { + err := f.repo.Index().Each(wctx, func(pb restic.PackedBlob) { idStr := pb.PackID.String() // keep entry in packIDs as Each() returns individual index entries matchingID := false @@ -481,6 +484,9 @@ func (f *Finder) indexPacksToBlobs(ctx context.Context, packIDs map[string]struc indexPackIDs[idStr] = struct{}{} } }) + if err != nil { + return nil, err + } for id := range indexPackIDs { delete(packIDs, id) @@ -493,7 +499,7 @@ func (f *Finder) indexPacksToBlobs(ctx context.Context, packIDs map[string]struc } Warnf("some pack files are missing from the repository, getting their blobs from the repository index: %v\n\n", list) } - return packIDs + return packIDs, nil } func (f *Finder) findObjectPack(id string, t restic.BlobType) { @@ -608,6 +614,9 @@ func runFind(ctx context.Context, opts FindOptions, gopts GlobalOptions, args [] for sn := range FindFilteredSnapshots(ctx, snapshotLister, repo, &opts.SnapshotFilter, opts.Snapshots) { filteredSnapshots = append(filteredSnapshots, sn) } + if ctx.Err() != nil { + return ctx.Err() + } sort.Slice(filteredSnapshots, func(i, j int) bool { return filteredSnapshots[i].Time.Before(filteredSnapshots[j].Time) diff --git a/cmd/restic/cmd_forget.go b/cmd/restic/cmd_forget.go index 9018da21173..92eeed4a174 100644 --- a/cmd/restic/cmd_forget.go +++ b/cmd/restic/cmd_forget.go @@ -188,6 +188,9 @@ func runForget(ctx context.Context, opts ForgetOptions, pruneOptions PruneOption for sn := range FindFilteredSnapshots(ctx, repo, repo, &opts.SnapshotFilter, args) { snapshots = append(snapshots, sn) } + if ctx.Err() != nil { + return ctx.Err() + } var jsonGroups []*ForgetGroup @@ -270,6 +273,10 @@ func runForget(ctx context.Context, opts ForgetOptions, pruneOptions PruneOption } } + if ctx.Err() != nil { + return ctx.Err() + } + if len(removeSnIDs) > 0 { if !opts.DryRun { bar := printer.NewCounter("files deleted") diff --git a/cmd/restic/cmd_init.go b/cmd/restic/cmd_init.go index 7154279e8ba..e6ea694413e 100644 --- a/cmd/restic/cmd_init.go +++ b/cmd/restic/cmd_init.go @@ -80,7 +80,7 @@ func runInit(ctx context.Context, opts InitOptions, gopts GlobalOptions, args [] return err } - gopts.password, err = ReadPasswordTwice(gopts, + gopts.password, err = ReadPasswordTwice(ctx, gopts, "enter password for new repository: ", "enter password again: ") if err != nil { @@ -131,7 +131,7 @@ func runInit(ctx context.Context, opts InitOptions, gopts GlobalOptions, args [] func maybeReadChunkerPolynomial(ctx context.Context, opts InitOptions, gopts GlobalOptions) (*chunker.Pol, error) { if opts.CopyChunkerParameters { - otherGopts, _, err := fillSecondaryGlobalOpts(opts.secondaryRepoOptions, gopts, "secondary") + otherGopts, _, err := fillSecondaryGlobalOpts(ctx, opts.secondaryRepoOptions, gopts, "secondary") if err != nil { return nil, err } diff --git a/cmd/restic/cmd_key_add.go b/cmd/restic/cmd_key_add.go index 83e0cab7f5d..30675462785 100644 --- a/cmd/restic/cmd_key_add.go +++ b/cmd/restic/cmd_key_add.go @@ -60,7 +60,7 @@ func runKeyAdd(ctx context.Context, gopts GlobalOptions, opts KeyAddOptions, arg } func addKey(ctx context.Context, repo *repository.Repository, gopts GlobalOptions, opts KeyAddOptions) error { - pw, err := getNewPassword(gopts, opts.NewPasswordFile) + pw, err := getNewPassword(ctx, gopts, opts.NewPasswordFile) if err != nil { return err } @@ -83,7 +83,7 @@ func addKey(ctx context.Context, repo *repository.Repository, gopts GlobalOption // testKeyNewPassword is used to set a new password during integration testing. var testKeyNewPassword string -func getNewPassword(gopts GlobalOptions, newPasswordFile string) (string, error) { +func getNewPassword(ctx context.Context, gopts GlobalOptions, newPasswordFile string) (string, error) { if testKeyNewPassword != "" { return testKeyNewPassword, nil } @@ -97,7 +97,7 @@ func getNewPassword(gopts GlobalOptions, newPasswordFile string) (string, error) newopts := gopts newopts.password = "" - return ReadPasswordTwice(newopts, + return ReadPasswordTwice(ctx, newopts, "enter new password: ", "enter password again: ") } diff --git a/cmd/restic/cmd_key_passwd.go b/cmd/restic/cmd_key_passwd.go index 70abca6dc20..0836c4cfe3a 100644 --- a/cmd/restic/cmd_key_passwd.go +++ b/cmd/restic/cmd_key_passwd.go @@ -57,7 +57,7 @@ func runKeyPasswd(ctx context.Context, gopts GlobalOptions, opts KeyPasswdOption } func changePassword(ctx context.Context, repo *repository.Repository, gopts GlobalOptions, opts KeyPasswdOptions) error { - pw, err := getNewPassword(gopts, opts.NewPasswordFile) + pw, err := getNewPassword(ctx, gopts, opts.NewPasswordFile) if err != nil { return err } diff --git a/cmd/restic/cmd_list.go b/cmd/restic/cmd_list.go index a3df0c98f5d..27f59b4ab17 100644 --- a/cmd/restic/cmd_list.go +++ b/cmd/restic/cmd_list.go @@ -59,10 +59,9 @@ func runList(ctx context.Context, gopts GlobalOptions, args []string) error { if err != nil { return err } - idx.Each(ctx, func(blobs restic.PackedBlob) { + return idx.Each(ctx, func(blobs restic.PackedBlob) { Printf("%v %v\n", blobs.Type, blobs.ID) }) - return nil }) default: return errors.Fatal("invalid type") diff --git a/cmd/restic/cmd_mount.go b/cmd/restic/cmd_mount.go index cb2b1142d5d..5a10447f36f 100644 --- a/cmd/restic/cmd_mount.go +++ b/cmd/restic/cmd_mount.go @@ -152,28 +152,15 @@ func runMount(ctx context.Context, opts MountOptions, gopts GlobalOptions, args } } - AddCleanupHandler(func(code int) (int, error) { - debug.Log("running umount cleanup handler for mount at %v", mountpoint) - err := umount(mountpoint) - if err != nil { - Warnf("unable to umount (maybe already umounted or still in use?): %v\n", err) - } - // replace error code of sigint - if code == 130 { - code = 0 - } - return code, nil - }) + systemFuse.Debug = func(msg interface{}) { + debug.Log("fuse: %v", msg) + } c, err := systemFuse.Mount(mountpoint, mountOptions...) if err != nil { return err } - systemFuse.Debug = func(msg interface{}) { - debug.Log("fuse: %v", msg) - } - cfg := fuse.Config{ OwnerIsRoot: opts.OwnerRoot, Filter: opts.SnapshotFilter, @@ -187,15 +174,26 @@ func runMount(ctx context.Context, opts MountOptions, gopts GlobalOptions, args Printf("When finished, quit with Ctrl-c here or umount the mountpoint.\n") debug.Log("serving mount at %v", mountpoint) - err = fs.Serve(c, root) - if err != nil { - return err - } - <-c.Ready - return c.MountError -} + done := make(chan struct{}) + + go func() { + defer close(done) + err = fs.Serve(c, root) + }() + + select { + case <-ctx.Done(): + debug.Log("running umount cleanup handler for mount at %v", mountpoint) + err := systemFuse.Unmount(mountpoint) + if err != nil { + Warnf("unable to umount (maybe already umounted or still in use?): %v\n", err) + } + + return ErrOK + case <-done: + // clean shutdown, nothing to do + } -func umount(mountpoint string) error { - return systemFuse.Unmount(mountpoint) + return err } diff --git a/cmd/restic/cmd_mount_integration_test.go b/cmd/restic/cmd_mount_integration_test.go index 590e1503047..d764b4e4f0f 100644 --- a/cmd/restic/cmd_mount_integration_test.go +++ b/cmd/restic/cmd_mount_integration_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + systemFuse "github.com/anacrolix/fuse" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" ) @@ -65,7 +66,7 @@ func testRunMount(t testing.TB, gopts GlobalOptions, dir string, wg *sync.WaitGr func testRunUmount(t testing.TB, dir string) { var err error for i := 0; i < mountWait; i++ { - if err = umount(dir); err == nil { + if err = systemFuse.Unmount(dir); err == nil { t.Logf("directory %v umounted", dir) return } diff --git a/cmd/restic/cmd_prune.go b/cmd/restic/cmd_prune.go index ea5acddf337..cbec100df05 100644 --- a/cmd/restic/cmd_prune.go +++ b/cmd/restic/cmd_prune.go @@ -197,6 +197,9 @@ func runPruneWithRepo(ctx context.Context, opts PruneOptions, gopts GlobalOption if err != nil { return err } + if ctx.Err() != nil { + return ctx.Err() + } if popts.DryRun { printer.P("\nWould have made the following changes:") diff --git a/cmd/restic/cmd_recover.go b/cmd/restic/cmd_recover.go index f9a4d419d22..debaa4e5b5f 100644 --- a/cmd/restic/cmd_recover.go +++ b/cmd/restic/cmd_recover.go @@ -61,16 +61,22 @@ func runRecover(ctx context.Context, gopts GlobalOptions) error { // tree. If it is not referenced, we have a root tree. trees := make(map[restic.ID]bool) - repo.Index().Each(ctx, func(blob restic.PackedBlob) { + err = repo.Index().Each(ctx, func(blob restic.PackedBlob) { if blob.Type == restic.TreeBlob { trees[blob.Blob.ID] = false } }) + if err != nil { + return err + } Verbosef("load %d trees\n", len(trees)) bar = newProgressMax(!gopts.Quiet, uint64(len(trees)), "trees loaded") for id := range trees { tree, err := restic.LoadTree(ctx, repo, id) + if ctx.Err() != nil { + return ctx.Err() + } if err != nil { Warnf("unable to load tree %v: %v\n", id.Str(), err) continue diff --git a/cmd/restic/cmd_repair_snapshots.go b/cmd/restic/cmd_repair_snapshots.go index 4d9745e1550..b200d100adc 100644 --- a/cmd/restic/cmd_repair_snapshots.go +++ b/cmd/restic/cmd_repair_snapshots.go @@ -145,6 +145,9 @@ func runRepairSnapshots(ctx context.Context, gopts GlobalOptions, opts RepairOpt changedCount++ } } + if ctx.Err() != nil { + return ctx.Err() + } Verbosef("\n") if changedCount == 0 { diff --git a/cmd/restic/cmd_rewrite.go b/cmd/restic/cmd_rewrite.go index 06d4ddbd177..38a868c5c97 100644 --- a/cmd/restic/cmd_rewrite.go +++ b/cmd/restic/cmd_rewrite.go @@ -294,6 +294,9 @@ func runRewrite(ctx context.Context, opts RewriteOptions, gopts GlobalOptions, a changedCount++ } } + if ctx.Err() != nil { + return ctx.Err() + } Verbosef("\n") if changedCount == 0 { diff --git a/cmd/restic/cmd_snapshots.go b/cmd/restic/cmd_snapshots.go index 1a9cd2232b2..faa86d3a6bc 100644 --- a/cmd/restic/cmd_snapshots.go +++ b/cmd/restic/cmd_snapshots.go @@ -69,6 +69,9 @@ func runSnapshots(ctx context.Context, opts SnapshotOptions, gopts GlobalOptions for sn := range FindFilteredSnapshots(ctx, repo, repo, &opts.SnapshotFilter, args) { snapshots = append(snapshots, sn) } + if ctx.Err() != nil { + return ctx.Err() + } snapshotGroups, grouped, err := restic.GroupSnapshots(snapshots, opts.GroupBy) if err != nil { return err diff --git a/cmd/restic/cmd_stats.go b/cmd/restic/cmd_stats.go index 6bf0dbf1935..a7891e5b036 100644 --- a/cmd/restic/cmd_stats.go +++ b/cmd/restic/cmd_stats.go @@ -117,9 +117,8 @@ func runStats(ctx context.Context, opts StatsOptions, gopts GlobalOptions, args return fmt.Errorf("error walking snapshot: %v", err) } } - - if err != nil { - return err + if ctx.Err() != nil { + return ctx.Err() } if opts.countMode == countModeRawData { @@ -352,7 +351,10 @@ func statsDebug(ctx context.Context, repo restic.Repository) error { Warnf("File Type: %v\n%v\n", t, hist) } - hist := statsDebugBlobs(ctx, repo) + hist, err := statsDebugBlobs(ctx, repo) + if err != nil { + return err + } for _, t := range []restic.BlobType{restic.DataBlob, restic.TreeBlob} { Warnf("Blob Type: %v\n%v\n\n", t, hist[t]) } @@ -370,17 +372,17 @@ func statsDebugFileType(ctx context.Context, repo restic.Lister, tpe restic.File return hist, err } -func statsDebugBlobs(ctx context.Context, repo restic.Repository) [restic.NumBlobTypes]*sizeHistogram { +func statsDebugBlobs(ctx context.Context, repo restic.Repository) ([restic.NumBlobTypes]*sizeHistogram, error) { var hist [restic.NumBlobTypes]*sizeHistogram for i := 0; i < len(hist); i++ { hist[i] = newSizeHistogram(2 * chunker.MaxSize) } - repo.Index().Each(ctx, func(pb restic.PackedBlob) { + err := repo.Index().Each(ctx, func(pb restic.PackedBlob) { hist[pb.Type].Add(uint64(pb.Length)) }) - return hist + return hist, err } type sizeClass struct { diff --git a/cmd/restic/cmd_tag.go b/cmd/restic/cmd_tag.go index b0d139fa673..3bf386f2cd3 100644 --- a/cmd/restic/cmd_tag.go +++ b/cmd/restic/cmd_tag.go @@ -122,6 +122,9 @@ func runTag(ctx context.Context, opts TagOptions, gopts GlobalOptions, args []st changeCnt++ } } + if ctx.Err() != nil { + return ctx.Err() + } if changeCnt == 0 { Verbosef("no snapshots were modified\n") } else { diff --git a/cmd/restic/global.go b/cmd/restic/global.go index eded479ada6..6920caa8d82 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -43,7 +43,7 @@ import ( "golang.org/x/term" ) -var version = "0.16.4-dev (compiled manually)" +const version = "0.16.4-dev (compiled manually)" // TimeFormat is the format used for all timestamps printed by restic. const TimeFormat = "2006-01-02 15:04:05" @@ -96,9 +96,6 @@ var globalOptions = GlobalOptions{ stderr: os.Stderr, } -var isReadingPassword bool -var internalGlobalCtx context.Context - func init() { backends := location.NewRegistry() backends.Register(azure.NewFactory()) @@ -112,15 +109,6 @@ func init() { backends.Register(swift.NewFactory()) globalOptions.backends = backends - var cancel context.CancelFunc - internalGlobalCtx, cancel = context.WithCancel(context.Background()) - AddCleanupHandler(func(code int) (int, error) { - // Must be called before the unlock cleanup handler to ensure that the latter is - // not blocked due to limited number of backend connections, see #1434 - cancel() - return code, nil - }) - f := cmdRoot.PersistentFlags() f.StringVarP(&globalOptions.Repo, "repo", "r", "", "`repository` to backup to or restore from (default: $RESTIC_REPOSITORY)") f.StringVarP(&globalOptions.RepositoryFile, "repository-file", "", "", "`file` to read the repository location from (default: $RESTIC_REPOSITORY_FILE)") @@ -165,8 +153,6 @@ func init() { // parse target pack size from env, on error the default value will be used targetPackSize, _ := strconv.ParseUint(os.Getenv("RESTIC_PACK_SIZE"), 10, 32) globalOptions.PackSize = uint(targetPackSize) - - restoreTerminal() } func stdinIsTerminal() bool { @@ -191,40 +177,6 @@ func stdoutTerminalWidth() int { return w } -// restoreTerminal installs a cleanup handler that restores the previous -// terminal state on exit. This handler is only intended to restore the -// terminal configuration if restic exits after receiving a signal. A regular -// program execution must revert changes to the terminal configuration itself. -// The terminal configuration is only restored while reading a password. -func restoreTerminal() { - if !term.IsTerminal(int(os.Stdout.Fd())) { - return - } - - fd := int(os.Stdout.Fd()) - state, err := term.GetState(fd) - if err != nil { - fmt.Fprintf(os.Stderr, "unable to get terminal state: %v\n", err) - return - } - - AddCleanupHandler(func(code int) (int, error) { - // Restoring the terminal configuration while restic runs in the - // background, causes restic to get stopped on unix systems with - // a SIGTTOU signal. Thus only restore the terminal settings if - // they might have been modified, which is the case while reading - // a password. - if !isReadingPassword { - return code, nil - } - err := term.Restore(fd, state) - if err != nil { - fmt.Fprintf(os.Stderr, "unable to restore terminal state: %v\n", err) - } - return code, err - }) -} - // ClearLine creates a platform dependent string to clear the current // line, so it can be overwritten. // @@ -333,24 +285,48 @@ func readPassword(in io.Reader) (password string, err error) { // readPasswordTerminal reads the password from the given reader which must be a // tty. Prompt is printed on the writer out before attempting to read the -// password. -func readPasswordTerminal(in *os.File, out io.Writer, prompt string) (password string, err error) { - fmt.Fprint(out, prompt) - isReadingPassword = true - buf, err := term.ReadPassword(int(in.Fd())) - isReadingPassword = false - fmt.Fprintln(out) +// password. If the context is canceled, the function leaks the password reading +// goroutine. +func readPasswordTerminal(ctx context.Context, in *os.File, out *os.File, prompt string) (password string, err error) { + fd := int(out.Fd()) + state, err := term.GetState(fd) + if err != nil { + fmt.Fprintf(os.Stderr, "unable to get terminal state: %v\n", err) + return "", err + } + + done := make(chan struct{}) + var buf []byte + + go func() { + defer close(done) + fmt.Fprint(out, prompt) + buf, err = term.ReadPassword(int(in.Fd())) + fmt.Fprintln(out) + }() + + select { + case <-ctx.Done(): + err := term.Restore(fd, state) + if err != nil { + fmt.Fprintf(os.Stderr, "unable to restore terminal state: %v\n", err) + } + return "", ctx.Err() + case <-done: + // clean shutdown, nothing to do + } + if err != nil { return "", errors.Wrap(err, "ReadPassword") } - password = string(buf) - return password, nil + return string(buf), nil } // ReadPassword reads the password from a password file, the environment -// variable RESTIC_PASSWORD or prompts the user. -func ReadPassword(opts GlobalOptions, prompt string) (string, error) { +// variable RESTIC_PASSWORD or prompts the user. If the context is canceled, +// the function leaks the password reading goroutine. +func ReadPassword(ctx context.Context, opts GlobalOptions, prompt string) (string, error) { if opts.password != "" { return opts.password, nil } @@ -361,7 +337,7 @@ func ReadPassword(opts GlobalOptions, prompt string) (string, error) { ) if stdinIsTerminal() { - password, err = readPasswordTerminal(os.Stdin, os.Stderr, prompt) + password, err = readPasswordTerminal(ctx, os.Stdin, os.Stderr, prompt) } else { password, err = readPassword(os.Stdin) Verbosef("reading repository password from stdin\n") @@ -379,14 +355,15 @@ func ReadPassword(opts GlobalOptions, prompt string) (string, error) { } // ReadPasswordTwice calls ReadPassword two times and returns an error when the -// passwords don't match. -func ReadPasswordTwice(gopts GlobalOptions, prompt1, prompt2 string) (string, error) { - pw1, err := ReadPassword(gopts, prompt1) +// passwords don't match. If the context is canceled, the function leaks the +// password reading goroutine. +func ReadPasswordTwice(ctx context.Context, gopts GlobalOptions, prompt1, prompt2 string) (string, error) { + pw1, err := ReadPassword(ctx, gopts, prompt1) if err != nil { return "", err } if stdinIsTerminal() { - pw2, err := ReadPassword(gopts, prompt2) + pw2, err := ReadPassword(ctx, gopts, prompt2) if err != nil { return "", err } @@ -469,7 +446,10 @@ func OpenRepository(ctx context.Context, opts GlobalOptions) (*repository.Reposi } for ; passwordTriesLeft > 0; passwordTriesLeft-- { - opts.password, err = ReadPassword(opts, "enter password for repository: ") + opts.password, err = ReadPassword(ctx, opts, "enter password for repository: ") + if ctx.Err() != nil { + return nil, ctx.Err() + } if err != nil && passwordTriesLeft > 1 { opts.password = "" fmt.Printf("%s. Try again\n", err) diff --git a/cmd/restic/global_debug.go b/cmd/restic/global_debug.go index b798074d10b..502b2cf6ed3 100644 --- a/cmd/restic/global_debug.go +++ b/cmd/restic/global_debug.go @@ -15,23 +15,28 @@ import ( "github.com/pkg/profile" ) -var ( - listenProfile string - memProfilePath string - cpuProfilePath string - traceProfilePath string - blockProfilePath string - insecure bool -) +type ProfileOptions struct { + listen string + memPath string + cpuPath string + tracePath string + blockPath string + insecure bool +} + +var profileOpts ProfileOptions +var prof interface { + Stop() +} func init() { f := cmdRoot.PersistentFlags() - f.StringVar(&listenProfile, "listen-profile", "", "listen on this `address:port` for memory profiling") - f.StringVar(&memProfilePath, "mem-profile", "", "write memory profile to `dir`") - f.StringVar(&cpuProfilePath, "cpu-profile", "", "write cpu profile to `dir`") - f.StringVar(&traceProfilePath, "trace-profile", "", "write trace to `dir`") - f.StringVar(&blockProfilePath, "block-profile", "", "write block profile to `dir`") - f.BoolVar(&insecure, "insecure-kdf", false, "use insecure KDF settings") + f.StringVar(&profileOpts.listen, "listen-profile", "", "listen on this `address:port` for memory profiling") + f.StringVar(&profileOpts.memPath, "mem-profile", "", "write memory profile to `dir`") + f.StringVar(&profileOpts.cpuPath, "cpu-profile", "", "write cpu profile to `dir`") + f.StringVar(&profileOpts.tracePath, "trace-profile", "", "write trace to `dir`") + f.StringVar(&profileOpts.blockPath, "block-profile", "", "write block profile to `dir`") + f.BoolVar(&profileOpts.insecure, "insecure-kdf", false, "use insecure KDF settings") } type fakeTestingTB struct{} @@ -41,10 +46,10 @@ func (fakeTestingTB) Logf(msg string, args ...interface{}) { } func runDebug() error { - if listenProfile != "" { - fmt.Fprintf(os.Stderr, "running profile HTTP server on %v\n", listenProfile) + if profileOpts.listen != "" { + fmt.Fprintf(os.Stderr, "running profile HTTP server on %v\n", profileOpts.listen) go func() { - err := http.ListenAndServe(listenProfile, nil) + err := http.ListenAndServe(profileOpts.listen, nil) if err != nil { fmt.Fprintf(os.Stderr, "profile HTTP server listen failed: %v\n", err) } @@ -52,16 +57,16 @@ func runDebug() error { } profilesEnabled := 0 - if memProfilePath != "" { + if profileOpts.memPath != "" { profilesEnabled++ } - if cpuProfilePath != "" { + if profileOpts.cpuPath != "" { profilesEnabled++ } - if traceProfilePath != "" { + if profileOpts.tracePath != "" { profilesEnabled++ } - if blockProfilePath != "" { + if profileOpts.blockPath != "" { profilesEnabled++ } @@ -69,30 +74,25 @@ func runDebug() error { return errors.Fatal("only one profile (memory, CPU, trace, or block) may be activated at the same time") } - var prof interface { - Stop() + if profileOpts.memPath != "" { + prof = profile.Start(profile.Quiet, profile.NoShutdownHook, profile.MemProfile, profile.ProfilePath(profileOpts.memPath)) + } else if profileOpts.cpuPath != "" { + prof = profile.Start(profile.Quiet, profile.NoShutdownHook, profile.CPUProfile, profile.ProfilePath(profileOpts.cpuPath)) + } else if profileOpts.tracePath != "" { + prof = profile.Start(profile.Quiet, profile.NoShutdownHook, profile.TraceProfile, profile.ProfilePath(profileOpts.tracePath)) + } else if profileOpts.blockPath != "" { + prof = profile.Start(profile.Quiet, profile.NoShutdownHook, profile.BlockProfile, profile.ProfilePath(profileOpts.blockPath)) } - if memProfilePath != "" { - prof = profile.Start(profile.Quiet, profile.NoShutdownHook, profile.MemProfile, profile.ProfilePath(memProfilePath)) - } else if cpuProfilePath != "" { - prof = profile.Start(profile.Quiet, profile.NoShutdownHook, profile.CPUProfile, profile.ProfilePath(cpuProfilePath)) - } else if traceProfilePath != "" { - prof = profile.Start(profile.Quiet, profile.NoShutdownHook, profile.TraceProfile, profile.ProfilePath(traceProfilePath)) - } else if blockProfilePath != "" { - prof = profile.Start(profile.Quiet, profile.NoShutdownHook, profile.BlockProfile, profile.ProfilePath(blockProfilePath)) - } - - if prof != nil { - AddCleanupHandler(func(code int) (int, error) { - prof.Stop() - return code, nil - }) - } - - if insecure { + if profileOpts.insecure { repository.TestUseLowSecurityKDFParameters(fakeTestingTB{}) } return nil } + +func stopDebug() { + if prof != nil { + prof.Stop() + } +} diff --git a/cmd/restic/global_release.go b/cmd/restic/global_release.go index 7cb2e6caf3c..1dab5a293ac 100644 --- a/cmd/restic/global_release.go +++ b/cmd/restic/global_release.go @@ -5,3 +5,6 @@ package main // runDebug is a noop without the debug tag. func runDebug() error { return nil } + +// stopDebug is a noop without the debug tag. +func stopDebug() {} diff --git a/cmd/restic/integration_helpers_test.go b/cmd/restic/integration_helpers_test.go index c87e1071e71..e7a90dd560a 100644 --- a/cmd/restic/integration_helpers_test.go +++ b/cmd/restic/integration_helpers_test.go @@ -252,11 +252,11 @@ func listTreePacks(gopts GlobalOptions, t *testing.T) restic.IDSet { rtest.OK(t, r.LoadIndex(ctx, nil)) treePacks := restic.NewIDSet() - r.Index().Each(ctx, func(pb restic.PackedBlob) { + rtest.OK(t, r.Index().Each(ctx, func(pb restic.PackedBlob) { if pb.Type == restic.TreeBlob { treePacks.Insert(pb.PackID) } - }) + })) return treePacks } @@ -280,11 +280,11 @@ func removePacksExcept(gopts GlobalOptions, t testing.TB, keep restic.IDSet, rem rtest.OK(t, r.LoadIndex(ctx, nil)) treePacks := restic.NewIDSet() - r.Index().Each(ctx, func(pb restic.PackedBlob) { + rtest.OK(t, r.Index().Each(ctx, func(pb restic.PackedBlob) { if pb.Type == restic.TreeBlob { treePacks.Insert(pb.PackID) } - }) + })) // remove all packs containing data blobs rtest.OK(t, r.List(ctx, restic.PackFile, func(id restic.ID, size int64) error { diff --git a/cmd/restic/lock.go b/cmd/restic/lock.go index 69d433df172..99e199a6773 100644 --- a/cmd/restic/lock.go +++ b/cmd/restic/lock.go @@ -21,18 +21,11 @@ func internalOpenWithLocked(ctx context.Context, gopts GlobalOptions, dryRun boo Verbosef("%s", msg) } }, Warnf) - - unlock = lock.Unlock - // make sure that a repository is unlocked properly and after cancel() was - // called by the cleanup handler in global.go - AddCleanupHandler(func(code int) (int, error) { - lock.Unlock() - return code, nil - }) - if err != nil { return nil, nil, nil, err } + + unlock = lock.Unlock } else { repo.SetDryRun() } diff --git a/cmd/restic/main.go b/cmd/restic/main.go index a4acb1cab38..e847b815674 100644 --- a/cmd/restic/main.go +++ b/cmd/restic/main.go @@ -3,6 +3,7 @@ package main import ( "bufio" "bytes" + "context" "fmt" "log" "os" @@ -24,6 +25,8 @@ func init() { _, _ = maxprocs.Set() } +var ErrOK = errors.New("ok") + // cmdRoot is the base command when no other command has been specified. var cmdRoot = &cobra.Command{ Use: "restic", @@ -74,6 +77,9 @@ The full documentation can be found at https://restic.readthedocs.io/ . // enabled) return runDebug() }, + PersistentPostRun: func(_ *cobra.Command, _ []string) { + stopDebug() + }, } // Distinguish commands that need the password from those that work without, @@ -88,8 +94,6 @@ func needsPassword(cmd string) bool { } } -var logBuffer = bytes.NewBuffer(nil) - func tweakGoGC() { // lower GOGC from 100 to 50, unless it was manually overwritten by the user oldValue := godebug.SetGCPercent(50) @@ -102,6 +106,7 @@ func main() { tweakGoGC() // install custom global logger into a buffer, if an error occurs // we can show the logs + logBuffer := bytes.NewBuffer(nil) log.SetOutput(logBuffer) err := feature.Flag.Apply(os.Getenv("RESTIC_FEATURES"), func(s string) { @@ -115,7 +120,16 @@ func main() { debug.Log("main %#v", os.Args) debug.Log("restic %s compiled with %v on %v/%v", version, runtime.Version(), runtime.GOOS, runtime.GOARCH) - err = cmdRoot.ExecuteContext(internalGlobalCtx) + + ctx := createGlobalContext() + err = cmdRoot.ExecuteContext(ctx) + + if err == nil { + err = ctx.Err() + } else if err == ErrOK { + // ErrOK overwrites context cancelation errors + err = nil + } switch { case restic.IsAlreadyLocked(err): @@ -137,11 +151,13 @@ func main() { } var exitCode int - switch err { - case nil: + switch { + case err == nil: exitCode = 0 - case ErrInvalidSourceData: + case err == ErrInvalidSourceData: exitCode = 3 + case errors.Is(err, context.Canceled): + exitCode = 130 default: exitCode = 1 } diff --git a/cmd/restic/secondary_repo.go b/cmd/restic/secondary_repo.go index 4c46b60df44..2afd36a81b9 100644 --- a/cmd/restic/secondary_repo.go +++ b/cmd/restic/secondary_repo.go @@ -1,6 +1,7 @@ package main import ( + "context" "os" "github.com/restic/restic/internal/errors" @@ -56,7 +57,7 @@ func initSecondaryRepoOptions(f *pflag.FlagSet, opts *secondaryRepoOptions, repo opts.PasswordCommand = os.Getenv("RESTIC_FROM_PASSWORD_COMMAND") } -func fillSecondaryGlobalOpts(opts secondaryRepoOptions, gopts GlobalOptions, repoPrefix string) (GlobalOptions, bool, error) { +func fillSecondaryGlobalOpts(ctx context.Context, opts secondaryRepoOptions, gopts GlobalOptions, repoPrefix string) (GlobalOptions, bool, error) { if opts.Repo == "" && opts.RepositoryFile == "" && opts.LegacyRepo == "" && opts.LegacyRepositoryFile == "" { return GlobalOptions{}, false, errors.Fatal("Please specify a source repository location (--from-repo or --from-repository-file)") } @@ -109,7 +110,7 @@ func fillSecondaryGlobalOpts(opts secondaryRepoOptions, gopts GlobalOptions, rep return GlobalOptions{}, false, err } } - dstGopts.password, err = ReadPassword(dstGopts, "enter password for "+repoPrefix+" repository: ") + dstGopts.password, err = ReadPassword(ctx, dstGopts, "enter password for "+repoPrefix+" repository: ") if err != nil { return GlobalOptions{}, false, err } diff --git a/cmd/restic/secondary_repo_test.go b/cmd/restic/secondary_repo_test.go index ff1a10b03cb..aa511ca992a 100644 --- a/cmd/restic/secondary_repo_test.go +++ b/cmd/restic/secondary_repo_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "os" "path/filepath" "testing" @@ -170,7 +171,7 @@ func TestFillSecondaryGlobalOpts(t *testing.T) { // Test all valid cases for _, testCase := range validSecondaryRepoTestCases { - DstGOpts, isFromRepo, err := fillSecondaryGlobalOpts(testCase.Opts, gOpts, "destination") + DstGOpts, isFromRepo, err := fillSecondaryGlobalOpts(context.TODO(), testCase.Opts, gOpts, "destination") rtest.OK(t, err) rtest.Equals(t, DstGOpts, testCase.DstGOpts) rtest.Equals(t, isFromRepo, testCase.FromRepo) @@ -178,7 +179,7 @@ func TestFillSecondaryGlobalOpts(t *testing.T) { // Test all invalid cases for _, testCase := range invalidSecondaryRepoTestCases { - _, _, err := fillSecondaryGlobalOpts(testCase.Opts, gOpts, "destination") + _, _, err := fillSecondaryGlobalOpts(context.TODO(), testCase.Opts, gOpts, "destination") rtest.Assert(t, err != nil, "Expected error, but function did not return an error") } } diff --git a/helpers/prepare-release/main.go b/helpers/prepare-release/main.go index baf8aa2baeb..703d85e7007 100644 --- a/helpers/prepare-release/main.go +++ b/helpers/prepare-release/main.go @@ -303,7 +303,7 @@ func generateFiles() { } } -var versionPattern = `var version = ".*"` +var versionPattern = `const version = ".*"` const versionCodeFile = "cmd/restic/global.go" @@ -313,7 +313,7 @@ func updateVersion() { die("unable to write version to file: %v", err) } - newVersion := fmt.Sprintf("var version = %q", opts.Version) + newVersion := fmt.Sprintf("const version = %q", opts.Version) replace(versionCodeFile, versionPattern, newVersion) if len(uncommittedChanges("VERSION")) > 0 || len(uncommittedChanges(versionCodeFile)) > 0 { @@ -323,7 +323,7 @@ func updateVersion() { } func updateVersionDev() { - newVersion := fmt.Sprintf(`var version = "%s-dev (compiled manually)"`, opts.Version) + newVersion := fmt.Sprintf(`const version = "%s-dev (compiled manually)"`, opts.Version) replace(versionCodeFile, versionPattern, newVersion) msg("committing cmd/restic/global.go with dev version") diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index 146ff3a7ccb..c1f73eea664 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -380,6 +380,7 @@ func (fn *FutureNode) take(ctx context.Context) futureNodeResult { return res } case <-ctx.Done(): + return futureNodeResult{err: ctx.Err()} } return futureNodeResult{err: errors.Errorf("no result")} } diff --git a/internal/archiver/tree_saver.go b/internal/archiver/tree_saver.go index eae524a7805..9c11b48f060 100644 --- a/internal/archiver/tree_saver.go +++ b/internal/archiver/tree_saver.go @@ -90,6 +90,10 @@ func (s *TreeSaver) save(ctx context.Context, job *saveTreeJob) (*restic.Node, I // return the error if it wasn't ignored if fnr.err != nil { debug.Log("err for %v: %v", fnr.snPath, fnr.err) + if fnr.err == context.Canceled { + return nil, stats, fnr.err + } + fnr.err = s.errFn(fnr.target, fnr.err) if fnr.err == nil { // ignore error diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 28f55ce3ad2..1057341bc73 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -106,9 +106,9 @@ func (c *Checker) LoadSnapshots(ctx context.Context) error { return err } -func computePackTypes(ctx context.Context, idx restic.MasterIndex) map[restic.ID]restic.BlobType { +func computePackTypes(ctx context.Context, idx restic.MasterIndex) (map[restic.ID]restic.BlobType, error) { packs := make(map[restic.ID]restic.BlobType) - idx.Each(ctx, func(pb restic.PackedBlob) { + err := idx.Each(ctx, func(pb restic.PackedBlob) { tpe, exists := packs[pb.PackID] if exists { if pb.Type != tpe { @@ -119,7 +119,7 @@ func computePackTypes(ctx context.Context, idx restic.MasterIndex) map[restic.ID } packs[pb.PackID] = tpe }) - return packs + return packs, err } // LoadIndex loads all index files. @@ -169,7 +169,7 @@ func (c *Checker) LoadIndex(ctx context.Context, p *progress.Counter) (hints []e debug.Log("process blobs") cnt := 0 - index.Each(ctx, func(blob restic.PackedBlob) { + err = index.Each(ctx, func(blob restic.PackedBlob) { cnt++ if _, ok := packToIndex[blob.PackID]; !ok { @@ -179,7 +179,7 @@ func (c *Checker) LoadIndex(ctx context.Context, p *progress.Counter) (hints []e }) debug.Log("%d blobs processed", cnt) - return nil + return err }) if err != nil { errs = append(errs, err) @@ -193,8 +193,14 @@ func (c *Checker) LoadIndex(ctx context.Context, p *progress.Counter) (hints []e } // compute pack size using index entries - c.packs = pack.Size(ctx, c.masterIndex, false) - packTypes := computePackTypes(ctx, c.masterIndex) + c.packs, err = pack.Size(ctx, c.masterIndex, false) + if err != nil { + return hints, append(errs, err) + } + packTypes, err := computePackTypes(ctx, c.masterIndex) + if err != nil { + return hints, append(errs, err) + } debug.Log("checking for duplicate packs") for packID := range c.packs { @@ -484,7 +490,7 @@ func (c *Checker) checkTree(id restic.ID, tree *restic.Tree) (errs []error) { } // UnusedBlobs returns all blobs that have never been referenced. -func (c *Checker) UnusedBlobs(ctx context.Context) (blobs restic.BlobHandles) { +func (c *Checker) UnusedBlobs(ctx context.Context) (blobs restic.BlobHandles, err error) { if !c.trackUnused { panic("only works when tracking blob references") } @@ -495,7 +501,7 @@ func (c *Checker) UnusedBlobs(ctx context.Context) (blobs restic.BlobHandles) { ctx, cancel := context.WithCancel(ctx) defer cancel() - c.repo.Index().Each(ctx, func(blob restic.PackedBlob) { + err = c.repo.Index().Each(ctx, func(blob restic.PackedBlob) { h := restic.BlobHandle{ID: blob.ID, Type: blob.Type} if !c.blobRefs.M.Has(h) { debug.Log("blob %v not referenced", h) @@ -503,7 +509,7 @@ func (c *Checker) UnusedBlobs(ctx context.Context) (blobs restic.BlobHandles) { } }) - return blobs + return blobs, err } // CountPacks returns the number of packs in the repository. diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index b0fa4e3e386..9746e9f5d8b 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -180,7 +180,8 @@ func TestUnreferencedBlobs(t *testing.T) { test.OKs(t, checkPacks(chkr)) test.OKs(t, checkStruct(chkr)) - blobs := chkr.UnusedBlobs(context.TODO()) + blobs, err := chkr.UnusedBlobs(context.TODO()) + test.OK(t, err) sort.Sort(blobs) test.Equals(t, unusedBlobsBySnapshot, blobs) diff --git a/internal/checker/testing.go b/internal/checker/testing.go index 9e949af026f..d0014398ff7 100644 --- a/internal/checker/testing.go +++ b/internal/checker/testing.go @@ -43,7 +43,10 @@ func TestCheckRepo(t testing.TB, repo restic.Repository, skipStructure bool) { } // unused blobs - blobs := chkr.UnusedBlobs(context.TODO()) + blobs, err := chkr.UnusedBlobs(context.TODO()) + if err != nil { + t.Error(err) + } if len(blobs) > 0 { t.Errorf("unused blobs found: %v", blobs) } diff --git a/internal/index/index.go b/internal/index/index.go index 73128f7bb21..1c20fe38d60 100644 --- a/internal/index/index.go +++ b/internal/index/index.go @@ -218,7 +218,7 @@ func (idx *Index) AddToSupersedes(ids ...restic.ID) error { // Each passes all blobs known to the index to the callback fn. This blocks any // modification of the index. -func (idx *Index) Each(ctx context.Context, fn func(restic.PackedBlob)) { +func (idx *Index) Each(ctx context.Context, fn func(restic.PackedBlob)) error { idx.m.Lock() defer idx.m.Unlock() @@ -232,6 +232,7 @@ func (idx *Index) Each(ctx context.Context, fn func(restic.PackedBlob)) { return true }) } + return ctx.Err() } type EachByPackResult struct { diff --git a/internal/index/index_test.go b/internal/index/index_test.go index 78e4800cac3..bafd95c4838 100644 --- a/internal/index/index_test.go +++ b/internal/index/index_test.go @@ -339,7 +339,7 @@ func TestIndexUnserialize(t *testing.T) { rtest.Equals(t, oldIdx, idx.Supersedes()) - blobs := listPack(idx, exampleLookupTest.packID) + blobs := listPack(t, idx, exampleLookupTest.packID) if len(blobs) != len(exampleLookupTest.blobs) { t.Fatalf("expected %d blobs in pack, got %d", len(exampleLookupTest.blobs), len(blobs)) } @@ -356,12 +356,12 @@ func TestIndexUnserialize(t *testing.T) { } } -func listPack(idx *index.Index, id restic.ID) (pbs []restic.PackedBlob) { - idx.Each(context.TODO(), func(pb restic.PackedBlob) { +func listPack(t testing.TB, idx *index.Index, id restic.ID) (pbs []restic.PackedBlob) { + rtest.OK(t, idx.Each(context.TODO(), func(pb restic.PackedBlob) { if pb.PackID.Equal(id) { pbs = append(pbs, pb) } - }) + })) return pbs } diff --git a/internal/index/master_index.go b/internal/index/master_index.go index 4c114b955d8..d99a3434df1 100644 --- a/internal/index/master_index.go +++ b/internal/index/master_index.go @@ -223,13 +223,16 @@ func (mi *MasterIndex) finalizeFullIndexes() []*Index { // Each runs fn on all blobs known to the index. When the context is cancelled, // the index iteration return immediately. This blocks any modification of the index. -func (mi *MasterIndex) Each(ctx context.Context, fn func(restic.PackedBlob)) { +func (mi *MasterIndex) Each(ctx context.Context, fn func(restic.PackedBlob)) error { mi.idxMutex.RLock() defer mi.idxMutex.RUnlock() for _, idx := range mi.idx { - idx.Each(ctx, fn) + if err := idx.Each(ctx, fn); err != nil { + return err + } } + return nil } // MergeFinalIndexes merges all final indexes together. @@ -320,6 +323,9 @@ func (mi *MasterIndex) Save(ctx context.Context, repo restic.Repository, exclude newIndex = NewIndex() } } + if wgCtx.Err() != nil { + return wgCtx.Err() + } } err := newIndex.AddToSupersedes(extraObsolete...) @@ -426,10 +432,6 @@ func (mi *MasterIndex) ListPacks(ctx context.Context, packs restic.IDSet) <-chan defer close(out) // only resort a part of the index to keep the memory overhead bounded for i := byte(0); i < 16; i++ { - if ctx.Err() != nil { - return - } - packBlob := make(map[restic.ID][]restic.Blob) for pack := range packs { if pack[0]&0xf == i { @@ -439,11 +441,14 @@ func (mi *MasterIndex) ListPacks(ctx context.Context, packs restic.IDSet) <-chan if len(packBlob) == 0 { continue } - mi.Each(ctx, func(pb restic.PackedBlob) { + err := mi.Each(ctx, func(pb restic.PackedBlob) { if packs.Has(pb.PackID) && pb.PackID[0]&0xf == i { packBlob[pb.PackID] = append(packBlob[pb.PackID], pb.Blob) } }) + if err != nil { + return + } // pass on packs for packID, pbs := range packBlob { diff --git a/internal/index/master_index_test.go b/internal/index/master_index_test.go index dcf6a94f6e9..fe0364c61dd 100644 --- a/internal/index/master_index_test.go +++ b/internal/index/master_index_test.go @@ -166,9 +166,9 @@ func TestMasterMergeFinalIndexes(t *testing.T) { rtest.Equals(t, 1, idxCount) blobCount := 0 - mIdx.Each(context.TODO(), func(pb restic.PackedBlob) { + rtest.OK(t, mIdx.Each(context.TODO(), func(pb restic.PackedBlob) { blobCount++ - }) + })) rtest.Equals(t, 2, blobCount) blobs := mIdx.Lookup(bhInIdx1) @@ -198,9 +198,9 @@ func TestMasterMergeFinalIndexes(t *testing.T) { rtest.Equals(t, []restic.PackedBlob{blob2}, blobs) blobCount = 0 - mIdx.Each(context.TODO(), func(pb restic.PackedBlob) { + rtest.OK(t, mIdx.Each(context.TODO(), func(pb restic.PackedBlob) { blobCount++ - }) + })) rtest.Equals(t, 2, blobCount) } @@ -319,9 +319,9 @@ func BenchmarkMasterIndexEach(b *testing.B) { for i := 0; i < b.N; i++ { entries := 0 - mIdx.Each(context.TODO(), func(pb restic.PackedBlob) { + rtest.OK(b, mIdx.Each(context.TODO(), func(pb restic.PackedBlob) { entries++ - }) + })) } } diff --git a/internal/pack/pack.go b/internal/pack/pack.go index cd118ab032a..53631a6fb73 100644 --- a/internal/pack/pack.go +++ b/internal/pack/pack.go @@ -389,10 +389,10 @@ func CalculateHeaderSize(blobs []restic.Blob) int { // If onlyHdr is set to true, only the size of the header is returned // Note that this function only gives correct sizes, if there are no // duplicates in the index. -func Size(ctx context.Context, mi restic.MasterIndex, onlyHdr bool) map[restic.ID]int64 { +func Size(ctx context.Context, mi restic.MasterIndex, onlyHdr bool) (map[restic.ID]int64, error) { packSize := make(map[restic.ID]int64) - mi.Each(ctx, func(blob restic.PackedBlob) { + err := mi.Each(ctx, func(blob restic.PackedBlob) { size, ok := packSize[blob.PackID] if !ok { size = headerSize @@ -403,5 +403,5 @@ func Size(ctx context.Context, mi restic.MasterIndex, onlyHdr bool) map[restic.I packSize[blob.PackID] = size + int64(CalculateEntrySize(blob.Blob)) }) - return packSize + return packSize, err } diff --git a/internal/repository/prune.go b/internal/repository/prune.go index 8900fffaa4f..77811e3214f 100644 --- a/internal/repository/prune.go +++ b/internal/repository/prune.go @@ -124,12 +124,15 @@ func PlanPrune(ctx context.Context, opts PruneOptions, repo restic.Repository, g blobCount := keepBlobs.Len() // when repacking, we do not want to keep blobs which are // already contained in kept packs, so delete them from keepBlobs - repo.Index().Each(ctx, func(blob restic.PackedBlob) { + err := repo.Index().Each(ctx, func(blob restic.PackedBlob) { if plan.removePacks.Has(blob.PackID) || plan.repackPacks.Has(blob.PackID) { return } keepBlobs.Delete(blob.BlobHandle) }) + if err != nil { + return nil, err + } if keepBlobs.Len() < blobCount/2 { // replace with copy to shrink map to necessary size if there's a chance to benefit @@ -152,7 +155,7 @@ func packInfoFromIndex(ctx context.Context, idx restic.MasterIndex, usedBlobs re // iterate over all blobs in index to find out which blobs are duplicates // The counter in usedBlobs describes how many instances of the blob exist in the repository index // Thus 0 == blob is missing, 1 == blob exists once, >= 2 == duplicates exist - idx.Each(ctx, func(blob restic.PackedBlob) { + err := idx.Each(ctx, func(blob restic.PackedBlob) { bh := blob.BlobHandle count, ok := usedBlobs[bh] if ok { @@ -166,6 +169,9 @@ func packInfoFromIndex(ctx context.Context, idx restic.MasterIndex, usedBlobs re usedBlobs[bh] = count } }) + if err != nil { + return nil, nil, err + } // Check if all used blobs have been found in index missingBlobs := restic.NewBlobSet() @@ -188,14 +194,18 @@ func packInfoFromIndex(ctx context.Context, idx restic.MasterIndex, usedBlobs re indexPack := make(map[restic.ID]packInfo) // save computed pack header size - for pid, hdrSize := range pack.Size(ctx, idx, true) { + sz, err := pack.Size(ctx, idx, true) + if err != nil { + return nil, nil, err + } + for pid, hdrSize := range sz { // initialize tpe with NumBlobTypes to indicate it's not set indexPack[pid] = packInfo{tpe: restic.NumBlobTypes, usedSize: uint64(hdrSize)} } hasDuplicates := false // iterate over all blobs in index to generate packInfo - idx.Each(ctx, func(blob restic.PackedBlob) { + err = idx.Each(ctx, func(blob restic.PackedBlob) { ip := indexPack[blob.PackID] // Set blob type if not yet set @@ -240,6 +250,9 @@ func packInfoFromIndex(ctx context.Context, idx restic.MasterIndex, usedBlobs re // update indexPack indexPack[blob.PackID] = ip }) + if err != nil { + return nil, nil, err + } // if duplicate blobs exist, those will be set to either "used" or "unused": // - mark only one occurrence of duplicate blobs as used @@ -247,7 +260,7 @@ func packInfoFromIndex(ctx context.Context, idx restic.MasterIndex, usedBlobs re // - if there are no used blobs in a pack, possibly mark duplicates as "unused" if hasDuplicates { // iterate again over all blobs in index (this is pretty cheap, all in-mem) - idx.Each(ctx, func(blob restic.PackedBlob) { + err = idx.Each(ctx, func(blob restic.PackedBlob) { bh := blob.BlobHandle count, ok := usedBlobs[bh] // skip non-duplicate, aka. normal blobs @@ -285,6 +298,9 @@ func packInfoFromIndex(ctx context.Context, idx restic.MasterIndex, usedBlobs re // update indexPack indexPack[blob.PackID] = ip }) + if err != nil { + return nil, nil, err + } } // Sanity check. If no duplicates exist, all blobs have value 1. After handling @@ -528,6 +544,9 @@ func (plan *PrunePlan) Execute(ctx context.Context, printer progress.Printer) (e printer.P("deleting unreferenced packs\n") _ = deleteFiles(ctx, true, repo, plan.removePacksFirst, restic.PackFile, printer) } + if ctx.Err() != nil { + return ctx.Err() + } if len(plan.repackPacks) != 0 { printer.P("repacking packs\n") @@ -578,6 +597,9 @@ func (plan *PrunePlan) Execute(ctx context.Context, printer progress.Printer) (e printer.P("removing %d old packs\n", len(plan.removePacks)) _ = deleteFiles(ctx, true, repo, plan.removePacks, restic.PackFile, printer) } + if ctx.Err() != nil { + return ctx.Err() + } if plan.opts.UnsafeRecovery { err = rebuildIndexFiles(ctx, repo, plan.ignorePacks, nil, true, printer) diff --git a/internal/repository/repack.go b/internal/repository/repack.go index 5588984f6f9..53656252a54 100644 --- a/internal/repository/repack.go +++ b/internal/repository/repack.go @@ -72,7 +72,7 @@ func repack(ctx context.Context, repo restic.Repository, dstRepo restic.Reposito return wgCtx.Err() } } - return nil + return wgCtx.Err() }) worker := func() error { diff --git a/internal/repository/repair_index.go b/internal/repository/repair_index.go index 63e10413278..a6e732b4452 100644 --- a/internal/repository/repair_index.go +++ b/internal/repository/repair_index.go @@ -54,7 +54,10 @@ func RepairIndex(ctx context.Context, repo *Repository, opts RepairIndexOptions, if err != nil { return err } - packSizeFromIndex = pack.Size(ctx, repo.Index(), false) + packSizeFromIndex, err = pack.Size(ctx, repo.Index(), false) + if err != nil { + return err + } } printer.P("getting pack files to read...\n") diff --git a/internal/repository/repair_pack_test.go b/internal/repository/repair_pack_test.go index c5cdf5ed52e..078017d213f 100644 --- a/internal/repository/repair_pack_test.go +++ b/internal/repository/repair_pack_test.go @@ -17,7 +17,7 @@ import ( func listBlobs(repo restic.Repository) restic.BlobSet { blobs := restic.NewBlobSet() - repo.Index().Each(context.TODO(), func(pb restic.PackedBlob) { + _ = repo.Index().Each(context.TODO(), func(pb restic.PackedBlob) { blobs.Insert(pb.BlobHandle) }) return blobs diff --git a/internal/repository/repository.go b/internal/repository/repository.go index ae4528d80bc..cac1551c441 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -704,15 +704,21 @@ func (r *Repository) LoadIndex(ctx context.Context, p *progress.Counter) error { defer cancel() invalidIndex := false - r.idx.Each(ctx, func(blob restic.PackedBlob) { + err := r.idx.Each(ctx, func(blob restic.PackedBlob) { if blob.IsCompressed() { invalidIndex = true } }) + if err != nil { + return err + } if invalidIndex { return errors.New("index uses feature not supported by repository version 1") } } + if ctx.Err() != nil { + return ctx.Err() + } // remove index files from the cache which have been removed in the repo return r.prepareCache() diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index b013c482362..48a56a1fd51 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -370,13 +370,13 @@ func testRepositoryIncrementalIndex(t *testing.T, version uint) { idx, err := loadIndex(context.TODO(), repo, id) rtest.OK(t, err) - idx.Each(context.TODO(), func(pb restic.PackedBlob) { + rtest.OK(t, idx.Each(context.TODO(), func(pb restic.PackedBlob) { if _, ok := packEntries[pb.PackID]; !ok { packEntries[pb.PackID] = make(map[restic.ID]struct{}) } packEntries[pb.PackID][id] = struct{}{} - }) + })) return nil }) if err != nil { diff --git a/internal/restic/repository.go b/internal/restic/repository.go index 89c54ffbb1b..7a3389e00d0 100644 --- a/internal/restic/repository.go +++ b/internal/restic/repository.go @@ -103,8 +103,8 @@ type MasterIndex interface { Lookup(BlobHandle) []PackedBlob // Each runs fn on all blobs known to the index. When the context is cancelled, - // the index iteration return immediately. This blocks any modification of the index. - Each(ctx context.Context, fn func(PackedBlob)) + // the index iteration returns immediately with ctx.Err(). This blocks any modification of the index. + Each(ctx context.Context, fn func(PackedBlob)) error ListPacks(ctx context.Context, packs IDSet) <-chan PackBlobs Save(ctx context.Context, repo Repository, excludePacks IDSet, extraObsolete IDs, opts MasterIndexSaveOpts) error