Skip to content

Commit

Permalink
Merge pull request #807 from rpiotrow/bugfix/806-ThreadLocalRandomGen…
Browse files Browse the repository at this point in the history
…erator-not-serializable

Make thread local field transient in `ThreadLocalRandomGenerator`
  • Loading branch information
dlwh committed Apr 19, 2021
2 parents b965a1c + bf0ef50 commit 31e9a79
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.apache.commons.math3.random.RandomGenerator
**/
@SerialVersionUID(1L)
class ThreadLocalRandomGenerator(genThunk: => RandomGenerator) extends RandomGenerator with Serializable {
private val genTL = new ThreadLocal[RandomGenerator] {
@transient private lazy val genTL = new ThreadLocal[RandomGenerator] {
override def initialValue(): RandomGenerator = genThunk
}
def nextBytes(bytes: Array[Byte]) = genTL.get().nextBytes(bytes)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package breeze.stats.distributions

import org.apache.commons.math3.random.MersenneTwister
import org.scalatest.{FunSuite, Matchers}

import java.io._

class ThreadLocalRandomGeneratorTest extends FunSuite with Matchers {
test("ThreadLocalRandomGeneratorTest should be serializable") {
val generator = new ThreadLocalRandomGenerator(new MersenneTwister())
serialize(generator)
}

test("ThreadLocalRandomGeneratorTest should be serializable after usage") {
val generator = new ThreadLocalRandomGenerator(new MersenneTwister())
generator.nextInt()
serialize(generator)
}

test("ThreadLocalRandomGeneratorTest should be deserializable") {
val generator = new ThreadLocalRandomGenerator(new MersenneTwister())
val i1 = generator.nextInt()
val bytes = serialize(generator)
val deserialized = deserialize(bytes)
val i2 = deserialized.nextInt()

i1 should not be i2
}

private def serialize(generator: ThreadLocalRandomGenerator): Array[Byte] = {
val outputStream = new ByteArrayOutputStream(512)
val out = new ObjectOutputStream(outputStream)
try {
out.writeObject(generator)
outputStream.toByteArray
} catch {
case _: IOException => fail("cannot serialize")
} finally {
if (out != null) out.close()
}
}

private def deserialize(bytes: Array[Byte]): ThreadLocalRandomGenerator = {
val in = new ObjectInputStream(new ByteArrayInputStream(bytes))
try {
in.readObject().asInstanceOf[ThreadLocalRandomGenerator]
} catch {
case _: IOException => fail("cannot deserialize")
} finally {
if (in != null) in.close()
}
}
}

0 comments on commit 31e9a79

Please sign in to comment.