Coverage for local_installation_linux/mumott/pipelines/optical_flow_alignment.py: 91%

172 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1""" 

2Functions for subpixel alignment algorithm. 

3""" 

4import sys 

5 

6import numpy as np 

7import scipy as sp 

8from scipy import ndimage 

9from tqdm import tqdm 

10 

11from typing import Any 

12 

13from mumott.data_handling import DataContainer 

14from mumott.data_handling.utilities import get_absorbances 

15from mumott.pipelines import run_mitra 

16from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector 

17from mumott.pipelines.utilities import image_processing as imp 

18from mumott.pipelines.utilities.alignment_geometry import get_alignment_geometry 

19 

20 

21def _define_axis_index(rotation_axis_index: int) -> tuple[int, int]: 

22 """ Defines the indices of the geometrical x and y axes relative to the axes of the array, 

23 using the index of the main rotation axis, which by definition is the geometrical y axis. 

24 

25 Parameters 

26 ---------- 

27 rotation_axis_index 

28 The index of the main rotation axis in the projections. 

29 

30 Returns 

31 ------- 

32 A tuple comprising the indices of the geometrical x and y axes for the array. 

33 

34 """ 

35 

36 # define the x and y axis, y being the rotation axis 

37 if rotation_axis_index == 0: 

38 x_axis_index = 1 

39 y_axis_index = 0 

40 else: 

41 x_axis_index = 0 

42 y_axis_index = 1 

43 return x_axis_index, y_axis_index 

44 

45 

46def run_optical_flow_alignment( 

47 data_container: DataContainer, 

48 **kwargs, 

49) -> tuple[np.ndarray[float], np.ndarray[float], np.ndarray[float]]: 

