from abc import abstractmethod, ABC
from basis import *


class SourseFunction(ABC):
    @abstractmethod
    def get_val(self, x: float):
        pass
    
    """
    Defines an operator which serves as a substitute equivalent to l(bi):
    bi = Basis(...)
    self | bi := l(bi)
    """
    @abstractmethod
    def __or__(self, basis: Basis):
        pass


class ConstantSource(SourseFunction):
    def __init__(self, const: float):
        self.const = const
    
    def get_val(self, x: float):
        return self.const
    
    def __or__(self, basis: Basis):
        diff = lambda x1, x2, n: x1**n - x2**n
        # Area of the triangle scaled by our self.const
        if isinstance(basis, Hat):
            x_m, x, x_p = basis.get_coords()
            return self.const * (0.5 * (x - x_m) + 0.5 * (x_p - x))
        elif isinstance(basis, Quadratic):
            if basis.f_odd:
                x_m, _, x_p = basis.get_coords()
                sum = x_m + x_p
                prod = x_p * x_m
                return basis.A* self.const * ((x_p**3 - x_m**3) / 3 - sum / 2 * (x_p**2 - x_m**2) + prod * (x_p - x_m))
            else:
                x_mm, x_m, _x, x_p, x_pp = basis.get_coords()
                A_l, A_r = basis.A
                sum_l, sum_r = (x_m + x_mm), (x_p + x_pp)
                prod_l, prod_r = x_m * x_mm, x_p * x_pp
                # (x - x_mm)(x - x_m) = x^2 - x(x_mm + x_m) + x_mm x_m
                #   ...   ...   ...   = x^2 - x(x_pp + x_p) + x_pp x_p
                return (A_l * self.const * (diff(_x, x_mm, 3) / 3 - sum_l * diff(_x, x_mm, 2) / 2 + prod_l * diff(_x, x_mm, 1))) + (A_r * self.const * (diff(x_pp, _x, 3) / 3 - sum_r * diff(x_pp, _x, 2) / 2 + prod_r * diff(x_pp, _x, 1)))
        raise NotImplemented


class DiracDelta(SourseFunction):
    def __init__(self, shift: float, scalar: float = 1):
        self.shift = shift
        self.scalar = scalar
    
    def get_val(self, x: float):
        return self.scalar if x == self.shift else 0

    def __or__(self, basis: Basis):
        return self.scalar * basis.get_val(self.shift)