#! /usr/bin/env python3
"""p2p_S1_TOPS_doublediff — three consecutive S1 TOPS scenes spanning the same
time interval, produce two interferograms (merge1, merge2) and a phase
double-difference (doublediff/phase_diff.grd).

Python port of `p2p_S1_TOPS_doublediff.csh` (bundled with the
S1_SLC_TOPS_Ross_doubledifference test).

Usage:
  p2p_S1_TOPS_doublediff s1.SAFE s1.EOF s2.SAFE s2.EOF s3.SAFE s3.EOF \\
      config.py pol parallel
"""
import os, sys, glob, shutil, subprocess, multiprocessing
from gmtsar_lib import run


def _ann_stem(safe_root, iw, pol):
    """Return the .xml stem in <safe>/annotation/ matching subswath iw and pol.
    csh: `ls */*iw1*$pol*xml | awk '{print substr($1,12,length($1)-15)}'`
    The substr(...,12,len-15) drops 'annotation/' prefix and '.xml' suffix."""
    ann = os.path.join(safe_root, "annotation")
    if not os.path.isdir(ann):
        sys.exit(f"missing {ann}")
    for fn in os.listdir(ann):
        if fn.startswith('._') or fn.startswith('.'):
            continue
        if iw in fn and pol in fn and fn.endswith('.xml'):
            return fn[:-4]
    sys.exit(f"no {iw}/{pol}/.xml in {ann}")


def _patch_config(src, dst):
    """Mirror csh sed: set threshold_geocode=0 and threshold_snaphu=0."""
    if not os.path.isfile(src):
        sys.exit(f"_patch_config: src config missing: {src} — staged config or "
                 f"bundled config.txt not found. Refusing to write empty {dst}.")
    with open(src) as f:
        lines = f.readlines()
    overrides = {"threshold_geocode": "0", "threshold_snaphu": "0"}
    seen, out = set(), []
    for line in lines:
        s = line.lstrip()
        replaced = False
        for k, v in overrides.items():
            if s.startswith(k + " ") or s.startswith(k + "="):
                out.append(f"{k} = {v}\n"); seen.add(k); replaced = True; break
        if not replaced:
            out.append(line)
    for k, v in overrides.items():
        if k not in seen:
            out.append(f"{k} = {v}\n")
    with open(dst, "w") as f:
        f.writelines(out)


def _setup_subswath(dirname, conf_src, src_safe, src_eof, src_xml,
                    aln_safe, aln_eof, aln_xml):
    """Build F<n>/{raw,topo,config.py} for one subswath/pair.
    src_* and aln_* are the master/aligned SAFE roots (relative to case_dir/raw)
    and the per-subswath xml stems (e.g. 's1a-iw1-slc-hh-...-005')."""
    os.makedirs(dirname, exist_ok=True)
    os.makedirs(f"{dirname}/raw", exist_ok=True)
    os.makedirs(f"{dirname}/topo", exist_ok=True)
    _patch_config(conf_src, f"{dirname}/{os.path.basename(conf_src)}")
    run(f"ln -sf ../../topo/dem.grd {dirname}/topo/")
    raw = f"{dirname}/raw"
    run(f"ln -sf ../topo/dem.grd {raw}/")
    # master subswath. EOF must be symlinked as <xml_stem>.EOF — pre_proc's
    # S1_TOPS path looks for the EOF by xml_stem name, not by the original
    # POEORB filename (matching what p2p_S1_TOPS_Frame's linkFiles does).
    run(f"ln -sf ../../raw/{src_safe}/annotation/{src_xml}.xml {raw}/")
    run(f"ln -sf ../../raw/{src_safe}/measurement/{src_xml}.tiff {raw}/")
    run(f"ln -sf ../../raw/{src_eof} {raw}/{src_xml}.EOF")
    # aligned subswath
    run(f"ln -sf ../../raw/{aln_safe}/annotation/{aln_xml}.xml {raw}/")
    run(f"ln -sf ../../raw/{aln_safe}/measurement/{aln_xml}.tiff {raw}/")
    run(f"ln -sf ../../raw/{aln_eof} {raw}/{aln_xml}.EOF")


def _run_p2p(dirname, m_xml, a_xml, conf_name):
    # `p2p_processing S1_TOPS <m_xml> <a_xml> config.py` handles the full
    # pipeline including align_tops internally — matching p2p_S1_TOPS_Frame's
    # processOneSubswath. The csh original called align_tops + p2p_S1_TOPS.csh
    # separately with short S1_<date>_<time>_F<n> prefixes; the Python
    # p2p_processing expects raw XML stems and re-derives the short prefixes
    # internally via renameMasterAlignedForS1tops.
    cwd = os.getcwd()
    os.chdir(dirname)
    run(f"p2p_processing S1_TOPS {m_xml} {a_xml} {conf_name}")
    os.chdir(cwd)


def _process_pair(swaths, conf_name, parallel):
    """swaths: list of (dirname, m_xml, m_eof, a_xml, a_eof)."""
    tasks = [(s[0], s[1], s[3], conf_name) for s in swaths]
    if parallel:
        with multiprocessing.Pool(processes=len(tasks)) as pool:
            pool.starmap(_run_p2p, tasks)
    else:
        for t in tasks:
            _run_p2p(*t)


def _merge_dir(merge_name, subswath_dirs, conf_name, det_stitch="0"):
    """Build merge<n>/ from given F<x> subswath dirs and run merge_unwrap_geocode_tops."""
    os.makedirs(merge_name, exist_ok=True)
    cwd = os.getcwd()
    os.chdir(merge_name)
    run(f"ln -sf ../{conf_name} .")
    run("ln -sf ../topo/dem.grd .")
    run(f"ln -sf ../{subswath_dirs[0]}/intf/*/gauss* .")
    rows = []
    for sw in subswath_dirs:
        prms = sorted(glob.glob(f"../{sw}/intf/*/*.PRM"))
        if len(prms) < 2:
            sys.exit(f"{sw}/intf/*/*.PRM: need 2 PRM files, found {len(prms)}")
        pth = os.path.dirname(prms[0]) + "/"
        rows.append(f"{pth}:{os.path.basename(prms[0])}:{os.path.basename(prms[1])}")
    with open("tmp.filelist", "w") as f:
        f.write("\n".join(rows) + "\n")
    run(f"merge_unwrap_geocode_tops tmp.filelist {conf_name} {det_stitch}")
    os.chdir(cwd)


def _doublediff():
    """csh:
        cp merge1/phasefilt_ll.grd ./phasefilt_1_temp.grd
        cp merge2/phasefilt_ll.grd ./phasefilt_2_temp.grd
        gmt grdsample phasefilt_2_temp.grd -Rphasefilt_1_temp.grd -Gphasefilt_2.grd
        gmt grdsample phasefilt_1_temp.grd -Rphasefilt_2.grd -Gphasefilt_1.grd
        gmt grdmath phasefilt_1.grd phasefilt_2.grd SUB = temp.grd
        gmt grdmath temp.grd 6.2832 ADD 6.2832 MOD PI SUB = phase_diff.grd
    """
    # Pre-condition: both merges must have produced phasefilt_ll.grd.
    for required in ("merge1/phasefilt_ll.grd", "merge2/phasefilt_ll.grd"):
        if not os.path.exists(required):
            sys.exit(f"doublediff: missing {required} — earlier merge stage failed silently. "
                     f"Cannot compute phase_diff without both pair-wise unwraps.")
    os.makedirs("doublediff", exist_ok=True)
    cwd = os.getcwd()
    os.chdir("doublediff")
    shutil.copy("../merge1/phasefilt_ll.grd", "./phasefilt_1_temp.grd")
    shutil.copy("../merge2/phasefilt_ll.grd", "./phasefilt_2_temp.grd")
    run("gmt grdsample phasefilt_2_temp.grd -Rphasefilt_1_temp.grd -Gphasefilt_2.grd -V")
    run("gmt grdsample phasefilt_1_temp.grd -Rphasefilt_2.grd -Gphasefilt_1.grd -V")
    for f in glob.glob("*_temp*"):
        os.remove(f)
    run("gmt grdmath phasefilt_1.grd phasefilt_2.grd SUB = temp.grd")
    run("gmt grdmath temp.grd 6.2832 ADD 6.2832 MOD PI SUB = phase_diff.grd")
    run("gmt makecpt -Crainbow -T-3.1416/3.1416 -Z > phase.cpt")
    run("grd2kml phase_diff phase.cpt")
    if os.path.exists("temp.grd"):
        os.remove("temp.grd")
    for f in glob.glob("gmt*"):
        try: os.remove(f)
        except OSError: pass
    os.chdir(cwd)


