1 module dnum.linalg;
2 
3 import dnum.tensor;
4 import std.typecons : tuple;
5 
6 /++
7     Identity Matrix
8 +/
9 Tensor eye(long n) {
10   auto I = Tensor(0, n, n);
11   foreach (i, ref rows; I.data) {
12     rows[i] = 1;
13   }
14   return I;
15 }
16 
17 /++
18     LU Decomposition
19 +/
20 auto lu(Tensor m) {
21   auto n = m.nrow;
22 
23   auto u = Tensor(0, n, n);
24   auto l = eye(n);
25 
26   foreach (i, ref rows; u.data) {
27     rows[i] = m[0, i];
28   }
29 
30   foreach (i; 0 .. n) {
31     foreach (k; i .. n) {
32       double s = 0;
33       foreach (j; 0 .. i) {
34         s += l[i, j] * u[j, k];
35       }
36       u[i, k] = m[i, k] - s;
37     }
38 
39     foreach (k; i + 1 .. n) {
40       double s = 0;
41       foreach (j; 0 .. i) {
42         s += l[k, j] * u[j, i];
43       }
44       l[k, i] = (m[k, i] - s) / u[i, i];
45     }
46   }
47   return tuple(l, u);
48 }
49 
50 /++
51     Determinant
52 +/
53 auto det(Tensor m) {
54   auto u = m.lu[1];
55   double s = 1;
56 
57   foreach (i, ref rows; u.data) {
58     s *= rows[i];
59   }
60   return s;
61 }
62 
63 /++
64     Block partitioning Matrix (Tuple(mat11, mat12, mat21 mat22))
65 +/
66 auto block(Tensor m) {
67   auto r = m.nrow;
68   auto l = r / 2;
69 
70   auto t1 = Tensor(l, l);
71   auto t2 = Tensor(l, r - l);
72   auto t3 = Tensor(r - l, l);
73   auto t4 = Tensor(r - l, r - l);
74 
75   foreach (i, ref rows; m.data) {
76     foreach (j, ref elem; rows) {
77       if (i < l) {
78         if (j < l) {
79           t1[i, j] = elem;
80         } else {
81           t2[i, j - l] = elem;
82         }
83       } else {
84         if (j < l) {
85           t3[i - l, j] = elem;
86         } else {
87           t4[i - l, j - l] = elem;
88         }
89       }
90     }
91   }
92   return tuple(t1, t2, t3, t4);
93 }
94 
95 /++
96     Inverse
97 +/
98 Tensor inv(Tensor m) {
99   auto res = m.lu;
100   auto l = res[0];
101   auto u = res[1];
102 
103   auto linv = l.invL;
104   auto uinv = u.invU;
105 
106   return uinv % linv;
107 }
108 
109 // =============================================================================
110 // Back-End Utils
111 // =============================================================================
112 
113 /++
114     Four Matrix to one Matrix
115 +/
116 Tensor combine(Tensor a, Tensor b, Tensor c, Tensor d) {
117   auto x1 = a.ncol;
118   auto x2 = b.ncol;
119   auto y1 = a.nrow;
120   auto y2 = c.nrow;
121 
122   auto x = x1 + x2;
123   auto y = y1 + y2;
124 
125   auto m = Tensor(x, y);
126 
127   foreach (i; 0 .. y1) {
128     foreach (j; 0 .. x1) {
129       m[i, j] = a[i, j];
130     }
131     foreach (j; x1 .. x) {
132       m[i, j] = b[i, j - x1];
133     }
134   }
135 
136   foreach (i; y1 .. y) {
137     foreach (j; 0 .. x1) {
138       m[i, j] = c[i - y1, j];
139     }
140     foreach (j; x1 .. x) {
141       m[i, j] = d[i - y1, j - x1];
142     }
143   }
144 
145   return m;
146 }
147 
148 /++
149     Inverse of Low Triangular Matrix
150 +/
151 Tensor invL(Tensor l) {
152   auto r = l.nrow;
153   auto m = Tensor(l);
154 
155   if (r == 1) {
156     return l;
157   } else if (r == 2) {
158     m[1, 0] = -l[1, 0];
159     return m;
160   } else {
161     auto ls = m.block;
162     auto l1 = ls[0];
163     auto l2 = ls[1];
164     auto l3 = ls[2];
165     auto l4 = ls[3];
166 
167     auto m1 = l1.invL;
168     auto m2 = l2;
169     auto m4 = l4.invL;
170     auto m3 = -(m4 % l3 % m1);
171 
172     return combine(m1, m2, m3, m4);
173   }
174 }
175 
176 /++
177     Inverse of Upper triangular matrix
178 +/
179 Tensor invU(Tensor u) {
180   auto r = u.nrow;
181   auto m = Tensor(u);
182 
183   if (r == 1) {
184     m[0, 0] = 1 / u[0, 0];
185     return m;
186   } else if (r == 2) {
187     auto a = m[0, 0];
188     auto b = m[0, 1];
189     auto c = m[1, 1];
190     auto d = a * c;
191 
192     m[0, 0] = 1 / a;
193     m[0, 1] = -b / d;
194     m[1, 1] = 1 / c;
195 
196     return m;
197   } else {
198     auto us = u.block;
199     auto u1 = us[0];
200     auto u2 = us[1];
201     auto u3 = us[2];
202     auto u4 = us[3];
203 
204     auto m1 = u1.invU;
205     auto m3 = u3;
206     auto m4 = u4.invU;
207     auto m2 = -(m1 % u2 % m4);
208 
209     return combine(m1, m2, m3, m4);
210   }
211 }