From 35bbb0957658db626fe1699b047aa840872cff52 Mon Sep 17 00:00:00 2001 From: peterneyens Date: Tue, 18 Oct 2016 13:03:46 +0200 Subject: [PATCH] Add MonadError instance for StateT. --- core/src/main/scala/cats/data/StateT.scala | 20 ++++++++++++++++--- .../test/scala/cats/tests/StateTTests.scala | 12 ++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index e6bd6f166d..65150e3a2a 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -174,14 +174,19 @@ private[data] sealed trait StateTInstances1 extends StateTInstances2 { new StateTMonadCombine[F, S] { implicit def F = F0 } } -private[data] sealed trait StateTInstances2 { - implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] = - new StateTMonad[F, S] { implicit def F = F0 } +private[data] sealed trait StateTInstances2 extends StateTInstances3 { + implicit def catsDataMonadErrorForStateT[F[_], S, E](implicit F0: MonadError[F, E]): MonadError[StateT[F, S, ?], E] = + new StateTMonadError[F, S, E] { implicit def F = F0 } implicit def catsDataSemigroupKForStateT[F[_], S](implicit F0: Monad[F], G0: SemigroupK[F]): SemigroupK[StateT[F, S, ?]] = new StateTSemigroupK[F, S] { implicit def F = F0; implicit def G = G0 } } +private[data] sealed trait StateTInstances3 { + implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] = + new StateTMonad[F, S] { implicit def F = F0 } +} + // To workaround SI-7139 `object State` needs to be defined inside the package object // together with the type alias. private[data] abstract class StateFunctions { @@ -258,3 +263,12 @@ private[data] sealed trait StateTMonadCombine[F[_], S] extends MonadCombine[Stat def empty[A]: StateT[F, S, A] = liftT[F, A](F.empty[A]) } + +private[data] sealed trait StateTMonadError[F[_], S, E] extends StateTMonad[F, S] with MonadError[StateT[F, S, ?], E] { + implicit def F: MonadError[F, E] + + def raiseError[A](e: E): StateT[F, S, A] = StateT.lift(F.raiseError(e)) + + def handleErrorWith[A](fa: StateT[F, S, A])(f: E => StateT[F, S, A]): StateT[F, S, A] = + StateT(s => F.handleErrorWith(fa.run(s))(e => f(e).run(s))) +} diff --git a/tests/src/test/scala/cats/tests/StateTTests.scala b/tests/src/test/scala/cats/tests/StateTTests.scala index 62696864e9..8d1e876a92 100644 --- a/tests/src/test/scala/cats/tests/StateTTests.scala +++ b/tests/src/test/scala/cats/tests/StateTTests.scala @@ -1,7 +1,7 @@ package cats package tests -import cats.data.{State, StateT} +import cats.data.{State, StateT, EitherT} import cats.kernel.instances.tuple._ import cats.laws.discipline._ import cats.laws.discipline.eq._ @@ -260,6 +260,16 @@ class StateTTests extends CatsSuite { checkAll("State[Long, ?]", MonadTests[State[Long, ?]].monad[Int, Int, Int]) checkAll("Monad[State[Long, ?]]", SerializableTests.serializable(Monad[State[Long, ?]])) } + + { + // F has a MonadError + implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]] + implicit val eqEitherTFA: Eq[EitherT[StateT[Option, Int , ?], Unit, Int]] = EitherT.catsDataEqForEitherT[StateT[Option, Int , ?], Unit, Int] + + checkAll("StateT[Option, Int, Int]", MonadErrorTests[StateT[Option, Int , ?], Unit].monadError[Int, Int, Int]) + checkAll("MonadError[StateT[Option, Int , ?], Unit]", SerializableTests.serializable(MonadError[StateT[Option, Int , ?], Unit])) + } + } object StateTTests extends StateTTestsInstances {