Matrix Multiplication with Java Threads - Optimized Code (Parallel)

1. Introduction


In this tutorial, We will write the code to matrix multiplication in java using the normal approach and multiple threads. This question will be asked in many interview program questions to see whether can you improve the performance for large matrixes.

We'll implement the programs for both cases.

Basic Matrix Multiplication Ref

Matrix 1 order = m x n (m rows and n columns)
Matrix 2 order = n x p (n rows and p columns)

Result matrix order = m x p (m rows and p columns)

Here, we are assuming input is given accurately.

Matrix Multiplication with Java Threads - Optimized Code (Parallel)




2. Matrix Generator Utility class


Let us create a class MatrixGeneratorUtil with a static method generate() which takes two int parameters such as rows and columns. This method is used to generate the matrix with those dimensions.

We will use this class in both versions of the code.


package com.java.w3schools.blog.java.program.to.threads.matrix;

import java.util.Random;

/**
* Utility class to generate the matrix.
*
* @author JavaProgramTo.com
*
*/
public class MatrixGeneratorUtil {

public static int[][] generateMatrix(int rows, int columns) {

// output array to store the matrix values
int[][] result = new int[rows][columns];

// TO generate a random integer.
Random random = new Random();

// adding values at each index.
for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {

result[i][j] = random.nextInt(100) * 10;

}
}

// returning output.
return result;
}

// to print the matrix
public static void print(int[][] matrix) {

System.out.println();

int rows = matrix.length;
int columns = matrix[0].length;

for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {
System.out.print(matrix[i][j] + " ");
}
System.out.println();
}

}

}

3. Traditional Approach


Creating a class that does the core logic of matrix multiplication. Create a static method multiply() which takes two matrixes and returns a resultant matrix. Inside this method, we used triple for loop to find the result.

Below is the source code for this.

package com.java.w3schools.blog.java.program.to.threads.matrix;

import java.util.Date;

/**
* Normal way to do matrix multiplication.
*
* @author JavaProgramTo.com
*
*/
public class MatrixMultiplication {

public static void main(String[] args) {

Date start = new Date();

int[][] m1 = MatrixGeneratorUtil.generateMatrix(3, 3);
int[][] m2 = MatrixGeneratorUtil.generateMatrix(3, 3);

int[][] result = multiply(m1, m2);

System.out.println("Matrix 1 : ");
MatrixGeneratorUtil.print(m1);

System.out.println("\nMatrix 2 : ");
MatrixGeneratorUtil.print(m2);

System.out.println("\nOutput Matrix : ");
MatrixGeneratorUtil.print(result);

Date end = new Date();
System.out.println("\nTime taken in milli seconds: " + (end.getTime() - start.getTime()));

}

public static int[][] multiply(int[][] matrix1, int[][] matrix2) {
int resultRows = matrix1.length;
int resultColumns = matrix2[0].length;

int[][] result = new int[resultRows][resultColumns];

int columns2 = matrix2[0].length;

for (int i = 0; i < resultRows; i++) {
for (int j = 0; j < columns2; j++) {
result[i][j] = 0;
for (int k = 0; k < resultColumns; k++) {
result[i][j] += matrix1[i][k] * matrix2[k][j];
}
}
}

return result;

}
}

Output:

Matrix 1 : 

830 930 360
840 140 580
200 270 370

Matrix 2 :

440 600 390
630 590 390
830 220 530

Output Matrix :

1249900 1125900 877200
939200 714200 689600
565200 360700 379400

Time taken in milli seconds: 2

To execute this program for order 3 x 3 it took 2ms. If the input order is 1000 x 1000 or 2000 x 2000 then it will take a longer time as below.

Time taken for order 1000 x 1000 in milli seconds: 2158
Time taken for order 2000 x 2000 in milli seconds: 65879

4. Multithreading code for matrix multiplication


To optimize the performance of this program, we have to use the advantage of multi-cores. We will create a thread for each row in a matrix that does the multiplication in parellel and reduce the processing time.

Let us create a Thread class that implements Runnable Interface.

package com.java.w3schools.blog.java.program.to.threads.matrix;

public class RowMultiplyWorker implements Runnable {

private final int[][] result;
private int[][] matrix1;
private int[][] matrix2;
private final int row;

public RowMultiplyWorker(int[][] result, int[][] matrix1, int[][] matrix2, int row) {
this.result = result;
this.matrix1 = matrix1;
this.matrix2 = matrix2;
this.row = row;
}

@Override
public void run() {

for (int i = 0; i < matrix2[0].length; i++) {
result[row][i] = 0;
for (int j = 0; j < matrix1[row].length; j++) {
result[row][i] += matrix1[row][j] * matrix2[j][i];

}

}

}

}


Next, create a class to create 10 threads at a time because if we create 2000 threads for 2000 x 2000 matrix then the application gets to hang up. So we will be using the 10 threads as a group and let them complete then again initiate the next 10 threads untill complete each row multiplication.

package com.java.w3schools.blog.java.program.to.threads.matrix;

import java.util.ArrayList;
import java.util.List;

public class ParallelThreadsCreator {

// creating 10 threads and waiting for them to complete then again repeat steps.
public static void multiply(int[][] matrix1, int[][] matrix2, int[][] result) {
List threads = new ArrayList<>();
int rows1 = matrix1.length;
for (int i = 0; i < rows1; i++) {
RowMultiplyWorker task = new RowMultiplyWorker(result, matrix1, matrix2, i);
Thread thread = new Thread(task);
thread.start();
threads.add(thread);
if (threads.size() % 10 == 0) {
waitForThreads(threads);
}
}
}

private static void waitForThreads(List threads) {
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
threads.clear();
}
}

Let us create the main class to test the time taking using this approach.

package com.java.w3schools.blog.java.program.to.threads.matrix;

import java.util.Date;

/**
* Normal way to do matrix multiplication.
*
* @author JavaProgramTo.com
*
*/
public class MatrixMultiplicationParallel {

public static void main(String[] args) {

Date start = new Date();

int[][] m1 = MatrixGeneratorUtil.generateMatrix(2000, 2000);
int[][] m2 = MatrixGeneratorUtil.generateMatrix(2000, 2000);

int[][] result = new int[m1.length][m2[0].length];
ParallelThreadsCreator.multiply(m1, m2, result);

Date end = new Date();
System.out.println("\nTime taken in milli seconds: " + (end.getTime() - start.getTime()));

}

}

Output:

Time taken in milli seconds: 17103

5. Conclusion


In this article, we have seen how to build the concurrent application for Matrix Multiplication. Now compare the time taken for 2000 x 2000 sized matrix is 65 seconds in the normal program and when we applied the multiple threads then it took only 17 seconds to complete this.

Still we can improve parallel version using Runtime.getRuntime().availableProcessors(). This returns the threads available to use and this count can be instead of 10 threads.

So, It is better to use the multiple threads efficiently for large scale applications.

0 Comments