Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize traverse #4498

Merged
merged 31 commits into from May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
858cca4
Optimize List traverse for stack-safe Monads
TimWSpence Aug 18, 2023
847df54
Optimize {List, Vector} traverse for StackSafeMonads
TimWSpence Aug 21, 2023
878c4d0
Optimize {Queue, Map} traverse for StackSafeMonads
TimWSpence Aug 21, 2023
9d86c98
Optimize Seq traverse for StackSafeMonads
TimWSpence Aug 21, 2023
c4bcba2
Optimize ArraySeq traverse for StackSafeMonads
TimWSpence Aug 22, 2023
c4c64ec
Optimize Chain traverse for StackSafeMonads
TimWSpence Aug 22, 2023
1344a70
Add size hint to StackSafeMonad traverse
TimWSpence Aug 22, 2023
b6973dc
Optimize List traverseFilter for StackSafeMonads
TimWSpence Aug 22, 2023
f0648c4
Optimize {Seq, Queue, Vector, ArraySeq} traverseFilter for
TimWSpence Aug 22, 2023
dbccc8e
Optimize Chain filterTraverse for StackSafeMonads
TimWSpence Aug 22, 2023
fbf5ea5
Optimize List traverse_ for StackSafeMonads
TimWSpence Aug 22, 2023
9f60d2b
Optimize Vector traverse_ for StackSafeMonads
TimWSpence Aug 23, 2023
f4fcf57
Optimize {Seq, ArraySeq, Chain, Queue} traverse_ for StackSafeMonads
TimWSpence Aug 23, 2023
d50bf83
Add extra laws tests so we test the multiple branches corresponding to
TimWSpence Aug 23, 2023
5c5049f
Add benchmarks for traverse_ and for Chain
TimWSpence Aug 23, 2023
a1c9ff5
Use .void over .map(_ => ()) to give instances chance to optimize
TimWSpence Aug 25, 2023
8cc742e
Experiment: use immutable data structures in optimized traverse
TimWSpence Nov 8, 2023
e3ae584
Experiment: use immutable List for optimized traverse
TimWSpence Nov 10, 2023
f95b087
Use applicative methods instead of flatMap
TimWSpence Nov 14, 2023
c5d90a6
Experiment: use Chain for aggregating our optimized traverse
TimWSpence Dec 8, 2023
2d5f4d7
Implement optimized traverseFilter in terms of Chain as well
TimWSpence Dec 13, 2023
702ab8b
Vector-based optimized traverse and traverseFilter in the same commit
TimWSpence Dec 19, 2023
b08196e
Retroactive attempt to establish a baseline including the new benchmarks
TimWSpence Dec 22, 2023
92c91e4
Revert "Retroactive attempt to establish a baseline including the new…
TimWSpence Mar 4, 2024
7d615dc
Merge remote-tracking branch 'upstream/main' into optimize-traverse
TimWSpence Mar 4, 2024
eebed5b
Use Applicative#unit to limit allocations
TimWSpence Mar 4, 2024
803107d
More Applicative#unit to limit allocations
TimWSpence Mar 5, 2024
0ee49e6
Remove mutable builder from Map traverse as its unlawful
TimWSpence Mar 7, 2024
6529be7
Use Chain in `traverseDirectly`
valencik Mar 22, 2024
a62ce4c
Use `toList` in seq traverse, not vector
valencik Mar 22, 2024
6cb787d
Merge pull request #2 from valencik/more-chain
TimWSpence Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 87 additions & 0 deletions bench/src/main/scala/cats/bench/TraverseBench.scala
Expand Up @@ -24,15 +24,19 @@ package cats.bench
import cats.{Eval, Traverse, TraverseFilter}
import org.openjdk.jmh.annotations.{Benchmark, Param, Scope, Setup, State}
import org.openjdk.jmh.infra.Blackhole
import cats.data.Chain

