Tuesday, December 11, 2012

How to Implement Matrix Multiplication

public class Matrix {
    private static int[][] multiply(int[][] a, int[][] b) {
        if (a[0].length != b.length) {
            throw new IllegalArgumentException(
                "A(" + a.length + "x" + a[0].length + ") did not match " +
                "B(" + b.length + "x" + b[0].length + ")");
        }
        
        int[][] c = new int[a.length][b[0].length];
        for (int i = 0; i < a.length; i++) {
            for (int j = 0; j < b[0].length; j++) {
                for (int k = 0; k < a[0].length; k++) {
                    c[i][j] += a[i][k] * b[k][j];
                }
            }
        }
        return c;
    }
    
    private static void print(int[][] a) {
        for (int i = 0; i < a.length; i++) {
            for (int j = 0; j < a[i].length; j++) {
                if (a[i][j] < 10) {
                    System.out.print(" ");
                } 
                System.out.print(a[i][j] + " ");
            }
            System.out.println();
        }
    }
    
    public static void main(String[] args) {
        int[][] a = new int[2][3];
        int value = 1;
        for (int i = 0; i < a.length; i++) {
            for (int j = 0; j < a[i].length; j++) {
                a[i][j] = value++;
            }
        }
        System.out.println("Matrix A:");
        print(a);
        
        int[][] b = new int[3][2];
        value = 1;
        for (int i = 0; i < b.length; i++) {
            for (int j = 0; j < b[i].length; j++) {
                b[i][j] = value++;
            }
        }
        System.out.println("Matrix B:");
        print(b);
        
        
        int[][] c = multiply(a, b);
        System.out.println("Matrix C:");
        print(c);
    }
}

No comments:

Post a Comment