50 """This pipeline implements the alignment algorithm described in [Odstrcil2019]_ 

51 The latter allows one to compute the shifts that are need to align the data 

52 according to the tomographical problem defined by the given geometry and projector. 

53 The alignment also relies on the related definition of the reconstruction volume 

54 and main axis of rotation ('y' axis). 

55 The procedure can be customized via the keyword arguments described under below (see Notes). 

56 

57 Parameters 

58 ---------- 

59 data_container 

60 Input data. 

61 volume : tuple[int, int, int] 

62 **Geometry parameter:** 

63 The size of the volume of the wanted tomogram. 

64 If not specified, deduced from information in :attr:`data_container`. 

65 main_rot_axis : int 

66 **Geometry parameter:** 

67 The index of the main rotation axis (``y`` axis on the real geometry) on the array. 

68 If not specified, deduced from information in :attr:`data_container`. 

69 smooth_data : bool 

70 **Data filtering:** 

71 If ``True`` apply Gaussian filter to smoothen the data; default: ``False``. 

72 sigma_smooth : float 

73 **Data filtering:** 

74 The smoothing kernel related to the overall data smoothing; default: ``0``. 

75 high_pass_filter : float 

76 **Data filtering:** 

77 The kernel for high pass filter in the gradient 

78 computation to avoid phase artifact; default: ``0.01``. 

79 optimal_shift : np.ndarray[float] 

80 **Shifts:** 

81 The original shift; default: ``np.zeros((nb_projections, 2))``. 

82 rec_iteration : int 

83 **Projector parameters:** 

84 The number of iteration used to solve the tomogram from the 

85 projections; default: ``20``. 

86 use_gpu : bool 

87 **Projector parameters:** 

88 Use GPU (``True``) or CPU (``False``) for the tomographic computation; 

89 default: ``False``. 

90 optimizer_kwargs : dict[str, Any] 

91 **Projector parameters:** 

92 Keyword arguments to pass on to the optimizer; 

93 default: ``dict(nestorov_weight = 0.6)``. 

94 stop_max_iteration : int 

95 **Alignment parameters:** 

96 The maximum iteration allowed to find the shifts for the alignment; default: ``20``. 

97 stop_min_correction : float 

98 **Alignment parameters:** 

99 The optimization is terminated if the correction (in pixel) 

100 drops below this value; default ``0.01``. 

101 align_horizontal : bool 

102 **Alignment parameters:** 

103 Apply the horizontal alignment procedure; default: ``True``. 

104 align_vertical : bool 

105 **Alignment parameters:** 

106 Apply the vertical alignment procedure; default: ``True``. 

107 center_reconstruction : bool 

108 **Alignment parameters:** 

109 Shift the reconstructed tomogram to the center of the volume 

110 to avoid drift in the alignment procedure; default: ``True``. 

111 

112 Returns 

113 ------- 

114 A tuple comprising the shifts used for aligning the, the resulting (and aligned) projections 

115 obtained by projecting the reconstructed tomogram based on the aligned data, and the 

116 resulting tomogram reconstructed using the aligned data. 

117 

118 Example 

119 ------- 

120 The alignment procedure is simple to use. 

121 

122 >>> import numpy as np 

123 >>> from mumott.data_handling import DataContainer 

124 >>> from mumott.pipelines import optical_flow_alignment as ofa 

125 >>> data_container = DataContainer('tests/test_fbp_data.h5') 

126 

127 We introduce some spurious offsets to this already-aligned data. 

128 

129 >>> length = len(data_container.geometry) 

130 >>> data_container.geometry.j_offsets = np.arange(0., length) - length * 0.5 

131 >>> data_container.geometry.k_offsets = \ 

132 np.cos(np.arange(0., length) * np.pi / length) 

133 

134 We then perform the alignment with default parameters. 

135 

136 >>> shifts, sinogram_corr, tomogram_corr = ofa.run_optical_flow_alignment(data_container) 

137 ... 

138 

139 To use the alignment shifts, we have to translate them from the reconstruction 

140 ``(x, y, z)``-coordinates into the projection ``(p, j, k)`` coordinates. 

141 The function :func:`shifts_to_vector` transforms the shifts in projection space 

142 into shifts in 3D space based on the detector geometry. 

143 The function :func:`shifts_to_geometry` transforms the shifts vector from the detector reference 

144 frame to the sample reference frame. 

145 

146 >>> j_offsets, k_offsets = ofa.shifts_to_geometry(data_container, shifts) 

147 >>> data_container.geometry.j_offsets = j_offsets 

148 >>> data_container.geometry.k_offsets = k_offsets 

149 

150 We can configure a variety of parameters and pass those to the alignment pipeline. 

151 For example, we can choose whether to align in the horizontal or vertical directions, 

152 whether to center the reconstruction, initial guesses for the shifts, 

153 the number of iterations for the reconstruction, 

154 the number of iterations for the alignment procedure, and so on. 

155 

156 A quick way to align very misaligned data is to use the :func:`line_vertical_alignment` function 

157 to obtain an initial guess. 

158 It can as well be used between alignment steps. Note that their is two version for the vertical 

159 It can also be used between alignment steps. 

160 There are two version for the vertical alignment. 

161 :func:`line_vertical_alignment` uses the :class:`DataContainer` geometry and data, 

162 while :func:`sinogram_line_vertical_alignment` is based on a sinogram and the geometry. 

163 Note that the more strongly the object deviates from axial symmetry, 

164 the worse the vertical line alignment will be. 

165 

166 >>> initial_shift = ofa.line_vertical_alignment(data_container) 

167 

168 >>> alignment_param = dict( 

169 ... optimal_shift=initial_shift, 

170 ... rec_iteration=10, 

171 ... stop_max_iteration=3, 

172 ... align_horizontal=True, 

173 ... align_vertical=True, 

174 ... center_reconstruction=True, 

175 ... optimizer_kwargs=dict(nestorov_weight=0.6)) 

176 >>> shifts, sinogram_corr, tomogram_corr = ofa.run_optical_flow_alignment(data_container, 

177 ... **alignment_param) 

178 >>> vertical_shift, index = ofa.sinogram_line_vertical_alignment( 

179 ... sinogram_corr, ofa.get_alignment_geometry(data_container)[0]) 

180 >>> shifts[:,index] += vertical_shift # to update the shifts with vertical line alignment. 

181 """ 

182 

183 # deduce main rotation axis and associated volume from information in data container 

184 main_rot_axis_deduced, volume_deduced = get_alignment_geometry(data_container) 

185 

186 # -- options 

187 

188 nb_projections = np.size(data_container.diode, 0) 

189 

190 # ========= geometry ================ 

191 # volume to reconstruct 

192 volume = kwargs.get('volume', volume_deduced) 

193 

194 # main rotation axis 

195 main_rot_axis = kwargs.get('main_rot_axis', main_rot_axis_deduced) 

196 

197 # ========= data filtering ================ 

198 # overall filtering 

199 smooth_data = kwargs.get('smooth_data', False) 

200 sigma_smooth = kwargs.get('sigma_smooth', 0) 

201 

202 # phase artifact removal filtering 

203 high_pass_filter = kwargs.get('high_pass_filter', 0.01) 

204 

205 # ========= original shift ================ 

206 optimal_shift = kwargs.get('optimal_shift', np.zeros((nb_projections, 2))) 