@State(Scope.Benchmark)
class TraverseBench {
val listT: Traverse[List] = Traverse[List]
val listTFilter: TraverseFilter[List] = TraverseFilter[List]
val chainTFilter: TraverseFilter[Chain] = TraverseFilter[Chain]

val vectorT: Traverse[Vector] = Traverse[Vector]
val vectorTFilter: TraverseFilter[Vector] = TraverseFilter[Vector]

val chainT: Traverse[Chain] = Traverse[Chain]

// the unit of CPU work per iteration
private[this] val Work: Long = 10

Expand All @@ -43,11 +47,13 @@ class TraverseBench {

var list: List[Int] = _
var vector: Vector[Int] = _
var chain: Chain[Int] = _

@Setup
def setup(): Unit = {
list = 0.until(length).toList
vector = 0.until(length).toVector
chain = Chain.fromSeq(0.until(length))
}

@Benchmark
Expand Down Expand Up @@ -83,6 +89,18 @@ class TraverseBench {
}
}

@Benchmark
def traverse_List(bh: Blackhole) = {
val result = listT.traverse_(list) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverseFilterList(bh: Blackhole) = {
val result = listTFilter.traverseFilter(list) { i =>
Expand Down Expand Up @@ -137,6 +155,18 @@ class TraverseBench {
bh.consume(result.value)
}

@Benchmark
def traverse_Vector(bh: Blackhole) = {
val result = vectorT.traverse_(vector) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverseVectorError(bh: Blackhole) = {
val result = vectorT.traverse(vector) { i =>
Expand Down Expand Up @@ -199,4 +229,61 @@ class TraverseBench {

bh.consume(results)
}

@Benchmark
def traverseChain(bh: Blackhole) = {
val result = chainT.traverse(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverse_Chain(bh: Blackhole) = {
val result = chainT.traverse_(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverseChainError(bh: Blackhole) = {
val result = chainT.traverse(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)

if (i == length * 0.3) {
throw Failure
}

i * 2
}
}

try {
bh.consume(result.value)
} catch {
case Failure => ()
}
}

@Benchmark
def traverseFilterChain(bh: Blackhole) = {
val result = chainTFilter.traverseFilter(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
if (i % 2 == 0) Some(i * 2) else None
}
}

bh.consume(result.value)
}
}
33 changes: 29 additions & 4 deletions core/src/main/scala-2.13+/cats/instances/arraySeq.scala
Expand Up @@ -102,7 +102,24 @@ private[cats] object ArraySeqInstances {
B.combineAll(fa.iterator.map(f))

def traverse[G[_], A, B](fa: ArraySeq[A])(f: A => G[B])(implicit G: Applicative[G]): G[ArraySeq[B]] =
G.map(Chain.traverseViaChain(fa)(f))(_.iterator.to(ArraySeq.untagged))
G match {
case x: StackSafeMonad[G] =>
x.map(Traverse.traverseDirectly(fa.iterator)(f)(x))(_.iterator.to(ArraySeq.untagged))
case _ =>
G.map(Chain.traverseViaChain(fa)(f))(_.iterator.to(ArraySeq.untagged))

}

override def traverse_[G[_], A, B](fa: ArraySeq[A])(f: A => G[B])(implicit G: Applicative[G]): G[Unit] =
G match {
case x: StackSafeMonad[G] => Traverse.traverse_Directly(fa)(f)(x)
case _ =>
foldRight(fa, Eval.now(G.unit)) { (a, acc) =>
G.map2Eval(f(a), acc) { (_, _) =>
()
}
}.value
}

override def mapAccumulate[S, A, B](init: S, fa: ArraySeq[A])(f: (S, A) => (S, B)): (S, ArraySeq[B]) =
StaticMethods.mapAccumulateFromStrictFunctor(init, fa, f)(this)
Expand Down Expand Up @@ -214,9 +231,17 @@ private[cats] object ArraySeqInstances {
def traverseFilter[G[_], A, B](
fa: ArraySeq[A]
)(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[ArraySeq[B]] =
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[B]))) { case (x, xse) =>
G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ +: o))
}.value
G match {
case x: StackSafeMonad[G] =>
x.map(TraverseFilter.traverseFilterDirectly(fa.iterator)(f)(x))(
_.iterator.to(ArraySeq.untagged)
)
case _ =>
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[B]))) { case (x, xse) =>
G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ +: o))
}.value

}

