From 66b51804dee53d1c8fe6740d3453987ec917bc19 Mon Sep 17 00:00:00 2001
From: "Joern Ungermann (IEK-7 FZ-Juelich)" <j.ungermann@fz-juelich.de>
Date: Fri, 26 Sep 2014 19:35:54 +0200
Subject: [PATCH] included a crude way of supproting masks in the ofrwardmodel.

---
 pyramid/forwardmodel.py | 19 +++++++++++++++++--
 1 file changed, 17 insertions(+), 2 deletions(-)

diff --git a/pyramid/forwardmodel.py b/pyramid/forwardmodel.py
index eeb4868..951ef50 100644
--- a/pyramid/forwardmodel.py
+++ b/pyramid/forwardmodel.py
@@ -7,7 +7,7 @@ import numpy as np
 
 from pyramid.kernel import Kernel
 from pyramid.projector import Projector
-
+from pyramid.magdata import MagData
 import logging
 
 
@@ -40,7 +40,7 @@ class ForwardModel(object):
 
     LOG = logging.getLogger(__name__+'.ForwardModel')
 
-    def __init__(self, projectors, kernel):
+    def __init__(self, projectors, kernel, mask=None):
         self.LOG.debug('Calling __init__')
         assert np.all([isinstance(projector, Projector) for projector in projectors]), \
             'List has to consist of Projector objects!'
@@ -53,9 +53,14 @@ class ForwardModel(object):
         self.dim_uv = kernel.dim_uv
         self.size_2d = kernel.size
         self.LOG.debug('Creating '+str(self))
+        self.mask = mask
 
     def __call__(self, x):
         self.LOG.debug('Calling __call__')
+        if self.mask is not None:
+            mag_data_rec = MagData(self.a, np.zeros((3,)+self.mask.shape))
+            mag_data_rec.set_vector(self.mask, x)
+            x = mag_data_rec.mag_vec
         result = [self.kernel(projector(x)) for projector in self.projectors]
         return np.reshape(result, -1)
 
@@ -80,6 +85,11 @@ class ForwardModel(object):
 
         '''
         self.LOG.debug('Calling jac_dot')
+        if self.mask is not None:
+            mag_data_rec = MagData(self.a, np.zeros((3,)+self.mask.shape))
+            mag_data_rec.set_vector(self.mask, vector)
+            vector = mag_data_rec.mag_vec
+
         result = [self.kernel.jac_dot(projector.jac_dot(vector)) for projector in self.projectors]
         result = np.reshape(result, -1)
         return result
@@ -109,6 +119,11 @@ class ForwardModel(object):
         result = np.zeros(3*size_3d)
         for (i, projector) in enumerate(self.projectors):
             result += projector.jac_T_dot(self.kernel.jac_T_dot(vector[i*size_2d:(i+1)*size_2d]))
+        if self.mask is not None:
+            mag_data_rec = MagData(self.a, np.zeros((3,)+self.mask.shape))
+            mag_data_rec.magnitude[:] = result.reshape(mag_data_rec.magnitude.shape)
+            result = mag_data_rec.get_vector(self.mask)
+
         return np.reshape(result, -1)
 
     def __repr__(self):
-- 
GitLab