207 

208 # ========= projector parameters ================ 

209 

210 # number of iteration for the tomogram reconstruction 

211 rec_iteration = kwargs.get('rec_iteration', 20) 

212 

213 # ========= alignment parameters ================ 

214 

215 # stoppage criteria 

216 # maximum iteration possible 

217 max_iteration = kwargs.get('stop_max_iteration', 20) 

218 # stop iterate if the position is not corrected enough anymore 

219 min_correction = kwargs.get('stop_min_correction', 0.01) 

220 

221 # alignment condition 

222 # horizontal and vertical algwinment 

223 align_horizontal = kwargs.get('align_horizontal', True) 

224 align_vertical = kwargs.get('align_vertical', True) 

225 

226 # to avoid a full volume shift 

227 center_reconstruction = kwargs.get('center_reconstruction', True) 

228 

229 # use the GPU for tomographic computation 

230 use_gpu = kwargs.get('use_gpu', False) 

231 

232 # optimizer kwargs 

233 optimizer_kwargs = kwargs.get('optimizer_kwargs', dict(nestorov_weight=0.6)) 

234 

235 # add the max iter 

236 optimizer_kwargs.update({'maxiter': rec_iteration}) 

237 

238 # call the real function 

239 return _optical_flow_alignment_full_param( 

240 data_container, 

241 volume, 

242 main_rot_axis, 

243 smooth_data, 

244 sigma_smooth, 

245 high_pass_filter, 

246 optimal_shift, 

247 rec_iteration, 

248 max_iteration, 

249 min_correction, 

250 align_horizontal, 

251 align_vertical, 

252 center_reconstruction, 

253 use_gpu, 

254 optimizer_kwargs, 

255 ) 

256 

257 

258def _optical_flow_alignment_full_param( 

259 data_container: DataContainer, 

260 volume: tuple[int, int, int], 

261 main_rot_axis: int, 

262 smooth_data: bool = False, 

263 sigma_smooth: float = 0.0, 

264 high_pass_filter: float = 0.01, 

265 optimal_shift: np.ndarray[float] = 0.0, 

266 rec_iteration: int = 20, 

267 max_iteration: int = 100, 

268 min_correction: float = 0.01, 

269 align_horizontal: bool = True, 

270 align_vertical: bool = True, 

271 center_reconstruction: bool = True, 

272 use_gpu: bool = False, 

273 optimizer_kwargs: dict[str, Any] = None, 

274): 

275 """ To compute the shifts nescessary to align the datas according to the tomographical 

276 problem defined by the geom and the projector given in arg. 

277 This will also rely on the related definition of the reconstruction volume and main axis of rotation 

278 ('y' axis), and a lot of alignment parameters. 

279 This function should not be called on itself, but through the wrapper optical_flow_alignment. 

280 

281 Complete re-use and reformating of the code of Michal Odstrcil, 2016. Based on [Odstrcil2019]_ 

282 

283 Parameters 

284 ---------- 

285 data_container 

286 Input data. 

287 volume 

288 The size of the volume of the wanted tomogram. 

289 main_rot_axis 

290 The index of the main rotation axis ('y' axis on the real geometry) on the array. 

291 smooth_data 

292 To smooth the data by gaussian filter. 

293 sigma_smooth 

294 The smoothing kernel related to the overall data smoothing. 

295 high_pass_filter 

296 The kernel for high pass filter in the gradient computation, to avoid phase artifact. 

297 optimal_shift 

298 The original shift. 

299 rec_iteration 

300 The number of iteration used to solve the tomogram from the projections. 

301 max_iteration 

302 The maximum iteration allowed to find the shifts for the alignment. 

303 min_correction 

304 The minimum of correction (in pixel) needed for each iteration to not stop the iterative procedure. 

305 align_horizontal 

306 Apply the horizontal alignment procedure. 

307 align_vertical 

308 Apply the horizontal vertical procedure. 

309 center_reconstruction 

310 Recenter the reconstructed tomogram at the center of the volume to avoid drift in the 

311 alignment procedure. 

312 use_gpu 

313 Use GPU or CPU for the tomographic computation. 

314 optimizer_kwargs 

315 kwargs for the optimizer, for the run pipeline. 

316 

317 Returns 

318 ------- 

319 shift 

320 np.ndarray[float]. 

321 The shifts nescessary to align the data. 

322 sinogram_corr 

323 A tuple comprising the shifts used for aligning the, the resulting (and aligned) projections 

324 obtained by projecting the reconstructed tomogram based on the aligned data, and the 

325 resulting tomogram reconstructed using the aligned data. 

326 """ 

