import numpy as np
from scipy import sparse

########## read and bin data #################################################

class MCS6_DATA:
    
    def __init__(self, file_names, file_type='.mpa', resolution=0.1):
        self.file_names = file_names
        self.file_type = file_type
        self.resolution = resolution
        
        self.data, self.bitshift, self.sweeps = self.__read_file()
        
        self.a, self.b = 0, 0
        self.calibrated = False
        self.calibrate_what = None
        
    def __read_file(self):
        '''reads data file and returns csr sparse array of data, bitshift and 
        number of sweeps'''
    
        sweeps = 0
        data_raw = []
        for file_name in self.file_names:
            lines, bitshift, cycles, sweeps_, range_ = self.__get_header_values(file_name)
            sweeps += sweeps_
            data_raw.append(np.loadtxt(file_name + self.file_type, delimiter=' ', 
                                        dtype='i', skiprows = lines, ndmin=2))
            ### if your numpy version is < 1.23: use the pandas function below instead,
            ### the old loadtxt was not imlplemented in C and >10x slower
            ### or even better: just update numpy
# =============================================================================
#             data_raw.append(read_csv(file_name + self.file_type, sep=' ', 
#                                      dtype='i', skiprows=lines).to_numpy())
# =============================================================================
        
        for i, dat in enumerate(data_raw):
            if i == 0:
                data = self.__make_sparse(dat, range_, cycles)
            else:
                data += self.__make_sparse(dat, range_, cycles)
        
        return data, bitshift, sweeps

    def __get_header_values(self, file_name):
        
        file = open(file_name + self.file_type, 'r')
        lines = 0
        channel = 1
        for line in file: #measures length of header and saves important header values
            lines += 1
            if '[CHN2]' in line:
                channel = 2
            if channel == 1:
                if 'bitshift' in line:
                    bitshift = float(line.replace('bitshift=', ''))
                if 'cycles' in line:
                    cycles = int(line.replace('cycles=', ''))
                if 'SWEEPS' in line:
                    sweeps = int(line.replace('SWEEPS: ', ''))
                if 'swpreset' in line:
                    swpreset = int(line.replace('swpreset=', ''))
                if 'range' in line:
                    range_ = int(line.replace('range=', ''))
                if 'caloff' in line:
                    caloff = float(line.replace('caloff=', ''))
                if 'REALTIME' in line:
                    realtime = float(line.replace('REALTIME: ', ''))
                if 'REPORT-FILE from ' in line:
                    starttime = str(line.replace('REPORT-FILE from ', ''))[:23]
            if '[DATA]' in line:
                break
        file.close()
        
        self.caloff, self.realtime, self.starttime = caloff, realtime, starttime
        self.caloff /= 1000 # convert from ns to us
        
        return lines, bitshift, cycles, sweeps, range_

    def __make_sparse(self, data_raw, range_, cycles):
        data = sparse.csr_array( (data_raw[:,2], (data_raw[:,1],data_raw[:,0])), 
                                 shape=(cycles,range_), dtype='i')
        return data

    def bin_data(self, data, binsum=1):
        '''bins scipy sparse data along time axis and converts to numpy array'''
        
        if binsum == 1:
            return data.toarray()
        elif len( data.shape ) == 1:
            length = int((data.shape[0] - (data.shape[0] % binsum))/binsum)
            data_new = data[:length*binsum].reshape((length,binsum))
            data_new = data_new.sum(axis=1)
            data_new = np.asarray(data_new)
            return data_new
        else:
            length = int((data.shape[1] - (data.shape[1] % binsum))/binsum)
            data_new = data[:,:length*binsum].reshape((length*data.shape[0],binsum))
            data_new = data_new.sum(axis=1).reshape((data.shape[0],length))
            data_new = np.asarray(data_new)
            # if data_new.shape[0] == 1:
            #     data_new = np.squeeze(np.asarray(data_new))
            return data_new
        
    def add_slices(self, N=1):
        '''adds together every N-th slice
        
        used the sparse library for this as scipy.sparse only supports 2D arrays, 
        but for efficient calculation a 3D intermediate step is necessary'''
        import sparse as sparse_
        s = sparse_.COO.from_scipy_sparse(self.data[:int(self.data.shape[0]/N)*N])
        s = s.reshape((int(s.shape[0]/N), -1, s.shape[-1]))
        s = s.sum(axis=0)
        s = s.tocsr()
        self.data = s
    
    def add_all_slices(self):
        '''adds together all slices'''
        s = self.data.sum(axis=0)
        s = sparse.csr_array(s)
        self.data = s
        
    
    ### converting between time and index ###
    
    def index_time(self, index, binsum=1):
        '''converts index to time in μs'''
        factor = self.resolution*2**self.bitshift/1000
        index = index*binsum #+ 0.5*(binsum-1)
        return index*factor

    def time_index(self, time, binsum=1):
        '''converts time in μs to index'''
        factor = self.resolution*2**self.bitshift/1000
        index = time/factor/binsum #- 0.5 + binsum*0.5
        try:
            return index.astype(int)
        except:
            return int(index)
        
    ### converting between time and cluster size / mass ###
    
    def calibrate(self, t, m, what='cluster size n'):
        self.a = (t[0] - t[1]) / (np.sqrt(m[0]) - np.sqrt(m[1]))
        self.b = 0.5 * (t[0]+t[1]-((np.sqrt(m[0])+np.sqrt(m[1]))/
                                   (np.sqrt(m[0])-np.sqrt(m[1])))*(t[0]-t[1]))
        self.calibrated = True
        self.calibrate_what = what

    def mass_time(self, m_q):
        return self.a * np.sqrt(m_q) + self.b
    
    def time_mass(self, t):
        return ((t-self.b)/self.a)**2