以前にガウスの消去法を用いて連立方程式を解くプログラムの記事を書いたのですが、今回はLU分解を用いた方法で連立方程式を解くことを考えてみます
LU分解とは

そもそもLU分解ってなんのこと?
あんま聞かないですよね。LU分解
ザックーリ言うと、
連立方程式を解く前に式変形をして解きやすいようにしよう!
って感じです。
具体的には以下のような連立方程式があったとして

このA(係数行列)をA=LUと分解します。
ただし適当に分解するのではなく
Lを対角成分が1の下三角行列、Uを上三角行列と言う条件で分解します。
(対角成分に関しては1にしない場合もありますがこの記事では1で考えます)

するとAX=Bが LUX=Bとなりますね。
その後は
- LY = BとなるYを求める
- UX = YとなるXを求める
と言う手順を踏めばXを導出できます。
分解の手順の計算量がO(n^3)、その後のX導出がO(n^2)で行うことができます

ガウスの消去法の計算量がO(n^3)だからLU分解をわざわざ行う必要はなくない?
一つの連立方程式を解くだけだと計算量のメリットはないです。
ただAX=BのBの部分が違うパターンの問題をたくさん解くときはメリットがあります。

一度Aの分解を行なってしまえば残りの計算は全てO(n^2)で行えるからです。
一般的なLとU
LとUの各要素を求める前に、ここでは係数行列Aと分解行列L,Uのi列j行の要素を記事中ではそれぞれa_ij、l_ij、u_ijとします。(図中ではijは添字です)
AがNxN行列としてそのときLとUもNxN行列です。上にも書いたようにLの対角成分は全て1で固定します。
この時のLとUは以下のような計算で得られます

この計算に関しては特に解説しません。
(そういうものだと思ってありがたく使わせてもらいましょう)
ただ上の式の計算はそれぞれの計算の中で、他の計算をしないと得られない値を用いているので計算の順番が大事です。
コードの実装
コードを実装していきましょう。まずアルゴリズムの部分の前に、準備の段階です。
とりあえずインクルードとマクロです
#include <stdio.h>
#define NMAX 64
#define N 3
問題文をファイルから抽出する関数
void SCAN(double A[N][N], char s[NMAX]){
FILE *fp;
if((fp = fopen(s,"r"))==NULL)
printf("ファイルを開けません\n");
else{
for(int i=0; i<N; i++){
double a[N];
for(int j=0; j<N; j++){
fscanf(fp,"%lf",&a[j]);
A[i][j] = a[j];
}
}
fclose(fp);
}
}
このコードではNxNの配列Aに係数行列を代入してます。
ファイル操作についてわからなかったら以下の記事をどうぞ
配列の中身を表示する関数
void SHOW_NxN(double A[N][N]){
int i, j;
for(i=0; i<N; i++){
for(j=0; j<N; j++){
printf("%f ",A[i][j]);
}
printf("\n");
}
}
LとUの要素を求める関数
void set_L(double L[N][N]){
for(int i=0; i<N; i++)
L[i][i] = 1;
}
void u_1k(double U[N][N], double A[N][N]){
for(int k=0; k<N; k++)
U[0][k] = A[0][k];
}
void l_j1(double L[N][N], double A[N][N], double U[N][N]){
for(int j=1; j<N; j++)
L[j][0] = A[j][0]/U[0][0];
}
void u_jk(double U[N][N], double L[N][N], double A[N][N], int j){
for(int k=j; k<N; k++){
double t=0;
for(int s=0; s<=j-1; s++)
t += L[j][s]*U[s][k];
U[j][k] = A[j][k] - t;
}
}
void l_jk(double U[N][N], double L[N][N], double A[N][N], int k){
for(int j=k+1; j<N; j++){
double t=0;
for(int s=0; s<=k-1; s++)
t += L[j][s]*U[s][k];
L[j][k] = (A[j][k] - t)/U[k][k];
}
}
set_Lは対角成分を代入する関数です。
残りは関数名の通りそれぞれ要素を求める関数です。
変数も上で説明したものと同じです。まあ特に難しいことはないと思います。
main関数
int main(int argc, const char * argv[]) {
char name[NMAX];
double A[N][N] = {0};
printf("使用するファイルを入力してください:");
scanf("%s",name);
SCAN(A,"test_lu.txt");
double L[N][N] = {0};
double U[N][N] = {0};
set_L(L);
u_1k(U, A);
l_j1(L, A, U);
for(int q=1; q<N; q++){
u_jk(U, L, A, q);
if(q == N-1){
break;
}else{
l_jk(U, L, A, q);
}
}
printf("A\n");
SHOW_NxN(A);
printf("L\n");
SHOW_NxN(L);
printf("U\n");
SHOW_NxN(U);
return 0;
}
とりあえずUの一行目とLの一列目までは普通の流れですね。
forの中もUとLを順に解いていきます。
Uから解いていくことで全ての要素について解くことができます。
ただこのまま解いていくと最後のl_jkは必要ない(ちょうど対角成分の1なので解かなくて良い)のでそこはif文で判断します。
全体のコード
#include <stdio.h>
#define NMAX 64
#define N 3
void SCAN(double A[N][N], char s[NMAX]){
FILE *fp;
if((fp = fopen(s,"r"))==NULL)
printf("ファイルを開けません\n");
else{
for(int i=0; i<N; i++){
double a[N];
for(int j=0; j<N; j++){
fscanf(fp,"%lf",&a[j]);
A[i][j] = a[j];
}
}
fclose(fp);
}
}
void SHOW_NxN(double A[N][N]){
int i, j;
for(i=0; i<N; i++){
for(j=0; j<N; j++){
printf("%f ",A[i][j]);
}
printf("\n");
}
}
void set_L(double L[N][N]){
for(int i=0; i<N; i++)
L[i][i] = 1;
}
void u_1k(double U[N][N], double A[N][N]){
for(int k=0; k<N; k++)
U[0][k] = A[0][k];
}
void l_j1(double L[N][N], double A[N][N], double U[N][N]){
for(int j=1; j<N; j++)
L[j][0] = A[j][0]/U[0][0];
}
void u_jk(double U[N][N], double L[N][N], double A[N][N], int j){
for(int k=j; k<N; k++){
double t=0;
for(int s=0; s<=j-1; s++)
t += L[j][s]*U[s][k];
U[j][k] = A[j][k] - t;
}
}
void l_jk(double U[N][N], double L[N][N], double A[N][N], int k){
for(int j=k+1; j<N; j++){
double t=0;
for(int s=0; s<=k-1; s++)
t += L[j][s]*U[s][k];
L[j][k] = (A[j][k] - t)/U[k][k];
}
}
int main(int argc, const char * argv[]) {
char name[NMAX];
double A[N][N] = {0};
printf("使用するファイルを入力してください:");
scanf("%s",name);
SCAN(A,"test_lu.txt");
double L[N][N] = {0};
double U[N][N] = {0};
set_L(L);
u_1k(U, A);
l_j1(L, A, U);
for(int q=1; q<N; q++){
u_jk(U, L, A, q);
if(q == N-1){
break;
}else{
l_jk(U, L, A, q);
}
}
printf("A\n");
SHOW_NxN(A);
printf("L\n");
SHOW_NxN(L);
printf("U\n");
SHOW_NxN(U);
return 0;
}
全体のコードです。
計算法が確立されているので簡単に実装できました。
ぜひ上のコードを使ってください!
コメント