Commit 2f7743e7 authored by Tobias Hangleiter's avatar Tobias Hangleiter
Browse files

Minor improvements to tensor

parent 5f3ee05a
......@@ -77,7 +77,8 @@ if numba:
)(abs2)
def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
def tensor(*args, rank: int = 2,
optimize: Union[bool, str] = False) -> ndarray:
"""
Fast, flexible tensor product using einsum. The product is taken over the
last *rank* axes (the exception being ``rank == 1`` since a column vector
......@@ -100,11 +101,11 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
----------
args : array_like
The elements of the tensor product
rank : int, optional
rank : int, optional (default: 2)
The rank of the tensors. E.g., for a Kronecker product between two
vectors, ``rank == 1``, and between two matrices ``rank == 2``. The
remaining axes are broadcast over.
optimize : bool|str, optional
optimize : bool|str, optional (default: False)
Optimize the tensor contraction order. Passed through to
:meth:`numpy.einsum`.
......@@ -136,8 +137,7 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
... result = tensor(A, B, rank=2)
... except ValueError as err: # cannot broadcast over axis 0
... print(err)
Incompatible shapes. Could not compute tensor(tensor(*args[:1], rank=2),
args[1], rank=2) with shapes (3, 1, 2) and (2, 2, 2).
Incompatible shapes (3, 1, 2) and (2, 2, 2) for tensor product of rank 2.
>>> result = tensor(A, B, rank=3)
>>> result.shape == (3*2, 1*2, 2*2)
True
......@@ -153,7 +153,7 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
# Vectors, but numpy arrays are still two-dimensional
rank += 1
chars = string.ascii_lowercase + string.ascii_uppercase
chars = string.ascii_letters
# All the subscripts we need
A_chars = chars[:rank]
B_chars = chars[rank:2*rank]
......@@ -161,12 +161,13 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
A_chars, B_chars, ''.join(i + j for i, j in zip(A_chars, B_chars))
)
def get_outshape(A, B):
"""Get tensor product result's shape"""
def tensor_product_shape(shape_A: Sequence[int], shape_B: Sequence[int],
rank: int):
"""Get shape of the tensor product between A and B of rank rank"""
broadcast_shape = ()
# Loop over dimensions from last to first, filling the 'shorter' shape
# with 1's once it is exhausted
for dims in zip_longest(A.shape[-rank-1::-1], B.shape[-rank-1::-1],
for dims in zip_longest(shape_A[-rank-1::-1], shape_B[-rank-1::-1],
fillvalue=1):
if 1 in dims:
# Broadcast 1-d of argument to dimension of other
......@@ -175,45 +176,46 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
# Both arguments have same dimension on axis.
broadcast_shape = dims[:1] + broadcast_shape
else:
# Pass the Exception through to binary_tensor_wrapper for a
# meaningful error message
raise ValueError
raise ValueError('Incompatible shapes ' +
'{} and {} '.format(shape_A, shape_B) +
'for tensor product of rank {}.'.format(rank))
# Shape of the actual tensor product is product of each dimension
# Shape of the actual tensor product is product of each dimension,
# again broadcasting if need be
tensor_shape = tuple(
reduce(operator.mul, dimensions)
for dimensions in zip(*[arg.shape[-rank:] for arg in (A, B)])
)
reduce(operator.mul, dimensions) for dimensions in zip_longest(
shape_A[:-rank-1:-1], shape_B[:-rank-1:-1], fillvalue=1
)
)[::-1]
return broadcast_shape + tensor_shape
def binary_tensor(A, B):
"""Compute the Kronecker product of two tensors"""
if optimize:
path, _ = np.einsum_path(subscripts, A, B, optimize=optimize)
else:
path = False
outshape = get_outshape(A, B)
return np.einsum(subscripts, A, B, optimize=path).reshape(outshape)
def binary_tensor_wrapper(A, B):
"""Wrap binary_tensor to count function calls and catch Exceptions"""
try:
binary_tensor.calls += 1
result = binary_tensor(A, B)
except ValueError:
raise ValueError(
'Incompatible shapes. Could not compute tensor(tensor(*' +
'args[:{0}], rank={1}), args[{0}], rank={1}) '.format(
binary_tensor.calls, rank) +
'with shapes {} and {}.'.format(A.shape, B.shape)
)
return result
# Initialize function call counter to zero
binary_tensor.calls = 0
return reduce(binary_tensor_wrapper, args)
# Add dimensions so that each arg has at least ndim == rank
while A.ndim < rank:
A = A[None, :]
while B.ndim < rank:
B = B[None, :]
outshape = tensor_product_shape(A.shape, B.shape, rank)
return np.einsum(subscripts, A, B, optimize=optimize).reshape(outshape)
# Compute the tensor products in a binary tree-like structure, calculating
# the product of two leaves and working up. This is more memory-efficient
# than reduce(binary_tensor, args) which computes the products
# left-to-right.
n = len(args)
bit = n % 2
while n > 1:
args = args[:bit] + [binary_tensor(*args[i:i+2])
for i in range(bit, n, 2)]
n = len(args)
bit = n % 2
return args[0]
def mdot(arr: Sequence, axis: int = 0) -> ndarray:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment