Skip to content
Snippets Groups Projects
Commit 66b51804 authored by Jörn Ungermann's avatar Jörn Ungermann
Browse files

included a crude way of supproting masks in the ofrwardmodel.

parent e39061e3
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from pyramid.kernel import Kernel from pyramid.kernel import Kernel
from pyramid.projector import Projector from pyramid.projector import Projector
from pyramid.magdata import MagData
import logging import logging
...@@ -40,7 +40,7 @@ class ForwardModel(object): ...@@ -40,7 +40,7 @@ class ForwardModel(object):
LOG = logging.getLogger(__name__+'.ForwardModel') LOG = logging.getLogger(__name__+'.ForwardModel')
def __init__(self, projectors, kernel): def __init__(self, projectors, kernel, mask=None):
self.LOG.debug('Calling __init__') self.LOG.debug('Calling __init__')
assert np.all([isinstance(projector, Projector) for projector in projectors]), \ assert np.all([isinstance(projector, Projector) for projector in projectors]), \
'List has to consist of Projector objects!' 'List has to consist of Projector objects!'
...@@ -53,9 +53,14 @@ class ForwardModel(object): ...@@ -53,9 +53,14 @@ class ForwardModel(object):
self.dim_uv = kernel.dim_uv self.dim_uv = kernel.dim_uv
self.size_2d = kernel.size self.size_2d = kernel.size
self.LOG.debug('Creating '+str(self)) self.LOG.debug('Creating '+str(self))
self.mask = mask
def __call__(self, x): def __call__(self, x):
self.LOG.debug('Calling __call__') self.LOG.debug('Calling __call__')
if self.mask is not None:
mag_data_rec = MagData(self.a, np.zeros((3,)+self.mask.shape))
mag_data_rec.set_vector(self.mask, x)
x = mag_data_rec.mag_vec
result = [self.kernel(projector(x)) for projector in self.projectors] result = [self.kernel(projector(x)) for projector in self.projectors]
return np.reshape(result, -1) return np.reshape(result, -1)
...@@ -80,6 +85,11 @@ class ForwardModel(object): ...@@ -80,6 +85,11 @@ class ForwardModel(object):
''' '''
self.LOG.debug('Calling jac_dot') self.LOG.debug('Calling jac_dot')
if self.mask is not None:
mag_data_rec = MagData(self.a, np.zeros((3,)+self.mask.shape))
mag_data_rec.set_vector(self.mask, vector)
vector = mag_data_rec.mag_vec
result = [self.kernel.jac_dot(projector.jac_dot(vector)) for projector in self.projectors] result = [self.kernel.jac_dot(projector.jac_dot(vector)) for projector in self.projectors]
result = np.reshape(result, -1) result = np.reshape(result, -1)
return result return result
...@@ -109,6 +119,11 @@ class ForwardModel(object): ...@@ -109,6 +119,11 @@ class ForwardModel(object):
result = np.zeros(3*size_3d) result = np.zeros(3*size_3d)
for (i, projector) in enumerate(self.projectors): for (i, projector) in enumerate(self.projectors):
result += projector.jac_T_dot(self.kernel.jac_T_dot(vector[i*size_2d:(i+1)*size_2d])) result += projector.jac_T_dot(self.kernel.jac_T_dot(vector[i*size_2d:(i+1)*size_2d]))
if self.mask is not None:
mag_data_rec = MagData(self.a, np.zeros((3,)+self.mask.shape))
mag_data_rec.magnitude[:] = result.reshape(mag_data_rec.magnitude.shape)
result = mag_data_rec.get_vector(self.mask)
return np.reshape(result, -1) return np.reshape(result, -1)
def __repr__(self): def __repr__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment