Skip to content
Snippets Groups Projects
Commit 69879fe6 authored by Jörn Ungermann's avatar Jörn Ungermann
Browse files

Added p and TV regularisation.

parent 8faed108
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,8 @@ import numpy as np
from scipy.sparse import coo_matrix, csr_matrix
import jutil.norms as jnorm
import jutil.diff as jdiff
import jutil.operator as joperator
from pyramid.converter import IndexConverter
import logging
......@@ -103,9 +104,13 @@ class ZeroOrderRegularisator(Regularisator):
LOG = logging.getLogger(__name__+'.ZeroOrderRegularisator')
def __init__(self, lam):
def __init__(self, _, lam, p=2):
self.LOG.debug('Calling __init__')
norm = jnorm.L2Square()
self.p = p
if p == 2:
norm = jnorm.L2Square()
else:
norm = jnorm.LPPow(p, 1e-12)
super(ZeroOrderRegularisator, self).__init__(norm, lam)
self.LOG.debug('Created '+str(self))
......@@ -113,11 +118,14 @@ class ZeroOrderRegularisator(Regularisator):
class FirstOrderRegularisator(Regularisator):
# TODO: Docstring!
def __init__(self, mask, lam, x_a=None):
import jutil
D0 = jutil.diff.get_diff_operator(mask, 0, 3)
D1 = jutil.diff.get_diff_operator(mask, 1, 3)
D = jutil.operator.VStack([D0, D1])
norm = jutil.norms.WeightedL2Square(D)
def __init__(self, mask, lam, p=2):
self.p = p
D0 = jdiff.get_diff_operator(mask, 0, 3)
D1 = jdiff.get_diff_operator(mask, 1, 3)
D = joperator.VStack([D0, D1])
if p == 2:
norm = jnorm.WeightedL2Square(D)
else:
norm = jnorm.WeightedTV(jnorm.LPPow(p, 1e-12), D, [D0.shape[0], D.shape[0]])
super(FirstOrderRegularisator, self).__init__(norm, lam)
self.LOG.debug('Created '+str(self))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment