Skip to content

Commit

Permalink
Merge pull request #1414 from peterneyens/lazy-foldM
Browse files Browse the repository at this point in the history
Lazy foldM for "Iterables"
  • Loading branch information
kailuowang authored Oct 22, 2016
2 parents 7ea2024 + 6f3da09 commit f90245d
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 6 deletions.
41 changes: 41 additions & 0 deletions core/src/main/scala/cats/Foldable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ import simulacrum.typeclass

/**
* Left associative monadic folding on `F`.
*
* The default implementation of this is based on `foldLeft`, and thus will
* always fold across the entire structure. Certain structures are able to
* implement this in such a way that folds can be short-circuited (not
* traverse the entirety of the structure), depending on the `G` result
* produced at a given step.
*/
def foldM[G[_], A, B](fa: F[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
foldLeft(fa, G.pure(z))((gb, a) => G.flatMap(gb)(f(_, a)))
Expand Down Expand Up @@ -372,4 +378,39 @@ object Foldable {
Eval.defer(if (it.hasNext) f(it.next, loop()) else lb)
loop()
}

/**
* Implementation of [[Foldable.foldM]] which can short-circuit for
* structures with an `Iterator`.
*
* For example we can sum a `Stream` of integers and stop if
* the sum reaches 100 (if we reach the end of the `Stream`
* before getting to 100 we return the total sum) :
*
* {{{
* scala> import cats.implicits._
* scala> type LongOr[A] = Either[Long, A]
* scala> def sumStream(s: Stream[Int]): Long =
* | Foldable.iteratorFoldM[LongOr, Int, Long](s.toIterator, 0L){ (acc, n) =>
* | val sum = acc + n
* | if (sum < 100L) Right(sum) else Left(sum)
* | }.merge
*
* scala> sumStream(Stream.continually(1))
* res0: Long = 100
*
* scala> sumStream(Stream(1,2,3,4))
* res1: Long = 10
* }}}
*
* Note that `Foldable[Stream].foldM` uses this method underneath, so
* you wouldn't call this method explicitly like in the example above.
*/
def iteratorFoldM[M[_], A, B](it: Iterator[A], z: B)(f: (B, A) => M[B])(implicit M: Monad[M]): M[B] = {
val go: B => M[Either[B, B]] = { b =>
if (it.hasNext) M.map(f(b, it.next))(Left(_))
else M.pure(Right(b))
}
M.tailRecM(z)(go)
}
}
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/list.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ trait ListInstances extends cats.kernel.instances.ListInstances {
override def isEmpty[A](fa: List[A]): Boolean = fa.isEmpty

override def filter[A](fa: List[A])(f: A => Boolean): List[A] = fa.filter(f)

override def foldM[G[_], A, B](fa: List[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
}

implicit def catsStdShowForList[A:Show]: Show[List[A]] =
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ trait MapInstances extends cats.kernel.instances.MapInstances {
override def size[A](fa: Map[K, A]): Long = fa.size.toLong

override def isEmpty[A](fa: Map[K, A]): Boolean = fa.isEmpty

override def foldM[G[_], A, B](fa: Map[K, A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.valuesIterator, z)(f)
}
// scalastyle:on method.length
}
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/set.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ trait SetInstances extends cats.kernel.instances.SetInstances {
fa.forall(p)

override def isEmpty[A](fa: Set[A]): Boolean = fa.isEmpty

override def foldM[G[_], A, B](fa: Set[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
}

implicit def catsStdShowForSet[A:Show]: Show[Set[A]] = new Show[Set[A]] {
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ trait StreamInstances extends cats.kernel.instances.StreamInstances {
override def filter[A](fa: Stream[A])(f: A => Boolean): Stream[A] = fa.filter(f)

override def collect[A, B](fa: Stream[A])(f: PartialFunction[A, B]): Stream[B] = fa.collect(f)

override def foldM[G[_], A, B](fa: Stream[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
}

implicit def catsStdShowForStream[A: Show]: Show[Stream[A]] =
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ trait VectorInstances extends cats.kernel.instances.VectorInstances {
override def filter[A](fa: Vector[A])(f: A => Boolean): Vector[A] = fa.filter(f)

override def collect[A, B](fa: Vector[A])(f: PartialFunction[A, B]): Vector[B] = fa.collect(f)

override def foldM[G[_], A, B](fa: Vector[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
Foldable.iteratorFoldM(fa.toIterator, z)(f)
}

implicit def catsStdShowForVector[A:Show]: Show[Vector[A]] =
Expand Down
36 changes: 30 additions & 6 deletions tests/src/test/scala/cats/tests/FoldableTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,35 @@ class FoldableTestsAdditional extends CatsSuite {
larger.value should === (large.map(_ + 1))
}

test("Foldable[List].foldM stack safety") {
def nonzero(acc: Long, x: Long): Option[Long] =
def checkFoldMStackSafety[F[_]](fromRange: Range => F[Int])(implicit F: Foldable[F]): Unit = {
def nonzero(acc: Long, x: Int): Option[Long] =
if (x == 0) None else Some(acc + x)

val n = 100000L
val expected = n*(n+1)/2
val actual = Foldable[List].foldM((1L to n).toList, 0L)(nonzero)
assert(actual.get == expected)
val n = 100000
val expected = n.toLong*(n.toLong+1)/2
val foldMResult = F.foldM(fromRange(1 to n), 0L)(nonzero)
assert(foldMResult.get == expected)
()
}

test("Foldable[List].foldM stack safety") {
checkFoldMStackSafety[List](_.toList)
}

test("Foldable[Stream].foldM stack safety") {
checkFoldMStackSafety[Stream](_.toStream)
}

test("Foldable[Vector].foldM stack safety") {
checkFoldMStackSafety[Vector](_.toVector)
}

test("Foldable[Set].foldM stack safety") {
checkFoldMStackSafety[Set](_.toSet)
}

test("Foldable[Map[String, ?]].foldM stack safety") {
checkFoldMStackSafety[Map[String, ?]](_.map(x => x.toString -> x).toMap)
}

test("Foldable[Stream]") {
Expand All @@ -141,6 +162,9 @@ class FoldableTestsAdditional extends CatsSuite {
// test trampolining
val large = Stream((1 to 10000): _*)
assert(contains(large, 10000).value)

// test laziness of foldM
dangerous.foldM(0)((acc, a) => if (a < 2) Some(acc + a) else None) should === (None)
}
}

Expand Down

0 comments on commit f90245d

Please sign in to comment.