Skip to content

Commit

Permalink
Merge pull request #24993 from Costor/TP_ImmutableMatrix
Browse files Browse the repository at this point in the history
TensorProduct fails for ImmutableMatrix: correct to MatrixBase instead of Matrix
  • Loading branch information
hanspi42 committed Mar 31, 2023
2 parents ad253b7 + 01ce737 commit cd9ff53
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
6 changes: 4 additions & 2 deletions sympy/physics/quantum/tensorproduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from sympy.core.mul import Mul
from sympy.core.power import Pow
from sympy.core.sympify import sympify
from sympy.matrices.dense import MutableDenseMatrix as Matrix
from sympy.matrices.dense import DenseMatrix as Matrix
from sympy.matrices.immutable import ImmutableDenseMatrix as ImmutableMatrix
from sympy.printing.pretty.stringpict import prettyForm

from sympy.physics.quantum.qexpr import QuantumError
Expand Down Expand Up @@ -120,7 +121,8 @@ class TensorProduct(Expr):
is_commutative = False

def __new__(cls, *args):
if isinstance(args[0], (Matrix, numpy_ndarray, scipy_sparse_matrix)):
if isinstance(args[0], (Matrix, ImmutableMatrix, numpy_ndarray,
scipy_sparse_matrix)):
return matrix_tensor_product(*args)
c_part, new_args = cls.flatten(sympify(args))
c_part = Mul(*c_part)
Expand Down
12 changes: 11 additions & 1 deletion sympy/physics/quantum/tests/test_tensorproduct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sympy.core.numbers import I
from sympy.core.symbol import symbols
from sympy.core.expr import unchanged
from sympy.matrices import Matrix, SparseMatrix
from sympy.matrices import Matrix, SparseMatrix, ImmutableMatrix

from sympy.physics.quantum.commutator import Commutator as Comm
from sympy.physics.quantum.tensorproduct import TensorProduct
Expand Down Expand Up @@ -125,3 +125,13 @@ def test_eval_trace():
1.0*A*Dagger(C)*Tr(B*Dagger(D)) +
1.0*C*Dagger(A)*Tr(D*Dagger(B)) +
1.0*C*Dagger(C)*Tr(D*Dagger(D)))


def test_pr24993():
from sympy.matrices.expressions.kronecker import matrix_kronecker_product
from sympy.physics.quantum.matrixutils import matrix_tensor_product
X = Matrix([[0, 1], [1, 0]])
Xi = ImmutableMatrix(X)
assert TensorProduct(Xi, Xi) == TensorProduct(X, X)
assert TensorProduct(Xi, Xi) == matrix_tensor_product(X, X)
assert TensorProduct(Xi, Xi) == matrix_kronecker_product(X, X)

0 comments on commit cd9ff53

Please sign in to comment.