Skip to content

Commit

Permalink
Add support for aborting RPC with error codes
Browse files Browse the repository at this point in the history
  • Loading branch information
wasdennnoch committed Jan 3, 2024
1 parent b7954a3 commit 6b04045
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 13 deletions.
33 changes: 22 additions & 11 deletions latte/src/main/java/gg/beemo/latte/broker/BrokerSubclients.kt
Expand Up @@ -91,11 +91,12 @@ class ProducerSubclient<T>(
targetInstances = instances,
),
)
return internalSend(msg)
@Suppress("UNCHECKED_CAST")
return internalSend(msg as AbstractBrokerMessage<T?>)
}

internal suspend fun internalSend(msg: AbstractBrokerMessage<T>): MessageId {
if (!isNullable) {
internal suspend fun internalSend(msg: AbstractBrokerMessage<T?>, bypassNullCheck: Boolean = false): MessageId {
if (!bypassNullCheck && !isNullable) {
requireNotNull(msg.value) {
"Cannot send null message for non-nullable type with key '$key' in topic '$topic'"
}
Expand Down Expand Up @@ -146,7 +147,8 @@ class ConsumerSubclient<T>(
headers: BrokerMessageHeaders,
) = coroutineScope {
val data = parseIncoming(value)
if (!isNullable) {
// Disable nullability enforcement for RPC exceptions. The caller has to deal with the unsafe typing now.
if (!isNullable && (headers !is RpcMessageHeaders || !headers.isException)) {
checkNotNull(data) {
"Received null message for non-nullable type with key '$key' in topic '$topic'"
}
Expand Down Expand Up @@ -199,7 +201,7 @@ class RpcClient<RequestT, ResponseT>(
responseIsNullable,
)

suspend fun sendResponse(response: ResponseT, status: RpcStatus, isUpdate: Boolean) {
suspend fun sendResponse(response: ResponseT?, status: RpcStatus, isException: Boolean, isUpdate: Boolean) {
val responseMsg = RpcResponseMessage(
client.toResponseTopic(topic),
client.toResponseKey(key),
Expand All @@ -211,22 +213,27 @@ class RpcClient<RequestT, ResponseT>(
targetInstances = setOf(msg.headers.sourceInstance),
inReplyTo = msg.headers.messageId,
status,
isException,
isUpdate,
),
)
responseProducer.internalSend(responseMsg)
responseProducer.internalSend(responseMsg, bypassNullCheck = isException)
}

val rpcMessage = msg.toRpcRequestMessage<ResponseT> { data, status ->
sendResponse(data, status, true)
sendResponse(data, status, isException = false, isUpdate = true)
}
val (status, response) = try {
callback(rpcMessage)
try {
val (status, response) = callback(rpcMessage)
sendResponse(response, status, false, isUpdate = false)
} catch (_: IgnoreRpcRequest) {
return@consumer
} catch (ex: RpcException) {
sendResponse(null, ex.status, true, isUpdate = false)
return@consumer
} finally {
responseProducer.destroy()
}
sendResponse(response, status, false)
responseProducer.destroy()
}

suspend fun call(
Expand Down Expand Up @@ -267,6 +274,10 @@ class RpcClient<RequestT, ResponseT>(
if (msg.headers.inReplyTo != messageId.get()) {
return@consumer
}
if (msg.headers.isException) {
close(RpcException(msg.headers.status))
return@consumer
}
send(msg)
timeoutLatch?.countDown()
val count = responseCounter.incrementAndGet()
Expand Down
3 changes: 2 additions & 1 deletion latte/src/main/java/gg/beemo/latte/broker/Exceptions.kt
Expand Up @@ -2,4 +2,5 @@ package gg.beemo.latte.broker

sealed class BrokerException(message: String?) : Exception(message)
class RpcRequestTimeout(message: String) : BrokerException(message)
class IgnoreRpcRequest : BrokerException(null)
class IgnoreRpcRequest : BrokerException("Ignoring RPC request")
class RpcException(val status: RpcStatus) : BrokerException("RPC failed with status $status")
10 changes: 9 additions & 1 deletion latte/src/main/java/gg/beemo/latte/broker/RpcMessageHeaders.kt
Expand Up @@ -8,6 +8,9 @@ class RpcMessageHeaders(headers: Map<String, String>) : BrokerMessageHeaders(hea
val status: RpcStatus by lazy {
RpcStatus(headers.getOrDefault(HEADER_STATUS, "999_999").toInt())
}
val isException: Boolean by lazy {
headers.getOrDefault(HEADER_IS_EXCEPTION, "false").toBoolean()
}
val isUpdate: Boolean by lazy {
headers.getOrDefault(HEADER_IS_UPDATE, "false").toBoolean()
}
Expand All @@ -21,6 +24,7 @@ class RpcMessageHeaders(headers: Map<String, String>) : BrokerMessageHeaders(hea
targetInstances: Set<String>,
inReplyTo: MessageId,
status: RpcStatus,
isException: Boolean,
isUpdate: Boolean,
) : this(
createHeadersMap(
Expand All @@ -32,6 +36,7 @@ class RpcMessageHeaders(headers: Map<String, String>) : BrokerMessageHeaders(hea
extra = mapOf(
HEADER_IN_REPLY_TO to inReplyTo,
HEADER_STATUS to status.code.toString(),
HEADER_IS_EXCEPTION to isException.toString(),
HEADER_IS_UPDATE to isUpdate.toString(),
)
)
Expand All @@ -43,6 +48,7 @@ class RpcMessageHeaders(headers: Map<String, String>) : BrokerMessageHeaders(hea
targetInstances: Set<String>,
inReplyTo: MessageId,
status: RpcStatus,
isException: Boolean,
isUpdate: Boolean,
) : this(
connection.serviceName,
Expand All @@ -51,14 +57,16 @@ class RpcMessageHeaders(headers: Map<String, String>) : BrokerMessageHeaders(hea
targetInstances,
inReplyTo,
status,
isException,
isUpdate,
)

companion object {

private const val HEADER_IN_REPLY_TO = "rpc-in-reply-to"
private const val HEADER_IS_UPDATE = "rpc-is-update"
private const val HEADER_STATUS = "rpc-response-status"
private const val HEADER_IS_EXCEPTION = "rpc-is-exception"
private const val HEADER_IS_UPDATE = "rpc-is-update"

}

Expand Down
Expand Up @@ -3,6 +3,7 @@ package gg.beemo.latte.broker
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows

class BrokerClientTest {

Expand All @@ -22,6 +23,14 @@ class BrokerClientTest {
Assertions.assertEquals(RpcStatus(1337), response.headers.status)
}

@Test
fun `test exception RPC`() = withTestClient { client ->
val exception = assertThrows<RpcException> {
client.exceptionRpc.call(null)
}
Assertions.assertEquals(RpcStatus(1337), exception.status)
}

@Test
fun `test safe Long serializer`() = withTestClient { client ->
client.safeLongProducer.send(1337)
Expand Down
Expand Up @@ -36,6 +36,14 @@ class TestBrokerClient(
return@rpc RpcStatus(1337) to null
}

val exceptionRpc = rpc<Unit?, Unit?>(
topic = "exception",
key = "exception",
) {
log.info("exceptionRpc received request: ${it.value}")
throw RpcException(RpcStatus(1337))
}

val safeLongProducer = producer<Long>(
topic = "long",
key = "long",
Expand Down

0 comments on commit 6b04045

Please sign in to comment.