diff --git a/pyramid/fielddata.py b/pyramid/fielddata.py index e08aa9f1b68e6ae4b65338d1811380abb66fd264..e4081d777d8f222153d58efa0875023f5eb0cfe1 100644 --- a/pyramid/fielddata.py +++ b/pyramid/fielddata.py @@ -170,6 +170,15 @@ class FieldData(object, metaclass=abc.ABCMeta): self._log.debug('Calling __imul__') return self.__mul__(other) + def __array__(self, dtype=None): + if dtype: + return self.field.astype(dtype) + else: + return self.field + + def __array_wrap__(self, array, _=None): # _ catches the context, which is not used. + return type(self)(self.a, array) + def copy(self): """Returns a copy of the :class:`~.FieldData` object diff --git a/pyramid/phasemap.py b/pyramid/phasemap.py index 7ba0d7971ec6a54335550022925728094a293fec..8dd91d58d17e0814932b3d9e185353f4d2d5f680 100644 --- a/pyramid/phasemap.py +++ b/pyramid/phasemap.py @@ -257,6 +257,15 @@ class PhaseMap(object): self._log.debug('Calling __imul__') return self.__mul__(other) + def __array__(self, dtype=None): + if dtype: + return self.phase.astype(dtype) + else: + return self.phase + + def __array_wrap__(self, array, _=None): # _ catches the context, which is not used. + return PhaseMap(self.a, array, self.mask, self.confidence, self.unit) + def copy(self): """Returns a copy of the :class:`~.PhaseMap` object diff --git a/pyramid/tests/test_analytic.py b/pyramid/tests/test_analytic.py index 9fd470419dc094ac2ff0d398a53298067e7e666b..ae2f79fec525a9db481d97fcabba91b81b966124 100644 --- a/pyramid/tests/test_analytic.py +++ b/pyramid/tests/test_analytic.py @@ -57,5 +57,5 @@ class TestCaseAnalytic(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseAnalytic) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_costfunction.py b/pyramid/tests/test_costfunction.py index 7e4b94140807ad8eae451513dedf84e9c5c50f46..b37c28885fde785c90b18eda2369b6503f6297bb 100644 --- a/pyramid/tests/test_costfunction.py +++ b/pyramid/tests/test_costfunction.py @@ -77,5 +77,5 @@ class TestCaseCostfunction(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseCostfunction) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_dataset.py b/pyramid/tests/test_dataset.py index 115eb71d4f5f973bc8934017df14cf25ae3a539a..3fc5ea064cf886490bbe3a10d0ca1ff9f6084170 100644 --- a/pyramid/tests/test_dataset.py +++ b/pyramid/tests/test_dataset.py @@ -90,5 +90,5 @@ class TestCaseDataSet(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseDataSet) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_fielddata.py b/pyramid/tests/test_fielddata.py index f2766ad4847dcbaae9618aab103fbb950ecd2ed8..1d94ac6c0d2ecf6c1d0cad49dbb92533a27f60ee 100644 --- a/pyramid/tests/test_fielddata.py +++ b/pyramid/tests/test_fielddata.py @@ -118,5 +118,5 @@ class TestCaseVectorData(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseVectorData) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_forwardmodel.py b/pyramid/tests/test_forwardmodel.py index b6466000127c2756c088a959517063ddb507bcb6..fd641794f3bc2c49733b7fcaa953a010da12c903 100644 --- a/pyramid/tests/test_forwardmodel.py +++ b/pyramid/tests/test_forwardmodel.py @@ -70,5 +70,5 @@ class TestCaseForwardModel(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseForwardModel) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_kernel.py b/pyramid/tests/test_kernel.py index b13ed805d7e19c71089a822274c2ffda084d82d5..46d19c398ec18836a958a7ebedaa2fbd37533b08 100644 --- a/pyramid/tests/test_kernel.py +++ b/pyramid/tests/test_kernel.py @@ -33,5 +33,5 @@ class TestCaseKernel(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseKernel) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_magcreator.py b/pyramid/tests/test_magcreator.py index 53edda249533a7b7fb19a854617a252ca75b1e4c..53b702deb6660a31c5e604302b78350003e8fa30 100644 --- a/pyramid/tests/test_magcreator.py +++ b/pyramid/tests/test_magcreator.py @@ -82,5 +82,5 @@ class TestCaseMagCreator(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseMagCreator) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_phasemap.py b/pyramid/tests/test_phasemap.py index e376e62d2fafdd31e1d9da0ba864c92c8cb63d72..3884817af5bc0d7c965f92a2d2118c45b305bf07 100644 --- a/pyramid/tests/test_phasemap.py +++ b/pyramid/tests/test_phasemap.py @@ -74,5 +74,5 @@ class TestCasePhaseMap(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMap) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_phasemapper.py b/pyramid/tests/test_phasemapper.py index d743f2c344c986df3e6ef7c4ea92f5c10bec1f42..0f3914fb08e203e42d4791323b2fb0cd233b6118 100644 --- a/pyramid/tests/test_phasemapper.py +++ b/pyramid/tests/test_phasemapper.py @@ -168,13 +168,5 @@ class TestCasePM(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperRDFC) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperRDRC) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperFDFC) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePhaseMapperMIP) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCasePM) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_projector.py b/pyramid/tests/test_projector.py index f3498d6a798ac128fcbf42d0a2344aea5afca83c..230abd87a3f8720e692b5dd3fcc3999a54e61939 100644 --- a/pyramid/tests/test_projector.py +++ b/pyramid/tests/test_projector.py @@ -258,9 +258,5 @@ class TestCaseYTiltProjector(unittest.TestCase): # TODO: Test RotTiltProjector!!! if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseSimpleProjector) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseXTiltProjector) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseYTiltProjector) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/pyramid/tests/test_regularisator.py b/pyramid/tests/test_regularisator.py index 7ffaf55e87e4027ec5a08b53d9c69a674011ee36..443bbcbf91d7c8d79fae0e644c5f283edc81429f 100644 --- a/pyramid/tests/test_regularisator.py +++ b/pyramid/tests/test_regularisator.py @@ -131,9 +131,5 @@ class TestCaseFirstOrderRegularisator(unittest.TestCase): if __name__ == '__main__': - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseNoneRegularisator) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseZeroOrderRegularisator) - unittest.TextTestRunner(verbosity=2).run(suite) - suite = unittest.TestLoader().loadTestsFromTestCase(TestCaseFirstOrderRegularisator) - unittest.TextTestRunner(verbosity=2).run(suite) + import nose + nose.run(defaultTest=__name__) diff --git a/setup.py b/setup.py index d2b588f85ec645f35782cfc7d40c2a545d5d5292..73510a518183fa72a87bd44b6c48b8198333c2b5 100644 --- a/setup.py +++ b/setup.py @@ -147,7 +147,7 @@ setup(name=DISTNAME, packages=find_packages(exclude=['tests']), include_dirs=[numpy.get_include()], requires=['numpy', 'scipy', 'matplotlib', 'Pillow', - 'mayavi', 'pyfftw', 'hyperspy', 'Cython'], + 'mayavi', 'pyfftw', 'hyperspy', 'Cython', 'nose'], scripts=get_files('scripts'), test_suite='nose.collector', cmdclass={'build_ext': build_ext, 'build': build},