import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 8})
plt.rcParams.update({'pdf.fonttype': 42})
plt.close('all')

from iminuit import Minuit
from scipy.special import erfc

name = '299_298_all_in_one_plot'

data_a = np.loadtxt('Zn-run_298.csv', skiprows=1, delimiter=',')
data_b = np.loadtxt('Zn-run_299.csv', skiprows=1, delimiter=',')
data_b[:,1] += 1 + data_a[:,1].max()
data_all = np.concatenate((data_a, data_b))

# remove Cu and CaF counts far away from Zn
precut1 = (1.88069e7, 1.88075e7)
data1 = data_all[data_all[:,0] < precut1[1]]
data1 = data1[data1[:,0] > precut1[0]]

# remove Zn and CaF counts far away from Cu
precut2 = (1.8806e7, 1.88066e7)
data2 = data_all[data_all[:,0] < precut2[1]]
data2 = data2[data2[:,0] > precut2[0]]

# only use first part of file 299 with short BG for CaF
# data3 = data_all[data_all[:,1] < 680+1+data_a[:,1].max()]
''' interestingly we get (nearly) the same result when just using all data,
only a single count of CaF is z=1 in the higher BG part of the file'''
data3 = data_all
# remove Cu and Zn counts far away from CaF
precut3 = (1.8809e7, 1.88104e7)
data3 = data3[data3[:,0] < precut3[1]]
data3 = data3[data3[:,0] > precut3[0]]

# remove counts too far from actual peak
right = 230
left = 130
def make_t_array(data):
    # z=1 cut
    sweep, index, zclass = np.unique(data[:,1], return_counts=True, return_index=True)
    data_cut = data[ index[zclass<=1] ]
    # data for "horizontal" MLE
    t = np.sort( data_cut[0:,0] )
    t_ = t[t > t.mean()-left]
    t_ = t[t < t.mean()+right]
    cut_left = t_.mean() - left
    cut_right = t_.mean() + right
    t = t[t > cut_left]
    t = t[t < cut_right]
    return t, cut_left, cut_right
    
t1, cut_left1, cut_right1 = make_t_array(data1)
t2, cut_left2, cut_right2 = make_t_array(data2)
t3, cut_left3, cut_right3 = make_t_array(data3)

print('______________')
print('counts per signal after ToF and z-cut:')
print('Zn:', len(t1), '\nCu:', len(t2), '\nCaF:', len(t3))
    
binwidth = 8
def make_ts_ns_binned(t, cut_left, cut_right, binwidth=8):
# data for "vertical" least squares and plot, binwidth in ns
    # bins = np.arange(t.mean()-left-binwidth/2, t.mean()+right+binwidth/2, binwidth)
    bins = np.arange(cut_left-binwidth/2, cut_right+binwidth/2, binwidth)
    histogram = np.histogram(t, bins=bins)
    xs = histogram[0]
    ts = histogram[1][:-1]+binwidth/2
    dxs_sq = np.array( xs, dtype=float )
    dxs_sq[dxs_sq==0] = 1 # set weight as 1 for all empty bins
    return xs, ts, dxs_sq

xs1, ts1, dxs_sq1 = make_ts_ns_binned(t1, cut_left1, cut_right1, binwidth=binwidth)
xs2, ts2, dxs_sq2 = make_ts_ns_binned(t2, cut_left2, cut_right2, binwidth=binwidth)
xs3, ts3, dxs_sq3 = make_ts_ns_binned(t3, cut_left3, cut_right3, binwidth=binwidth)

def hyper_EMG(x, σ, μ, τ1p):
    Δ = x - μ
    sigp1 = σ / (2**0.5 * τ1p)
    dsig = Δ/(2**0.5 * σ)
    return 1/(2*τ1p) * np.exp(sigp1**2 - Δ/τ1p) * erfc(sigp1 - dsig)

def negloglik(σ, C_ToF, μ2, μ3, τ1p):
    μ1 = C_ToF*(μ2-μ3) + (μ2+μ3)/2
    negloglik1 = - np.sum( np.log( hyper_EMG(t1, σ, μ1, τ1p) ) )
    negloglik2 = - np.sum( np.log( hyper_EMG(t2, σ, μ2, τ1p) ) )
    negloglik3 = - np.sum( np.log( hyper_EMG(t3, σ, μ3, τ1p) ) )
    return negloglik1 + negloglik2 + negloglik3

def redChi2Poisson(x, t, σ, μ, τ1p, A, params=2):
    E = A * hyper_EMG(t, σ, μ, τ1p) # expected values
    O = x.copy() # observed values
    O1 = x.copy()
    O1[O1==0] = 1 # need this so we get 0s instead of NANs in our Chi array
    Chi = 2 * (E - O + O*np.log(O1/E)) # Chi array
    return Chi.sum() / (len(x) - params) # reduced poisson chi^2


