Skip to content

Commit

Permalink
Use decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
lsbardel committed Sep 2, 2023
1 parent 003c9f8 commit e01871e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
12 changes: 7 additions & 5 deletions quantflow/sp/copula.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from abc import ABC, abstractmethod
from decimal import Decimal
from math import isclose

import numpy as np
from pydantic import BaseModel, Field

from quantflow.utils.functions import debye
from quantflow.utils.numbers import ZERO
from quantflow.utils.types import FloatArray, FloatArrayLike


Expand Down Expand Up @@ -71,10 +73,10 @@ class FrankCopula(Copula):
u\right)-1\right)\left(\exp\left(-\kappa
v\right)-1\right)}{\exp\left(-\kappa\right)-1}\right]
"""
kappa: float = Field(default=0, description="Frank copula parameter")
kappa: Decimal = Field(default=ZERO, description="Frank copula parameter")

def __call__(self, u: FloatArrayLike, v: FloatArrayLike) -> FloatArrayLike:
k = self.kappa
k = float(self.kappa)
if isclose(k, 0.0):
return u * v
eu = np.exp(-k * u)
Expand All @@ -84,20 +86,20 @@ def __call__(self, u: FloatArrayLike, v: FloatArrayLike) -> FloatArrayLike:

def tau(self) -> float:
"""Kendall's tau"""
k = self.kappa
k = float(self.kappa)
if isclose(k, 0.0):
return 0
return 1 + 4 * (debye(1, k) - 1) / k

def rho(self) -> float:
"""Spearman's rho"""
k = self.kappa
k = float(self.kappa)
if isclose(k, 0.0):
return 0
return 1 - 12 * (debye(2, -k) - debye(1, -k)) / k

def jacobian(self, u: FloatArrayLike, v: FloatArrayLike) -> FloatArray:
k = self.kappa
k = float(self.kappa)
if isclose(k, 0.0):
return np.array([v, u, v * 0])
eu = np.exp(-k * u)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_copula.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from decimal import Decimal
from math import isclose

import numpy as np
Expand All @@ -13,8 +14,8 @@ def test_independent_copula():


def test_frank_copula():
c = FrankCopula(kappa=0.3)
assert c.kappa == 0.3
c = FrankCopula(kappa=Decimal("0.3"))
assert c.kappa == Decimal("0.3")
assert c.tau() > 0
assert c.rho() < 0
assert c.jacobian(0.3, 0.4).shape == (3,)
Expand Down

0 comments on commit e01871e

Please sign in to comment.