-
-
Notifications
You must be signed in to change notification settings - Fork 995
Description
The introduction page in Cell 23 has contour plots showing the cross sections of Posterior.
The legend does not have the appropriate patches to help distinguish between the two guides (SVI Diagonal Normal and SVI MV Normal)
I think there are a couple of ways to solve this.
Approach 1 (using Pandas and hue=
in Seaborn)
svi_samples_df = pd.DataFrame({k: v.detach().cpu().numpy() for k, v in samples.items()})
svi_mvn_samples_df = pd.DataFrame({k: v.detach().cpu().numpy() for k, v in mvn_samples.items()})
svi_samples_df['Guide'] = 'Diagonal Normal'
svi_mvn_samples_df['Guide'] = 'Multivariate Normal'
svi_all_df = pd.concat([svi_samples_df, svi_mvn_samples_df])
svi_all_df = svi_all_df.reset_index()
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
fig.suptitle("Cross-sections of the Posterior Distribution", fontsize=16)
sns.kdeplot(data = svi_all_df, x = "bA", y="bR", ax=axs[0], hue='Guide', shade=True, alpha=0.5)
sns.kdeplot(data = svi_all_df, x = "bR", y="bAR", ax=axs[1], hue='Guide', shade=False)
This produces the following (I shaded the two subplots differently on purpose)
Approach 2: Create the patch for the legend
svi_samples = {k: v.detach().cpu().numpy() for k, v in samples.items()}
svi_mvn_samples = {k: v.detach().cpu().numpy() for k, v in mvn_samples.items()}
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
fig.suptitle("Cross-sections of the Posterior Distribution", fontsize=16)
sns.kdeplot(svi_samples["bA"], y=svi_samples["bR"], ax=axs[0], )
sns.kdeplot(svi_mvn_samples["bA"], y=svi_mvn_samples["bR"], ax=axs[0], shade=True )
axs[0].set(xlabel="bA", ylabel="bR", xlim=(-2.5, -1.2), ylim=(-0.5, 0.1))
sns.kdeplot(svi_samples["bR"], y=svi_samples["bAR"], ax=axs[1], )
sns.kdeplot(svi_mvn_samples["bR"], y=svi_mvn_samples["bAR"], ax=axs[1], shade=True )
axs[1].set(xlabel="bR", ylabel="bAR", xlim=(-0.45, 0.05), ylim=(-0.15, 0.8))
for label, color in zip(["SVI (Diagonal Normal)", "SVI (Multivariate Normal)"], sns.color_palette()[:2]):
plt.plot([], [],
label=label, color=color)
fig.legend(loc='upper right')
Bandwidth adjustment
Another thing to consider might be the bandwidth adjustment param in KDE plot. Here is the plot with bw_adjustment = 4
.
svi_samples = {k: v.detach().cpu().numpy() for k, v in samples.items()}
svi_mvn_samples = {k: v.detach().cpu().numpy() for k, v in mvn_samples.items()}
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
fig.suptitle("Cross-sections of the Posterior Distribution", fontsize=16)
sns.kdeplot(svi_samples["bA"], y=svi_samples["bR"], ax=axs[0], bw_adjust=4 )
sns.kdeplot(svi_mvn_samples["bA"], y=svi_mvn_samples["bR"], ax=axs[0], shade=True, bw_adjust=4 )
axs[0].set(xlabel="bA", ylabel="bR", xlim=(-2.5, -1.2), ylim=(-0.5, 0.1))
sns.kdeplot(svi_samples["bR"], y=svi_samples["bAR"], ax=axs[1],bw_adjust=4 )
sns.kdeplot(svi_mvn_samples["bR"], y=svi_mvn_samples["bAR"], ax=axs[1], shade=True, bw_adjust=4 )
axs[1].set(xlabel="bR", ylabel="bAR", xlim=(-0.45, 0.05), ylim=(-0.15, 0.8))
for label, color in zip(["SVI (Diagonal Normal)", "SVI (Multivariate Normal)"], sns.color_palette()[:2]):
plt.plot([], [],
label=label, color=color)
fig.legend(loc='upper right')
The benefit of bw_adjustment
might be to make the contours smoothers. It would be easier to identify that the Diagonal Normal is indeed axis-aligned while the MVN has non-zero covariance terms and thus not axis aligned.
Let me know if you'd like me to make a PR and if so with which approach (1 or 2) and with or without bandwidth adjustment.
I'm assuming I'd have to modify this notebook. Would the PR need anything else?