Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add virtual thread support #1299

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions actor/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ pekko {
# Valid options:
# - "default-executor" requires a "default-executor" section
# - "fork-join-executor" requires a "fork-join-executor" section
# - "virtual-thread-executor" requires a "virtual-thread-executor" section
# - "thread-pool-executor" requires a "thread-pool-executor" section
# - "affinity-pool-executor" requires an "affinity-pool-executor" section
# - A FQCN of a class extending ExecutorServiceConfigurator
Expand Down Expand Up @@ -539,6 +540,19 @@ pekko {
allow-core-timeout = on
}

# This will be used if you have set "executor = "virtual-thread-executor"
# This executor will execute the every task on a new virtual thread.
# Underlying thread pool implementation is java.util.concurrent.ForkJoinPool for JDK <= 22
# If the current runtime does not support virtual thread,
# then the executor configured in "fallback" will be used.
virtual-thread-executor {
#Please set the the underlying pool with system properties below:
#jdk.virtualThreadScheduler.parallelism
#jdk.virtualThreadScheduler.maxPoolSize
#jdk.virtualThreadScheduler.minRunnable
#jdk.unparker.maxPoolSize
fallback = "fork-join-executor"
}
# How long time the dispatcher will wait for new actors until it shuts down
shutdown-timeout = 1s

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@ package org.apache.pekko.dispatch
import java.{ util => ju }
import java.util.concurrent._

import scala.annotation.tailrec
import scala.annotation.{ nowarn, tailrec }
import scala.concurrent.{ ExecutionContext, ExecutionContextExecutor }
import scala.concurrent.duration.{ Duration, FiniteDuration }
import scala.util.control.NonFatal

import scala.annotation.nowarn
import com.typesafe.config.Config

import org.apache.pekko
import pekko.actor._
import pekko.annotation.InternalStableApi
Expand All @@ -33,6 +30,8 @@ import pekko.event.EventStream
import pekko.event.Logging.{ Debug, Error, LogEventException }
import pekko.util.{ unused, Index, Unsafe }

import com.typesafe.config.Config

final case class Envelope private (message: Any, sender: ActorRef) {

def copy(message: Any = message, sender: ActorRef = sender) = {
Expand Down Expand Up @@ -367,9 +366,16 @@ abstract class MessageDispatcherConfigurator(_config: Config, val prerequisites:
def dispatcher(): MessageDispatcher

def configureExecutor(): ExecutorServiceConfigurator = {
@tailrec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not look like a recursion func.

def configurator(executor: String): ExecutorServiceConfigurator = executor match {
case null | "" | "fork-join-executor" =>
new ForkJoinExecutorConfigurator(config.getConfig("fork-join-executor"), prerequisites)
case "virtual-thread-executor" =>
if (VirtualThreadSupport.isSupported) {
new VirtualThreadExecutorConfigurator(config.getConfig("virtual-thread-executor"), prerequisites)
} else {
configurator(config.getString("virtual-thread-executor.fallback"))
}
case "thread-pool-executor" =>
new ThreadPoolExecutorConfigurator(config.getConfig("thread-pool-executor"), prerequisites)
case "affinity-pool-executor" =>
Expand Down Expand Up @@ -401,6 +407,29 @@ abstract class MessageDispatcherConfigurator(_config: Config, val prerequisites:
}
}

class VirtualThreadExecutorConfigurator(config: Config, prerequisites: DispatcherPrerequisites)
extends ExecutorServiceConfigurator(config, prerequisites) {

override def createExecutorServiceFactory(id: String, threadFactory: ThreadFactory): ExecutorServiceFactory = {
val tf = threadFactory match {
case MonitorableThreadFactory(name, _, contextClassLoader, exceptionHandler, _) =>
new ThreadFactory {
private val vtFactory = VirtualThreadSupport.create(name)
override def newThread(r: Runnable): Thread = {
val vt = vtFactory.newThread(r)
vt.setUncaughtExceptionHandler(exceptionHandler)
contextClassLoader.foreach(vt.setContextClassLoader)
vt
}
}
case _ => VirtualThreadSupport.create(prerequisites.settings.name);
}
new ExecutorServiceFactory {
override def createExecutorService: ExecutorService = new NewVirtualThreadPerTaskExecutor(tf)
}
}
}

class ThreadPoolExecutorConfigurator(config: Config, prerequisites: DispatcherPrerequisites)
extends ExecutorServiceConfigurator(config, prerequisites) {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.dispatch

import java.util
import java.util.Collections
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.ReentrantLock

private[dispatch] class NewVirtualThreadPerTaskExecutor(threadFactory: ThreadFactory) extends AbstractExecutorService {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will use method handle to reflection create this instead of this class

import NewVirtualThreadPerTaskExecutor._

/**
* 0 RUNNING
* 1 SHUTDOWN
* 2 TERMINATED
*/
private val state = new AtomicInteger(RUNNING)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will replace with a single field and a fieldUpdater later

private val virtualThreads = ConcurrentHashMap.newKeySet[Thread]()
private val terminateLock = new ReentrantLock()
private val terminatedCondition = terminateLock.newCondition()

override def shutdown(): Unit = {
shutdown(false)
}

private def shutdown(interrupt: Boolean): Unit = {
if (!isShutdown) {
terminateLock.lock()
try {
if (isTerminated) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reduce this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fee free to fork and change on it:( I'm a little busy at work. if we remove this, we must ensure this method is only be called when we hold the lock.

()
} else {
if (state.compareAndSet(RUNNING, SHUTDOWN) && interrupt) {
virtualThreads.forEach(thread => {
if (!thread.isInterrupted) {
thread.interrupt()
}
})
}
tryTerminateAndSignal()
}
} finally {
terminateLock.unlock()
}
}
}

private def tryTerminateAndSignal(): Unit = {
if (isTerminated) {
()
}
terminateLock.lock()
try {
if (isTerminated) {
return
}
if (virtualThreads.isEmpty && state.compareAndSet(SHUTDOWN, TERMINATED)) {
terminatedCondition.signalAll()
}
} finally {
terminateLock.unlock()
}
}

override def shutdownNow(): util.List[Runnable] = {
shutdown(true)
Collections.emptyList()
}

override def isShutdown: Boolean = state.get() >= SHUTDOWN

override def isTerminated: Boolean = state.get() >= TERMINATED

private def isRunning: Boolean = state.get() == RUNNING

override def awaitTermination(timeout: Long, unit: TimeUnit): Boolean = {
if (isTerminated) {
return true
}
terminateLock.lock()
try {
var nanosRemaining = unit.toNanos(timeout)
while (!isTerminated && nanosRemaining > 0) {
nanosRemaining = terminatedCondition.awaitNanos(nanosRemaining)
}
} finally {
terminateLock.unlock()
}
isTerminated
}

// TODO AS only this execute method is been used in `Dispatcher.scala`, so `submit` and other methods is not override.
override def execute(command: Runnable): Unit = {
if (state.get() >= SHUTDOWN) {
throw new RejectedExecutionException("Shutdown")
}
var started = false;
try {
val thread = threadFactory.newThread(Task(this, command))
virtualThreads.add(thread)
Copy link

@alexandru alexandru Apr 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic here is obviously thread-unsafe because stage.get() >= SHUTDOWN can be reordered with virtualThreads.add(thread).

In other words, you can start and add new threads that won't be waited on while the thread-pool is shutting down.

I don't know if this is a concern we need to have, or how the other implementations do it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, would you like to continue the work? I will continue this when find time.

I have a isStarted check beblow, that would help?
Another issue is, only execute method is been used, not sure what the origin design requires an ExecutionService.

if (isRunning) {
thread.start()
started = true
} else {
onThreadExit(thread)
}
} finally {
if (!started) {
throw new RejectedExecutionException("Shutdown")
}
}
}

private def onThreadExit(thread: Thread): Unit = {
virtualThreads.remove(thread)
if (state.get() == SHUTDOWN) {
tryTerminateAndSignal()
}
}
}

private[dispatch] object NewVirtualThreadPerTaskExecutor {
private final val RUNNING = 0
private final val SHUTDOWN = 1
private final val TERMINATED = 2

private case class Task(executor: NewVirtualThreadPerTaskExecutor, runnable: Runnable) extends Runnable {
override def run(): Unit = {
try {
runnable.run()
} finally {
executor.onThreadExit(Thread.currentThread())
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.dispatch

import org.apache.pekko.annotation.InternalApi

import java.lang.invoke.{ MethodHandles, MethodType }
import java.util.concurrent.ThreadFactory

@InternalApi
private[dispatch] object VirtualThreadSupport {

/**
* Returns if the current Runtime supports virtual threads.
*/
def isSupported: Boolean = create("testIsSupported") ne null
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better with a lazy val, will update later.


/**
* Create a virtual thread factory, returns null when failed.
*/
def create(prefix: String): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val lookup = MethodHandles.lookup
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case _: Throwable => null
}
}