From 878a64430e7ac6eed421151a34e3f10723c1d779 Mon Sep 17 00:00:00 2001 From: Jan Caron <j.caron@fz-juelich.de> Date: Thu, 23 Jun 2016 14:52:56 +0200 Subject: [PATCH] Added numpy ufunc support and changed testing to nose. --- pyramid/fielddata.py | 9 +++++++++ pyramid/phasemap.py | 9 +++++++++ pyramid/tests/test_analytic.py | 4 ++-- pyramid/tests/test_costfunction.py | 4 ++-- pyramid/tests/test_dataset.py | 4 ++-- pyramid/tests/test_fielddata.py | 4 ++-- pyramid/tests/test_forwardmodel.py | 4 ++-- pyramid/tests/test_kernel.py | 4 ++-- pyramid/tests/test_magcreator.py | 4 ++-- pyramid/tests/test_phasemap.py | 4 ++-- pyramid/tests/test_phasemapper.py | 12 ++---------- pyramid/tests/test_projector.py | 8 ++------ pyramid/tests/test_regularisator.py | 8 ++------ setup.py | 2 +- 14 files changed, 41 insertions(+), 39 deletions(-) diff --git a/pyramid/fielddata.py b/pyramid/fielddata.py index e08aa9f..e4081d7 100644 --- a/pyramid/fielddata.py +++ b/pyramid/fielddata.py @@ -170,6 +170,15 @@ class FieldData(object, metaclass=abc.ABCMeta): self._log.debug('Calling __imul__') return self.__mul__(other) + def __array__(self, dtype=None): + if dtype: + return self.field.astype(dtype) + else: + return self.field + + def __array_wrap__(self, array, _=None): # _ catches the context, which is not used. + return type(self)(self.a, array) + def copy(self): """Returns a copy of the :class:`~.FieldData` object diff --git a/pyramid/phasemap.py b/pyramid/phasemap.py index 7ba0d79..8dd91d5 100644 --- a/pyramid/phasemap.py +++ b/pyramid/phasemap.py @@ -257,6 +257,15 @@ class PhaseMap(object): self._log.debug('Calling __imul__') return self.__mul__(other) + def __array__(self, dtype=None): + if dtype: + return self.phase.astype(dtype) + else: + return self.phase + + def __array_wrap__(self, array, _=None): # _ catches the context, which is not used. + return PhaseMap(self.a, array, self.mask, self.confidence, self.unit) + def copy(self): """Returns a copy of the :class:`~.PhaseMap` object diff --git a/pyramid/tests/test_analytic.py b/pyramid/tests/test_analytic.py index 9fd4704..ae2f79f 100644 --- a/pyramid/tests/test_analytic.py +++ b/pyramid/tests/test_analytic.py @@ -57,5 +57,5 @@ class TestCaseAnalytic(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseAnalytic) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_costfunction.py b/pyramid/tests/test_costfunction.py index 7e4b941..b37c288 100644 --- a/pyramid/tests/test_costfunction.py +++ b/pyramid/tests/test_costfunction.py @@ -77,5 +77,5 @@ class TestCaseCostfunction(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseCostfunction) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_dataset.py b/pyramid/tests/test_dataset.py index 115eb71..3fc5ea0 100644 --- a/pyramid/tests/test_dataset.py +++ b/pyramid/tests/test_dataset.py @@ -90,5 +90,5 @@ class TestCaseDataSet(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseDataSet) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_fielddata.py b/pyramid/tests/test_fielddata.py index f2766ad..1d94ac6 100644 --- a/pyramid/tests/test_fielddata.py +++ b/pyramid/tests/test_fielddata.py @@ -118,5 +118,5 @@ class TestCaseVectorData(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseVectorData) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_forwardmodel.py b/pyramid/tests/test_forwardmodel.py index b646600..fd64179 100644 --- a/pyramid/tests/test_forwardmodel.py +++ b/pyramid/tests/test_forwardmodel.py @@ -70,5 +70,5 @@ class TestCaseForwardModel(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseForwardModel) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_kernel.py b/pyramid/tests/test_kernel.py index b13ed80..46d19c3 100644 --- a/pyramid/tests/test_kernel.py +++ b/pyramid/tests/test_kernel.py @@ -33,5 +33,5 @@ class TestCaseKernel(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseKernel) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_magcreator.py b/pyramid/tests/test_magcreator.py index 53edda2..53b702d 100644 --- a/pyramid/tests/test_magcreator.py +++ b/pyramid/tests/test_magcreator.py @@ -82,5 +82,5 @@ class TestCaseMagCreator(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseMagCreator) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_phasemap.py b/pyramid/tests/test_phasemap.py index e376e62..3884817 100644 --- a/pyramid/tests/test_phasemap.py +++ b/pyramid/tests/test_phasemap.py @@ -74,5 +74,5 @@ class TestCasePhaseMap(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMap) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_phasemapper.py b/pyramid/tests/test_phasemapper.py index d743f2c..0f3914f 100644 --- a/pyramid/tests/test_phasemapper.py +++ b/pyramid/tests/test_phasemapper.py @@ -168,13 +168,5 @@ class TestCasePM(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperRDFC) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperRDRC) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperFDFC) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperMIP) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePM) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_projector.py b/pyramid/tests/test_projector.py index f3498d6..230abd8 100644 --- a/pyramid/tests/test_projector.py +++ b/pyramid/tests/test_projector.py @@ -258,9 +258,5 @@ class TestCaseYTiltProjector(unittest.TestCase): # TODO: Test RotTiltProjector!!! if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseSimpleProjector) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseXTiltProjector) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseYTiltProjector) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_regularisator.py b/pyramid/tests/test_regularisator.py index 7ffaf55..443bbcb 100644 --- a/pyramid/tests/test_regularisator.py +++ b/pyramid/tests/test_regularisator.py @@ -131,9 +131,5 @@ class TestCaseFirstOrderRegularisator(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseNoneRegularisator) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseZeroOrderRegularisator) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseFirstOrderRegularisator) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/setup.py b/setup.py index d2b588f..73510a5 100644 --- a/setup.py +++ b/setup.py @@ -147,7 +147,7 @@ setup(name=DISTNAME, packages=find_packages(exclude=['tests']), include_dirs=[numpy.get_include()], requires=['numpy', 'scipy', 'matplotlib', 'Pillow', - 'mayavi', 'pyfftw', 'hyperspy', 'Cython'], + 'mayavi', 'pyfftw', 'hyperspy', 'Cython', 'nose'], scripts=get_files('scripts'), test_suite='nose.collector', cmdclass={'build_ext': build_ext, 'build': build}, -- GitLab