Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast calculation of theta functions for critical arguments (|q|->1) #626

Open
MontyPythagoras opened this issue May 14, 2022 · 0 comments
Open

Comments

@MontyPythagoras
Copy link

MontyPythagoras commented May 14, 2022

Hi all,
the calculation of the Jacobi theta functions for certain arguments (|q|->1) is critical, since the convergence for these arguments is very slow, see the documentation for the mpmath jtheta function here: https://mpmath.org/doc/current/functions/elliptic.html#jtheta.
I have programmed a much faster version of the Jacobi theta functions, see code at the end of this posting. For demonstration, I used a critical set of arguments with q=0.599999+0.8j and z=0.5+1j. The imaginary part of z makes convergence even slower. Feel free to try other arguments. Here is a sample output of a test run:

n = 2
z = (0.5 + 1.0j)
q = (0.599999 + 0.8j)
Precision: 200 digits

 theta = (2.9177_(187)_30027968e+723825 - 6.4828_(187)_82484441e+723825j)   in   4.53 ms
jtheta = (2.9177_(187)_30027968e+723825 - 6.4828_(187)_82484441e+723825j)   in 13578.12 ms

Factor 2994.49

As one can see, both functions yield the same result (output is being abbreviated by means of the function "shortnum" which shows only a few digits after the point and the very last digits, "_(187)_" in the middle means 187 digits have been hidden).
While the mpmath jtheta function takes more than 13 seconds to calculate the value once, the new theta function just takes 4.53 milliseconds, which is roughly 3000 times faster (tested on an i7 8th gen processor). Even at 1000 digits, the new theta takes only 78ms, jtheta takes more than 5000 times longer (7 minutes).
The set of arguments given in the documentation n = 1, z = 10, q = 0.99999999 * exp(0.5*j) where jtheta raises a ValueError can be calculated in 18 milliseconds at a precision of 1000 digits.
An explanation of the algorithm can be found here:
https://matheplanet.de/matheplanet/nuke/html/viewtopic.php?rd2&topic=258531&start=0#p1877370
(The code there is slightly different, facilitating the calculation of the natural logarithm of theta at "regular" double precision within microseconds.)

from mpmath import sqrt, pi, log, expjpi, nan, mpc, fabs, mp, nint, jtheta, nstr
from time import process_time as prtm

def jacobitheta(n, z, q = None, tau = None):
    if q == 0: return 0 if n < 3 else 1
    currprec = mp.prec
    if tau == None:
        if fabs(q) >= 1: return nan
        mp.prec += 12 + round(3 * log((1 + fabs(z.imag)) / (1 - fabs(q))))
        tau = -1j * log(q) / pi
    else:
        if tau.imag <= 0: return nan
        mp.prec += 10 + round(3 * log(1 + (1 + fabs(z.imag)) / tau.imag))
        tau -= 2 * nint(0.5 * tau.real)
    
    z /= pi
    f = 1
    if n == 1 or n == 4: z -= 0.5
    if n < 3:
        g = z + 0.25 * tau
        z += 0.5 * tau
    else:
        g = 0

    n = nint(z.imag / tau.imag, prec = 0)
    if n != 0:
        a = z
        z -= n * tau
        g -= n * (a + z)

    while tau.imag < 0.5:
        n = nint(tau.real, prec = 0)
        z -= 0.5 * n
        z -= nint(z.real, prec = 0)
        tau = -1 / (tau - n)
        a = z
        z *= tau
        g += a * z - 0.25
        f *= sqrt(tau)
        
    q = z.imag / tau.imag
    n = nint(q, prec = 0)
    if n != 0:
        q -= n
        a = z
        z -= n * tau
        g -= n * (a + z)
    
    s = 0
    z *= 2
    a = sqrt(q * q + 0.2206356 * (currprec + 2) / tau.imag) # 0.2206356==ln(2)/pi
    for n in range(-int(a + q), int(a - q) + 1): s += expjpi(n * (tau * n + z))
    s *= f * expjpi(g)
    mp.prec = currprec
    return s
# End of function definition

def looptime(f, a, t_test):
    l = 0
    t = prtm()
    while prtm() - t < min(0.5, 0.05 * t_test):
        r = f(* a)
        l += 1
    t_min = prtm() - t
    if t_min < 0.3 * t_test:
        l = int(t_test / t_min * l * 0.2) + 1
        t_min = 1e100
        for m in range(5):
            t = prtm()
            for k in range(l):
                r = f(* a)
            t = prtm() - t
            t_min = min(t_min, t)
    return r, t_min / l

def shortnum(x, left, right):
    a = nstr(x, mp.dps, strip_zeros=False)
    if a.find("j") > -1:
        l1 = a.find(" ")
        a, b = a[:l1], a[l1:]
        l1 = a.find(".") + 1 + left
        l2 = a.find("e")
        l2 = l2 + len(a) + 1 - right if l2 < 0 else l2 - right
        r = a if l2 - l1 < 6 else a[:l1] + "…(" + str(l2 - l1) + ")…" + a[l2:]
        l1 = b.find(".") + 1 + left
        l2 = b.find("e")
        l2 = l2 + len(b) - 1 - right if l2 < 0 else l2 - right
        r = r + b if l2 - l1 < 6 else r + b[:l1] + "…(" + str(l2 - l1) + ")…" + b[l2:]
    else:
        l1 = a.find(".") + 1 + left
        l2 = a.find("e")
        l2 = l2 + len(a) + 1 - right if l2 < 0 else l2 - right
        r = a if l2 - l1 < 6 else a[:l1] + "…(" + str(l2 - l1) + ")…" + a[l2:]
    return r

mp.dps = 200
t_test = 5  #seconds test duration target

mp.dps += 100
a = 2, mpc("0.5", "1"), mpc("0.59999", "0.8")
mp.dps -= 100

print("n = {:d}\nz = {}\nq = {}".format(a[0], nstr(a[1], 15), nstr(a[2], 15)))
print("Precision: {} digits\n".format(mp.dps))

r, t_2 = looptime(jacobitheta, a, t_test)
print(" theta = {}   in {:6.2f} ms".format(shortnum(r, 6, 10), t_2 * 1000))

if 1-fabs(a[2]) > 1e-7:
    r, t_1 = looptime(jtheta, a, t_test)
    print("jtheta = {}   in {:6.2f} ms".format(shortnum(r, 6, 10), t_1 * 1000))
    
    print("\nFactor {:.2f}".format(t_1 / t_2))
else:
    print("jtheta limit exceeded")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant