import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
# User settings (edit if needed)
# -----------------------------
csv_path = "Exp_datasets.csv"   # or full path
window = 50                    # rolling mean window
threshold = 0.157              # threshold on k/k0
x_scale = 1000                 # to plot time in x10^3 s
initial_time_max = 50          # normalize by mean for t <= 50 (same as your code)

# If your column-to-condition mapping is different, just reorder these four entries.
# Each entry: (time_col, k_col, label, color, linestyle)
datasets = [
    (8, 9,  r'Steady ($\Delta P = 150$ mbar, $C = 0.3\%$, $d_p = 4.5\ \mu m$)', 'red',    '-'),
    (6, 7,  r'Steady ($\Delta P = 150$ mbar, $C = 0.2\%$, $d_p = 6\ \mu m$)',   'green',  '-'),
    (4, 5,  r'Steady ($\Delta P = 100$ mbar, $C = 0.3\%$, $d_p = 6\ \mu m$)',   'purple',   '-'),
    (2, 3,  r'Steady ($\Delta P = 150$ mbar, $C = 0.3\%$, $d_p = 6\ \mu m$)',   'blue', '-'),
    (0, 1, r'Steady ($\Delta P = 150$ mbar, $C = 0.3\%$, $d_p = 6\ \mu m$ $reproduced$ )', 'blue', '--'),
]

# -----------------------------
# Plot style (Times New Roman)
# -----------------------------
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 14

# -----------------------------
# Load CSV (no header)
# -----------------------------
raw = pd.read_csv(csv_path, header=None)

def create_df_from_cols(df_raw, xcol, ycol):
    df = df_raw[[xcol, ycol]].copy()
    df.columns = ['x', 'y']

    # Clean NaNs / non-numeric
    df['x'] = pd.to_numeric(df['x'], errors='coerce')
    df['y'] = pd.to_numeric(df['y'], errors='coerce')
    df = df.dropna(subset=['x', 'y']).sort_values('x')

    # Smooth
    df['y_smooth'] = df['y'].rolling(window=window).mean()
    df = df.dropna(subset=['y_smooth'])

    # Scale x
    df['x_scaled'] = df['x'] / x_scale

    # Normalize by initial mean
    mean_initial = df.loc[df['x'] <= initial_time_max, 'y_smooth'].mean()
    if pd.isna(mean_initial) or mean_initial == 0:
        mean_initial = df['y_smooth'].iloc[0]  # fallback
    df['y_smooth'] = df['y_smooth'] / mean_initial

    return df

def get_threshold_time(df):
    hit = df[df['y_smooth'] < threshold]
    return hit['x'].min() if not hit.empty else np.nan

# -----------------------------
# Build, plot, and annotate
# -----------------------------
plt.figure(figsize=(10, 5))

threshold_times = []
for xcol, ycol, label, color, ls in datasets:
    dfi = create_df_from_cols(raw, xcol, ycol)
    plt.plot(dfi['x_scaled'], dfi['y_smooth'], color=color, linestyle=ls, linewidth=2, label=label)

    t_hit = get_threshold_time(dfi)
    threshold_times.append((t_hit, color))

# Vertical lines at threshold crossing
for t_hit, color in threshold_times:
    if pd.notna(t_hit):
        plt.axvline(x=t_hit / x_scale, linestyle='--', color=color, alpha=0.5)

# Horizontal threshold line
plt.axhline(y=threshold, linestyle='--', color='black', alpha=0.6, label='')

# Formatting
plt.xlabel(r'Time ($\times 10^3$ sec)', fontsize=16)
plt.ylabel(r'$k/k_0$', fontsize=16)
plt.grid(True)

ax = plt.gca()
ax.set_xlim([0, 14])
ax.set_xticks(np.arange(0, 15, 1))
ax.tick_params(axis='both', which='major', labelsize=14)
ax.ticklabel_format(axis='x', style='plain')
ax.xaxis.get_offset_text().set_visible(False)

plt.legend(fontsize=12)
plt.tight_layout()
plt.show()