327 

328 # define the x and y axis, y being the rotation axis 

329 x_axis_index, y_axis_index = _define_axis_index(main_rot_axis) 

330 

331 # get absorbance for the sino 

332 abs_dict = get_absorbances(data_container.diode, normalize_per_projection=True) 

333 absorbances = abs_dict['absorbances'] 

334 sinogram_0 = np.moveaxis(np.squeeze(absorbances), 0, -1) 

335 

336 # reference diodes 

337 diodes_ref = np.moveaxis(np.squeeze(data_container.diode), 0, -1) 

338 

339 # define sinogram sizes 

340 nb_projections = np.size(sinogram_0, -1) 

341 # the volume is a parallepipede with square base 

342 nb_layers = data_container.diode.shape[main_rot_axis + 1] 

343 width = data_container.diode.shape[np.mod(main_rot_axis + 1, 2) + 1] 

344 

345 # -- pre processing 

346 

347 # if data is smoothened, smooth sinogram_0 via Gaussian filter 

348 if smooth_data: 348 ↛ 349line 348 didn't jump to line 349, because the condition on line 348 was never true

349 sinogram_0 = sp.ndimage.gaussian_filter(sinogram_0, sigma_smooth) 

350 sinogram_0 = imp.smooth_edges(sinogram_0) 

351 

352 # Tukey window, necessary for the grad computation, to avoid edges issues 

353 W = imp.compute_tukey_window(width, nb_layers) 

354 if x_axis_index == 0: 354 ↛ 358line 354 didn't jump to line 358, because the condition on line 354 was never false

355 W = W.transpose(1, 0, 2) 

356 

357 # -- configuration for the loop 

358 iteration = 0 

359 max_step = 1 + min_correction 

360 shifts = np.zeros((max_iteration + 1, nb_projections, 2)) 

361 

362 # shift initial tomogram with a guessed shift 

363 shifts[0, :, :] = optimal_shift 

364 

365 # shift in geometry to 0, since it is done by sinogram shift 

366 # data_container.geometry.j_offsets = optimal_shift[:, 0] * 0 

367 # data_container.geometry.k_offsets = optimal_shift[:, 1] * 0 

368 

369 # visual feedback of the progression 

370 pbar = tqdm(total=max_iteration, desc='Alignment iterations', file=sys.stdout) 

371 

372 while True: 

373 

374 # compute the shifts and the related tomogram and corrected sinograms 

375 max_step, sinogram_corr, tomogram, shifts = compute_shifts( 

376 sinogram_0, 

377 diodes_ref, 

378 shifts, 

379 data_container, 

380 iteration, 

381 x_axis_index, 

382 y_axis_index, 

383 rec_iteration, 

384 high_pass_filter, 

385 W, 

386 align_horizontal, 

387 align_vertical, 

388 center_reconstruction, 

389 use_gpu, 

390 optimizer_kwargs, 

391 ) 

392 

393 pbar.update(1) 

394 # if the step size is too small stop optimization 

395 if max_step <= min_correction: 395 ↛ 396line 395 didn't jump to line 396, because the condition on line 395 was never true

396 print( 

397 'The largest change ({max_step})' 

398 ' has dropped below the stopping criterion ({min_correction})' 

399 ' The alignment is complete.', 

400 ) 

401 break 

402 if iteration + 1 >= max_iteration: 

403 break 

404 else: 

405 iteration = iteration + 1 

406 pbar.close() 

407 

408 shift = shifts[iteration + 1, :, :] 

409 return shift, np.moveaxis(sinogram_corr, 0, -1), tomogram 

410 

411 

412def recenter_tomogram( 

413 tomogram: np.ndarray[float], 

414 step: float = 0.2, 

415 **kwargs: dict[str, Any], 

416): 

417 """ Recenter a tomogram in frame. 

418 

419 Parameters 

420 ---------- 

421 tomogram 

422 The tomogram to shift. 

423 step 

424 The step for the recentering, should be smaller than one. Default is 0.2 

425 

426 Returns 

427 ------- 

428 The shifted tomogram, recentered in frame. 

429 

430 """ 

431 # remove the extra dimension for this calculation 

432 tomo_2 = np.squeeze(np.copy(tomogram)) 

433 

434 # the volume is a parallelepiped with square base 

435 axis = 0 

436 volume = tomo_2.shape 

437 if volume[0] == volume[1]: 437 ↛ 438line 437 didn't jump to line 438, because the condition on line 437 was never true

