Coverage for local_installation_linux/mumott/methods/utilities/tensor_operations.py: 70%

28 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1""" Utilities for efficient tensor algebra. """ 

2import numpy as np 

3from numba import jit 

4 

5 

6""" 

7NB: Seems faster to simply let numba work out types and shapes than doing it upfront, 

8after the first call. There is no parallelization with `prange` because there was 

9little to no visible gains in the cases checked, probably since the 

10matrix multiplication should already be parallelized. This may need to be revisited 

11in the future. 

12""" 

13 

14 

15@jit(nopython=True) 

16def _framewise_contraction(matrices: np.ndarray[float], 

17 data: np.ndarray[float], 

18 out: np.ndarray[float]) -> None: 

19 """This is the internal function that carries out calculations.""" 

20 # NB: In principle could consider cache size but have not had much success with this. 

21 for i in range(len(matrices)): 

22 np.dot(data[i], matrices[i], out=out[i]) 

23 

24 

25def framewise_contraction(matrices: np.ndarray[float], 

26 data: np.ndarray[float], 

27 out: np.ndarray[float]) -> None: 

28 """This function is a more efficient implementation of the 

29 series of matrix multiplications which is carried out the expression 

30 ``np.einsum('ijk, inmj -> inmk', matrices, data, out=out)``, using 

31 ``numba.jit``, for a four-dimensional :attr:`data` array. 

32 Used to linearly transform data from one representation to another. 

33 

34 Parameters 

35 ---------- 

36 matrices 

37 Stack of matrices with shape ``(i, j, k)`` where ``i`` is the 

38 stacking dimension. 

39 data 

40 Stack of data with shape ``(i, n, m, ..., j)`` where ``i`` is the 

41 stacking dimension and ``j`` is the representation dimension. 

42 out 

43 Output array with shape ``(i, n, m, ..., k)`` where ``i`` is the 

44 stacking dimension and ``k`` is the representation dimension. 

45 Will be modified in-place and must be contiguous. 

46 

47 Notes 

48 ----- 

49 All three arrays will be cast to ``out.dtype``. 

50 """ 

51 # This is inefficient to do in the jitted function. 

52 dtype = out.dtype 

53 if matrices.dtype != dtype: 

54 matrices = matrices.astype(dtype) 

55 if data.dtype != dtype: 55 ↛ 56line 55 didn't jump to line 56, because the condition on line 55 was never true

56 data = data.astype(dtype) 

57 data = data.reshape(len(matrices), -1, data.shape[-1]) 

58 out = out.reshape(len(matrices), -1, out.shape[-1]) 

59 _framewise_contraction(matrices, data, out) 

60 

61 

62@jit(nopython=True) 

63def _framewise_contraction_transpose(matrices: np.ndarray[float], 

64 data: np.ndarray[float], 

65 out: np.ndarray[float]) -> None: 

66 """Internal function for matrix contraction.""" 

67 for i in range(len(matrices)): 

68 np.dot(data[i], matrices[i].T, out=out[i]) 

69 

70 

71def framewise_contraction_transpose(matrices: np.ndarray[float], 

72 data: np.ndarray[float], 

73 out: np.ndarray[float]) -> np.ndarray[float]: 

74 """This function is a more efficient implementation of the 

75 series of matrix multiplications which is carried out the expression 

76 ``np.einsum('ijk, inmk -> inmj', matrices, data, out=out)``, using 

77 ``numba.jit``, for a four-dimensional :attr:`data` array. 

78 Used to linearly transform data from one representation to another. 

79 

80 Parameters 

81 ---------- 

82 matrices 

83 Stack of matrices with shape ``(i, j, k)`` where ``i`` is the 

84 stacking dimension. 

85 data 

86 Stack of data with shape ``(i, n, m, ..., k)`` where ``i`` is the 

87 stacking dimension and ``k`` is the representation dimension. 

88 out 

89 Output array with shape ``(i, n, m, ..., j)`` where ``i`` is the 

90 stacking dimension and ``j`` is the representation dimension. 

91 Will be modified in-place and must be contiguous. 

92 

93 Notes 

94 ----- 

95 All three arrays will be cast to ``out.dtype``. 

96 """ 

97 dtype = out.dtype 

98 if matrices.dtype != dtype: 

99 matrices = matrices.astype(dtype) 

100 if data.dtype != dtype: 100 ↛ 101line 100 didn't jump to line 101, because the condition on line 100 was never true

101 data = data.astype(dtype) 

102 data = data.reshape(len(matrices), -1, data.shape[-1]) 

103 out = out.reshape(len(matrices), -1, out.shape[-1]) 

104 _framewise_contraction_transpose(matrices, data, out)