Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Partial trace of state vecrtor backend #73

Open
masa10-f opened this issue Jul 11, 2023 · 7 comments
Open

[Bug]: Partial trace of state vecrtor backend #73

masa10-f opened this issue Jul 11, 2023 · 7 comments
Labels
bug Something isn't working

Comments

@masa10-f
Copy link
Contributor

masa10-f commented Jul 11, 2023

Describe the bug

The partial trace of a state vector is generally expected to return a density matrix, but Statevec.ptrace currently returns a state vector. This is an incorrect operation for non-separable quantum states, such as a Bell state.

To Reproduce

You can check the above behavior with the following code.

# %%
import numpy as np
from graphix.sim.statevec import Statevec

# %%
# prepare H and CNOT gates.
H = np.array([[1, 1], [1, -1]]) / np.sqrt(2)
CNOT = np.array([[1, 0, 0, 0], [0, 1, 0, 0],
                 [0, 0, 0, 1], [0, 0, 1, 0]])

# %%
# make a bell state
sv = Statevec(plus_states=False, nqubit=2)
sv.evolve(H, [0])
sv.evolve(CNOT, [0, 1])

print(sv.flatten())
# %%
# trace out 2nd qubit
sv.ptrace([1])
print(sv.flatten())
# the return should be |0><0| + |1><1| but, it returns |0> or |1>

Expected behavior

The Statevec.ptrace should return a density matrix. A code converting a reduced density matrix into a state vector should check its purity before conversion.

Environment (please complete the following information):

  • OS: Ubuntu 20.04
  • Python version: 3.8.10
  • Related module versions if applicable: graphix=0.2.0

Additional context

N/A

@masa10-f masa10-f added the bug Something isn't working label Jul 11, 2023
@shinich1
Copy link
Contributor

Just to clarify: this indeed is something to be taken care of, but pattern.simulate_pattern(backend='statevector') is working perfectly fine as it is, because the measured qubits are separable from the rest.

@masa10-f
Copy link
Contributor Author

masa10-f commented Jul 14, 2023

I've written a draft improvement plan for partial tracing a separable state below.

 # %%
import numpy as np
from graphix.sim.statevec import Statevec, meas_op

import time


def truncate_one_qubit(k, sv):
    n = len(sv.dims())
    taken = np.zeros(2**(n - 1), dtype=np.complex128)
    state = sv.flatten()
    for i in range(2**k):
        for j in range(2**(n - k - 1)):
            taken[i * 2**(n - k - 1) + j] = state[2**(n - k) * i + j]

    norm = taken.dot(taken.flatten().conjugate())
    taken = taken / norm**0.5
    return taken


# %%
# prepare a non-separable state
k = 3
n = 10
statevec = Statevec(nqubit=n)
for i in range(n):
    statevec.entangle((i, (i + 1) % n))
# print(statevec.flatten())

# %%
# measure qubit k
m_op = meas_op(np.pi / 5)
statevec.evolve(m_op, [k])
# print(statevec.flatten())

# %%
# discard qubit 0
start = time.perf_counter()
reduced = truncate_one_qubit(k, statevec)
end = time.perf_counter()
print("time(new method)", end - start)
# print(reduced)

# %%
# reference
start = time.perf_counter()
statevec.ptrace([k])
end = time.perf_counter()
print("time(ptrace)", end - start)
# print(statevec.flatten())

# %%
# check inner product
inner_product = reduced.dot(statevec.flatten().conjugate())
print(np.abs(inner_product))

In my environment, the execution time has improved by 3 orders of magnitude when the number of qubits(nqubit) is equal to 10.

@masa10-f
Copy link
Contributor Author

masa10-f commented Jul 14, 2023

The above method is not yet complete because the k-th qubit can be in the |0> state(and, not the Statevec format). However, there are several ways to resolve this problem.
Once this is completed, I'd like to know the performance comparison with the previous method.

@nabe98
Copy link
Contributor

nabe98 commented Jul 24, 2023

This is completed by the follwoings.

def truncate_one_qubit(self, qarg):
    """truncate one qubit

    Args:
        qarg (int): qubit index
    """
    # extract |***0_{qarg}***> components if not zero else |***1_{qarg}***>
    psi = self.psi.take(indices=0, axis=qarg)
    self.psi = psi if psi[(0,) * psi.ndim] != 0.0 else self.psi.take(indices=1, axis=qarg)
    self.normalize()

The performance can be compared by the followings.

# %%
import numpy as np
from graphix.sim.statevec import Statevec, meas_op

import time


def truncate_one_qubit_old(k, sv):
    n = len(sv.dims())
    taken = np.zeros(2**(n - 1), dtype=np.complex128)
    state = sv.flatten()
    for i in range(2**k):
        for j in range(2**(n - k - 1)):
            taken[i * 2**(n - k - 1) + j] = state[i * 2**(n - k) + j]

    norm = taken.dot(taken.flatten().conjugate())
    taken = taken / norm**0.5
    return taken

def truncate_one_qubit(k, sv):
    # extract |***0_{qarg}***> components if not zero else |***1_{qarg}***>
    psi = sv.psi.take(indices=0, axis=k)
    psi = psi if psi[(0,) * psi.ndim] != 0.0 else sv.take(indices=1, axis=k).psi
    norm = psi.flatten().dot(psi.flatten().conjugate())
    psi = psi / norm**0.5
    return psi.flatten()


