from contextlib import contextmanager
import numpy as np
import tensorflow as tf
from time import time
from analysis.util import atq
from io import BytesIO
import h5py
from tensorflow.python.keras.engine.saving import (
    save_weights_to_hdf5_group,
    load_weights_from_hdf5_group,
)

from evil import pin, ccall


from tensorflow.python.util import deprecation

deprecation._PRINT_DEPRECATION_WARNINGS = False


def last_map(func, obj, order=1):
    return tf.concat(
        [
            obj[..., :-1],
            tf.expand_dims(
                last_map(func=func, obj=obj[..., -1], order=order - 1)
                if 1 < order
                else func(obj[..., -1]),
                axis=-1,
            ),
        ],
        axis=-1,
    )


def tf_meanstd(val, **kwargs):
    mean = tf.reduce_mean(val, **kwargs)
    std = tf.reduce_mean(tf.square(val), **kwargs) - tf.square(mean)
    return mean, std


# this is mostly unused ...


class FD(dict):
    def add(self, data, dtype=None, shape=None, **kwargs):
        if shape is True:
            shape = data.shape
        elif shape is None:
            shape = (None,) + data.shape[1:]
        if dtype is None:
            dtype = data.dtype
        ph = tf.placeholder(dtype, shape=shape, **kwargs)
        self[ph] = data
        return ph


class Chain(object):
    def __init__(
        self,
        __name__,
        loss,
        sumnbs=(),
        step=None,
        opt=None,
        train=None,
        ema=(1, 2, 3, 4, 5),
        **kwargs
    ):
        if step is None:
            step = tf.Variable(0, trainable=False, name="%s_step" % __name__)
        if opt is None:
            opt = tf.train.AdamOptimizer()
        if train is None:
            train = opt.minimize(loss, global_step=step)
        reset_opt = tf.variables_initializer(opt.variables())

        summaries = []
        _ema_val = {}
        _ema_out = SKDict()
        kwargs["loss"] = loss
        with tf.name_scope(__name__):
            for key, val in kwargs.items():
                if not isinstance(val, tf.Tensor):
                    continue
                nd = val.shape.ndims
                if nd > 1 and sumnbs is not True and val not in sumnbs:
                    continue
                if val.dtype == tf.bool:
                    val = tf.reduce_mean(tf.cast(val, tf.float32))
                if nd:
                    # summaries.append(tf.summary.histogram(key, val))
                    mean, std = tf_meanstd(val)
                    summaries.append(tf.summary.scalar(key + "_mean", mean))
                    summaries.append(tf.summary.scalar(key + "_std", std))
                    _ema_val[key + "_mean"] = mean
                    _ema_val[key + "_std"] = std
                else:
                    summaries.append(tf.summary.scalar(key, val))
                    _ema_val[key] = val

        if ema:
            _ema_ops = []
            reset_ema_to = tf.constant(0, tf.int32)
            ema_step = tf.Variable(reset_ema_to, trainable=False, name="%s_ema_step" % __name__)
            reset_ema = ema_step.initializer
            for i in ema:
                with tf.name_scope("%s_%d" % (__name__, i)):
                    ema = tf.train.ExponentialMovingAverage(
                        decay=1.0
                        - tf.maximum(
                            0.1 ** i,
                            tf.exp((np.log(0.1 ** i) / 10.0) * tf.cast(ema_step, tf.float32)),
                        )
                    )
                    _ema_ops.append(ema.apply(_ema_val.values()))
                    for key, val in _ema_val.items():
                        val = ema.average(val)
                        _ema_out[key, i] = val
                        summaries.append(tf.summary.scalar("%s" % key, val))
            with tf.control_dependencies([train]):
                with tf.control_dependencies(_ema_ops):
                    train = ema_step.assign_add(1)

        ema = _ema_out
        summaries = tf.summary.merge(summaries)
        self.__dict__.update(kwargs)
        pin(locals(), True, kwargs)


Chain.cc = ccall(Chain, True)


class Runner(object):
    def __init__(self, sess, writer):
        self.step = 0
        pin(locals())

    def __call__(self, *args, **kwargs):
        return self.sess.run(*args, **kwargs)

    def train(self, chain, extra=None, step=1, sumiv=10, sumskip=0, **kwargs):
        if extra is None:
            extra = chain.loss
        for i in count():
            for j in xrange(step):
                self(chain.train, **kwargs)
            sum = sumiv and not (i < sumskip or i % sumiv)
            get = (chain.train, extra)
            if sum:
                get += (chain.summaries, chain.step)
            out = self(get, **kwargs)
            if sum:
                self.writer.add_summary(*out[2:])
            yield out[1]

    def conf(self, *args, **kwargs):
        return Confirmer(self.train(*args, **kwargs).next)

    @classmethod
    @contextmanager
    def make(cls, path, gpuOpts={}, flush_secs=20):
        with cls.make_session(gpuOpts=gpuOpts) as sess:
            with tf.summary.FileWriter(
                path, session=sess, flush_secs=flush_secs, graph=sess.graph
            ) as writer:
                yield cls(sess, writer)

    @classmethod
    def make_session(cls, gpuOpts={}):
        return tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(**gpuOpts)))


class Confirmer(object):
    last = np.inf
    task = None
    tdelay = 60
    tbestbump = 10

    def __init__(self, func):
        self.reset.func = func

    def __call__(self, target, its=np.inf, pre=0, **kwargs):
        if not kwargs.setdefault("disable", not kwargs):
            kwargs.setdefault("auto", self._prog)
        assert 0 < target
        tnext = 0
        with atq(xrange(its), **kwargs) as tq:
            for i in tq:
                if self.task is not None:
                    if self.task.check_stop():
                        break
                    if tnext < time():
                        tnext = time() + self.tdelay
                        pTotal = 100.0 * i / its
                        self.task.set_progress_percentage(pTotal)
                        self.task.set_status_message(
                            "total=%.1f%% conf=%.1f%% best=%.2e"
                            % (pTotal, 100.0 * self.conf / target, self.best)
                        )
                if not (self.conf < target):
                    break
                self.step += 1
                self.last = self.func()
                if self.step < pre:
                    continue
                elif self.best < self.last:
                    self.conf += 1
                else:
                    self.on_best()
                    self.conf = 0
                    self.best = self.last
                    tnext -= self.tbestbump
        return not (self.conf < target)

    def on_best(self):
        pass

    @property
    def reset(self):
        self.last = self.best = np.inf
        self.step = self.conf = 0
        return self

    @property
    def rdlb(self):
        # Relative Delta Last to Best
        return np.float_(self.last - self.best) / self.best

    def _prog(self, i, prefix=""):
        return {
            prefix + k: v
            for k, v in dict(
                step=self.step if self.step != i else None,
                last=self.last if self.rdlb else None,
                rdlb=self.rdlb or None,
                conf=self.conf or None,
                best=self.best,
            ).items()
            if v is not None
        }

    def __repr__(self):
        return "Confirmer(step=%d, best=%.3e, conf=%d)" % (self.step, self.best, self.conf)


class Bestie(object):
    def __init__(self, model, path):
        pin(locals())

    def save(self):
        self.model.save_weights(self.path)

    def load(self):
        self.model.load_weights(self.path)

    def finish(self):
        self.load()
        self.model.save(self.path)


class BestieMemory(Bestie):
    def __init__(self, *args, **kwargs):
        super(BestieMemory, self).__init__(*args, **kwargs)

    def save(self):
        self.h5 = h5py.File(BytesIO())
        save_weights_to_hdf5_group(self.h5, self.model.layers)

    def load(self):
        load_weights_from_hdf5_group(self.h5, self.model.layers)