438 axis = 2 

439 elif volume[0] == volume[2]: 439 ↛ 440line 439 didn't jump to line 440, because the condition on line 439 was never true

440 axis = 1 

441 axis = kwargs.get('axis ', axis) 

442 

443 # the 2 first dimensions must be the transverse slices of the tomogram, i.e., have the same dimension 

444 tomo_2 = np.moveaxis(tomo_2, axis, -1) 

445 

446 # try to keep reconstruction in center 

447 # enforce positivity 

448 pos_tomo_2 = np.copy(tomo_2) 

449 pos_tomo_2[pos_tomo_2 < 0] = 0 

450 x, y, mass = imp.center(np.sqrt(pos_tomo_2)) # x is here first axis, y second axis 

451 # more robust estimation of the center 

452 # remove nan 

453 ind_x = np.argwhere(~np.isnan(x)) 

454 ind_y = np.argwhere(~np.isnan(y)) 

455 

456 x = ( 

457 np.mean(x[ind_x] * mass[ind_x] ** 2) 

458 / np.mean(mass[ind_x] ** 2 + np.finfo(float).eps) 

459 * np.ones(x.shape) 

460 ) 

461 y = ( 

462 np.mean(y[ind_y] * mass[ind_y] ** 2) 

463 / np.mean(mass[ind_y] ** 2 + np.finfo(float).eps) 

464 * np.ones(y.shape) 

465 ) 

466 # shift (slowly) the tomogram to the new center 

467 # here, x is the first axis, y the second axis, same reference for imshift_fft function 

468 tomo_2 = imp.imshift_fft( 

469 tomo_2, -x * step, -y * step 

470 ) # go slowly, using only one fifth of the shift 

471 # put back the right order 

472 tomo_2 = np.moveaxis(tomo_2, -1, axis) 

473 

474 tomogram[:, :, :, 0] = tomo_2 

475 

476 return tomogram 

477 

478 

479def compute_shifts( 

480 sinogram_0: np.ndarray[float], 

481 diodes_ref: np.ndarray[float], 

482 shifts, 

483 data_container: DataContainer, 

484 iteration: int, 

485 x_axis_index: int, 

486 y_axis_index: int, 

487 rec_iteration: int, 

488 high_pass_filter: float, 

489 W: np.ndarray[float], 

490 align_horizontal: bool, 

491 align_vertical: bool, 

492 center_reconstruction: bool, 

493 use_gpu: bool, 

494 optimizer_kwargs: dict[str, Any], 

495) -> tuple[float, np.ndarray[float], np.ndarray[float], np.ndarray[float]]: 

496 """ Compute the shifts to align the reference sinogram and the synthetized sinogram. 

497 

498 Parameters 

499 ---------- 

500 sinogram_0 

501 The absorbance of the data to align. 

502 diodes_ref 

503 The data to align. 

504 data_container 

505 The data container. 

506 iteration 

507 The current iteration. 

508 shifts 

509 Current shifts. 

510 x_axis_index 

511 The index in the array of the geometrical X axis (tilt axis). 

512 y_axis_index 

513 The index in the array of the geometrical Y axis (main rotation axis). 

514 rec_iteration 

515 The number of iteration used to solve the tomogram from the projections. 

516 high_pass_filter 

517 The kernel for high pass filter in the gradient computation 

518 applied to avoid phase artifact. 

519 W 

520 The Tukey window associated with our computation. 

521 align_horizontal 

522 Apply the horizontal alignment procedure. 

523 align_vertical 

524 Apply the vertical alignment procedure. 

525 center_reconstruction 

526 Shift the reconstructed tomogram to the center of the volume to avoid drift in the 

527 alignment procedure. 

528 use_gpu 

529 Use GPU (``True``) or CPU (``False``) for the tomographic computation. 

530 

531 Returns 

532 ------- 

533 Tuple comprising the maximum update in this iteration, 

534 the resulting aligned projections found by projecting 

535 the reconstructed tomogram using the aligned data, 

536 the resulting tomogram reconstructed by the aligned data, and 

537 the updated shifts. 

538 """ 

539 

540 nb_projections = np.size(sinogram_0, -1) 

541 

542 # shift the original sinogram_0 by the new value 

543 sinogram_shifted = imp.imshift_fft( 

544 np.copy(sinogram_0), shifts[iteration, :, 0], shifts[iteration, :, 1]) 

545 # shift the original diodes 

546 diodes_shifted = imp.imshift_fft(diodes_ref, shifts[iteration, :, 0], shifts[iteration, :, 1]) 

