From 878a64430e7ac6eed421151a34e3f10723c1d779 Mon Sep 17 00:00:00 2001
From: Jan Caron <j.caron@fz-juelich.de>
Date: Thu, 23 Jun 2016 14:52:56 +0200
Subject: [PATCH] Added numpy ufunc support and changed testing to nose.

---
 pyramid/fielddata.py                |  9 +++++++++
 pyramid/phasemap.py                 |  9 +++++++++
 pyramid/tests/test_analytic.py      |  4 ++--
 pyramid/tests/test_costfunction.py  |  4 ++--
 pyramid/tests/test_dataset.py       |  4 ++--
 pyramid/tests/test_fielddata.py     |  4 ++--
 pyramid/tests/test_forwardmodel.py  |  4 ++--
 pyramid/tests/test_kernel.py        |  4 ++--
 pyramid/tests/test_magcreator.py    |  4 ++--
 pyramid/tests/test_phasemap.py      |  4 ++--
 pyramid/tests/test_phasemapper.py   | 12 ++----------
 pyramid/tests/test_projector.py     |  8 ++------
 pyramid/tests/test_regularisator.py |  8 ++------
 setup.py                            |  2 +-
 14 files changed, 41 insertions(+), 39 deletions(-)

diff --git a/pyramid/fielddata.py b/pyramid/fielddata.py
index e08aa9f..e4081d7 100644
--- a/pyramid/fielddata.py
+++ b/pyramid/fielddata.py
@@ -170,6 +170,15 @@ class FieldData(object, metaclass=abc.ABCMeta):
         self._log.debug('Calling __imul__')
         return self.__mul__(other)
 
+    def __array__(self, dtype=None):
+        if dtype:
+            return self.field.astype(dtype)
+        else:
+            return self.field
+
+    def __array_wrap__(self, array, _=None):  # _ catches the context, which is not used.
+        return type(self)(self.a, array)
+
     def copy(self):
         """Returns a copy of the :class:`~.FieldData` object
 
diff --git a/pyramid/phasemap.py b/pyramid/phasemap.py
index 7ba0d79..8dd91d5 100644
--- a/pyramid/phasemap.py
+++ b/pyramid/phasemap.py
@@ -257,6 +257,15 @@ class PhaseMap(object):
         self._log.debug('Calling __imul__')
         return self.__mul__(other)
 
+    def __array__(self, dtype=None):
+        if dtype:
+            return self.phase.astype(dtype)
+        else:
+            return self.phase
+
+    def __array_wrap__(self, array, _=None):  # _ catches the context, which is not used.
+        return PhaseMap(self.a, array, self.mask, self.confidence, self.unit)
+
     def copy(self):
         """Returns a copy of the :class:`~.PhaseMap` object
 
