Coverage for local_installation_linux/mumott/methods/projectors/saxs_projector_cuda.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2 

3import numpy as np 

4from numpy.typing import NDArray 

5from numba.cuda import device_array 

6 

7from mumott.core.john_transform_cuda import john_transform_cuda, john_transform_adjoint_cuda 

8from .saxs_projector import SAXSProjector 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13class SAXSProjectorCUDA(SAXSProjector): 

14 """ 

15 Projector for transforms of tensor fields from three-dimensional space 

16 to projection space. Uses a projection algorithm implemented in 

17 ``numba.cuda``. 

18 

19 Parameters 

20 ---------- 

21 geometry : Geometry 

22 An instance of :class:`Geometry <mumott.Geometry>` containing the 

23 necessary vectors to compute forwared and adjoint projections. 

24 """ 

25 @staticmethod 

26 def _get_zeros_method(array: NDArray): 

27 """ Internal method for returning a device allocation method. """ 

28 return device_array 

29 

30 def _compile_john_transform(self, 

31 field: NDArray[float], 

32 projections: NDArray[float], 

33 *args) -> None: 

34 """ Internal method for compiling John transform only as needed. """ 

35 self._compiled_john_transform = john_transform_cuda( 

36 field, projections, *args) 

37 self._compiled_john_transform_adjoint = john_transform_adjoint_cuda( 

38 field, projections, *args) 

39 

40 @property 

41 def dtype(self) -> np.typing.DTypeLike: 

42 """ Preferred dtype of this ``Projector``. """ 

43 return np.float32