#!/usr/bin/env python3

# MIT License
#
# Copyright (c) 2018 Sam Kovaka <skovaka@gmail.com>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import sys                         
import os
import argparse
import time
import re
import time
import subprocess
import traceback
import uncalled as unc
import numpy as np

MAX_SLEEP = 0.01

def index_cmd(args):

    if args.bwa_prefix == None:
        args.bwa_prefix = args.fasta_filename

    bwa_built = True

    for suff in unc.index.BWA_SUFFS:
        if not os.path.exists(args.bwa_prefix + suff):
            bwa_built = False
            break

    if bwa_built:
        sys.stderr.write("Using previously built BWA index.\nNote: to fully re-build the index delete files with the \"%s.*\" prefix.\n" % args.bwa_prefix)
    else:
        unc.BwaIndex.create(args.fasta_filename, args.bwa_prefix)

    sys.stderr.write("Initializing parameter search\n")
    p = unc.index.IndexParameterizer(args)

    p.add_preset("default", tgt_speed=115)

    if args.probs != None:
        for tgt in args.probs.split(","):
            sys.stderr.write("Writing 'prob_%s' parameters\n" % tgt)
            try:
                p.add_preset("prob_%s" % tgt, tgt_prob=float(tgt))
            except Exception as e:
                sys.stderr.write("Failed to add 'prob_%s'\n" % tgt)

    if args.speeds != None:
        for tgt in args.speeds.split(","):
            sys.stderr.write("Writing 'speed_%s' parameters\n" % tgt)
            try:
                p.add_preset("speed_%s" % tgt, tgt_speed=float(tgt))
            except:
                sys.stderr.write("Failed to add 'speed_%s'\n" % tgt)

    p.write()

    sys.stderr.write("Done\n")

def fast5_path(fname):
    if fname.startswith("#") or not fname.endswith("fast5"):
        return None

    path = os.path.abspath(fname)
    if not os.path.isfile(path):
        sys.stderr.write("Warning: \"%s\" is not a fast5 file.\n" % fname)
        return None

    return path

def load_fast5s(fast5s, recursive):
    for path in fast5s:
        path = path.strip()

        if not os.path.exists(path):
            sys.stderr.write("Error: \"%s\" does not exist\n" % path)
            sys.exit(1)

        isdir = os.path.isdir(path)

        #Recursive directory search 
        if isdir and recursive:
            for root, dirs, files in os.walk(path):
                for fname in files:
                    yield fast5_path(os.path.join(root, fname))

        #Non-recursive directory search 
        elif isdir and not recursive:
            for fname in os.listdir(path):
                yield fast5_path(os.path.join(path, fname))

        #Read fast5 name directly
        elif path.endswith(".fast5"):
            yield fast5_path(path)

        #Read fast5 filenames from text file
        else:
            with open(path) as infile:
                for line in infile:
                    yield fast5_path(line.strip())

def assert_exists(fname):
    if not os.path.exists(fname):
        sys.stderr.write("Error: '%s' does not exist\n" % fname)
        sys.exit(1)

def map_cmd(conf, args):

    assert_exists(conf.bwa_prefix + ".bwt")
    assert_exists(conf.bwa_prefix + ".uncl")

    #for fname in open(conf.fast5_list):
    #    assert_exists(fname.strip())

    if len(conf.read_list) > 0:
        assert_exists(conf.read_list)

    sys.stderr.flush()

    mapper = unc.MapPool(conf)

    sys.stderr.write("Loading fast5s\n")
    for fast5 in load_fast5s(args.fast5s, args.recursive):
        if fast5 != None:
            mapper.add_fast5(fast5)

    sys.stderr.flush()

    sys.stderr.write("Mapping\n")
    sys.stderr.flush()

    n = 0

    try:
        while mapper.running():
            t0 = time.time()
            for p in mapper.update():
                p.print_paf()
                n += 1
            dt = time.time() - t0;
            if dt < MAX_SLEEP:
                time.sleep(MAX_SLEEP - dt);
    except KeyboardInterrupt:
        pass
    
    sys.stderr.write("Finishing\n")
    mapper.stop()

