Skip to content

Commit

Permalink
Merge pull request #483 from zxzxwu/rfc
Browse files Browse the repository at this point in the history
RFCOMM: Handle packets received before DLC sink set
  • Loading branch information
zxzxwu committed May 10, 2024
2 parents 7fbfdb6 + 9682077 commit 8781943
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 5 deletions.
33 changes: 29 additions & 4 deletions bumble/rfcomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import logging
import asyncio
import collections
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
Expand Down Expand Up @@ -54,6 +55,7 @@
# fmt: off

RFCOMM_PSM = 0x0003
DEFAULT_RX_QUEUE_SIZE = 32

class FrameType(enum.IntEnum):
SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
Expand Down Expand Up @@ -445,7 +447,8 @@ class State(enum.IntEnum):
RESET = 0x05

connection_result: Optional[asyncio.Future]
sink: Optional[Callable[[bytes], None]]
_sink: Optional[Callable[[bytes], None]]
_enqueued_rx_packets: collections.deque[bytes]

def __init__(
self,
Expand All @@ -466,17 +469,32 @@ def __init__(
self.state = DLC.State.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.sink = None
self.connection_result = None
self.drained = asyncio.Event()
self.drained.set()
# Queued packets when sink is not set.
self._enqueued_rx_packets = collections.deque(maxlen=DEFAULT_RX_QUEUE_SIZE)
self._sink = None

# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead
)

@property
def sink(self) -> Optional[Callable[[bytes], None]]:
return self._sink

@sink.setter
def sink(self, sink: Optional[Callable[[bytes], None]]) -> None:
self._sink = sink
# Dump queued packets to sink
if sink:
for packet in self._enqueued_rx_packets:
sink(packet) # pylint: disable=not-callable
self._enqueued_rx_packets.clear()

def change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}')
self.state = new_state
Expand Down Expand Up @@ -549,8 +567,15 @@ def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
f'rx_credits={self.rx_credits}: {data.hex()}'
)
if data:
if self.sink:
self.sink(data) # pylint: disable=not-callable
if self._sink:
self._sink(data) # pylint: disable=not-callable
else:
self._enqueued_rx_packets.append(data)
if (
self._enqueued_rx_packets.maxlen
and len(self._enqueued_rx_packets) >= self._enqueued_rx_packets.maxlen
):
logger.warning(f'DLC [{self.dlci}] received packet queue is full')

# Update the credits
if self.rx_credits > 0:
Expand Down
25 changes: 25 additions & 0 deletions tests/rfcomm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
RFCOMM_PSM,
)

_TIMEOUT = 0.1


# -----------------------------------------------------------------------------
def basic_frame_check(x):
Expand Down Expand Up @@ -82,6 +84,29 @@ async def test_basic_connection() -> None:
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_receive_pdu_before_open_dlc_returns() -> None:
devices = await test_utils.TwoDevices.create_with_connection()
DATA = b'123'

accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future()
channel = Server(devices[0]).listen(acceptor=accept_future.set_result)

assert devices.connections[1]
multiplexer = await Client(devices.connections[1]).start()
open_dlc_task = asyncio.create_task(multiplexer.open_dlc(channel))

dlc_responder = await accept_future
dlc_responder.write(DATA)

dlc_initiator = await open_dlc_task
dlc_initiator_queue = asyncio.Queue() # type: ignore[var-annotated]
dlc_initiator.sink = dlc_initiator_queue.put_nowait

assert await asyncio.wait_for(dlc_initiator_queue.get(), timeout=_TIMEOUT) == DATA


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_record():
Expand Down
9 changes: 8 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
from typing import List, Optional
from typing import List, Optional, Type
from typing_extensions import Self

from bumble.controller import Controller
from bumble.link import LocalLink
Expand Down Expand Up @@ -81,6 +82,12 @@ async def setup_connection(self) -> None:
def __getitem__(self, index: int) -> Device:
return self.devices[index]

@classmethod
async def create_with_connection(cls: Type[Self]) -> Self:
devices = cls()
await devices.setup_connection()
return devices


# -----------------------------------------------------------------------------
async def async_barrier():
Expand Down

0 comments on commit 8781943

Please sign in to comment.