Skip to content

Commit

Permalink
optimize key derivation in the wallet (#17991)
Browse files Browse the repository at this point in the history
optimize key derivation in the wallet. instead of deriving the same keys for every wallet, derive the keys once and re-use them for the derived puzzle hashes per wallet
  • Loading branch information
arvidn committed May 13, 2024
1 parent 98df4ce commit 41306ad
Showing 1 changed file with 72 additions and 57 deletions.
129 changes: 72 additions & 57 deletions chia/wallet/wallet_state_manager.py
Expand Up @@ -417,86 +417,101 @@ async def create_more_puzzle_hashes(
self.log.debug(f"Requested to generate puzzle hashes to at least index {unused}")
start_t = time.time()
to_generate = num_additional_phs if num_additional_phs is not None else self.initial_num_public_keys
new_paths: bool = False

# iterate all wallets that need derived keys and establish the start
# index for all of them
start_index: int = 0
start_index_by_wallet: Dict[uint32, int] = {}
last_index = unused + to_generate
for wallet_id in targets:
target_wallet = self.wallets[wallet_id]
if not target_wallet.require_derivation_paths():
self.log.debug("Skipping wallet %s as no derivation paths required", wallet_id)
continue
if from_zero:
start_index_by_wallet[wallet_id] = 0
continue
last: Optional[uint32] = await self.puzzle_store.get_last_derivation_path_for_wallet(wallet_id)
self.log.debug(
"Fetched last record for wallet %r: %s (from_zero=%r, unused=%r)", wallet_id, last, from_zero, unused
)
start_index = 0
derivation_paths: List[DerivationRecord] = []

if last is not None:
start_index = last + 1

# If the key was replaced (from_zero=True), we should generate the puzzle hashes for the new key
if from_zero:
start_index = 0
last_index = unused + to_generate
if start_index >= last_index:
self.log.debug(f"Nothing to create for for wallet_id: {wallet_id}, index: {start_index}")
if last + 1 >= last_index:
self.log.debug(f"Nothing to create for for wallet_id: {wallet_id}, index: {start_index}")
continue
start_index = min(start_index, last + 1)
start_index_by_wallet[wallet_id] = last + 1
else:
creating_msg = (
f"Creating puzzle hashes from {start_index} to {last_index - 1} for wallet_id: {wallet_id}"
)
self.log.info(f"Start: {creating_msg}")
if self.private_key is not None:
intermediate_sk = master_sk_to_wallet_sk_intermediate(self.private_key)
intermediate_pk_un = master_pk_to_wallet_pk_unhardened_intermediate(self.root_pubkey)
for index in range(start_index, last_index):
if target_wallet.type() == WalletType.POOLING_WALLET:
continue
start_index_by_wallet[wallet_id] = 0

if self.private_key is not None:
# Hardened
pubkey: G1Element = _derive_path(intermediate_sk, [index]).get_g1()
puzzlehash: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey)
self.log.debug(f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash.hex()}")
new_paths = True
derivation_paths.append(
DerivationRecord(
uint32(index),
puzzlehash,
pubkey,
target_wallet.type(),
uint32(target_wallet.id()),
True,
)
)
# Unhardened
pubkey_unhardened: G1Element = _derive_pk_unhardened(intermediate_pk_un, [index])
puzzlehash_unhardened: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey_unhardened)
self.log.debug(
f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash_unhardened.hex()}"
)
# We await sleep here to allow an asyncio context switch (since the other parts of this loop do
# not have await and therefore block). This can prevent networking layer from responding to ping.
await asyncio.sleep(0)
if len(start_index_by_wallet) == 0:
return

# now derive the keysfrom start_index to last_index
# these maps derivation index to public key
hardened_keys: Dict[int, G1Element] = {}
unhardened_keys: Dict[int, G1Element] = {}

if self.private_key is not None:
# Hardened
intermediate_sk = master_sk_to_wallet_sk_intermediate(self.private_key)
for index in range(start_index, last_index):
hardened_keys[index] = _derive_path(intermediate_sk, [index]).get_g1()

# Unhardened
intermediate_pk_un = master_pk_to_wallet_pk_unhardened_intermediate(self.root_pubkey)
for index in range(start_index, last_index):
unhardened_keys[index] = _derive_pk_unhardened(intermediate_pk_un, [index])

for wallet_id, start_index in start_index_by_wallet.items():
target_wallet = self.wallets[wallet_id]
assert target_wallet.type() != WalletType.POOLING_WALLET
assert start_index < last_index

derivation_paths: List[DerivationRecord] = []
creating_msg = f"Creating puzzle hashes from {start_index} to {last_index - 1} for wallet_id: {wallet_id}"
self.log.info(f"Start: {creating_msg}")
for index in range(start_index, last_index):
pubkey: Optional[G1Element] = hardened_keys.get(index)
if pubkey is not None:
# Hardened
puzzlehash: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey)
self.log.debug(f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash.hex()}")
derivation_paths.append(
DerivationRecord(
uint32(index),
puzzlehash_unhardened,
pubkey_unhardened,
puzzlehash,
pubkey,
target_wallet.type(),
uint32(target_wallet.id()),
False,
True,
)
)
self.log.info(f"Done: {creating_msg} Time: {time.time() - start_t} seconds")
await self.puzzle_store.add_derivation_paths(derivation_paths)
# Unhardened
pubkey = unhardened_keys.get(index)
assert pubkey is not None
puzzlehash_unhardened: bytes32 = target_wallet.puzzle_hash_for_pk(pubkey)
self.log.debug(
f"Puzzle at index {index} wallet ID {wallet_id} puzzle hash {puzzlehash_unhardened.hex()}"
)
derivation_paths.append(
DerivationRecord(
uint32(index),
puzzlehash_unhardened,
pubkey,
target_wallet.type(),
uint32(target_wallet.id()),
False,
)
)
self.log.info(f"Done: {creating_msg} Time: {time.time() - start_t} seconds")
if len(derivation_paths) > 0:
await self.puzzle_store.add_derivation_paths(derivation_paths)
if wallet_id == self.main_wallet.id():
await self.wallet_node.new_peak_queue.subscribe_to_puzzle_hashes(
[record.puzzle_hash for record in derivation_paths]
)
self.state_changed("new_derivation_index", data_object={"index": derivation_paths[-1].index})
if len(unhardened_keys) > 0:
self.state_changed("new_derivation_index", data_object={"index": last_index - 1})
# By default, we'll mark previously generated unused puzzle hashes as used if we have new paths
if mark_existing_as_used and unused > 0 and new_paths:
if mark_existing_as_used and unused > 0 and len(unhardened_keys) > 0:
self.log.info(f"Updating last used derivation index: {unused - 1}")
await self.puzzle_store.set_used_up_to(uint32(unused - 1))

Expand Down

0 comments on commit 41306ad

Please sign in to comment.