def realtime_cmd(conf, args):

    #TODO replace with conf mode
    sim = args.subcmd == "sim"

    assert_exists(conf.bwa_prefix + ".bwt")
    assert_exists(conf.bwa_prefix + ".uncl")

    pool = None
    client = None

    try:
        if sim:
            client = unc.ClientSim(conf)
            unc.sim_utils.load_sim(client, conf)

            for fast5 in load_fast5s(args.fast5s, args.recursive):
                if fast5 != None:
                    client.add_fast5(fast5)

            client.load_fast5s()

        else:
            if not unc.minknow_client.ru_loaded:
                sys.stderr.write("Error: read_until module not installed. Please install \"read_until_api\" submodule.\n")
                sys.exit(1)
            client = unc.minknow_client.Client(conf.host, conf.port, conf.chunk_time, conf.num_channels)

        if not client.run():
            sys.exit(1)

        deplete = conf.realtime_mode == unc.RealtimePool.DEPLETE
        even = conf.active_chs == unc.RealtimePool.EVEN #TODO: do within mapper

        if not sim:
            raw_type = str(client.signal_dtype)

        pool = unc.RealtimePool(conf)

        chunk_times = [time.time() for c in range(conf.num_channels)]
        unblocked = [None for c in range(conf.num_channels)]

        if conf.duration == None or conf.duration == 0:
            end_time = float("inf")
        else:
            end_time = conf.duration*60*60

        while client.is_running:
            t0 = time.time()

            for ch, nm, paf in pool.update():
                t = time.time()-chunk_times[ch-1]
                if paf.is_ended():
                    paf.set_float(unc.Paf.ENDED, t)
                    client.stop_receiving_read(ch, nm)

                elif (paf.is_mapped() and deplete) or not (paf.is_mapped() or deplete):

                    if sim or client.should_eject():
                        paf.set_float(unc.Paf.EJECT, t)
                        u = client.unblock_read(ch, nm)

                        if sim:
                            paf.set_int(unc.Paf.DELAY, u)

                        unblocked[ch-1] = nm
                    else:
                        paf.set_float(unc.Paf.IN_SCAN, t)
                        client.stop_receiving_read(ch, nm)

                else:
                    paf.set_float(unc.Paf.KEEP, t)
                    client.stop_receiving_read(ch, nm)

                paf.print_paf()

            if sim:
                read_batch = client.get_read_chunks()
                for channel, read in read_batch:
                    if even and channel % 2 == 1:
                        client.stop_receiving_read(channel, read.number)
                    else:
                        if unblocked[channel-1] == read.number:
                            sys.stdout.write("# recieved chunk from %s after unblocking\n" % read.id)
                            continue

                        chunk_times[channel-1] = time.time()
                        pool.add_chunk(read)
       
            else:

                read_batch = client.get_read_chunks(batch_size=client.queue_length)
                for channel, read in read_batch:
                    if even and channel % 2 == 1:
                        client.stop_receiving_read(channel, read.number)
                    else:
                        if unblocked[channel-1] == read.number:
                            sys.stdout.write("# recieved chunk from %s after unblocking\n" % read.id)
                            continue

                        chunk_times[channel-1] = time.time()
                        pool.add_chunk(unc.Chunk(read.id, 
                                                     channel, 
                                                     read.number,
                                                     read.chunk_start_sample,
                                                     raw_type,
                                                     read.raw_data))


            if client.get_runtime() >= end_time:
                if not sim:
                    client.reset()
                client = None
                break

            dt = time.time() - t0;
            if dt < MAX_SLEEP:
                time.sleep(MAX_SLEEP - dt);

    except KeyboardInterrupt:
        sys.stderr.write("Keyboard interrupt\n")

    except Exception as e:
        sys.stderr.write(traceback.format_exc())

    #client.log("Finished")

    if client != None and not sim:
        client.reset()

    if pool != None:
        pool.stop_all()

def sim_cmd(args):
    unc.sim_utils.load_sim()

#TODO fix
def list_ports_cmd(args):
    log_re = re.compile("^([0-9\-]+ [0-9:]+).+ :.+instance_started.+")
    port_re = re.compile("grpc_port = (\d+)")
    device_re = re.compile("(device_id|instance) = ([^\s,]+)")

    fnames = os.listdir(args.log_dir)
    log_fnames = list(sorted( (f for f in fnames if f.startswith("mk_manager_svc")) ))
    latest_log = os.path.join(args.log_dir, log_fnames[-1])

    for line in open(latest_log):
        lm = log_re.match(line)
        if not lm: continue
        pm = port_re.search(line)
        dm = device_re.search(line)

        if pm is None or dm is None:
            sys.stderr.write("Error: failed to parse \"%s\"\n" % line.strip())
            continue

        timestamp = lm.group(1)
        port = pm.group(1)
        device = dm.group(2)

        sys.stdout.write("%s (%s): %s\n" % (device, timestamp, port))



if __name__ == "__main__":
    

    parser, conf, args = unc.args.load_conf(sys.argv)

    if args.subcmd == "index":
        index_cmd(args)
    elif args.subcmd == "map":
        map_cmd(conf, args)
    elif args.subcmd in {"sim", "realtime"}:
        realtime_cmd(conf, args)
    elif args.subcmd == "list-ports":
        list_ports_cmd(args)
    elif args.subcmd == "pafstats":
        unc.pafstats.run(args)
    else:
        parser.print_help()
        
