import matplotlib.pyplot as plt
import numpy as np
import astropy.units as u
from kelp import Model, Filter, Planet

import astropy.units as u
from astropy.constants import R_jup, R_sun

filt = Filter.from_name('IRAC 1')

n_phi = 100
n_theta = 100
lmax = 3

p = Planet.from_name('HD 189733')

def indexer(l, m):
    C_ml = [[0],
            [0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0]]
    C_ml[l][m] = 1
    return C_ml

def generate_temp_map(a, l, m):
    hotspot_offset = 0
    C_ml = indexer(l, m)


    alpha = 0.9
    omega_drag = 1.5
    rp_a = float(R_jup / (a * u.AU))
    a_rs = float(a * u.AU / R_sun)
    A_B = 0
    T_s = 5770

    model = Model(hotspot_offset, alpha, omega_drag, A_B,
                   C_ml, lmax, a_rs=a_rs, rp_a=rp_a, T_s=T_s, filt=filt)

    phase_offset = np.pi / 2
    f = 1 / np.sqrt(2)
    T, theta, phi = model.temperature_map(n_theta, n_phi, f=f)
    return T, theta, phi

cml_example = indexer(1, 0)

fig = plt.figure(figsize=(10, 4))

ax = np.array(
    [fig.add_subplot(
        len(cml_example),
        len(cml_example[-1]),
        1+i,
        projection="mollweide")
     for i in range(len(cml_example[-1]) * len(cml_example))]
).reshape((len(cml_example), len(cml_example[-1])))

for l in range(0, lmax + 1):
    for m in range(-l, l + 1):
        temperature, theta, phi = generate_temp_map(0.17, l, m)
        phirange = (-np.pi <= phi) & (np.pi >= phi)
        cax = ax[l, m + len(cml_example[-1])//2].pcolormesh(
            phi[phirange], (theta - np.pi/2), temperature[:, phirange],
            rasterized=True
        )
        ax[l, m + len(cml_example[-1])//2].set_title(f'$m = {m},\,\ell = {l}$')
        ax[l, m + len(cml_example[-1])//2].grid(False)

for i in range(len(cml_example)):
    for j in range(len(cml_example[-1])):
        ax[i, j].axis('off')

plt.tight_layout(h_pad=0.8, w_pad=0.3)
plt.show()