from contextlib import contextmanager
import numpy as np
import tensorflow as tf
from time import time
from analysis.util import atq
from io import BytesIO
import h5py
from tensorflow.python.keras.engine.saving import save_weights_to_hdf5_group, load_weights_from_hdf5_group

from evil import pin, ccall


from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False


def last_map(func, obj, order=1):
    return tf.concat([
        obj[..., :-1],
        tf.expand_dims(
            last_map(func=func, obj=obj[..., -1], order=order - 1)
            if 1 < order else
            func(obj[..., -1]),
            axis=-1
        ),
    ], axis=-1)


def tf_meanstd(val, **kwargs):
    mean = tf.reduce_mean(val, **kwargs)
    std = tf.reduce_mean(tf.square(val), **kwargs) - tf.square(mean)
    return mean, std


# this is mostly unused ...

class FD(dict):
    def add(self, data, dtype=None, shape=None, **kwargs):
        if shape is True:
            shape = data.shape
        elif shape is None:
            shape = (None,) + data.shape[1:]
        if dtype is None:
            dtype = data.dtype
        ph = tf.placeholder(dtype, shape=shape, **kwargs)
        self[ph] = data
        return ph


class Chain(object):
    def __init__(self, __name__, loss, sumnbs=(), step=None, opt=None, train=None, ema=(1, 2, 3, 4, 5), **kwargs):
        if step is None:
            step = tf.Variable(0, trainable=False, name="%s_step" % __name__)
        if opt is None:
            opt = tf.train.AdamOptimizer()
        if train is None:
            train = opt.minimize(loss, global_step=step)
        reset_opt = tf.variables_initializer(opt.variables())

        summaries = []
        _ema_val = {}
        _ema_out = SKDict()
        kwargs["loss"] = loss
        with tf.name_scope(__name__):
            for key, val in kwargs.items():
                if not isinstance(val, tf.Tensor):
                    continue
                nd = val.shape.ndims
                if nd > 1 and sumnbs is not True and val not in sumnbs:
                    continue
                if val.dtype == tf.bool:
                    val = tf.reduce_mean(tf.cast(val, tf.float32))
                if nd:
                    # summaries.append(tf.summary.histogram(key, val))
                    mean, std = tf_meanstd(val)
                    summaries.append(tf.summary.scalar(key + "_mean", mean))
                    summaries.append(tf.summary.scalar(key + "_std", std))
                    _ema_val[key + "_mean"] = mean
                    _ema_val[key + "_std"] = std
                else:
                    summaries.append(tf.summary.scalar(key, val))
                    _ema_val[key] = val

        if ema:
            _ema_ops = []
            reset_ema_to = tf.constant(0, tf.int32)
            ema_step = tf.Variable(reset_ema_to, trainable=False, name="%s_ema_step" % __name__)
            reset_ema = ema_step.initializer
            for i in ema:
                with tf.name_scope("%s_%d" % (__name__, i)):
                    ema = tf.train.ExponentialMovingAverage(decay=1. - tf.maximum(
                        0.1**i, tf.exp((np.log(0.1**i) / 10.) * tf.cast(ema_step, tf.float32))
                    ))
                    _ema_ops.append(ema.apply(_ema_val.values()))
                    for key, val in _ema_val.items():
                        val = ema.average(val)
                        _ema_out[key, i] = val
                        summaries.append(tf.summary.scalar("%s" % key, val))
            with tf.control_dependencies([train]):
                with tf.control_dependencies(_ema_ops):
                    train = ema_step.assign_add(1)

        ema = _ema_out
        summaries = tf.summary.merge(summaries)
        self.__dict__.update(kwargs)
        pin(locals(), True, kwargs)


Chain.cc = ccall(Chain, True)


class Runner(object):
    def __init__(self, sess, writer):
        self.step = 0
        pin(locals())

    def __call__(self, *args, **kwargs):
        return self.sess.run(*args, **kwargs)

    def train(self, chain, extra=None, step=1, sumiv=10, sumskip=0, **kwargs):
        if extra is None:
            extra = chain.loss
        for i in count():
            for j in xrange(step):
                self(chain.train, **kwargs)
            sum = sumiv and not (i < sumskip or i % sumiv)
            get = (chain.train, extra)
            if sum:
                get += (chain.summaries, chain.step)
            out = self(get, **kwargs)
            if sum:
                self.writer.add_summary(*out[2:])
            yield out[1]

    def conf(self, *args, **kwargs):
        return Confirmer(self.train(*args, **kwargs).next)

    @classmethod
    @contextmanager
    def make(cls, path, gpuOpts={}, flush_secs=20):
        with cls.make_session(gpuOpts=gpuOpts) as sess:
            with tf.summary.FileWriter(path, session=sess, flush_secs=flush_secs, graph=sess.graph) as writer:
                yield cls(sess, writer)

    @classmethod
    def make_session(cls, gpuOpts={}):
        return tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(**gpuOpts)))


class Confirmer(object):
    last = np.inf
    task = None
    tdelay = 60
    tbestbump = 10

    def __init__(self, func):
        self.reset.func = func

    def __call__(self, target, its=np.inf, pre=0, **kwargs):
        if not kwargs.setdefault("disable", not kwargs):
            kwargs.setdefault("auto", self._prog)
        assert 0 < target
        tnext = 0
        with atq(xrange(its), **kwargs) as tq:
            for i in tq:
                if self.task is not None:
                    if self.task.check_stop():
                        break
                    if tnext < time():
                        tnext = time() + self.tdelay
                        pTotal = 100. * i / its
                        self.task.set_progress_percentage(pTotal)
                        self.task.set_status_message("total=%.1f%% conf=%.1f%% best=%.2e" % (
                            pTotal, 100. * self.conf / target, self.best
                        ))
                if not (self.conf < target):
                    break
                self.step += 1
                self.last = self.func()
                if self.step < pre:
                    continue
                elif self.best < self.last:
                    self.conf += 1
                else:
                    self.on_best()
                    self.conf = 0
                    self.best = self.last
                    tnext -= self.tbestbump
        return not (self.conf < target)

    def on_best(self):
        pass

    @property
    def reset(self):
        self.last = self.best = np.inf
        self.step = self.conf = 0
        return self

    @property
    def rdlb(self):
        # Relative Delta Last to Best
        return np.float_(self.last - self.best) / self.best

    def _prog(self, i, prefix=""):
        return {
            prefix + k: v
            for k, v in dict(
                step=self.step if self.step != i else None,
                last=self.last if self.rdlb else None,
                rdlb=self.rdlb or None,
                conf=self.conf or None,
                best=self.best,
            ).items()
            if v is not None
        }

    def __repr__(self):
        return "Confirmer(step=%d, best=%.3e, conf=%d)" % (
            self.step,
            self.best,
            self.conf
        )


class Bestie(object):
    def __init__(self, model, path):
        pin(locals())

    def save(self):
        self.model.save_weights(self.path)

    def load(self):
        self.model.load_weights(self.path)

    def finish(self):
        self.load()
        self.model.save(self.path)


class BestieMemory(Bestie):
    def __init__(self, *args, **kwargs):
        super(BestieMemory, self).__init__(*args, **kwargs)

    def save(self):
        self.h5 = h5py.File(BytesIO())
        save_weights_to_hdf5_group(self.h5, self.model.layers)

    def load(self):
        load_weights_from_hdf5_group(self.h5, self.model.layers)