import java.util.Arrays;
public class StrassenMultiplyMatrix {
public static void main(String[] args) {
int[][] matrix1 = new int[][]{{1, 3}, {2, 4}};
int[][] matrix2 = new int[][]{{1, 3, 4, 5}, {2, 4, 7, 9}};
int[][] ans = new int[matrix1.length][matrix2[0].length];
strassenMatrixMultiply(matrix1.length, matrix1, matrix2, ans);
for (int i = 0; i < ans.length; i++) {
System.out.println(Arrays.toString(ans[i]));
}
}
public static void strassenMatrixMultiply(int N, int[][] matrix1, int[][] matrix2, int[][] result) {
if (N == 1) {
result[0][0] = matrix1[0][0] * matrix2[0][0];
return;
}
int half = N / 2;
int[][] A = new int[half][half];
int[][] B = new int[half][half];
int[][] C = new int[half][half];
int[][] D = new int[half][half];
int[][] E = new int[half][half];
int[][] F = new int[half][half];
int[][] G = new int[half][half];
int[][] H = new int[half][half];
int[][] C1 = new int[half][half];
int[][] C2 = new int[half][half];
int[][] C3 = new int[half][half];
int[][] C4 = new int[half][half];
int[][] P1 = new int[half][half];
int[][] P2 = new int[half][half];
int[][] P3 = new int[half][half];
int[][] P4 = new int[half][half];
int[][] P5 = new int[half][half];
int[][] P6 = new int[half][half];
int[][] P7 = new int[half][half];
int[][] tempA = new int[half][half];
int[][] tempB = new int[half][half];
for (int i = 0; i < half; i++) {
for (int j = 0; j < half; j++) {
A[i][j] = matrix1[i][j];
B[i][j] = matrix1[i][half + j];
C[i][j] = matrix1[i + half][j];
D[i][j] = matrix1[i + half][j + half];
E[i][j] = matrix2[i][j];
F[i][j] = matrix2[i][half + j];
G[i][j] = matrix2[i + half][j];
H[i][j] = matrix2[i + half][j + half];
}
}
matrixSub(F, H, tempB);
strassenMatrixMultiply(half, A, tempB, P1);
matrixAdd(A, B, tempA);
strassenMatrixMultiply(half, tempA, H, P2);
matrixAdd(C, D, tempA);
strassenMatrixMultiply(half, tempA, E, P3);
matrixSub(G, E, tempB);
strassenMatrixMultiply(half, D, tempB, P4);
matrixAdd(A, D, tempA);
matrixAdd(E, H, tempB);
strassenMatrixMultiply(half, tempA, tempB, P5);
matrixSub(B, D, tempA);
matrixAdd(G, H, tempB);
strassenMatrixMultiply(half, tempA, tempB, P6);
matrixSub(A, C, tempA);
matrixAdd(E, F, tempB);
strassenMatrixMultiply(half, tempA, tempB, P7);
matrixAdd(P5, P4, C1);
matrixSub(C1, P2, C1);
matrixAdd(C1, P6, C1);
matrixAdd(P1, P2, C2);
matrixAdd(P3, P4, C3);
matrixAdd(P5, P1, C4);
matrixSub(C4, P3, C4);
matrixSub(C4, P7, C4);
for (int i = 0; i < half; i++) {
for (int j = 0; j < half; j++) {
result[i][j] = C1[i][j];
result[i][j + half] = C2[i][j];
result[i + half][j] = C3[i][j];
result[i + half][j + half] = C4[i][j];
}
}
}
public static void matrixSub(int[][] matrixA, int[][] matrixB, int[][] result) {
for (int i = 0; i < matrixA.length; i++) {
for (int j = 0; j < matrixA.length; j++) {
result[i][j] = matrixA[i][j] - matrixB[i][j];
}
}
}
public static void matrixAdd(int[][] matrixA, int[][] matrixB, int[][] result) {
for (int i = 0; i < matrixA.length; i++) {
for (int j = 0; j < matrixA.length; j++) {
result[i][j] = matrixA[i][j] + matrixB[i][j];
}
}
}
}
References
- https://en.wikipedia.org/wiki/Strassen_algorithm