Skip to content

Introduction article Cross sections of Posterior plot has legend without patches #3016

@nipunbatra

Description

@nipunbatra

The introduction page in Cell 23 has contour plots showing the cross sections of Posterior.

intro_long_71_0

The legend does not have the appropriate patches to help distinguish between the two guides (SVI Diagonal Normal and SVI MV Normal)

I think there are a couple of ways to solve this.

Approach 1 (using Pandas and hue= in Seaborn)

svi_samples_df = pd.DataFrame({k: v.detach().cpu().numpy() for k, v in samples.items()})
svi_mvn_samples_df = pd.DataFrame({k: v.detach().cpu().numpy() for k, v in mvn_samples.items()})

svi_samples_df['Guide'] = 'Diagonal Normal'
svi_mvn_samples_df['Guide'] = 'Multivariate Normal'

svi_all_df = pd.concat([svi_samples_df, svi_mvn_samples_df])
svi_all_df = svi_all_df.reset_index()

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
fig.suptitle("Cross-sections of the Posterior Distribution", fontsize=16)
sns.kdeplot(data = svi_all_df, x = "bA", y="bR", ax=axs[0], hue='Guide',  shade=True, alpha=0.5)
sns.kdeplot(data = svi_all_df, x = "bR", y="bAR", ax=axs[1], hue='Guide', shade=False)

This produces the following (I shaded the two subplots differently on purpose)

1

Approach 2: Create the patch for the legend

svi_samples = {k: v.detach().cpu().numpy() for k, v in samples.items()}
svi_mvn_samples = {k: v.detach().cpu().numpy() for k, v in mvn_samples.items()}

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
fig.suptitle("Cross-sections of the Posterior Distribution", fontsize=16)
sns.kdeplot(svi_samples["bA"], y=svi_samples["bR"], ax=axs[0], )
sns.kdeplot(svi_mvn_samples["bA"], y=svi_mvn_samples["bR"], ax=axs[0], shade=True )
axs[0].set(xlabel="bA", ylabel="bR", xlim=(-2.5, -1.2), ylim=(-0.5, 0.1))

sns.kdeplot(svi_samples["bR"], y=svi_samples["bAR"], ax=axs[1], )
sns.kdeplot(svi_mvn_samples["bR"], y=svi_mvn_samples["bAR"], ax=axs[1], shade=True )
axs[1].set(xlabel="bR", ylabel="bAR", xlim=(-0.45, 0.05), ylim=(-0.15, 0.8))


for label, color in zip(["SVI (Diagonal Normal)", "SVI (Multivariate Normal)"], sns.color_palette()[:2]):
    plt.plot([], [],
                label=label, color=color)
fig.legend(loc='upper right')

2

Bandwidth adjustment

Another thing to consider might be the bandwidth adjustment param in KDE plot. Here is the plot with bw_adjustment = 4.

3

svi_samples = {k: v.detach().cpu().numpy() for k, v in samples.items()}
svi_mvn_samples = {k: v.detach().cpu().numpy() for k, v in mvn_samples.items()}

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
fig.suptitle("Cross-sections of the Posterior Distribution", fontsize=16)
sns.kdeplot(svi_samples["bA"], y=svi_samples["bR"], ax=axs[0], bw_adjust=4 )
sns.kdeplot(svi_mvn_samples["bA"], y=svi_mvn_samples["bR"], ax=axs[0], shade=True, bw_adjust=4 )
axs[0].set(xlabel="bA", ylabel="bR", xlim=(-2.5, -1.2), ylim=(-0.5, 0.1))

sns.kdeplot(svi_samples["bR"], y=svi_samples["bAR"], ax=axs[1],bw_adjust=4 )
sns.kdeplot(svi_mvn_samples["bR"], y=svi_mvn_samples["bAR"], ax=axs[1], shade=True, bw_adjust=4 )
axs[1].set(xlabel="bR", ylabel="bAR", xlim=(-0.45, 0.05), ylim=(-0.15, 0.8))


for label, color in zip(["SVI (Diagonal Normal)", "SVI (Multivariate Normal)"], sns.color_palette()[:2]):
    plt.plot([], [],
                label=label, color=color)
fig.legend(loc='upper right')

The benefit of bw_adjustment might be to make the contours smoothers. It would be easier to identify that the Diagonal Normal is indeed axis-aligned while the MVN has non-zero covariance terms and thus not axis aligned.

Let me know if you'd like me to make a PR and if so with which approach (1 or 2) and with or without bandwidth adjustment.

I'm assuming I'd have to modify this notebook. Would the PR need anything else?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions