Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize traverse #4498

Merged
merged 31 commits into from May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
858cca4
Optimize List traverse for stack-safe Monads
TimWSpence Aug 18, 2023
847df54
Optimize {List, Vector} traverse for StackSafeMonads
TimWSpence Aug 21, 2023
878c4d0
Optimize {Queue, Map} traverse for StackSafeMonads
TimWSpence Aug 21, 2023
9d86c98
Optimize Seq traverse for StackSafeMonads
TimWSpence Aug 21, 2023
c4bcba2
Optimize ArraySeq traverse for StackSafeMonads
TimWSpence Aug 22, 2023
c4c64ec
Optimize Chain traverse for StackSafeMonads
TimWSpence Aug 22, 2023
1344a70
Add size hint to StackSafeMonad traverse
TimWSpence Aug 22, 2023
b6973dc
Optimize List traverseFilter for StackSafeMonads
TimWSpence Aug 22, 2023
f0648c4
Optimize {Seq, Queue, Vector, ArraySeq} traverseFilter for
TimWSpence Aug 22, 2023
dbccc8e
Optimize Chain filterTraverse for StackSafeMonads
TimWSpence Aug 22, 2023
fbf5ea5
Optimize List traverse_ for StackSafeMonads
TimWSpence Aug 22, 2023
9f60d2b
Optimize Vector traverse_ for StackSafeMonads
TimWSpence Aug 23, 2023
f4fcf57
Optimize {Seq, ArraySeq, Chain, Queue} traverse_ for StackSafeMonads
TimWSpence Aug 23, 2023
d50bf83
Add extra laws tests so we test the multiple branches corresponding to
TimWSpence Aug 23, 2023
5c5049f
Add benchmarks for traverse_ and for Chain
TimWSpence Aug 23, 2023
a1c9ff5
Use .void over .map(_ => ()) to give instances chance to optimize
TimWSpence Aug 25, 2023
8cc742e
Experiment: use immutable data structures in optimized traverse
TimWSpence Nov 8, 2023
e3ae584
Experiment: use immutable List for optimized traverse
TimWSpence Nov 10, 2023
f95b087
Use applicative methods instead of flatMap
TimWSpence Nov 14, 2023
c5d90a6
Experiment: use Chain for aggregating our optimized traverse
TimWSpence Dec 8, 2023
2d5f4d7
Implement optimized traverseFilter in terms of Chain as well
TimWSpence Dec 13, 2023
702ab8b
Vector-based optimized traverse and traverseFilter in the same commit
TimWSpence Dec 19, 2023
b08196e
Retroactive attempt to establish a baseline including the new benchmarks
TimWSpence Dec 22, 2023
92c91e4
Revert "Retroactive attempt to establish a baseline including the new…
TimWSpence Mar 4, 2024
7d615dc
Merge remote-tracking branch 'upstream/main' into optimize-traverse
TimWSpence Mar 4, 2024
eebed5b
Use Applicative#unit to limit allocations
TimWSpence Mar 4, 2024
803107d
More Applicative#unit to limit allocations
TimWSpence Mar 5, 2024
0ee49e6
Remove mutable builder from Map traverse as its unlawful
TimWSpence Mar 7, 2024
6529be7
Use Chain in `traverseDirectly`
valencik Mar 22, 2024
a62ce4c
Use `toList` in seq traverse, not vector
valencik Mar 22, 2024
6cb787d
Merge pull request #2 from valencik/more-chain
TimWSpence Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 87 additions & 0 deletions bench/src/main/scala/cats/bench/TraverseBench.scala
Expand Up @@ -24,15 +24,19 @@ package cats.bench
import cats.{Eval, Traverse, TraverseFilter}
import org.openjdk.jmh.annotations.{Benchmark, Param, Scope, Setup, State}
import org.openjdk.jmh.infra.Blackhole
import cats.data.Chain

