31
votes

I'm trying to implement Dijkstra's algorithm for finding shortest paths using a priority queue. In each step of the algorithm, I remove the vertex with the shortest distance from the priority queue, and then update the distances for each of its neighbors in the priority queue. Now I read that a Priority Queue in Java won't reorder when you edit the elements in it (the elements that determine the ordering), so I tried to force it to reorder by inserting and removing a dummy vertex. But this doesn't seem to be working, and I'm stuck trying to figure it out.

This is the code for the vertex object and the comparator

class vertex {
    int v, d;
    public vertex(int num, int dis) {
        v=num;
        d=dis;
    }
}

class VertexComparator implements Comparator {
    public int compare (Object a, Object b) {
        vertex v1 = (vertex)a;
        vertex v2 = (vertex)b;
        return v1.d-v2.d;
    }
 }

Here is then where I run the algorithm:

    int[] distances=new int[p];
    Comparator<vertex> comparator = new VertexComparator();
    PriorityQueue<vertex> queue = new PriorityQueue<vertex>(p, comparator);
    for(int i=0; i<p; i++) {
        if(i!=v) {
            distances[i]=MAX;
        }
        else {
            distances[i]=0;
        }
        queue.add(new vertex(i, distances[i]));
    }
    // run dijkstra
    for(int i=0; i<p; i++) {
        vertex cur=queue.poll();
        Iterator itr = queue.iterator();
        while(itr.hasNext()) {
            vertex test = (vertex)(itr.next());
            if(graph[cur.v][test.v]!=-1) {
                test.d=Math.min(test.d, cur.d+graph[cur.v][test.v]);
                distances[test.v]=test.d;
            }
        }
        // force the PQ to resort by adding and then removing a dummy vertex
        vertex resort = new vertex(-1, -1);
        queue.add(resort);
        queue.remove(resort);
    }

I've run several text cases, and I know that the priority queue isn't reordering correctly each time I go through and update the distances for vertices, but I don't know why. Did I make an error somewhere?

7

7 Answers

25
votes

As you discovered, a priority queue does not resort all elements whenever an element is added or removed. Doing that would be too expensive (remember the n log n lower bound for comparison sort), while any reasonable priority queue implementation (including PriorityQueue) promises to add/remove nodes in O(log n).

In fact, it doesn't sort its elements at all (that's why its iterator can not promise to iterate elements in sorted order).

PriorityQueue does not offer an api to inform it about a changed node, as that would require it to provide efficient node lookup, which its underlying algorithm does not support. Implementing a priority queue that does is quite involved. The Wikipedia article on PriorityQueues might be a good starting point for reading about that. I am not certain such an implementation would be faster, though.

A straightforward idea is to remove and then add the changed node. Do not do that as remove() takes O(n). Instead, insert another entry for the same node into the PriorityQueue, and ignore duplicates when polling the queue, i.e. do something like:

PriorityQueue<Step> queue = new PriorityQueue();

void findShortestPath(Node start) {
    start.distance = 0;
    queue.addAll(start.steps());

    Step step;
    while ((step = queue.poll()) != null) {
        Node node = step.target;
        if (!node.reached) {
            node.reached = true;
            node.distance = step.distance;
            queue.addAll(node.steps());
        }
    }

}

Edit: It is not advisable to change the priorities of elements in the PQ, hence the need to insert Steps instead of Nodes.

5
votes

you will have to delete and re-insert each element which is editted. (the actual element, and not a dummy one!). so, every time you update distances, you need to remove and add the elements that were affected by the changed entree.

as far as I know, this is not unique to Java, but every priority queue which runs at O(logn) for all ops, have to work this way.

4
votes

The disadvantage of Java's PriorityQueue is that remove(Object) requires O(n) time, resulting in O(n) time if you want to update priorities by removing and adding elements again. It can be done in time O(log(n)) however. As I wasn't able to find a working implementation via Google, I tried implementing it myself (in Kotlin though, as I really prefer that language to Java) using TreeSet. This seems to work, and should have O(log(n)) for adding/updating/removing (updating is done via add):

// update priority by adding element again (old occurrence is removed automatically)
class DynamicPriorityQueue<T>(isMaxQueue: Boolean = false) {

    private class MyComparator<A>(val queue: DynamicPriorityQueue<A>, isMaxQueue: Boolean) : Comparator<A> {
        val sign = if (isMaxQueue) -1 else 1

        override fun compare(o1: A, o2: A): Int {
            if (o1 == o2)
                return 0
            if (queue.priorities[o2]!! - queue.priorities[o1]!! < 0)
                return sign
            return -sign
        }

    }

    private val priorities = HashMap<T, Double>()
    private val treeSet = TreeSet<T>(MyComparator(this, isMaxQueue))

    val size: Int
        get() = treeSet.size

    fun isEmpty() = (size == 0)