547 

548 # update the shifts for SIRT / MITRA 

549 for ii in range(data_container.projections.diode.shape[0]): 

550 data_container.projections[ii].diode = diodes_shifted[..., ii] 

551 

552 # reconstruct the tomogram with the given geometry, shifts, and data 

553 tomogram = run_mitra( 

554 data_container, 

555 use_gpu=use_gpu, 

556 maxiter=rec_iteration, 

557 ftol=None, 

558 enforce_non_negativity=True, 

559 optimizer_kwargs=optimizer_kwargs, 

560 )['result']['x'] 

561 

562 if center_reconstruction and (iteration < 20): 562 ↛ 568line 562 didn't jump to line 568, because the condition on line 562 was never false

563 tomogram = recenter_tomogram(tomogram) 

564 # positivity constraint 

565 tomogram[tomogram < 0] = 0 

566 

567 # compute the projected data from the reconstructed tomogram: sinogram_corr 

568 if use_gpu: 

569 projector = SAXSProjectorCUDA(data_container.geometry) 

570 else: 

571 projector = SAXSProjector(data_container.geometry) 

572 sinogram_corr = projector.forward(tomogram) 

573 # remove extra dimension of the sinogram 

574 sinogram_corr = np.squeeze(sinogram_corr) 

575 

576 # find the optimal shift 

577 shift_hor, shift_vect = compute_shift_through_sinogram( 

578 sinogram_shifted, 

579 np.moveaxis(sinogram_corr, 0, -1), 

580 x_axis_index, 

581 y_axis_index, 

582 high_pass_filter, 

583 W, 

584 align_horizontal, 

585 align_vertical, 

586 ) 

587 

588 # store the values on the right axis 

589 shift_vector = np.zeros((1, nb_projections, 2)) 

590 shift_vector[:, :, x_axis_index] = shift_hor 

591 shift_vector[:, :, y_axis_index] = shift_vect 

592 

593 # apply the optical flow correction method 

594 step_relaxation = 0.5 # small relaxation is needed to avoid oscilations 

595 shifts[iteration + 1, :, :] = shifts[iteration, :, :] + shift_vector * step_relaxation 

596 # remove degree of freedom in the vertical dimension (avoid drifts) 

597 shifts[iteration + 1, :, y_axis_index] = shifts[iteration + 1, :, y_axis_index] - np.median( 

598 shifts[iteration + 1, :, y_axis_index] 

599 ) 

600 

601 max_step = np.maximum(np.max(np.abs(shift_hor)), np.max(np.abs(shift_vect))) 

602 

603 # put back original data 

604 for ii in range(data_container.projections.diode.shape[0]): 

605 data_container.projections[ii].diode = diodes_ref[..., ii] 

606 

607 return max_step, sinogram_corr, tomogram, shifts 

608 

609 

610def compute_shift_through_sinogram( 

611 sinogram_shifted: np.ndarray[float], 

612 sinogram_corr: np.ndarray[float], 

613 x_axis_index: int, 

614 y_axis_index: int, 

615 high_pass_filter: float, 

616 W: np.ndarray[float], 

617 align_horizontal: bool, 

618 align_vertical: bool, 

619) -> tuple[np.ndarray[float], np.ndarray[float]]: 

620 """ Compute the shift needed to obtain a better sinogram, based on actual shifted 

621sinogram and the sythetic sinogram. 

622 

623 Parameters 

624 ---------- 

625 sinogram_shifted 

626 The sinogram, aka data, to compute the shift correction on. 

627 sinogram_corr 

628 The sythetic sinogram obtain after reconstruction on the tomogram obtain from the 

629 sinogram_shifted and reprojecting it. 

630 W 

631 The Tukey window associated with our computation. 

632 x_axis_index 

633 The index in the array of the geometrical X axis (tilt axis). 

634 y_axis_index 

635 The index in the array of the geometrical Y axis (main rotation axis). 

636 high_pass_filter 

637 The kernel for high pass filter in the gradient computation, to avoid phase artifact. 

638 align_horizontal 

639 Compute the horizontal alignment shift. 

640 align_vertical 

641 Compute the horizontal vertical shift. 

642 

643 Returns 

644 ------- 

645 A tuple comprising the horizontal and vertical shifts. 

646 """ 

647 

648 nb_projections = np.size(sinogram_shifted, -1) 

649 

650 # find the optimal shift 

651 d_vect = np.zeros(sinogram_corr.shape) 

652 d_hor = np.zeros(sinogram_corr.shape) 

