Skip to content

Commit

Permalink
Merge pull request #422 from google/gbg/bench-rfcomm-params
Browse files Browse the repository at this point in the history
add rfcomm options and fix l2cap mtu negotiation
  • Loading branch information
barbibulle committed Feb 5, 2024
2 parents 6d91e7e + d7489a6 commit f4aeaa6
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 25 deletions.
59 changes: 52 additions & 7 deletions apps/bench.py
Expand Up @@ -87,6 +87,7 @@
DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0

DEFAULT_RFCOMM_CHANNEL = 8
DEFAULT_RFCOMM_MTU = 2048


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -896,11 +897,14 @@ async def drain(self):
# RfcommClient
# -----------------------------------------------------------------------------
class RfcommClient(StreamedPacketIO):
def __init__(self, device, channel, uuid):
def __init__(self, device, channel, uuid, l2cap_mtu, max_frame_size, window_size):
super().__init__()
self.device = device
self.channel = channel
self.uuid = uuid
self.l2cap_mtu = l2cap_mtu
self.max_frame_size = max_frame_size
self.window_size = window_size
self.rfcomm_session = None
self.ready = asyncio.Event()

Expand All @@ -924,13 +928,21 @@ async def on_connection(self, connection):

# Create a client and start it
logging.info(color('*** Starting RFCOMM client...', 'blue'))
rfcomm_client = bumble.rfcomm.Client(connection)
rfcomm_options = {}
if self.l2cap_mtu:
rfcomm_options['l2cap_mtu'] = self.l2cap_mtu
rfcomm_client = bumble.rfcomm.Client(connection, **rfcomm_options)
rfcomm_mux = await rfcomm_client.start()
logging.info(color('*** Started', 'blue'))

logging.info(color(f'### Opening session for channel {channel}...', 'yellow'))
try:
rfcomm_session = await rfcomm_mux.open_dlc(channel)
dlc_options = {}
if self.max_frame_size:
dlc_options['max_frame_size'] = self.max_frame_size
if self.window_size:
dlc_options['window_size'] = self.window_size
rfcomm_session = await rfcomm_mux.open_dlc(channel, **dlc_options)
logging.info(color(f'### Session open: {rfcomm_session}', 'yellow'))
except bumble.core.ConnectionError as error:
logging.info(color(f'!!! Session open failed: {error}', 'red'))
Expand All @@ -955,13 +967,16 @@ async def drain(self):
# RfcommServer
# -----------------------------------------------------------------------------
class RfcommServer(StreamedPacketIO):
def __init__(self, device, channel):
def __init__(self, device, channel, l2cap_mtu):
super().__init__()
self.dlc = None
self.ready = asyncio.Event()

# Create and register a server
rfcomm_server = bumble.rfcomm.Server(device)
server_options = {}
if l2cap_mtu:
server_options['l2cap_mtu'] = l2cap_mtu
rfcomm_server = bumble.rfcomm.Server(device, **server_options)

# Listen for incoming DLC connections
channel_number = rfcomm_server.listen(self.on_dlc, channel)
Expand Down Expand Up @@ -1298,11 +1313,20 @@ def create_mode(device):

if mode == 'rfcomm-client':
return RfcommClient(
device, channel=ctx.obj['rfcomm_channel'], uuid=ctx.obj['rfcomm_uuid']
device,
channel=ctx.obj['rfcomm_channel'],
uuid=ctx.obj['rfcomm_uuid'],
l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
max_frame_size=ctx.obj['rfcomm_max_frame_size'],
window_size=ctx.obj['rfcomm_window_size'],
)

if mode == 'rfcomm-server':
return RfcommServer(device, channel=ctx.obj['rfcomm_channel'])
return RfcommServer(
device,
channel=ctx.obj['rfcomm_channel'],
l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
)

raise ValueError('invalid mode')

