diff --git a/cmd/ipfs/daemon.go b/cmd/ipfs/daemon.go index b2104ce6fc8..9b8d0f022df 100644 --- a/cmd/ipfs/daemon.go +++ b/cmd/ipfs/daemon.go @@ -80,6 +80,15 @@ func daemonFunc(req cmds.Request, res cmds.Response) { // let the user know we're going. fmt.Printf("Initializing daemon...\n") + ctx := req.Context() + + go func() { + select { + case <-ctx.Context.Done(): + fmt.Println("Received interrupt signal, shutting down...") + } + }() + // first, whether user has provided the initialization flag. we may be // running in an uninitialized state. initialize, _, err := req.Option(initOptionKwd).Bool() @@ -111,7 +120,6 @@ func daemonFunc(req cmds.Request, res cmds.Response) { return } - ctx := req.Context() cfg, err := ctx.GetConfig() if err != nil { res.SetError(err, cmds.ErrNormal) @@ -149,7 +157,19 @@ func daemonFunc(req cmds.Request, res cmds.Response) { res.SetError(err, cmds.ErrNormal) return } - defer node.Close() + + defer func() { + // We wait for the node to close first, as the node has children + // that it will wait for before closing, such as the API server. + node.Close() + + select { + case <-ctx.Context.Done(): + log.Info("Gracefully shut down daemon") + default: + } + }() + req.Context().ConstructNode = func() (*core.IpfsNode, error) { return node, nil } @@ -262,9 +282,6 @@ func daemonFunc(req cmds.Request, res cmds.Response) { corehttp.VersionOption(), } - // our global interrupt handler can now try to stop the daemon - close(req.Context().InitDone) - if rootRedirect != nil { opts = append(opts, rootRedirect) } diff --git a/cmd/ipfs/main.go b/cmd/ipfs/main.go index d501c77b4e1..2ad90cfc148 100644 --- a/cmd/ipfs/main.go +++ b/cmd/ipfs/main.go @@ -11,6 +11,7 @@ import ( "runtime" "runtime/pprof" "strings" + "sync" "syscall" "time" @@ -39,7 +40,6 @@ const ( cpuProfile = "ipfs.cpuprof" heapProfile = "ipfs.memprof" errorFormat = "ERROR: %v\n\n" - shutdownMessage = "Received interrupt signal, shutting down..." ) type cmdInvocation struct { @@ -132,15 +132,10 @@ func main() { os.Exit(1) } - // our global interrupt handler may try to stop the daemon - // before the daemon is ready to be stopped; this dirty - // workaround is for the daemon only; other commands are always - // ready to be stopped - if invoc.cmd != daemonCmd { - close(invoc.req.Context().InitDone) - } - // ok, finally, run the command invocation. + intrh, ctx := invoc.SetupInterruptHandler(ctx) + defer intrh.Close() + output, err := invoc.Run(ctx) if err != nil { printErr(err) @@ -157,8 +152,6 @@ func main() { } func (i *cmdInvocation) Run(ctx context.Context) (output io.Reader, err error) { - // setup our global interrupt handler. - i.setupInterruptHandler() // check if user wants to debug. option OR env var. debug, _, err := i.req.Option("debug").Bool() @@ -226,7 +219,6 @@ func (i *cmdInvocation) Parse(ctx context.Context, args []string) error { if err != nil { return err } - i.req.Context().Context = ctx repoPath, err := getRepoPath(i.req) if err != nil { @@ -279,6 +271,8 @@ func callCommand(ctx context.Context, req cmds.Request, root *cmds.Command, cmd log.Info(config.EnvDir, " ", req.Context().ConfigRoot) var res cmds.Response + req.Context().Context = ctx + details, err := commandDetails(req.Path(), root) if err != nil { return nil, err @@ -474,59 +468,70 @@ func writeHeapProfileToFile() error { return pprof.WriteHeapProfile(mprof) } -// listen for and handle SIGTERM -func (i *cmdInvocation) setupInterruptHandler() { +// IntrHandler helps set up an interrupt handler that can +// be cleanly shut down through the io.Closer interface. +type IntrHandler struct { + sig chan os.Signal + wg sync.WaitGroup +} + +func NewIntrHandler() *IntrHandler { + ih := &IntrHandler{} + ih.sig = make(chan os.Signal, 1) + return ih +} + +func (ih *IntrHandler) Close() error { + close(ih.sig) + ih.wg.Wait() + return nil +} - ctx := i.req.Context() - sig := allInterruptSignals() +// Handle starts handling the given signals, and will call the handler +// callback function each time a signal is catched. The function is passed +// the number of times the handler has been triggered in total, as +// well as the handler itself, so that the handling logic can use the +// handler's wait group to ensure clean shutdown when Close() is called. +func (ih *IntrHandler) Handle(handler func(count int, ih *IntrHandler), sigs ...os.Signal) { + signal.Notify(ih.sig, sigs...) + ih.wg.Add(1) go func() { - // first time, try to shut down. - - // loop because we may be - for count := 0; ; count++ { - <-sig - - // if we're still initializing, cannot use `ctx.GetNode()` - select { - default: // initialization not done - fmt.Println(shutdownMessage) - os.Exit(-1) - case <-ctx.InitDone: - } - - switch count { - case 0: - fmt.Println(shutdownMessage) - if ctx.Online { - go func() { - // TODO cancel the command context instead - n, err := ctx.GetNode() - if err != nil { - log.Error(err) - fmt.Println(shutdownMessage) - os.Exit(-1) - } - n.Close() - log.Info("Gracefully shut down.") - }() - } else { - os.Exit(0) - } - - default: - fmt.Println("Received another interrupt before graceful shutdown, terminating...") - os.Exit(-1) - } + defer ih.wg.Done() + count := 0 + for _ = range ih.sig { + count++ + handler(count, ih) } + signal.Stop(ih.sig) }() } -func allInterruptSignals() chan os.Signal { - sigc := make(chan os.Signal, 1) - signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, - syscall.SIGTERM) - return sigc +func (i *cmdInvocation) SetupInterruptHandler(ctx context.Context) (io.Closer, context.Context) { + + intrh := NewIntrHandler() + ctx, cancelFunc := context.WithCancel(ctx) + + handlerFunc := func(count int, ih *IntrHandler) { + switch count { + case 1: + fmt.Println() // Prevent un-terminated ^C character in terminal + + ih.wg.Add(1) + go func() { + defer ih.wg.Done() + cancelFunc() + }() + + default: + fmt.Println("Received another interrupt before graceful shutdown, terminating...") + os.Exit(-1) + } + } + + intrh.Handle(handlerFunc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM) + + return intrh, ctx } func profileIfEnabled() (func(), error) { diff --git a/commands/http/client.go b/commands/http/client.go index 31ffc6df087..a34c89d1ae2 100644 --- a/commands/http/client.go +++ b/commands/http/client.go @@ -82,25 +82,44 @@ func (c *client) Send(req cmds.Request) (cmds.Response, error) { version := config.CurrentVersionNumber httpReq.Header.Set("User-Agent", fmt.Sprintf("/go-ipfs/%s/", version)) - httpRes, err := http.DefaultClient.Do(httpReq) - if err != nil { - return nil, err - } + ec := make(chan error, 1) + rc := make(chan cmds.Response, 1) + dc := req.Context().Context.Done() - // using the overridden JSON encoding in request - res, err := getResponse(httpRes, req) - if err != nil { - return nil, err - } - - if found && len(previousUserProvidedEncoding) > 0 { - // reset to user provided encoding after sending request - // NB: if user has provided an encoding but it is the empty string, - // still leave it as JSON. - req.SetOption(cmds.EncShort, previousUserProvidedEncoding) + go func() { + httpRes, err := http.DefaultClient.Do(httpReq) + if err != nil { + ec <- err + return + } + // using the overridden JSON encoding in request + res, err := getResponse(httpRes, req) + if err != nil { + ec <- err + return + } + rc <- res + }() + + for { + select { + case <-dc: + log.Debug("Context cancelled, cancelling HTTP request...") + tr := http.DefaultTransport.(*http.Transport) + tr.CancelRequest(httpReq) + dc = nil // Wait for ec or rc + case err := <-ec: + return nil, err + case res := <-rc: + if found && len(previousUserProvidedEncoding) > 0 { + // reset to user provided encoding after sending request + // NB: if user has provided an encoding but it is the empty string, + // still leave it as JSON. + req.SetOption(cmds.EncShort, previousUserProvidedEncoding) + } + return res, nil + } } - - return res, nil } func getQuery(req cmds.Request) (string, error) { @@ -162,6 +181,8 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error dec := json.NewDecoder(httpRes.Body) outputType := reflect.TypeOf(req.Command().Type) + ctx := req.Context().Context + for { var v interface{} var err error @@ -175,6 +196,14 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error fmt.Println(err.Error()) return } + + select { + case <-ctx.Done(): + close(outChan) + return + default: + } + if err == io.EOF { close(outChan) return diff --git a/commands/request.go b/commands/request.go index 17b9dfa7e7d..8938900e16a 100644 --- a/commands/request.go +++ b/commands/request.go @@ -30,7 +30,6 @@ type Context struct { node *core.IpfsNode ConstructNode func() (*core.IpfsNode, error) - InitDone chan bool } // GetConfig returns the config of the current Command exection @@ -288,7 +287,7 @@ func NewRequest(path []string, opts OptMap, args []string, file files.File, cmd optDefs = make(map[string]Option) } - ctx := Context{Context: context.TODO(), InitDone: make(chan bool)} + ctx := Context{Context: context.TODO()} values := make(map[string]interface{}) req := &request{path, opts, args, file, cmd, ctx, optDefs, values, os.Stdin} err := req.ConvertOptions() diff --git a/core/corehttp/corehttp.go b/core/corehttp/corehttp.go index 2c679eb1f50..ff9bac70440 100644 --- a/core/corehttp/corehttp.go +++ b/core/corehttp/corehttp.go @@ -2,6 +2,7 @@ package corehttp import ( "net/http" + "time" manners "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/braintree/manners" ma "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" @@ -63,6 +64,9 @@ func listenAndServe(node *core.IpfsNode, addr ma.Multiaddr, handler http.Handler var serverError error serverExited := make(chan struct{}) + node.Children().Add(1) + defer node.Children().Done() + go func() { serverError = server.ListenAndServe(host, handler) close(serverExited) @@ -75,8 +79,22 @@ func listenAndServe(node *core.IpfsNode, addr ma.Multiaddr, handler http.Handler // if node being closed before server exits, close server case <-node.Closing(): log.Infof("server at %s terminating...", addr) + + // make sure keep-alive connections do not keep the server running + server.InnerServer.SetKeepAlivesEnabled(false) + server.Shutdown <- true - <-serverExited // now, DO wait until server exit + + outer: + for { + // wait until server exits + select { + case <-serverExited: + break outer + case <-time.After(5 * time.Second): + log.Infof("waiting for server at %s to terminate...", addr) + } + } } log.Infof("server at %s terminated", addr)