#!/usr/bin/python3
# -*- coding: utf-8 -*-
# Onboard is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Onboard is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright © 2012, marmuta <marmvta@gmail.com>
#
# This file is part of pypredict, Onboard.
#

import os
import sys
import fnmatch
import subprocess
import re
from optparse import OptionParser
from collections import Counter

import sys
from os.path import dirname, abspath

from pypredict import *


def main():
    parser = OptionParser(usage="Usage: %prog [options] model_in model_out")

    # pruning
    parser.add_option("-p", "--prune-counts",
              type="str", dest="prune_counts", default="",
              help="prune n-grams with counts below or equal <prune-count>")
    parser.add_option("-u", "--max-unigrams", 
              type="int", dest="max_unigrams", default=0,
              help="prune n-grams with counts below or equal the one of the "
                   "least frequent of the top <max-unigrams> unigrams;"
                   "default 0, disabled")

    # regex filters
    parser.add_option("-r", "--regex-unigram",
              type="str", dest="regex_unigram", default="",
              help="regular expression filter applied to unigrams. Matches are dropped.")
    parser.add_option("-n", "--regex-ngram",
              type="str", dest="regex_ngram", default="",
              help="regular expression filter applied to the space separated "
                   "ngram string. Matches are dropped.")
    parser.add_option("-t", "--title-case",
              action="store_true", dest="title_case", default=False,
              help="Keep only capitalized words for bigrams and up")

    # spell check
    parser.add_option("-l", "--language", type="str", dest="lang_id", default="",
              help="language id for the spell checker, e.g. en_US")
    parser.add_option("-N", "--filter-names",
              action="store_true", dest="filter_names", default=False,
              help="Remove capitalized word if the lower caps variant is "
                   "known to the spell checker too.")
    parser.add_option("-x", "--name-exceptions",
              type="str", dest="name_exceptions", default="",
              help="Exempt a comma separated list of words from being "
                   "removed by --filter-names")
    parser.add_option("-i", "--lu-ratio",
              type="float", dest="lu_ratio", default=5.0,
              help="Max. ratio of uppercase to lowercase occurrences of "
                   "capitalized words. Increase to drop more upper case words.")

    # diagnose
    parser.add_option("-d", "--diagnose-ngram",
              type="str", dest="diagnose_ngram", default="",
              help="Test for existence of a comma-separated n-gram "
                   "at every filter step.")

    # not currently used
    parser.add_option("-v", "--vocabulary", type="str", dest="vocabulary_file",
              help="list of words to consider during model creation")
    parser.add_option("-c", "--caps-ngram-len",
              type="int", dest="caps_bigram_len", default=0,
              help="capitalized word filter for bigrams, len=min word length")

    # options
    parser.add_option("-S", "--save-sorted",
              action="store_true", dest="save_sorted", default=False,
              help="Load and re-save the final model to take advantage of"
                   "unigram-sorting on load. Also verifies file integrity.")
    parser.add_option("-q", "--quiet",
              action="store_true", dest="quiet", default=False,
              help="only show the final summary")
    options, args = parser.parse_args()

    vocabulary = read_vocabulary(options.vocabulary_file) \
                 if options.vocabulary_file else None

    out = None if options.quiet else sys.stdout
    caps_bigram_len = options.caps_bigram_len
    model_in_filename = args[0]
    model_out_filename = args[1]
    lang_id = options.lang_id
    diagnose_ngram = None
    if options.diagnose_ngram:
        diagnose_ngram = options.diagnose_ngram.split(",")

    spell_checker = None
    if lang_id:
        spell_checker = SpellChecker()
        spell_checker.set_backend(0)
        if not spell_checker.set_dict_ids([lang_id]):
            print("No spell checker dictionary found for '{}'".format(lang_id),
                  file = sys.stderr)
            sys.exit(1)

    with timeit("loading " + model_in_filename, out):

        if read_order(model_in_filename) == 1:
            model = UnigramModel()
        else:
            model = DynamicModel()
        model.load(model_in_filename)

    check_ngram(model, diagnose_ngram)

    if options.max_unigrams:
        with timeit("prune by max unigrams", out):
            cnt = Counter(tokens)
            most_common = cnt.most_common(options.max_unigrams)
            if most_common:
                min_token = most_common[-1]
                _prune_count = min_token[1]
                #print("pruning", min_token, prune_count)
                model = model.prune(_prune_count)

    check_ngram(model, diagnose_ngram)

    if options.prune_counts:
        prune_counts = [int(c) for c in options.prune_counts.split(",")]
        with timeit("prune by count", out):
            model = model.prune(prune_counts)

    check_ngram(model, diagnose_ngram)

    if options.regex_unigram:
        with timeit("regex unigram filter", out):
            model = regex_filter_unigram(model, options.regex_unigram)

    check_ngram(model, diagnose_ngram)

    if options.regex_ngram:
        with timeit("regex joined ngram filter", out):
            model = regex_filter_ngram(model, options.regex_ngram)

    check_ngram(model, diagnose_ngram)

    if options.title_case:
        with timeit("title case filter", out):
            model = title_case_filter(model)

    check_ngram(model, diagnose_ngram)

    if caps_bigram_len:
        with timeit("caps ngram filter", out):
            model = caps_bigram_filter(model, caps_bigram_len)

    check_ngram(model, diagnose_ngram)

    if spell_checker:
        name_exceptions = options.name_exceptions.split(",")
        model = spell_check_model(model, spell_checker,
                                  options.filter_names,
                                  name_exceptions, 
                                  options.lu_ratio,
                                  out)

    check_ngram(model, diagnose_ngram)

    with timeit("saving " + model_out_filename, out):
        model.save(model_out_filename)

    if options.save_sorted:
        with timeit("loading " + model_out_filename, out):
            model.load(model_out_filename)

        with timeit("saving " + model_out_filename, out):
            model.save(model_out_filename)

    print_stats(model)

def check_ngram(model, ngram):
    if ngram:
        count = model.get_ngram_count(ngram)
        print ("ngram {} has count {}".format(repr(ngram), count))

def regex_filter_unigram(model, regex):
    out_model = model.__class__(model.order)
    pattern = re.compile(regex, re.VERBOSE)

    for it in model.iter_ngrams():
        ngram = it[0]
        count = it[1]
        for token in ngram:
            if pattern.search(token):
                break
        else:
            out_model.count_ngram(ngram, count)
    return out_model

def regex_filter_ngram(model, regex):
    out_model = model.__class__(model.order)
    pattern = re.compile(regex, re.VERBOSE)

    for it in model.iter_ngrams():
        ngram = it[0]
        count = it[1]
        ngram_str = " ".join(ngram)
        if not pattern.search(ngram_str):
            out_model.count_ngram(ngram, count)
    return out_model

def caps_bigram_filter(model, word_length):
    """
    Rather hard-coded filter to keep only capitalized bigrams
    above a certain length.

    Pythone re doesn't support upper/lower character class,
    so we can't reliably use the regex filters for this.
    """
    out_model = model.__class__(model.order)

    for it in model.iter_ngrams():
        ngram = it[0]
        count = it[1]

        if len(ngram) == 1 or \
           _is_caps_ngram(ngram, word_length):
            out_model.count_ngram(ngram, count)

    return out_model

def title_case_filter(model):
    """ Leaves only bigrams and up with all words starting uppercase. """
    out_model = model.__class__(model.order)

    for it in model.iter_ngrams():
        ngram = it[0]
        count = it[1]

        if len(ngram) == 1:
            out_model.count_ngram(ngram, count)
        else:
            for word in ngram:
                if word[0].islower():
                    break
            else:
                out_model.count_ngram(ngram, count)

    return out_model

def _is_caps_ngram(ngram, word_length):
    for i, word in enumerate(ngram):
        minlen = word_length if i else word_length + 1
        if not word[0].isupper() or len(word) < minlen:
            return False
    return True

