Skip to content

Commit

Permalink
added fee, eth support, tests and small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pum committed Aug 13, 2023
1 parent e8ccb17 commit 5b08aab
Show file tree
Hide file tree
Showing 11 changed files with 1,767 additions and 53 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/node.js.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Node.js CI

on: push

jobs:
build:
runs-on: ubuntu-latest

strategy:
matrix:
node-version: [16.x]
# See supported Node.js release schedule at https://nodejs.org/en/about/releases/

steps:
- uses: actions/checkout@v2
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v2
with:
node-version: ${{ matrix.node-version }}
cache: 'yarn'
- name: Install dependencies
run: yarn install --frozen-lockfile
- name: Run tests
run: npx hardhat test
5 changes: 4 additions & 1 deletion contracts/IVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ pragma solidity ^0.8.0;
import "./IVaultsFactory.sol";

interface IVault {
function initialize(address underlyingToken_, IVaultsFactory factory_) external;
function initialize(address underlyingToken_, IVaultsFactory factory_, bool isEth, string memory name_, string memory symbol_) external;
function emergencyWithdraw(address to_, uint256 amount_) external;

// must return keccak256("Vaults.Vault") ^ bytes32(uint256(uint160(address(VaultsFactory))))
function isVault() external view returns (bytes32);
}
7 changes: 6 additions & 1 deletion contracts/IVaultsFactory.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

import "./IVault.sol";

interface IVaultsFactory {
function feeReceiver() external view returns(address);
function feeBasisPoints() external view returns (uint256);

function unwrapDelay() external view returns (uint256);
function isPaused(address vault) external view returns (bool);
function isPaused(IVault vault) external view returns (bool);
}
7 changes: 7 additions & 0 deletions contracts/IWETH.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

interface IWETH {
function deposit() external payable;
function withdraw(uint amount) external;
}
104 changes: 81 additions & 23 deletions contracts/VaultImplementation.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@ pragma solidity ^0.8.0;

import "@openzeppelin/contracts-upgradeable/token/ERC20/ERC20Upgradeable.sol";
import "@openzeppelin/contracts-upgradeable/token/ERC20/utils/SafeERC20Upgradeable.sol";
import "@openzeppelin/contracts-upgradeable/security/ReentrancyGuardUpgradeable.sol";
import "./IVaultsFactory.sol";
import "./IVault.sol";
import "./IWETH.sol";