diff --git a/pyramid/tests/test_analytic.py b/pyramid/tests/test_analytic.py
index 9fd4704..ae2f79f 100644
--- a/pyramid/tests/test_analytic.py
+++ b/pyramid/tests/test_analytic.py
@@ -57,5 +57,5 @@ class TestCaseAnalytic(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseAnalytic)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_costfunction.py b/pyramid/tests/test_costfunction.py
index 7e4b941..b37c288 100644
--- a/pyramid/tests/test_costfunction.py
+++ b/pyramid/tests/test_costfunction.py
@@ -77,5 +77,5 @@ class TestCaseCostfunction(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseCostfunction)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_dataset.py b/pyramid/tests/test_dataset.py
index 115eb71..3fc5ea0 100644
--- a/pyramid/tests/test_dataset.py
+++ b/pyramid/tests/test_dataset.py
@@ -90,5 +90,5 @@ class TestCaseDataSet(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseDataSet)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_fielddata.py b/pyramid/tests/test_fielddata.py
index f2766ad..1d94ac6 100644
--- a/pyramid/tests/test_fielddata.py
+++ b/pyramid/tests/test_fielddata.py
@@ -118,5 +118,5 @@ class TestCaseVectorData(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseVectorData)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_forwardmodel.py b/pyramid/tests/test_forwardmodel.py
index b646600..fd64179 100644
--- a/pyramid/tests/test_forwardmodel.py
+++ b/pyramid/tests/test_forwardmodel.py
@@ -70,5 +70,5 @@ class TestCaseForwardModel(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseForwardModel)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_kernel.py b/pyramid/tests/test_kernel.py
index b13ed80..46d19c3 100644
--- a/pyramid/tests/test_kernel.py
+++ b/pyramid/tests/test_kernel.py
@@ -33,5 +33,5 @@ class TestCaseKernel(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseKernel)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_magcreator.py b/pyramid/tests/test_magcreator.py
index 53edda2..53b702d 100644
--- a/pyramid/tests/test_magcreator.py
+++ b/pyramid/tests/test_magcreator.py
@@ -82,5 +82,5 @@ class TestCaseMagCreator(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseMagCreator)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_phasemap.py b/pyramid/tests/test_phasemap.py
index e376e62..3884817 100644
--- a/pyramid/tests/test_phasemap.py
+++ b/pyramid/tests/test_phasemap.py
@@ -74,5 +74,5 @@ class TestCasePhaseMap(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMap)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_phasemapper.py b/pyramid/tests/test_phasemapper.py
index d743f2c..0f3914f 100644
--- a/pyramid/tests/test_phasemapper.py
+++ b/pyramid/tests/test_phasemapper.py
@@ -168,13 +168,5 @@ class TestCasePM(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperRDFC)
-    unittest.TextTestRunner(verbosity=2).run(suite)
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperRDRC)
-    unittest.TextTestRunner(verbosity=2).run(suite)
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperFDFC)
-    unittest.TextTestRunner(verbosity=2).run(suite)
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperMIP)
-    unittest.TextTestRunner(verbosity=2).run(suite)
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePM)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_projector.py b/pyramid/tests/test_projector.py
index f3498d6..230abd8 100644
--- a/pyramid/tests/test_projector.py
+++ b/pyramid/tests/test_projector.py
@@ -258,9 +258,5 @@ class TestCaseYTiltProjector(unittest.TestCase):
 # TODO: Test RotTiltProjector!!!
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseSimpleProjector)
-    unittest.TextTestRunner(verbosity=2).run(suite)
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseXTiltProjector)
-    unittest.TextTestRunner(verbosity=2).run(suite)
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseYTiltProjector)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/pyramid/tests/test_regularisator.py b/pyramid/tests/test_regularisator.py
index 7ffaf55..443bbcb 100644
--- a/pyramid/tests/test_regularisator.py
+++ b/pyramid/tests/test_regularisator.py
@@ -131,9 +131,5 @@ class TestCaseFirstOrderRegularisator(unittest.TestCase):
 
 
 if __name__ == '__main__':
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseNoneRegularisator)
-    unittest.TextTestRunner(verbosity=2).run(suite)
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseZeroOrderRegularisator)
-    unittest.TextTestRunner(verbosity=2).run(suite)
-    suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseFirstOrderRegularisator)
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    import nose
+    nose.run(defaultTest=__name__)
diff --git a/setup.py b/setup.py
index d2b588f..73510a5 100644
--- a/setup.py
+++ b/setup.py
@@ -147,7 +147,7 @@ setup(name=DISTNAME,
       packages=find_packages(exclude=['tests']),
       include_dirs=[numpy.get_include()],
       requires=['numpy', 'scipy', 'matplotlib', 'Pillow',
-                'mayavi', 'pyfftw', 'hyperspy', 'Cython'],
+                'mayavi', 'pyfftw', 'hyperspy', 'Cython', 'nose'],
       scripts=get_files('scripts'),
       test_suite='nose.collector',
       cmdclass={'build_ext': build_ext, 'build': build},
-- 
GitLab