#!/usr/bin/python3

r'''Visualize the difference in projection between N models

SYNOPSIS

  $ mrcal-show-projection-diff before.cameramodel after.cameramodel
  ... a plot pops up showing how these two models differ in their projections

The operation of this tool is documented at
http://mrcal.secretsauce.net/differencing.html

This tool visualizes the results of mrcal.projection_diff()

It is often useful to compare the projection behavior of two camera models. For
instance, one may want to validate a calibration by comparing the results of two
different chessboard dances. Or one may want to evaluate the stability of the
intrinsics in response to mechanical or thermal stresses. This tool makes these
comparisons, and produces a visualization of the results.

In the most common case we're given exactly 2 models to compare. We then display
the projection difference as either a vector field or a heat map. If we're given
more than 2 models, then a vector field isn't possible and we instead display as
a heatmap the standard deviation of the differences between models 1..N and
model0.

The top-level operation of this tool:

- Grid the imager
- Unproject each point in the grid using one camera model
- Apply a transformation to map this point from one camera's coord system to the
  other. How we obtain this transformation is described below
- Project the transformed points to the other camera
- Look at the resulting pixel difference in the reprojection

Several arguments control how we obtain the transformation. Top-level logic:

  if --intrinsics-only:
      Rt10 = identity_Rt()
  else:
      if --radius 0:
          Rt10 = relative_extrinsics(models)
      else:
          Rt10 = implied_Rt10__from_unprojections()

The details of how the comparison is computed, and the meaning of the arguments
controlling this, are in the docstring of mrcal.projection_diff().

'''

import sys
import argparse
import re
import os

def parse_args():

    parser = \
        argparse.ArgumentParser(description = __doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('--gridn',
                        type=int,
                        default = (60,40),
                        nargs = 2,
                        help='''How densely we should sample the imager. By default we use a 60x40 grid''')
    parser.add_argument('--distance',
                        type=str,
                        help='''Has an effect only without --intrinsics-only.
                        The projection difference varies depending on the range
                        to the observed world points, with the queried range set
                        in this argument. If omitted we look out to infinity. We
                        can also fit multiple distances at the same time by
                        passing them here in a comma-separated, whitespace-less
                        list. If multiple distances are given, we fit the
                        implied-by-the-intrinsics transformation using ALL the
                        distances, but we display the difference for the FIRST
                        distance given. Only one distance is supported if
                        --vectorfield. Multiple distances are especially useful
                        if we have uncertainties: the most confident distance
                        will be found, and displayed.''')

    parser.add_argument('--intrinsics-only',
                        action='store_true',
                        help='''If given, we evaluate the intrinsics of each
                        lens in isolation by assuming that the coordinate
                        systems of each camera line up exactly''')

    parser.add_argument('--where',
                        type=float,
                        nargs=2,
                        help='''Center of the region of interest for this diff.
                        Used only without --intrinsics-only and without
                        "--radius 0". It is usually impossible for the models to
                        match everywhere, but focusing on a particular area can
                        work better. The implied transformation will be fit to
                        match as large as possible an area centered on this
                        argument. If omitted, we will focus on the center of the
                        imager''')
    parser.add_argument('--radius',
                        type=float,
                        default=-1.,
                        help='''Radius of the region of interest. If ==0, we do
                        NOT fit an implied transformation at all, but use the
                        transformations in the models. If omitted or <0, we use
                        a "reasonable" value: the whole imager if we're using
                        uncertainties, or min(width,height)/6 if
                        --no-uncertainties. To fit with data across the whole
                        imager in either case, pass in a very large radius''')
    parser.add_argument('--same-dance',
                        action='store_true',
                        default=False,
                        help=argparse.SUPPRESS)
    '''--same-dance IS NOT READY FOR PUBLIC CONSUMPTION. It is purposely
    undocumented. If given, I assume that the given models all came from the
    same calibration dance, and I can thus compute implied_Rt10 geometrically,
    instead of fitting projections. Valid only if we're given exactly two
    models. Exclusive with --where and --radius (those control the fit that we
    skip with --same-dance

    If putting this back, this was a part of the docstring for this tool:

    - If we know that our models came from different interpretations of the same
      board dance (using different lens models or optimization parameters, for
      instance), we can use the geometry of the cameras and frames to compute
      implied_Rt10. Indicate this case by passing --same-dance
'''



    parser.add_argument('--observations',
                        action='store_true',
                        default=False,
                        help='''If given, I show where the chessboard corners were observed at calibration
                        time. These should correspond to the low-diff regions.''')
    parser.add_argument('--valid-intrinsics-region',
                        action='store_true',
                        default=False,
                        help='''If given, I overlay the valid-intrinsics regions onto the plot''')
    parser.add_argument('--cbmax',
                        type=float,
                        default=4,
                        help='''Maximum range of the colorbar''')

    parser.add_argument('--title',
                        type=str,
                        default = None,
                        help='''Title string for the plot. Overrides the default
                        title. Exclusive with --extratitle''')
    parser.add_argument('--extratitle',
                        type=str,
                        default = None,
                        help='''Additional string for the plot to append to the
                        default title. Exclusive with --title''')

    parser.add_argument('--vectorfield',
                        action = 'store_true',
                        default = False,
                        help='''Plot the diff as a vector field instead of as a heat map. The vector field
                        contains more information (magnitude AND direction), but
                        is less clear at a glance''')

    parser.add_argument('--vectorscale',
                        type = float,
                        default = 1.0,
                        help='''If plotting a vectorfield, scale all the vectors by this factor. Useful to
                        improve legibility if the vectors are too small to
                        see''')

    parser.add_argument('--directions',
                        action = 'store_true',
                        help='''If given, the plots are color-coded by the direction of the error, instead of
                        the magnitude''')

    parser.add_argument('--no-uncertainties',
                        action = 'store_true',
                        help='''Used only without --intrinsics-only and without
                        "--radius 0". By default we use the uncertainties in the
                        model to weigh the fit. This will focus the fit on the
                        confident region in the models without --where or
                        --radius. The computation will run faster with
                        --no-uncertainties, but the default --where and --radius
                        may need to be adjusted''')

    parser.add_argument('--hardcopy',
                        type=str,
                        help='''Write the output to disk, instead of making an interactive plot''')
    parser.add_argument('--terminal',
                        type=str,
                        help=r'''gnuplotlib terminal. The default is good almost always, so most people don't
                        need this option''')
    parser.add_argument('--set',
                        type=str,
                        action='append',
                        help='''Extra 'set' directives to gnuplotlib. Can be given multiple times''')
    parser.add_argument('--unset',
                        type=str,
                        action='append',
                        help='''Extra 'unset' directives to gnuplotlib. Can be given multiple times''')

    parser.add_argument('models',
                        type=str,
                        nargs='+',
                        help='''Camera models to diff''')

    args = parser.parse_args()

    if len(args.models) < 2:
        print(f"I need at least two models to diff. Instead got '{args.models}'", file=sys.stderr)
        sys.exit(1)

    if args.title      is not None and \
       args.extratitle is not None:
        print("--title and --extratitle are exclusive", file=sys.stderr)
        sys.exit(1)

    return args

args = parse_args()

# arg-parsing is done before the imports so that --help works without building
# stuff, so that I can generate the manpages and README

if args.vectorscale != 1.0 and not args.vectorfield:
    print("Error: --vectorscale only makes sense with --vectorfield",
          file = sys.stderr)
    sys.exit(1)

if args.distance is None:
    distance = None
else:
    try:
        distance = [float(d) for d in args.distance.split(',')]
    except:
        print("Error: distances must be given a comma-separated list of floats in --distance",
              file = sys.stderr)
        sys.exit(1)

    if len(distance) > 1 and args.vectorfield:
        print("Error: --distance can support at most one value if --vectorfield. Not clear how to plot otherwise",
              file = sys.stderr)
        sys.exit(1)

if len(args.models) > 2:
    if args.vectorfield:
        print("Error: --vectorfield works only with exactly 2 models",
              file = sys.stderr)
        sys.exit(1)
    if args.directions:
        print("Error: --directions works only with exactly 2 models",
              file = sys.stderr)
        sys.exit(1)
    if args.same_dance:
        print("Error: --same-dance works only with exactly 2 models",
              file = sys.stderr)
        sys.exit(1)

if args.same_dance and \
   (args.where is not None or args.radius >= 0):
    print("Error: --same-dance is exclusive with --where and --radius",
          file = sys.stderr)
    sys.exit(1)


import mrcal
import numpy as np
import numpysane as nps


plotkwargs_extra = {}
if args.set is not None:
    plotkwargs_extra['set'] = args.set
if args.unset is not None:
    plotkwargs_extra['unset'] = args.unset

if args.title is not None:
    plotkwargs_extra['title'] = args.title
if args.extratitle is not None:
    plotkwargs_extra['extratitle'] = args.extratitle

def openmodel(f):
    try:
        return mrcal.cameramodel(f)
    except Exception as e:
        print(f"Couldn't load camera model '{f}': {e}",
              file=sys.stderr)
        sys.exit(1)

models = [openmodel(modelfilename) for modelfilename in args.models]


if args.same_dance:

    # EXPERIMENTAL, UNDOCUMENTED STUFF
    if 1:

        # default method: fit the corner observations in the coord system of the
        # camera. Each observation contributes a point to the fit


        # I need to make sure to pick the same inlier set when choosing the frame
        # observations. I pick the intersection of the two
        i_f_ci_ce = [model.optimization_inputs()['indices_frame_camintrinsics_camextrinsics'] \
                     for model in models]
        if i_f_ci_ce[0].shape != i_f_ci_ce[1].shape or \
           not np.all(i_f_ci_ce[0] == i_f_ci_ce[1]):
            print(f"--same-dance means that the two models must have an identical indices_frame_camintrinsics_camextrinsics, but they do not. Shapes {i_f_ci_ce[0].shape},{i_f_ci_ce[1].shape}",
                  file=sys.stderr)
            sys.exit(1)

        obs = [model.optimization_inputs()['observations_board'] \
                     for model in models]
        if obs[0].shape != obs[1].shape:
            print(f"--same-dance means that the two models must have an identically-laid-out observations_board, but they do not. Shapes {obs[0].shape},{obs[1].shape}",
                  file=sys.stderr)
            sys.exit(1)

        idx_inliers = [ obs[i][ ...,2] > 0 \
                        for i in range(len(models)) ]

        calobjects = \
            [ mrcal.hypothesis_board_corner_positions(model.icam_intrinsics(),
                                                      **model.optimization_inputs(),
                                                      idx_inliers = idx_inliers[0]*idx_inliers[1] )[-2] \
              for model in models ]

        Rt10 = mrcal.align_procrustes_points_Rt01(calobjects[1], calobjects[0])

    elif 1:

        # new method being tested. Fit the corners in the ref coordinate
        # systems. This explicitly models the extra existing rotation, so it
        # should be most correct

        optimization_inputs0 = models[0].optimization_inputs()
        optimization_inputs1 = models[1].optimization_inputs()

        calobject_height,calobject_width = optimization_inputs0['observations_board'].shape[1:3]
        calobject_spacing                = optimization_inputs0['calibration_object_spacing']

        # shape (Nh, Nw, 3)
        calibration_object0 = \
            mrcal.ref_calibration_object(calobject_width, calobject_height,
                                         calobject_spacing,
                                         calobject_warp = optimization_inputs0['calobject_warp'])
        calibration_object1 = \
            mrcal.ref_calibration_object(calobject_width, calobject_height,
                                         calobject_spacing,
                                         calobject_warp = optimization_inputs1['calobject_warp'])

        # shape (Nframes, Nh, Nw, 3)
        pcorners_ref0 = \
            mrcal.transform_point_rt( nps.mv( optimization_inputs0['frames_rt_toref'], -2, -4),
                                      calibration_object0 )
        pcorners_ref1 = \
            mrcal.transform_point_rt( nps.mv( optimization_inputs1['frames_rt_toref'], -2, -4),
                                      calibration_object1 )

        # shape (4,3)
        Rt_ref10 = mrcal.align_procrustes_points_Rt01(# shape (N,3)
                                                      nps.clump(pcorners_ref1, n=3),

                                                      # shape (N,3)
                                                      nps.clump(pcorners_ref0, n=3))

        # I have the ref-ref transform. I convert to a cam-cam transform

        icam_extrinsics0 = \
            mrcal.corresponding_icam_extrinsics(models[0].icam_intrinsics(),
                                                **optimization_inputs0)
        icam_extrinsics1 = \
            mrcal.corresponding_icam_extrinsics(models[1].icam_intrinsics(),
                                                **optimization_inputs1)

        if icam_extrinsics0 >= 0:
            Rt_cr0 = mrcal.Rt_from_rt(optimization_inputs0['extrinsics_rt_fromref'][icam_extrinsics0])
        else:
            Rt_cr0 = mrcal.identity_Rt()
        if icam_extrinsics1 >= 0:
            Rt_cr1 = mrcal.Rt_from_rt(optimization_inputs1['extrinsics_rt_fromref'][icam_extrinsics1])
        else:
            Rt_cr1 = mrcal.identity_Rt()

        Rt10 = \
            mrcal.compose_Rt(Rt_cr1,
                             Rt_ref10,
                             mrcal.invert_Rt(Rt_cr0))

    else:

        # default method used in the uncertainty computation. compute the mean
        # of the frame points

        optimization_inputs0 = models[0].optimization_inputs()
        icam_intrinsics0     = models[0].icam_intrinsics()
        icam_extrinsics0     = \
            mrcal.corresponding_icam_extrinsics(icam_intrinsics0, **optimization_inputs0)
        if icam_extrinsics0 >= 0:
            Rt_rc0 = mrcal.invert_Rt(mrcal.Rt_from_rt(optimization_inputs0['extrinsics_rt_fromref'][icam_extrinsics0]))
        else:
            Rt_rc0 = mrcal.identity_Rt()
        Rt_fr0 = mrcal.invert_Rt(mrcal.Rt_from_rt(optimization_inputs0['frames_rt_toref']))

        optimization_inputs1 = models[1].optimization_inputs()
        icam_intrinsics1     = models[1].icam_intrinsics()
        icam_extrinsics1     = \
        mrcal.corresponding_icam_extrinsics(icam_intrinsics1, **optimization_inputs1)
        if icam_extrinsics1 >= 0:
            Rt_cr1 = mrcal.Rt_from_rt(optimization_inputs1['extrinsics_rt_fromref'][icam_extrinsics1])
        else:
            Rt_cr1 = mrcal.identity_Rt()
        Rt_rf1 = mrcal.Rt_from_rt(optimization_inputs1['frames_rt_toref'])

        # not a geometric transformation: the R is not in SO3
        Rt_r1r0 = np.mean( mrcal.compose_Rt(Rt_rf1, Rt_fr0), axis=-3)

        Rt_c1c0 = mrcal.compose_Rt( Rt_cr1, Rt_r1r0, Rt_rc0)


        q0 = models[0].imagersize()*2./5.
        v0_cam = mrcal.unproject(q0, optimization_inputs0['lensmodel'], optimization_inputs0['intrinsics'][icam_intrinsics0],
                                 normalize = True)
        p0_cam = v0_cam * (1e5 if distance is None else distance)

        # print(f"v0_cam = {v0_cam}")
        # print(f"reprojected = {mrcal.project(v0_cam, optimization_inputs0['lensmodel'], optimization_inputs0['intrinsics'][icam_intrinsics0])}")
        # print(f"reprojected other = {mrcal.project(np.array([-0.2471004,-0.21598248,0.9446126 ]), optimization_inputs0['lensmodel'], optimization_inputs0['intrinsics'][icam_intrinsics0])}")
        # print(f"reprojected other = {mrcal.project(np.array([-0.2471004,-0.21598248,0.9446126 ]), *models[0].intrinsics())}")
        # print(f"p0_cam = {p0_cam}")

        p0_ref = mrcal.transform_point_Rt(Rt_rc0, p0_cam)
        p0_frame  = mrcal.transform_point_Rt(Rt_fr0, p0_ref)
        p1_refall = mrcal.transform_point_Rt(Rt_rf1, p0_frame)
        p1_ref    = np.mean(p1_refall, axis=0)

        # print(f"p0_ref = {p0_ref}")
        # print(f"rt_rf0[0] = {optimization_inputs0['frames_rt_toref'][0]}")
        # print(f"Rt_fr0[0] = {Rt_fr0[0]}")
        # print(f"p0_frame = {p0_frame}")
        # print(f"p1_refall = {p1_refall}")
        # print(f"p1_ref = {p1_ref}")

        p1_cam = mrcal.transform_point_Rt(Rt_cr1, p1_ref)
        q1 = mrcal.project(p1_cam, *models[1].intrinsics())

        # print(f"p1_cam = {p1_cam}")
        # print(f"q1 = {q1}")
        # print(f"before:\n{Rt10}")

        Rt10 = Rt_r1r0

        # print(f"after:\n{Rt10}")



        # print(f"--same-dance Rt10: {Rt10}")
        # print(f"--same-dance Ninliers = {np.count_nonzero(idx_inliers[0]*idx_inliers[1])} Ntotal = {np.size(idx_inliers[0])}")

        # print(f"idx_inliers[0].shape: {idx_inliers[0].shape}")
        # import gnuplotlib as gp
        # gp.plot(nps.mag(calobjects[0] - calobjects[1]), wait=1)
        # # gp.plot( (calobjects[0] - calobjects[1],), tuplesize=-3, _with='points', square=1, _3d=1, xlabel='x', ylabel = 'y', zlabel = 'z')
        # # import IPython
        # # IPython.embed()
        # # sys.exit()

if args.observations:
    optimization_inputs = [ m.optimization_inputs() for m in models ]
    if any( oi is None for oi in optimization_inputs ):
        print("mrcal-show-projection-diff --observations requires optimization_inputs to be available for all models, but this is missing for some models",
              file=sys.stderr)
        sys.exit(1)

plot,Rt10 = mrcal.show_projection_diff(models,
                                       gridn_width             = args.gridn[0],
                                       gridn_height            = args.gridn[1],
                                       observations            = args.observations,
                                       valid_intrinsics_region = args.valid_intrinsics_region,
                                       distance                = distance,
                                       intrinsics_only         = args.intrinsics_only,
                                       use_uncertainties       = not args.no_uncertainties,
                                       focus_center            = args.where,
                                       focus_radius            = args.radius,
                                       vectorfield             = args.vectorfield,
                                       vectorscale             = args.vectorscale,
                                       directions              = args.directions,
                                       hardcopy                = args.hardcopy,
                                       terminal                = args.terminal,
                                       cbmax                   = args.cbmax,
                                       **plotkwargs_extra)

if not args.intrinsics_only and args.radius != 0 and \
   Rt10 is not None and Rt10.shape == (4,3):
    rt10 = mrcal.rt_from_Rt(Rt10)

    print(f"Transformation cam1 <-- cam0:  rotation: {nps.mag(rt10[:3])*180./np.pi:.03f} degrees, translation: {rt10[3:]} m")

if args.hardcopy is None:
    plot.wait()
