Skip to content

Commit

Permalink
nonblock SSL now passes all tests
Browse files Browse the repository at this point in the history
The following public endpoints for SSLHandlerG1 now correctly handle the non-block operations
- open
- write
- receive
- ack
- flush
- close
I added try..finally blocks to ensure processed messages are fired even if a subsequent message caused the SSL to fail.
  • Loading branch information
jon-valliere committed Feb 20, 2024
1 parent 0cbbc8b commit ec93bb7
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 120 deletions.
220 changes: 108 additions & 112 deletions mina-core/src/main/java/org/apache/mina/filter/ssl/SSLHandlerG1.java
Expand Up @@ -68,7 +68,7 @@
/**
* Enable asynchronous tasks
*/
static protected final boolean ENABLE_ASYNC_TASKS = false;
static protected final boolean ENABLE_ASYNC_TASKS = true;

/**
* Indicates whether the first handshake was completed
Expand Down Expand Up @@ -142,25 +142,21 @@ public boolean isConnected() {
*/
@Override
public void open(NextFilter next) throws SSLException {
synchronized (this) {
if (mHandshakeStarted == false) {
mHandshakeStarted = true;
if (mEngine.getUseClientMode()) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} open() - begin handshaking", this);
try {
synchronized (this) {
if (mHandshakeStarted == false) {
mHandshakeStarted = true;
if (mEngine.getUseClientMode()) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} open() - begin handshaking", this);
}
mEngine.beginHandshake();
write_handshake(next);
}
mEngine.beginHandshake();
write_handshake(next);
}
}
}
synchronized (mWriteQueue) {
EncryptedWriteRequest x;
while((x = mWriteQueue.poll()) != null) {
next.filterWrite(mSession, x);
}
}
synchronized (this) {
} finally {
forward_writes(next);
throw_pending_error(next);
}
}
Expand All @@ -170,20 +166,11 @@ public void open(NextFilter next) throws SSLException {
*/
@Override
public void receive(NextFilter next, IoBuffer message) throws SSLException {
receive_start(next, message);
synchronized (mReceiveQueue) {
IoBuffer x;
while((x = mReceiveQueue.poll()) != null) {
next.messageReceived(mSession, x);
}
}
synchronized (mWriteQueue) {
EncryptedWriteRequest x;
while((x = mWriteQueue.poll()) != null) {
next.filterWrite(mSession, x);
}
}
synchronized (this) {
try {
receive_start(next, message);
} finally {
forward_received(next);
forward_writes(next);
throw_pending_error(next);
}
}
Expand Down Expand Up @@ -267,7 +254,7 @@ protected void receive_loop(NextFilter next, IoBuffer message) throws SSLExcepti
LOGGER.debug("{} receive_loop() - result {}", toString(), dest);
}

mReceiveQueue.push(dest);
mReceiveQueue.add(dest);
}

