Copyright (c) 2016, Christian Schou Oxvig, Thomas Arildsen, and Torben Larsen
All rights reserved.
This notebook shows figures detailing noiseless emperical phase transition curves for various Generalized Approximate Message Passing (GAMP) algorithms. The figures are based on the dataset "Generalized Approximate Message Passing Practical 2D Phase Transition Simulations Dataset" by Christian Schou Oxvig, Thomas Arildsen, and Torben Larsen licensed under CC BY 4.0 (http://creativecommons.org/licenses/by/4.0/). The full dataset along with its license conditions is available at https://dx.doi.org/10.5281/zenodo.165051.
Furthermore, parts of the results from "Replication of certain details from J. P. Vila and P. Schniter: "Expectation-Maximization Bernoulli-Gaussian Approximate Message Passing"" by Thomas Arildsen licensed under the BSD 2-Clause license are used in Figure 1. The full deposition along with its license conditions is available at https://doi.org/10.5281/zenodo.160700.
%matplotlib inline
from __future__ import division
import itertools
import cycler
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.io
import tables as tb
import magni
mpl.style.use('ggplot')
fig_size = mpl.rcParams['figure.figsize']
cb5 = magni.utils.plotting._ColourCollection(
{'RdYlBu': ((215, 25, 28), (253, 174, 97), (255, 255, 191), (171, 217, 233), (44, 123, 182))})
color_cycle = cycler.cycler('color', cb5['RdYlBu'])
style_cycle_errorbar = cycler.cycler( # For matplotlib 1.5.3, see https://github.com/matplotlib/matplotlib/issues/7074
'ls', itertools.chain(*[['', '', ls] for ls in magni.utils.plotting.linestyles[:2]]))
magni.utils.plotting.setup_matplotlib({'figure': {'dpi': 800, 'figsize': [fs * 2 for fs in fig_size]},
'axes': {'prop_cycle': color_cycle * style_cycle_errorbar},
'grid': {'alpha': 0.25}})
# Paths to "gamp_practical_2d_phase_transitions_ID_[0-4]_of_5.hdf5" from https://dx.doi.org/10.5281/zenodo.165051.
# File MD5SUM SHA256SUM
# gamp_practical_2d_phase_transitions_ID_0_of_5.hdf5 a9b7cc5c64ff1f900f8fa0ab607f0543 f74fa5f976f00d5d3a32c0ac98b5b37cd11bfd55248b88fbc18ac98fef518640
# gamp_practical_2d_phase_transitions_ID_1_of_5.hdf5 333cce779423e884bf9a0684321bfcf7 118182086822abc1ec22980fc22d8fd2c622de0b19fe29e4d0e6095e7971a346
# gamp_practical_2d_phase_transitions_ID_2_of_5.hdf5 4bd9dad316839435c043f7d0d59c4f4d 365b3118f10da83d7f173301c96cbf5bca47835e48a4e0da144818622ca1866f
# gamp_practical_2d_phase_transitions_ID_3_of_5.hdf5 6ad91ff617f55b1c87d75a75c1b9c9f7 b0d57022f9c4e72a2ac0a441b3d120df06791cc33d35d0bcfb2129334496720f
# gamp_practical_2d_phase_transitions_ID_4_of_5.hdf5 6d89403eaa297eaff5f287a6ff3f4d49 391f8096bec5915c93ca6167c6a41da28ebe01fbb61bb8df151c2430f23037d5
hdf5_database_paths = ['gamp_practical_2d_phase_transitions_ID_0_of_5.hdf5',
'gamp_practical_2d_phase_transitions_ID_1_of_5.hdf5',
'gamp_practical_2d_phase_transitions_ID_2_of_5.hdf5',
'gamp_practical_2d_phase_transitions_ID_3_of_5.hdf5',
'gamp_practical_2d_phase_transitions_ID_4_of_5.hdf5',]
# Path to "vila2011_results_old_not_unif.mat" from https://doi.org/10.1109/ACSSC.2011.6190117
# MD5SUM: 2cf3713c31ddb69de84f3fc5eb3fe94b
# SHA256SUM: a7e6361a31099fa9abb7cb669f97c9ac1a3dddcbba23a8fafb89f95ffd19af6d
vila2011_reference_path = 'vila2011_results_old_not_unif.mat'
algorithm_names = {'AMPresidual_genie': 'DMM AMP R',
'AMPmedian_genie': 'DMM AMP M',
'GAMPFull_genie': 'GAMP using $|\mathbf{A}|^{\\circ 2}$ (genie)',
'GAMPFull_vila': 'GAMP using $|\mathbf{A}|^{\\circ 2}$ (EM)',
'GAMPSAKrzakala1m_genie': 'GAMP using Krzakalas SA with variance $1/m$ (genie)',
'GAMPSAKrzakala1m_vila': 'GAMP using Krzakalas SA with variance $1/m$ (EM)',
'GAMPSAKrzakala1n_genie': 'GAMP using Krzakalas SA with variance $1/n$ (genie)',
'GAMPSAKrzakala1n_vila': 'GAMP using Krzakalas SA with variance $1/n$ (EM)',
'GAMPSARangan_genie': 'GAMP using Rangans SA (genie)',
'GAMPSARangan_vila': 'GAMP using Rangans SA (EM)'}
skips = ['annotations', 'chases', 'fileset', 'gamp_pt_tools', 'parameter_values']
pt_curves = {}
pt_dists = {}
pt_times = {}
pt_norms = {}
pt_mses = {}
pt_percentiles = {}
for hdf5_database_path in hdf5_database_paths:
with tb.File(hdf5_database_path, mode='r') as h5_file:
for group in h5_file.walk_groups('/'):
label = group._v_name
if any([skip in label for skip in skips]) or group is h5_file.root:
continue
pt_dists[label] = h5_file.get_node('/' + '/'.join([label, 'dist'])).read()
pt_times[label] = h5_file.get_node('/' + '/'.join([label, 'time'])).read()
pt_norms[label] = h5_file.get_node('/' + '/'.join([label, 'norm'])).read()
pt_mses[label] = h5_file.get_node('/' + '/'.join([label, 'mse'])).read()
pt_percentiles[label] = h5_file.get_node('/' + '/'.join([label, 'phase_transition_percentiles'])).read()
pt_curves[label] = magni.cs.phase_transition.io.load_phase_transition(hdf5_database_path, label=label)
fig_1_pt_curves = [
{'delta': pt[0],
'rho': pt[1],
'yerr': np.vstack([pt_percentiles[curve_label][0, :], pt_percentiles[curve_label][3, :]]), # 90/10 errorbars
'label': algorithm_names['_'.join([curve_label.split('_')[k] for k in [0, -1]])]}
for curve_label, pt in sorted(pt_curves.items())
if all([select_criterion in curve_label
for select_criterion in ['32', 'USE']])]
vila_2011_reference_data = scipy.io.loadmat(vila2011_reference_path)
vila_2011_reference_pt_curves = [
{'delta': vila_2011_reference_data['delta_values'].ravel(),
'rho': vila_2011_reference_data['transition_rho_values_EMBGAMP'].ravel(),
'label': 'EM-BG-GAMP',
'style': {'color': (0.99, 0.68, 0.38), 'marker': '*', 'linestyle':'None', 'ms': 12}}
]
magni.cs.phase_transition.plotting.plot_phase_transitions(fig_1_pt_curves, plot_l1=True, errorevery=2, reference_curves=vila_2011_reference_pt_curves)
plt.gcf().gca().axes.lines[-2].set_linestyle('--')
plt.gcf().gca().legend(loc='upper left')
plt.gcf().gca().get_legend().get_frame().set_alpha(0)
plt.savefig('fig1.pdf')
_ = ! pdfcrop 'fig1.pdf'
plt.show()
fig_2_pt_curves = [
{'delta': pt[0],
'rho': pt[1],
'yerr': np.vstack([pt_percentiles[curve_label][0, :], pt_percentiles[curve_label][3, :]]), # 90/10 errorbars
'label': algorithm_names['_'.join([curve_label.split('_')[k] for k in [0, -1]])]}
for curve_label, pt in sorted(pt_curves.items())
if all([select_criterion in curve_label
for select_criterion in ['32', 'RandomDCT2D']])]
magni.cs.phase_transition.plotting.plot_phase_transitions(fig_2_pt_curves, plot_l1=True, errorevery=2)
plt.gcf().gca().axes.lines[-1].set_linestyle('--')
plt.gcf().gca().legend(loc='upper left')
plt.gcf().gca().get_legend().get_frame().set_alpha(0)
plt.savefig('fig2.pdf')
_ = ! pdfcrop 'fig2.pdf'
plt.show()
fig_3_pt_curves = [
{'delta': pt[0],
'rho': pt[1],
'yerr': np.vstack([pt_percentiles[curve_label][0, :], pt_percentiles[curve_label][3, :]]), # 90/10 errorbars
'label': algorithm_names['_'.join([curve_label.split('_')[k] for k in [0, -1]])]}
for curve_label, pt in sorted(pt_curves.items())
if all([select_criterion in curve_label
for select_criterion in ['32', 'DFTDWT']])]
magni.cs.phase_transition.plotting.plot_phase_transitions(fig_3_pt_curves, plot_l1=True, errorevery=2)
plt.gcf().gca().axes.lines[-1].set_linestyle('--')
plt.gcf().gca().legend(loc='upper left')
plt.gcf().gca().get_legend().get_frame().set_alpha(0)
plt.savefig('fig3.pdf')
_ = ! pdfcrop 'fig3.pdf'
plt.show()
ensemble_names = {'RandomDCT2D': 'SEM', 'DFTDWT': 'MRI'}
for matrix_ensemble in ['RandomDCT2D', 'DFTDWT']:
for problem_size in ['64', '128']:
print('{} for a problem size of {}^2'.format(ensemble_names[matrix_ensemble], problem_size))
disp_pt_curves = [
{'delta': pt[0],
'rho': pt[1],
'yerr': np.vstack([pt_percentiles[curve_label][0, :], pt_percentiles[curve_label][3, :]]), # 90/10 errorbars
'label': algorithm_names['_'.join([curve_label.split('_')[k] for k in [0, -1]])]}
for curve_label, pt in sorted(pt_curves.items())
if all([select_criterion in curve_label
for select_criterion in [problem_size, matrix_ensemble]])]
magni.cs.phase_transition.plotting.plot_phase_transitions(disp_pt_curves, plot_l1=True, errorevery=2)
plt.gcf().gca().axes.lines[-1].set_linestyle('--')
plt.gcf().gca().legend(loc='upper left')
plt.gcf().gca().get_legend().get_frame().set_alpha(0)
plt.show()