Banner

Saturday, January 31, 2015

BufferedLatch



























Multithreading can be hard. When you have several threads modifying and contending for shared objects, variables, and resources... many things can go wrong if you are not deliberate and careful.

In Java, that is why it is critical to utilize synchronizers like CountDownLatch, Semaphore, and Cyclic Barrier (in addition to other multithreading tools like volatile and final keywords). These all ensure thread safety or prevent them them from proceeding until certain conditions are met and it is safe to proceed.

An analogy that I find really helpful is comparing this to the Kentucky Derby. The race horses (the "threads") walk up to the gate (the "synchronizer"), but the gate is closed so they cannot proceed any further (in a state of "await()"). Once a certain condition is cleared, the gate opens ("notifyAll()") and the horses are free to charge forward.

The CountDownLatch is probably the synchronizer I use the most. Its policy is simple and effective. It starts at a specified number on construction, and decrements on each call to "countDown()". When it reaches zero, the CountDownLatch lets any other threads waiting on it pass through.

If you have three tasks each running on a separate thread, but you don't want the main thread to proceed until these three tasks are done, create a CountDownLatch and let each task countDown() it when they are done. Have the main thread await() on the CountDownLatch, and when all three tasks are done the main thread will proceed.

The BufferedLatch

However, what if we don't know what the count will be in advance? We do not have a number for the CountDownLatch to start at, but we will only know later in the process. I ran into this issue a lot dealing with buffered streams from database queries. As I looped through each database record, I wanted to kick it off into a task and submit it into a fixed threadpool. The problem is once all the records are looped through, how do I pause until all the runnables are complete that are processing those records?

ResultSet rs = //issue some query
ExecutorService executor = //create a fixed thread pool

while (rs.next()) { 
 FinanceDay financeDay = convertRecordToFinanceDay(rs);
 executor.submit(() -> addToConcurrentHashMap(financeDay));
}

//I want to wait for the executor to finish all the runnables here
rs.close();
executor.shutDown();

Looking at the conceptual code above, we are looping through each record in the query as each record comes in. But we do not know how many records there will be until rs.next() returns false and we finish looping. This is problematic if we are going to convert each record into an object (in this case "FinanceDay"), and then pass that object to a Runnable that runs on an entirely separate thread. If we don't find a synchronization solution, we have an enormous risk of creating a race condition where this method will finish prematurely. Upon completion, the FinanceDay objects may still be getting processed and are not ready for the application to use.

Introducing the only synchronizer I ever wrote... the BufferedLatch

public final class BufferedLatch {
 private int recordCount;
 private int processedRecordCount;
 private boolean iterationComplete = false;
 
 public synchronized void incrementRecordCount() { 
  if (iterationComplete) { 
   throw new RuntimeException("Cannot increment record count after iteration is flagged complete!");
  }
  else { 
  recordCount++;
  }
 }
 public synchronized void incrementProcessedRecordCount() { 
  processedRecordCount++;
  if (iterationComplete && recordCount == processedRecordCount) { 
   this.notifyAll();
  }
 }
 public synchronized void setIterationComplete() { 
  iterationComplete = true;
  if (recordCount == processedRecordCount) { 
   this.notifyAll();
  }
 }
 public void await() throws InterruptedException { 
  while (! (iterationComplete && recordCount == processedRecordCount)) { 
   this.wait();
  }
 }
}

The BufferedLatch solves this problem. Its purpose is very similar to the CountDownLatch, except it is for cases where the countdown number is unknown, and will not be known until later mid-process.

There are three methods that control the state of the synchronizer (incrementRecordCount(), incrementProcessedRecordCount(), and setIterationComplete()) as well a method for waiting threads (await())

incrementRecordCount() is called every time a record is iterated.

incrementProcessedRecordCount() is called every time a runnable of that record is completed.

setIterationComplete() is only called once by the looping thread to flag that the iteration is complete.

Any threads waiting for the runnables to complete will need to call await(). In my uses, this always has been the thread that does the iteration and calls setIterationComplete(), and then it calls await() and sits until all the runnables are done.

For our example above, this is how BufferedLatch would be implemented.

ResultSet rs = //issue some query
ExecutorService executor = //create a fixed thread pool
BufferedLatch latch = new BufferedLatch();

while (rs.next()) { 
 FinanceDay financeDay = convertRecordToFinanceDay(rs);
 executor.submit(() -> {
  addToConcurrentHashMap(financeDay);
  latch.incrementProcessedRecordCount();
 });
 latch.incrementRecordCount();
}
rs.close();
latch.setIterationComplete();
latch.await();
executor.shutDown();


The way this works now is a BufferedLatch is created before any recordset iteration starts. After a record is iterated, converted to a FinanceDay object, and passed off as a Runnable to the executor, incrementRecordCount() is called.

When a Runnable (the lambda passed to the executor submit() method) finishes processing the FinanceDay, it calls incrementProcessedRecordCount();

After the entire ResultSet is looped through, the setIterationComplete() is called to flag that no more records are coming in. The query has been iterated through completely. If any more incrementRecordCount() is called, a RuntimeException will be thrown because iterating records should not happen after setIterationComplete() is called.

Finally, the original thread will come to the latch's await() method and will pause until the runnables complete by calling incrementProcessedRecordCount() enough times to match the recordCount. After that, every record has been iterated and processed, and the original thread is now free to shutdown the executor and move on.

Conclusions

My only regret is this latch does add some boilerplate to the client code by having four different methods that need to be called, where CountDownLatch typically only has two (countDown() and await()). If anybody has suggestions I am very willing to hear them. But I have not found a latch that accomplishes anything like this, perhaps because this problem is somewhat niche.

One disclaimer: like any multithreading decision... first evaluate if it is even worth multithreading the task in question. Test and ensure there will be performance gains over a single-threaded approach. If your database query is quick, it might be worth importing all the data first before doing anything with it. But if you have worked with painfully slow data connections like me, or are issuing an intensive query, you may want to utilize that idle CPU time and use the solution above.





No comments:

Post a Comment