@State(Scope.Benchmark)
class TraverseBench {
val listT: Traverse[List] = Traverse[List]
val listTFilter: TraverseFilter[List] = TraverseFilter[List]
val chainTFilter: TraverseFilter[Chain] = TraverseFilter[Chain]

val vectorT: Traverse[Vector] = Traverse[Vector]
val vectorTFilter: TraverseFilter[Vector] = TraverseFilter[Vector]

val chainT: Traverse[Chain] = Traverse[Chain]

// the unit of CPU work per iteration
private[this] val Work: Long = 10

Expand All @@ -43,11 +47,13 @@ class TraverseBench {

var list: List[Int] = _
var vector: Vector[Int] = _
var chain: Chain[Int] = _

@Setup
def setup(): Unit = {
list = 0.until(length).toList
vector = 0.until(length).toVector
chain = Chain.fromSeq(0.until(length))
}

@Benchmark
Expand Down Expand Up @@ -83,6 +89,18 @@ class TraverseBench {
}
}

@Benchmark
def traverse_List(bh: Blackhole) = {
val result = listT.traverse_(list) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverseFilterList(bh: Blackhole) = {
val result = listTFilter.traverseFilter(list) { i =>
Expand Down Expand Up @@ -137,6 +155,18 @@ class TraverseBench {
bh.consume(result.value)
}

@Benchmark
def traverse_Vector(bh: Blackhole) = {
val result = vectorT.traverse_(vector) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverseVectorError(bh: Blackhole) = {
val result = vectorT.traverse(vector) { i =>
Expand Down Expand Up @@ -199,4 +229,61 @@ class TraverseBench {

bh.consume(results)
}

@Benchmark
def traverseChain(bh: Blackhole) = {
val result = chainT.traverse(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverse_Chain(bh: Blackhole) = {
val result = chainT.traverse_(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverseChainError(bh: Blackhole) = {
val result = chainT.traverse(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)

if (i == length * 0.3) {
throw Failure
}

i * 2
}
}

try {
bh.consume(result.value)
} catch {
case Failure => ()
}
}

@Benchmark
def traverseFilterChain(bh: Blackhole) = {
val result = chainTFilter.traverseFilter(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
if (i % 2 == 0) Some(i * 2) else None
}
}

bh.consume(result.value)
}
}
33 changes: 29 additions & 4 deletions core/src/main/scala-2.13+/cats/instances/arraySeq.scala
Expand Up @@ -102,7 +102,24 @@ private[cats] object ArraySeqInstances {
B.combineAll(fa.iterator.map(f))

def traverse[G[_], A, B](fa: ArraySeq[A])(f: A => G[B])(implicit G: Applicative[G]): G[ArraySeq[B]] =
G.map(Chain.traverseViaChain(fa)(f))(_.iterator.to(ArraySeq.untagged))
G match {
case x: StackSafeMonad[G] =>
x.map(Traverse.traverseDirectly(Vector.newBuilder[B])(fa.iterator)(f)(x))(_.to(ArraySeq.untagged))
case _ =>
G.map(Chain.traverseViaChain(fa)(f))(_.iterator.to(ArraySeq.untagged))

}

override def traverse_[G[_], A, B](fa: ArraySeq[A])(f: A => G[B])(implicit G: Applicative[G]): G[Unit] =
G match {
case x: StackSafeMonad[G] => Traverse.traverse_Directly(fa)(f)(x)
case _ =>
foldRight(fa, Always(G.pure(()))) { (a, acc) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
foldRight(fa, Always(G.pure(()))) { (a, acc) =>
foldRight(fa, Eval.now(G.unit)) { (a, acc) =>

G.map2Eval(f(a), acc) { (_, _) =>
()
}
}.value
}

override def mapAccumulate[S, A, B](init: S, fa: ArraySeq[A])(f: (S, A) => (S, B)): (S, ArraySeq[B]) =
StaticMethods.mapAccumulateFromStrictFunctor(init, fa, f)(this)
Expand Down Expand Up @@ -214,9 +231,17 @@ private[cats] object ArraySeqInstances {
def traverseFilter[G[_], A, B](
fa: ArraySeq[A]
)(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[ArraySeq[B]] =
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[B]))) { case (x, xse) =>
G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ +: o))
}.value
G match {
case x: StackSafeMonad[G] =>
x.map(TraverseFilter.traverseFilterDirectly(Vector.newBuilder[B])(fa.iterator)(f)(x))(
_.to(ArraySeq.untagged)
)
case _ =>
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[B]))) { case (x, xse) =>
G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ +: o))
}.value

}

override def filterA[G[_], A](fa: ArraySeq[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[ArraySeq[A]] =
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[A]))) { case (x, xse) =>
Expand Down
34 changes: 34 additions & 0 deletions core/src/main/scala/cats/Traverse.scala
Expand Up @@ -23,6 +23,9 @@ package cats

import cats.data.State
import cats.data.StateT
import cats.kernel.compat.scalaVersionSpecific._
import cats.StackSafeMonad
import scala.collection.mutable

/**
* Traverse, also known as Traversable.
Expand Down Expand Up @@ -284,4 +287,35 @@ object Traverse {
@deprecated("Use cats.syntax object imports", "2.2.0")
object nonInheritedOps extends ToTraverseOps

private[cats] def traverseDirectly[Coll[x] <: IterableOnce[x], G[_], A, B](
builder: mutable.Builder[B, Coll[B]]
)(fa: IterableOnce[A])(f: A => G[B])(implicit G: StackSafeMonad[G]): G[Coll[B]] = {
val size = fa.knownSize
if (size >= 0) {
builder.sizeHint(size)
}
G.map(fa.iterator.foldLeft(G.pure(builder)) { case (accG, a) =>
G.flatMap(accG) { acc =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment why not G.map2(accG, f(a))(_ :: _)

Since we know G is lazy (a StackSafeMonad has to be I think), I'm not sure the downside here. One answer would be we don't have to call f if we have already failed for a short-circuiting monad, but we are still iterating the whole list, so we are doing O(N) work. Adding the call to allocate the Monad doesn't seem like a big problem, since we have to allocate the function to pass to flatMap in the current case.

By calling map2 we are at least communicating to G what we are doing, and in principle, some monads could optimize this (e.g. a Parser can make a more optimized map2 than flatMap, and it can also be StackSafe since runs lazily only when input is passed to the resulting parser).

G.map(f(a)) { a =>
acc += a
acc
}
}
})(_.result())
}

private[cats] def traverse_Directly[G[_], A, B](
fa: IterableOnce[A]
)(f: A => G[B])(implicit G: StackSafeMonad[G]): G[Unit] = {
val iter = fa.iterator
if (iter.hasNext) {
val first = iter.next()
G.map(iter.foldLeft(f(first)) { case (g, a) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be G.void(.. so that has a chance for an optimized implementation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good catch!

G.flatMap(g) { _ =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not G.productR(g, f(a)) here? Is it to avoid calling f when g may have already failed? I think a comment would be helpful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was no good reason 😅 Thanks, I'll change it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it to avoid calling f when g may have already failed?

Maybe this is something we should consider though?

Edit: ahh, I see your comment on f95b087. Fair enough.

f(a)
}
})(_ => ())
} else G.unit
}

}
19 changes: 19 additions & 0 deletions core/src/main/scala/cats/TraverseFilter.scala
Expand Up @@ -22,8 +22,10 @@
package cats

import cats.data.State
import cats.kernel.compat.scalaVersionSpecific._

import scala.collection.immutable.{IntMap, TreeSet}
import scala.collection.mutable

/**
* `TraverseFilter`, also known as `Witherable`, represents list-like structures
Expand Down Expand Up @@ -203,4 +205,21 @@ object TraverseFilter {
@deprecated("Use cats.syntax object imports", "2.2.0")
object nonInheritedOps extends ToTraverseFilterOps

private[cats] def traverseFilterDirectly[Coll[x] <: IterableOnce[x], G[_], A, B](
builder: mutable.Builder[B, Coll[B]]
)(fa: IterableOnce[A])(f: A => G[Option[B]])(implicit G: StackSafeMonad[G]): G[Coll[B]] = {
val size = fa.knownSize
if (size >= 0) {
builder.sizeHint(size)
}
G.map(fa.iterator.foldLeft(G.pure(builder)) { case (bldrG, a) =>
G.flatMap(bldrG) { bldr =>
G.map(f(a)) {
case Some(b) => bldr += b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't see why we know this is safe or lawful.

We have a mutable data structure that could potentially be in a multithreaded situation with G.

Also, consider cases like IO where you have a long computation that finally fails, then you recover some part of it to succeed. It feels like this mutable builder could remember things from failed branches.

I think using an immutable builder (like for instance just building up a List, Chain or Vector) would be much easier to verify it is lawful.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a mutable data structure that could potentially be in a multithreaded situation with G.

It can't be: every append to the builder happens in flatMap so by law it must be sequential, not concurrent/parallel.

Also, consider cases like IO where you have a long computation that finally fails, then you recover some part of it to succeed. It feels like this mutable builder could remember things from failed branches.

Hmm. I'm not entirely sure how this applies to traverse. It doesn't have a notion of "recovery". For sure, each individual step that the traverse runs may have a notion of recovery, but that's just it: it will either succeed or it will fail. But there's no way to "recover" an intermediate structure from the traverse itself.

I think using an immutable builder (like for instance just building up a List, Chain or Vector) would be much easier to verify it is lawful.

Won't disagree here. We could do a benchmark to see how much performance we are leaving on the table with that strategy.

case None => bldr
}
}
})(_.result())
}

}
41 changes: 31 additions & 10 deletions core/src/main/scala/cats/data/Chain.scala
Expand Up @@ -1241,11 +1241,27 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
def traverse[G[_], A, B](fa: Chain[A])(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else
traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
G match {
johnynek marked this conversation as resolved.
Show resolved Hide resolved
case x: StackSafeMonad[G] =>
x.map(Traverse.traverseDirectly(List.newBuilder[B])(fa.iterator)(f)(x))(Chain.fromSeq(_))
case _ =>
traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
}

override def traverse_[G[_], A, B](fa: Chain[A])(f: A => G[B])(implicit G: Applicative[G]): G[Unit] =
G match {
johnynek marked this conversation as resolved.
Show resolved Hide resolved
case x: StackSafeMonad[G] => Traverse.traverse_Directly(fa.iterator)(f)(x)
case _ =>
foldRight(fa, Always(G.pure(()))) { (a, acc) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Eval.now(G.unit) I think is going to have 1 allocation in the common case of override val unit ..., vs Alway(G.pure(())) having 3 allocations: 1: Always, 2: thunk, 3. pure wrapper.

G.map2Eval(f(a), acc) { (_, _) =>
()
}
}.value
}

override def mapAccumulate[S, A, B](init: S, fa: Chain[A])(f: (S, A) => (S, B)): (S, Chain[B]) =
StaticMethods.mapAccumulateFromStrictFunctor(init, fa, f)(this)
Expand Down Expand Up @@ -1354,11 +1370,16 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
def traverseFilter[G[_], A, B](fa: Chain[A])(f: A => G[Option[B]])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else
traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
G match {
case x: StackSafeMonad[G] =>
G.map(TraverseFilter.traverseFilterDirectly(List.newBuilder[B])(fa.iterator)(f)(x))(Chain.fromSeq(_))
case _ =>
johnynek marked this conversation as resolved.
Show resolved Hide resolved
traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
}

override def filterA[G[_], A](fa: Chain[A])(f: A => G[Boolean])(implicit G: Applicative[G]): G[Chain[A]] =
traverse
Expand Down