#! /usr/bin/env python3
"""estimate_ionospheric_phase — split-spectrum ionospheric phase estimate.

Python port of csh estimate_ionospheric_phase.csh. Implements the
Gomba et al. (2016) split-spectrum method with Fattahi et al. (2017)
filtering. Iono path only — not exercised by the standard test suite.

Usage:  estimate_ionospheric_phase intf_high intf_low intf_orig
                                   intf_to_be_corrected [xratio yratio]
Output: ph_iono.grd, ph_iono_orig.grd, ph_corrected.grd in cwd.
"""
import glob
import os
import subprocess
import sys

from gmtsar_lib import run, grep_value


def _capture(cmd):
    return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE,
                          check=False).stdout.decode('utf-8').strip()


def estimate_ionospheric_phase():
    if len(sys.argv) not in (5, 7):
        sys.exit(
            "Usage: estimate_ionospheric_phase intf_high intf_low intf_orig "
            "intf_to_be_corrected [xratio yratio]\n"
            "  Estimates ionosphere via split-spectrum (Gomba 2016 + Fattahi 2017)."
        )
    intf_h, intf_l, intf_o, intf = sys.argv[1:5]
    rx, ry = ("1", "1") if len(sys.argv) == 5 else (sys.argv[5], sys.argv[6])

    prms = sorted(glob.glob(f"{intf_h}/*PRM"))
    if len(prms) < 2:
        sys.exit(f"estimate_ionospheric_phase: need at least 2 PRMs in {intf_h}, "
                 f"found {len(prms)}")
    prm1, prm2 = prms[0], prms[-1]

    # Spectrum frequencies from params1 (high/low split product)
    fc = grep_value(f"{intf_h}/params1", "center_freq", 3)
    fh = grep_value(f"{intf_h}/params1", "high_freq",   3)
    fl = grep_value(f"{intf_h}/params1", "low_freq",    3)
    thresh = 0.1

    print(f"Applying split spectrum result to estimate ionospheric phase ({fh} {fl})...")
    run(f"cp {intf}/phasefilt.grd ./ph0.grd")

    # Determine filter sizes from PRM metadata + grid resolution
    wavelength = 20000.0
    rng_samp_rate = float(grep_value(prm1, "rng_samp_rate", 3))
    rng_pxl = 299792458.0 / rng_samp_rate / 2.0
    prf = float(grep_value(prm1, "PRF", 3))
    vel = float(grep_value(prm1, "SC_vel", 3))
    azi_pxl = vel / prf

    x_inc = float(_capture(f"gmt grdinfo {intf_h}/phasefilt.grd -C | awk '{{print $8}}'"))
    y_inc = float(_capture(f"gmt grdinfo {intf_h}/phasefilt.grd -C | awk '{{print $9}}'"))

    # Filter size: round to nearest odd integer via int(x/2)*2+1
    filtx = int(wavelength * float(rx) / rng_pxl / x_inc / 2) * 2 + 1
    filty = int(wavelength * float(ry) / azi_pxl / y_inc / 2) * 2 + 1
    filt_incx = filtx // 8
    filt_incy = filty // 8
    print(f"Filtering size: {filtx} along range, {filty} along azimuth")

    limit = (float(fh) * float(fl)) / (float(fh) ** 2 - float(fl) ** 2) * 3.1415926

    # Copy unwrapped phases
    for grd in ("up_h", "up_l", "up_o"):
        src = {"up_h": intf_h, "up_l": intf_l, "up_o": intf_o}[grd]
        run(f"cp {src}/unwrap.grd ./{grd}.grd")

    # Unwrapping-error correction (multiples of 2pi)
    for tag, side in (("h", "h"), ("l", "l")):
        run(f"gmt grdmath up_{side}.grd up_o.grd SUB = tmp.grd")
        c_raw = float(_capture("gmt grdinfo tmp.grd -L1 -C | awk '{print $12}'"))
        c = int(c_raw / 6.2831853072 + (0.5 if c_raw >= 0 else -0.5))
        print(f"Correcting {tag}igh/low passed phase by {c} * 2PI ...")
        run(f"gmt grdmath up_{side}.grd {c} 2 PI MUL MUL SUB = tmp.grd")
        run(f"mv tmp.grd up_{side}.grd")

    # Build mask_up from inter-passband consistency
    run("gmt grdmath up_h.grd up_l.grd ADD up_o.grd 2 MUL SUB = tmp.grd")
    run("gmt grdmath tmp.grd 2 PI MUL DIV ABS 0.2 GE 1 SUB -1 MUL 0 NAN = mask_up.grd")

    # Median-filter the masked phases
    for side in ("h", "l"):
        run(f"gmt grdmath up_{side}.grd mask_up.grd MUL = tmp.grd")
        run(f"gmt grdfilter tmp.grd -Dp -Fm21/21 -Gup_{side}.grd -Vq -Nr")

    # Correlation-based masks
    run(f"gmt grdmath {intf_h}/corr.grd {intf_l}/corr.grd ADD 2 DIV 0 DENAN "
        f"{thresh} GE 0 NAN 0 MUL 1 ADD mask_up.grd MUL = mask.grd")
    run(f"gmt grdmath {intf_h}/corr.grd {intf_l}/corr.grd ADD 2 DIV 0 DENAN "
        f"{thresh} GE 0 NAN ISNAN 1 SUB -1 MUL mask_up.grd 0 DENAN MUL = mask1.grd")
    run("gmt grdmath mask1.grd 1 SUB -1 MUL = mask2.grd")

    # Iono phase: (fh/fc * up_l - fl/fc * up_h) / (fl*fh - fh^2 + fl^2) * fl*fh
    run(f"gmt grdmath {fh} {fc} DIV up_l.grd MUL {fl} {fc} DIV up_h.grd MUL SUB "
        f"{fl} {fh} MUL {fh} {fh} MUL {fl} {fl} MUL SUB DIV MUL = tmp_ph0.grd")

    # Merge-log correction (multi-frame stitching support)
    has_merge_log  = os.path.exists(f"{intf_h}/merge_log")
    has_merge_log1 = os.path.exists(f"{intf_h}/merge_log1")
    if has_merge_log or has_merge_log1:
        run("cp tmp_ph0.grd tmp_ph0_save.grd")
        if has_merge_log:
            run(f"cp {intf_h}/tmp_phaselist .")
            run("correct_merge_offset.csh tmp_phaselist ../intf_h/merge_log "
                "tmp_ph0.grd tmp_ph0_corrected.grd")
            run("mv tmp_ph0_corrected.grd tmp_ph0.grd")
        else:
            run(f"cp {intf_h}/tmp_phaselist1 .")
            run("correct_merge_offset.csh tmp_phaselist1 ../intf_h/merge_log1 "
                "tmp_ph0.grd tmp_ph0_corrected.grd")
            run(f"ln -s {intf_h}/tmp_first .")
            run(f"cp {intf_h}/tmp_phaselist2 .")
            run("correct_merge_offset.csh tmp_phaselist2 ../intf_h/merge_log2 "
                "tmp_ph0_corrected.grd tmp_ph0.grd")
            run("rm -f tmp_ph0_corrected.grd")

    run("gmt grdmath tmp_ph0.grd mask.grd MUL = tmp_ph.grd")
    run("cp tmp_ph.grd tmp_ph1.grd")

    mm = float(_capture("gmt grdinfo tmp_ph1.grd -L1 -C | awk '{print $12}'"))
    run(f"gmt grdmath tmp_ph0.grd {mm} {limit} ADD LE = tmp1.grd")
    run(f"gmt grdmath tmp_ph0.grd {mm} {limit} SUB GE = tmp2.grd")
    run("gmt grdmath tmp1.grd tmp2.grd MUL 0 NAN mask.grd MUL = tmp.grd")
    run("mv tmp.grd mask.grd")
    run("gmt grdmath tmp1.grd tmp2.grd MUL 0 NAN ISNAN 1 SUB -1 MUL mask1.grd "
        "MUL 0 DENAN = tmp.grd")
    run("mv tmp.grd mask1.grd")
    run("gmt grdmath tmp_ph0.grd mask.grd MUL = tmp_ph.grd")
    run("gmt grdmath mask1.grd 1 SUB -1 MUL = mask2.grd")

    run("nearest_grid tmp_ph.grd tmp_ph_interp.grd")

    # 3 iterations of filter + surface + nearest_grid
    for iteration in (1, 2, 3):
        odd = iteration % 2
        filt_mode = "m" if odd == 1 else "b"
        run(f"gmt grdfilter tmp_ph_interp.grd -Dp -F{filt_mode}{filtx}/{filty} "
            f"-Gtmp_filt.grd -Vq -Ni -I{filt_incx}/{filt_incy}")
        run("gmt grd2xyz tmp_filt.grd -s | gmt surface -Rtmp_ph0.grd -T0.1 -Gtmp.grd")
        run("mv tmp.grd tmp_filt.grd")
        run(f"cp tmp_filt.grd tmp_{iteration}.grd")
        run("gmt grdmath tmp_filt.grd mask.grd MUL = tmp.grd")
        run("nearest_grid tmp.grd tmp2.grd")
        run("gmt grdmath tmp2.grd mask2.grd MUL tmp_ph0.grd 0 DENAN mask1.grd "
            "MUL ADD = tmp_ph_interp.grd")

    # Final resampling / filtering pass
    RR = _capture("gmt grdinfo -I- ph0.grd")
    II = _capture("gmt grdinfo -I tmp_ph0.grd")
    run(f"gmt grdfilter tmp_ph_interp.grd -Dp -Fb{filtx}/{filty} "
        f"-Gtmp_filt.grd -Vq -Ni -I{filt_incx}/{filt_incy}")
    run(f"gmt grd2xyz tmp_filt.grd -s | gmt surface {RR} {II} -Gtmp.grd -T0.1")
    run("mv tmp.grd tmp_filt.grd")
    run(f"gmt grdfilter tmp_filt.grd -Dp -Fg{filtx}/{filty} "
        f"-Gtmp.grd -Vq -Ni -I{filt_incx}/{filt_incy}")
    run("mv tmp.grd tmp_filt.grd")
    run(f"gmt grd2xyz tmp_filt.grd -s | gmt surface {RR} {II} -Gtmp.grd -T0.1")
    run("mv tmp.grd tmp_filt.grd")

    # Wrap to (-pi, pi]
    run("gmt grdmath tmp_filt.grd PI ADD 2 PI MUL MOD PI SUB = tmp_ph.grd")
    run("cp tmp_ph.grd ph_iono.grd")

    # ph0 - tmp_filt → corrected phase
    run("gmt grdsample tmp_filt.grd -Rph0.grd -Gtmp.grd")
    run("gmt grdmath ph0.grd tmp.grd SUB PI ADD 2 PI MUL MOD PI SUB = ph_corrected.grd")

    # 2pi-cycle correction for ph_iono
    cc_raw = float(_capture("gmt grdinfo ph_corrected.grd -L1 -C | awk '{print $12}'"))
    cc = int(cc_raw / 3.141592653 + (0.5 if cc_raw >= 0 else -0.5))
    print(f"Correcting iono phase by {cc} PI ...")
    run(f"gmt grdmath tmp_filt.grd {cc} PI MUL ADD = tmp_ph.grd")
    run("gmt grdmath tmp_ph.grd PI ADD 2 PI MUL MOD PI SUB = ph_iono.grd")

    run("gmt grdsample tmp_ph.grd -Rph0.grd -Gtmp.grd")
    run("gmt grdmath ph0.grd tmp.grd SUB PI ADD 2 PI MUL MOD PI SUB = ph_corrected.grd")
    run("mv tmp_ph.grd ph_iono_orig.grd")


if __name__ == "__main__":
    estimate_ionospheric_phase()
