import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import binned_statistic

import theano.tensor as tt
import exoplanet as xo
import pymc3 as pm
import pymc3_ext as pmx

from corner import corner

from lightkurve import search_lightcurve

import astropy.units as u
from astropy.constants import R_sun, R_earth
from astropy.stats import sigma_clip, mad_std

from kelp import Filter
from kelp.theano import reflected_phase_curve, thermal_phase_curve

floatX = 'float64'
t0 = 2454954.357462  # Bonomo 2017
period = 2.204740    # Stassun 2017
rp = 16.9 * R_earth  # Stassun 2017
rstar = 1.991 * R_sun  # Berger 2017
a = 4.13 * rstar     # Stassun 2017
duration = 4.0398 / 24  # Holczer 2016
b = 0.4960           # Esteves 2015
rho_star = 0.27 * u.g / u.cm ** 3  # Stassun 2017
T_s = 6449           # Berger 2018

a_rs = float(a / rstar)
a_rp = float(a / rp)
rp_rstar = float(rp / rstar)
eclipse_half_dur = duration / period / 2

lcf = search_lightcurve(
    "HAT-P-7", mission="Kepler", cadence="long", quarter=10
).download_all()

slc = lcf.stitch()

phases = ((slc.time.jd - t0) % period) / period
in_eclipse = np.abs(phases - 0.5) < eclipse_half_dur
in_transit = (phases < 1.5 * eclipse_half_dur) | (
            phases > 1 - 1.5 * eclipse_half_dur)
out_of_transit = np.logical_not(in_transit)

slc = slc.flatten(
    polyorder=3, break_tolerance=10, window_length=1001, mask=~out_of_transit
).remove_nans()

phases = ((slc.time.jd - t0) % period) / period
in_eclipse = np.abs(phases - 0.5) < eclipse_half_dur
in_transit = (phases < 1.5 * eclipse_half_dur) | (
            phases > 1 - 1.5 * eclipse_half_dur)
out_of_transit = np.logical_not(in_transit)

sc = sigma_clip(
    np.ascontiguousarray(slc.flux[out_of_transit], dtype=np.float64),
    maxiters=100, sigma=8, stdfunc=mad_std
)

phase = np.ascontiguousarray(
    phases[out_of_transit][~sc.mask], dtype=np.float64
)
time = np.ascontiguousarray(
    slc.time.jd[out_of_transit][~sc.mask], dtype=np.float64
)

bin_in_eclipse = np.abs(phase - 0.5) < eclipse_half_dur
unbinned_flux_mean = np.mean(sc[~sc.mask].data)  # .mean()

unbinned_flux_mean_ppm = 1e6 * (unbinned_flux_mean - 1)
flux_normed = np.ascontiguousarray(
    1e6 * (sc[~sc.mask].data / unbinned_flux_mean - 1.0), dtype=np.float64
)
flux_normed_err = np.ascontiguousarray(
    1e6 * slc.flux_err[out_of_transit][~sc.mask].value, dtype=np.float64
)

bins = 100
bs = binned_statistic(
    phase, flux_normed, statistic=np.median, bins=bins
)

bs_err = binned_statistic(
    phase, flux_normed_err,
    statistic=lambda x: 3 * np.median(x) / len(x) ** 0.5, bins=bins
)

binphase = 0.5 * (bs.bin_edges[1:] + bs.bin_edges[:-1])
# Normalize the binned fluxes by the in-eclipse flux:
binflux = bs.statistic - np.median(bs.statistic[np.abs(binphase - 0.5) < 0.01])
binerror = bs_err.statistic

filt = Filter.from_name("Kepler")
filt.bin_down(6)   # This speeds up integration by orders of magnitude
filt_wavelength, filt_trans = filt.wavelength.to(u.m).value, filt.transmittance

