Coverage for local_installation_linux/mumott/core/cuda_utils.py: 79%

20 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2 

3from typing import Tuple 

4 

5from numba import cuda 

6from numpy.typing import DTypeLike, NDArray 

7 

8from numpy import float32, float64 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def cuda_calloc(shape: Tuple, dtype: DTypeLike = float32) -> NDArray: 

14 """ Function for creating a zero-initialized 

15 device array using ``numba.cuda``, as it provides no 

16 native method for doing so. 

17 

18 Paramters 

19 --------- 

20 shape 

21 A ``tuple`` with the shape of the desired 

22 device array. 

23 dtype 

24 A ``numba`` or ``numpy`` dtype. ``float64`` 

25 is turned into ``float32``. 

26 """ 

27 if dtype == float64: 

28 dtype = float32 

29 array = cuda.device_array(shape, dtype=dtype) 

30 

31 size = array.size 

32 

33 @cuda.jit 

34 def _calloc(array: NDArray): 

35 """ Initializes the given array. Depends on 

36 `size` being predefined. 

37 

38 Parameters 

39 ---------- 

40 array 

41 A device array which is initialized with zeros. 

42 """ 

43 x = cuda.grid(1) 

44 if x < size: 

45 array[x] = 0 

46 

47 # Use 512 threads per block, and blocks per grid based on array size. 

48 tpb = 512 

49 bpg = (size + tpb - 1) // tpb 

50 _calloc[bpg, tpb](array.ravel()) 

51 return array