diff --git a/pyramid/regularisator.py b/pyramid/regularisator.py index 5f4c2cfa71a389c59ab1188ed12397478e319eb4..f34713cb8eef0ec918bd5e04ee2119c70ebda6ac 100644 --- a/pyramid/regularisator.py +++ b/pyramid/regularisator.py @@ -12,7 +12,8 @@ import numpy as np from scipy.sparse import coo_matrix, csr_matrix import jutil.norms as jnorm - +import jutil.diff as jdiff +import jutil.operator as joperator from pyramid.converter import IndexConverter import logging @@ -103,9 +104,13 @@ class ZeroOrderRegularisator(Regularisator): LOG = logging.getLogger(__name__+'.ZeroOrderRegularisator') - def __init__(self, lam): + def __init__(self, _, lam, p=2): self.LOG.debug('Calling __init__') - norm = jnorm.L2Square() + self.p = p + if p == 2: + norm = jnorm.L2Square() + else: + norm = jnorm.LPPow(p, 1e-12) super(ZeroOrderRegularisator, self).__init__(norm, lam) self.LOG.debug('Created '+str(self)) @@ -113,11 +118,14 @@ class ZeroOrderRegularisator(Regularisator): class FirstOrderRegularisator(Regularisator): # TODO: Docstring! - def __init__(self, mask, lam, x_a=None): - import jutil - D0 = jutil.diff.get_diff_operator(mask, 0, 3) - D1 = jutil.diff.get_diff_operator(mask, 1, 3) - D = jutil.operator.VStack([D0, D1]) - norm = jutil.norms.WeightedL2Square(D) + def __init__(self, mask, lam, p=2): + self.p = p + D0 = jdiff.get_diff_operator(mask, 0, 3) + D1 = jdiff.get_diff_operator(mask, 1, 3) + D = joperator.VStack([D0, D1]) + if p == 2: + norm = jnorm.WeightedL2Square(D) + else: + norm = jnorm.WeightedTV(jnorm.LPPow(p, 1e-12), D, [D0.shape[0], D.shape[0]]) super(FirstOrderRegularisator, self).__init__(norm, lam) self.LOG.debug('Created '+str(self))