diff --git a/apps/bench.py b/apps/bench.py index 1f9d45f7..83625f00 100644 --- a/apps/bench.py +++ b/apps/bench.py @@ -87,6 +87,7 @@ DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0 DEFAULT_RFCOMM_CHANNEL = 8 +DEFAULT_RFCOMM_MTU = 2048 # ----------------------------------------------------------------------------- @@ -896,11 +897,14 @@ async def drain(self): # RfcommClient # ----------------------------------------------------------------------------- class RfcommClient(StreamedPacketIO): - def __init__(self, device, channel, uuid): + def __init__(self, device, channel, uuid, l2cap_mtu, max_frame_size, window_size): super().__init__() self.device = device self.channel = channel self.uuid = uuid + self.l2cap_mtu = l2cap_mtu + self.max_frame_size = max_frame_size + self.window_size = window_size self.rfcomm_session = None self.ready = asyncio.Event() @@ -924,13 +928,21 @@ async def on_connection(self, connection): # Create a client and start it logging.info(color('*** Starting RFCOMM client...', 'blue')) - rfcomm_client = bumble.rfcomm.Client(connection) + rfcomm_options = {} + if self.l2cap_mtu: + rfcomm_options['l2cap_mtu'] = self.l2cap_mtu + rfcomm_client = bumble.rfcomm.Client(connection, **rfcomm_options) rfcomm_mux = await rfcomm_client.start() logging.info(color('*** Started', 'blue')) logging.info(color(f'### Opening session for channel {channel}...', 'yellow')) try: - rfcomm_session = await rfcomm_mux.open_dlc(channel) + dlc_options = {} + if self.max_frame_size: + dlc_options['max_frame_size'] = self.max_frame_size + if self.window_size: + dlc_options['window_size'] = self.window_size + rfcomm_session = await rfcomm_mux.open_dlc(channel, **dlc_options) logging.info(color(f'### Session open: {rfcomm_session}', 'yellow')) except bumble.core.ConnectionError as error: logging.info(color(f'!!! Session open failed: {error}', 'red')) @@ -955,13 +967,16 @@ async def drain(self): # RfcommServer # ----------------------------------------------------------------------------- class RfcommServer(StreamedPacketIO): - def __init__(self, device, channel): + def __init__(self, device, channel, l2cap_mtu): super().__init__() self.dlc = None self.ready = asyncio.Event() # Create and register a server - rfcomm_server = bumble.rfcomm.Server(device) + server_options = {} + if l2cap_mtu: + server_options['l2cap_mtu'] = l2cap_mtu + rfcomm_server = bumble.rfcomm.Server(device, **server_options) # Listen for incoming DLC connections channel_number = rfcomm_server.listen(self.on_dlc, channel) @@ -1298,11 +1313,20 @@ def create_mode(device): if mode == 'rfcomm-client': return RfcommClient( - device, channel=ctx.obj['rfcomm_channel'], uuid=ctx.obj['rfcomm_uuid'] + device, + channel=ctx.obj['rfcomm_channel'], + uuid=ctx.obj['rfcomm_uuid'], + l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'], + max_frame_size=ctx.obj['rfcomm_max_frame_size'], + window_size=ctx.obj['rfcomm_window_size'], ) if mode == 'rfcomm-server': - return RfcommServer(device, channel=ctx.obj['rfcomm_channel']) + return RfcommServer( + device, + channel=ctx.obj['rfcomm_channel'], + l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'], + ) raise ValueError('invalid mode') @@ -1389,6 +1413,21 @@ def create_role(packet_io): default=DEFAULT_RFCOMM_UUID, help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)', ) +@click.option( + '--rfcomm-l2cap-mtu', + type=int, + help='RFComm L2CAP MTU', +) +@click.option( + '--rfcomm-max-frame-size', + type=int, + help='RFComm maximum frame size', +) +@click.option( + '--rfcomm-window-size', + type=int, + help='RFComm window size', +) @click.option( '--l2cap-psm', type=int, @@ -1486,6 +1525,9 @@ def bench( linger, rfcomm_channel, rfcomm_uuid, + rfcomm_l2cap_mtu, + rfcomm_max_frame_size, + rfcomm_window_size, l2cap_psm, l2cap_mtu, l2cap_mps, @@ -1498,6 +1540,9 @@ def bench( ctx.obj['att_mtu'] = att_mtu ctx.obj['rfcomm_channel'] = rfcomm_channel ctx.obj['rfcomm_uuid'] = rfcomm_uuid + ctx.obj['rfcomm_l2cap_mtu'] = rfcomm_l2cap_mtu + ctx.obj['rfcomm_max_frame_size'] = rfcomm_max_frame_size + ctx.obj['rfcomm_window_size'] = rfcomm_window_size ctx.obj['l2cap_psm'] = l2cap_psm ctx.obj['l2cap_mtu'] = l2cap_mtu ctx.obj['l2cap_mps'] = l2cap_mps diff --git a/bumble/avdtp.py b/bumble/avdtp.py index 3be1e157..f7851099 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -1470,10 +1470,10 @@ def send_message(self, transaction_label: int, message: Message) -> None: f'[{transaction_label}] {message}' ) max_fragment_size = ( - self.l2cap_channel.mtu - 3 + self.l2cap_channel.peer_mtu - 3 ) # Enough space for a 3-byte start packet header payload = message.payload - if len(payload) + 2 <= self.l2cap_channel.mtu: + if len(payload) + 2 <= self.l2cap_channel.peer_mtu: # Fits in a single packet packet_type = self.PacketType.SINGLE_PACKET else: diff --git a/bumble/hid.py b/bumble/hid.py index 5ea9b98a..fc5c8074 100644 --- a/bumble/hid.py +++ b/bumble/hid.py @@ -416,7 +416,7 @@ def handle_get_report(self, pdu: bytes): data = bytearray() data.append(report_id) data.extend(ret.data) - if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr] + if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr] self.send_control_data(report_type=report_type, data=data) else: self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index f91a269f..cec14b85 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -173,7 +173,7 @@ @dataclasses.dataclass class ClassicChannelSpec: psm: Optional[int] = None - mtu: int = L2CAP_MIN_BR_EDR_MTU + mtu: int = L2CAP_DEFAULT_MTU @dataclasses.dataclass @@ -749,6 +749,8 @@ class State(enum.IntEnum): sink: Optional[Callable[[bytes], Any]] state: State connection: Connection + mtu: int + peer_mtu: int def __init__( self, @@ -765,6 +767,7 @@ def __init__( self.signaling_cid = signaling_cid self.state = self.State.CLOSED self.mtu = mtu + self.peer_mtu = L2CAP_MIN_BR_EDR_MTU self.psm = psm self.source_cid = source_cid self.destination_cid = 0 @@ -861,7 +864,7 @@ def send_configure_request(self) -> None: [ ( L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE, - struct.pack(' None: options = L2CAP_Control_Frame.decode_configuration_options(request.options) for option in options: if option[0] == L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE: - self.mtu = struct.unpack(' str: return ( f'Channel({self.source_cid}->{self.destination_cid}, ' f'PSM={self.psm}, ' - f'MTU={self.mtu}, ' + f'MTU={self.mtu}/{self.peer_mtu}, ' f'state={self.state.name})' ) diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 5500bc12..6ca0f509 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -104,6 +104,7 @@ class MccType(enum.IntEnum): 0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF ]) +RFCOMM_DEFAULT_L2CAP_MTU = 2048 RFCOMM_DEFAULT_WINDOW_SIZE = 7 RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000 @@ -473,7 +474,7 @@ def __init__( # Compute the MTU max_overhead = 4 + 1 # header with 2-byte length + fcs self.mtu = min( - max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead + max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead ) def change_state(self, new_state: State) -> None: @@ -908,8 +909,11 @@ class Client: multiplexer: Optional[Multiplexer] l2cap_channel: Optional[l2cap.ClassicChannel] - def __init__(self, connection: Connection) -> None: + def __init__( + self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU + ) -> None: self.connection = connection + self.l2cap_mtu = l2cap_mtu self.l2cap_channel = None self.multiplexer = None @@ -917,7 +921,7 @@ async def start(self) -> Multiplexer: # Create a new L2CAP connection try: self.l2cap_channel = await self.connection.create_l2cap_channel( - spec=l2cap.ClassicChannelSpec(RFCOMM_PSM) + spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu) ) except ProtocolError as error: logger.warning(f'L2CAP connection failed: {error}') @@ -955,7 +959,9 @@ async def __aexit__(self, *args) -> None: class Server(EventEmitter): acceptors: Dict[int, Callable[[DLC], None]] - def __init__(self, device: Device) -> None: + def __init__( + self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU + ) -> None: super().__init__() self.device = device self.multiplexer = None @@ -963,7 +969,8 @@ def __init__(self, device: Device) -> None: # Register ourselves with the L2CAP channel manager self.l2cap_server = device.create_l2cap_server( - spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection + spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu), + handler=self.on_connection, ) def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int: diff --git a/examples/run_a2dp_source.py b/examples/run_a2dp_source.py index 92812fe1..46452293 100644 --- a/examples/run_a2dp_source.py +++ b/examples/run_a2dp_source.py @@ -74,7 +74,7 @@ def codec_capabilities(): # ----------------------------------------------------------------------------- def on_avdtp_connection(read_function, protocol): packet_source = SbcPacketSource( - read_function, protocol.l2cap_channel.mtu, codec_capabilities() + read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities() ) packet_pump = MediaPacketPump(packet_source.packets) protocol.add_source(packet_source.codec_capabilities, packet_pump) @@ -98,7 +98,7 @@ async def stream_packets(read_function, protocol): # Stream the packets packet_source = SbcPacketSource( - read_function, protocol.l2cap_channel.mtu, codec_capabilities() + read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities() ) packet_pump = MediaPacketPump(packet_source.packets) source = protocol.add_source(packet_source.codec_capabilities, packet_pump) diff --git a/setup.cfg b/setup.cfg index 8ef11b1c..91d61832 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ install_requires = pyserial-asyncio >= 0.5; platform_system!='Emscripten' pyserial >= 3.5; platform_system!='Emscripten' pyusb >= 1.2; platform_system!='Emscripten' - websockets >= 8.1; platform_system!='Emscripten' + websockets >= 12.0; platform_system!='Emscripten' [options.entry_points] console_scripts = diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index 5cb285c3..6323ddfa 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -227,12 +227,34 @@ def on_client_data(data): assert server_received_bytes == message_bytes +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_mtu(): + devices = TwoDevices() + await devices.setup_connection() + + def on_channel_open(channel): + assert channel.peer_mtu == 456 + + def on_channel(channel): + channel.on('open', lambda: on_channel_open(channel)) + + server = devices.devices[1].create_l2cap_server( + spec=ClassicChannelSpec(mtu=345), handler=on_channel + ) + client_channel = await devices.connections[0].create_l2cap_channel( + spec=ClassicChannelSpec(server.psm, mtu=456) + ) + assert client_channel.peer_mtu == 345 + + # ----------------------------------------------------------------------------- async def run(): test_helpers() await test_basic_connection() await test_transfer() await test_bidirectional_transfer() + await test_mtu() # ----------------------------------------------------------------------------- diff --git a/tests/rfcomm_test.py b/tests/rfcomm_test.py index 2ab3c2c4..4ce4d116 100644 --- a/tests/rfcomm_test.py +++ b/tests/rfcomm_test.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- import asyncio import pytest +from typing import List from . import test_utils from bumble import core @@ -59,17 +60,18 @@ def test_frames(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_basic_connection(): +async def test_basic_connection() -> None: devices = test_utils.TwoDevices() await devices.setup_connection() accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future() channel = Server(devices[0]).listen(acceptor=accept_future.set_result) + assert devices.connections[1] multiplexer = await Client(devices.connections[1]).start() dlcs = await asyncio.gather(accept_future, multiplexer.open_dlc(channel)) - queues = [asyncio.Queue(), asyncio.Queue()] + queues: List[asyncio.Queue] = [asyncio.Queue(), asyncio.Queue()] for dlc, queue in zip(dlcs, queues): dlc.sink = queue.put_nowait