Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing 6-step FFT algorithm #967

Open
NimaSarajpoor opened this issue May 13, 2024 · 9 comments
Open

Optimizing 6-step FFT algorithm #967

NimaSarajpoor opened this issue May 13, 2024 · 9 comments
Labels
enhancement New feature or request

Comments

@NimaSarajpoor
Copy link
Collaborator

This issue is to optimize the 6-step FFT algorithm, as discussed initially in #938. We will try to improve the performance of each block of the algorithm. Each result MUST come with an end-to-end code, and the code MUST have assertion to make sure the output is correct.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 13, 2024

RFFT (Real FFT)

We first start with RFFT algorithm in which we use 6-step FFT algorithm as provided here . We test it by comparing its output against scipy.fft.rfft, and we show its performance.

code

# saved in file: rfft_optimized_v0

import math
import numpy as np
from numba import njit


@njit(fastmath=True)
def _fft_block(n, s, eo, x, y):
    """
    A recursive function that is used as part of fft algorithm

    n : int
    s : int
    eo: bool
    x : numpy.array 1D
    y : numpy.array 1D
    """
    if n == 2:
        if eo:
            z = y
        else:
            z = x

        for i in range(s):
            j = i + s
            a = x[i]
            b = x[j]
            z[i] = a + b
            z[j] = a - b

    elif n >= 4:
        m = n // 2
        sm = s * m

        theta = math.pi / m
        for p in range(m):
            sp = s * p
            w = math.cos(p * theta) - 1j * math.sin(p * theta)
            for q in range(s):
                idx = sp + q
                a = x[idx]
                b = x[idx + sm]

                y[idx + sp] = a + b
                y[idx + sp + s] = (a - b) * w

        _fft_block(m, 2*s, not eo, y, x)

    else:
        pass


@njit(fastmath=True)
def _eightstep_fft(x, y):
    """
    Apply 8-step FFT algorithm and update x in-place.
    """
    n = len(x)
    m = n // 2

    theta = math.pi / m 
    for i in range(m):
        w = math.cos(i * theta) - 1j * math.sin(i * theta)
        j = i + m
        y[i] = x[i] + x[j]
        y[j] = (x[i] - x[j]) * w

    _sixstep_fft(y[:m], x[:m])
    _sixstep_fft(y[m:], x[m:])

    for p in range(m):
        x[2 * p] = y[p]
        x[2 * p + 1] = y[p + m]

    return


@njit(fastmath=True)
def _sixstep_fft(x, y):
    """
    Apply 6-step FFT algorithm and update x in-place.
    """
    n = len(x)
    n_sqrt = int(np.sqrt(n))

    theta = 2 * math.pi / n_sqrt
    c_theta = math.cos(theta) - 1j * math.sin(theta)


    # step 1: matrix transpose
    for k in range(n_sqrt):
        kn = k * n_sqrt
        for p in range(k + 1, n_sqrt):
            i = k + p * n_sqrt
            j = p + kn
            x[i], x[j] = x[j], x[i]

    # step 2
    for start in range(0, n, n_sqrt):
        _fft_block(n_sqrt, 1, False, x[start:], y[start:])

    # step 3 and 4: tranpose with twiddle_factor
    for p in range(n_sqrt):
        theta0 = 2 * p * math.pi / n
        for k in range(p, n_sqrt):
            theta = k * theta0
            w = math.cos(theta) - 1j * math.sin(theta)
    
            if k == p:
                i = p * n_sqrt + p
                x[i] = x[i] * w
            else:
                i = p * n_sqrt + k
                j = k * n_sqrt + p
                x[j], x[i] = x[i] * w, x[j] * w
                
    # step 5
    for start in range(0, n, n_sqrt):
        _fft_block(n_sqrt, 1, False, x[start:], y[start:])

    # step 6: matrix transpose
    for k in range(n_sqrt):
        kn = k * n_sqrt
        for p in range(k + 1, n_sqrt):
            i = k + p * n_sqrt
            j = p + kn
            x[i], x[j] = x[j], x[i]

    return


@njit(fastmath=True)
def _compute_fft(x, y):
    n = len(x)
    n_logtwo = int(np.log2(n))

    if n_logtwo == 1:
        a = x[0]
        b = x[1]
        x[0] = a + b
        x[1] = a - b
    elif n_logtwo % 2 == 0:
        _sixstep_fft(x, y)
    else:
        _eightstep_fft(x, y)

    return


@njit(fastmath=True)
def _compute_rfft(T):
    n = len(T)
    half_n = n // 2

    x = np.empty(half_n, dtype=np.complex_)
    for i in range(half_n):
        x[i] = T[2 * i] + 1j * T[2 * i + 1]

    y = np.empty(half_n + 1, dtype=np.complex_)
    _compute_fft(x, y[:half_n])

    y[0] = x[0].real + x[0].imag
    y[n // 4] = x[n // 4].conjugate()
    y[half_n] = x[0].real - x[0].imag

    w = 0.5j
    factor = math.cos(math.pi / half_n) - 1j * math.sin(math.pi / half_n)
    for k in range(1, n//4):
        w = w * factor
        v = (x[k] - x[half_n - k].conjugate()) * (0.5 + w)

        y[k] = x[k] - v
        y[half_n - k] = x[half_n - k] + v.conjugate()

    return y

Performance (+ Assertion)

import time

import numpy as np
import scipy

from matplotlib import pyplot as plt
from rfft_optimized_v0 import _compute_rfft

funcs_dict = {
    'rfft_v0': _compute_rfft,
    'numpy_rfft': np.fft.rfft,
    'scipy_rfft': scipy.fft.rfft
}

ref_func_key = 'numpy_rfft'
ref_func = funcs_dict[ref_func_key]

seed = 0
np.random.seed(seed)
data = np.random.rand(2 ** 23)

p_vals = np.arange(2, 24)

n_iter = 2
performance = {}
for func_name, func_obj in funcs_dict.items():
    print(f'============ {func_name} ============')
    running_time = []
    for p in p_vals:
        print(f'{p}', end='-->')
        T = data[:2 ** p]

        F_ref = ref_func(T)
        F_comp = func_obj(T)
        np.testing.assert_allclose(F_ref, F_comp, atol=1e-7)

        lst = []
        for _ in range(n_iter):
            t1 = time.time()
            func_obj(T)
            t2 = time.time()
            lst.append(t2 - t1)

        running_time.append(lst)
    
    performance[func_name] = running_time
    print('Done!')


### plotting
plt.figure(figsize=(20, 5))
plt.title('Comparing performances of different functions for computing rfft')

colors = ['cyan', 'r', 'b', 'orange', 'k', 'm', 'g', 'yellow', 'brown']

baseline_key = 'scipy_rfft'
baseline = np.array([np.median(lst) for lst in performance[baseline_key]])

for i, key in enumerate(list(performance.keys())):
    if key == baseline_key:
      continue
    
    arr = np.array([np.median(lst) for lst in performance[key]])
    r = arr / baseline
    
    plt.plot(r, color=colors[i], marker='o', label=key)

plt.axhline(y=1, color='k', linestyle='--', label=f'y=1 (baseline: {baseline_key})')
plt.xlabel('The log2 of input array length', fontsize=13)
plt.ylabel(f"Running time ratio w.r.t baseline", fontsize=13)
plt.xticks(ticks=np.arange(len(p_vals)), labels=p_vals, fontsize=13)
plt.yticks(fontsize=13)
plt.grid()
plt.legend(fontsize=13)
plt.show()

In my MacOS, the result is:
image

This Colab notebook contains the code.


[Update]
I ran it again with 100 iterations, and got this:

image

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 13, 2024

Matrix Transposition

In the RFFT code provided in previous comment, the function _compute_fft does the majority of the computing load. This function calls 6-step or 8-step depending on the log2 of the length of its input. We start with the blocks of the code in 6-step algorithm. The first step in this algorithm is to transform 1D array x into x.reshape(n, n).T.ravel(), where $n=\sqrt{len(x)}$.

Code

Currently, this is happening via the following code:

@njit(fastmath=True)
def _tranpose_v0(x, n_sqrt):
    for k in range(n_sqrt):
        kn = k * n_sqrt
        for p in range(k + 1, n_sqrt):
            i = k + p * n_sqrt
            j = p + kn
            x[i], x[j] = x[j], x[i]
    
    return

However, as discussed in #938, we can use cache-oblivious algorithm (see #965) for matrix transposition as follows:

@njit(fastmath=True)
def _tranpose_v2(x, n_sqrt, x_transpose):
    blocksize = 32
    blocksize = min(blocksize, n_sqrt)

    x = x.reshape(n_sqrt, n_sqrt)
    x_transpose = x_transpose.reshape(n_sqrt, n_sqrt)
    for i in range(0, n_sqrt, blocksize):
        for j in range(0, n_sqrt, blocksize):
            x_transpose[i:i + blocksize, j:j + blocksize] = np.transpose(x[j:j + blocksize, i:i + blocksize])

    return

Performance( + Assertion)

import time
import numpy as np

from matplotlib import pyplot as plt
from numba import njit

@njit(fastmath=True)
def _tranpose_v0(x, n_sqrt):
    for k in range(n_sqrt):
        kn = k * n_sqrt
        for p in range(k + 1, n_sqrt):
            i = k + p * n_sqrt
            j = p + kn
            x[i], x[j] = x[j], x[i]
    
    return


@njit(fastmath=True)
def _tranpose_v1(x, n_sqrt, x_transpose):
    blocksize = 32
    blocksize = min(blocksize, n_sqrt)

    x = x.reshape(n_sqrt, n_sqrt)
    x_transpose = x_transpose.reshape(n_sqrt, n_sqrt)
    for i in range(0, n_sqrt, blocksize):
        for j in range(0, n_sqrt, blocksize):
            x_transpose[i:i + blocksize, j:j + blocksize] = np.transpose(x[j:j + blocksize, i:i + blocksize])

    return

funcs_dict = {
    'tranpose_v0': _tranpose_v0,
    'tranpose_v1': _tranpose_v1,
}

ref_func_key = 'tranpose_v0'
ref_func = funcs_dict[ref_func_key]

seed = 0
np.random.seed(seed)
data = np.random.rand(2 ** 23)
data = data[::2] + 1j * data[1::2]

p_vals = np.arange(2, 22 + 1, 2)

n_iter = 500
performance = {}
for func_name, func_obj in funcs_dict.items():
    print(f'============ {func_name} ============')
    running_time = []
    for p in p_vals:
        print(f'{p}', end='-->')
        T = data[:2 ** p]
        n = len(T)
        n_sqrt = int(np.sqrt(n))

        y = np.empty(n, dtype=np.complex_)
    
        ref = T.copy()
        ref_func(ref, n_sqrt)

        if func_name == 'tranpose_v0':
            comp = T.copy()
            func_obj(comp, n_sqrt)
        else:
            comp = y.copy()
            func_obj(T, n_sqrt, comp)

        np.testing.assert_allclose(ref, comp, atol=1e-7)

        lst = []
        for _ in range(n_iter):
            x = T.copy()
            
            if func_name == 'tranpose_v0':
                t1 = time.time()
                func_obj(x, n_sqrt)
                t2 = time.time()
            else:
                t1 = time.time()
                func_obj(x, n_sqrt, y)
                t2 = time.time()
            lst.append(t2 - t1)

        running_time.append(lst)
    
    performance[func_name] = running_time
    print('Done!')


### plotting
plt.figure(figsize=(20, 5))
plt.title('Comparing performances of different functions\n for performing matrix transpose')

colors = ['cyan', 'r', 'b', 'orange', 'k', 'm', 'g', 'yellow', 'brown']

baseline_key = 'tranpose_v0'
baseline = np.array([np.median(lst) for lst in performance[baseline_key]])

for i, key in enumerate(list(performance.keys())):
    if key == baseline_key:
      continue
    
    arr = np.array([np.median(lst) for lst in performance[key]])
    r = arr / baseline
    
    plt.plot(r, color=colors[i], marker='o', label=key)

plt.axhline(y=1, color='k', linestyle='--', label=f'y=1 (baseline: {baseline_key})')
plt.xlabel('The log2 of input array length', fontsize=13)
plt.ylabel(f"Running time ratio w.r.t baseline", fontsize=13)
plt.xticks(ticks=np.arange(len(p_vals)), labels=p_vals, fontsize=13)
plt.yticks(fontsize=13)
plt.grid()
plt.legend(fontsize=13)
plt.show()
image

This Colab notebook contains the code.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 13, 2024

fft_block

In the step 2 and 5 of the 6-step FFT algorithm, we use the following recursive function:

# fft_block_v0

@njit(fastmath=True)
def _fft_block(n, s, eo, x, y):
    """
    A recursive function that is used as part of fft algorithm

    n : int
    s : int
    eo: bool
    x : numpy.array 1D
    y : numpy.array 1D
    """
    if n == 2:
        if eo:
            z = y
        else:
            z = x

        for i in range(s):
            j = i + s
            a = x[i]
            b = x[j]
            z[i] = a + b
            z[j] = a - b

    elif n >= 4:
        m = n // 2
        sm = s * m

        theta = math.pi / m
        for p in range(m):
            sp = s * p
            c = math.cos(p * theta) - 1j * math.sin(p * theta)
            for q in range(s):
                idx = sp + q
                a = x[idx]
                b = x[idx + sm]

                y[idx + sp] = a + b
                y[idx + sp + s] = (a - b) * c

        _fft_block(m, 2*s, not eo, y, x)

    else:
        pass

We can think of the following ways to speed this up:

(1) Avoid calling math.cos and math.sin function in each iteration of the outer for-loop. We can move it to the outside of the outer for-loop. Then, we can update the factor as follows:

# will be used in new version: fft_block_v1

theta = math.pi / m
c = math.cos(theta) - 1j * math.sin(theta)
w = 1.0
for p in range(m):
    sp = s * p
    for q in range(s):
        idx = sp + q
        a = x[idx]
        b = x[idx + sm]

        y[idx + sp] = a + b
        y[idx + sp + s] = (a - b) * w
    
    w = w * c

(2) In (1), we have: c = math.cos(p * theta) - 1j * math.sin(p * theta). So, we still call math.cos and math.sin in each call of the fft_block recursive function. We can avoid this by adding c this as parameter to the function's signature.

# will be used in new version: fft_block_v2

@njit(fastmath=True)
_fft_block(n, s, eo, x, y, c):
    # Do something
    _fft_block(n, s, eo, x, y, c * c):

An initial c needs to be computed and passed as argument to the _fft_block.

(3) The factor w in y[idx + sp + s] = (a - b) * w needs to be updated m times in each call of the recursive function. Now we try to reduce the number of times this parameter needs to be updated within the for-loop for p in range(m):.

For a given m, the factor w in y[idx + sp + s] = (a - b) * w has the following relationship with p (see # fft_block_v0 provided at the top of this comment):

$w_{p} = cos(\theta_p) - 1j * sin(\theta_p)$, where $\theta_p = \frac{p * \pi}{m}$

$w_{m - p} = cos(\theta_{m-p}) - 1j * sin(\theta_{m-p})$, where $\theta_{m-p} = \frac{(m - p) * \pi}{m}$

Note that $\theta_p + \theta_{m-p} = \pi$. Therefore:

$w_{m - p} = cos(\pi - \theta_{p}) - 1j * sin(\pi - \theta_{p})$
$w_{m - p} = - cos(\theta_p) - 1j * sin(\theta_p)$
$w_{m - p} = - w_{p}^{*}$

Therefore, in this version, in addition to (2), we can replace the following for-loop

w = 1.0
for p in range(m):
    sp = s * p
    for q in range(s):
        idx = sp + q
        a = x[idx]
        b = x[idx + sm]

        y[idx + sp] = a + b
        y[idx + sp + s] = (a - b) * w
    
    w = w * c

with this code:

# will be used in new version: fft_block_v3

        # p = 0
        for i in range(s):
            j = i + sm
            y[i] = x[i] + x[j]
            y[i + s] = x[i] - x[j]

        w = 1.0
        for p in range(1, m // 2):
            # p  --> 1, 2, 3, ..., m//2 - 1
            w = w * c
            sp = s * p
            for i in range(sp, sp + s):
                b = x[i + sm]
                k = i + sp
                y[k] = x[i] + b
                y[k + s] = (x[i] - b) * w

            # p = m - p --> m - 1, m - 2, m - 3, ..., m - m // 2 + 1
            sp = sm - sp 
            w_conj = w.conjugate()
            for i in range(sp, sp + s):
                b = x[i + sm]
                k = i + sp
                y[k] = x[i] + b
                y[k + s] = (b - x[i]) * w_conj

        # p = m // 2
        w = w * c
        sp = sm // 2
        for i in range(sp, sp + s):
            b = x[i + sm]
            k = i + sp
            y[k] = x[i] + b
            y[k + s] = (x[i] - b) * w

(4) we can precompute all the factors and pass the precomputed array as argument to the recursive function. We can use the following code to compute the array:

# will be used in new version: fft_block_v4

@njit(fastmath=True)
def _fill_c_array(c_arr, n, c):
    m = n // 2
    w = 1.0
    for i in range(m):
        c_arr[i] = w
        w = w * c

    if m > 2:
        _fill_c_array(c_arr[m:], m, c * c)
    else:
        return

@njit(fastmath=True)
def fill_c_array(n):
    """
    n is square root of length of input array in 6-step function.
    """
    theta = 2 * math.pi / n
    c = math.cos(theta) - 1j * math.sin(theta)
    c_arr = np.empty(n, dtype=np.complex_)

    _fill_c_array(c_arr, n, c)

    return c_arr

Performance (+ Assertion)

In this Colab notebook, the performance of these four new versions are checked and compared with the baseline. In the 6-step FFT, the recursive function _fft_block(n, s, eo, x, y) is being called n times, twice! So, to better reflect the impact of different functions above on the 6-step fft algorithm, we will consider the same number of calls. The fft_block of version (2) seems to work better than the others (overall?)

@NimaSarajpoor
Copy link
Collaborator Author

transpose + twiddle factor

This is regarding step 3 and 4 of the 6-step algorithm. Originally, we have the following code:

# version 0

    for p in range(n_sqrt):
        theta0 = 2 * p * math.pi / n
        for k in range(p, n_sqrt):
            theta = k * theta0
            w = math.cos(theta) - 1j * math.sin(theta)
    
            if k == p:
                i = p * n_sqrt + p
                x[i] = x[i] * w
            else:
                i = p * n_sqrt + k
                j = k * n_sqrt + p
                x[j], x[i] = x[i] * w, x[j] * w

which tranpose the matrix and multiple its element by twiddle factor. The twiddle factor $w$ is $exp(-j\theta)$, where $\theta = p * k* \frac{2\pi}{n}$. This factor will be multiplied to the element (p, k) and element (k, p) in x.reshape(n_sqrt, n_sqrt), where n_sqrt is $\sqrt{len(x)}$.

We can avoid calling math.cos and math.sin in the inner for-loop by creating an initial value for w, and keep updating it within the inner for-loop as shown below:

# version 1

    theta_init = 2 * math.pi / n
    for p in range(n_sqrt):
        theta0 = theta_init * p

        c = math.cos(theta0) - 1j * math.sin(theta0)
        w = math.cos(theta0 * p) - 1j * math.sin(theta0 * p)
        for k in range(p, n_sqrt):
            i = p * n_sqrt + k

            if p == k:
                x[i] = x[i] * w
            else:
                j = k * n_sqrt + p
                x[j], x[i] = x[i] * w, x[j] * w

            w = w * c

We still need to call math.cos and math.sin twice in each iteration of the outer for-loop. We can again move them to the outside of the outer for-loop. We can start with an initial value, and keep updating them as follows:

# version 2

    wp = 1.0
    cp = 1.0

    theta = 2 * math.pi / n
    factor = math.cos(theta) - 1j * math.sin(theta)
    for p in range(n_sqrt):
        pns = p * n_sqrt
        c = cp
        w = wp
        x[pns + p] = x[pns + p] * w
        for q in range(p + 1, n_sqrt):
            w = w * c
            i = pns + q
            j = q * n_sqrt + p
            x[j], x[i] = x[i] * w, x[j] * w

        cp_new = factor * cp
        wp = wp * cp_new * cp
        cp = cp_new

Performance ( + Assertion)

This Colab Notebook contains the code that shows the performance of these three versions.

image

As observed, the last version mentioned above outperforms the others.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 15, 2024

Function _sixstep_fft (Let's put the optimized blocks together)

Now, let's compare the original function _sixstep_fft (used in our fft algorithm as provided in this comment) with the new version where each block of code is replaced with its optimized version based on the results provided in previous comments. The new version of this function is provided below:

%%writefile sixstep_fft_v1.py
import math
import numpy as np
from numba import njit


@njit(fastmath=True)
def _tranpose(x, n_sqrt, x_transpose):
    blocksize = 32
    blocksize = min(blocksize, n_sqrt)

    x = x.reshape(n_sqrt, n_sqrt)
    x_transpose = x_transpose.reshape(n_sqrt, n_sqrt)
    for i in range(0, n_sqrt, blocksize):
        for j in range(0, n_sqrt, blocksize):
            x_transpose[i:i + blocksize, j:j + blocksize] = np.transpose(x[j:j + blocksize, i:i + blocksize])

    return


@njit(fastmath=True)
def _fft_block(n, s, eo, x, y, c):
    """
    A recursive function that is used as part of fft algorithm

    n : int
    s : int
    eo: bool
    x : numpy.array 1D
    y : numpy.array 1D
    """
    if n == 2:
        if eo:
            z = y
        else:
            z = x

        for i in range(s):
            j = i + s
            a = x[i]
            b = x[j]
            z[i] = a + b
            z[j] = a - b

    elif n >= 4:
        m = n // 2
        sm = s * m

        w = 1.0
        for p in range(m):
            sp = s * p
            for q in range(s):
                idx = sp + q
                a = x[idx]
                b = x[idx + sm]

                y[idx + sp] = a + b
                y[idx + sp + s] = (a - b) * w

            w = w * c

        _fft_block(m, 2*s, not eo, y, x, c * c)

    else:
        pass


@njit(fastmath=True)
def _sixstep_fft(x, y):
    """
    Apply 6-step FFT algorithm and update x in-place.
    """
    n = len(x)
    n_sqrt = int(np.sqrt(n))


    # step 1: matrix transpose
    _tranpose(x, n_sqrt, y)

    # step 2
    theta = 2 * math.pi / n_sqrt
    c_theta = math.cos(theta) - 1j * math.sin(theta)
    for start in range(0, n, n_sqrt):
      _fft_block(n_sqrt, 1, False, y[start:], x[start:], c_theta)

    # step 3 and 4: tranpose with twiddle_factor
    wp = 1.0
    cp = 1.0

    theta_twiddle = 2 * math.pi / n
    factor = math.cos(theta_twiddle) - 1j * math.sin(theta_twiddle)
    for p in range(n_sqrt):
        pns = p * n_sqrt
        c = cp
        w = wp
        y[pns + p] = y[pns + p] * w
        for q in range(p + 1, n_sqrt):
            w = w * c
            i = pns + q
            j = q * n_sqrt + p
            y[j], y[i] = y[i] * w, y[j] * w

        cp_new = factor * cp
        wp = wp * cp_new * cp
        cp = cp_new

    # step 5
    for start in range(0, n, n_sqrt):
      _fft_block(n_sqrt, 1, False, y[start:], x[start:], c_theta)

    # step 6: matrix transpose
    _tranpose(y, n_sqrt, x)

    return

Performance (+ Assertion)

The code is provided in this Colab notebook. The result is provided below:

image

As observed, we see 50-80% improvement (except input with length 2^2)

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 15, 2024

Eight-step function

We now work on the eight-step function as provided in this comment

@njit(fastmath=True)
def _eightstep_fft(x, y):
    """
    Apply 8-step FFT algorithm and update x in-place.
    """
    n = len(x)
    m = n // 2

    theta = math.pi / m 
    for i in range(m):
        w = math.cos(i * theta) - 1j * math.sin(i * theta)
        j = i + m
        y[i] = x[i] + x[j]
        y[j] = (x[i] - x[j]) * w

    _sixstep_fft(y[:m], x[:m])
    _sixstep_fft(y[m:], x[m:])

    for p in range(m):
        x[2 * p] = y[p]
        x[2 * p + 1] = y[p + m]

    return

We focus on the following part of the code:

# version 0

   m = len(x) // 2

  theta = math.pi / m 
  for i in range(m):
      w = math.cos(i * theta) - 1j * math.sin(i * theta)
      j = i + m
      y[i] = x[i] + x[j]
      y[j] = (x[i] - x[j]) * w

Note 1: We can move w = math.cos(i * theta) - 1j * math.sin(i * theta) to the outside of the for-loop. (# version 1)

Note 2: The factor, w, of the i-th iteration. is negative-conjugate of the w of m-i-th iteration (# version 2)

Performance (+ Assertion)

This Colab Notebook contains the code. The following result is obtained:

image

As observed, the version 1 and 2 show close performance. The code in version 2 is more complicated though. So, we go with version 1.

m = len(x) // 2
  
  theta = math.pi / m 
  factor = math.cos(theta) - 1j * math.sin(theta)
  w = 1.0
  for i in range(m):
      j = i + m
      y[i] = x[i] + x[j]
      y[j] = (x[i] - x[j]) * w

      w = w * factor

  return

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 19, 2024

Put pieces together...

It is now the time to put the pieces together, and compare our so-far-optimized RFFT with the one provided in this comment. Code is available in this Colab notebook. The running time of RFFT is recorded for input with length 2^2...2^20. For each length, the function is called 5000 times. Since the running time is small, small deviation may result in considerable difference when calculating the speed-up ratio (w.r.t the running time of scipy.fft.rfft). To this end, I tried to calculate a range. Out of 5000 samples (of running time), I removed the ones that are outside of the range $[\mu - 2\sigma, \mu + 2\sigma]$. I then got the min, max, and mean. Therefore, the max and min speed-up for RFFT w.r.t scipy, can be computed as follows:

max_speed_up = max_running_time_of_RFFT / min_running_time_of_Scipy
min_speed_up = min_running_time_of_RFFT / max_running_time_of_Scipy

and, we can calculate the mean-based speed-up as follows:

mean-based speed_up = mean_running_time_of_RFFT  / mean_running_time_of_Scioy

Not sure if this approach is science-backed. Still, it can give us some idea about the range. In Colab, I got the following result (lower is better):

image

And, in my MacOS, I got this:
image

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented May 19, 2024

As a follow up to my previous comment, I would like to check the performance of numpy.fft.rfft as well. The result (from running the code in my MacOS) shown below. Lower is better.
image

@seanlaw seanlaw added the enhancement New feature or request label May 22, 2024
@NimaSarajpoor
Copy link
Collaborator Author

[Recap]
In the previous comment, I showed that the optimized version of our FFT implementation (shown in red) is significantly better than the initial implementation. However, it is still outperformed by Numpy and Scipy when the size of input is large. Our FFT implementation calls a recursive function which was optimized according to the study described in this comment.

[Now]
We replace the recursive function with a for-loop. Furthermore, I call different functions for input with different sizes. We call it rfft_v2. I noticed a considerable performance gain. It is NOT a clean approach and I need to work on it more. But.... it gives us some hope! :)

We are now getting closer to Scipy's performance!

image

The code is available in this colab notebook and it has assertion to make sure the output is correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants