Coverage for local_installation_linux/mumott/methods/utilities/tensor_operations.py: 70%
28 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1""" Utilities for efficient tensor algebra. """
2import numpy as np
3from numba import jit
6"""
7NB: Seems faster to simply let numba work out types and shapes than doing it upfront,
8after the first call. There is no parallelization with `prange` because there was
9little to no visible gains in the cases checked, probably since the
10matrix multiplication should already be parallelized. This may need to be revisited
11in the future.
12"""
15@jit(nopython=True)
16def _framewise_contraction(matrices: np.ndarray[float],
17 data: np.ndarray[float],
18 out: np.ndarray[float]) -> None:
19 """This is the internal function that carries out calculations."""
20 # NB: In principle could consider cache size but have not had much success with this.
21 for i in range(len(matrices)):
22 np.dot(data[i], matrices[i], out=out[i])
25def framewise_contraction(matrices: np.ndarray[float],
26 data: np.ndarray[float],
27 out: np.ndarray[float]) -> None:
28 """This function is a more efficient implementation of the
29 series of matrix multiplications which is carried out the expression
30 ``np.einsum('ijk, inmj -> inmk', matrices, data, out=out)``, using
31 ``numba.jit``, for a four-dimensional :attr:`data` array.
32 Used to linearly transform data from one representation to another.
34 Parameters
35 ----------
36 matrices
37 Stack of matrices with shape ``(i, j, k)`` where ``i`` is the
38 stacking dimension.
39 data
40 Stack of data with shape ``(i, n, m, ..., j)`` where ``i`` is the
41 stacking dimension and ``j`` is the representation dimension.
42 out
43 Output array with shape ``(i, n, m, ..., k)`` where ``i`` is the
44 stacking dimension and ``k`` is the representation dimension.
45 Will be modified in-place and must be contiguous.
47 Notes
48 -----
49 All three arrays will be cast to ``out.dtype``.
50 """
51 # This is inefficient to do in the jitted function.
52 dtype = out.dtype
53 if matrices.dtype != dtype:
54 matrices = matrices.astype(dtype)
55 if data.dtype != dtype: 55 ↛ 56line 55 didn't jump to line 56, because the condition on line 55 was never true
56 data = data.astype(dtype)
57 data = data.reshape(len(matrices), -1, data.shape[-1])
58 out = out.reshape(len(matrices), -1, out.shape[-1])
59 _framewise_contraction(matrices, data, out)
62@jit(nopython=True)
63def _framewise_contraction_transpose(matrices: np.ndarray[float],
64 data: np.ndarray[float],
65 out: np.ndarray[float]) -> None:
66 """Internal function for matrix contraction."""
67 for i in range(len(matrices)):
68 np.dot(data[i], matrices[i].T, out=out[i])
71def framewise_contraction_transpose(matrices: np.ndarray[float],
72 data: np.ndarray[float],
73 out: np.ndarray[float]) -> np.ndarray[float]:
74 """This function is a more efficient implementation of the
75 series of matrix multiplications which is carried out the expression
76 ``np.einsum('ijk, inmk -> inmj', matrices, data, out=out)``, using
77 ``numba.jit``, for a four-dimensional :attr:`data` array.
78 Used to linearly transform data from one representation to another.
80 Parameters
81 ----------
82 matrices
83 Stack of matrices with shape ``(i, j, k)`` where ``i`` is the
84 stacking dimension.
85 data
86 Stack of data with shape ``(i, n, m, ..., k)`` where ``i`` is the
87 stacking dimension and ``k`` is the representation dimension.
88 out
89 Output array with shape ``(i, n, m, ..., j)`` where ``i`` is the
90 stacking dimension and ``j`` is the representation dimension.
91 Will be modified in-place and must be contiguous.
93 Notes
94 -----
95 All three arrays will be cast to ``out.dtype``.
96 """
97 dtype = out.dtype
98 if matrices.dtype != dtype:
99 matrices = matrices.astype(dtype)
100 if data.dtype != dtype: 100 ↛ 101line 100 didn't jump to line 101, because the condition on line 100 was never true
101 data = data.astype(dtype)
102 data = data.reshape(len(matrices), -1, data.shape[-1])
103 out = out.reshape(len(matrices), -1, out.shape[-1])
104 _framewise_contraction_transpose(matrices, data, out)