# %%
# prepare a non-separable state
k = 3
n = 10
statevec = Statevec(nqubit=n)
for i in range(n):
    statevec.entangle((i, (i + 1) % n))
# print(statevec.flatten())

# %%
# measure qubit k
m_op = meas_op(np.pi / 5)
statevec.evolve(m_op, [k])
# print(statevec.flatten())

# %%
# discard qubit 0 (old)
start = time.perf_counter()
reduced = truncate_one_qubit_old(k, statevec)
end = time.perf_counter()
print("time(old method)", end - start)
# print(reduced)

# %%
# discard qubit 0 (new)

start = time.perf_counter()
reduced2 = truncate_one_qubit(k, statevec)
end = time.perf_counter()
print("time(new method)", end - start)
# print(reduced2)

# %%
# reference
start = time.perf_counter()
statevec.ptrace([k])
end = time.perf_counter()
print("time(ptrace)", end - start)
# print(statevec.flatten())

# %%
# check inner product
inner_product = reduced.dot(statevec.flatten().conjugate())
print(np.abs(inner_product))
inner_product = reduced2.dot(statevec.flatten().conjugate())
print(np.abs(inner_product))

In my environment, the new version sometimes slower than the old trancate_one_qubit in this sample. But when I changed the implementation in the statevec to this new trancate_one_qubit, it was nearly 2 times higher than the old ones.

@shinich1
Copy link
Contributor

@nabe98
thanks - could you paste a plot of speed comparison, for visual inspection? for example, could you compare pattern simulation speed for varying pattern size with and without the new code?

also what do you mean by below? is it sometimes slower somehow?

the new version sometimes slower than the old trancate_one_qubit in this sample

@masa10-f
Copy link
Contributor Author

masa10-f commented Jul 25, 2023

@nabe98 Thank you!
@shinich1 I compared the performance of the three methods.
Including standard deviations, @nabe98 's method is faster than mine.

comparison_truncation
comparison_truncation_old_new

# %%
import numpy as np
import matplotlib.pyplot as plt
from statistics import stdev, mean
from graphix.sim.statevec import Statevec, meas_op

import time
from copy import deepcopy


def truncate_one_qubit_old(k, sv):
    n = len(sv.dims())
    taken = np.zeros(2**(n - 1), dtype=np.complex128)
    state = sv.flatten()
    for i in range(2**k):
        for j in range(2**(n - k - 1)):
            taken[i * 2**(n - k - 1) + j] = state[i * 2**(n - k) + j]

    norm = taken.dot(taken.flatten().conjugate())
    taken = taken / norm**0.5
    return taken


def truncate_one_qubit(k, sv):
    # extract |***0_{qarg}***> components if not zero else |***1_{qarg}***>
    psi = sv.psi.take(indices=0, axis=k)
    psi = psi if psi[(
        0,) * psi.ndim] != 0.0 else sv.take(indices=1, axis=k).psi
    norm = psi.flatten().dot(psi.flatten().conjugate())
    psi = psi / norm**0.5
    return psi.flatten()


# %%
time_old = []
time_new = []
time_ptrace = []
iteration = 30

# prepare a non-separable state
k = 3
n = 10
statevec = Statevec(nqubit=n)
for i in range(n):
    statevec.entangle((i, (i + 1) % n))
# print(statevec.flatten())

# %%
# measure qubit k
m_op = meas_op(np.pi / 5)
statevec.evolve(m_op, [k])
# print(statevec.flatten())

# %%
# discard qubit 0 (old)
for i in range(iteration):
    start = time.perf_counter()
    reduced = truncate_one_qubit_old(k, statevec)
    end = time.perf_counter()
    time_old.append(end - start)
    # print("time(old method)", end - start)
    # print(reduced)

# %%
# discard qubit 0 (new)
for i in range(iteration):
    start = time.perf_counter()
    reduced2 = truncate_one_qubit(k, statevec)
    end = time.perf_counter()
    time_new.append(end - start)
    # print("time(new method)", end - start)
    # print(reduced2)

# %%
# reference
for i in range(iteration):
    statevec_cp = deepcopy(statevec)
    start = time.perf_counter()
    statevec_cp.ptrace([k])
    end = time.perf_counter()
    time_ptrace.append(end - start)
    # print("time(ptrace)", end - start)
    # print(statevec_cp.flatten())

# %%
# check inner product
inner_product = reduced.dot(statevec_cp.flatten().conjugate())
print(np.abs(inner_product))
inner_product = reduced2.dot(statevec_cp.flatten().conjugate())
print(np.abs(inner_product))
# %%
# acquire statistics

mean_old = mean(time_old)
mean_new = mean(time_new)
mean_ptrace = mean(time_ptrace)

std_old = stdev(time_old)
std_new = stdev(time_new)
std_ptrace = stdev(time_ptrace)

# %%
# plot

plt.bar(["old", "new", "ptrace"], [mean_old, mean_new, mean_ptrace],
        yerr=[std_old, std_new, std_ptrace])
plt.ylabel("time (s)")
plt.yscale("log")
plt.show()
# %%

# plot without ptrace
plt.bar(["old", "new"], [mean_old, mean_new], yerr=[std_old, std_new])
plt.ylabel("time (s)")
plt.show()

@masa10-f
Copy link
Contributor Author

To make a comparison in the pattern simulator, we need to modify the 'measure' method which is tailored for the ptrace method, tracing out a group of measured qubits together. With the new truncation method, we can individually trace out measured qubits. This will improve the performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants