Strassons Matrix Multiplication - CSU083 - Shoolini U

Strassens Matrix Multiplication

Introduction

This is an introduction of Quick Select, also known as Selection Procedure, Algorithm.

Language: C++

#include <iostream>
#include <vector>
using namespace std;

// Function to add two matrices
vector<vector<int>> matrixAddition(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    vector<vector<int>> result(n, vector<int>(n, 0));

    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            result[i][j] = A[i][j] + B[i][j];
        }
    }
    return result;
}

// Function to subtract two matrices
vector<vector<int>> matrixSubtraction(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    vector<vector<int>> result(n, vector<int>(n, 0));

    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            result[i][j] = A[i][j] - B[i][j];
        }
    }
    return result;
}

// Function to perform standard matrix multiplication for small matrices
vector<vector<int>> standardMatrixMultiplication(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();
    vector<vector<int>> result(n, vector<int>(n, 0));

    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            for (int k = 0; k < n; k++) {
                result[i][j] += A[i][k] * B[k][j];
            }
        }
    }
    return result;
}

// Function to divide matrix into submatrices
void splitMatrix(const vector<vector<int>>& A, vector<vector<int>>& A11, vector<vector<int>>& A12,
                 vector<vector<int>>& A21, vector<vector<int>>& A22) {
    int n = A.size();
    int mid = n / 2;

    for (int i = 0; i < mid; i++) {
        for (int j = 0; j < mid; j++) {
            A11[i][j] = A[i][j];
            A12[i][j] = A[i][j + mid];
            A21[i][j] = A[i + mid][j];
            A22[i][j] = A[i + mid][j + mid];
        }
    }
}

// Function to join submatrices into a single matrix
void joinMatrix(vector<vector<int>>& C, const vector<vector<int>>& C11, const vector<vector<int>>& C12,
                const vector<vector<int>>& C21, const vector<vector<int>>& C22) {
    int n = C.size();
    int mid = n / 2;

    for (int i = 0; i < mid; i++) {
        for (int j = 0; j < mid; j++) {
            C[i][j] = C11[i][j];
            C[i][j + mid] = C12[i][j];
            C[i + mid][j] = C21[i][j];
            C[i + mid][j + mid] = C22[i][j];
        }
    }
}

// Strassen's Matrix Multiplication
vector<vector<int>> strassensMatrixMultiplication(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = A.size();

    if (n <= 2) {
        return standardMatrixMultiplication(A, B);
    }

    int mid = n / 2;

    // Divide matrices into submatrices
    vector<vector<int>> A11(mid, vector<int>(mid));
    vector<vector<int>> A12(mid, vector<int>(mid));
    vector<vector<int>> A21(mid, vector<int>(mid));
    vector<vector<int>> A22(mid, vector<int>(mid));

    vector<vector<int>> B11(mid, vector<int>(mid));
    vector<vector<int>> B12(mid, vector<int>(mid));
    vector<vector<int>> B21(mid, vector<int>(mid));
    vector<vector<int>> B22(mid, vector<int>(mid));

    splitMatrix(A, A11, A12, A21, A22);
    splitMatrix(B, B11, B12, B21, B22);

    // Strassen's Matrix Multiplication calculations
    vector<vector<int>> M1 = strassensMatrixMultiplication(matrixAddition(A11, A22), matrixAddition(B11, B22));
    vector<vector<int>> M2 = strassensMatrixMultiplication(matrixAddition(A21, A22), B11);
    vector<vector<int>> M3 = strassensMatrixMultiplication(A11, matrixSubtraction(B12, B22));
    vector<vector<int>> M4 = strassensMatrixMultiplication(A22, matrixSubtraction(B21, B11));
    vector<vector<int>> M5 = strassensMatrixMultiplication(matrixAddition(A11, A12), B22);
    vector<vector<int>> M6 = strassensMatrixMultiplication(matrixSubtraction(A21, A11), matrixAddition(B11, B12));
    vector<vector<int>> M7 = strassensMatrixMultiplication(matrixSubtraction(A12, A22), matrixAddition(B21, B22));

    vector<vector<int>> C11 = matrixAddition(matrixSubtraction(matrixAddition(M1, M4), M5), M7);
    vector<vector<int>> C12 = matrixAddition(M3, M5);
    vector<vector<int>> C21 = matrixAddition(M2, M4);
    vector<vector<int>> C22 = matrixAddition(matrixAddition(matrixSubtraction(M1, M2), M3), M6);

    // Join submatrices into result matrix
    vector<vector<int>> result(n, vector<int>(n));
    joinMatrix(result, C11, C12, C21, C22);

    return result;
}

// Function to display matrix
void displayMatrix(const vector<vector<int>>& matrix) {
    int n = matrix.size();
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            cout << matrix[i][j] << " ";
        }
        cout << endl;
    }
}

int main() {
    int n;
    cout << "Enter the size of the square matrices: ";
    cin >> n;

    cout << "Enter elements of matrix A of size " << n << "x" << n << " row-wise:" << endl;
    vector<vector<int>> A(n, vector<int>(n));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            cin >> A[i][j];
        }
    }

    cout << "Enter elements of matrix B of size " << n << "x" << n << " row-wise:" << endl;
    vector<vector<int>> B(n, vector<int>(n));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            cin >> B[i][j];
        }
    }

    cout << "Matrix A:" << endl;
    displayMatrix(A);

    cout << "Matrix B:" << endl;
    displayMatrix(B);

    vector<vector<int>> result = strassensMatrixMultiplication(A, B);

    cout << "Matrix A * B (Strassen's Multiplication):" << endl;
    displayMatrix(result);

    return 0;
}

Output:

Enter the size of the square matrices: 2
Enter elements of matrix A of size 2x2 row-wise:
 1 4 2 5
Enter elements of matrix B of size 2x2 row-wise:
 2 56 4 8
Matrix A:
 1 4
 2 5
Matrix B:
 2 56
 4 8
Matrix A * B (Strassen's Multiplication):
 18 88
 24 152