with pm.Model() as model:
    # Define a Keplerian orbit using `exoplanet`:
    orbit = xo.orbits.KeplerianOrbit(
        period=period, t0=0, b=b, rho_star=rho_star.to(u.g / u.cm ** 3),
        r_star=float(rstar / R_sun)
    )

    # Compute the eclipse model (no limb-darkening):
    eclipse_light_curves = xo.LimbDarkLightCurve([0, 0]).get_light_curve(
        orbit=orbit._flip(rp_rstar), r=orbit.r_star,
        t=binphase * period,
        texp=(30 * u.min).to(u.d).value
    )

    # Normalize the eclipse model to unity out of eclipse and
    # zero in-eclipse
    eclipse = 1 + pm.math.sum(eclipse_light_curves, axis=-1)

    # Define reflected light phase curve model according to
    # Heng, Morris & Kitzmann (2021)
    omega = pm.Uniform('omega', lower=0, upper=1)
    g = pm.TruncatedNormal('g', lower=0, upper=1, mu=0, sigma=0.4)

    reflected_ppm, A_g, q = reflected_phase_curve(binphase, omega, g, a_rp)

    # Define the ellipsoidal variation parameterization (simple sinusoid)
    ellipsoidal_amp = pm.Uniform('ellip_amp', lower=0, upper=50)
    ellipsoidal_model_ppm = - ellipsoidal_amp * tt.cos(
        4 * np.pi * (binphase - 0.5)) + ellipsoidal_amp

    # Define the doppler variation parameterization (simple sinusoid)
    doppler_amp = pm.Uniform('doppler_amp', lower=0, upper=50)
    doppler_model_ppm = doppler_amp * tt.sin(
        2 * np.pi * binphase)

    # Define the thermal emission model according to description in
    # Morris et al. (in prep)
    xi = 2 * np.pi * (binphase - 0.5)
    n_phi = 75
    n_theta = 5
    phi = np.linspace(-2 * np.pi, 2 * np.pi, n_phi, dtype=floatX)
    theta = np.linspace(0, np.pi, n_theta, dtype=floatX)
    theta2d, phi2d = np.meshgrid(theta, phi)

    ln_C_11_kepler = -2.6
    C_11_kepler = tt.exp(ln_C_11_kepler)
    hml_eps = 0.72
    hml_f = (2/3 - hml_eps * 5 / 12) ** 0.25
    delta_phi = 0

    A_B = 0.5

    # Compute the thermal phase curve with zero phase offset
    thermal, T = thermal_phase_curve(
        xi, delta_phi, 4.5, 0.575, C_11_kepler, T_s, a_rs, 1 / a_rp, A_B,
        theta2d, phi2d, filt_wavelength, filt_trans, 2 ** -0.5
    )

    # Define the composite phase curve model
    flux_norm = eclipse * (
            reflected_ppm + ellipsoidal_model_ppm +
            doppler_model_ppm + 1e6 * thermal
    )

    # Keep track of the geometric albedo and integral phase function at
    # each step in the chain
    pm.Deterministic('A_g', A_g)
    pm.Deterministic('q', q)

    # Define the likelihood
    pm.Normal('obs', mu=flux_norm, sigma=binerror, observed=binflux)

    # Optimize a fast maximum-likelihood solution to seed posterior draws:
    map_soln = pm.find_MAP()

with model:
    trace = pmx.sample(
        draws=1000, tune=50, start=map_soln, compute_convergence_checks=False,
        target_accept=0.95, initial_accept=0.2,
        return_inferencedata=False,
        cores=1, chains=1
    )

with model:
    corner(pm.trace_to_dataframe(trace));
    plt.show()

plt.errorbar(binphase, binflux, binerror, fmt='.', color='k', ecolor='silver')

with model:
    for i, sample in enumerate(pmx.get_samples_from_trace(trace, size=10)):
        plt.plot(binphase, pmx.eval_in_model(flux_norm, sample), alpha=0.5,
                 color='r', zorder=10)

        plt.plot(binphase, pmx.eval_in_model(reflected_ppm, sample),
                 color='DodgerBlue', zorder=10, label='reflected' if i==0 else None)
        plt.plot(binphase, pmx.eval_in_model(1e6 * thermal, sample), color='m',
                 zorder=10, label='thermal' if i==0 else None)
        plt.plot(binphase, pmx.eval_in_model(ellipsoidal_model_ppm, sample),
                 color='b', zorder=10, label='ellipsoidal' if i==0 else None)
        plt.plot(binphase, pmx.eval_in_model(doppler_model_ppm, sample), color='g',
                 zorder=10, label='doppler' if i==0 else None)

plt.legend()
plt.ylim([-30, 120])
for sp in ['right', 'top']:
    plt.gca().spines[sp].set_visible(False)
plt.gca().set(xlabel='Phase', ylabel='$F_p/F_\mathrm{star}$ [ppm]',
              title='HAT-P-7 b')
plt.show()