653 for index in range(nb_projections): 

654 d_hor[..., index], d_vect[..., index] = imp.get_img_grad( 

655 imp.smooth_edges(sinogram_corr)[..., index], x_axis_index, y_axis_index 

656 ) 

657 DS = sinogram_corr - sinogram_shifted 

658 

659 # apply high pass filter => get rid of phase artefacts 

660 DS = imp.imfilter_high_pass_1d(DS, x_axis_index, high_pass_filter) 

661 

662 # align horizontal 

663 if align_horizontal: 663 ↛ 673line 663 didn't jump to line 673, because the condition on line 663 was never false

664 # calculate optimal shift of the 2D projections in horiz direction 

665 d_hor = imp.imfilter_high_pass_1d(d_hor, x_axis_index, high_pass_filter).real 

666 shift_hor = -( 

667 np.sum(W * d_hor * DS, axis=(0, 1)) / np.sum(W * d_hor ** 2 + np.finfo(float).eps, axis=(0, 1)) 

668 ) 

669 # do not allow more than 1px shift per iteration ! 

670 shift_hor = np.minimum(np.ones(shift_hor.shape), np.abs(shift_hor)) * np.sign(shift_hor) 

671 else: 

672 # if disable 

673 shift_hor = np.zeros((1, nb_projections)) 

674 

675 # align vertical 

676 if align_vertical: 676 ↛ 688line 676 didn't jump to line 688, because the condition on line 676 was never false

677 # calculate optimal shift of the 2D projections in vert direction 

678 d_vect = imp.imfilter_high_pass_1d(d_vect, x_axis_index, high_pass_filter).real 

679 shift_vect = -( 

680 np.sum(W * d_vect * DS, axis=(0, 1)) / np.sum(W * d_vect ** 2 + np.finfo(float).eps, axis=(0, 1)) 

681 ) 

682 

683 shift_vect = shift_vect - np.mean(shift_vect) 

684 # do not allow more than 1px shift per iteration ! 

685 shift_vect = np.minimum(np.ones(shift_vect.shape), np.abs(shift_vect)) * np.sign(shift_vect) 

686 else: 

687 # if disable 

688 shift_vect = np.zeros((1, nb_projections)) 

689 

690 return shift_hor, shift_vect 

691 

692 

693def sinogram_line_vertical_alignment( 

694 sinogram: np.ndarray[float], 

695 main_rot_axis: int, 

696 threshold: float = 0 

697) -> tuple[np.ndarray[float], float]: 

698 

699 """ 

700 Compute the vertical, i.e., transverse, weight on a given sinogram and 

701 align it by aligning the center of mass. 

702 This yields only a very rough alignment but it is very efficient if the misalignment is large. 

703 It can be used as the initial step in an alignment procedure or between different steps of the alignment. 

704 

705 Parameters 

706 ---------- 

707 sinogram 

708 The sinogram to align, as a stack of 2D arrays. 

709 main_rot_axis 

710 The index of the vertical axis on the given sinogram. It is the 

711 axis perpendicular to the rotation axis ``// y_axis_index``. 

712 threshold 

713 Maximum value to take into account for the weighting. 

714 

715 Returns 

716 ------- 

717 Computed shifts in the vertical direction. 

718 """ 

719 

720 horizontal_axis = main_rot_axis 

721 # defining the horizontal axis as the one not vertical 

722 vertical_axis = 0 if horizontal_axis == 1 else 1 

723 

724 im = np.copy(sinogram) 

725 # select the center of the sinogram to have only a 2D array left 

726 if vertical_axis == 0: 726 ↛ 729line 726 didn't jump to line 729, because the condition on line 726 was never false

727 im = im[:, round(sinogram.shape[1] / 2), :] 

728 else: 

729 im = im[round(sinogram.shape[0] / 2), :, :] 

730 

731 center_vect = np.zeros(im.shape[-1]) 

732 

733 # threshold value to avoid outliers to influence the alignment too much 

734 if threshold > 0: 734 ↛ 735line 734 didn't jump to line 735, because the condition on line 734 was never true

735 im[im > threshold] = threshold 

736 

737 # compute the weights and positions of each vertical line of the sinogram 

738 for ii in range(im.shape[-1]): 

739 center_vect[ii] = ndimage.center_of_mass(im[:, ii])[0] 

740 

741 # we want all vertical center of mass to be aligned with the global center of mass 

742 center_vect = -(center_vect - np.mean(center_vect)) 

743 

744 return center_vect, vertical_axis 

745 

746 

