From 66b51804dee53d1c8fe6740d3453987ec917bc19 Mon Sep 17 00:00:00 2001 From: "Joern Ungermann (IEK-7 FZ-Juelich)" <j.ungermann@fz-juelich.de> Date: Fri, 26 Sep 2014 19:35:54 +0200 Subject: [PATCH] included a crude way of supproting masks in the ofrwardmodel. --- pyramid/forwardmodel.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/pyramid/forwardmodel.py b/pyramid/forwardmodel.py index eeb4868..951ef50 100644 --- a/pyramid/forwardmodel.py +++ b/pyramid/forwardmodel.py @@ -7,7 +7,7 @@ import numpy as np from pyramid.kernel import Kernel from pyramid.projector import Projector - +from pyramid.magdata import MagData import logging @@ -40,7 +40,7 @@ class ForwardModel(object): LOG = logging.getLogger(__name__+'.ForwardModel') - def __init__(self, projectors, kernel): + def __init__(self, projectors, kernel, mask=None): self.LOG.debug('Calling __init__') assert np.all([isinstance(projector, Projector) for projector in projectors]), \ 'List has to consist of Projector objects!' @@ -53,9 +53,14 @@ class ForwardModel(object): self.dim_uv = kernel.dim_uv self.size_2d = kernel.size self.LOG.debug('Creating '+str(self)) + self.mask = mask def __call__(self, x): self.LOG.debug('Calling __call__') + if self.mask is not None: + mag_data_rec = MagData(self.a, np.zeros((3,)+self.mask.shape)) + mag_data_rec.set_vector(self.mask, x) + x = mag_data_rec.mag_vec result = [self.kernel(projector(x)) for projector in self.projectors] return np.reshape(result, -1) @@ -80,6 +85,11 @@ class ForwardModel(object): ''' self.LOG.debug('Calling jac_dot') + if self.mask is not None: + mag_data_rec = MagData(self.a, np.zeros((3,)+self.mask.shape)) + mag_data_rec.set_vector(self.mask, vector) + vector = mag_data_rec.mag_vec + result = [self.kernel.jac_dot(projector.jac_dot(vector)) for projector in self.projectors] result = np.reshape(result, -1) return result @@ -109,6 +119,11 @@ class ForwardModel(object): result = np.zeros(3*size_3d) for (i, projector) in enumerate(self.projectors): result += projector.jac_T_dot(self.kernel.jac_T_dot(vector[i*size_2d:(i+1)*size_2d])) + if self.mask is not None: + mag_data_rec = MagData(self.a, np.zeros((3,)+self.mask.shape)) + mag_data_rec.magnitude[:] = result.reshape(mag_data_rec.magnitude.shape) + result = mag_data_rec.get_vector(self.mask) + return np.reshape(result, -1) def __repr__(self): -- GitLab