import numpy as np
from matplotlib.pyplot import *
from matplotlib import animation
font = {'color'  : 'black',
        'weight' : 'normal',
        'size'   : 16,
        }

# Definition des constantes

delta = -10 # unité arbitraire
omega = 0.2 # rad/s
c = 1.0 # m/s

# Relation de dispersion dans chaque milieu

def k1(W):
    return W / c

def k2():
    return (1 + 1j) / delta

# Relations de définition des coefficients de réflexion et de trasmission

def Dephasage(Vect, Pos):
    return np.exp(1j*Vect*Pos)

def coeff_reflexion(Vect1, Vect2, Interface):
    return (Vect1 - Vect2)/(Vect1 + Vect2) * Dephasage(2*Vect1, Interface)

def coeff_transmission(Vect1, Vect2, Interface):
    return (Dephasage(Vect1, Interface) + Dephasage(- Vect1, Interface) * coeff_reflexion(Vect1, Vect2, Interface)) * Dephasage(- Vect2, Interface)

# Définition d'une OPPS

def OPPS(Vect, W, x, t):
    return np.exp(1j*(Vect*x-W*t))

# Initialisation

N = 1000  # Nombre de points sur les domaines [xmin; x_interface] et [x_interface; xmax]
xmin = -3 * 3 * 3.14 * 5
xmax = 100
x_interface = 0.0 # Abscisse de l'interface

                                                            # CALCUL et ANIMATION DES ONDES
def FCT(t):
    x1 = np.linspace(xmin,x_interface,N)
    E1_i = np.zeros(x1.size,dtype=np.complex)
    E1_r = np.zeros(x1.size,dtype=np.complex)
    E1 = np.zeros(x1.size,dtype=np.complex)
    x2 = np.linspace(x_interface,xmax,N)
    E2 = np.zeros(x2.size,dtype=np.complex)
    E1_i = 1.0 * OPPS(k1(omega), omega, x1, t)
    E1_r = coeff_reflexion(k1(omega), k2(), x_interface) * OPPS(-k1(omega), omega, x1, t)
    E1 = E1_i + E1_r
    E2 = coeff_transmission(k1(omega), k2(), x_interface) * OPPS(k2(), omega, x2, t)
    return  (x1,x2,E1_i, E1_r, E1, E2)

temps = 0.0
(x1,x2,E1_i, E1_r, E1, E2) = FCT(0.0)
dt = 0.001 
fig, ax = subplots()
title("Delta = %1.2f " %delta)
line1, = ax.plot(x1,np.real(E1_i),'r--', label = "Inc.")
line2, = ax.plot(x1,np.real(E1_r),'g--', label = "Réfl.")
line3, = ax.plot(x1,np.real(E1),'b', label = "Inc. + Refl.", linewidth = 2)
line4, = ax.plot(x2,np.real(E2),'y', label = "Transmis", linewidth = 2)
ax.grid()
axvspan(x_interface, xmax, facecolor='lightgrey', alpha=0.5)
text(30, -1.7, 'Conducteur', fontdict = font) 
text(-100, -1.7, 'Vide', fontdict = font) 
xlabel('x (unité de delta)')
ylabel('E (unité arbitraire)')
axis([xmin, xmax, -2 ,3])
legend(loc='upper right')

def animate(i):
    global temps,xmin,xmax,N
    temps += dt
    (x1,x2,E1_i, E1_r, E1, E2) = FCT(i)
    line1.set_xdata(x1)
    line1.set_ydata(np.real(E1_i))
    line2.set_xdata(x1)
    line2.set_ydata(np.real(E1_r))
    line3.set_xdata(x1)
    line3.set_ydata(np.real(E1))
    line4.set_xdata(x2)
    line4.set_ydata(np.real(E2))

anim = animation.FuncAnimation(fig,animate, interval=200)
show()