747def line_vertical_alignment( 

748 data_container: DataContainer, 

749 threshold: float = 0, 

750 **kwargs 

751) -> np.ndarray[float]: 

752 

753 """ 

754 Compute the vertical, i.e. transverse weight on the sinogram of a given DataContainer and 

755 align it by aligning the center of mass. 

756 It is a very rough alignment but very efficient for large misalignment. It can be used 

757 as initialisation of an alignment procedure or between different steps of the alignment. 

758 

759 Parameters 

760 ---------- 

761 data_container 

762 The DataContainer with the absorption data to align. 

763 threshold 

764 Maximum value to take into account for the weighting. 

765 main_rot_axis : int 

766 Geometry parameter. 

767 The index of the main rotation axis (``y`` axis on the real geometry) of the array. 

768 If not specified, deduced from information in :attr:`data_container`. 

769 

770 Returns 

771 ------- 

772 Computed shifts in the vertical direction. 

773 """ 

774 # deduce main rotation axis and associated volume from information in data container 

775 main_rot_axis_deduced, _ = get_alignment_geometry(data_container) 

776 

777 # main rotation axis 

778 main_rot_axis = kwargs.get('main_rot_axis', main_rot_axis_deduced) 

779 

780 # get absorbance for the sinogram 

781 abs_dict = get_absorbances(data_container.diode, normalize_per_projection=True) 

782 absorbances = abs_dict['absorbances'] 

783 sinogram_0 = np.moveaxis(np.squeeze(absorbances), 0, -1) 

784 print(sinogram_0.shape) 

785 

786 vertical_shift, vertical_axis = sinogram_line_vertical_alignment(sinogram_0, main_rot_axis) 

787 

788 # realign sinogram that was shifted originally 

789 shift_all = np.zeros((sinogram_0.shape[-1], 2)) 

790 shift_all[:, vertical_axis] = vertical_shift 

791 

792 return shift_all 

793 

794 

795def shifts_to_vector( 

796 shifts: np.ndarray[float], 

797 data_container: DataContainer, 

798 **kwargs, 

799) -> np.ndarray[float]: 

800 """ 

801 This function expresses the detector-referenced shifts in a 3D Cartesian coordinate 

802 system based on the detector frame and the inner and outer axes of the geometry. 

803 

804 Parameters 

805 ---------- 

806 shifts 

807 The shifts expressed in the detector frame. 

808 data_container 

809 The data container containing the geometry information. 

810 main_rot_axis : int 

811 Geometry parameter. 

812 The index of the main rotation axis (``y`` axis on the real geometry) of the array. 

813 If not specified, deduced from information in :attr:`data_container`. 

814 

815 Returns 

816 ------- 

817 vector_shifts in 3D coordinates system based on the detector given geometry. 

818 """ 

819 

820 # deduce main rotation axis and associated volume from information in data container 

821 main_rot_axis_deduced, _ = get_alignment_geometry(data_container) 

822 

823 # main rotation axis 

824 main_rot_axis = kwargs.get('main_rot_axis', main_rot_axis_deduced) 

825 

826 other_axis = 0 if main_rot_axis == 1 else 1 

827 

828 vector_shifts = np.array( 

829 [( 

830 shifts[i, main_rot_axis] * g.inner_axis + 

831 shifts[i, other_axis] * g.outer_axis 

832 ) 

833 for i, g in enumerate(data_container.geometry)] 

834 ) 

835 

836 return vector_shifts 

837 

838 

839def shifts_to_geometry( 

840 data_container: DataContainer, 

841 shifts: np.ndarray[float], 

842) -> tuple[np.ndarray[float], np.ndarray[float]]: 

843 """ 

844 This function expresses the detector-referenced shifts in a 3D Cartesian coordinate 

845 system based on the object coordinate system, i.e., the ``j`` and ``k``-directions. 

846 

847 Parameters 

848 ---------- 

849 data_container 

850 The data container containing the geometry information connected to the data. 

851 shifts 

852 The shifts expressed in the detector frame. 

853 

854 Returns 

855 ------- 

856 Two tuples for the shifts in the j and k directions. 

857 """ 

858 

859 # computing the shifts in detector space from the projection shifts 

860 vector_shifts = shifts_to_vector(shifts, data_container) 

861 

862 j_offsets = np.sum(vector_shifts * data_container.geometry.j_direction_0[np.newaxis], axis=1) 

863 k_offsets = np.sum(vector_shifts * data_container.geometry.k_direction_0[np.newaxis], axis=1) 

864 

865 return j_offsets, k_offsets