Skip to content

Commit

Permalink
feat: Add virtual thread support
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Apr 27, 2024
1 parent 51d6e09 commit a25b239
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 4 deletions.
14 changes: 14 additions & 0 deletions actor/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ pekko {
# Valid options:
# - "default-executor" requires a "default-executor" section
# - "fork-join-executor" requires a "fork-join-executor" section
# - "virtual-thread-executor" requires a "virtual-thread-executor" section
# - "thread-pool-executor" requires a "thread-pool-executor" section
# - "affinity-pool-executor" requires an "affinity-pool-executor" section
# - A FQCN of a class extending ExecutorServiceConfigurator
Expand Down Expand Up @@ -539,6 +540,19 @@ pekko {
allow-core-timeout = on
}

# This will be used if you have set "executor = "virtual-thread-executor"
# This executor will execute the every task on a new virtual thread.
# Underlying thread pool implementation is java.util.concurrent.ForkJoinPool for JDK <= 22
# If the current runtime does not support virtual thread,
# then the executor configured in "fallback" will be used.
virtual-thread-executor {
#Please set the the underlying pool with system properties below:
#jdk.virtualThreadScheduler.parallelism
#jdk.virtualThreadScheduler.maxPoolSize
#jdk.virtualThreadScheduler.minRunnable
#jdk.unparker.maxPoolSize
fallback = "fork-join-executor"
}
# How long time the dispatcher will wait for new actors until it shuts down
shutdown-timeout = 1s

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@ package org.apache.pekko.dispatch
import java.{ util => ju }
import java.util.concurrent._

import scala.annotation.tailrec
import scala.annotation.{ nowarn, tailrec }
import scala.concurrent.{ ExecutionContext, ExecutionContextExecutor }
import scala.concurrent.duration.{ Duration, FiniteDuration }
import scala.util.control.NonFatal

import scala.annotation.nowarn
import com.typesafe.config.Config

import org.apache.pekko
import pekko.actor._
import pekko.annotation.InternalStableApi
Expand All @@ -33,6 +30,8 @@ import pekko.event.EventStream
import pekko.event.Logging.{ Debug, Error, LogEventException }
import pekko.util.{ unused, Index, Unsafe }

import com.typesafe.config.Config

final case class Envelope private (message: Any, sender: ActorRef) {

def copy(message: Any = message, sender: ActorRef = sender) = {
Expand Down Expand Up @@ -367,9 +366,16 @@ abstract class MessageDispatcherConfigurator(_config: Config, val prerequisites:
def dispatcher(): MessageDispatcher

def configureExecutor(): ExecutorServiceConfigurator = {
@tailrec
def configurator(executor: String): ExecutorServiceConfigurator = executor match {
case null | "" | "fork-join-executor" =>
new ForkJoinExecutorConfigurator(config.getConfig("fork-join-executor"), prerequisites)
case "virtual-thread-executor" =>
if (VirtualThreadSupport.isSupported) {
new VirtualThreadExecutorConfigurator(config.getConfig("virtual-thread-executor"), prerequisites)
} else {
configurator(config.getString("virtual-thread-executor.fallback"))
}
case "thread-pool-executor" =>
new ThreadPoolExecutorConfigurator(config.getConfig("thread-pool-executor"), prerequisites)
case "affinity-pool-executor" =>
Expand Down Expand Up @@ -401,6 +407,29 @@ abstract class MessageDispatcherConfigurator(_config: Config, val prerequisites:
}
}

class VirtualThreadExecutorConfigurator(config: Config, prerequisites: DispatcherPrerequisites)
extends ExecutorServiceConfigurator(config, prerequisites) {

override def createExecutorServiceFactory(id: String, threadFactory: ThreadFactory): ExecutorServiceFactory = {
val tf = threadFactory match {
case MonitorableThreadFactory(name, _, contextClassLoader, exceptionHandler, _) =>
new ThreadFactory {
private val vtFactory = VirtualThreadSupport.create(name)
override def newThread(r: Runnable): Thread = {
val vt = vtFactory.newThread(r)
vt.setUncaughtExceptionHandler(exceptionHandler)
contextClassLoader.foreach(vt.setContextClassLoader)
vt
}
}
case _ => VirtualThreadSupport.create(prerequisites.settings.name);
}
new ExecutorServiceFactory {
override def createExecutorService: ExecutorService = new NewVirtualThreadPerTaskExecutor(tf)
}
}
}

class ThreadPoolExecutorConfigurator(config: Config, prerequisites: DispatcherPrerequisites)
extends ExecutorServiceConfigurator(config, prerequisites) {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.dispatch

import java.util
import java.util.Collections
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.ReentrantLock

private[dispatch] class NewVirtualThreadPerTaskExecutor(threadFactory: ThreadFactory) extends AbstractExecutorService {
import NewVirtualThreadPerTaskExecutor._

/**
* 0 RUNNING
* 1 SHUTDOWN
* 2 TERMINATED
*/
private val state = new AtomicInteger(RUNNING)
private val virtualThreads = ConcurrentHashMap.newKeySet[Thread]()
private val terminateLock = new ReentrantLock()
private val terminatedCondition = terminateLock.newCondition()

override def shutdown(): Unit = {
shutdown(false)
}

private def shutdown(interrupt: Boolean): Unit = {
if (!isShutdown) {
terminateLock.lock()
try {
if (isTerminated) {
()
} else {
if (state.compareAndSet(RUNNING, SHUTDOWN) && interrupt) {
virtualThreads.forEach(thread => {
if (!thread.isInterrupted) {
thread.interrupt()
}
})
}
tryTerminateAndSignal()
}
} finally {
terminateLock.unlock()
}
}
}

private def tryTerminateAndSignal(): Unit = {
if (isTerminated) {
()
}
terminateLock.lock()
try {
if (isTerminated) {
return
}
if (virtualThreads.isEmpty && state.compareAndSet(SHUTDOWN, TERMINATED)) {
terminatedCondition.signalAll()
}
} finally {
terminateLock.unlock()
}
}

override def shutdownNow(): util.List[Runnable] = {
shutdown(true)
Collections.emptyList()
}

override def isShutdown: Boolean = state.get() >= SHUTDOWN

override def isTerminated: Boolean = state.get() >= TERMINATED

private def isRunning: Boolean = state.get() == RUNNING

override def awaitTermination(timeout: Long, unit: TimeUnit): Boolean = {
if (isTerminated) {
return true
}
terminateLock.lock()
try {
var nanosRemaining = unit.toNanos(timeout)
while (!isTerminated && nanosRemaining > 0) {
nanosRemaining = terminatedCondition.awaitNanos(nanosRemaining)
}
} finally {
terminateLock.unlock()
}
isTerminated
}

// TODO AS only this execute method is been used in `Dispatcher.scala`, so `submit` and other methods is not override.
override def execute(command: Runnable): Unit = {
if (state.get() >= SHUTDOWN) {
throw new RejectedExecutionException("Shutdown")
}
var started = false;
try {
val thread = threadFactory.newThread(Task(this, command))
virtualThreads.add(thread)
if (isRunning) {
thread.start()
started = true
} else {
onThreadExit(thread)
}
} finally {
if (!started) {
throw new RejectedExecutionException("Shutdown")
}
}
}

private def onThreadExit(thread: Thread): Unit = {
virtualThreads.remove(thread)
if (state.get() == SHUTDOWN) {
tryTerminateAndSignal()
}
}
}

private[dispatch] object NewVirtualThreadPerTaskExecutor {
private final val RUNNING = 0
private final val SHUTDOWN = 1
private final val TERMINATED = 2

private case class Task(executor: NewVirtualThreadPerTaskExecutor, runnable: Runnable) extends Runnable {
override def run(): Unit = {
try {
runnable.run()
} finally {
executor.onThreadExit(Thread.currentThread())
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.dispatch

import org.apache.pekko.annotation.InternalApi

import java.lang.invoke.{ MethodHandles, MethodType }
import java.util.concurrent.ThreadFactory

@InternalApi
private[dispatch] object VirtualThreadSupport {

/**
* Returns if the current Runtime supports virtual threads.
*/
def isSupported: Boolean = create("testIsSupported") ne null

/**
* Create a virtual thread factory, returns null when failed.
*/
def create(prefix: String): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val lookup = MethodHandles.lookup
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case _: Throwable => null
}
}

0 comments on commit a25b239

Please sign in to comment.