### Minor improvements to tensor

parent 5f3ee05a
 ... ... @@ -77,7 +77,8 @@ if numba: )(abs2) def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False): def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False) -> ndarray: """ Fast, flexible tensor product using einsum. The product is taken over the last *rank* axes (the exception being ``rank == 1`` since a column vector ... ... @@ -100,11 +101,11 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False): ---------- args : array_like The elements of the tensor product rank : int, optional rank : int, optional (default: 2) The rank of the tensors. E.g., for a Kronecker product between two vectors, ``rank == 1``, and between two matrices ``rank == 2``. The remaining axes are broadcast over. optimize : bool|str, optional optimize : bool|str, optional (default: False) Optimize the tensor contraction order. Passed through to :meth:`numpy.einsum`. ... ... @@ -136,8 +137,7 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False): ... result = tensor(A, B, rank=2) ... except ValueError as err: # cannot broadcast over axis 0 ... print(err) Incompatible shapes. Could not compute tensor(tensor(*args[:1], rank=2), args, rank=2) with shapes (3, 1, 2) and (2, 2, 2). Incompatible shapes (3, 1, 2) and (2, 2, 2) for tensor product of rank 2. >>> result = tensor(A, B, rank=3) >>> result.shape == (3*2, 1*2, 2*2) True ... ... @@ -153,7 +153,7 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False): # Vectors, but numpy arrays are still two-dimensional rank += 1 chars = string.ascii_lowercase + string.ascii_uppercase chars = string.ascii_letters # All the subscripts we need A_chars = chars[:rank] B_chars = chars[rank:2*rank] ... ... @@ -161,12 +161,13 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False): A_chars, B_chars, ''.join(i + j for i, j in zip(A_chars, B_chars)) ) def get_outshape(A, B): """Get tensor product result's shape""" def tensor_product_shape(shape_A: Sequence[int], shape_B: Sequence[int], rank: int): """Get shape of the tensor product between A and B of rank rank""" broadcast_shape = () # Loop over dimensions from last to first, filling the 'shorter' shape # with 1's once it is exhausted for dims in zip_longest(A.shape[-rank-1::-1], B.shape[-rank-1::-1], for dims in zip_longest(shape_A[-rank-1::-1], shape_B[-rank-1::-1], fillvalue=1): if 1 in dims: # Broadcast 1-d of argument to dimension of other ... ... @@ -175,45 +176,46 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False): # Both arguments have same dimension on axis. broadcast_shape = dims[:1] + broadcast_shape else: # Pass the Exception through to binary_tensor_wrapper for a # meaningful error message raise ValueError raise ValueError('Incompatible shapes ' + '{} and {} '.format(shape_A, shape_B) + 'for tensor product of rank {}.'.format(rank)) # Shape of the actual tensor product is product of each dimension # Shape of the actual tensor product is product of each dimension, # again broadcasting if need be tensor_shape = tuple( reduce(operator.mul, dimensions) for dimensions in zip(*[arg.shape[-rank:] for arg in (A, B)]) ) reduce(operator.mul, dimensions) for dimensions in zip_longest( shape_A[:-rank-1:-1], shape_B[:-rank-1:-1], fillvalue=1 ) )[::-1] return broadcast_shape + tensor_shape def binary_tensor(A, B): """Compute the Kronecker product of two tensors""" if optimize: path, _ = np.einsum_path(subscripts, A, B, optimize=optimize) else: path = False outshape = get_outshape(A, B) return np.einsum(subscripts, A, B, optimize=path).reshape(outshape) def binary_tensor_wrapper(A, B): """Wrap binary_tensor to count function calls and catch Exceptions""" try: binary_tensor.calls += 1 result = binary_tensor(A, B) except ValueError: raise ValueError( 'Incompatible shapes. Could not compute tensor(tensor(*' + 'args[:{0}], rank={1}), args[{0}], rank={1}) '.format( binary_tensor.calls, rank) + 'with shapes {} and {}.'.format(A.shape, B.shape) ) return result # Initialize function call counter to zero binary_tensor.calls = 0 return reduce(binary_tensor_wrapper, args) # Add dimensions so that each arg has at least ndim == rank while A.ndim < rank: A = A[None, :] while B.ndim < rank: B = B[None, :] outshape = tensor_product_shape(A.shape, B.shape, rank) return np.einsum(subscripts, A, B, optimize=optimize).reshape(outshape) # Compute the tensor products in a binary tree-like structure, calculating # the product of two leaves and working up. This is more memory-efficient # than reduce(binary_tensor, args) which computes the products # left-to-right. n = len(args) bit = n % 2 while n > 1: args = args[:bit] + [binary_tensor(*args[i:i+2]) for i in range(bit, n, 2)] n = len(args) bit = n % 2 return args def mdot(arr: Sequence, axis: int = 0) -> ndarray: ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment