#! /usr/bin/env python3
"""correct_merge_offset — subtract sub-swath offsets from a merged grid.

Python port of csh correct_merge_offset.csh. After merging 2 or 3 TOPS
sub-swaths, residual offsets across stitching boundaries can remain
(e.g. due to range-time differences). This estimates the offset by
comparing narrow strips on each side of the boundary and subtracts it.

Usage:  correct_merge_offset merge_list merge_log input.grd output.grd
"""
import subprocess
import sys
from gmtsar_lib import run


def _capture(cmd):
    return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE,
                          check=False).stdout.decode("utf-8").strip()


def _grdinfo_field(grd, field):
    return _capture(f"gmt grdinfo {grd} -C | awk '{{print ${field}}}'")


def correct_merge_offset():
    if len(sys.argv) != 5:
        sys.exit(
            "Usage: correct_merge_offset merge_list merge_log input.grd output.grd\n"
            "  Removes residual offsets at sub-swath stitching boundaries."
        )
    merge_list, merge_log, inp, out = sys.argv[1:5]
    wid, spc = 5, 5

    # Read merge_list (one "path:grd" per line, 2 or 3 entries).
    with open(merge_list) as f:
        lines = [ln.strip() for ln in f if ":" in ln]
    nf = len(lines)
    grd1 = lines[0].split(":", 1)[1]
    if nf == 3:
        grd2 = lines[1].split(":", 1)[1]
    elif nf == 2:
        grd2 = lines[-1].split(":", 1)[1]
    else:
        sys.exit("incorrect number of files in merge_list (expected 2 or 3)")

    incx = float(_grdinfo_field(inp, 8))
    xmin = float(_grdinfo_field(inp, 2))
    xmax = float(_grdinfo_field(inp, 3))
    ymin = float(_grdinfo_field(inp, 4))
    ymax = float(_grdinfo_field(inp, 5))

    nx1   = int(_grdinfo_field(grd1, 10))
    n1    = int(_capture(f"grep n1 {merge_log} | awk '{{print $NF}}'"))
    ovl12 = int(_capture(f"grep ovl {merge_log} | awk -F: '{{print $2}}' | awk -F, '{{print $1}}'"))
    stitch_position1 = nx1 + n1 - ovl12
    position1 = incx * stitch_position1

    stitch_position2 = position2 = None
    if nf == 3:
        nx2   = int(_grdinfo_field(grd2, 10))
        n2    = int(_capture(f"grep n2 {merge_log} | awk '{{print $NF}}'"))
        ovl23 = int(_capture(f"grep ovl {merge_log} | awk -F: '{{print $2}}' | awk -F, '{{print $2}}'"))
        stitch_position2 = nx1 + nx2 - 1 - ovl12 - ovl23 + n2
        position2 = incx * stitch_position2

    print(f"Stitch positions {stitch_position1} {stitch_position2} ...")

    def _strip_R(sp):
        # Two narrow strips of `wid` pixels on each side of stitch position, `spc` apart.
        x_left  = incx * (sp - spc - wid)
        x_left2 = incx * (sp - spc)
        x_right = incx * (sp + 1 + spc)
        x_right2 = incx * (sp + 1 + spc + wid)
        return (f"-R{x_left}/{x_left2}/{ymin}/{ymax}",
                f"-R{x_right}/{x_right2}/{ymin}/{ymax}")

    R1, R2 = _strip_R(stitch_position1)
    run(f"gmt grdcut {inp} {R1} -Gtmp1_{out}")
    run(f"gmt grdcut {inp} {R2} -Gtmp2_{out}")
    run(f"gmt grdedit tmp2_{out} -Rtmp1_{out} -Gtmp2_{out}")
    run(f"gmt grdmath tmp2_{out} tmp1_{out} SUB = tmp12_diff.grd")
    diff1 = float(_capture("gmt grdinfo tmp12_diff.grd -L1 -C | awk '{print $12}'"))

    diff2 = None
    if nf == 3:
        R1b, R2b = _strip_R(stitch_position2)
        run(f"gmt grdcut {inp} {R1b} -Gtmp1_{out}")
        run(f"gmt grdcut {inp} {R2b} -Gtmp2_{out}")
        run(f"gmt grdedit tmp2_{out} -Rtmp1_{out} -Gtmp2_{out}")
        run(f"gmt grdmath tmp2_{out} tmp1_{out} SUB = tmp23_diff.grd")
        diff2 = float(_capture("gmt grdinfo tmp23_diff.grd -L1 -C | awk '{print $12}'"))

    if nf == 2:
        run(f"gmt grdcut {inp} -R{xmin}/{position1}/{ymin}/{ymax} -Gtmp1_{out}")
        run(f"gmt grdcut {inp} -R{position1}/{xmax}/{ymin}/{ymax} -Gtmp2_{out}")
        print(f"Correcting second image by {diff1} ...")
        run(f"gmt grdmath tmp2_{out} {diff1} SUB = tmp2_{out}")
        run(f"gmt grdpaste tmp1_{out} tmp2_{out} -G{out}")
        run(f"rm -f tmp1_{out} tmp2_{out} tmp12_diff.grd")
    else:  # nf == 3
        run(f"gmt grdcut {inp} -R{xmin}/{position1}/{ymin}/{ymax} -Gtmp1_{out}")
        run(f"gmt grdcut {inp} -R{position1}/{position2}/{ymin}/{ymax} -Gtmp2_{out}")
        run(f"gmt grdcut {inp} -R{position2}/{xmax}/{ymin}/{ymax} -Gtmp3_{out}")
        print(f"Correcting second image by {diff1} ...")
        print(f"Correcting third image by {diff1} + {diff2} ...")
        run(f"gmt grdmath tmp2_{out} {diff1} SUB = tmp2_{out}")
        run(f"gmt grdmath tmp3_{out} {diff1} SUB {diff2} SUB = tmp3_{out}")
        run(f"gmt grdpaste tmp1_{out} tmp2_{out} -Gtmp4_{out}")
        run(f"gmt grdpaste tmp4_{out} tmp3_{out} -G{out}")
        run(f"rm -f tmp1_{out} tmp2_{out} tmp3_{out} tmp4_{out} tmp12_diff.grd tmp23_diff.grd")


if __name__ == "__main__":
    correct_merge_offset()
