import numpy as np
import pandas as pd
import tensorflow as tf
from scipy import signal
from scipy.interpolate import interp1d
import os

import UnitConverter

class ExpDataProcessing:
    """Process experimental TDS data from file with flexible y-axis units"""

    def __init__(self, file_name, temp_units, y_units, y_type, material, hyperparameters):
        self.file_name = file_name
        self.temp_units = temp_units  # units of temp data
        self.y_units = y_units       # units of y-axis data
        self.y_type = y_type         # 'flux' or 'delta_c'
        self.material = material
        self.hp = hyperparameters
        
        # Initialize data attributes
        self.temperature_raw = None
        self.y_data_raw = None       # Generic y-data (could be flux or delta_c)
        self.Temperature = None
        self.Flux = None             # Always converted to flux for model
        self.TDS_Curve = None
        
        # Process the data
        self._load_data()
        self._process_data()

    def _load_data(self):
        """Load experimental data from Excel file"""
        data = pd.read_excel(self.file_name, header=0)
        self.temperature_raw = data.iloc[:, 0].to_numpy()
        self.y_data_raw = data.iloc[:, 1].to_numpy()

    def _process_data(self):
        """Convert units, smooth, and downsample the data"""
        # Unit conversion
        corrected_flux, corrected_temperature = self._apply_unit_conversion()
        
        # Smooth the data
        smoothed_flux = self._smooth_data(corrected_flux)
        
        # Downsample to model grid
        self._downsample_data(smoothed_flux, corrected_temperature)
        
        # Create TensorFlow tensor
        self.TDS_Curve = [tf.convert_to_tensor(self.Flux)]

    def _apply_unit_conversion(self):
        """Apply unit conversion to y-data and temperature"""
        unit_converter = UnitConverter.UnitConverter(
            atw=1.008,
            mass_density=self.material.mass_density,
            thickness=self.material.Thickness
        )

        corrected_temperature, corrected_flux = unit_converter.convert_experimental_to_standard(
            temp=self.temperature_raw,
            y_data=self.y_data_raw,
            temp_unit=self.temp_units,
            y_unit=self.y_units,
            y_type=self.y_type
        )

        return corrected_flux, corrected_temperature


    def _smooth_data(self, desorption_rate):
        """Apply Savitzky-Golay smoothing filter."""
        window_length = max(3, 2 * round(len(desorption_rate) / self.material.ntp))
    
        if window_length % 2 == 0:
            window_length += 1
        window_length = min(window_length, len(desorption_rate))
        
        return signal.savgol_filter(desorption_rate, window_length, 2)

    def _downsample_data(self, desorption_rate_smoothed, corrected_temperature):
        """Downsample data to model temperature grid using interpolation"""
        # Create model temperature grid (ntp points from TMin to TMax, excluding TMax)
        temperature_model = np.linspace(
            self.material.TMin, 
            self.material.TMax, 
            self.material.ntp + 1
        )

        # Create interpolation function
        interp_function = interp1d(
            corrected_temperature, 
            desorption_rate_smoothed, 
            kind='cubic', 
            bounds_error=False, 
            fill_value=(desorption_rate_smoothed[0], self.hp.flux_threshold)
        )

        # Interpolate to model grid
        desorption_rate_downsampled = interp_function(temperature_model)
        
        # Set extrapolated values to threshold
        beyond_range = temperature_model > np.max(self.temperature_raw)
        desorption_rate_downsampled[beyond_range] = self.hp.flux_threshold

        # Remove last point (TMax) and store results
        self.Temperature = temperature_model[:-1]
        self.Flux = desorption_rate_downsampled[:-1]

    def get_raw_data(self):
        """Get original unprocessed data"""
        return self.temperature_raw, self.y_data_raw

    def get_processed_data(self):
        """Get processed data ready for modelling"""
        return self.Temperature, self.Flux

    def get_data_info(self):
        """Get information about the data types and units"""
        return {
            'original_temp_units': self.temp_units,
            'original_y_units': self.y_units,
            'original_y_type': self.y_type,
            'processed_temp_units': 'K',
            'processed_flux_units': 'mol/m²/s'
        }

    def plot_comparison(self):
        """Plot raw vs processed data for visualization"""
        try:
            import matplotlib.pyplot as plt
            
            # Get appropriate labels
            unit_converter = UnitConverter.UnitConverter()
            labels = unit_converter.get_labels(self.temp_units, self.y_units)
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
            
            # Raw data
            ax1.plot(self.temperature_raw, self.y_data_raw, 'b-', alpha=0.7)
            ax1.set_xlabel(f'Temperature ({self.temp_units})')
            ax1.set_ylabel(f'{self.y_type.title()} ({self.y_units})')
            ax1.set_title('Raw Data')
            ax1.grid(True, alpha=0.3)
            
            # Processed data
            ax2.plot(self.Temperature, self.Flux, 'r-', linewidth=2)
            ax2.set_xlabel('Temperature (K)')
            ax2.set_ylabel('Flux (mol/m²/s)')
            ax2.set_title('Processed Data (Converted to Flux)')
            ax2.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
            
        except ImportError:
            print("Matplotlib not available for plotting")
