Skip to content

Commit

Permalink
Make reader and writer listen for channel cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
e5l committed May 7, 2024
1 parent 7435299 commit ec01ba7
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 2 deletions.
Expand Up @@ -14,7 +14,6 @@ import io.ktor.client.tests.utils.*
import io.ktor.http.*
import io.ktor.test.dispatcher.*
import io.ktor.util.*
import io.ktor.utils.io.errors.*
import kotlinx.coroutines.*
import kotlin.reflect.*
import kotlin.test.*
Expand Down
Expand Up @@ -282,6 +282,7 @@ public fun CoroutineScope.reader(
block: suspend ReaderScope.() -> Unit
): ReaderJob = reader(coroutineContext, ByteChannel(), block)

@OptIn(InternalCoroutinesApi::class)
public fun CoroutineScope.reader(
coroutineContext: CoroutineContext,
channel: ByteChannel,
Expand All @@ -290,6 +291,10 @@ public fun CoroutineScope.reader(
val job = launch(coroutineContext) {
try {
block(ReaderScope(channel))

if (this.coroutineContext.job.isCancelled) {
channel.cancel(this.coroutineContext.job.getCancellationException())
}
} catch (cause: Throwable) {
channel.close(cause)
} finally {
Expand Down
Expand Up @@ -129,6 +129,7 @@ public fun CoroutineScope.writer(
block: suspend WriterScope.() -> Unit
): WriterJob = writer(coroutineContext, ByteChannel(), block)

@OptIn(InternalCoroutinesApi::class)
public fun CoroutineScope.writer(
coroutineContext: CoroutineContext = EmptyCoroutineContext,
channel: ByteChannel,
Expand All @@ -137,6 +138,10 @@ public fun CoroutineScope.writer(
val job = launch(coroutineContext) {
try {
block(WriterScope(channel))

if (this.coroutineContext.job.isCancelled) {
channel.cancel(this.coroutineContext.job.getCancellationException())
}
} catch (cause: Throwable) {
channel.cancel(cause)
} finally {
Expand Down
2 changes: 1 addition & 1 deletion ktor-io/jvm/src/io/ktor/utils/io/jvm/javaio/Reading.kt
Expand Up @@ -47,7 +47,7 @@ internal class RawSourceChannel(
override val isClosedForRead: Boolean
get() = closedToken != null && buffer.exhausted()

val job = Job()
val job = Job(parent[Job])
val coroutineContext = parent + job + CoroutineName("RawSourceChannel")

@InternalAPI
Expand Down

0 comments on commit ec01ba7

Please sign in to comment.