Skip to content

Commit

Permalink
Fix/update align_joint
Browse files Browse the repository at this point in the history
  • Loading branch information
dgursoy committed May 4, 2023
1 parent f040f79 commit 9152511
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions source/tomopy/prep/alignment.py
Expand Up @@ -214,7 +214,7 @@ def align_seq(


def align_joint(
prj, ang, fdir='.', iters=10, pad=(0, 0),
prj, ang, fdir='./', iters=10, pad=(0, 0),
blur=True, center=None, algorithm='sirt',
upsample_factor=10, rin=0.5, rout=0.8,
save=False, debug=True):
Expand Down Expand Up @@ -300,25 +300,28 @@ def align_joint(
if algorithm != 'gridrec':
extra_kwargs['num_iter'] = 1

# Make a copy of the projections
prj_copy = prj.copy()

# Register each image frame-by-frame.
for n in range(iters):

if np.mod(n, 1) == 0:
_rec = rec

# Reconstruct image.
rec = recon(prj, ang, center=center, algorithm=algorithm,
rec = recon(prj_copy, ang, center=center, algorithm=algorithm,
init_recon=_rec, **extra_kwargs)

# Re-project data and obtain simulated data.
sim = project(rec, ang, center=center, pad=False)

# Blur edges.
if blur:
_prj = blur_edges(prj, rin, rout)
_prj = blur_edges(prj_copy, rin, rout)
_sim = blur_edges(sim, rin, rout)
else:
_prj = prj
_prj = prj_copy
_sim = sim

# Initialize error matrix per iteration.
Expand All @@ -329,23 +332,27 @@ def align_joint(

# Register current projection in sub-pixel precision
shift, error, diffphase = phase_cross_correlation(
_prj[m], _sim[m], upsample_factor=upsample_factor)
_prj[m], _sim[m], normalization=None,
upsample_factor=upsample_factor)
err[m] = np.sqrt(shift[0]*shift[0] + shift[1]*shift[1])
sx[m] += shift[0]
sy[m] += shift[1]

# Register current image with the simulated one
tform = tf.SimilarityTransform(translation=(shift[1], shift[0]))
prj[m] = tf.warp(prj[m], tform, order=5)
tform = tf.SimilarityTransform(translation=(sy[m], sx[m]))
prj_copy[m] = tf.warp(prj[m].copy(), tform, order=5)

if debug:
print('iter=' + str(n) + ', err=' + str(np.linalg.norm(err)))
print('iter=' + str(n) +
', err=' + str(np.linalg.norm(err) / prj.shape[0]))
conv[n] = np.linalg.norm(err)

if save:
write_tiff(prj, 'tmp/iters/prj', n)
write_tiff(sim, 'tmp/iters/sim', n)
write_tiff(rec, 'tmp/iters/rec', n)
write_tiff(_prj, fdir + 'tmp/iters/prj', n)
write_tiff(sim, fdir + 'tmp/iters/sim', n)
write_tiff(rec, fdir + 'tmp/iters/rec', n)
write_tiff(sx, fdir + 'tmp/iters/sx', n)
write_tiff(sy, fdir + 'tmp/iters/sy', n)

# Re-normalize data
prj *= scl
Expand Down

0 comments on commit 9152511

Please sign in to comment.