Sunday, April 22, 2012

How to Implement Priority Queue in Python

Below a simple implementation of priority queue using binary heap.
#!/usr/bin/env python

class PriorityQueue(object):
    def __init__(self, compare=cmp):
        self.pq = []
        self.n = 0
        self.compare = compare
        
    def insert(self, element):
        self.swim(element)
        
    def remove(self):
        if self.n == 0: return None
        e = self.pq[0]
        self.sink()
        return e
    
    def elements(self):
        return [self.pq[i] for i in xrange(0, self.n)]
        
    def swim(self, element):
        self.n += 1
        if len(self.pq) >= self.n:
            self.pq[self.n-1] = element
        else:
            self.pq.append(element)
        
        current_pos = self.n
        parent_pos = current_pos / 2
        # python uses 0-based index, hence always minus the position by 1
        while (self.compare(self.pq[current_pos-1], self.pq[parent_pos-1]) > 0
               and current_pos > 1):
            self.pq[current_pos-1], self.pq[parent_pos-1] = (self.pq[parent_pos-1], 
                                                             self.pq[current_pos-1])
            current_pos = parent_pos
            parent_pos = current_pos / 2
        
    def sink(self):
        current_pos = 1
        # since it's a binary heap, a parent can only have max 2 children
        self.pq[current_pos-1] = self.pq[self.n-1]
        self.n -= 1
        child_pos = self.get_child_position(current_pos)
        if child_pos == -1: return
        
        while self.compare(self.pq[child_pos-1], self.pq[current_pos-1]) > 0:
            self.pq[child_pos-1], self.pq[current_pos-1] = (self.pq[current_pos-1],
                                                            self.pq[child_pos-1])
            current_pos = child_pos
            child_pos = self.get_child_position(current_pos)
            if child_pos == -1: break
    
    def get_child_position(self, current_pos):
        first_child_pos = 2 * current_pos
        second_child_pos = (2 * current_pos) + 1
        child_pos = -1 # no child
        # no child
        if first_child_pos > self.n: return child_pos
        # there's only one child
        if second_child_pos > self.n: child_pos = first_child_pos
        else: # there are two children
            if self.compare(self.pq[first_child_pos-1], self.pq[second_child_pos-1]) > 0:
                child_pos = first_child_pos
            else:
                child_pos = second_child_pos
        return child_pos
        
    def size(self):
        return self.n
    
def min_compare(x, y):
    if x < y: return 1
    elif x == y: return 0
    else: return -1
    
if __name__ == "__main__":
    pq = PriorityQueue()
    pq.insert(4)
    pq.insert(5)
    pq.insert(1)
    pq.insert(8)
    pq.insert(2)
    pq.insert(3)
    pq.insert(10)
    pq.insert(7)
    pq.insert(9)
    pq.insert(6)
    
    assert 10 == pq.remove()
    assert 9 == pq.remove()
    assert 8 == pq.remove()
    assert 7 == pq.remove()
    assert 6 == pq.remove()
    assert 5 == pq.remove()
    assert 4 == pq.remove()
    assert 3 == pq.remove()
    assert 2 == pq.remove()
    assert 1 == pq.remove()

    pq.insert(4)
    pq.insert(5)
    assert 5 == pq.remove()
    assert 4 == pq.remove()

    assert None == pq.remove()
    
    pq.insert(4)
    pq.insert(5)
    pq.insert(1)
    pq.insert(8)
    pq.insert(2)
    pq.insert(3)
    pq.insert(10)
    pq.insert(7)
    pq.insert(9)
    pq.insert(6)
    
    assert 10 == pq.remove()
    assert 9 == pq.remove()
    assert 8 == pq.remove()
    assert 7 == pq.remove()
    assert 6 == pq.remove()
    assert 5 == pq.remove()
    assert 4 == pq.remove()
    assert 3 == pq.remove()
    assert 2 == pq.remove()
    assert 1 == pq.remove()
    
    pq = PriorityQueue(min_compare)
    pq.insert(4)
    pq.insert(5)
    pq.insert(1)
    pq.insert(8)
    pq.insert(2)
    pq.insert(3)
    pq.insert(10)
    pq.insert(7)
    pq.insert(9)
    pq.insert(6)
    
    assert 1 == pq.remove()
    assert 2 == pq.remove()
    assert 3 == pq.remove()
    assert 4 == pq.remove()
    assert 5 == pq.remove()
    assert 6 == pq.remove()
    assert 7 == pq.remove()
    assert 8 == pq.remove()
    assert 9 == pq.remove()
    assert 10 == pq.remove()
    
    pq.insert(4)
    pq.insert(5)
    assert 4 == pq.remove()
    assert 5 == pq.remove()

    assert None == pq.remove()
    
    pq.insert(4)
    pq.insert(5)
    pq.insert(1)
    pq.insert(8)
    pq.insert(2)
    pq.insert(3)
    pq.insert(10)
    pq.insert(7)
    pq.insert(9)
    pq.insert(6)
    
    assert 1 == pq.remove()
    assert 2 == pq.remove()
    assert 3 == pq.remove()
    assert 4 == pq.remove()
    assert 5 == pq.remove()
    assert 6 == pq.remove()
    assert 7 == pq.remove()
    assert 8 == pq.remove()
    assert 9 == pq.remove()
    assert 10 == pq.remove()

No comments:

Post a Comment