Skip to content

Commit

Permalink
Opt-In Panic Recovery (#364)
Browse files Browse the repository at this point in the history
* Initial feature implementation

* move tests

* Add test for recoverFromPanicsOption.String()

* Fix small comments
  • Loading branch information
JacobOaks authored Dec 13, 2022
1 parent cbba855 commit c29a62e
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 4 deletions.
13 changes: 12 additions & 1 deletion constructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (n *constructorNode) String() string {

// Call calls this constructor if it hasn't already been called and
// injects any values produced by it into the provided container.
func (n *constructorNode) Call(c containerStore) error {
func (n *constructorNode) Call(c containerStore) (err error) {
if n.called {
return nil
}
Expand All @@ -142,6 +142,17 @@ func (n *constructorNode) Call(c containerStore) error {
}
}

if n.s.recoverFromPanics {
defer func() {
if p := recover(); p != nil {
err = PanicError{
fn: n.location,
Panic: p,
}
}
}()
}

args, err := n.paramList.BuildList(c)
if err != nil {
return errArgumentsFailed{
Expand Down
19 changes: 19 additions & 0 deletions container.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,25 @@ func (deferAcyclicVerificationOption) applyOption(c *Container) {
c.scope.deferAcyclicVerification = true
}

// RecoverFromPanics is an [Option] to recover from panics that occur while
// running functions given to the container. When set, recovered panics
// will be placed into a [PanicError], and returned at the invoke callsite.
// See [PanicError] for an example on how to handle panics with this option
// enabled, and distinguish them from errors.
func RecoverFromPanics() Option {
return recoverFromPanicsOption{}
}

type recoverFromPanicsOption struct{}

func (recoverFromPanicsOption) String() string {
return "RecoverFromPanics()"
}

func (recoverFromPanicsOption) applyOption(c *Container) {
c.scope.recoverFromPanics = true
}

// Changes the source of randomness for the container.
//
// This will help provide determinism during tests.
Expand Down
6 changes: 6 additions & 0 deletions container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,10 @@ func TestOptionStrings(t *testing.T) {
assert.Equal(t, "DryRun(true)", fmt.Sprint(DryRun(true)))
assert.Equal(t, "DryRun(false)", fmt.Sprint(DryRun(false)))
})

t.Run("RecoverFromPanics()", func(t *testing.T) {
t.Parallel()

assert.Equal(t, "RecoverFromPanics()", fmt.Sprint(RecoverFromPanics()))
})
}
13 changes: 12 additions & 1 deletion decorate.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func newDecoratorNode(dcor interface{}, s *Scope) (*decoratorNode, error) {
return n, nil
}

func (n *decoratorNode) Call(s containerStore) error {
func (n *decoratorNode) Call(s containerStore) (err error) {
if n.state == decoratorCalled {
return nil
}
Expand All @@ -109,6 +109,17 @@ func (n *decoratorNode) Call(s containerStore) error {
}
}

if n.s.recoverFromPanics {
defer func() {
if p := recover(); p != nil {
err = PanicError{
fn: n.location,
Panic: p,
}
}
}()
}

args, err := n.params.BuildList(n.s)
if err != nil {
return errArgumentsFailed{
Expand Down
73 changes: 73 additions & 0 deletions dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,79 @@ func TestGroups(t *testing.T) {

// --- END OF END TO END TESTS

func TestRecoverFromPanic(t *testing.T) {

tests := []struct {
name string
setup func(*digtest.Container)
invoke interface{}
wantErr []string
}{
{
name: "panic in provided function",
setup: func(c *digtest.Container) {
c.RequireProvide(func() int {
panic("terrible sadness")
})
},
invoke: func(i int) {},
wantErr: []string{
`could not build arguments for function "go.uber.org/dig_test".TestRecoverFromPanic.\S+`,
`failed to build int:`,
`panic: "terrible sadness" in func: "go.uber.org/dig_test".TestRecoverFromPanic.\S+`,
},
},
{
name: "panic in decorator",
setup: func(c *digtest.Container) {
c.RequireProvide(func() string { return "" })
c.RequireDecorate(func(s string) string {
panic("great sadness")
})
},
invoke: func(s string) {},
wantErr: []string{
`could not build arguments for function "go.uber.org/dig_test".TestRecoverFromPanic.\S+`,
`failed to build string:`,
`panic: "great sadness" in func: "go.uber.org/dig_test".TestRecoverFromPanic.\S+`,
},
},
{
name: "panic in invoke",
setup: func(c *digtest.Container) {},
invoke: func() { panic("terrible woe") },
wantErr: []string{
`panic: "terrible woe" in func: "go.uber.org/dig_test".TestRecoverFromPanic.\S+`,
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

t.Run("without option", func(t *testing.T) {
c := digtest.New(t)
tt.setup(c)
assert.Panics(t, func() { c.Container.Invoke(tt.invoke) },
"expected panic without dig.RecoverFromPanics() option",
)
})

t.Run("with option", func(t *testing.T) {
c := digtest.New(t, dig.RecoverFromPanics())
tt.setup(c)
err := c.Container.Invoke(tt.invoke)
require.Error(t, err)
dig.AssertErrorMatches(t, err, tt.wantErr[0], tt.wantErr[1:]...)
var pe dig.PanicError
assert.True(t, errors.As(err, &pe), "expected error chain to contain a PanicError")
_, ok := dig.RootCause(err).(dig.PanicError)
assert.True(t, ok, "expected root cause to be a PanicError")
})
})
}
}

func TestProvideConstructorErrors(t *testing.T) {
t.Run("multiple-type constructor returns multiple objects of same type", func(t *testing.T) {
c := digtest.New(t)
Expand Down
54 changes: 53 additions & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,55 @@ type digError interface {
fmt.Formatter
}

// A PanicError occurs when a panic occurs while running functions given to the container
// with the [RecoverFromPanic] option being set. It contains the panic message from the
// original panic. A PanicError does not wrap other errors, and it does not implement
// dig.Error, meaning it will be returned from [RootCause]. With the [RecoverFromPanic]
// option set, a panic can be distinguished from dig errors and errors from provided/
// invoked/decorated functions like so:
//
// rootCause := dig.RootCause(err)
//
// var pe dig.PanicError
// var de dig.Error
// if errors.As(rootCause, &pe) {
// // This is caused by a panic
// } else if errors.As(err, &de) {
// // This is a dig error
// } else {
// // This is an error from one of my provided/invoked functions or decorators
// }
//
// Or, if only interested in distinguishing panics from errors:
//
// var pe dig.PanicError
// if errors.As(err, &pe) {
// // This is caused by a panic
// } else {
// // This is an error
// }
type PanicError struct {

// The function the panic occurred at
fn *digreflect.Func

// The panic that was returned from recover()
Panic any
}

// Format will format the PanicError, expanding the corresponding function if in +v mode.
func (e PanicError) Format(w fmt.State, c rune) {
if w.Flag('+') && c == 'v' {
fmt.Fprintf(w, "panic: %q in func: %+v", e.Panic, e.fn)
} else {
fmt.Fprintf(w, "panic: %q in func: %v", e.Panic, e.fn)
}
}

func (e PanicError) Error() string {
return fmt.Sprint(e)
}

// formatError will call a dig.Error's writeMessage() method to print the error message
// and then will automatically attempt to print errors wrapped underneath (which can create
// a recursive effect if the wrapped error's Format() method then points back to this function).
Expand Down Expand Up @@ -96,8 +145,11 @@ func formatError(e digError, w fmt.State, v rune) {
// if errors.As(rootCause, &de) {
// // Is a Dig error
// } else {
// // Is an error thrown by one of my provided or invoked functions
// // Is an error thrown by one of my provided/invoked/decorated functions
// }
//
// See [PanicError] for an example showing how to additionally detect
// and handle panics in provided/invoked/decorated functions.
func RootCause(err error) error {
var de Error
// Dig down to first non dig.Error, or bottom of chain
Expand Down
17 changes: 16 additions & 1 deletion invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ type InvokeOption interface {
//
// The function may return an error to indicate failure. The error will be
// returned to the caller as-is.
//
// If the [RecoverFromPanics] option was given to the container and a panic
// occurs when invoking, a [PanicError] with the panic contained will be
// returned. See [PanicError] for more info.
func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error {
return c.scope.Invoke(function, opts...)
}
Expand All @@ -54,7 +58,7 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error {
//
// The function may return an error to indicate failure. The error will be
// returned to the caller as-is.
func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) error {
func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) (err error) {
ftype := reflect.TypeOf(function)
if ftype == nil {
return newErrInvalidInput("can't invoke an untyped nil", nil)
Expand Down Expand Up @@ -90,6 +94,17 @@ func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) error {
Reason: err,
}
}
if s.recoverFromPanics {
defer func() {
if p := recover(); p != nil {
err = PanicError{
fn: digreflect.InspectFunc(function),
Panic: p,
}
}
}()
}

returned := s.invokerFn(reflect.ValueOf(function), args)
if len(returned) == 0 {
return nil
Expand Down
4 changes: 4 additions & 0 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ type Scope struct {
// Defer acyclic check on provide until Invoke.
deferAcyclicVerification bool

// Recover from panics in user-provided code and wrap in an exported error type.
recoverFromPanics bool

// invokerFn calls a function with arguments provided to Provide or Invoke.
invokerFn invokerFn

Expand Down Expand Up @@ -115,6 +118,7 @@ func (s *Scope) Scope(name string, opts ...ScopeOption) *Scope {
child.parentScope = s
child.invokerFn = s.invokerFn
child.deferAcyclicVerification = s.deferAcyclicVerification
child.recoverFromPanics = s.recoverFromPanics

// child copies the parent's graph nodes.
child.gh.nodes = append(child.gh.nodes, s.gh.nodes...)
Expand Down

0 comments on commit c29a62e

Please sign in to comment.