Skip to content

Commit

Permalink
Merge pull request #271 from zxzxwu/device_typing
Browse files Browse the repository at this point in the history
Typing transport and relateds
  • Loading branch information
zxzxwu committed Sep 8, 2023
2 parents a1b6eb6 + b312170 commit 01603ca
Show file tree
Hide file tree
Showing 19 changed files with 188 additions and 99 deletions.
8 changes: 6 additions & 2 deletions bumble/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import logging
import asyncio
import itertools
Expand Down Expand Up @@ -58,8 +60,10 @@
HCI_Packet,
HCI_Role_Change_Event,
)
from typing import Optional, Union, Dict
from typing import Optional, Union, Dict, TYPE_CHECKING

if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource

# -----------------------------------------------------------------------------
# Logging
Expand Down Expand Up @@ -104,7 +108,7 @@ def __init__(
self,
name,
host_source=None,
host_sink=None,
host_sink: Optional[TransportSink] = None,
link=None,
public_address: Optional[Union[bytes, str, Address]] = None,
):
Expand Down
44 changes: 37 additions & 7 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@
import logging
from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
TYPE_CHECKING,
)

from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
Expand Down Expand Up @@ -152,6 +163,9 @@
from . import l2cap
from . import core

if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink


# -----------------------------------------------------------------------------
# Logging
Expand Down Expand Up @@ -942,7 +956,13 @@ def on_characteristic_subscription(
pass

@classmethod
def with_hci(cls, name, address, hci_source, hci_sink):
def with_hci(
cls,
name: str,
address: Address,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
'''
Create a Device instance with a Host configured to communicate with a controller
through an HCI source/sink
Expand All @@ -951,18 +971,25 @@ def with_hci(cls, name, address, hci_source, hci_sink):
return cls(name=name, address=address, host=host)

@classmethod
def from_config_file(cls, filename):
def from_config_file(cls, filename: str) -> Device:
config = DeviceConfiguration()
config.load_from_file(filename)
return cls(config=config)

@classmethod
def from_config_with_hci(cls, config, hci_source, hci_sink):
def from_config_with_hci(
cls,
config: DeviceConfiguration,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
host = Host(controller_source=hci_source, controller_sink=hci_sink)
return cls(config=config, host=host)

@classmethod
def from_config_file_with_hci(cls, filename, hci_source, hci_sink):
def from_config_file_with_hci(
cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink
) -> Device:
config = DeviceConfiguration()
config.load_from_file(filename)
return cls.from_config_with_hci(config, hci_source, hci_sink)
Expand Down Expand Up @@ -2238,9 +2265,11 @@ async def pair(self, connection):
def request_pairing(self, connection):
return self.smp_manager.request_pairing(connection)

async def get_long_term_key(self, connection_handle, rand, ediv):
async def get_long_term_key(
self, connection_handle: int, rand: bytes, ediv: int
) -> Optional[bytes]:
if (connection := self.lookup_connection(connection_handle)) is None:
return
return None

# Start by looking for the key in an SMP session
ltk = self.smp_manager.get_long_term_key(connection, rand, ediv)
Expand All @@ -2260,6 +2289,7 @@ async def get_long_term_key(self, connection_handle, rand, ediv):

if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
return keys.ltk_peripheral.value
return None

async def get_link_key(self, address: Address) -> Optional[bytes]:
if self.keystore is None:
Expand Down
26 changes: 20 additions & 6 deletions bumble/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import struct

from typing import Optional
from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable

from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
Expand Down Expand Up @@ -73,10 +73,14 @@
BT_LE_TRANSPORT,
ConnectionPHY,
ConnectionParameters,
InvalidStateError,
)
from .utils import AbortableEventEmitter
from .transport.common import TransportLostError

if TYPE_CHECKING:
from .transport.common import TransportSink, TransportSource


# -----------------------------------------------------------------------------
# Logging
Expand Down Expand Up @@ -116,10 +120,21 @@ def on_acl_pdu(self, pdu: bytes) -> None:

# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None):
connections: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket]
hci_sink: TransportSink
long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
]
link_key_provider: Optional[Callable[[Address], Awaitable[Optional[bytes]]]]

def __init__(
self,
controller_source: Optional[TransportSource] = None,
controller_sink: Optional[TransportSink] = None,
) -> None:
super().__init__()

self.hci_sink = None
self.hci_metadata = None
self.ready = False # True when we can accept incoming packets
self.reset_done = False
Expand Down Expand Up @@ -299,7 +314,7 @@ async def reset(self, driver_factory=drivers.get_driver_for_host):
self.reset_done = True

@property
def controller(self):
def controller(self) -> TransportSink:
return self.hci_sink

@controller.setter
Expand All @@ -308,13 +323,12 @@ def controller(self, controller):
if controller:
controller.set_packet_sink(self)

def set_packet_sink(self, sink):
def set_packet_sink(self, sink: TransportSink) -> None:
self.hci_sink = sink

def send_hci_packet(self, packet: HCI_Packet) -> None:
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)

self.hci_sink.on_packet(bytes(packet))

async def send_command(self, command, check_result=False):
Expand Down
11 changes: 7 additions & 4 deletions bumble/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import os

from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..controller import Controller
from ..snoop import create_snooper

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -119,7 +118,8 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'file':
from .file import open_file_transport

return await open_file_transport(spec[0] if spec else None)
assert spec is not None
return await open_file_transport(spec[0])

if scheme == 'vhci':
from .vhci import open_vhci_transport
Expand All @@ -134,12 +134,14 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'usb':
from .usb import open_usb_transport

return await open_usb_transport(spec[0] if spec else None)
assert spec is not None
return await open_usb_transport(spec[0])

if scheme == 'pyusb':
from .pyusb import open_pyusb_transport

return await open_pyusb_transport(spec[0] if spec else None)
assert spec is not None
return await open_pyusb_transport(spec[0])

if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
Expand Down Expand Up @@ -168,6 +170,7 @@ async def open_transport_or_link(name: str) -> Transport:
"""
if name.startswith('link-relay:'):
from ..controller import Controller
from ..link import RemoteLink # lazy import

link = RemoteLink(name[11:])
Expand Down
7 changes: 4 additions & 3 deletions bumble/transport/android_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import grpc.aio

from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport

# pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
Expand All @@ -33,7 +33,7 @@


# -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec):
async def open_android_emulator_transport(spec: str | None) -> Transport:
'''
Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax:
Expand Down Expand Up @@ -66,7 +66,7 @@ async def write(self, packet):
# Parse the parameters
mode = 'host'
server_host = 'localhost'
server_port = 8554
server_port = '8554'
if spec is not None:
params = spec.split(',')
for param in params:
Expand All @@ -82,6 +82,7 @@ async def write(self, packet):
logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address)

service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
if mode == 'host':
# Connect as a host
service = EmulatedBluetoothServiceStub(channel)
Expand Down
4 changes: 3 additions & 1 deletion bumble/transport/android_netsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def cleanup():


# -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(server_host, server_port):
async def open_android_netsim_controller_transport(
server_host: str | None, server_port: int
) -> Transport:
if not server_port:
raise ValueError('invalid port')
if server_host == '_' or not server_host:
Expand Down

0 comments on commit 01603ca

Please sign in to comment.