Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise NonEmptyTraverse implementation #3382

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions core/src/main/scala-2.13+/cats/data/NonEmptyLazyList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -473,16 +473,14 @@ sealed abstract private[data] class NonEmptyLazyListInstances extends NonEmptyLa

def extract[A](fa: NonEmptyLazyList[A]): A = fa.head

def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyLazyList[A])(f: A => G[B]): G[NonEmptyLazyList[B]] =
Foldable[LazyList]
.reduceRightToOption[A, G[LazyList[B]]](fa.tail)(a => Apply[G].map(f(a))(LazyList.apply(_))) { (a, lglb) =>
Apply[G].map2Eval(f(a), lglb)(_ +: _)
def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyLazyList[A])(f: A => G[B]): G[NonEmptyLazyList[B]] = {
def loop(head: A, tail: LazyList[A]): Eval[G[NonEmptyLazyList[B]]] =
tail.headOption.fold(Eval.now(Apply[G].map(f(head))(NonEmptyLazyList(_)))) { h =>
Apply[G].map2Eval(f(head), Eval.defer(loop(h, tail.tail)))((b, acc) => b +: acc)
}
Comment on lines +478 to 480
Copy link
Contributor

@diesalbla diesalbla Apr 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bit of LISP here. This code may be more readable as follows:

Suggested change
tail.headOption.fold(Eval.now(Apply[G].map(f(head))(NonEmptyLazyList(_)))) { h =>
Apply[G].map2Eval(f(head), Eval.defer(loop(h, tail.tail)))((b, acc) => b +: acc)
}
val fh = f(head)
tail match {
case _ if tail.isEmpty =>
Eval.now(Apply[G].map(fh)(NonEmptyLazyList(_)))
case h #:: ttail =>
val ftail = Eval.defer(loop(h, ttail))
Apply[G].map2Eval(fh, ftail)(_ +: _)
}

The changes are:

  • Extract f(head), which is eagerly evaluated in both branches, as a val.
  • Replace the Option.fold by a pattern-match.
  • Replace the cases of headOption by two cases on the tail lazy list itself. For the first case, there seems not to be any case-object, like Nil, for the empty list. Discussion.
  • In the second case, abbreviate (b, acc) => b +: acc by underscores _ +: _.

Copy link
Contributor Author

@gagandeepkalra gagandeepkalra Jun 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I see what you mean there.

On my part, usually I try to avoid pattern match (it's slow)

.map {
case None => Apply[G].map(f(fa.head))(h => create(LazyList(h)))
case Some(gtail) => Apply[G].map2(f(fa.head), gtail)((h, t) => create(LazyList(h) ++ t))
}
.value

loop(fa.head, fa.tail).value
}

def reduceLeftTo[A, B](fa: NonEmptyLazyList[A])(f: A => B)(g: (B, A) => B): B = fa.reduceLeftTo(f)(g)

Expand Down
16 changes: 7 additions & 9 deletions core/src/main/scala/cats/data/NonEmptyChain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -420,16 +420,14 @@ sealed abstract private[data] class NonEmptyChainInstances extends NonEmptyChain
new AbstractNonEmptyInstances[Chain, NonEmptyChain] with Align[NonEmptyChain] {
def extract[A](fa: NonEmptyChain[A]): A = fa.head

def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyChain[A])(f: A => G[B]): G[NonEmptyChain[B]] =
Foldable[Chain]
.reduceRightToOption[A, G[Chain[B]]](fa.tail)(a => Apply[G].map(f(a))(Chain.one)) { (a, lglb) =>
Apply[G].map2Eval(f(a), lglb)(_ +: _)
def nonEmptyTraverse[G[_]: Apply, A, B](fa: NonEmptyChain[A])(f: A => G[B]): G[NonEmptyChain[B]] = {
def loop(head: A, tail: Chain[A]): Eval[G[NonEmptyChain[B]]] =
tail.uncons.fold(Eval.now(Apply[G].map(f(head))(NonEmptyChain(_)))) {
case (h, t) => Apply[G].map2Eval(f(head), Eval.defer(loop(h, t)))((b, acc) => b +: acc)
}
.map {
case None => Apply[G].map(f(fa.head))(NonEmptyChain.one)
case Some(gtail) => Apply[G].map2(f(fa.head), gtail)((h, t) => create(Chain.one(h) ++ t))
}
.value

loop(fa.head, fa.tail).value
}

override def size[A](fa: NonEmptyChain[A]): Long = fa.length

Expand Down
17 changes: 8 additions & 9 deletions core/src/main/scala/cats/data/NonEmptyList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -562,16 +562,15 @@ sealed abstract private[data] class NonEmptyListInstances extends NonEmptyListIn

def extract[A](fa: NonEmptyList[A]): A = fa.head

def nonEmptyTraverse[G[_], A, B](nel: NonEmptyList[A])(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyList[B]] =
Foldable[List]
.reduceRightToOption[A, G[List[B]]](nel.tail)(a => G.map(f(a))(_ :: Nil)) { (a, lglb) =>
G.map2Eval(f(a), lglb)(_ :: _)
def nonEmptyTraverse[G[_], A, B](nel: NonEmptyList[A])(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyList[B]] = {
def loop(head: A, tail: List[A]): Eval[G[NonEmptyList[B]]] =
tail match {
case Nil => Eval.now(G.map(f(head))(NonEmptyList(_, Nil)))
case h :: t => G.map2Eval(f(head), Eval.defer(loop(h, t)))((b, acc) => NonEmptyList(b, acc.toList))
}
.map {
case None => G.map(f(nel.head))(NonEmptyList(_, Nil))
case Some(gtail) => G.map2(f(nel.head), gtail)(NonEmptyList(_, _))
}
.value

loop(nel.head, nel.tail).value
}

override def traverse[G[_], A, B](
fa: NonEmptyList[A]
Expand Down
19 changes: 9 additions & 10 deletions core/src/main/scala/cats/data/NonEmptyMapImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,15 @@ sealed class NonEmptyMapOps[K, A](val value: NonEmptyMap[K, A]) {
* through the running of this function on all the values in this map,
* returning an NonEmptyMap[K, B] in a G context.
*/
def nonEmptyTraverse[G[_], B](f: A => G[B])(implicit G: Apply[G]): G[NonEmptyMap[K, B]] =
reduceRightToOptionWithKey[A, G[SortedMap[K, B]]](tail)({
case (k, a) =>
G.map(f(a))(b => SortedMap.empty[K, B] + ((k, b)))
}) { (t, lglb) =>
G.map2Eval(f(t._2), lglb)((b, bs) => bs + ((t._1, b)))
}.map {
case None => G.map(f(head._2))(a => NonEmptyMapImpl.one(head._1, a))
case Some(gtail) => G.map2(f(head._2), gtail)((a, bs) => NonEmptyMapImpl((head._1, a), bs))
}.value
def nonEmptyTraverse[G[_], B](f: A => G[B])(implicit G: Apply[G]): G[NonEmptyMap[K, B]] = {
def loop(h: (K, A), t: SortedMap[K, A]): Eval[G[NonEmptyMap[K, B]]] =
if (t.isEmpty)
Eval.now(G.map(f(h._2))(b => NonEmptyMap(h._1 -> b, SortedMap.empty[K, B])))
else
G.map2Eval(f(h._2), Eval.defer(loop(t.head, t.tail)))((b, acc) => NonEmptyMap(h._1 -> b, acc.toSortedMap))

loop(head, tail).value
}

/**
* Typesafe stringification method.
Expand Down
20 changes: 9 additions & 11 deletions core/src/main/scala/cats/data/NonEmptyVector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -371,17 +371,15 @@ sealed abstract private[data] class NonEmptyVectorInstances {
def extract[A](fa: NonEmptyVector[A]): A = fa.head

def nonEmptyTraverse[G[_], A, B](
nel: NonEmptyVector[A]
)(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyVector[B]] =
Foldable[Vector]
.reduceRightToOption[A, G[Vector[B]]](nel.tail)(a => G.map(f(a))(_ +: Vector.empty)) { (a, lglb) =>
G.map2Eval(f(a), lglb)(_ +: _)
}
.map {
case None => G.map(f(nel.head))(NonEmptyVector(_, Vector.empty))
case Some(gtail) => G.map2(f(nel.head), gtail)(NonEmptyVector(_, _))
}
.value
nev: NonEmptyVector[A]
)(f: A => G[B])(implicit G: Apply[G]): G[NonEmptyVector[B]] = {
def loop(head: A, tail: Vector[A]): Eval[G[NonEmptyVector[B]]] =
tail.headOption.fold(Eval.now(G.map(f(head))(NonEmptyVector(_, Vector.empty[B]))))(h =>
G.map2Eval(f(head), Eval.defer(loop(h, tail.tail)))((b, acc) => b +: acc)
)

loop(nev.head, nev.tail).value
}

override def traverse[G[_], A, B](
fa: NonEmptyVector[A]
Expand Down
17 changes: 14 additions & 3 deletions core/src/main/scala/cats/data/OneAnd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,20 @@ sealed abstract private[data] class OneAndLowPriority0 extends OneAndLowPriority
F2: Alternative[F]
): NonEmptyTraverse[OneAnd[F, *]] =
new NonEmptyReducible[OneAnd[F, *], F] with NonEmptyTraverse[OneAnd[F, *]] {
def nonEmptyTraverse[G[_], A, B](fa: OneAnd[F, A])(f: (A) => G[B])(implicit G: Apply[G]): G[OneAnd[F, B]] =
fa.map(a => Apply[G].map(f(a))(OneAnd(_, F2.empty[B])))(F)
.reduceLeft(((acc, a) => G.map2(acc, a)((x: OneAnd[F, B], y: OneAnd[F, B]) => x.combine(y))))
def nonEmptyTraverse[G[_], A, B](fa: OneAnd[F, A])(f: (A) => G[B])(implicit G: Apply[G]): G[OneAnd[F, B]] = {
import syntax.foldable._

def loop(head: A, tail: Iterator[A]): Eval[G[OneAnd[F, B]]] =
if (tail.hasNext) {
val h = tail.next()
val t = tail
G.map2Eval(f(head), Eval.defer(loop(h, t)))((b, acc) => OneAnd(b, acc.unwrap))
} else {
Eval.now(G.map(f(head))(OneAnd(_, F2.empty[B])))
}

loop(fa.head, fa.tail.toIterable.iterator).value
}

override def traverse[G[_], A, B](fa: OneAnd[F, A])(f: (A) => G[B])(implicit G: Applicative[G]): G[OneAnd[F, B]] =
G.map2Eval(f(fa.head), Always(F.traverse(fa.tail)(f)))(OneAnd(_, _)).value
Expand Down
21 changes: 20 additions & 1 deletion laws/src/main/scala/cats/laws/ShortCircuitingLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import java.util.concurrent.atomic.AtomicLong
import cats.instances.option._
import cats.syntax.foldable._
import cats.syntax.traverse._
import cats.syntax.nonEmptyTraverse._
import cats.syntax.traverseFilter._
import cats.{Applicative, Foldable, MonoidK, Traverse, TraverseFilter}
import cats.{Applicative, Foldable, MonoidK, NonEmptyTraverse, Traverse, TraverseFilter}

trait ShortCircuitingLaws[F[_]] {

Expand Down Expand Up @@ -46,6 +47,24 @@ trait ShortCircuitingLaws[F[_]] {
f.invocations.get <-> size
}

def nonEmptyTraverseShortCircuits[A](fa: F[A])(implicit F: NonEmptyTraverse[F]): IsEq[Long] = {
val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((i: A) => Some(i), maxInvocationsAllowed, None)

fa.nonEmptyTraverse(f)
f.invocations.get <-> (maxInvocationsAllowed + 1).min(size)
}

def nonEmptyTraverseWontShortCircuit[A](fa: F[A])(implicit F: NonEmptyTraverse[F]): IsEq[Long] = {
val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((i: A) => Some(i), maxInvocationsAllowed, None)

fa.nonEmptyTraverse(f)(nonShortCircuitingApplicative)
f.invocations.get <-> size
}

def traverseFilterShortCircuits[A](fa: F[A])(implicit TF: TraverseFilter[F]): IsEq[Long] = {
implicit val F: Traverse[F] = TF.traverse

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cats.laws.discipline

import cats.laws.ShortCircuitingLaws
import cats.{Eq, Foldable, Traverse, TraverseFilter}
import cats.{Eq, Foldable, NonEmptyTraverse, Traverse, TraverseFilter}
import org.scalacheck.Arbitrary
import org.scalacheck.Prop.forAll
import org.typelevel.discipline.Laws
Expand All @@ -25,11 +25,17 @@ trait ShortCircuitingTests[F[_]] extends Laws {
"traverse won't short-circuit if Applicative[G].map2Eval won't" -> forAll(laws.traverseWontShortCircuit[A] _)
)

def traverseFilter[A: Arbitrary](implicit
TF: TraverseFilter[F],
ArbFA: Arbitrary[F[A]],
lEq: Eq[Long]
): RuleSet = {
def nonEmptyTraverse[A: Arbitrary](implicit TF: NonEmptyTraverse[F], ArbFA: Arbitrary[F[A]], lEq: Eq[Long]): RuleSet =
new DefaultRuleSet(
name = "nonEmptyTraverseShortCircuiting",
parent = Some(traverse[A]),
"nonEmptyTraverse short-circuits if Applicative[G].map2Eval shorts" ->
forAll(laws.nonEmptyTraverseShortCircuits[A] _),
"nonEmptyTraverse short-circuits if Applicative[G].map2Eval won't" ->
forAll(laws.nonEmptyTraverseWontShortCircuit[A] _)
)

def traverseFilter[A: Arbitrary](implicit TF: TraverseFilter[F], ArbFA: Arbitrary[F[A]], lEq: Eq[Long]): RuleSet = {
implicit val T: Traverse[F] = TF.traverse
new DefaultRuleSet(
name = "traverseFilterShortCircuiting",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class NonEmptyStreamSuite extends CatsSuite {

checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].foldable[Int])
checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].traverse[Int])
checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].nonEmptyTraverse[Int])

{
// Test functor and subclasses don't have implicit conflicts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class NonEmptyLazyListSuite extends NonEmptyCollectionSuite[LazyList, NonEmptyLa

checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].foldable[Int])
checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].traverse[Int])
checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].nonEmptyTraverse[Int])

test("show") {
Show[NonEmptyLazyList[Int]].show(NonEmptyLazyList(1, 2, 3)) should ===("NonEmptyLazyList(1, ?)")
Expand Down
1 change: 1 addition & 0 deletions tests/src/test/scala/cats/tests/NonEmptyChainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class NonEmptyChainSuite extends NonEmptyCollectionSuite[Chain, NonEmptyChain, N

checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].foldable[Int])
checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].traverse[Int])
checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].nonEmptyTraverse[Int])

{
implicit val partialOrder: PartialOrder[ListWrapper[Int]] = ListWrapper.partialOrder[Int]
Expand Down
1 change: 1 addition & 0 deletions tests/src/test/scala/cats/tests/NonEmptyListSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class NonEmptyListSuite extends NonEmptyCollectionSuite[List, NonEmptyList, NonE

checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].foldable[Int])
checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].traverse[Int])
checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].nonEmptyTraverse[Int])

{
implicit val A: PartialOrder[ListWrapper[Int]] = ListWrapper.partialOrder[Int]
Expand Down
2 changes: 2 additions & 0 deletions tests/src/test/scala/cats/tests/NonEmptyMapSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class NonEmptyMapSuite extends CatsSuite {
checkAll("NonEmptyMap[String, Int]", AlignTests[NonEmptyMap[String, *]].align[Int, Int, Int, Int])
checkAll("Align[NonEmptyMap]", SerializableTests.serializable(Align[NonEmptyMap[String, *]]))

checkAll("NonEmptyMap[Int, *]", ShortCircuitingTests[NonEmptyMap[Int, *]].nonEmptyTraverse[Int])

test("Show is not empty and is formatted as expected") {
forAll { (nem: NonEmptyMap[String, Int]) =>
nem.show.nonEmpty should ===(true)
Expand Down
1 change: 1 addition & 0 deletions tests/src/test/scala/cats/tests/NonEmptyVectorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class NonEmptyVectorSuite extends NonEmptyCollectionSuite[Vector, NonEmptyVector

checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].foldable[Int])
checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].traverse[Int])
checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].nonEmptyTraverse[Int])

test("size is consistent with toList.size") {
forAll { (nonEmptyVector: NonEmptyVector[Int]) =>
Expand Down