    fun add(newElement: T, priority: Double) {
        if (newElement in priorities)
            treeSet.remove(newElement)
        priorities[newElement] = priority
        treeSet.add(newElement)
    }

    fun remove(element: T) {
        treeSet.remove(element)
        priorities.remove(element)
    }

    fun getPriorityOf(element: T): Double {
        return priorities[element]!!
    }


    fun first(): T = treeSet.first()
    fun poll(): T {
        val res = treeSet.pollFirst()
        priorities.remove(res)
        return res
    }

    fun pollWithPriority(): Pair<T, Double> {
        val res = treeSet.pollFirst()
        val priority = priorities[res]!!
        priorities.remove(res)
        return Pair(res, priority)
    }

}
2
votes

You can avoid updating items in the queue just marking each node as visited=false by default, and adding new items to the queue as you go.

Then pop a node from the queue and process it only if it was not visited before.

Dijkstra's algorithm guarantees that each node is visited only once, so even if you may have stale nodes down the queue you never really process them.

Also it's probably easier if you separate the algorithm internals from the graph data structure.

public void dijkstra(Node source) throws Exception{
    PriorityQueue q = new PriorityQueue();
    source.work.distance = 0;
    q.add(new DijkstraHeapItem(source));

    while(!q.isEmpty()){
        Node n = ((DijkstraHeapItem)q.remove()).node;
        Work w = n.work;

        if(!w.visited){
            w.visited = true;

            Iterator<Edge> adiacents = n.getEdgesIterator();
            while(adiacents.hasNext()){
                Edge e = adiacents.next();
                if(e.weight<0) throw new Exception("Negative weight!!");
                Integer relaxed = e.weight + w.distance;

                Node t = e.to;
                if (t.work.previous == null || t.work.distance > relaxed){
                    t.work.distance = relaxed;
                    t.work.previous = n;
                    q.add(new DijkstraHeapItem(t));
                }
            }
        }
    }
}
0
votes

The problem is that you update the distances array, but not the corresponding entry in the queue. To update the appropriate objects in the queue, you need to remove and then add.

0
votes

I solve this problem by dividing my process into timeSlots ( A time Scheduler will be just fine ) and Extending the native PriorityQueue. So I implement a notify method where the key of this method is the following code:

// If queue has one or less elements, then it shouldn't need an ordering
// procedure
if (size() > 1)
{
    // holds the current size, as during this process the size will
    // be vary
    int tmpSize = size();
    for (int i = 1; i < tmpSize; i++)
    {
        add(poll());
    }
}

I hope It helped.

0
votes

I implemented an adaptive MinHeap that supports reorderining (O(nlogn)) when objects' priority are updated, written in Python.

class Node:
    """
    Model an object in Heap.
    """

    def __init__(self, key, val, i=-1) -> None:
        self.key = key  # object ID
        self.val = val  # object priority
        self.i = i  # index in heap array


class AdaptiveMinHeap:
    """
    Heap for objects. Support reorderining when objects' priority are updated.
    """

    def __init__(self) -> None:
        self.hp = {0: Node(-1, -1, 0)}  # Use dict to simulate list (key as the index) to support efficient reordering.
        self.d = dict()

    def __len__(self):
        return len(self.hp)-1

    def _swap(self, anode, bnode):
        d = self.d
        anode.key, bnode.key = bnode.key, anode.key
        anode.val, bnode.val = bnode.val, anode.val
        d[anode.key] = anode
        d[bnode.key] = bnode

    def _swim(self, i):
        hp = self.hp
        while i//2 > 0 and hp[i].val < hp[i//2].val:
            self._swap(hp[i], hp[i//2])
            i = i//2

    def _sink(self, i):
        hp = self.hp
        while i*2 < len(hp):
            if i*2 + 1 >= len(hp) or hp[i*2+1].val >= hp[i*2].val:
                min_child = i*2
            else:
                min_child = i*2+1
            if hp[min_child].val < hp[i].val:
                self._swap(hp[min_child], hp[i])
            i = min_child

    def push(self, key, val):
        hp = self.hp
        d = self.d
        if key in d:
            self.remove(key)
        node = Node(key, val, len(hp))
        d[key] = node
        hp[node.i] = node
        self._swim(node.i)

    def pop(self):
        hp = self.hp
        if len(hp) > 1:
            return self.remove(hp[1].key)

    def remove(self, key):
        hp = self.hp
        d = self.d
        node = d[key]
        self._swap(hp[node.i], hp[len(hp)-1])
        hp.pop(len(hp)-1)
        self._sink(node.i)
        return d.pop(key)

hp = AdaptiveMinHeap()
hp.push(1, 40)
hp.push(4, 8900)
hp.push(2, 500)
hp.push(3, 1075)
hp.push(1, 1)
for _ in range(len(hp)):
    poped = hp.pop()
    print(poped.key, poped.val)
# 1
# 500
# 1075
# 8900