import numpy as np


from lossfunc import currentloss



def int(x,y):
  x,y=(list(t) for t in zip(*sorted(zip(x,y))))
  ret=0.0
  for i in range(1,len(x)):
    ret+=((y[i]+y[i-1])*(x[i]-x[i-1]))/2
  return ret


def difference(a,b):
  if len(a.shape)<3:a=np.reshape(a,(1,a.shape[0],a.shape[1]))
  if len(b.shape)<3:b=np.reshape(b,(1,b.shape[0],b.shape[1]))
  return currentloss(a,b,np)
  return np.sqrt(np.mean((a-b)**2))

def caucd(d,y):
  #d:       value by which to sort
  #y:       class (0 or 1)

  d0,d1=[],[]
  for i in range(len(y)):
    if (y[i]>0.5):
      d1.append(d[i])
    else:
      d0.append(d[i])


  
  #sort list (d,y)
  d,y=zip(*sorted(zip(d, y)))
  
  d,y=np.array(d),np.array(y)
  
  
  has1=0#10
  has0=0#00
  mis1=np.sum(y>0.5)#11
  mis0=np.sum(y<0.5)#01



  fpr=[]
  tpr=[]
  tnr=[]
  fnr=[]

  for i,(ad,ay) in enumerate(zip(d,y)):
    if ay<0.5:
      has0+=1
      mis0-=1
    else:
      has1+=1
      mis1-=1
    tpr.append((mis1/(mis1+has1+0.000000001)))
    fpr.append((mis0/(mis0+has0+0.000000001)))
    tnr.append((has0/(has0+mis0+0.000000001)))
    fnr.append((has1/(mis1+has1+0.000000001)))
    #print(tpr,1/(fpr+0.000000001))
    #if i>1000:break
  
  
  
  tpr=np.array(tpr)
  fpr=np.array(fpr)
  tnr=np.array(tnr)
  fnr=np.array(fnr)
  
  auc=int(fpr,tpr)
  
  i30=np.argmin((tpr-0.3)**2)
  e30=fpr[i30]


  return {"tpr":tpr,"fpr":fpr,"tnr":tnr,"fnr":fnr,"auc":auc,"e30":e30,"i30":i30,"nw":True,"c":-1,"d0":d0,"d1":d1}
  
  

def cauc(p,c,y):
  #p:  prediction
  #c:  correct
  #y:  class (0 or 1)


  d=np.zeros((len(y),))

  d0,d1=[],[]
  for i in range(len(y)):
    d[i]=difference(c[i],p[i])
    if (y[i]>0.5):
  #  if np.random.randint(2)==0:
      d1.append(d[i])
    else:
      d0.append(d[i])
  #for i in range(len(y)):
  #  d[i]=difference(p[i],c[i])
  

  ret=caucd(d,y)
  return ret