Skip to content

Commit

Permalink
Merge pull request #234 from zxzxwu/addr
Browse files Browse the repository at this point in the history
Support address resolution offload
  • Loading branch information
barbibulle committed Aug 9, 2023
2 parents 53d66bc + 6399c5f commit fe28473
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 51 deletions.
56 changes: 32 additions & 24 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
HCI_LE_Extended_Create_Connection_Command,
HCI_LE_Rand_Command,
HCI_LE_Read_PHY_Command,
HCI_LE_Set_Address_Resolution_Enable_Command,
HCI_LE_Set_Advertising_Data_Command,
HCI_LE_Set_Advertising_Enable_Command,
HCI_LE_Set_Advertising_Parameters_Command,
Expand Down Expand Up @@ -778,6 +779,7 @@ def __init__(self) -> None:
self.irk = bytes(16) # This really must be changed for any level of security
self.keystore = None
self.gatt_services: List[Dict[str, Any]] = []
self.address_resolution_offload = False

def load_from_dict(self, config: Dict[str, Any]) -> None:
# Load simple properties
Expand Down Expand Up @@ -1029,6 +1031,7 @@ def __init__(
self.discoverable = config.discoverable
self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any
self.address_resolution_offload = config.address_resolution_offload

for service in config.gatt_services:
characteristics = []
Expand Down Expand Up @@ -1256,31 +1259,16 @@ async def power_on(self) -> None:
)

# Load the address resolving list
if self.keystore and self.host.supports_command(
HCI_LE_CLEAR_RESOLVING_LIST_COMMAND
):
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]

resolving_keys = await self.keystore.get_resolving_keys()
for irk, address in resolving_keys:
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
peer_identity_address_type=address.address_type,
peer_identity_address=address,
peer_irk=irk,
local_irk=self.irk,
) # type: ignore[call-arg]
)

# Enable address resolution
# await self.send_command(
# HCI_LE_Set_Address_Resolution_Enable_Command(
# address_resolution_enable=1)
# )
# )
if self.keystore:
await self.refresh_resolving_list()

# Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys)
# Enable address resolution
if self.address_resolution_offload:
await self.send_command(
HCI_LE_Set_Address_Resolution_Enable_Command(
address_resolution_enable=1
) # type: ignore[call-arg]
)

if self.classic_enabled:
await self.send_command(
Expand Down Expand Up @@ -1310,6 +1298,26 @@ async def power_off(self) -> None:
await self.host.flush()
self.powered_on = False

async def refresh_resolving_list(self) -> None:
assert self.keystore is not None

resolving_keys = await self.keystore.get_resolving_keys()
# Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys)

if self.address_resolution_offload:
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]

for irk, address in resolving_keys:
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
peer_identity_address_type=address.address_type,
peer_identity_address=address,
peer_irk=irk,
local_irk=self.irk,
) # type: ignore[call-arg]
)

def supports_le_feature(self, feature):
return self.host.supports_le_feature(feature)

Expand Down
15 changes: 4 additions & 11 deletions bumble/smp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ async def on_pairing(self) -> None:
keys.link_key = PairingKeys.Key(
value=self.link_key, authenticated=authenticated
)
self.manager.on_pairing(self, peer_address, keys)
await self.manager.on_pairing(self, peer_address, keys)

def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})')
Expand Down Expand Up @@ -1827,20 +1827,13 @@ def request_pairing(self, connection: Connection) -> None:
def on_session_start(self, session: Session) -> None:
self.device.on_pairing_start(session.connection)

def on_pairing(
async def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:

async def store_keys():
try:
assert self.device.keystore
await self.device.keystore.update(str(identity_address), keys)
except Exception as error:
logger.warning(f'!!! error while storing keys: {error}')

self.device.abort_on('flush', store_keys())
await self.device.keystore.update(str(identity_address), keys)
await self.device.refresh_resolving_list()

# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
Expand Down
31 changes: 15 additions & 16 deletions tests/self_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,16 @@ def __init__(self):
),
]

self.paired = [None, None]
self.paired = [
asyncio.get_event_loop().create_future(),
asyncio.get_event_loop().create_future(),
]

def on_connection(self, which, connection):
self.connections[which] = connection

def on_paired(self, which, keys):
self.paired[which] = keys
def on_paired(self, which: int, keys: PairingKeys):
self.paired[which].set_result(keys)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -323,8 +326,8 @@ async def _test_self_smp_with_configs(pairing_config1, pairing_config2):
# Pair
await two_devices.devices[0].pair(connection)
assert connection.is_encrypted
assert two_devices.paired[0] is not None
assert two_devices.paired[1] is not None
assert await two_devices.paired[0] is not None
assert await two_devices.paired[1] is not None


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -527,16 +530,12 @@ async def test_self_smp_over_classic():
two_devices.connections[0].encryption = 1
two_devices.connections[1].encryption = 1

paired = [
asyncio.get_event_loop().create_future(),
asyncio.get_event_loop().create_future(),
]

def on_pairing(which: int, keys: PairingKeys):
paired[which].set_result(keys)

two_devices.connections[0].on('pairing', lambda keys: on_pairing(0, keys))
two_devices.connections[1].on('pairing', lambda keys: on_pairing(1, keys))
two_devices.connections[0].on(
'pairing', lambda keys: two_devices.on_paired(0, keys)
)
two_devices.connections[1].on(
'pairing', lambda keys: two_devices.on_paired(1, keys)
)

# Mock SMP
with patch('bumble.smp.Session', spec=True) as MockSmpSession:
Expand All @@ -547,7 +546,7 @@ def on_pairing(which: int, keys: PairingKeys):

# Start CTKD
await two_devices.connections[0].pair()
await asyncio.gather(*paired)
await asyncio.gather(*two_devices.paired)

# Phase 2 commands should not be invoked
MockSmpSession.send_pairing_confirm_command.assert_not_called()
Expand Down

0 comments on commit fe28473

Please sign in to comment.