Monday, August 17, 2015

Parallel Merge Sort using Fork and Join

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;

public class MergeSortForkJoin {
    public static class MergeSortTask extends RecursiveAction {
        private static final long serialVersionUID = 1L;
        private final List<Integer> list;
        private final int lo;
        private final int hi;
        
        public MergeSortTask(List<Integer> list, int lo, int hi) {
            this.list = list;
            this.lo = lo;
            this.hi = hi;
        }
        
        @Override
        protected void compute() {
            if (lo >= hi) {
                return;
            } else {
                int mid = (lo + hi) / 2;
                MergeSortTask task1 = new MergeSortTask(list, lo, mid);
                MergeSortTask task2 = new MergeSortTask(list, mid+1, hi);

                invokeAll(task1, task2);

                merge(list, lo, mid, hi);
            }
        }
        
        private void merge(List<Integer> list, int lo, int mid, int hi) {
            List<Integer> tmp = new ArrayList<>();
            for (int i : list) {
                tmp.add(i);
            }
            int left = lo;
            int right = mid + 1;
            int idx = lo;
            while (left <= mid && right <= hi) {
                if (tmp.get(left) <= tmp.get(right)) {
                    Integer element = tmp.get(left);
                    list.set(idx, element);
                    left++;
                    idx++;
                } else {
                    Integer element = tmp.get(right);
                    list.set(idx, element);
                    right++;
                    idx++;
                }
            }
            while (left <= mid) {
                Integer element = tmp.get(left);
                list.set(idx, element);
                idx++;
                left++;
            }
            while (right <= hi) {
                Integer element = tmp.get(right);
                list.set(idx, element);
                idx++;
                right++;
            }
        }
    }
    
    public static void main(String[] args) throws Exception {
        ForkJoinPool pool = new ForkJoinPool();
        List<Integer> list = Arrays.asList(4, 9, 1, 5, 8, 0, 7, 6, 3, 2);
        System.out.println("Unsorted: " + list);
        MergeSortTask task = new MergeSortTask(list, 0, list.size()-1);
        try {
            do {
                pool.execute(task);
            } while (!task.isDone());
        } finally {
            pool.shutdown();
        }
        System.out.println("Sorted:   " + list);
    }
}

No comments:

Post a Comment