Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
qutech
qutil
Commits
2f7743e7
Commit
2f7743e7
authored
Apr 30, 2019
by
Tobias Hangleiter
Browse files
Minor improvements to tensor
parent
5f3ee05a
Changes
1
Hide whitespace changes
Inline
Sidebyside
qutil/linalg.py
View file @
2f7743e7
...
...
@@ 77,7 +77,8 @@ if numba:
)(
abs2
)
def
tensor
(
*
args
,
rank
:
int
=
2
,
optimize
:
Union
[
bool
,
str
]
=
False
):
def
tensor
(
*
args
,
rank
:
int
=
2
,
optimize
:
Union
[
bool
,
str
]
=
False
)
>
ndarray
:
"""
Fast, flexible tensor product using einsum. The product is taken over the
last *rank* axes (the exception being ``rank == 1`` since a column vector
...
...
@@ 100,11 +101,11 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):

args : array_like
The elements of the tensor product
rank : int, optional
rank : int, optional
(default: 2)
The rank of the tensors. E.g., for a Kronecker product between two
vectors, ``rank == 1``, and between two matrices ``rank == 2``. The
remaining axes are broadcast over.
optimize : boolstr, optional
optimize : boolstr, optional
(default: False)
Optimize the tensor contraction order. Passed through to
:meth:`numpy.einsum`.
...
...
@@ 136,8 +137,7 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
... result = tensor(A, B, rank=2)
... except ValueError as err: # cannot broadcast over axis 0
... print(err)
Incompatible shapes. Could not compute tensor(tensor(*args[:1], rank=2),
args[1], rank=2) with shapes (3, 1, 2) and (2, 2, 2).
Incompatible shapes (3, 1, 2) and (2, 2, 2) for tensor product of rank 2.
>>> result = tensor(A, B, rank=3)
>>> result.shape == (3*2, 1*2, 2*2)
True
...
...
@@ 153,7 +153,7 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
# Vectors, but numpy arrays are still twodimensional
rank
+=
1
chars
=
string
.
ascii_l
owercase
+
string
.
ascii_uppercase
chars
=
string
.
ascii_l
etters
# All the subscripts we need
A_chars
=
chars
[:
rank
]
B_chars
=
chars
[
rank
:
2
*
rank
]
...
...
@@ 161,12 +161,13 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
A_chars
,
B_chars
,
''
.
join
(
i
+
j
for
i
,
j
in
zip
(
A_chars
,
B_chars
))
)
def
get_outshape
(
A
,
B
):
"""Get tensor product result's shape"""
def
tensor_product_shape
(
shape_A
:
Sequence
[
int
],
shape_B
:
Sequence
[
int
],
rank
:
int
):
"""Get shape of the tensor product between A and B of rank rank"""
broadcast_shape
=
()
# Loop over dimensions from last to first, filling the 'shorter' shape
# with 1's once it is exhausted
for
dims
in
zip_longest
(
A
.
shape
[

rank

1
::

1
],
B
.
shape
[

rank

1
::

1
],
for
dims
in
zip_longest
(
shape
_A
[

rank

1
::

1
],
shape
_B
[

rank

1
::

1
],
fillvalue
=
1
):
if
1
in
dims
:
# Broadcast 1d of argument to dimension of other
...
...
@@ 175,45 +176,46 @@ def tensor(*args, rank: int = 2, optimize: Union[bool, str] = False):
# Both arguments have same dimension on axis.
broadcast_shape
=
dims
[:
1
]
+
broadcast_shape
else
:
# Pass the Exception through to binary_tensor_wrapper for a
# meaningful error message
raise
ValueError
raise
ValueError
(
'Incompatible shapes '
+
'{} and {} '
.
format
(
shape_A
,
shape_B
)
+
'for tensor product of rank {}.'
.
format
(
rank
))
# Shape of the actual tensor product is product of each dimension
# Shape of the actual tensor product is product of each dimension,
# again broadcasting if need be
tensor_shape
=
tuple
(
reduce
(
operator
.
mul
,
dimensions
)
for
dimensions
in
zip
(
*
[
arg
.
shape
[

rank
:]
for
arg
in
(
A
,
B
)])
)
reduce
(
operator
.
mul
,
dimensions
)
for
dimensions
in
zip_longest
(
shape_A
[:

rank

1
:

1
],
shape_B
[:

rank

1
:

1
],
fillvalue
=
1
)
)[::

1
]
return
broadcast_shape
+
tensor_shape
def
binary_tensor
(
A
,
B
):
"""Compute the Kronecker product of two tensors"""
if
optimize
:
path
,
_
=
np
.
einsum_path
(
subscripts
,
A
,
B
,
optimize
=
optimize
)
else
:
path
=
False
outshape
=
get_outshape
(
A
,
B
)
return
np
.
einsum
(
subscripts
,
A
,
B
,
optimize
=
path
).
reshape
(
outshape
)
def
binary_tensor_wrapper
(
A
,
B
):
"""Wrap binary_tensor to count function calls and catch Exceptions"""
try
:
binary_tensor
.
calls
+=
1
result
=
binary_tensor
(
A
,
B
)
except
ValueError
:
raise
ValueError
(
'Incompatible shapes. Could not compute tensor(tensor(*'
+
'args[:{0}], rank={1}), args[{0}], rank={1}) '
.
format
(
binary_tensor
.
calls
,
rank
)
+
'with shapes {} and {}.'
.
format
(
A
.
shape
,
B
.
shape
)
)
return
result
# Initialize function call counter to zero
binary_tensor
.
calls
=
0
return
reduce
(
binary_tensor_wrapper
,
args
)
# Add dimensions so that each arg has at least ndim == rank
while
A
.
ndim
<
rank
:
A
=
A
[
None
,
:]
while
B
.
ndim
<
rank
:
B
=
B
[
None
,
:]
outshape
=
tensor_product_shape
(
A
.
shape
,
B
.
shape
,
rank
)
return
np
.
einsum
(
subscripts
,
A
,
B
,
optimize
=
optimize
).
reshape
(
outshape
)
# Compute the tensor products in a binary treelike structure, calculating
# the product of two leaves and working up. This is more memoryefficient
# than reduce(binary_tensor, args) which computes the products
# lefttoright.
n
=
len
(
args
)
bit
=
n
%
2
while
n
>
1
:
args
=
args
[:
bit
]
+
[
binary_tensor
(
*
args
[
i
:
i
+
2
])
for
i
in
range
(
bit
,
n
,
2
)]
n
=
len
(
args
)
bit
=
n
%
2
return
args
[
0
]
def
mdot
(
arr
:
Sequence
,
axis
:
int
=
0
)
>
ndarray
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment