import tdg, torch
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1,1, figsize=(8,6))

tdg.references.EPJST217153.energy_comparison(ax)

alpha = torch.linspace(-0.9, 0.9, 1000)
# Show just the mean-field subtracted piece of eq (2)
ax.plot(alpha, (3-4 *torch.tensor(2.).log())*(alpha/2)**2, color='orange', label='eq. (2)')
ax.set_xlim([-0.9, 0.9])
ax.set_ylim([-0.025, 0.15])
ax.set_xlabel('α = 2g')
ax.set_ylabel('(E/N - MF)/E_FG')

inset = ax.inset_axes([0, 0.06, 0.8, 0.08], transform=ax.transData)
blue, gray = tdg.references.EPJST217153.conventional_figure_2()
inset.errorbar(blue[0], blue[3], yerr=blue[4], color='blue', marker='v', linestyle='none')
inset.errorbar(gray[0], gray[3], yerr=gray[4], color='gray', marker='o', linestyle='none')
inset.set_xlim([-0.9, 0.9])
inset.set_ylabel('E/N / E_FG')
# Show just the whole of eq (2)
inset.plot(alpha, 1+alpha + (3-4 *torch.tensor(2.).log())*(alpha/2)**2, color='orange')

ax.legend()