#JCourtin 08/2016

#class matrices réelles n-lignes x m-colonnes

class matrice():
    
    ##attributs de class
    n=None        #nombre de lignes
    m=None        #nombre de colonnes
    
    ligne=()      #tuple des lignes   successives sous forme de listes
    colonne=()    #tuple des colonnes successives sous forme de listes
 
    
    
    ##Méthodes de class
    
    
    #constructeur qui mange une liste de liste
    def __init__(self, liste2Lignes):
        ERROR=False
        self.n=len(liste2Lignes);      self.m=len(liste2Lignes[0])
        self.ligne=();                 self.colonne=()
        
        #récupération des lignes
        for ligne in liste2Lignes:
            ERROR=(len(ligne)!=self.m) #si une ligne n'a pas même longueur
            line=()
            for val in ligne:
                line+=(val,)
                if (type(val+0.)!=type(1.)):#si pas int ou float
                    ERROR=True; break
            if (ERROR):
                raise Exception("Erreur dans les données")
            self.ligne+=(line,)
            
        #construction des colonnes
        for j in range(self.m):
            col=()
            for i in range(self.n):
                col+=(self.ligne[i][j],)
            self.colonne+=(col,)
    
    
    #AFFICHAGE------------------------------------------------------------
    def __str__(self):
        #recherche de longeur de chaine max
        max=0
        for i in range(self.n):
            for j in range(self.m):
                if len(str(self.ligne[i][j]))>max:
                    max=len(str(self.ligne[i][j]))
        #affichage 
        myString="matrice "+str(self.n)+"x"+str(self.m)+" :\n"
        for i in range(self.n):
            myString+="| "
            for j in range(self.m):
                #Astuce de Sioux [stackOveFlow.com]
                #0 et 1 sont les indices des arguments passés à Format.
                myString+="{0:{1}}".format(self.ligne[i][j], max)+" "
            myString+="|\n"
        return myString#-------------------------------------------------

    
########################### PARTIE 1 : A MODIFIER ############################    
    
     ## OPERATION INTERNE   
   
    #multiplication par un scalaire
    def __mul__(self, scala):
        
        # Acompléter
        
        return #la matrice en question
        
    __rmul__=__mul__   #gère la multiplication par la droite
    
    
    
    #negation
    def __neg__(self):
        return self.__mul__(42) #A corriger
        
        
        
        
    #Opération de signe /
    def __truediv__(self, scalaire):
        invScalaire=1./scalaire
        return # A compléter
    
    
    
    #sousMatrice(i,j) !!! (i,j) indices MATH commencent à 1 !!!
    def subMatrice(self, k, l):
        k=k-1; l=l-1# expliquer
        matriceTMP=[]
        for i in range(self.n):    #expliquer
            if (i!=n):             #A corriger
                ligneTMP=[] 
                for j in range(self.m):   #expliquer
                    if (j!=m):            #A corriger
                        ligneTMP+=[self.ligne[i][j]]
                matriceTMP+=[ligneTMP]
        return #A completer
    
    
    
    
    #calcul du déterminant : Approche récursive !----------------------
    def determinant(self):
        if (self.n!=self.m):
            print("Opération sur les matrice carrées svp !"); return
        if (self.n==1):
            return self.ligne[0][0] 
        det=0.
        for k in range(self.n):   #A expliquer : en particulier les indices
            i,j=k+1,1             #qui suivente.
            M=self.subMatrice(i,j)
            det+=(-1)**(i+j)*M.determinant()*self.ligne[k][0]
        return  det     
    
    #matrice transposée------------------------------------------------
    def transpose(self):
        
        #A compléter (il y a une astuce !).
        
        return #matrice transposée
    
    
    #calcul de la comatrice--------------------------------------------
    def comatrice(self):
        matriceTMP=[]
        for i in range(1, self.n+1):#Expliquer les indices
            ligneTMP=[] 
            for j in range(1, self.m+1):#idem
                val=0#trouver la bonne expression
                ligneTMP+=[val]
            matriceTMP+=[ligneTMP]
        return #renvoyer la matrice correspondante.
    
    
    
    
    #calcul de la matrice inverse-------------------------------------
    def inverse(self):
        det=self.determinant()
        if (det==0.):
            raise Exception("Matrice non inversible !!!")
        #doit on ou non utiliser l'appel au constructeur ci-dessous ?
        return 1/det*(self.comatrice()).transpose()
     
     
     
     
    #Arrondi à 10-15 près--------------------------------------------
    def round(self,n=15):
        matriceTMP=[]
        for i in range(self.n):         #essayer et commenter cette méthode
            ligneTMP=[]                 
            for j in range(self.m):
                val=round(self.ligne[i][j],n)
                if (abs(val-int(val))<10**(-n)):
                    val=int(val)
                ligneTMP+=[val]
            matriceTMP+=[ligneTMP]
        return matrice(matriceTMP)



    ## OPERATION EXTERNE   
 
    #opération de signe +   --------------------------------------------
    def __add__(self, other):
        if (self.n!= other.n or self.m!=other.m):
            raise Exception("Matrices non semblables")
        matriceTMP=[]
        
        #A compléter
        
        return matrice(matriceTMP)
    
    #soustraction signe -  --------------------------------------------
    def __sub__(self, other):
        return #A compléter (il y a toujours une astuce ici)
    
    
    #test égalité          --------------------------------------------
    def __eq__(self, other):
        test=True                         #A tester :
        for i in range(self.n):           #ce test d'égalité est-il
            for j in range(self.m):       #vraiment pertinent ?? pourquoi ?
                if (self.ligne[i][j]!=other.ligne[i][j]):
                    test=False; break
        return test
    
    
    
    #produit matriciel --> SELF.OTHER (sens de lecture)----------------
    def dot(self, other):
        if (self.m!= other.n):
            raise Exception("Produit impossible")
        matriceTMP=[]
        
        # A compléter
        
        return matrice(matriceTMP)
    ##FIN DE LA CLASS

