Skip to content

Commit

Permalink
Merge pull request #129 from google/gbg/smp-improvements
Browse files Browse the repository at this point in the history
improve smp compatibility with other OS flows
  • Loading branch information
barbibulle committed Feb 15, 2023
2 parents 9874bb3 + a8beb6b commit fbc7cf0
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 76 deletions.
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,10 @@
"editor.rulers": [88]
},
"python.formatting.provider": "black",
"pylint.importStrategy": "useBundled"
"pylint.importStrategy": "useBundled",
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
113 changes: 72 additions & 41 deletions apps/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import os
import logging
import click
import aioconsole
from colors import color
from prompt_toolkit.shortcuts import PromptSession

from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link
Expand All @@ -42,9 +42,23 @@
)


# -----------------------------------------------------------------------------
class Waiter:
instance = None

def __init__(self):
self.done = asyncio.get_running_loop().create_future()

def terminate(self):
self.done.set_result(None)

async def wait_until_terminated(self):
return await self.done


# -----------------------------------------------------------------------------
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, prompt):
def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__(
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
Expand All @@ -58,7 +72,18 @@ def __init__(self, mode, connection, capability_string, prompt):
self.mode = mode
self.peer = Peer(connection)
self.peer_name = None
self.prompt = prompt
self.do_prompt = do_prompt

def print(self, message):
print(color(message, 'yellow'))

async def prompt(self, message):
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)

session = PromptSession(message)
response = await session.prompt_async()
return response.lower().strip()

async def update_peer_name(self):
if self.peer_name is not None:
Expand All @@ -73,19 +98,15 @@ async def update_peer_name(self):
self.peer_name = '[?]'

async def accept(self):
if self.prompt:
if self.do_prompt:
await self.update_peer_name()

# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)

# Prompt for acceptance
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing request from {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing request from {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow'))
response = response.lower().strip()
response = await self.prompt('>>> Accept? ')

if response == 'yes':
return True
Expand All @@ -96,23 +117,17 @@ async def accept(self):
# Accept silently
return True

async def compare_numbers(self, number, digits=6):
async def compare_numbers(self, number, digits):
await self.update_peer_name()

# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)

# Prompt for a numeric comparison
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await aioconsole.ainput(
color(
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
)
response = await self.prompt(
f'>>> Does the other device display {number:0{digits}}? '
)
response = response.lower().strip()

if response == 'yes':
return True
Expand All @@ -123,30 +138,24 @@ async def compare_numbers(self, number, digits=6):
async def get_number(self):
await self.update_peer_name()

# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)

# Prompt for a PIN
while True:
try:
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
return int(await aioconsole.ainput(color('>>> Enter PIN: ', 'yellow')))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
return int(await self.prompt('>>> Enter PIN: '))
except ValueError:
pass

async def display_number(self, number, digits=6):
async def display_number(self, number, digits):
await self.update_peer_name()

# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)

# Display a PIN code
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color(f'### PIN: {number:0{digits}}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print(f'### PIN: {number:0{digits}}')
self.print('###-----------------------------------')


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -238,13 +247,15 @@ def on_pairing(keys):
print(color('*** Paired!', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()


# -----------------------------------------------------------------------------
def on_pairing_failure(reason):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red'))
Waiter.instance.terminate()


# -----------------------------------------------------------------------------
Expand All @@ -262,6 +273,8 @@ async def pair(
hci_transport,
address_or_name,
):
Waiter.instance = Waiter()

print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
Expand Down Expand Up @@ -332,7 +345,19 @@ async def pair(
# Advertise so that peers can find us and connect
await device.start_advertising(auto_restart=True)

await hci_source.wait_for_termination()
# Run until the user asks to exit
await Waiter.instance.wait_until_terminated()


# -----------------------------------------------------------------------------
class LogHandler(logging.Handler):
def __init__(self):
super().__init__()
self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))

def emit(self, record):
message = self.format(record)
print(message)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -388,7 +413,13 @@ def main(
hci_transport,
address_or_name,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Setup logging
log_handler = LogHandler()
root_logger = logging.getLogger()
root_logger.addHandler(log_handler)
root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())

# Pair
asyncio.run(
pair(
mode,
Expand Down
74 changes: 46 additions & 28 deletions bumble/smp.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,13 +518,15 @@ async def accept(self) -> bool:
async def confirm(self) -> bool:
return True

async def compare_numbers(self, _number: int, _digits: int = 6) -> bool:
# pylint: disable-next=unused-argument
async def compare_numbers(self, number: int, digits: int) -> bool:
return True

async def get_number(self) -> int:
return 0

async def display_number(self, _number: int, _digits: int = 6) -> None:
# pylint: disable-next=unused-argument
async def display_number(self, number: int, digits: int) -> None:
pass

async def key_distribution_response(
Expand Down Expand Up @@ -661,7 +663,8 @@ def __init__(self, manager, connection, pairing_config):
self.peer_expected_distributions = []
self.dh_key = None
self.confirm_value = None
self.passkey = 0
self.passkey = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method = 0
Expand Down Expand Up @@ -839,6 +842,7 @@ def display_passkey(self):
# Generate random Passkey/PIN code
self.passkey = secrets.randbelow(1000000)
logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
self.passkey_ready.set()

# The value of TK is computed from the PIN code
if not self.sc:
Expand All @@ -859,6 +863,8 @@ def after_input(passkey):
self.tk = passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}')

self.passkey_ready.set()

if next_steps is not None:
next_steps()

Expand Down Expand Up @@ -910,17 +916,29 @@ def send_pairing_confirm_command(self):
logger.debug(f'generated random: {self.r.hex()}')

if self.sc:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return

if self.is_initiator:
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
else:
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
async def next_steps():
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
# We need a passkey
await self.passkey_ready.wait()

z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return

if self.is_initiator:
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
else:
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))

self.send_command(
SMP_Pairing_Confirm_Command(confirm_value=confirm_value)
)

# Perform the next steps asynchronously in case we need to wait for input
self.connection.abort_on('disconnection', next_steps())
else:
confirm_value = crypto.c1(
self.tk,
Expand All @@ -933,7 +951,7 @@ def send_pairing_confirm_command(self):
self.ra,
)

self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))

def send_pairing_random_command(self):
self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
Expand Down Expand Up @@ -1364,8 +1382,8 @@ def on_smp_pairing_response_command(self, command):

# Start phase 2
if self.sc:
if self.pairing_method == self.PASSKEY and self.passkey_display:
self.display_passkey()
if self.pairing_method == self.PASSKEY:
self.display_or_input_passkey()

self.send_public_key_command()
else:
Expand Down Expand Up @@ -1426,18 +1444,22 @@ def on_smp_pairing_random_command_legacy(self, command):
else:
srand = self.r
mrand = command.random_value
stk = crypto.s1(self.tk, srand, mrand)
logger.debug(f'STK = {stk.hex()}')
self.stk = crypto.s1(self.tk, srand, mrand)
logger.debug(f'STK = {self.stk.hex()}')

# Generate LTK
self.ltk = crypto.r()

if self.is_initiator:
self.start_encryption(stk)
self.start_encryption(self.stk)
else:
self.send_pairing_random_command()

def on_smp_pairing_random_command_secure_connections(self, command):
if self.pairing_method == self.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command')
return

# pylint: disable=too-many-return-statements
if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
Expand Down Expand Up @@ -1565,17 +1587,13 @@ def on_smp_pairing_public_key_command(self, command):
logger.debug(f'DH key: {self.dh_key.hex()}')

if self.is_initiator:
if self.pairing_method == self.PASSKEY:
if self.passkey_display:
self.send_pairing_confirm_command()
else:
self.input_passkey(self.send_pairing_confirm_command)
self.send_pairing_confirm_command()
else:
# Send our public key back to the initiator
if self.pairing_method == self.PASSKEY:
self.display_or_input_passkey(self.send_public_key_command)
else:
self.send_public_key_command()
self.display_or_input_passkey()

# Send our public key back to the initiator
self.send_public_key_command()

if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
# We can now send the confirmation value
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ package_dir =
bumble.apps = apps
include-package-data = True
install_requires =
aioconsole >= 0.4.1
ansicolors >= 1.1
appdirs >= 1.4
click >= 7.1.2; platform_system!='Emscripten'
Expand Down

0 comments on commit fbc7cf0

Please sign in to comment.