switch (result.getHandshakeStatus()) {
Expand Down Expand Up @@ -323,25 +310,25 @@ protected void receive_loop(NextFilter next, IoBuffer message) throws SSLExcepti
*/
@Override
public void ack(NextFilter next, WriteRequest request) throws SSLException {
synchronized (this) {
if (mAckQueue.remove(request)) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} ack() - {}", toString(), request);
}
try {
synchronized (this) {
if (mAckQueue.remove(request)) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} ack() - accepted {}", toString(), request);
}

if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} ack() - checking to see if any messages can be flushed", toString(), request);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} ack() - checking to see if any messages can be flushed", toString(), request);
}
flush_start(next);
} else {
if(LOGGER.isWarnEnabled()) {
LOGGER.warn("{} ack() - unknown message {}", toString(), request);
}
}
flush_start(next);
}
}
synchronized (mWriteQueue) {
EncryptedWriteRequest x;
while((x = mWriteQueue.poll()) != null) {
next.filterWrite(mSession, x);
}
}
synchronized (this) {
} finally {
forward_writes(next);
throw_pending_error(next);
}
}
Expand All @@ -351,41 +338,37 @@ public void ack(NextFilter next, WriteRequest request) throws SSLException {
*/
@Override
public void write(NextFilter next, WriteRequest request) throws SSLException, WriteRejectedException {
synchronized (this) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write() - source {}", toString(), request);
}
if (mOutboundClosing) {
throw new WriteRejectedException(request, "closing");
}
if (mEncodeQueue.isEmpty()) {
if (write_loop(next, request) == false) {
try {
synchronized (this) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write() - source {}", toString(), request);
}
if (mOutboundClosing) {
throw new WriteRejectedException(request, "closing");
}
if (mEncodeQueue.isEmpty()) {
if (write_loop(next, request) == false) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write() - unable to write right now, saving request for later", toString(),
request);
}
if (mEncodeQueue.size() == MAX_QUEUED_MESSAGES) {
throw new BufferOverflowException();
}
mEncodeQueue.add(request);
}
} else {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write() - unable to write right now, saving request for later", toString(),
request);
LOGGER.debug("{} write() - unable to write right now, saving request for later", toString(), request);
}
if (mEncodeQueue.size() == MAX_QUEUED_MESSAGES) {
throw new BufferOverflowException();
}
mEncodeQueue.add(request);
}
} else {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write() - unable to write right now, saving request for later", toString(), request);
}
if (mEncodeQueue.size() == MAX_QUEUED_MESSAGES) {
throw new BufferOverflowException();
}
mEncodeQueue.add(request);
}
}
synchronized (mWriteQueue) {
EncryptedWriteRequest x;
while((x = mWriteQueue.poll()) != null) {
next.filterWrite(mSession, x);
}
}
synchronized (this) {
} finally {
forward_writes(next);
throw_pending_error(next);
}
}
Expand All @@ -404,7 +387,7 @@ public void write(NextFilter next, WriteRequest request) throws SSLException, Wr
@SuppressWarnings("incomplete-switch")
synchronized protected boolean write_loop(NextFilter next, WriteRequest request) throws SSLException {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write_user_loop() - source {}", toString(), request);
LOGGER.debug("{} write_loop() - source {}", toString(), request);
}

IoBuffer source = IoBuffer.class.cast(request.getMessage());
Expand All @@ -413,7 +396,7 @@ synchronized protected boolean write_loop(NextFilter next, WriteRequest request)
SSLEngineResult result = mEngine.wrap(source.buf(), dest.buf());

if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write_user_loop() - bytes-consumed {}, bytes-produced {}, status {}, handshake {}",
LOGGER.debug("{} write_loop() - bytes-consumed {}, bytes-produced {}, status {}, handshake {}",
toString(), result.bytesConsumed(), result.bytesProduced(), result.getStatus(),
result.getHandshakeStatus());
}
Expand All @@ -426,39 +409,36 @@ synchronized protected boolean write_loop(NextFilter next, WriteRequest request)
EncryptedWriteRequest encrypted = new EncryptedWriteRequest(dest, null);

if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write_user_loop() - result {}", toString(), encrypted);
LOGGER.debug("{} write_loop() - result {}", toString(), encrypted);
}

mWriteQueue.push(encrypted);
mWriteQueue.add(encrypted);
// do not return because we want to enter the handshake switch
} else {
// then we probably consumed some data
dest.flip();

if (source.hasRemaining()) {
EncryptedWriteRequest encrypted = new EncryptedWriteRequest(dest, null);
mAckQueue.add(encrypted);

if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write_user_loop() - result {}", toString(), encrypted);
LOGGER.debug("{} write_loop() - result {}", toString(), encrypted);
}

mWriteQueue.push(encrypted);
mWriteQueue.add(encrypted);

if (mAckQueue.size() < MAX_UNACK_MESSAGES) {
if (mWriteQueue.size() + mAckQueue.size() < MAX_UNACK_MESSAGES) {
return write_loop(next, request); // write additional chunks
}

return false;
} else {
EncryptedWriteRequest encrypted = new EncryptedWriteRequest(dest, request);
mAckQueue.add(encrypted);


if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write_user_loop() - result {}", toString(), encrypted);
LOGGER.debug("{} write_loop() - result {}", toString(), encrypted);
}

mWriteQueue.push(encrypted);
mWriteQueue.add(encrypted);

return true;
}
Expand All @@ -469,22 +449,22 @@ synchronized protected boolean write_loop(NextFilter next, WriteRequest request)
switch (result.getHandshakeStatus()) {
case NEED_TASK:
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write_user_loop() - handshake needs task, scheduling", toString());
LOGGER.debug("{} write_loop() - handshake needs task, scheduling", toString());
}

schedule_task(next);
break;

case NEED_WRAP:
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write_user_loop() - handshake needs wrap, looping", toString());
LOGGER.debug("{} write_loop() - handshake needs wrap, looping", toString());
}

return write_loop(next, request);

case FINISHED:
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} write_user_loop() - handshake finished, flushing queue", toString());
LOGGER.debug("{} write_loop() - handshake finished, flushing queue", toString());
}

finish_handshake(next);
Expand Down Expand Up @@ -581,7 +561,7 @@ protected boolean write_handshake_loop(NextFilter next, IoBuffer source, IoBuffe
}

EncryptedWriteRequest encrypted = new EncryptedWriteRequest(dest, null);
mWriteQueue.push(encrypted);
mWriteQueue.add(encrypted);
}

switch (result.getHandshakeStatus()) {
Expand Down Expand Up @@ -641,14 +621,10 @@ synchronized protected void finish_handshake(NextFilter next) throws SSLExceptio
* {@inheritDoc}
*/
public void flush(NextFilter next) throws SSLException {
flush_start(next);
synchronized (mWriteQueue) {
EncryptedWriteRequest x;
while((x = mWriteQueue.poll()) != null) {
next.filterWrite(mSession, x);
}
}
synchronized (this) {
try {
flush_start(next);
} finally {
forward_writes(next);
throw_pending_error(next);
}
}
Expand Down Expand Up @@ -701,14 +677,10 @@ synchronized protected void flush_start(NextFilter next) throws SSLException {
*/
@Override
public void close(NextFilter next, boolean linger) throws SSLException {
close_start(next, linger);
synchronized (mWriteQueue) {
EncryptedWriteRequest x;
while((x = mWriteQueue.poll()) != null) {
next.filterWrite(mSession, x);
}
}
synchronized (this) {
try {
close_start(next, linger);
} finally {
forward_writes(next);
throw_pending_error(next);
}
}
Expand Down Expand Up @@ -747,13 +719,13 @@ synchronized protected void close_start(NextFilter next, boolean linger) throws
*/
synchronized protected void throw_pending_error(NextFilter next) throws SSLException {
SSLException sslException = mPendingError;

if (sslException != null) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} throw_pending_error() - throwing pending error");
}
// Loop to send back the alert messages
receive_loop(next, null);

mPendingError = null;

// And finally rethrow the exception
throw sslException;
}
Expand All @@ -770,6 +742,31 @@ synchronized protected void store_pending_error(SSLException sslException) {
}
}

protected void forward_received(NextFilter next) {
synchronized (mReceiveQueue) {
IoBuffer x;
while ((x = mReceiveQueue.poll()) != null) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} forward_received() - received {}", toString(), x);
}
next.messageReceived(mSession, x);
}
}
}

protected void forward_writes(NextFilter next) {
synchronized (mWriteQueue) {
EncryptedWriteRequest x;
while ((x = mWriteQueue.poll()) != null) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} forward_writes() - writing {}", toString(), x);
}
mAckQueue.add(x);
next.filterWrite(mSession, x);
}
}
}

/**
* Schedule a SSLEngine task for execution, either using an Executor, or immediately.
*
Expand All @@ -792,7 +789,6 @@ protected void schedule_task(NextFilter next) {
*/
synchronized protected void execute_task(NextFilter next) {
Runnable task;

while ((task = mEngine.getDelegatedTask()) != null) {
try {
if (LOGGER.isDebugEnabled()) {
Expand Down

0 comments on commit ec93bb7

Please sign in to comment.