def spell_check_model(model, spell_checker,
                      filter_names, name_exceptions, max_lu_ratio,
                      out):
    out_model = model.__class__(model.order)

    unigrams = []
    with timeit("reading model", out):
        for it in model.iter_ngrams():
            ngram = it[0]
            if len(ngram) == 1:
                unigrams.append(ngram[0])

    with timeit("checking spelling of {} unigrams".format(len(unigrams)), out):
        step = 1000
        correct_words = set()
        correct_words_non_caps = set()
        for i in range(0, len(unigrams), step):

            tokens = unigrams[i:i+step]
            print ("spell-checked {}/{} tokens, checking {} ...)" \
                   .format(i, len(unigrams), len(tokens)))
            results = spell_checker.query(tokens)

            tokens = [w.lower() for w in results]
            results_non_caps = spell_checker.query(tokens)

            correct_words.update(results)
            correct_words_non_caps.update(results_non_caps)

    with timeit("creating spell-checked model", out):
        dropped_caps = []
        name_exceptions += ["Delhi"]
        for it in model.iter_ngrams():
            ngram = it[0]
            count = it[1]

            for word in ngram:
                if not word in name_exceptions:

                    # drop unknown spellings
                    if not word in correct_words:
                        break

                    # drop uppercase words that can be correctly spelled
                    # as lower case words too.
                    word_lower = word.lower()
                    if filter_names and \
                       word[0].isupper() and \
                       word_lower in correct_words_non_caps:

                        # Only drop words that are more frequently
                        # written lower case.
                        count_upper = model.get_ngram_count([word])
                        count_lower = model.get_ngram_count([word_lower])
                        lu_ratio = count_lower / float(count_upper)
                        if lu_ratio > max_lu_ratio:
                            if len(ngram) == 1:
                                dropped_caps.append((lu_ratio,
                                                     count_upper,
                                                     count_lower,
                                                     word))
                            break
            else:
                out_model.count_ngram(ngram, count)

        if 1:
            dropped_caps.sort()
            for i, tpl in enumerate(dropped_caps):
                lu_ratio, count, count_lower, word = tpl
                print("dropping capitalized: "
                      "#{:6} upper {:8} lower {:8} l/u {:8.1f} {}" \
                      .format(i, count, count_lower, lu_ratio , word))

    return out_model

def print_stats(model):
    counts, totals = model.get_counts()
    print("calculating stats...")
    for i,c in enumerate(counts):
        sys.stdout.write("%d-grams: types %10d, occurences %10d\n" % \
              (i+1, counts[i], totals[i]))
    print("mem_size", model.memory_size(), sum(model.memory_size()),
          "(-S for best estimate)")


class SpellChecker:

    def __init__(self):
        self.dict_ids = []

    def set_backend(self, backend):
        pass

    def set_dict_ids(self, dict_ids):
        self.dict_ids = dict_ids
        return True

    def query(self, tokens):
        correct_words = []

        args = ["hunspell", "-G", "-i", "UTF-8"]
        if self.dict_ids:
            args += ["-d", ",".join(self.dict_ids)]

        p = None
        try:
            p = subprocess.Popen(args, stdin=subprocess.PIPE,
                                       stdout=subprocess.PIPE,
                                       close_fds=True)
        except OSError as e:
            _logger.error(_format("Failed to execute '{}', {}", \
                            " ".join(args), e))

        # Check if the process is still running, it might have
        # exited on start due to an unknown dictinary name.
        if p and p.poll() is None:
            for token in tokens:
                line = (token + "\n").encode("UTF-8")
                p.stdin.write(line)
            p.stdin.close()

            while True:
                line = p.stdout.readline().decode("UTF-8")
                if not line:
                    break
                token = line.strip()
                correct_words.append(token)

        if p:
            p.terminate()
            p.wait()

        return correct_words


if __name__ == '__main__':
    main()

