Commit c7013555 authored by Tobias Hangleiter's avatar Tobias Hangleiter
Browse files

Small improvements for check_phase_eq and dot_HS

parent d553fe5f
......@@ -76,19 +76,18 @@ if numba:
def tensor(*args):
r"""
Fast tensor product using einsum. The arguments may be of arbitrary
dimension but are required to be square on its last two axes and the shapes
of the remaining axes must be the same. For example, the following shapes
are compatible::
dimension but are required to be square or pairwise transposed on its last
two axes and the shapes of the remaining axes must be the same. For
example, the following shapes are compatible::
(a, b, c, d, d), (a, b, c, e, e) -> (a, b, c, d*e, d*e)
(a, b, c), (a, c, b) -> (a, b*c, c*b)
"""
if len(set(arg.shape[:-2] for arg in args)) != 1:
raise ValueError('Require all args to have the same shape except ' +
'on the last two axes.')
if not all(arg.shape[-2] == arg.shape[-1] for arg in args):
raise ValueError('Require all args to be square in its last two axes.')
def binary_tensor(A, B):
d = A.shape[-1]*B.shape[-1]
......@@ -213,8 +212,8 @@ def remove_float_errors(arr: ndarray, eps_scale: float = None):
return arr
def check_phase_eq(psi: Union[qt.Qobj, ndarray],
phi: Union[qt.Qobj, ndarray],
def check_phase_eq(psi: Union[qt.Qobj, Sequence],
phi: Union[qt.Qobj, Sequence],
eps: float = None,
normalized: bool = False) -> Tuple[bool, float]:
r"""
......@@ -229,7 +228,7 @@ def check_phase_eq(psi: Union[qt.Qobj, ndarray],
Parameters
----------
psi, phi : Qobj or ndarray
psi, phi : Qobj or array_like
Vectors or operators to be compared
eps : float
The tolerance below which the two objects are treated as equal, i.e.,
......@@ -245,14 +244,14 @@ def check_phase_eq(psi: Union[qt.Qobj, ndarray],
>>> check_phase_eq(psi, phi)
(True, 1.2345)
"""
d = max(psi.shape)
psi, phi = [obj.full() if isinstance(obj, qt.Qobj) else obj
for obj in (psi, phi)]
if eps is None:
# Tolerance the floating point eps times the dimension squared
eps = max(np.finfo(psi.dtype).eps, np.finfo(phi.dtype).eps)*d**2
# Tolerance the floating point eps times the # of flops for the matrix
# multiplication, i.e. for psi and phi n x m matrices n**2*m
eps = max(np.finfo(psi.dtype).eps, np.finfo(phi.dtype).eps) *\
np.prod(psi.shape)*phi.shape[-1]
if psi.ndim - psi.shape.count(1) == 1:
# Vector
......@@ -290,6 +289,11 @@ def dot_HS(U: Union[ndarray, qt.Qobj],
U, V : Qobj or ndarray
Objects to compute the inner product of.
Returns
-------
result : float, complex
The result rounded to precision eps.
Examples
--------
>>> U, V = qt.sigmax(), qt.sigmay()
......@@ -303,8 +307,6 @@ def dot_HS(U: Union[ndarray, qt.Qobj],
if isinstance(V, qt.Qobj):
V = V.full()
res = np.einsum('ij,ij', U.conj(), V)
if eps is None:
# Tolerance is the dtype precision times the number of flops for the
# matrix multiplication
......@@ -314,22 +316,13 @@ def dot_HS(U: Union[ndarray, qt.Qobj],
# dtype is int and therefore exact
eps = 0
# Deal with real and imaginary part separately
if abs(res.real) <= eps:
real = 0.0
elif abs(1 - abs(res.real)) <= eps:
real = copysign(1, res.real)
else:
real = res.real
if abs(res.imag) <= eps:
imag = 0
elif abs(1 - abs(res.imag)) <= eps:
imag = copysign(1, res.imag)
if eps == 0:
decimals = 0
else:
imag = res.imag
decimals = abs(int(np.log10(eps)))
return real + 1j*imag if imag != 0 else real
res = np.round(np.einsum('ij,ij', U.conj(), V), decimals)
return res if res.imag else res.real
def sparsity(arr: ndarray, eps: float = None) -> float:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment