import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import diags, lil_matrix
from scipy.sparse.linalg import spsolve

# Paramètres
L = 1.0
nx = 50
dx = L / nx
x = np.linspace(dx/2, L - dx/2, nx)

tmax = 0.01
N = 10000

# Conditions initiales
def u0(x):
    return np.where((x >= 0.5-1/8) & (x <= 0.5+1/8), 1.0, 0.0)

# Calcul des coefficients c_i
def c(i):
    if i == 0:
        return 0.25
    else:
        return 4 / i / np.pi * np.cos(i * np.pi / 2) * np.sin(i * np.pi / 8)

# Calcul de u_N(x, t)
def u_N(x, t, N):
    u = np.zeros_like(x)
    for i in range(N + 1):
        u += c(i) * np.exp(-(i**2) * (np.pi**2) * t) * np.cos(i * np.pi * x)
    return u

# Tracé des résultats
# plt.plot(x, u_N(x, tmax, N))
# plt.ylim(-0.1, 1.1)
# plt.show()

# maintenant on passe au theta schéma
cfl = 0.1

dt = cfl * dx 
#print(dt,dx**2/2) 
#dt=dx**2/2+2e-7

# calcul des matrices

# Création de la matrice A en format LIL pour une construction facile
A = lil_matrix((nx, nx), dtype=np.float64)
A.setdiag(2, k=0)
A.setdiag(-1, k=1)
A.setdiag(-1, k=-1)

# matrice identité
Id = lil_matrix((nx, nx), dtype=np.float64)
Id.setdiag(1, k=0)

# Conditions aux limites de Neumann
A[0, 0] = 1 
A[-1, -1] = 1 

# afficher la matrice A au format plein ligne par ligne
#print(A.toarray())
#exit()
# Conversion en format CSR pour une résolution efficace
A = A.tocsr()
Id = Id.tocsr()

theta = 1

Mexp = Id - (1-theta) * dt / dx**2 * A
Mimp = Id + theta * dt / dx**2 * A


#....

#initialisation en temps
t = 0;

un = u0(x)
un = u_N(x,tmax,1000)

# boucle en temps pour evoluer un et unp1
while t < tmax:
    b = Mexp.dot(un)
    unp1 = spsolve(Mimp, b)
    un = unp1
    t += dt

# tracé sur la même figure
uex=u_N(x, 2*tmax, N)
err2=np.linalg.norm(uex-un)*np.sqrt(dx)
print(err2,nx)
plt.plot(x, uex, label='Solution exacte')
plt.plot(x, un, label='Theta schéma')
plt.ylim(-0.1, 1.1)
plt.legend()
plt.show()


