TheAlgorithms/C++ 1.0.0
All the algorithms implemented in C++
Loading...
Searching...
No Matches
strassen_matrix_multiplication.cpp
1
13#include <cassert>
14#include <chrono>
15#include <iostream>
16#include <tuple>
17#include <vector>
18
23namespace divide_and_conquer {
24
30
32constexpr size_t MAX_SIZE = ~0ULL;
36template <typename T,
37 typename = typename std::enable_if<
38 std::is_integral<T>::value || std::is_floating_point<T>::value,
39 bool>::type>
40class Matrix {
41 std::vector<std::vector<T>> _mat;
42
43 public:
50 template <typename Integer,
51 typename = typename std::enable_if<
52 std::is_integral<Integer>::value, Integer>::type>
53 explicit Matrix(const Integer size) {
54 for (size_t i = 0; i < size; ++i) {
55 _mat.emplace_back(std::vector<T>(size, 0));
56 }
57 }
58
66 template <typename Integer,
67 typename = typename std::enable_if<
68 std::is_integral<Integer>::value, Integer>::type>
69 Matrix(const Integer rows, const Integer cols) {
70 for (size_t i = 0; i < rows; ++i) {
71 _mat.emplace_back(std::vector<T>(cols, 0));
72 }
73 }
74
79 inline std::pair<size_t, size_t> size() const {
80 return {_mat.size(), _mat[0].size()};
81 }
82
90 template <typename Integer,
91 typename = typename std::enable_if<
92 std::is_integral<Integer>::value, Integer>::type>
93 inline std::vector<T> &operator[](const Integer index) {
94 return _mat[index];
95 }
96
106 Matrix slice(const size_t row_start, const size_t row_end = MAX_SIZE,
107 const size_t col_start = MAX_SIZE,
108 const size_t col_end = MAX_SIZE) const {
109 const size_t h_size =
110 (row_end != MAX_SIZE ? row_end : _mat.size()) - row_start;
111 const size_t v_size = (col_end != MAX_SIZE ? col_end : _mat[0].size()) -
112 (col_start != MAX_SIZE ? col_start : 0);
113 Matrix result = Matrix<T>(h_size, v_size);
114
115 const size_t v_start = (col_start != MAX_SIZE ? col_start : 0);
116 for (size_t i = 0; i < h_size; ++i) {
117 for (size_t j = 0; j < v_size; ++j) {
118 result._mat[i][j] = _mat[i + row_start][j + v_start];
119 }
120 }
121 return result;
122 }
123
130 template <typename Number, typename = typename std::enable_if<
131 std::is_integral<Number>::value ||
132 std::is_floating_point<Number>::value,
133 Number>::type>
134 void h_stack(const Matrix<Number> &other) {
135 assert(_mat.size() == other._mat.size());
136 for (size_t i = 0; i < other._mat.size(); ++i) {
137 for (size_t j = 0; j < other._mat[i].size(); ++j) {
138 _mat[i].push_back(other._mat[i][j]);
139 }
140 }
141 }
142
149 template <typename Number, typename = typename std::enable_if<
150 std::is_integral<Number>::value ||
151 std::is_floating_point<Number>::value,
152 Number>::type>
153 void v_stack(const Matrix<Number> &other) {
154 assert(_mat[0].size() == other._mat[0].size());
155 for (size_t i = 0; i < other._mat.size(); ++i) {
156 _mat.emplace_back(std::vector<T>(other._mat[i].size()));
157 for (size_t j = 0; j < other._mat[i].size(); ++j) {
158 _mat.back()[j] = other._mat[i][j];
159 }
160 }
161 }
162
169 template <typename Number, typename = typename std::enable_if<
170 std::is_integral<Number>::value ||
171 std::is_floating_point<Number>::value,
172 bool>::type>
173 Matrix operator+(const Matrix<Number> &other) const {
174 assert(this->size() == other.size());
175 Matrix C = Matrix<Number>(_mat.size(), _mat[0].size());
176 for (size_t i = 0; i < _mat.size(); ++i) {
177 for (size_t j = 0; j < _mat[i].size(); ++j) {
178 C._mat[i][j] = _mat[i][j] + other._mat[i][j];
179 }
180 }
181 return C;
182 }
183
190 template <typename Number, typename = typename std::enable_if<
191 std::is_integral<Number>::value ||
192 std::is_floating_point<Number>::value,
193 bool>::type>
194 Matrix &operator+=(const Matrix<Number> &other) const {
195 assert(this->size() == other.size());
196 for (size_t i = 0; i < _mat.size(); ++i) {
197 for (size_t j = 0; j < _mat[i].size(); ++j) {
198 _mat[i][j] += other._mat[i][j];
199 }
200 }
201 return this;
202 }
203
210 template <typename Number, typename = typename std::enable_if<
211 std::is_integral<Number>::value ||
212 std::is_floating_point<Number>::value,
213 bool>::type>
214 Matrix operator-(const Matrix<Number> &other) const {
215 assert(this->size() == other.size());
216 Matrix C = Matrix<Number>(_mat.size(), _mat[0].size());
217 for (size_t i = 0; i < _mat.size(); ++i) {
218 for (size_t j = 0; j < _mat[i].size(); ++j) {
219 C._mat[i][j] = _mat[i][j] - other._mat[i][j];
220 }
221 }
222 return C;
223 }
224
231 template <typename Number, typename = typename std::enable_if<
232 std::is_integral<Number>::value ||
233 std::is_floating_point<Number>::value,
234 bool>::type>
235 Matrix &operator-=(const Matrix<Number> &other) const {
236 assert(this->size() == other.size());
237 for (size_t i = 0; i < _mat.size(); ++i) {
238 for (size_t j = 0; j < _mat[i].size(); ++j) {
239 _mat[i][j] -= other._mat[i][j];
240 }
241 }
242 return this;
243 }
244
251 template <typename Number, typename = typename std::enable_if<
252 std::is_integral<Number>::value ||
253 std::is_floating_point<Number>::value,
254 bool>::type>
255 inline Matrix operator*(const Matrix<Number> &other) const {
256 assert(_mat[0].size() == other._mat.size());
257 auto size = this->size();
258 const size_t row = size.first, col = size.second;
259 // Main condition for applying strassen's method:
260 // 1: matrix should be a square matrix
261 // 2: matrix should be of even size (mat.size() % 2 == 0)
262 return (row == col && (row & 1) == 0)
263 ? this->strassens_multiplication(other)
264 : this->naive_multiplication(other);
265 }
266
273 template <typename Number, typename = typename std::enable_if<
274 std::is_integral<Number>::value ||
275 std::is_floating_point<Number>::value,
276 bool>::type>
277 inline Matrix operator*(const Number other) const {
278 Matrix C = Matrix<Number>(_mat.size(), _mat[0].size());
279 for (size_t i = 0; i < _mat.size(); ++i) {
280 for (size_t j = 0; j < _mat[i].size(); ++j) {
281 C._mat[i][j] = _mat[i][j] * other;
282 }
283 }
284 return C;
285 }
286
293 template <typename Number, typename = typename std::enable_if<
294 std::is_integral<Number>::value ||
295 std::is_floating_point<Number>::value,
296 bool>::type>
297 Matrix &operator*=(const Number other) const {
298 for (size_t i = 0; i < _mat.size(); ++i) {
299 for (size_t j = 0; j < _mat[i].size(); ++j) {
300 _mat[i][j] *= other;
301 }
302 }
303 return this;
304 }
305
312 template <typename Number, typename = typename std::enable_if<
313 std::is_integral<Number>::value ||
314 std::is_floating_point<Number>::value,
315 bool>::type>
317 Matrix C = Matrix<Number>(_mat.size(), other._mat[0].size());
318
319 for (size_t i = 0; i < _mat.size(); ++i) {
320 for (size_t k = 0; k < _mat[0].size(); ++k) {
321 for (size_t j = 0; j < other._mat[0].size(); ++j) {
322 C._mat[i][j] += _mat[i][k] * other._mat[k][j];
323 }
324 }
325 }
326 return C;
327 }
328
336 template <typename Number, typename = typename std::enable_if<
337 std::is_integral<Number>::value ||
338 std::is_floating_point<Number>::value,
339 bool>::type>
341 const size_t size = _mat.size();
342 // Base case: when a matrix is small enough for faster naive
343 // multiplication, or the matrix is of odd size, then go with the naive
344 // multiplication route;
345 // else; go with the strassen's method.
346 if (size <= 64ULL || (size & 1ULL)) {
347 return this->naive_multiplication(other);
348 } else {
349 const Matrix<Number>
350 A = this->slice(0ULL, size >> 1, 0ULL, size >> 1),
351 B = this->slice(0ULL, size >> 1, size >> 1, size),
352 C = this->slice(size >> 1, size, 0ULL, size >> 1),
353 D = this->slice(size >> 1, size, size >> 1, size),
354 E = other.slice(0ULL, size >> 1, 0ULL, size >> 1),
355 F = other.slice(0ULL, size >> 1, size >> 1, size),
356 G = other.slice(size >> 1, size, 0ULL, size >> 1),
357 H = other.slice(size >> 1, size, size >> 1, size);
358
359 Matrix P1 = A.strassens_multiplication(F - H);
360 Matrix P2 = (A + B).strassens_multiplication(H);
361 Matrix P3 = (C + D).strassens_multiplication(E);
362 Matrix P4 = D.strassens_multiplication(G - E);
363 Matrix P5 = (A + D).strassens_multiplication(E + H);
364 Matrix P6 = (B - D).strassens_multiplication(G + H);
365 Matrix P7 = (A - C).strassens_multiplication(E + F);
366
367 // Building final matrix C11 would be
368 // [ | ]
369 // [ C11 | C12 ]
370 // C = [ ____ | ____ ]
371 // [ | ]
372 // [ C21 | C22 ]
373 // [ | ]
374
375 Matrix C11 = P5 + P4 - P2 + P6;
376 Matrix C12 = P1 + P2;
377 Matrix C21 = P3 + P4;
378 Matrix C22 = P1 + P5 - P3 - P7;
379
380 C21.h_stack(C22);
381 C11.h_stack(C12);
382 C11.v_stack(C21);
383
384 return C11;
385 }
386 }
387
393 bool operator==(const Matrix<T> &other) const {
394 if (_mat.size() != other._mat.size() ||
395 _mat[0].size() != other._mat[0].size()) {
396 return false;
397 }
398 for (size_t i = 0; i < _mat.size(); ++i) {
399 for (size_t j = 0; j < _mat[i].size(); ++j) {
400 if (_mat[i][j] != other._mat[i][j]) {
401 return false;
402 }
403 }
404 }
405 return true;
406 }
407
408 friend std::ostream &operator<<(std::ostream &out, const Matrix<T> &mat) {
409 for (auto &row : mat._mat) {
410 for (auto &elem : row) {
411 out << elem << " ";
412 }
413 out << "\n";
414 }
415 return out << "\n";
416 }
417};
418
419} // namespace strassens_multiplication
420
421} // namespace divide_and_conquer
422
427static void test() {
428 const size_t s = 512;
429 auto matrix_demo =
431
432 for (size_t i = 0; i < s; ++i) {
433 for (size_t j = 0; j < s; ++j) {
434 matrix_demo[i][j] = i + j;
435 }
436 }
437
438 auto matrix_demo2 =
440 for (size_t i = 0; i < s; ++i) {
441 for (size_t j = 0; j < s; ++j) {
442 matrix_demo2[i][j] = 2 + i + j;
443 }
444 }
445
446 auto start = std::chrono::system_clock::now();
447 auto Mat3 = matrix_demo2 * matrix_demo;
448 auto end = std::chrono::system_clock::now();
449
450 std::chrono::duration<double> time = (end - start);
451 std::cout << "Strassen time: " << time.count() << "s" << std::endl;
452
453 start = std::chrono::system_clock::now();
454 auto conf = matrix_demo2.naive_multiplication(matrix_demo);
455 end = std::chrono::system_clock::now();
456
457 time = end - start;
458 std::cout << "Normal time: " << time.count() << "s" << std::endl;
459
460 // std::cout << Mat3 << conf << std::endl;
461 assert(Mat3 == conf);
462}
463
468int main() {
469 test(); // run self-test implementation
470 return 0;
471}
void test()
Matrix slice(const size_t row_start, const size_t row_end=MAX_SIZE, const size_t col_start=MAX_SIZE, const size_t col_end=MAX_SIZE) const
Creates a new matrix and returns a part of it.
Matrix & operator-=(const Matrix< Number > &other) const
Subtract another matrices to current matrix.
Matrix(const Integer rows, const Integer cols)
Constructor.
bool operator==(const Matrix< T > &other) const
Compares two matrices if each of them are equal or not.
Matrix naive_multiplication(const Matrix< Number > &other) const
Naive multiplication performed on this.
Matrix operator*(const Matrix< Number > &other) const
Multiply two matrices and returns a new matrix.
Matrix operator-(const Matrix< Number > &other) const
Subtract two matrices and returns a new matrix.
Matrix strassens_multiplication(const Matrix< Number > &other) const
Strassens method of multiplying two matrices References: https://en.wikipedia.org/wiki/Strassen_algor...
void h_stack(const Matrix< Number > &other)
Horizontally stack the matrix (one after the other)
std::vector< T > & operator[](const Integer index)
returns the address of the element at ith place (here ith row of the matrix)
Matrix operator+(const Matrix< Number > &other) const
Add two matrices and returns a new matrix.
Matrix & operator+=(const Matrix< Number > &other) const
Add another matrices to current matrix.
std::pair< size_t, size_t > size() const
Get the matrix shape.
Matrix operator*(const Number other) const
Multiply matrix with a number and returns a new matrix.
Matrix & operator*=(const Number other) const
Multiply a number to current matrix.
void v_stack(const Matrix< Number > &other)
Horizontally stack the matrix (current matrix above the other)
int main()
Main function.
for std::vector
Namespace for performing strassen's multiplication.