contract VaultImplementation is IVault, Initializable, ERC20Upgradeable {
contract VaultImplementation is IVault, Initializable, ERC20Upgradeable, ReentrancyGuardUpgradeable {
using SafeERC20Upgradeable for IERC20MetadataUpgradeable;

IERC20MetadataUpgradeable public underlyingToken;
IVaultsFactory public factory;
bool public isEth;

bool public emergency = false;

struct PendingUnwrap {
uint256 amount;
Expand All @@ -22,66 +27,119 @@ contract VaultImplementation is IVault, Initializable, ERC20Upgradeable {
mapping(address => PendingUnwrap) public pendingUnwraps;

event Wrapped(address indexed user, uint256 amount);
event Unwrapped(address indexed user, uint256 amount);
event UnwrapRequested(address indexed user, uint256 amount);
event Claimed(address indexed user, uint256 amount);
event UnwrapCancelled(address indexed user, uint256 amount);

modifier notPaused() {
require(!factory.isPaused(address(this)), "Operation is paused");
require(!factory.isPaused(this), "VAULTS: OPERATION_PAUSED");
require(!emergency, "VAULTS: OPERATION_PAUSED_EMERGENCY");
_;
}

function initialize(address underlyingTokenAddress_, IVaultsFactory factory_) public initializer {
function initialize(address underlyingTokenAddress_, IVaultsFactory factory_, bool isEth_, string memory name_, string memory symbol_) public initializer {
underlyingToken = IERC20MetadataUpgradeable(underlyingTokenAddress_);
factory = factory_;
isEth = isEth_;
__ERC20_init(
string(abi.encodePacked("Vaulted ", underlyingToken.name())),
string(abi.encodePacked("v", underlyingToken.symbol()))
bytes(name_).length != 0 ? name_ : string(abi.encodePacked("Vaulted ", underlyingToken.symbol())),
bytes(symbol_).length != 0 ? symbol_ : string(abi.encodePacked("v", underlyingToken.symbol()))
);
__ReentrancyGuard_init();
}

// only accept ETH via fallback from the WETH contract
receive() external payable {
require(isEth, "VAULTS: NOT_ETHER");
require(msg.sender == address(underlyingToken), "VAULTS: RESTRICTED");
}

function decimals() public view virtual override returns (uint8) {
return underlyingToken.decimals();
}

function wrap(uint256 amount_) external notPaused {
function wrapEther() public payable notPaused nonReentrant {
require(isEth, "VAULTS: NOT_ETHER");

IWETH(address(underlyingToken)).deposit{value: msg.value}();

_wrap(msg.value);
}

function wrap(uint256 amount_) public notPaused nonReentrant {
underlyingToken.safeTransferFrom(msg.sender, address(this), amount_);
_mint(msg.sender, amount_);
emit Wrapped(msg.sender, amount_);

_wrap(amount_);
}

function _wrap(uint256 amount_) internal {
require(amount_ > 0, "VAULTS: INVALID_AMOUNT");

uint256 fee = (amount_ * factory.feeBasisPoints()) / 10000;
uint256 afterFeeAmount = amount_ - fee;

if (fee > 0) {
address feeReceiver = factory.feeReceiver();
require(feeReceiver != address(0), "VAULTS: FEE_RECEIVER_NOT_SET");
underlyingToken.safeTransfer(feeReceiver, fee);
}

_mint(msg.sender, afterFeeAmount);
emit Wrapped(msg.sender, afterFeeAmount);
}

function unwrap(uint256 amount_) external notPaused {
require(amount_ > 0, "Amount should be greater than 0");
require(balanceOf(msg.sender) >= amount_, "Insufficient balance to unwrap");
function unwrap(uint256 amount_) external notPaused nonReentrant {
require(amount_ > 0, "VAULTS: INVALID_AMOUNT");
require(balanceOf(msg.sender) >= amount_, "VAULTS: INSUFFICIENT_BALANCE");
pendingUnwraps[msg.sender] = PendingUnwrap(amount_, block.timestamp);
emit Unwrapped(msg.sender, amount_);
emit UnwrapRequested(msg.sender, amount_);
}

function claim() external notPaused {
require(pendingUnwraps[msg.sender].amount > 0, "No unwrap requested");
require(block.timestamp >= pendingUnwraps[msg.sender].timestamp + factory.unwrapDelay(), "Delay has not passed yet");
function claim() external notPaused nonReentrant {
require(pendingUnwraps[msg.sender].amount > 0, "VAULTS: NO_UNWRAP_REQUESTED");
require(block.timestamp >= pendingUnwraps[msg.sender].timestamp + factory.unwrapDelay(), "VAULTS: UNWRAP_DELAY_NOT_MET");

uint256 amount = pendingUnwraps[msg.sender].amount;
delete pendingUnwraps[msg.sender];

_burn(msg.sender, amount);
underlyingToken.safeTransfer(msg.sender, amount);

if (isEth) {
IWETH(address(underlyingToken)).withdraw(amount);
(bool success, ) = msg.sender.call{value:amount}("");
require(success, "VAULTS: ETH_TRANSFER_FAILED");
} else {
underlyingToken.safeTransfer(msg.sender, amount);
}

emit Claimed(msg.sender, amount);
}

function cancelUnwrap() external notPaused {
require(pendingUnwraps[msg.sender].amount > 0, "No unwrap requested to cancel");
function cancelUnwrap() external notPaused nonReentrant {
require(pendingUnwraps[msg.sender].amount > 0, "VAULTS: NO_UNWRAP_TO_CANCEL");
uint256 amount = pendingUnwraps[msg.sender].amount;
delete pendingUnwraps[msg.sender];
emit UnwrapCancelled(msg.sender, amount);
}

function emergencyWithdraw(address to_, uint256 amount_) external {
require(factory.isPaused(address(this)), "Vault is not paused");
require(msg.sender == address(factory), "Only VaultsFactory can perform emergency withdrawal");
require(to_ != address(0), "Zero address not allowed");
function emergencyWithdraw(address to_, uint256 amount_) external nonReentrant {
require(factory.isPaused(this), "VAULTS: NOT_PAUSED");
require(msg.sender == address(factory), "VAULTS: NOT_FACTORY_ADDRESS");
require(to_ != address(0), "VAULTS: ZERO_ADDRESS");

emergency = true;

uint256 withdrawalAmount = (amount_ == 0) ? underlyingToken.balanceOf(address(this)) : amount_;
underlyingToken.safeTransfer(to_, withdrawalAmount);
}

function _beforeTokenTransfer(address from, address /* to */, uint256 amount) internal view override {
if (from != address(0) && pendingUnwraps[from].amount > 0) {
require(balanceOf(from) >= amount + pendingUnwraps[from].amount, "VAULTS: TRANSFER_EXCEEDS_BALANCE");
}
}

function isVault() public view returns (bytes32) {
return keccak256("Vaults.Vault") ^ bytes32(uint256(uint160(address(factory))));
}
}
76 changes: 57 additions & 19 deletions contracts/VaultsFactory.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,54 +7,75 @@ import "./IVaultsFactory.sol";
import "./IVault.sol";

contract VaultsFactory is IVaultsFactory, AccessControlEnumerable {
address public immutable weth;

address public vaultsImplementation;
uint256 public unwrapDelay;

mapping(address => bool) public pausedVaults;
address public feeReceiver;
uint256 public feeBasisPoints;

mapping(IVault => bool) public pausedVaults;
bool public allVaultsPaused = false;

// Role identifiers for pausing, deploying, and admin actions
bytes32 public constant PAUSE_ROLE = keccak256("PAUSE_ROLE");
bytes32 public constant UNPAUSE_ROLE = keccak256("UNPAUSE_ROLE");
bytes32 public constant DEPLOY_ROLE = keccak256("DEPLOY_ROLE");

event VaultDeployed(address vaultAddress);
event VaultPaused(address vaultAddress);
event VaultUnpaused(address vaultAddress);
event VaultDeployed(IVault vaultAddress);
event VaultPaused(IVault vaultAddress);
event VaultUnpaused(IVault vaultAddress);
event AllVaultsPaused();
event AllVaultsUnpaused();

modifier isVault(IVault vault_) {
try vault_.isVault() returns (bytes32 result) {
require(vault_.isVault() == keccak256("Vaults.Vault") ^ bytes32(uint256(uint160(address(this)))), "VAULTS: NOT_VAULT");
} catch {
revert("VAULTS: NOT_VAULT");
}
_;
}

constructor(
address weth_,
address vaultsImplementationAddress_,
uint256 unwrapDelay_,
address rolesAddr_
address rolesAddr_,
address initialFeeReceiver_,
uint256 initialFeeBasisPoints_
) {
weth = weth_;
vaultsImplementation = vaultsImplementationAddress_;
unwrapDelay = unwrapDelay_;

_setupRole(DEFAULT_ADMIN_ROLE, rolesAddr_);
_setupRole(PAUSE_ROLE, rolesAddr_);
_setupRole(UNPAUSE_ROLE, rolesAddr_);
_setupRole(DEPLOY_ROLE, rolesAddr_);

_setFeeReceiver(initialFeeReceiver_);
_setFeeBasisPoints(initialFeeBasisPoints_);
}

function deployVault(address underlyingToken_) external onlyRole(DEPLOY_ROLE) returns (address) {
function deployVault(address underlyingToken_, string memory name_, string memory symbol_) external onlyRole(DEPLOY_ROLE) returns (IVault result) {
TransparentUpgradeableProxy proxy = new TransparentUpgradeableProxy(
vaultsImplementation,
msg.sender,
getRoleMember(DEFAULT_ADMIN_ROLE, 0),
""
);
IVault(address(proxy)).initialize(underlyingToken_, this);
emit VaultDeployed(address(proxy));
return address(proxy);
result = IVault(address(proxy));
result.initialize(underlyingToken_, this, underlyingToken_ == weth, name_, symbol_);
emit VaultDeployed(result);
}

function pauseVault(address vaultAddress_) external onlyRole(PAUSE_ROLE) {
pausedVaults[vaultAddress_] = true;
emit VaultPaused(vaultAddress_);
function pauseVault(IVault vault_) external onlyRole(PAUSE_ROLE) isVault(vault_) {
pausedVaults[vault_] = true;
emit VaultPaused(vault_);
}

function unpauseVault(address vaultAddress_) external onlyRole(UNPAUSE_ROLE) {
function unpauseVault(IVault vaultAddress_) external onlyRole(UNPAUSE_ROLE) isVault(vaultAddress_) {
delete pausedVaults[vaultAddress_];
emit VaultUnpaused(vaultAddress_);
}
Expand All @@ -69,7 +90,7 @@ contract VaultsFactory is IVaultsFactory, AccessControlEnumerable {
emit AllVaultsUnpaused();
}

function isPaused(address vaultAddress_) public view returns (bool) {
function isPaused(IVault vaultAddress_) public view returns (bool) {
return allVaultsPaused || pausedVaults[vaultAddress_];
}

Expand All @@ -78,11 +99,28 @@ contract VaultsFactory is IVaultsFactory, AccessControlEnumerable {
}

function setVaultsImplementation(address vaultsImplementation_) external onlyRole(DEFAULT_ADMIN_ROLE) {
require(vaultsImplementation_ != address(0), "Zero address");
require(vaultsImplementation_ != address(0), "VAULTS: ZERO_ADDRESS");
vaultsImplementation = vaultsImplementation_;
}

function emergencyWithdrawFromVault(address vaultAddress, address to_, uint256 amount_) external onlyRole(DEFAULT_ADMIN_ROLE) {
IVault(vaultAddress).emergencyWithdraw(to_, amount_);
function emergencyWithdrawFromVault(IVault vaultAddress_, address to_, uint256 amount_) external onlyRole(DEFAULT_ADMIN_ROLE) isVault(vaultAddress_) {
vaultAddress_.emergencyWithdraw(to_, amount_);
}

function _setFeeReceiver(address feeReceiver_) internal {
feeReceiver = feeReceiver_;
}

function _setFeeBasisPoints(uint256 feeBasisPoints_) internal {
require(feeBasisPoints_ <= 10000, "VAULTS: EXCESSIVE_FEE_PERCENT"); // Max of 10000 basis points
feeBasisPoints = feeBasisPoints_;
}

function setFeeReceiver(address feeReceiver_) external onlyRole(DEFAULT_ADMIN_ROLE) {
_setFeeReceiver(feeReceiver_);
}

function setFeeBasisPoints(uint256 feeBasisPoints_) external onlyRole(DEFAULT_ADMIN_ROLE) {
_setFeeBasisPoints(feeBasisPoints_);
}
}
}
19 changes: 14 additions & 5 deletions contracts/tests-helpers/MockERC20.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,24 @@ pragma solidity ^0.8.0;
import "@openzeppelin/contracts/token/ERC20/ERC20.sol";

contract MockERC20 is ERC20 {

uint8 internal _decimals;

constructor(
string memory name,
string memory symbol,
uint256 initialSupply
) ERC20(name, symbol) {
_mint(msg.sender, initialSupply);
string memory name_,
string memory symbol_,
uint8 decimals_,
uint256 initialSupply_
) ERC20(name_, symbol_) {
_mint(msg.sender, initialSupply_);
_decimals = decimals_;
}

function mint(address to, uint256 amount) external {
_mint(to, amount);
}

function decimals() public view virtual override returns (uint8) {
return _decimals;
}
}

0 comments on commit 5b08aab

Please sign in to comment.