Expand Down Expand Up @@ -1389,6 +1413,21 @@ def create_role(packet_io):
default=DEFAULT_RFCOMM_UUID,
help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)',
)
@click.option(
'--rfcomm-l2cap-mtu',
type=int,
help='RFComm L2CAP MTU',
)
@click.option(
'--rfcomm-max-frame-size',
type=int,
help='RFComm maximum frame size',
)
@click.option(
'--rfcomm-window-size',
type=int,
help='RFComm window size',
)
@click.option(
'--l2cap-psm',
type=int,
Expand Down Expand Up @@ -1486,6 +1525,9 @@ def bench(
linger,
rfcomm_channel,
rfcomm_uuid,
rfcomm_l2cap_mtu,
rfcomm_max_frame_size,
rfcomm_window_size,
l2cap_psm,
l2cap_mtu,
l2cap_mps,
Expand All @@ -1498,6 +1540,9 @@ def bench(
ctx.obj['att_mtu'] = att_mtu
ctx.obj['rfcomm_channel'] = rfcomm_channel
ctx.obj['rfcomm_uuid'] = rfcomm_uuid
ctx.obj['rfcomm_l2cap_mtu'] = rfcomm_l2cap_mtu
ctx.obj['rfcomm_max_frame_size'] = rfcomm_max_frame_size
ctx.obj['rfcomm_window_size'] = rfcomm_window_size
ctx.obj['l2cap_psm'] = l2cap_psm
ctx.obj['l2cap_mtu'] = l2cap_mtu
ctx.obj['l2cap_mps'] = l2cap_mps
Expand Down
4 changes: 2 additions & 2 deletions bumble/avdtp.py
Expand Up @@ -1470,10 +1470,10 @@ def send_message(self, transaction_label: int, message: Message) -> None:
f'[{transaction_label}] {message}'
)
max_fragment_size = (
self.l2cap_channel.mtu - 3
self.l2cap_channel.peer_mtu - 3
) # Enough space for a 3-byte start packet header
payload = message.payload
if len(payload) + 2 <= self.l2cap_channel.mtu:
if len(payload) + 2 <= self.l2cap_channel.peer_mtu:
# Fits in a single packet
packet_type = self.PacketType.SINGLE_PACKET
else:
Expand Down
2 changes: 1 addition & 1 deletion bumble/hid.py
Expand Up @@ -416,7 +416,7 @@ def handle_get_report(self, pdu: bytes):
data = bytearray()
data.append(report_id)
data.extend(ret.data)
if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr]
if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
Expand Down
13 changes: 8 additions & 5 deletions bumble/l2cap.py
Expand Up @@ -173,7 +173,7 @@
@dataclasses.dataclass
class ClassicChannelSpec:
psm: Optional[int] = None
mtu: int = L2CAP_MIN_BR_EDR_MTU
mtu: int = L2CAP_DEFAULT_MTU


@dataclasses.dataclass
Expand Down Expand Up @@ -749,6 +749,8 @@ class State(enum.IntEnum):
sink: Optional[Callable[[bytes], Any]]
state: State
connection: Connection
mtu: int
peer_mtu: int

def __init__(
self,
Expand All @@ -765,6 +767,7 @@ def __init__(
self.signaling_cid = signaling_cid
self.state = self.State.CLOSED
self.mtu = mtu
self.peer_mtu = L2CAP_MIN_BR_EDR_MTU
self.psm = psm
self.source_cid = source_cid
self.destination_cid = 0
Expand Down Expand Up @@ -861,7 +864,7 @@ def send_configure_request(self) -> None:
[
(
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE,
struct.pack('<H', L2CAP_DEFAULT_MTU),
struct.pack('<H', self.mtu),
)
]
)
Expand Down Expand Up @@ -926,8 +929,8 @@ def on_configure_request(self, request) -> None:
options = L2CAP_Control_Frame.decode_configuration_options(request.options)
for option in options:
if option[0] == L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE:
self.mtu = struct.unpack('<H', option[1])[0]
logger.debug(f'MTU = {self.mtu}')
self.peer_mtu = struct.unpack('<H', option[1])[0]
logger.debug(f'peer MTU = {self.peer_mtu}')

self.send_control_frame(
L2CAP_Configure_Response(
Expand Down Expand Up @@ -1026,7 +1029,7 @@ def __str__(self) -> str:
return (
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
f'MTU={self.mtu}, '
f'MTU={self.mtu}/{self.peer_mtu}, '
f'state={self.state.name})'
)

Expand Down
17 changes: 12 additions & 5 deletions bumble/rfcomm.py
Expand Up @@ -104,6 +104,7 @@ class MccType(enum.IntEnum):
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
])

RFCOMM_DEFAULT_L2CAP_MTU = 2048
RFCOMM_DEFAULT_WINDOW_SIZE = 7
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000

Expand Down Expand Up @@ -473,7 +474,7 @@ def __init__(
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead
)

def change_state(self, new_state: State) -> None:
Expand Down Expand Up @@ -908,16 +909,19 @@ class Client:
multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.ClassicChannel]

def __init__(self, connection: Connection) -> None:
def __init__(
self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None:
self.connection = connection
self.l2cap_mtu = l2cap_mtu
self.l2cap_channel = None
self.multiplexer = None

async def start(self) -> Multiplexer:
# Create a new L2CAP connection
try:
self.l2cap_channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(RFCOMM_PSM)
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu)
)
except ProtocolError as error:
logger.warning(f'L2CAP connection failed: {error}')
Expand Down Expand Up @@ -955,15 +959,18 @@ async def __aexit__(self, *args) -> None:
class Server(EventEmitter):
acceptors: Dict[int, Callable[[DLC], None]]