########################### PARTIE 2 : A MODIFIER ############################    

"""
Cette pratie se propose de résoudre un système linéaire de type AX=B dans le cas le plus général
par la méthode du pivot de Gauss.
L'intérêt de cette méthode est d'être beaucoup plus rapide que de calculer la matrice inverse (dont le coût devient vite inacceptable) et de détecter automatiquement si le système n'est pas solvable.

En niveau Bleu on demande simplement de commenter cette partie et de la comprendre pour se convaincre qu'elle fonctionne correctement.

Et cela n'est déjà pas facile : 
-------------------------------
- Il faut bien comprendre et résumer le rôle global de chaque fonction 
- Il faut ensuite fouiller dans les détails de chacune des trois fonctions
- Les dernières lignes du MAIN permettent de tester ce code (tout à la fin)

Bon courage
"""
## Résolution de système linéaire type AX=B  


#Cette premiere fonction pilote l'ensemble de la résolution en s'appuyant sur 3 fonctions :
# 1 - gereColonne
# 2 - getPivot
# 3 - resolSysTrig
def resolSystem(A,B):#----------------------------------------------------------
    if (A.n!=A.m or A.n!=B.n):  #Expliquer
        raise Exception("problème de dimensions")
    
    
    #première partie  :                   Définir l'objectif ???
    (A0, B0) = (A,B)#Que fait-t-on ?
    for indexCol in range(A.m-1):
        A0,B0=gereColonne(A0,B0,indexCol)    #expliquer (voir gère colonne)
    print(A0);print(B0)
    #Le résultat est-il conforme à l'objectif ? 
    
    
    #deuxième partie  :                   Définir l'objectif ???
    X=resolSysTrig(A0,B0)
    
    return X   #----------------------------------------------------------------



