-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add object regularization for multi-slice reconstruction #525
base: dev
Are you sure you want to change the base?
Changes from all commits
6949de8
97a51e4
b35c3de
67eb25a
6a7e8f9
8da41aa
c9e71b1
d4b178b
aa97276
a4b83ec
55a7b26
7a65a77
90058d1
391cc4f
381065c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
A simple implementation of Multislice for the | ||
ePIE algorithm. | ||
|
||
authors: Benedikt J. Daurer and more... | ||
""" | ||
from ptypy.engines import stochastic | ||
from ptypy.engines import register | ||
from ptypy.core import geometry | ||
from ptypy.utils import Param | ||
from ptypy.utils.verbose import logger | ||
from ptypy import io | ||
import numpy as np | ||
|
||
@register() | ||
class ThreePIE(stochastic.EPIE): | ||
""" | ||
An extension of EPIE to include multislice | ||
|
||
Defaults: | ||
|
||
[name] | ||
default = ThreePIE | ||
type = str | ||
help = | ||
doc = | ||
|
||
[number_of_slices] | ||
default = 2 | ||
type = int | ||
help = The number of slices | ||
doc = Defines how many slices are used for the multi-slice object. | ||
|
||
[slice_thickness] | ||
default = 1e-6 | ||
type = float, list, tuple | ||
help = Thickness of a single slice in meters | ||
doc = A single float value or a list of float values. If a single value is used, all the slice will be assumed to be of the same thickness. | ||
|
||
[slice_start_iteration] | ||
default = 0 | ||
type = int, list, tuple | ||
help = iteration number to start using a specific slice | ||
doc = | ||
|
||
[fslices] | ||
default = slices.h5 | ||
type = str | ||
help = File path for the slice data | ||
doc = | ||
|
||
""" | ||
def __init__(self, ptycho_parent, pars=None): | ||
super(ThreePIE, self).__init__(ptycho_parent, pars) | ||
self.article = dict( | ||
title='{Ptychographic transmission microscopy in three dimensions using a multi-slice approach', | ||
author='A. M. Maiden et al.', | ||
journal='J. Opt. Soc. Am. A', | ||
volume=29, | ||
year=2012, | ||
page=1606, | ||
doi='10.1364/JOSAA.29.001606', | ||
comment='The 3PIE reconstruction algorithm', | ||
) | ||
self.ptycho.citations.add_article(**self.article) | ||
|
||
def engine_initialize(self): | ||
super().engine_initialize() | ||
|
||
# Create a list of objects and exit waves (one for each slice) | ||
self._object = [None] * self.p.number_of_slices | ||
self._probe = [None] * self.p.number_of_slices | ||
self._exits = [None] * self.p.number_of_slices | ||
for i in range(self.p.number_of_slices): | ||
self._object[i] = self.ob.copy(self.ob.ID + "_o_" + str(i)) | ||
self._probe[i] = self.pr.copy(self.pr.ID + "_p_" + str(i)) | ||
self._exits[i] = self.pr.copy(self.pr.ID + "_e_" + str(i)) | ||
|
||
# ToDo: | ||
# - allow for non equal slice spacing | ||
# - allow for start_slice_update at a freely chosen iteration | ||
# for each slice separately - works, but not if the | ||
# most downstream slice is switched off | ||
|
||
if isinstance(self.p.slice_start_iteration, int): | ||
self.p.slice_start_iteration = np.ones(self.p.number_of_slices) * self.p.slice_start_iteration | ||
#if ĺen(self.p.slice_start_iteration) != self.p.number_of_slices: | ||
# logger.info(f'dimension of given slice_start_iteration ({ĺen(self.p.slice_start_iteration)}) does not match number of slices ({self.p.number_of_slices})') | ||
|
||
scan = list(self.ptycho.model.scans.values())[0] | ||
geom = scan.geometries[0] | ||
g = Param() | ||
g.energy = geom.energy | ||
g.distance = self.p.slice_thickness | ||
g.psize = geom.resolution | ||
g.shape = geom.shape | ||
g.propagation = "nearfield" | ||
|
||
self.fw = [] | ||
self.bw = [] | ||
if type(self.p.slice_thickness) in [list, tuple]: | ||
assert(len(self.p.slice_thickness) == self.p.number_of_slices-1) | ||
for thickness in self.p.slice_thickness: | ||
g.distance = thickness | ||
G = geometry.Geo(owner=None, pars=g) | ||
self.fw.append(G.propagator.fw) | ||
self.bw.append(G.propagator.bw) | ||
else: | ||
g.distance = self.p.slice_thickness | ||
G = geometry.Geo(owner=None, pars=g) | ||
self.fw = [G.propagator.fw for i in range(self.p.number_of_slices-1)] | ||
self.bw = [G.propagator.bw for i in range(self.p.number_of_slices-1)] | ||
|
||
def engine_iterate(self, num=1): | ||
""" | ||
Compute one iteration. | ||
""" | ||
vieworder = list(self.di.views.keys()) | ||
vieworder.sort() | ||
rng = np.random.default_rng() | ||
|
||
for it in range(num): | ||
|
||
error_dct = {} | ||
rng.shuffle(vieworder) | ||
|
||
for name in vieworder: | ||
view = self.di.views[name] | ||
if not view.active: | ||
continue | ||
|
||
# Multislice update | ||
error_dct[name] = self.multislice_update(view) | ||
|
||
self.curiter += 1 | ||
|
||
return error_dct | ||
|
||
def engine_finalize(self): | ||
self.ob.fill(self._object[0]) | ||
for i in range(1, self.p.number_of_slices): | ||
self.ob *= self._object[i] | ||
|
||
# Save the slices | ||
slices_info = Param() | ||
slices_info.number_of_slices = self.p.number_of_slices | ||
slices_info.slice_thickness = self.p.slice_thickness | ||
slices_info.objects = {ob.ID: {ID: S._to_dict() for ID, S in ob.storages.items()} | ||
for ob in self._object} | ||
slices_info.slice_start_iteration = self.p.slice_start_iteration | ||
|
||
header = {'description': 'multi-slices result details.'} | ||
|
||
h5opt = io.h5options['UNSUPPORTED'] | ||
io.h5options['UNSUPPORTED'] = 'ignore' | ||
logger.info(f'Saving to {self.p.fslices}') | ||
io.h5write(self.p.fslices, header=header, content=slices_info) | ||
io.h5options['UNSUPPORTED'] = h5opt | ||
|
||
return super().engine_finalize() | ||
|
||
def multislice_update(self, view): | ||
""" | ||
Performs one 'iteration' of 3PIE (multislice ePIE) for a single view. | ||
Based on https://doi.org/10.1364/JOSAA.29.001606 | ||
""" | ||
|
||
for i in range(self.p.number_of_slices-1): | ||
for name, pod in view.pods.items(): | ||
# exit wave for this slice | ||
if self.curiter >= self.p.slice_start_iteration[i]: | ||
self._exits[i][pod.pr_view] = self._probe[i][pod.pr_view] * self._object[i][pod.ob_view] | ||
else: | ||
self._exits[i][pod.pr_view] = self._probe[i][pod.pr_view] * 1. | ||
# incident wave for next slice | ||
self._probe[i+1][pod.pr_view] = self.fw[i](self._exits[i][pod.pr_view]) | ||
|
||
for name, pod in view.pods.items(): | ||
# Exit wave for last slice | ||
if self.curiter >= self.p.slice_start_iteration[-1]: | ||
self._exits[-1][pod.pr_view] = self._probe[-1][pod.pr_view] * self._object[-1][pod.ob_view] | ||
else: | ||
self._exits[-1][pod.pr_view] = self._probe[-1][pod.pr_view] * 1. | ||
# Save final state into pod (need for ptypy fourier update) | ||
pod.probe = self._probe[-1][pod.pr_view] | ||
pod.object = self._object[-1][pod.ob_view] | ||
pod.exit = self._exits[-1][pod.pr_view] | ||
|
||
# Fourier update | ||
error = self.fourier_update(view) | ||
|
||
# Object/probe update for the last slice | ||
if self.curiter >= self.p.slice_start_iteration[-1]: | ||
self.object_update(view, {pod.ID:self._exits[-1][pod.pr_view] for name, pod in view.pods.items()}) | ||
self.probe_update(view, {pod.ID:self._exits[-1][pod.pr_view] for name, pod in view.pods.items()}) | ||
for name, pod in view.pods.items(): | ||
self._object[-1][pod.ob_view] = pod.object | ||
self._probe[-1][pod.pr_view] = pod.probe | ||
else: | ||
for name, pod in view.pods.items(): | ||
self._probe[-1][pod.pr_view] = pod.exit * 1. | ||
|
||
# Object/probe update for other slices (backwards) | ||
for i in range(self.p.number_of_slices-2, -1, -1): | ||
if self.curiter >= self.p.slice_start_iteration[i]: | ||
|
||
for name, pod in view.pods.items(): | ||
# Backwards propagation of the probe | ||
pod.exit = self.bw[i](self._probe[i+1][pod.pr_view]) | ||
# Save state into pods | ||
pod.probe = self._probe[i][pod.pr_view] | ||
pod.object = self._object[i][pod.ob_view] | ||
|
||
# Actual object/probe update | ||
self.object_update(view, {pod.ID:self._exits[i][pod.pr_view] for name, pod in view.pods.items()}) | ||
self.probe_update(view, {pod.ID:self._exits[i][pod.pr_view] for name, pod in view.pods.items()}) | ||
for name, pod in view.pods.items(): | ||
self._object[i][pod.ob_view] = pod.object | ||
self._probe[i][pod.pr_view] = pod.probe | ||
else: | ||
for name, pod in view.pods.items(): | ||
self._probe[i][pod.pr_view] = self.bw[i](self._probe[i+1][pod.pr_view]) | ||
|
||
# set the object as the product of all slices for better live plotting | ||
self.ob.fill(self._object[0]) | ||
for i in range(1, self.p.number_of_slices): | ||
self.ob *= self._object[i] | ||
|
||
return error |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,12 @@ class ThreePIE(stochastic.EPIE): | |
help = File path for the slice data | ||
doc = | ||
|
||
[object_regularization_rate] | ||
default = 0.0 | ||
type = float | ||
help = regularization rate for object slices | ||
doc = | ||
|
||
""" | ||
def __init__(self, ptycho_parent, pars=None): | ||
super(ThreePIE, self).__init__(ptycho_parent, pars) | ||
|
@@ -227,4 +233,30 @@ def multislice_update(self, view): | |
for i in range(1, self.p.number_of_slices): | ||
self.ob *= self._object[i] | ||
|
||
return error | ||
if self.p.object_regularization_rate > 0: | ||
self.apply_object_regularization() | ||
|
||
return error | ||
|
||
def apply_object_regularization(self): | ||
# single mode implementation | ||
# only valide for slices with identical thickness | ||
assert(self.p.number_of_slices > 1) | ||
assert(isinstance(self.p.slice_thickness, float)) | ||
|
||
shape = self._object[0].S["Sscan_00G00"].data.shape[1:] | ||
psize = self._object[0].S["Sscan_00G00"].psize[0] | ||
kz = np.fft.fftfreq(self.p.number_of_slices, self.p.slice_thickness)[..., np.newaxis, np.newaxis] | ||
ky = np.fft.fftfreq(shape[0], psize)[..., np.newaxis] | ||
kx = np.fft.fftfreq(shape[1], psize) | ||
|
||
# calculate the weight array | ||
w = 1 - 2*np.arctan2(self.p.object_regularization_rate**2 * kz**2, kx**2+ky**2+np.spacing(1))/np.pi | ||
Comment on lines
+247
to
+254
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this part is basically just calculating some weights |
||
|
||
current_object = np.fft.ifftn(np.fft.fftn([self._object[i].S["Sscan_00G00"].data[0,...] for i in range(len(self._object))]) * w) | ||
|
||
print("object shape", self._object[0].S["Sscan_00G00"].data.shape) | ||
print("w shape", w.shape) | ||
print("current shape", current_object.shape) | ||
for i in range(len(self._object)): | ||
self._object[i].S["Sscan_00G00"].data[0, ...] = current_object[i, ...] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
This script is a test for ptychographic reconstruction in the absence | ||
of actual data. It uses the test Scan class | ||
`ptypy.core.data.MoonFlowerScan` to provide "data". | ||
""" | ||
from ptypy.core import Ptycho | ||
from ptypy import utils as u | ||
from ptypy.custom import ePIE_multislice | ||
|
||
import tempfile | ||
tmpdir = tempfile.gettempdir() | ||
|
||
p = u.Param() | ||
|
||
# for verbose output | ||
p.verbose_level = "info" | ||
|
||
# set home path | ||
p.io = u.Param() | ||
p.io.home = "/".join([tmpdir, "ptypy"]) | ||
|
||
# saving intermediate results | ||
p.io.autosave = u.Param(active=False) | ||
|
||
# opens plotting GUI if interaction set to active) | ||
p.io.autoplot = u.Param(active=True) | ||
p.io.interaction = u.Param(active=True) | ||
|
||
# max 200 frames (128x128px) of diffraction data | ||
p.scans = u.Param() | ||
p.scans.MF = u.Param() | ||
# now you have to specify which ScanModel to use with scans.XX.name, | ||
# just as you have to give 'name' for engines and PtyScan subclasses. | ||
p.scans.MF.name = 'GradFull' | ||
p.scans.MF.data= u.Param() | ||
p.scans.MF.data.name = 'MoonFlowerScan' | ||
p.scans.MF.data.shape = 128 | ||
p.scans.MF.data.num_frames = 200 | ||
p.scans.MF.data.save = None | ||
|
||
# position distance in fraction of illumination frame | ||
p.scans.MF.data.density = 0.2 | ||
# total number of photon in empty beam | ||
p.scans.MF.data.photons = 1e8 | ||
# Gaussian FWHM of possible detector blurring | ||
p.scans.MF.data.psf = 0. | ||
|
||
# attach a reconstrucion engine | ||
p.engines = u.Param() | ||
p.engines.engine00 = u.Param() | ||
p.engines.engine00.name = 'ePIE_multislice' | ||
p.engines.engine00.numiter = 200 | ||
p.engines.engine00.probe_center_tol = None | ||
p.engines.engine00.compute_log_likelihood = True | ||
p.engines.engine00.object_norm_is_global = True | ||
p.engines.engine00.alpha = 1 | ||
p.engines.engine00.beta = 1 | ||
p.engines.engine00.probe_update_start = 0 | ||
p.engines.engine00.number_of_slices = 2 | ||
p.engines.engine00.slice_thickness = 60e-9 | ||
|
||
# prepare and run | ||
if __name__ == "__main__": | ||
P = Ptycho(p,level=5) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would probably make sense to move this a bit higher up to make sure that
self.ob
is calculated for plotting after the regulariser has been applied...