def main():
    if len(sys.argv) != 10:
        sys.exit("Usage: p2p_S1_TOPS_doublediff s1.SAFE s1.EOF s2.SAFE s2.EOF "
                 "s3.SAFE s3.EOF config.py pol parallel")
    safe1, eof1, safe2, eof2, safe3, eof3, conf, pol, par = sys.argv[1:10]
    parallel = int(par)

    # Discover per-subswath xml stems (per SAFE × iw1..iw3).
    f1s1 = _ann_stem(f"raw/{safe1}", "iw1", pol)
    f2s1 = _ann_stem(f"raw/{safe1}", "iw2", pol)
    f3s1 = _ann_stem(f"raw/{safe1}", "iw3", pol)
    f1s2 = _ann_stem(f"raw/{safe2}", "iw1", pol)
    f2s2 = _ann_stem(f"raw/{safe2}", "iw2", pol)
    f3s2 = _ann_stem(f"raw/{safe2}", "iw3", pol)
    f1s3 = _ann_stem(f"raw/{safe3}", "iw1", pol)
    f2s3 = _ann_stem(f"raw/{safe3}", "iw2", pol)
    f3s3 = _ann_stem(f"raw/{safe3}", "iw3", pol)

    # Pair 1: SAFE1 → SAFE2, three subswaths in F1, F2, F3.
    for dirname, sx1, sx2 in (("F1", f1s1, f1s2),
                              ("F2", f2s1, f2s2),
                              ("F3", f3s1, f3s2)):
        _setup_subswath(dirname, conf, safe1, eof1, sx1, safe2, eof2, sx2)

    # Pair 2: SAFE2 → SAFE3, three subswaths in F2_1, F2_2, F2_3.
    for dirname, sx1, sx2 in (("F2_1", f1s2, f1s3),
                              ("F2_2", f2s2, f2s3),
                              ("F2_3", f3s2, f3s3)):
        _setup_subswath(dirname, conf, safe2, eof2, sx1, safe3, eof3, sx2)

    # Process pair 1.
    pair1 = [("F1", f1s1, eof1, f1s2, eof2),
             ("F2", f2s1, eof1, f2s2, eof2),
             ("F3", f3s1, eof1, f3s2, eof2)]
    _process_pair(pair1, os.path.basename(conf), parallel == 1)

    # Process pair 2.
    pair2 = [("F2_1", f1s2, eof2, f1s3, eof3),
             ("F2_2", f2s2, eof2, f2s3, eof3),
             ("F2_3", f3s2, eof2, f3s3, eof3)]
    _process_pair(pair2, os.path.basename(conf), parallel == 1)

    # Merge each pair.
    _merge_dir("merge1", ["F1", "F2", "F3"], os.path.basename(conf))
    _merge_dir("merge2", ["F2_1", "F2_2", "F2_3"], os.path.basename(conf))

    # Double difference.
    _doublediff()


if __name__ == "__main__":
    main()
