Aasq GAMP Figures (Version 2)

Copyright (c) 2016-2017, Christian Schou Oxvig, Thomas Arildsen, and Torben Larsen
All rights reserved.

This is an updated version of the "Aasq GAMP Figures.ipynb" notebook available at http://dx.doi.org/10.5278/240710282. It has been updated in connection with a revision of the paper "Entrywise Squared Transforms for High Dimensional Signal Reconstruction via Generalized Approximate Message Passing" which is part of the thesis "Algorithms for Reconstruction of Undersampled Atomic Force Microscopy Images" available at http://dx.doi.org/10.5278/vbn.phd.engsci.00158. The notebook corresponding to the version of the paper which is part of the thesis is the original "Aasq GAMP Figures.ipynb" notebook.

This notebook shows figures detailing noiseless emperical phase transition curves for various Generalized Approximate Message Passing (GAMP) algorithms. The figures are based on the dataset "Generalized Approximate Message Passing Practical 2D Phase Transition Simulations Dataset 2" by Christian Schou Oxvig, Thomas Arildsen, and Torben Larsen licensed under CC BY 4.0 (http://creativecommons.org/licenses/by/4.0/). The full dataset along with its license conditions is available at https://dx.doi.org/10.5281/zenodo.574258.

Furthermore, parts of the results from "Replication of certain details from J. P. Vila and P. Schniter: "Expectation-Maximization Bernoulli-Gaussian Approximate Message Passing"" by Thomas Arildsen licensed under the BSD 2-Clause license are used in Figure 1. The full deposition along with its license conditions is available at https://doi.org/10.5281/zenodo.160700.

In [1]:
%matplotlib inline

from __future__ import division
import itertools

import cycler
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.io
import tables as tb

import magni

mpl.style.use('ggplot')
fig_size = mpl.rcParams['figure.figsize']
cb3 = magni.utils.plotting._ColourCollection(
    {'PuOr': ((230, 97, 1), (253, 184, 99), (94, 60, 153))})
cb5 = magni.utils.plotting._ColourCollection(
    {'BuYlRd': ((44, 123, 182), (171, 217, 233), (255, 255, 191), (253, 174, 97), (215, 25, 28))})
color_cycle = cycler.cycler('color', cb3['PuOr'])
style_cycle_errorbar = cycler.cycler('ls', magni.utils.plotting.linestyles[:2])
magni.utils.plotting.setup_matplotlib({'figure': {'dpi': 800, 'figsize': [fs * 2 for fs in fig_size]},
                                       'axes': {'prop_cycle': color_cycle * style_cycle_errorbar},
                                       'grid': {'alpha': 0.25}})
<matplotlib.figure.Figure at 0x7f90b82e5cf8>
In [2]:
#Monkey Patching of Phase Transition plot function to workaround Matplotlib 2.0.0 bug
magni_170_pt_plot = magni.cs.phase_transition.plotting.plot_phase_transitions
def matplotlib_200_mpatched_pt_plot(*args, **kwargs):
    magni_170_pt_plot(*args, **kwargs)
    linestyles = style_cycle_errorbar.by_key()['ls']
    
    if len(style_cycle_errorbar) == 1:
        pass
    
    elif len(style_cycle_errorbar) == 2:
        for line in plt.gcf().gca().axes.lines[1::2]:
            # Manually set linestyle for Matplotlib 2.0.0 errorbar function
            # See bug report: https://github.com/matplotlib/matplotlib/issues/7074
            if line.get_marker() == 'None':
                line.set_linestyle(linestyles[1])
     
    elif len(style_cycle_errorbar) == 3:
        for line in plt.gcf().gca().axes.lines[1::3]:
            # Manually set linestyle for Matplotlib 2.0.0 errorbar function
            # See bug report: https://github.com/matplotlib/matplotlib/issues/7074
            if line.get_marker() == 'None':
                line.set_linestyle(linestyles[1])

        for line in plt.gcf().gca().axes.lines[2::3]:
            # Manually set linestyle for Matplotlib 2.0.0 errorbar function
            # See bug report: https://github.com/matplotlib/matplotlib/issues/7074
            if line.get_marker() == 'None':
                line.set_linestyle(linestyles[2])
    
    else:
        raise ValueError('Unable to monkeypatch errorbar with more than 3 linestyles.')
        
magni.cs.phase_transition.plotting.plot_phase_transitions = matplotlib_200_mpatched_pt_plot

Data paths

In [3]:
# Paths to "gamp_practical_2d_phase_transitions_ID_[0-4]_of_5.hdf5" from https://dx.doi.org/10.5281/zenodo.574258.
# File                                                MD5SUM                            SHA256SUM
# gamp_practical_2d_phase_transitions_ID_0_of_5.hdf5  3eadb9c6840b1a857097c49547fa0675  18ba9a8647c9af9555aaab6e6bd5ee6adbd4f641ef98760f53d9e0fa272faa78
# gamp_practical_2d_phase_transitions_ID_1_of_5.hdf5  14dfff08e652fca19b8b185cfb73155d  d422a42e38fb9cf38a64123e9e401d7202933089812b507bec8bc03239b91978 
# gamp_practical_2d_phase_transitions_ID_2_of_5.hdf5  a112c9483eab9b8cbb52b11fac5ac7ff  76f3e28eabf1d10b47e87c707bac3ba738591ed35dd8c329ec47a44ccb36df1f
# gamp_practical_2d_phase_transitions_ID_3_of_5.hdf5  9fbeb6754c948d21cf44c40addb079e7  5c95f26cbe32c79543b33730648d28312fc9a7d1d706967f1108c63a6db448a7
# gamp_practical_2d_phase_transitions_ID_4_of_5.hdf5  f7f929ea747b0635d55654592b69524d  e21d3c10ba0d7fae3e3dde938964434a6ae5f9087dc199248519800b8b672ea8
hdf5_database_paths = ['gamp_practical_2d_phase_transitions_ID_0_of_5.hdf5',
                       'gamp_practical_2d_phase_transitions_ID_1_of_5.hdf5',
                       'gamp_practical_2d_phase_transitions_ID_2_of_5.hdf5',
                       'gamp_practical_2d_phase_transitions_ID_3_of_5.hdf5',
                       'gamp_practical_2d_phase_transitions_ID_4_of_5.hdf5',]


# Path to "vila2011_results_old_not_unif.mat" from https://doi.org/10.1109/ACSSC.2011.6190117
# MD5SUM: 2cf3713c31ddb69de84f3fc5eb3fe94b
# SHA256SUM: a7e6361a31099fa9abb7cb669f97c9ac1a3dddcbba23a8fafb89f95ffd19af6d
vila2011_reference_path = 'vila2011_results_old_not_unif.mat'

algorithm_names = {'GAMPFullwnBG_USE': 'USE for BG GAMP using $|\mathbf{A}|^{\\circ 2}$',
                   'GAMPFullwnBG_RandomDCT2D': 'SEM for BG GAMP using $|\mathbf{A}|^{\\circ 2}$',
                   'GAMPFullwnBG_EX3D': 'EX3D for BG GAMP using $|\mathbf{A}|^{\\circ 2}$',
                   'GAMPSAKrzaFrobwnBG_USE': 'USE for BG GAMP using $||\mathbf{A}||_F^2$ SA',
                   'GAMPSAKrzaFrobwnBG_RandomDCT2D': 'SEM for BG GAMP using $||\mathbf{A}||_F^2$ SA',
                   'GAMPSAKrzaFrobwnBG_EX3D': 'EX3D for BG GAMP using $||\mathbf{A}||_F^2$ SA'}

mismatches = [str(x).replace('.', 'd') + '_' for x in [0.75, 0.8, 0.85, 0.9, 0.95, 1.0, 1.05, 1.1, 1.15, 1.20, 1.5, 2.0, 3.0]]

Load data

In [4]:
skips = ['annotations', 'chases', 'fileset', 'gamp_pt_tools', 'parameter_values']

pt_curves = {}
pt_dists = {}
pt_times = {}
pt_norms = {}
pt_mses = {}
pt_percentiles = {}

for hdf5_database_path in hdf5_database_paths:
    with tb.File(hdf5_database_path, mode='r') as h5_file:
        for group in h5_file.walk_groups('/'):
            label = group._v_name
            if any([skip in label for skip in skips]) or group is h5_file.root:
                continue

            pt_dists[label] = h5_file.get_node('/' + '/'.join([label, 'dist'])).read()
            pt_times[label] = h5_file.get_node('/' + '/'.join([label, 'time'])).read()
            pt_norms[label] = h5_file.get_node('/' + '/'.join([label, 'norm'])).read()
            pt_mses[label] = h5_file.get_node('/' + '/'.join([label, 'mse'])).read()
            pt_percentiles[label] = h5_file.get_node('/' + '/'.join([label, 'phase_transition_percentiles'])).read()
            pt_curves[label] = magni.cs.phase_transition.io.load_phase_transition(hdf5_database_path, label=label)

Figure 1

All problem suites with genie knowledge of true signal prior and sum approximation constant

In [5]:
select_criterions = ['BG', 'gaussian', 'genie', mismatches[5]]
fig_1_pt_curves = []

for curve_label, pt in sorted(pt_curves.items()):
    if all([select_criterion in curve_label for select_criterion in select_criterions]):
        fig_1_pt_curves.append(
            {'delta': pt[0],
             'rho': pt[1],
             'yerr': np.vstack([pt_percentiles[curve_label][0, :], pt_percentiles[curve_label][3, :]]),  # 90/10 errorbars
             'label': algorithm_names['_'.join([curve_label.split('_')[k] for k in [0, 1]])]})

# Sort curves for display
fig_1_pt_curves = sorted(fig_1_pt_curves, key=lambda t: t['label'].replace('$|\mathbf{A}|^{\\circ 2}', '$$'))
        
# Reference markers
vila_2011_reference_data = scipy.io.loadmat(vila2011_reference_path)
vila_2011_reference_pt_curves = [
    {'delta': vila_2011_reference_data['delta_values'].ravel(),
     'rho': vila_2011_reference_data['transition_rho_values_EMBGAMP'].ravel(),
     'label': 'EM-BG-GAMP (USE)',
     'style': {'color': (0.99, 0.68, 0.38), 'marker': '*', 'linestyle':'None', 'ms': 12, 'zorder': 10, 'color': '#b2abd2'}}
]

# Plot everything
magni.cs.phase_transition.plotting.plot_phase_transitions(
    fig_1_pt_curves, plot_l1=True, errorevery=2, reference_curves=vila_2011_reference_pt_curves)

plt.gcf().gca().legend(loc='upper left')
plt.gcf().gca().get_legend().get_frame().set_alpha(0)
handles, labels = plt.gcf().gca().axes.get_legend_handles_labels()

plt.savefig('fig1.pdf')
_ = ! pdfcrop 'fig1.pdf'

plt.show()