#!/usr/bin/python3
# -*- coding: utf-8 -*-

# Copyright © 2010, 2012 marmuta <marmvta@gmail.com>
#
# This file is part of Onboard.
#
# Onboard is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Onboard is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import sys, random, string, math
import copy
from optparse import OptionParser

import matplotlib.pyplot as plt


import pypredict
from pypredict import *

class Annealing:
    def __call__(self, iterations, attenuation, progress):

        current = self.init()
        current_fitness = self.fitness(current)
        best, best_fitness = current, current_fitness

        temperature = 1.0
        for iteration in range(iterations):
            next = self.modify(current, temperature)
            next_fitness = self.fitness(next)
            progress(next, next_fitness, best, best_fitness)

            if next_fitness > current_fitness:
                best, best_fitness = next, next_fitness
                current, current_fitness = next, next_fitness
                print "accepted: ", current_fitness, current

            temperature *= attenuation

        return (best_fitness, best)

class Scenario:
    def __repr__(self):
        return "recency_smoothing=%d recency_ratio=%.3f " \
               "recency_halflife=%.0f recency_lambdas=[%s]" % \
              (self.recency_smoothing, self.recency_ratio,
               self.recency_halflife,
               ", ".join("%.3f" % l for l in self.recency_lambdas))

def random_step(value, minval, maxval, temperature):
    d = (maxval - minval) * temperature
    a = max(minval, value - d)
    b = min(maxval, value + d)
    return random.random() * (b - a) + a

class KsrAnnealing(Annealing):

    def init(self):
        self.cache = {}
        s = Scenario()
        s.recency_smoothing = 0

        s.recency_ratio = 0.58
        s.recency_halflife = 80
        s.recency_lambdas = [0.3, 0.3, 0.3]

        s.recency_ratio = 0.813
        s.recency_halflife = 97
        s.recency_lambdas = [0.155, 0.147, 0.411]

        #ksr= 26.0088272383 recency_smoothing=0 recency_ratio=0.784 recency_halflife=78 recency_lambdas=[0.154, 0.658, 0.270]
        s.recency_ratio = 0.784
        s.recency_halflife = 78
        s.recency_lambdas = [0.154, 0.658, 0.270]

        #ksr= 26.013126218 recency_smoothing=0 recency_ratio=0.811 recency_halflife=96 recency_lambdas=[0.404, 0.831, 0.444]
        s.recency_ratio=0.811
        s.recency_halflife=96
        s.recency_lambdas=[0.404, 0.831, 0.444]

        #ksr= 26.0145592113 recency_smoothing=0 recency_ratio=0.811 recency_halflife=96 recency_lambdas=[0.159, 0.967, 0.036]
        s.recency_ratio=0.811
        s.recency_halflife=96
        s.recency_lambdas=[0.159, 0.967, 0.036]

        s.recency_smoothing = 1
        if 0:
            s.recency_smoothing = 0
            s.recency_ratio = random.random()
            s.recency_halflife = random.randrange(1,200)
            s.recency_lambdas = [random.random(), random.random(), random.random()]
        return s

    def modify(self, scenario, temperature):
        temperature *= 2
        s = copy.copy(scenario)
        s.recency_ratio = random_step(s.recency_ratio, 0, 1.0, temperature)
        s.recency_halflife = random_step(s.recency_halflife, 1, 200, temperature)
        for i in range(3):
            s.recency_lambdas[i] = random_step(s.recency_lambdas[i], 0, 1.0, temperature)
        return s

    def fitness(self, scenario):
        s = scenario
        v = (s.recency_smoothing, s.recency_ratio, s.recency_halflife,
             tuple(s.recency_lambdas))
        if v in self.cache:
            return self.cache[v]

        for m in models:
            m.smoothing = "abs-disc"
        model = overlay(models)
        learn_model = models[1]
        learn_model.clear()
        learn_model.recency_ratio = s.recency_ratio
        learn_model.recency_halflife = s.recency_halflife
        index = min(max(int(s.recency_smoothing),0),len(recency_smoothings)-1)
        learn_model.recency_smoothing = recency_smoothings[index]
        learn_model.lambdas = s.recency_lambdas
        ksr = pypredict.ksr(model, learn_model, testing_sentences, 10)#,
                         #lambda i,n, c, p: sys.stdout.write("%d/%d\n" % (i+1,n)))
        #ksr = -((s.recency_ratio-.2)**2) + 26
        self.cache[v] = ksr
        return ksr

class PlotProgress:
    def __init__(self):
        self.xvalues = []
        self.ksrs = []
        self.best_ksrs = []

    def __call__(self, scenario, ksr, best_scenario, best_ksr):
        print "best:    ksr=", best_ksr, best_scenario
        print "current: ksr=", ksr, scenario
        self.xvalues.append(len(self.xvalues))
        self.ksrs.append(ksr)
        self.best_ksrs.append(best_ksr)
        plt.ion()  # interactive mode on
        plt.figure(1)
        plt.clf()
        lines = [
            plt.plot(self.xvalues, self.ksrs),
            plt.plot(self.xvalues, self.best_ksrs),
        ]
        labels = ["current", "best"]
        plt.xlabel("iteration")
        plt.ylabel('ksr [%]')
        ymin, ymax = plt.ylim()
        plt.ylim(ymin, ymax+(ymax-ymin)*0.05)
        plt.figlegend(lines, labels, 'upper right')
        plt.gcf().suptitle('Optimizing Recency Caching',
                           fontsize=16)
        plt.draw()

def optimize_caching(base_model, testing, order):
    global models, testing_text, testing_sentences, testing_tokens
    global recency_smoothings

    models = []
    recency_smoothings = ["jelinek-mercer", "witten-bell"]

    with timeit("loading base model '%s'" % (base_model,)):
        model = DynamicModelKN(order)
        model.load(base_model)
        models.append(model)

    model = CachedDynamicModel(order)
    models.append(model)

    filename = testing
    with timeit("tokenizing '%s'" % (filename,)):
        testing_text = read_corpus(filename)
        testing_sentences = split_sentences(testing_text)
        testing_tokens, spans = tokenize_text(testing_text)

    # simulated annealing
    print KsrAnnealing()(200, 0.95, PlotProgress())

    plt.show()  # blocks; allows for interaction with the chart, saving images


def main():
    parser = OptionParser(usage= \
"""Usage: %prog [options] caching <base model> <testing text>
           <base model>   is a static base language model
           <testing text> is simulated input that incrementally
                          trains the second language model""")
    options, args = parser.parse_args()
    if len(args) < 1:
        parser.print_usage()
        sys.exit(1)

    order = 3
    command = args[0].lower()

    if command == "caching":
        if len(args) < 3:
            parser.print_usage()
            sys.exit(1)
        optimize_caching(args[1], args[2], order)
    else:
        print "unknown command '%s', exiting" % command
        sys.exit(1)

if __name__ == '__main__':
    main()