μ10 = t1.mean()-50
μ20 = t2.mean()-50
μ30 = t3.mean()-50
C_ToF0 = 0.5*(2*μ10-μ20-μ30) / (μ20-μ30)
minimizer = Minuit( negloglik, 
                    σ=20, C_ToF=C_ToF0, μ2=μ20, μ3=μ30, τ1p=46)
minimizer.errordef = Minuit.LIKELIHOOD
minimizer.limits = [(1e-5, None), 
                    (-1, 1), 
                    (μ20-100, μ20+100), 
                    (μ30-100, μ30+100), 
                    (1e-5, 1000)]
minimizer.migrad()
minimizer.hesse()
minimizer.minos()

print('______________')
print('fit success:',minimizer.valid)

σ = minimizer.values['σ']
τ1p = minimizer.values['τ1p']
C_ToF = minimizer.values['C_ToF']
C_ToF_err = minimizer.errors['C_ToF']
μ2 = minimizer.values['μ2']
μ3 = minimizer.values['μ3']
μ1 = C_ToF*(μ2-μ3) + (μ2+μ3)/2
μ = [μ1, μ2, μ3]
μerr = [10, minimizer.errors['μ2'], minimizer.errors['μ3']]
species = ['$^{61}$Zn$^{+}$', '$^{61}$Cu$^{+}$', '$^{42}$Ca$^{19}$F$^{+}$']
datalabel = ['binned data', None, None]
fitlabel = [None, 'MLE EMG fit', None]
regionlabel = [None, 'ToF cut', None]

t = [t1, t2, t3]
ts = [ts1, ts2, ts3]
xs = [xs1, xs2, xs3]
cut_left = [cut_left1, cut_left2, cut_left3]
cut_right = [cut_right1, cut_right2, cut_right3]

from  matplotlib import colors
lightfirebrick = np.array(colors.to_rgb('firebrick'))*0.15+np.full(3, 1-0.15)

cm = 1 / 2.54
figsize = (9.3*cm, 3.5*cm)
fig, ax = plt.subplots(figsize=figsize)
for i in range(3):
    plt.fill_betweenx(y=(0,300), x1=cut_left[i]/1e3-18800, x2=cut_right[i]/1e3-18800, 
                      color=lightfirebrick, label=regionlabel[i], lw=0)
    
    plt.bar(ts[i]/1e3-18800, xs[i], width=binwidth/1e3, fc='grey', label=datalabel[i])
    
    center1 = str(round(μ[i],1))+'±'+str(round(μerr[i],1))
    
    ts_plot = np.linspace(cut_left[i], cut_right[i], num=70)
    plt.plot(ts_plot/1e3-18800, 
             hyper_EMG(ts_plot, σ, μ[i], τ1p) *len(t[i])*binwidth,
             lw=0.8, c='firebrick', alpha=0.7,
             label=fitlabel[i])
    
    plt.text(μ[i]/1e3-18800+0.05, xs[i].max()*1.5, species[i], ha='center')
    

plt.xlabel('(ToF - 18800 µs) / µs')
plt.ylabel('counts / '+str(binwidth)+' ns bin')
plt.ylim(bottom = 0.4)
plt.xlim((np.min(cut_left)-200)/1e3-18800, (np.max(cut_right)+200)/1e3-18800)
plt.minorticks_on()
plt.yscale('log')

plt.legend(loc=(0.4,0.3),
           title='$C_\mathrm{ToF}=$'+str(round(C_ToF, 4))+'('+str(round(C_ToF_err*1e4))+')')

plt.savefig(name+'.pdf', bbox_inches='tight', pad_inches=0.01)

# calculate mass excess

m_e = 0.0005485799
m_61Cu = 60.9334574
m_CaF =  60.95702094

m1 = m_61Cu
m2 = m_CaF

dC_ToF = C_ToF_err

m2_anion = m2 - m_e
m1_anion = m1 - m_e

delta_ref = m1_anion**0.5 - m2_anion**0.5
sum_ref = m1_anion**0.5 + m2_anion**0.5
m_Zn = ( C_ToF * delta_ref + 0.5 * sum_ref )**2 + m_e

dm_Zn = ( ( dC_ToF * 2*delta_ref*(delta_ref*C_ToF+0.5*sum_ref) )**2 )**0.5

amu_keV = 931494.102

ME_Zn = (m_Zn-61) * amu_keV
dME_Zn = dm_Zn*amu_keV

# print results
print('______________')
print(' C_ToF', round(C_ToF, 5))
print('dC_ToF', round(dC_ToF, 5))
print('______________')
print(' ME:', round(ME_Zn, 1))
print('dME:', round(dME_Zn, 1))
