From 7b0a90e063a3209255cee25ba3da35b68b8a1316 Mon Sep 17 00:00:00 2001
From: caron <j.caron@fz-juelich.de>
Date: Fri, 12 Jan 2018 09:23:38 +0100
Subject: [PATCH] get_vector_field_errors now accepts masks.

---
 pyramid/diagnostics.py | 20 +++++++++++++++-----
 1 file changed, 15 insertions(+), 5 deletions(-)

diff --git a/pyramid/diagnostics.py b/pyramid/diagnostics.py
index 328e73d..8e8b78f 100644
--- a/pyramid/diagnostics.py
+++ b/pyramid/diagnostics.py
@@ -609,16 +609,26 @@ class LCurve(object):
         # TODO: Don't plot the steep part on the right...
 
 
-def get_vector_field_errors(vector_data, vector_data_ref):
+def get_vector_field_errors(vector_data, vector_data_ref, mask=None):
     """After Kemp et. al.: Analysis of noise-induced errors in vector-field electron tomography"""
-    v, vr = vector_data.field, vector_data_ref.field
-    va, vra = vector_data.field_amp, vector_data_ref.field_amp
-    volume = np.prod(vector_data.dim)
+    if mask is not None:
+        vector_data_masked = VectorData(vector_data.a, np.zeros(vector_data.shape))
+        vector_data_masked.set_vector(vector_data.get_vector(mask), mask)
+        vector_data_ref_masked = VectorData(vector_data_ref.a, np.zeros(vector_data_ref.shape))
+        vector_data_ref_masked.set_vector(vector_data_ref.get_vector(mask), mask)
+        v, vr = vector_data_masked.field, vector_data_ref_masked.field
+        va, vra = vector_data_masked.field_amp, vector_data_ref_masked.field_amp
+        volume = mask.sum()
+    else:
+        v, vr = vector_data.field, vector_data_ref.field
+        va, vra = vector_data.field_amp, vector_data_ref.field_amp
+        volume = np.prod(vector_data.dim)
     # Total error:
     amp_sum_sqr = np.nansum((v - vr)**2)
     rms_tot = np.sqrt(amp_sum_sqr / np.nansum(vra**2))
     # Directional error:
-    scal_prod = np.clip(np.nansum(vr * v, axis=0) / (vra * va), -1, 1)  # arccos float pt. inacc.!
+    with np.errstate(divide='ignore', invalid='ignore'):  # ignore "invalid value in true_divide"!
+        scal_prod = np.clip(np.nansum(vr * v, axis=0) / (vra * va), -1, 1)  # arccos float inacc.!
     rms_dir = np.sqrt(np.nansum(np.arccos(scal_prod)**2) / volume) / np.pi
     # Magnitude error:
     rms_mag = np.sqrt(np.nansum((va - vra)**2) / np.nansum(vra**2))
-- 
GitLab