override def filterA[G[_], A](fa: ArraySeq[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[ArraySeq[A]] =
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[A]))) { case (x, xse) =>
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/scala/cats/Traverse.scala
Expand Up @@ -21,8 +21,10 @@

package cats

import cats.data.Chain
import cats.data.State
import cats.data.StateT
import cats.kernel.compat.scalaVersionSpecific._

/**
* Traverse, also known as Traversable.
Expand Down Expand Up @@ -284,4 +286,26 @@ object Traverse {
@deprecated("Use cats.syntax object imports", "2.2.0")
object nonInheritedOps extends ToTraverseOps

private[cats] def traverseDirectly[G[_], A, B](
fa: IterableOnce[A]
)(f: A => G[B])(implicit G: StackSafeMonad[G]): G[Chain[B]] = {
fa.iterator.foldLeft(G.pure(Chain.empty[B])) { case (accG, a) =>
G.map2(accG, f(a)) { case (acc, x) =>
acc :+ x
}
}
}

private[cats] def traverse_Directly[G[_], A, B](
fa: IterableOnce[A]
)(f: A => G[B])(implicit G: StackSafeMonad[G]): G[Unit] = {
val iter = fa.iterator
if (iter.hasNext) {
val first = iter.next()
G.void(iter.foldLeft(f(first)) { case (g, a) =>
G.productR(g)(f(a))
})
} else G.unit
}

}
14 changes: 13 additions & 1 deletion core/src/main/scala/cats/TraverseFilter.scala
Expand Up @@ -21,7 +21,8 @@

package cats

import cats.data.State
import cats.data.{Chain, State}
import cats.kernel.compat.scalaVersionSpecific._

import scala.collection.immutable.{IntMap, TreeSet}

Expand Down Expand Up @@ -203,4 +204,15 @@ object TraverseFilter {
@deprecated("Use cats.syntax object imports", "2.2.0")
object nonInheritedOps extends ToTraverseFilterOps

private[cats] def traverseFilterDirectly[G[_], A, B](
fa: IterableOnce[A]
)(f: A => G[Option[B]])(implicit G: StackSafeMonad[G]): G[Chain[B]] = {
fa.iterator.foldLeft(G.pure(Chain.empty[B])) { case (bldrG, a) =>
G.map2(bldrG, f(a)) {
case (acc, Some(b)) => acc :+ b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you benchmarked :+ on Vector vs Chain? I think Chain may be faster since it doesn't have to copy each internal block of items.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's next on my todo list

case (acc, None) => acc
}
}
}

}
43 changes: 32 additions & 11 deletions core/src/main/scala/cats/data/Chain.scala
Expand Up @@ -1243,11 +1243,27 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
def traverse[G[_], A, B](fa: Chain[A])(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else
traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
G match {
johnynek marked this conversation as resolved.
Show resolved Hide resolved
case x: StackSafeMonad[G] =>
Traverse.traverseDirectly(fa.iterator)(f)(x)
case _ =>
traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
}

override def traverse_[G[_], A, B](fa: Chain[A])(f: A => G[B])(implicit G: Applicative[G]): G[Unit] =
G match {
johnynek marked this conversation as resolved.
Show resolved Hide resolved
case x: StackSafeMonad[G] => Traverse.traverse_Directly(fa.iterator)(f)(x)
case _ =>
foldRight(fa, Eval.now(G.unit)) { (a, acc) =>
G.map2Eval(f(a), acc) { (_, _) =>
()
}
}.value
}

override def mapAccumulate[S, A, B](init: S, fa: Chain[A])(f: (S, A) => (S, B)): (S, Chain[B]) =
StaticMethods.mapAccumulateFromStrictFunctor(init, fa, f)(this)
Expand Down Expand Up @@ -1341,7 +1357,7 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
}

implicit val catsDataTraverseFilterForChain: TraverseFilter[Chain] = new TraverseFilter[Chain] {
def traverse: Traverse[Chain] = Chain.catsDataInstancesForChain
def traverse: Traverse[Chain] with Alternative[Chain] = Chain.catsDataInstancesForChain

override def filter[A](fa: Chain[A])(f: A => Boolean): Chain[A] = fa.filter(f)

Expand All @@ -1356,11 +1372,16 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
def traverseFilter[G[_], A, B](fa: Chain[A])(f: A => G[Option[B]])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else
traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
G match {
case x: StackSafeMonad[G] =>
TraverseFilter.traverseFilterDirectly(fa.iterator)(f)(x)
case _ =>
johnynek marked this conversation as resolved.
Show resolved Hide resolved
traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
}

override def filterA[G[_], A](fa: Chain[A])(f: A => G[Boolean])(implicit G: Applicative[G]): G[Chain[A]] =
traverse
Expand Down