def __init__(self, device: Device) -> None:
def __init__(
self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None:
super().__init__()
self.device = device
self.multiplexer = None
self.acceptors = {}

# Register ourselves with the L2CAP channel manager
self.l2cap_server = device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu),
handler=self.on_connection,
)

def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int:
Expand Down
4 changes: 2 additions & 2 deletions examples/run_a2dp_source.py
Expand Up @@ -74,7 +74,7 @@ def codec_capabilities():
# -----------------------------------------------------------------------------
def on_avdtp_connection(read_function, protocol):
packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities()
read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets)
protocol.add_source(packet_source.codec_capabilities, packet_pump)
Expand All @@ -98,7 +98,7 @@ async def stream_packets(read_function, protocol):

# Stream the packets
packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities()
read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets)
source = protocol.add_source(packet_source.codec_capabilities, packet_pump)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -52,7 +52,7 @@ install_requires =
pyserial-asyncio >= 0.5; platform_system!='Emscripten'
pyserial >= 3.5; platform_system!='Emscripten'
pyusb >= 1.2; platform_system!='Emscripten'
websockets >= 8.1; platform_system!='Emscripten'
websockets >= 12.0; platform_system!='Emscripten'

[options.entry_points]
console_scripts =
Expand Down
22 changes: 22 additions & 0 deletions tests/l2cap_test.py
Expand Up @@ -227,12 +227,34 @@ def on_client_data(data):
assert server_received_bytes == message_bytes


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu():
devices = TwoDevices()
await devices.setup_connection()

def on_channel_open(channel):
assert channel.peer_mtu == 456

def on_channel(channel):
channel.on('open', lambda: on_channel_open(channel))

server = devices.devices[1].create_l2cap_server(
spec=ClassicChannelSpec(mtu=345), handler=on_channel
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=ClassicChannelSpec(server.psm, mtu=456)
)
assert client_channel.peer_mtu == 345


# -----------------------------------------------------------------------------
async def run():
test_helpers()
await test_basic_connection()
await test_transfer()
await test_bidirectional_transfer()
await test_mtu()


# -----------------------------------------------------------------------------
Expand Down
6 changes: 4 additions & 2 deletions tests/rfcomm_test.py
Expand Up @@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
import asyncio
import pytest
from typing import List

from . import test_utils
from bumble import core
Expand Down Expand Up @@ -59,17 +60,18 @@ def test_frames():

# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_basic_connection():
async def test_basic_connection() -> None:
devices = test_utils.TwoDevices()
await devices.setup_connection()

accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future()
channel = Server(devices[0]).listen(acceptor=accept_future.set_result)

assert devices.connections[1]
multiplexer = await Client(devices.connections[1]).start()
dlcs = await asyncio.gather(accept_future, multiplexer.open_dlc(channel))

queues = [asyncio.Queue(), asyncio.Queue()]
queues: List[asyncio.Queue] = [asyncio.Queue(), asyncio.Queue()]
for dlc, queue in zip(dlcs, queues):
dlc.sink = queue.put_nowait

Expand Down

0 comments on commit f4aeaa6

Please sign in to comment.