Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add object regularization for multi-slice reconstruction #525

Draft
wants to merge 15 commits into
base: dev
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 33 additions & 1 deletion ptypy/custom/threepie.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class ThreePIE(stochastic.EPIE):
help = File path for the slice data
doc =

[object_regularization_rate]
default = 0.0
type = float
help = regularization rate for object slices
doc =

"""
def __init__(self, ptycho_parent, pars=None):
super(ThreePIE, self).__init__(ptycho_parent, pars)
Expand Down Expand Up @@ -227,4 +233,30 @@ def multislice_update(self, view):
for i in range(1, self.p.number_of_slices):
self.ob *= self._object[i]

return error
if self.p.object_regularization_rate > 0:
Comment on lines +236 to +237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably make sense to move this a bit higher up to make sure that self.ob is calculated for plotting after the regulariser has been applied...

self.apply_object_regularization()

return error

def apply_object_regularization(self):
# single mode implementation
# only valide for slices with identical thickness
assert(self.p.number_of_slices > 1)
assert(isinstance(self.p.slice_thickness, float))

shape = self._object[0].S["Sscan_00G00"].data.shape[1:]
psize = self._object[0].S["Sscan_00G00"].psize[0]
kz = np.fft.fftfreq(self.p.number_of_slices, self.p.slice_thickness)[..., np.newaxis, np.newaxis]
ky = np.fft.fftfreq(shape[0], psize)[..., np.newaxis]
kx = np.fft.fftfreq(shape[1], psize)

# calculate the weight array
Comment on lines +247 to +254
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part is basically just calculating some weights w and does not depend on any current update so can be moved into a separate function, e.g. initialize_regularizer and called once in the constructor. In this current implementation the weights are re-calculated for every iteration which seems unnecessary.

w = 1 - 2*np.arctan2(self.p.object_regularization_rate**2 * kz**2, kx**2+ky**2+np.spacing(1))/np.pi

current_object = np.fft.ifftn(np.fft.fftn([self._object[i].S["Sscan_00G00"].data[0,...] for i in range(len(self._object))]) * w)

print("object shape", self._object[0].S["Sscan_00G00"].data.shape)
print("w shape", w.shape)
print("current shape", current_object.shape)
for i in range(len(self._object)):
self._object[i].S["Sscan_00G00"].data[0, ...] = current_object[i, ...]