# coding: utf-8

from utils.order import getPGroot
from collections import defaultdict

import awkward as ak
import numpy as np
import coffea.processor as processor

# from config.processes import get as Pget
from processor.generator import GeneratorHelper
from processor.util import *
from processor.util import reduce_and, reduce_or, nano_mask_or, get_ht, normalize
from processor.sf import *
from processor.bbww import *
import tasks.corrections.processors as corr_proc
from utils.coffea import Histogramer, ArrayExporter
from utils.tf_serve import autoClient


class BaseSelection:

    skip_cats = {"resolved", "boosted"}
    common = ("energy", "x", "y", "z")
    param = ("year",)
    hl = (
        "MET",
        "dr_ll",
        "ht",
        "min_dphi_jet",
        "min_dr_jet",
        "mll",
        "njets",
    )

    @property
    def trigger(self):
        return {
            2017: {
                "ee": {
                    "Ele23_Ele12_CaloIdL_TrackIdL_IsoVL": all,
                    "Ele23_Ele12_CaloIdL_TrackIdL_IsoVL_DZ": all,
                },
                "mumu": {
                    "Mu17_TrkIsoVVL_Mu8_TrkIsoVVL_DZ": "B",
                    "Mu17_TrkIsoVVL_Mu8_TrkIsoVVL_DZ_Mass3p8": "C-F",
                },
            },
        }[int(self.year)]

    @property
    def tensors(self):
        common = self.common
        hl = self.hl
        param = self.param
        return {
            "lep": (
                "good_leptons",
                2,
                common + ("pdgId", "charge"),
                np.float32,
                {"groups": ["multiclass", "input", "part"]},
            ),
            "jet": (
                "good_jets",
                1,
                common + ("btagDeepFlavB",),
                np.float32,
                {"groups": ["multiclass", "input", "part"]},
            ),
            "met": (
                "met",
                0,
                ("x", "y"),
                np.float32,
                {"groups": ["multiclass", "input", "part"]},
            ),
            "hl": (
                None,
                0,
                hl,
                np.float32,
                {
                    "groups": ["multiclass", "input"],
                },
            ),
            "param": (None, 0, param, np.float32, {"groups": ["multiclass", "input"]}),
            "eventnr": (
                None,
                0,
                ["eventnr"],
                np.int64,
                {"groups": ["multiclass", "split"]},
            ),
            "procid": (
                None,
                0,
                ["procid"],
                np.int32,
                {"groups": ["multiclass", "class"]},
            ),
        }

    # fmt: on

    def arrays(self, X):
        out = {}
        for name, tensor in self.tensors.items():
            id, n, vars, typ, aux = tensor

            def preproc(x):
                return ak.to_numpy(
                    normalize(x, pad=1)
                    if n == 0
                    else ak.fill_none(ak.pad_none(x, n, clip=True), 0),
                    allow_missing=False,
                )

            if id is None:
                vals = [preproc(X[var]) for var in vars]
            else:
                vals = [preproc(getattr(X[id], var)) for var in vars]
            out[name] = np.stack(vals, axis=-1)
            # set all nans or infs to 0
            out[name] = np.nan_to_num(out[name], nan=0.0, posinf=0.0, neginf=-0.0).astype(typ)
        return out

    dtype = np.float32

    debug_dataset = (
        "data_F_mumu"  # "ggZH_HToBB_ZToLL_M125"  # "data_F_mumu"  # "TTTo2L2Nu" # "WWW_4F"
    )
    # debug_uuids = {"5CB647A3-93F6-0B40-BD9F-3DB44E8D60F7"}

    def select(self, events):
        def const_arr(value):
            arr = np.broadcast_to(value, shape=(n_events,))
            arr.flags.writeable = False
            return arr

        dataset_key = tuple(events.metadata["dataset"])
        dataset_inst = self.get_dataset(events)
        is_data = not dataset_inst.is_mc

        muons = events.Muon
        electrons = events.Electron

        # muons
        muon_selection = reduce_and(
            muons.tightId,
            (muons.pt > 15.0),
            (abs(muons.eta) <= 2.4),
            (abs(muons.dxy) <= 0.05),
            (abs(muons.dz) <= 0.1),
            (muons.pfRelIso04_all <= 0.15),
            (muons.sip3d <= 8.0),
        )

        ele_selection = reduce_and(
            electrons.mvaFall17V1Iso_WP80,
            (electrons.pt > 15.0),
            (abs(electrons.eta) <= 2.5),
            (abs(electrons.dxy) <= 0.05),
            (abs(electrons.dz) <= 0.1),
            (electrons.sip3d <= 8.0),
            (electrons.lostHits <= 1),
        )

        good_electrons = electrons[ele_selection]
        good_muons = muons[muon_selection]

        # leading lepton pt cuts
        leading_muon = good_muons[:, 0:1]
        leading_electron = good_electrons[:, 0:1]
        leading_muon_pt_cut = leading_muon.pt > 25.0
        leading_electron_pt_cut = leading_electron.pt > 25.0
        single_muon_pt_cut = leading_muon.pt > 30.0
        single_electron_pt_cut = leading_electron.pt > 35.0

        trigger = Trigger(events.HLT, self.trigger, dataset_inst)

        for unc, shift, met, jets, fatjets in self.jec_loop(events):
            output = self.accumulator.identity()
            weights = processor.Weights(
                len(events), storeIndividual=True  # self.individal_weights or True
            )
            selection = processor.PackedSelection()
            n_events = len(events)
            output["n_events"][dataset_key] = n_events

            MET = met.pt

            if not is_data:
                output["sum_gen_weights"][dataset_key] = ak.sum(events.Generator.weight)
            else:
                output["sum_gen_weights"][dataset_key] = 0.0

            year = np.full(n_events, int(self.year))

            # load objects

            # select good jets
            jet_selection = (jets.isTight) & (jets.pt > 10.0) & (abs(jets.eta) < 2.4)

            good_jets = jets[jet_selection]
            # start object cutflow
            output["object_cutflow"]["total muons"] += ak.sum(ak.num(muons))
            output["object_cutflow"]["total electrons"] += ak.sum(ak.num(electrons))
            output["object_cutflow"]["total jets"] += ak.sum(ak.num(jets))
            output["object_cutflow"]["good electrons"] += ak.sum(ak.num(good_electrons))
            output["object_cutflow"]["good muons"] += ak.sum(ak.num(good_muons))
            output["object_cutflow"]["good jets"] += ak.sum(ak.num(good_jets))

            # now clean all objects

            clean_jet_mu = nano_object_overlap(good_jets, good_muons, dr=0.4)
            clean_jet_ele = nano_object_overlap(good_jets, good_electrons, dr=0.4)

            clean_good_jets = good_jets[(clean_jet_mu & clean_jet_ele)]

            output["object_cutflow"][
                "clean good muons and good electrons from good jets"
            ] += ak.sum(ak.num(clean_good_jets))

            # met filter
            if is_data:
                met_filter = self.campaign_inst.aux["met_filters"]["data"]
            else:
                met_filter = self.campaign_inst.aux["met_filters"]["mc"]

            met_filter_mask = nano_mask_and(events.Flag, met_filter)
            selection.add("met_filter", ak.to_numpy(met_filter_mask))

            # create leptons
            good_leptons = ak.with_name(
                ak.concatenate([good_muons, good_electrons], axis=1),
                "PtEtaPhiMCandidate",
            )

            # select at least one bjet
            passbtag = clean_good_jets.btagDeepFlavB >= btag_wp["2017"]["tight"]
            one_bjets = ak.num(clean_good_jets[passbtag]) >= 1
            zero_bjets = ak.num(clean_good_jets[passbtag]) == 0

            # lepton cuts
            # lep_tpls = good_leptons.distincts()
            lep_tpls = ak.combinations(
                good_leptons[..., :2],
                n=2,
                replacement=False,
                axis=-1,
                fields=["lep1", "lep2"],
            )
            dilep = lep_tpls.lep1 + lep_tpls.lep2
            # dilep = lead_diobj(good_leptons)
            ll_opp_charge = dilep.charge == 0
            met_pt_cut = met.pt > 30
            mll = ak.max(dilep.mass, axis=-1)

            # dijet variables, sort jets by btag value
            clean_good_jets = clean_good_jets[ak.argsort(clean_good_jets.btagDeepFlavB)]

            two_leps = good_leptons[:, :2]

            # dr_ll = min_dr(two_leps.distincts())
            dr_ll = normalize(min_dr(good_leptons))
            # dphi_ll = min_dphi(two_leps.distincts())
            dphi_ll = normalize(min_dphi(good_leptons))

            in_dphi_jet = normalize(min_dphi(good_jets))

            ht = get_ht(clean_good_jets)
            njets = ak.num(good_jets)

            # configure channel
            ch_mumu = reduce_and(
                ak.any(leading_muon.pt > 30.0, axis=-1),
                trigger.get("mumu"),
                (ak.num(good_muons) == 2),
                ak.num(good_electrons) == 0,
                ak.any(ll_opp_charge, axis=-1),
                ak.any(met.pt < 200, axis=-1),
            )
            ch_ee = reduce_and(
                ak.any(leading_electron.pt > 25.0, axis=-1),
                trigger.get("ee"),
                (ak.num(good_electrons) == 2),
                (ak.num(good_muons) == 0),
                ak.any(ll_opp_charge, axis=-1),
                ak.any(met.pt < 200, axis=-1),
            )

            # jets
            selection.add("one_bjets", ak.to_numpy(one_bjets))
            selection.add("zero_bjets", ak.to_numpy(zero_bjets))
            # lepton channel
            selection.add("ch_mumu", ak.to_numpy(ch_mumu))
            selection.add("ch_ee", ak.to_numpy(ch_ee))

            # weights after object selection
            if not is_data:
                badgw = np.abs(events.Generator.weight) > 1e10
                gw = np.where(~badgw, events.Generator.weight, 0.0)
                output["sum_gen_weights"][dataset_key] = ak.sum(gw)
                weights.add("gen_weight", gw)

                if "pileup" in self.corrections:
                    weights.add(
                        "pileup",
                        **self.corrections["pileup"](
                            pu_key=self.get_pu_key(events),
                            nTrueInt=events.Pileup.nTrueInt,
                        ),
                    )
                if "pdf" in self.corrections:
                    self.corrections["pdf"](
                        events,
                        weights,
                        clip=10,
                    )
                # PS weight shifts: ISR & FSR
                # add LHEScaleWeights: renormalization & factorization
                gh = GeneratorHelper(events, weights)
                gh.PSWeight(lfn=self.get_lfn(events))
                gh.ScaleWeight()

                # do not need this?!
                # gh.gen_weight(output, dataset_key, datasets)

                # generator_variations(
                # events, weights, dataset_inst, gw, output["sum_gen_weights"]
                # )

                if dataset_inst.name.startswith("TTTo"):
                    w_toppt = ak.to_numpy(top_pT_reweighting(events.GenPart)).squeeze()
                    weights.add(
                        "top_pT_reweighting",
                        w_toppt,
                        weightUp=w_toppt**2,  # w**2
                        weightDown=ak.ones_like(w_toppt),  # w = 1.
                    )

                # L1 ECAL prefiring
                weights.add(
                    "l1_ecal_prefiring",
                    events.L1PreFiringWeight.Nom,
                    weightUp=events.L1PreFiringWeight.Up,
                    weightDown=events.L1PreFiringWeight.Dn,
                )

                # electrons
                if "electron" in self.corrections:
                    POGElectronSF(self.corrections)(good_electrons, weights)
                # muons
                if "muon" in self.corrections:
                    POGMuonSF(self.corrections)(good_muons, weights, self.year)

                # trigger sf
                lep_pt = ak.fill_none(ak.pad_none(good_leptons.pt, 2, clip=True), np.nan)
                if "trigger" in self.corrections:
                    Ctrig = self.corrections["trigger"]["hist"]
                    for name, ch, trigger_hist in [
                        (
                            "mumu",
                            ch_mumu,
                            "h_DoubleMu_OR__X__allMET_mu0_pt_vs_mu1_pt_withSysts",
                        ),
                        (
                            "ee",
                            ch_ee,
                            "h_DoubleEl_OR__X__allMET_el0_pt_vs_el1_pt_withSysts",
                        ),
                    ]:
                        args = lep_pt[:, 0], lep_pt[:, 1]
                        trigger_sf = Ctrig[f"ttHbb_dilepton_trigger_{trigger_hist}"](*args)
                        trigger_sf_error = Ctrig[f"ttHbb_dilepton_trigger_{trigger_hist}_error"](
                            *args
                        )
                    weights.add(
                        f"trigger_{name}_sf",
                        np.where(ch, trigger_sf, 1),
                        weightUp=np.where(ch, trigger_sf_error, 0),
                        weightDown=None,
                        shift=True,
                    )

                if "btag" in self.corrections:
                    # AK4 Jet reshape SF
                    corr = self.corrections["btag"]["deepjet"][
                        "POG_reduced"
                        + (
                            "_TuneCP5"
                            if int(self.year) == 2016
                            and "_TuneCP5_" in self.get_lfn(events, default="")
                            else ""
                        )
                    ].get("reshape")
                    if shift is None:
                        shifts = list(self.analysis_inst.aux.get("btag_sf_shifts", []))
                        # shifts = [] # no shifts -> speedup
                        c = "central"
                        corr = corr.reduce(shifts=shifts)
                    else:
                        shifts = []
                        if unc in ("jer", "UnclustEn", "HemIssue"):
                            c = "central"
                        else:
                            # flip up/down direction for certain jec shifts, since they changed direction:
                            # ReducedJECv1 (used for BTagSF) <-> ReducedJECv2 (applied here)
                            # https://twiki.cern.ch/twiki/bin/viewauth/CMS/JECUncertaintySources#Run_2_reduced_set_of_uncertainty
                            # if unc in ("RelativeSample", "RelativeBal"):
                            if unc in ("FlavorQCD", "RelativeBal") or unc.startswith(
                                (
                                    "RelativeSample_",
                                    "EC2",
                                )  # these are suffixed by their year
                            ):
                                btag_shift = dict(up="down", down="up")[shift]
                            else:
                                btag_shift = shift
                            c = f"{btag_shift}_jes{unc}"
                        assert c in corr.shifts, (c, corr.shifts)
                        corr = corr.reduce(shifts=[c], updown=False)
                    sfs = ak.prod(
                        corr.eval_awk1(obj=clean_good_jets, discr="btagDeepFlavB"),
                        axis=-1,
                    )
                    sf0 = sfs[c]
                    weights.add("btagWeight", sf0)
                    for btag_shift in shifts:
                        weights.add(
                            "btagWeight_%s" % btag_shift,
                            ak.ones_like(sf0),
                            weightUp=sfs[f"up_{btag_shift}"] / sf0,
                            weightDown=sfs[f"down_{btag_shift}"] / sf0,
                        )
                    del sf0, sfs

                if "btagnorm" in self.corrections:
                    btagnorm = self.corrections["btagnorm"][dataset_inst.name]
                    weights.add("btagNorm", btagnorm(ak.num(clean_good_jets)))

            # n build categories
            eventnr = ak.flatten(events.event, axis=-1)
            common_filter = ["met_filter"]

            categories = dict(
                # 1 bjet
                ee_1b=common_filter + ["ch_ee", "one_bjets"],
                mumu_1b=common_filter + ["ch_mumu", "one_bjets"],
                # 2 bjets
                ee_0b=common_filter + ["ch_ee", "zero_bjets"],
                mumu_0b=common_filter + ["ch_mumu", "zero_bjets"],
            )

            if dataset_inst.is_mc:
                (process,) = dataset_inst.processes.values
                procid = const_arr(process.id)
            else:
                procid = const_arr(self.analysis_inst.processes.get("data").id)
                # raise NotImplementedError(
                #     "There was an old Pget import, which is not anymore available. please check the comment under this error!"
                # )
                # procid = np.full(n_events, Pget("data").id)

            yield locals()


class Processor(BaseSelection, Histogramer):
    jes_shifts = False
    pass


class ExporterProcessor(BaseSelection, ArrayExporter):
    sep = "+"
    group = "classification"

    @classmethod
    def requires(cls, task):
        return task.base_requires(data_source="mc", btagnorm=False)

    def select(self, events):
        out = next(super().select(events))
        dataset = self.get_dataset(events)
        (process,) = dataset.processes.values()
        xsec_weight = (
            1
            if process.is_data
            else process.xsecs[13].nominal * self.config.campaign.get_aux("lumi")
        )

        out["weights"].add("xsec", xsec_weight)
        yield out


class PUCount(BaseSelection, corr_proc.PUCount):
    memory = "1000MiB"


class BTagSFNormEff(BaseSelection, corr_proc.BTagSFNormEff):
    # skip_processes = {Pget("")}
    skip_cuts = {"zero_bjets", "one_bjets"}

    @classmethod
    def requires(cls, task):
        return task.base_requires(
            data_source="mc",
            btagnorm=False,
            vjets=False,
            fake=False,
            metphi=False,
            dyestimation=False,
        )

    def njets(self, _locals):
        return ak.to_numpy(ak.num(_locals["clean_good_jets"]))