1 module dnum.matrix;
2 
3 import std.algorithm.comparison : min, max;
4 import std.conv : to;
5 
6 /// Light-weight R-like matrix structure
7 struct Matrix {
8     /// Data container
9     double[] data;
10 
11     /// Row
12     ulong row;
13 
14     /// Column
15     ulong col;
16 
17     /// Shape
18     bool by_row;
19 
20     /// Default Constructor
21     this(double[] vec, ulong row, ulong col, bool by_row) {
22         this.data = vec;
23         this.row = row;
24         this.col = col;
25         this.by_row = by_row;
26     }
27 
28     /// Initialize with single number
29     this(double val, ulong row, ulong col, bool by_row) {
30         this.data.length = row * col;
31         this.row = row;
32         this.col = col;
33         this.by_row = by_row;
34         foreach(i; 0 .. row * col) {
35             this.data[i] = val;
36         }
37     }
38 
39     void toString(scope void delegate(const(char)[]) sink) const { // @suppress(dscanner.suspicious.unused_parameter)
40         import std.stdio : write;
41 
42         this.spread.write;
43     }
44     
45     /// Change shape
46     Matrix change_shape() {
47         assert((this.row * this.col) == this.data.length);
48         auto r = this.row;
49         auto c = this.col;
50         auto l = r * c - 1;
51         double[] vec;
52         vec.length = r * c;
53 
54         switch (this.by_row) {
55             case true:
56                 foreach(i; 0 .. l) {
57                     auto s = (i * c) % l;
58                     vec[i] = this.data[s];
59                 }
60                 vec[l] = this.data[l];
61                 return Matrix(vec, r, c, false);
62             default:
63                 foreach(i; 0 .. l) {
64                     auto s = (i * r) % l;
65                     vec[i] = this.data[s];
66                 }
67                 vec[l] = this.data[l];
68                 return Matrix(vec, r, c, true);
69         }
70     }
71 
72     /// Spread 1D to 2D
73     string spread() const {
74         import std.format : format;
75 
76         assert(this.row * this.col == this.data.length);
77         ulong space = 5;
78         foreach(i; 0 .. this.row * this.col) {
79             auto temp = this.data[i];
80             const ulong m = min(to!string(temp).length, format!"%.4f"(temp).length) + 1;
81             if (m > space) {
82                 space = m;
83             }
84         }
85 
86         string result = "";
87 
88         result ~= tab("", 5);
89         foreach(i; 0 .. this.col) {
90             result ~= tab(format!"c[%d]"(i), space); // Header
91         }
92         result ~= '\n';
93 
94         foreach(i; 0 .. this.row) {
95             result ~= tab(format!"r[%d]"(i), 5);
96             foreach(j; 0 .. this.col) {
97                 const string st1 = format!"%.4f"(this[i, j]);
98                 const string st2 = to!string(this[i, j]);
99                 string st = (st1.length > st2.length) ? st2 : st1; // Choose smaller
100                 result ~= tab(st, space);
101             }
102             if (i == (this.row - 1)) {
103                 break;
104             }
105             result ~= '\n';
106         }
107         return result;
108     }
109 
110     // =========================================================================
111     // Operator Overloading
112     // =========================================================================
113     /// Getter
114     pure double opIndex(ulong i, ulong j) const {
115         switch(this.by_row) {
116             case true:
117                 const ulong idx_row = i * this.col + j;
118                 return this.data[idx_row];
119             default:
120                 const ulong idx_col = i + j * this.row;
121                 return this.data[idx_col];
122         }
123     }
124 
125     /// Setter
126     void opIndexAssign(double value, ulong i, ulong j) {
127         switch(this.by_row) {
128             case true:
129                 const ulong idx_row = i * this.col + j;
130                 this.data[idx_row] = value;
131                 break;
132             default:
133                 const ulong idx_col = i + this.row * j;
134                 this.data[idx_col] = value;
135                 break;
136         }
137     }
138 
139     /// Unary ops - negative
140     Matrix opUnary(string op)() {
141         double[] vec;
142         vec.length = this.row * this.col;
143 
144         switch (op) {
145             case "-":
146                 foreach(i; 0 .. this.row * this.col) {
147                     vec[i] = - this.data[i];
148                 }
149                 break;
150             default:
151                 break;
152         }
153 
154         return Matrix(vec, this.row, this.col, this.by_row);
155     }
156 
157     /// Binary ops with single number
158     Matrix opBinary(string op)(double rhs) {
159         double[] vec;
160         vec.length = this.row * this.col;
161 
162         switch (op) {
163             case "+":
164                 foreach(i; 0 .. this.row * this.col) {
165                     vec[i] = this.data[i] + rhs;
166                 }
167                 break;
168             case "-":
169                 foreach(i; 0 .. this.row * this.col) {
170                     vec[i] = this.data[i] - rhs;
171                 }
172                 break;
173             case "*":
174                 foreach(i; 0 .. this.row * this.col) {
175                     vec[i] = this.data[i] * rhs;
176                 }
177                 break;
178             case "/":
179                 foreach(i; 0 .. this.row / this.col) {
180                     vec[i] = this.data[i] / rhs;
181                 }
182                 break;
183             case "^^":
184                 foreach(i; 0 .. this.row * this.col) {
185                     vec[i] = this.data[i] ^^ rhs;
186                 }
187                 break;
188             default:
189                 break;
190         }
191         return Matrix(vec, this.row, this.col, this.by_row);
192     }
193 
194     /// Binary ops with single number - right hand side
195     Matrix opBinaryRight(string op)(double lhs) {
196         double[] vec;
197         vec.length = this.row * this.col;
198 
199         switch (op) {
200             case "+":
201                 foreach(i; 0 .. this.row * this.col) {
202                     vec[i] = this.data[i] + lhs;
203                 }
204                 return Matrix(vec, this.row, this.col, this.by_row);
205             case "-":
206                 foreach(i; 0 .. this.row * this.col) {
207                     vec[i] = this.data[i] - lhs;
208                 }
209                 return Matrix(vec, this.row, this.col, this.by_row);
210             case "*":
211                 foreach(i; 0 .. this.row * this.col) {
212                     vec[i] = this.data[i] * lhs;
213                 }
214                 return Matrix(vec, this.row, this.col, this.by_row);
215             case "/":
216                 foreach(i; 0 .. this.row / this.col) {
217                     vec[i] = this.data[i] / lhs;
218                 }
219                 return Matrix(vec, this.row, this.col, this.by_row);
220             case "^^":
221                 foreach(i; 0 .. this.row * this.col) {
222                     vec[i] = this.data[i] ^^ lhs;
223                 }
224                 return Matrix(vec, this.row, this.col, this.by_row);
225             default:
226                 throw new Exception("No Operation!");
227         }
228     }
229 
230     /// Binary ops with matrix
231     Matrix opBinary(string op)(Matrix rhs) {
232         double[] vec;
233         vec.length = this.row * this.col;
234 
235         assert(this.by_row == rhs.by_row);
236 
237         switch (op) {
238             case "+":
239                 foreach(i; 0 .. this.row * this.col) {
240                     vec[i] = this.data[i] + rhs.data[i];
241                 }
242                 return Matrix(vec, this.row, this.col, this.by_row);
243             case "-":
244                 foreach(i; 0 .. this.row * this.col) {
245                     vec[i] = this.data[i] - rhs.data[i];
246                 }
247                 return Matrix(vec, this.row, this.col, this.by_row);
248             case "*":
249                 foreach(i; 0 .. this.row * this.col) {
250                     vec[i] = this.data[i] * rhs.data[i];
251                 }
252                 return Matrix(vec, this.row, this.col, this.by_row);
253             case "/":
254                 foreach(i; 0 .. this.row * this.col) {
255                     vec[i] = this.data[i] / rhs.data[i];
256                 }
257                 return Matrix(vec, this.row, this.col, this.by_row);
258             case "%":
259                 assert(this.col == rhs.row);
260                 auto m = Matrix(0, this.row, rhs.col, this.by_row);
261                 foreach(i; 0 .. this.row) {
262                     foreach(j; 0 .. rhs.col) {
263                         double s = 0;
264                         foreach(k; 0 .. this.col) {
265                             s += this[i, k] * rhs[k, j];
266                         }
267                         m[i, j] = s;
268                     }
269                 }
270                 return m;
271             default:
272                 throw new Exception("No operation!");
273         }
274     }
275 
276     // =========================================================================
277     // Basic row & col ops
278     // =========================================================================
279     /// Transpose
280     Matrix transpose() {
281         switch (this.by_row) {
282             case(true):
283                 return Matrix(this.data, this.col, this.row, false);
284             default:
285                 return Matrix(this.data, this.col, this.row, true);
286         }
287     }
288 
289     /// Extract column - TODO: slice op
290     double[] cols(ulong idx) {
291         assert(idx < this.col);
292         double[] container;
293 
294         switch (this.by_row) {
295             case(true):
296                 const ulong l = this.row * this.col;
297                 foreach(i; 0 .. l) {
298                     if (i % this.col == idx) {
299                         container ~= this.data[i];
300                     }
301                 }
302                 break;
303             default:
304                 const ulong s = this.row * idx;
305                 container = this.data[s .. s + this.row];
306                 break;
307         }
308         return container;
309     }
310 
311     /// Extract rows - TODO: slice op
312     double[] rows(ulong idx) {
313         assert(idx < this.row);
314         double[] container;
315 
316         switch (this.by_row) {
317             case(true):
318                 const ulong s = this.col * idx;
319                 container = this.data[s .. s + this.row];
320                 break;
321             default:
322                 const ulong l = this.row * this.col;
323                 foreach(i; 0 .. l) {
324                     if (i % this.row == idx) {
325                         container ~= this.data[i];
326                     }
327                 }
328                 break;
329         }
330         return container;
331     }
332 }
333 
334 /// R-like matrix wrapper
335 Matrix matrix(double[] vec, ulong row, ulong col, bool by_row) {
336     return Matrix(vec, row, col, by_row);
337 }
338 
339 // =============================================================================
340 // Back-end utils
341 // =============================================================================
342 /// Flexible tab
343 string tab(string s, ulong space) {
344     import std.array: replicate;
345     const ulong l = s.length;
346     const string fs = " ".replicate(space - l);
347     return fs ~ s;
348 }