import matplotlib.pyplot as plt
import numpy as np

from kelp import Model, Planet, Filter

np.random.seed(42)

planet = Planet.from_name('HD 189733')
filt = Filter.from_name("IRAC 1")
filt.bin_down(10)

xi = np.linspace(-np.pi, np.pi, 50)

hotspot_offset = np.radians(-40)
alpha = 0.6
omega = 4.5
f = 2**-0.5
A_B = 0
lmax = 1
C = [[0],
     [0, 0.15, 0]]
obs_err = 50
model = Model(hotspot_offset, alpha, omega,
              A_B, C, lmax, planet=planet, filt=filt)
obs = model.thermal_phase_curve(xi, f=f).flux
obs += obs_err * np.random.randn(xi.shape[0])

def pc_model(p, x):
    """
    Phase curve model with two free parameters
    """
    offset, c_11, f = p
    C = [[0],
         [0, c_11, 0]]
    model = Model(hotspot_offset=offset, alpha=alpha,
                  omega_drag=omega, A_B=A_B, C_ml=C, lmax=1,
                  planet=planet, filt=filt)
    return model.thermal_phase_curve(x, f=f, check_sorted=False).flux

def lnprior(p):
    """
    Log-prior: sets reasonable bounds on the fitting parameters
    """
    offset, c_11, f = p

    if (offset > np.pi or offset < -np.pi or c_11 > 1 or c_11 < 0):
        return -np.inf

    return 0

def lnlike(p, x, y, yerr):
    """
    Log-likelihood: via the chi^2
    """
    return -0.5 * np.sum((pc_model(p, x) - y)**2 / yerr**2)

def lnprob(p, x, y, yerr):
    """
    Log probability: sum of lnlike and lnprior
    """
    lp = lnprior(p)

    if np.isfinite(lp):
        return lp + lnlike(p, x, y, yerr)
    return -np.inf


from scipy.optimize import minimize

initp = np.array([-0.7, 0.2, 2**-0.5])

bounds = [[-2, 0], [0.0, 1], [0.5, 0.85]]

soln = minimize(lambda *args: -lnprob(*args),
                initp, args=(xi, obs, obs_err),
                method='powell')

from emcee import EnsembleSampler
from multiprocessing import Pool
from corner import corner

ndim = 3
nwalkers = 2 * ndim

p0 = [soln.x + 0.1 * np.random.randn(ndim)
      for i in range(nwalkers)]

sampler = EnsembleSampler(nwalkers, ndim, lnprob,
                          args=(xi, obs, obs_err))
p1 = sampler.run_mcmc(p0, 100)
sampler.reset()
sampler.run_mcmc(p1, 100, progress=True)

corner(sampler.flatchain, truths=[hotspot_offset, C[1][1], 2**-0.5],
       labels=['$\Delta \phi$', '$C_{11}$', '$f$'])
plt.show()

p_map = sampler.flatchain[np.argmax(sampler.flatlnprobability)]

errkwargs = dict(color='k', fmt='.', ecolor='silver')
plt.errorbar(xi/np.pi, obs, obs_err, **errkwargs)
plt.plot(xi/np.pi, pc_model(p_map, xi), color='r')
plt.xlabel('$\\xi/\\pi$')
plt.ylabel('$\\rm F_p/F_s$')
plt.tight_layout()
plt.show()