From 3556d2239fa75294454bcaac677d399d37b8f696 Mon Sep 17 00:00:00 2001 From: Trevor Richard Date: Fri, 22 Mar 2024 15:27:45 +0000 Subject: [PATCH] revert when asset does not have decimals --- src/PrizeVault.sol | 10 +++++++++- test/unit/PrizeVault/PrizeVault.t.sol | 23 ++++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/PrizeVault.sol b/src/PrizeVault.sol index 7e57e1a..99d4ef3 100644 --- a/src/PrizeVault.sol +++ b/src/PrizeVault.sol @@ -259,6 +259,10 @@ contract PrizeVault is TwabERC20, Claimable, IERC4626, ILiquidationSource, Ownab /// @param minAssets The min asset threshold requested error MinAssetsNotReached(uint256 assets, uint256 minAssets); + /// @notice Thrown when the underlying asset does not specify it's number of decimals. + /// @param asset The underlying asset that was checked + error FailedToGetAssetDecimals(address asset); + //////////////////////////////////////////////////////////////////////////////// // Modifiers //////////////////////////////////////////////////////////////////////////////// @@ -309,7 +313,11 @@ contract PrizeVault is TwabERC20, Claimable, IERC4626, ILiquidationSource, Ownab IERC20 asset_ = IERC20(yieldVault_.asset()); (bool success, uint8 assetDecimals) = _tryGetAssetDecimals(asset_); - _underlyingDecimals = success ? assetDecimals : 18; + if (success) { + _underlyingDecimals = assetDecimals; + } else { + revert FailedToGetAssetDecimals(address(asset_)); + } _asset = asset_; yieldVault = yieldVault_; diff --git a/test/unit/PrizeVault/PrizeVault.t.sol b/test/unit/PrizeVault/PrizeVault.t.sol index 53b9d58..7d9b966 100644 --- a/test/unit/PrizeVault/PrizeVault.t.sol +++ b/test/unit/PrizeVault/PrizeVault.t.sol @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; -import { UnitBaseSetup, PrizePool, TwabController, ERC20, IERC20, IERC4626 } from "./UnitBaseSetup.t.sol"; +import { UnitBaseSetup, PrizePool, TwabController, ERC20, IERC20, IERC4626, YieldVault } from "./UnitBaseSetup.t.sol"; import { IVaultHooks, VaultHooks } from "../../../src/interfaces/IVaultHooks.sol"; import { ERC20BrokenDecimalMock } from "../../contracts/mock/ERC20BrokenDecimalMock.sol"; @@ -180,6 +180,27 @@ contract PrizeVaultTest is UnitBaseSetup { assertEq(decimals, 0); } + function testConstructorFailsWhenDecimalFails() public { + IERC20 brokenDecimalToken = new ERC20BrokenDecimalMock(); + YieldVault brokenDecimalYieldVault = new YieldVault( + address(brokenDecimalToken), + "Test Yield Vault", + "yvTest" + ); + vm.expectRevert(abi.encodeWithSelector(PrizeVault.FailedToGetAssetDecimals.selector, address(brokenDecimalToken))); + new PrizeVault( + "PoolTogether Decimal Fail", + "pDecFail", + brokenDecimalYieldVault, + PrizePool(address(prizePool)), + address(this), + address(this), + YIELD_FEE_PERCENTAGE, + 1e6, + address(this) + ); + } + /* ============ maxDeposit / maxMint ============ */ function testMaxDeposit_SubtractsLatentBalance() public {