matrix-multiplication.java 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import java.io.BufferedReader;
  2. import java.io.FileReader;
  3. import java.io.IOException;
  4. import java.util.LinkedList;
  5. import java.util.List;
  6. import java.util.ArrayList;
  7. import java.util.concurrent.Callable;
  8. import java.util.concurrent.ExecutionException;
  9. import java.util.concurrent.ExecutorService;
  10. import java.util.concurrent.Executors;
  11. import java.util.concurrent.Future;
  12. public class Shell {
  13. static List<ArrayList<ArrayList<Integer>>> read(String filename) {
  14. ArrayList<ArrayList<Integer>> A = new ArrayList<ArrayList<Integer>>();
  15. ArrayList<ArrayList<Integer>> B = new ArrayList<ArrayList<Integer>>();
  16. String thisLine;
  17. try {
  18. BufferedReader br = new BufferedReader(new FileReader(filename));
  19. // Begin reading A
  20. while ((thisLine = br.readLine()) != null) {
  21. if (thisLine.trim().equals("")) {
  22. break;
  23. } else {
  24. ArrayList<Integer> line = new ArrayList<Integer>();
  25. String[] lineArray = thisLine.split("\t");
  26. for (String number : lineArray) {
  27. line.add(Integer.parseInt(number));
  28. }
  29. A.add(line);
  30. }
  31. }
  32. // Begin reading B
  33. while ((thisLine = br.readLine()) != null) {
  34. ArrayList<Integer> line = new ArrayList<Integer>();
  35. String[] lineArray = thisLine.split("\t");
  36. for (String number : lineArray) {
  37. line.add(Integer.parseInt(number));
  38. }
  39. B.add(line);
  40. }
  41. br.close();
  42. } catch (IOException e) {
  43. System.err.println("Error: " + e);
  44. }
  45. List<ArrayList<ArrayList<Integer>>> res = new LinkedList<ArrayList<ArrayList<Integer>>>();
  46. res.add(A);
  47. res.add(B);
  48. return res;
  49. }
  50. static void printMatrix(int[][] matrix) {
  51. for (int[] line : matrix) {
  52. int i = 0;
  53. StringBuilder sb = new StringBuilder(matrix.length);
  54. for (int number : line) {
  55. if (i != 0) {
  56. sb.append("\t");
  57. } else {
  58. i++;
  59. }
  60. sb.append(number);
  61. }
  62. System.out.println(sb.toString());
  63. }
  64. }
  65. public static int[][] parallelMult(ArrayList<ArrayList<Integer>> A,
  66. ArrayList<ArrayList<Integer>> B, int threadNumber) {
  67. int[][] C = new int[A.size()][B.get(0).size()];
  68. ExecutorService executor = Executors.newFixedThreadPool(threadNumber);
  69. List<Future<int[][]>> list = new ArrayList<Future<int[][]>>();
  70. int part = A.size() / threadNumber;
  71. if (part < 1) {
  72. part = 1;
  73. }
  74. for (int i = 0; i < A.size(); i += part) {
  75. System.err.println(i);
  76. Callable<int[][]> worker = new LineMultiplier(A, B, i, i+part);
  77. Future<int[][]> submit = executor.submit(worker);
  78. list.add(submit);
  79. }
  80. // now retrieve the result
  81. int start = 0;
  82. int CF[][];
  83. for (Future<int[][]> future : list) {
  84. try {
  85. CF = future.get();
  86. for (int i=start; i < start+part; i += 1) {
  87. C[i] = CF[i];
  88. }
  89. } catch (InterruptedException e) {
  90. e.printStackTrace();
  91. } catch (ExecutionException e) {
  92. e.printStackTrace();
  93. }
  94. start+=part;
  95. }
  96. executor.shutdown();
  97. return C;
  98. }
  99. public static void main(String[] args) {
  100. String filename;
  101. int cores = Runtime.getRuntime().availableProcessors();
  102. System.err.println("Number of cores:\t" + cores);
  103. int threads;
  104. if (args.length < 3) {
  105. filename = "3.in";
  106. threads = cores;
  107. } else {
  108. filename = args[1];
  109. threads = Integer.parseInt(args[2]);
  110. }
  111. List<ArrayList<ArrayList<Integer>>> matrices = read(filename);
  112. ArrayList<ArrayList<Integer>> A = matrices.get(0);
  113. ArrayList<ArrayList<Integer>> B = matrices.get(1);
  114. int[][] C = parallelMult(A, B, threads);
  115. printMatrix(C);
  116. }
  117. }