summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-10-22 17:43:58 +0100
committerGitHub <noreply@github.com>2019-10-22 17:43:58 +0100
commitd82298ee9a6e38ff6e286077f52a694acd58d5db (patch)
tree5379274dd9daa4b07a69c56733992b9a5d6e4e72
parent35fa0c0cdd07fd8a26fb15eae28b5b412007e5ba (diff)
downloadframework-d82298ee9a6e38ff6e286077f52a694acd58d5db.tar.gz
framework-d82298ee9a6e38ff6e286077f52a694acd58d5db.tar.bz2
framework-d82298ee9a6e38ff6e286077f52a694acd58d5db.tar.xz
framework-d82298ee9a6e38ff6e286077f52a694acd58d5db.zip
Pdhg last objective (#407)
* save previous iteration at start of iteration * adds very_verbose to run method * modified test closes #396
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/Algorithm.py42
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py9
-rwxr-xr-xWrappers/Python/test/test_algorithms.py3
3 files changed, 33 insertions, 21 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py b/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py
index f08688d..78ce438 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py
@@ -60,6 +60,7 @@ class Algorithm(object):
self.timing = []
self._iteration = []
self.update_objective_interval = kwargs.get('update_objective_interval', 1)
+ self.x = None
def set_up(self, *args, **kwargs):
'''Set up the algorithm'''
raise NotImplementedError()
@@ -109,16 +110,23 @@ class Algorithm(object):
'''Returns the solution found'''
return self.x
- def get_last_loss(self):
+ def get_last_loss(self, **kwargs):
'''Returns the last stored value of the loss function
if update_objective_interval is 1 it is the value of the objective at the current
iteration. If update_objective_interval > 1 it is the last stored value.
'''
- return self.__loss[-1]
- def get_last_objective(self):
+ return_all = kwargs.get('return_all', False)
+ objective = self.__loss[-1]
+ if return_all:
+ return list(objective)
+ if isinstance(objective, list):
+ return objective[0]
+ else:
+ return objective
+ def get_last_objective(self, **kwargs):
'''alias to get_last_loss'''
- return self.get_last_loss()
+ return self.get_last_loss(**kwargs)
def update_objective(self):
'''calculates the objective with the current solution'''
raise NotImplementedError()
@@ -155,7 +163,7 @@ class Algorithm(object):
raise ValueError('Update objective interval must be an integer >= 1')
else:
raise ValueError('Update objective interval must be an integer >= 1')
- def run(self, iterations=None, verbose=True, callback=None):
+ def run(self, iterations=None, verbose=True, callback=None, very_verbose=False):
'''run n iterations and update the user with the callback if specified
:param iterations: number of iterations to run. If not set the algorithm will
@@ -163,30 +171,32 @@ class Algorithm(object):
:param verbose: toggles verbose output to screen
:param callback: is a function that receives: current iteration number,
last objective function value and the current solution
+ :param very_verbose: bool, useful for algorithms with primal and dual objectives (PDHG),
+ prints to screen both primal and dual
'''
if self.should_stop():
print ("Stop cryterion has been reached.")
i = 0
if verbose:
- print (self.verbose_header())
+ print (self.verbose_header(very_verbose))
if self.iteration == 0:
if verbose:
- print(self.verbose_output())
+ print(self.verbose_output(very_verbose))
for _ in self:
if (self.iteration) % self.update_objective_interval == 0:
if verbose:
- print (self.verbose_output())
+ print (self.verbose_output(very_verbose))
if callback is not None:
- callback(self.iteration, self.get_last_objective(), self.x)
+ callback(self.iteration, self.get_last_objective(return_all=very_verbose), self.x)
i += 1
if i == iterations:
if self.iteration != self._iteration[-1]:
self.update_objective()
if verbose:
- print (self.verbose_output())
+ print (self.verbose_output(very_verbose))
break
- def verbose_output(self):
+ def verbose_output(self, verbose=False):
'''Creates a nice tabulated output'''
timing = self.timing[-self.update_objective_interval-1:-1]
self._iteration.append(self.iteration)
@@ -198,20 +208,20 @@ class Algorithm(object):
self.iteration,
self.max_iteration,
"{:.3f}".format(t),
- self.objective_to_string()
+ self.objective_to_string(verbose)
)
return out
- def objective_to_string(self):
- el = self.get_last_objective()
+ def objective_to_string(self, verbose=False):
+ el = self.get_last_objective(return_all=verbose)
if type(el) == list:
string = functools.reduce(lambda x,y: x+' {:>13.5e}'.format(y), el[:-1],'')
string += '{:>15.5e}'.format(el[-1])
else:
string = "{:>20.5e}".format(el)
return string
- def verbose_header(self):
- el = self.get_last_objective()
+ def verbose_header(self, verbose=False):
+ el = self.get_last_objective(return_all=verbose)
if type(el) == list:
out = "{:>9} {:>10} {:>13} {:>13} {:>13} {:>15}\n".format('Iter',
'Max Iter',
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 7ed82b2..db1b8dc 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -123,7 +123,10 @@ class PDHG(Algorithm):
def update(self):
-
+ # save previous iteration
+ self.x_old.fill(self.x)
+ self.y_old.fill(self.y)
+
# Gradient ascent for the dual variable
self.operator.direct(self.xbar, out=self.y_tmp)
self.y_tmp *= self.sigma
@@ -145,9 +148,7 @@ class PDHG(Algorithm):
self.xbar += self.x
- self.x_old.fill(self.x)
- self.y_old.fill(self.y)
-
+
def update_objective(self):
p1 = self.f(self.operator.direct(self.x)) + self.g(self.x)
diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py
index 2b38e3f..db13b97 100755
--- a/Wrappers/Python/test/test_algorithms.py
+++ b/Wrappers/Python/test/test_algorithms.py
@@ -195,6 +195,7 @@ class TestAlgorithms(unittest.TestCase):
print ("PDHG Denoising with 3 noises")
# adapted from demo PDHG_TV_Color_Denoising.py in CIL-Demos repository
+ # loader = TestData(data_dir=os.path.join(os.environ['SIRF_INSTALL_PATH'], 'share','ccpi'))
# loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi'))
loader = TestData()
@@ -254,7 +255,7 @@ class TestAlgorithms(unittest.TestCase):
pdhg1 = PDHG(f=f1,g=g,operator=operator, tau=tau, sigma=sigma)
pdhg1.max_iteration = 2000
pdhg1.update_objective_interval = 200
- pdhg1.run(1000)
+ pdhg1.run(1000, very_verbose=True)
rmse = (pdhg1.get_output() - data).norm() / data.as_array().size
print ("RMSE", rmse)