Skip to content

Commit

Permalink
Fix master task shutdown (#121)
Browse files Browse the repository at this point in the history
* quick and dirty fix for infinite loop in state machine

* use the type system to make it harder for this bug to occur

* update tests and examples to target .NET 6
  • Loading branch information
jadamcrain committed Jul 24, 2023
1 parent 056c6bb commit 7f997f9
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 54 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion ffi/bindings/dotnet/examples/client/client.csproj
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp3.1</TargetFramework>
<TargetFramework>net6.0</TargetFramework>
<IsPublishable>False</IsPublishable>
</PropertyGroup>

Expand Down
2 changes: 1 addition & 1 deletion ffi/bindings/dotnet/examples/server/server.csproj
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp3.1</TargetFramework>
<TargetFramework>net6.0</TargetFramework>
<IsPublishable>False</IsPublishable>
</PropertyGroup>

Expand Down
2 changes: 1 addition & 1 deletion ffi/bindings/dotnet/rodbus-tests/rodbus-tests.csproj
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netcoreapp3.1</TargetFramework>
<TargetFramework>net6.0</TargetFramework>
<RootNamespace>rodbus_tests</RootNamespace>

<IsPackable>false</IsPackable>
Expand Down
17 changes: 17 additions & 0 deletions rodbus/src/channel.rs
@@ -0,0 +1,17 @@
use crate::Shutdown;

/// wrap a Tokio receiver and only provide a recv() that returns a Result<T, Shutdown>
/// that makes it harder to misuse.
pub(crate) struct Receiver<T>(tokio::sync::mpsc::Receiver<T>);

impl<T> From<tokio::sync::mpsc::Receiver<T>> for Receiver<T> {
fn from(value: tokio::sync::mpsc::Receiver<T>) -> Self {
Self(value)
}
}

impl<T> Receiver<T> {
pub(crate) async fn recv(&mut self) -> Result<T, Shutdown> {
self.0.recv().await.ok_or(Shutdown)
}
}
2 changes: 1 addition & 1 deletion rodbus/src/client/channel.rs
Expand Up @@ -73,7 +73,7 @@ impl Channel {
let _ = crate::serial::client::SerialChannelTask::new(
&path,
serial_settings,
rx,
rx.into(),
retry,
decode,
listener.unwrap_or_else(|| crate::client::NullListener::create()),
Expand Down
94 changes: 52 additions & 42 deletions rodbus/src/client/task.rs
Expand Up @@ -25,17 +25,29 @@ pub(crate) enum SessionError {
Shutdown,
}

impl From<Shutdown> for SessionError {
fn from(_: Shutdown) -> Self {
SessionError::Shutdown
}
}

#[derive(Copy, Clone, Debug, PartialEq)]
pub(crate) enum StateChange {
Disable,
Shutdown,
}

impl From<Shutdown> for StateChange {
fn from(_: Shutdown) -> Self {
StateChange::Shutdown
}
}

impl std::fmt::Display for SessionError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
SessionError::IoError(err) => {
write!(f, "i/o error: {err}")
write!(f, "I/O error: {err}")
}
SessionError::BadFrame => {
write!(f, "Parser encountered a bad frame")
Expand All @@ -51,9 +63,9 @@ impl std::fmt::Display for SessionError {
}

impl SessionError {
pub(crate) fn from(err: &RequestError) -> Option<Self> {
pub(crate) fn from_request_err(err: RequestError) -> Option<Self> {
match err {
RequestError::Io(x) => Some(SessionError::IoError(*x)),
RequestError::Io(x) => Some(SessionError::IoError(x)),
RequestError::BadFrame(_) => Some(SessionError::BadFrame),
// all other errors don't kill the loop
_ => None,
Expand All @@ -62,7 +74,7 @@ impl SessionError {
}

pub(crate) struct ClientLoop {
rx: tokio::sync::mpsc::Receiver<Command>,
rx: crate::channel::Receiver<Command>,
writer: FrameWriter,
reader: FramedReader,
tx_id: TxId,
Expand All @@ -72,7 +84,7 @@ pub(crate) struct ClientLoop {

impl ClientLoop {
pub(crate) fn new(
rx: tokio::sync::mpsc::Receiver<Command>,
rx: crate::channel::Receiver<Command>,
writer: FrameWriter,
reader: FramedReader,
decode: DecodeLevel,
Expand Down Expand Up @@ -118,32 +130,31 @@ impl ClientLoop {

pub(crate) async fn run(&mut self, io: &mut PhysLayer) -> SessionError {
loop {
tokio::select! {
frame = self.reader.next_frame(io, self.decode) => {
match frame {
Ok(frame) => {
tracing::warn!("Received unexpected frame while idle: {:?}", frame.header);
}
Err(err) => {
if let Some(err) = SessionError::from(&err) {
tracing::warn!("{}", err);
return err;
}
}
if let Err(err) = self.poll(io).await {
tracing::warn!("ending session: {}", err);
return err;
}
}
}

pub(crate) async fn poll(&mut self, io: &mut PhysLayer) -> Result<(), SessionError> {
tokio::select! {
frame = self.reader.next_frame(io, self.decode) => {
match frame {
Ok(frame) => {
tracing::warn!("Received unexpected frame while idle: {:?}", frame.header);
Ok(())
}
}
cmd = self.rx.recv() => {
match cmd {
// other side has closed the request channel
None => return SessionError::Shutdown,
Some(cmd) => {
if let Err(err) = self.run_cmd(cmd, io).await {
return err;
}
}
Err(err) => match SessionError::from_request_err(err) {
Some(err) => Err(err),
None => Ok(()),
}
}
}
res = self.rx.recv() => {
let cmd: Command = res?;
self.run_cmd(cmd, io).await
}
}
}

Expand All @@ -166,7 +177,7 @@ impl ClientLoop {

// some request errors are a session error that will
// bubble up and close the session
if let Some(err) = SessionError::from(&err) {
if let Some(err) = SessionError::from_request_err(err) {
return Err(err);
}
}
Expand Down Expand Up @@ -240,21 +251,20 @@ impl ClientLoop {
}

async fn fail_next_request(&mut self) -> Result<(), StateChange> {
match self.rx.recv().await {
None => return Err(StateChange::Disable),
Some(cmd) => match cmd {
Command::Request(mut req) => {
req.details.fail(RequestError::NoConnection);
}
Command::Setting(x) => {
self.change_setting(x);
if !self.enabled {
return Err(StateChange::Disable);
}
match self.rx.recv().await? {
Command::Request(mut req) => {
req.details.fail(RequestError::NoConnection);
Ok(())
}
Command::Setting(x) => {
self.change_setting(x);
if self.enabled {
Ok(())
} else {
Err(StateChange::Disable)
}
},
}
}
Ok(())
}

pub(crate) async fn fail_requests_for(
Expand Down Expand Up @@ -300,7 +310,7 @@ mod tests {
let (tx, rx) = tokio::sync::mpsc::channel(16);
let (mock, io_handle) = sfio_tokio_mock_io::mock();
let mut client_loop = ClientLoop::new(
rx,
rx.into(),
FrameWriter::tcp(),
FramedReader::tcp(),
DecodeLevel::default().application(AppDecodeLevel::DataValues),
Expand Down
1 change: 1 addition & 0 deletions rodbus/src/lib.rs
Expand Up @@ -167,6 +167,7 @@ pub mod constants;
pub mod server;

// modules that are re-exported
pub(crate) mod channel;
pub(crate) mod decode;
pub(crate) mod error;
pub(crate) mod exception;
Expand Down
3 changes: 1 addition & 2 deletions rodbus/src/serial/client.rs
@@ -1,7 +1,6 @@
use crate::common::phys::PhysLayer;
use crate::decode::DecodeLevel;
use crate::serial::SerialSettings;
use tokio::sync::mpsc::Receiver;

use crate::client::message::Command;
use crate::client::task::{ClientLoop, SessionError, StateChange};
Expand All @@ -21,7 +20,7 @@ impl SerialChannelTask {
pub(crate) fn new(
path: &str,
serial_settings: SerialSettings,
rx: Receiver<Command>,
rx: crate::channel::Receiver<Command>,
retry: Box<dyn RetryStrategy>,
decode: DecodeLevel,
listener: Box<dyn Listener<PortState>>,
Expand Down
5 changes: 2 additions & 3 deletions rodbus/src/tcp/client.rs
Expand Up @@ -11,7 +11,6 @@ use crate::error::Shutdown;
use crate::retry::RetryStrategy;

use tokio::net::TcpStream;
use tokio::sync::mpsc::Receiver;

pub(crate) fn spawn_tcp_channel(
host: HostAddr,
Expand All @@ -37,7 +36,7 @@ pub(crate) fn create_tcp_channel(
let task = async move {
TcpChannelTask::new(
host.clone(),
rx,
rx.into(),
TcpTaskConnectionHandler::Tcp,
connect_retry,
decode,
Expand Down Expand Up @@ -81,7 +80,7 @@ pub(crate) struct TcpChannelTask {
impl TcpChannelTask {
pub(crate) fn new(
host: HostAddr,
rx: Receiver<Command>,
rx: crate::channel::Receiver<Command>,
connection_handler: TcpTaskConnectionHandler,
connect_retry: Box<dyn RetryStrategy>,
decode: DecodeLevel,
Expand Down
2 changes: 1 addition & 1 deletion rodbus/src/tcp/tls/client.rs
Expand Up @@ -54,7 +54,7 @@ pub(crate) fn create_tls_channel(
let task = async move {
TcpChannelTask::new(
host.clone(),
rx,
rx.into(),
TcpTaskConnectionHandler::Tls(tls_config),
connect_retry,
decode,
Expand Down

0 comments on commit 7f997f9

Please sign in to comment.