# 1 - Résumer en quelques lignes le rôle de cette fonction :
#
#
def gereColonne(A, B, indexCol=0):#---------------------------------------------
    
    #1 - recopie des lignes pour les pivots précédents
    matriceTMP=[];                           colonneTMP=[]
    for i in range(indexCol):
        matriceTMP+=[list(A.ligne[i])];      colonneTMP+=[list(B.ligne[i])]
    #   Expliquer le rôle de "indexCol" 
    
    
    #2 - Recherche du pivot optimisée:
    nPivot, pivot, indexLignes = getPivot(A, indexCol)#(indicePivot, Pivot, lignes a modifier)
    #expliquer l'affectation précédente ??? Est-elle valide ?
    #
    #je stocke la ligne du pivot (!!!copie par référence mais tuple non modifiable!!!)
    matriceTMP+=[list(A.ligne[nPivot])]
    colonneTMP+=[list(B.ligne[nPivot])]
    
    
    #3 - Re-calcul des lignes si nécessaire ou simple recopie
        # verifier que l'on recopie le bon nombre de ligne et pas une de trop ??
    for i in range(indexCol, A.n):
        
        if (i in indexLignes):#nécessaire
            maLigne=[0]*(indexCol+1)             #Que fait on ici ??? 
            coeff=A.ligne[i][indexCol]           #coeff diviseur : de quoi s'agit-il ?
            
            for j in range(indexCol+1, A.m):
                maLigne+=[A.ligne[i][j]/coeff*pivot - A.ligne[nPivot][j]]
                                                      #D'où vient cette formule
            matriceTMP+=[maLigne]
            colonneTMP+=[[B.ligne[i][0]/coeff*pivot - B.ligne[nPivot][0]]]
        
        elif(i!=nPivot):#simple recopie
            matriceTMP+=[list(A.ligne[i])]
            colonneTMP+=[list(B.ligne[i])]
            
    #renvoie le système re-calculé 
    return matrice(matriceTMP), matrice(colonneTMP)#----------------------------



# 2 - Recherche de Pivot optimisée  :    ---------------------------------------
#
# Expliquer en quelques ligne ce que fait cette partie, 
# comment et en quoi elle se veut optimisée ?
#       
def getPivot(A, indexCol):   #justifier les valeurs ci-dessous :
    i=indexCol;                 nPivot=0;          indexLignes=[]
    while(i<A.n):
        if (A.ligne[i][indexCol]!=0): #lignes de coeff non nul
            indexLignes+=[i]
        i+=1
    if (len(indexLignes)==0): #Que se passe-t-il ici ?
            raise Exception("Système non solvable")
    nPivot=indexLignes.pop(0) #premier indice non nul devient le pivot
    pivot=A.ligne[nPivot][indexCol]
    return (nPivot, pivot, indexLignes)#(indicePivot, Pivot, lignes a modifier)
    #Quelles sont les informations renvoyées dans ces trois éléments ?
    
    
    
# 3 - Résolution d'un système xxxx ? xxxx  :    ---------------------------
#
# Expliquer en quelques ligne ce que fait cette partie, 
#       
def resolSysTrig(T,C):
    X=[0.]*T.n
    for k in range(0, T.n):
        i=T.n-k-1               #comment varie  i  et pourquoi
        val=C.ligne[i][0]
        for j in range(i+1,T.n):     # pourquoir j commence en i+1 ??
            val-=X[j]*T.ligne[i][j]  #que fait-on ?
        X[i]=val/T.ligne[i][i]       #pourquoi cette division ?
    Xmat=[[X[k]] for k in range(len(X))]#expliquer cette construction de Xmat ? 
    return matrice(Xmat)
    
############# MAIN #############################################################
if (__name__=='__main__'):
    
    
    data22=[[1,2],\
            [3,4]]
    
    data42=[[2,2],\
            [3,4],\
            [5,6],\
            [7,8]]
    
    data33=[[1,2,3],\
            [4,5,6],\
            [7,8,9]]     #lignes liées : L1 + L3 - 2*L2 = 0
    
    matrice22=matrice(data22)
    matrice42=matrice(data42)
    matrice33=matrice(data33)
    
    print("matrice 2x2 :\n")
    print(matrice22.ligne)
    print(matrice22.colonne)
    print()
    print("matrice 4x2 :\n")
    print(matrice42.ligne)
    print(matrice42.colonne)
    print()
    print("matrice 3x3 :\n")
    print(matrice33.ligne)
    print(matrice33.colonne)
    
    ############################### TEST RESOLUTION ############################
    
    #Test de la résolution d'un système en 4x4
    mTest=matrice([[4,1,2,3], [5,2,0,3], [0,9,1,4], [2,1,3,4]])
    bTest=matrice([[1], [2], [3], [4]])
    
    X=resolSystem(mTest, bTest)
    #Proposer une formule utilisant la matrice inverse pour comparer les résultats.
    #AY=B soit:
    #Y=??????
    #on compare ensuite X (par Gauss) et Y :
    #print("X.round(14)==Y.round(14) : ",X.round(14)==Y.round(14))