Sunday, May 22, 2011

Taking advantage of multiple cores with Python and Java

The first round of Google Code Jam 2011 has just ended. I participated in round 1A, waking up at 3AM on a week-end day (to tell you how motivated I was). The first problem was not too difficult, but it still took me a good 45 minutes to solve. Then came the second one, The Killer Word: the one where you have to find the hardest word for your opponent to guess in a game of hangman. I implemented the solution the way it was described, but my first small attempt failed.

I dug into the code again, and found a rule I had not implemented. Alas, my next attempt was also a failure. I had to try harder. Another bug fix later, my third submission was still not right. Time was running low. About twenty minutes before the end my rank was around 900th and falling. If I could not submit another valid solution, I would probably fall behind the qualifying 1000th rank.

After another pass over the code, I found yet another subtlety that I had not taken into account. Clickety-clickety-click (I have a "Das Keyboard" that really clicks, if you want to know). Another submission. Fingers crossed. And... yes! Now it was right. The small problem set was now solved. My rank raised to 500th. I was saved. Time to download a large problem set and run the program over it. Eight minutes should be plenty. The dictionary size was multiplied by 100 and the list of letters by 10, but it should be ok... Except it was not. Not at all, by a long shot.

Sleeping over it

With the knowledge of my qualification and still the problems in my head, I went to bed. When I woke up, I thought that perhaps a multithreaded implementation could have saved my attempt at solving the large problem set. As a first step, to speed up the program, I rewrote it from Python to Java. Straightforward, if a little dull. It ran faster, but still it was impossible for that algorithm to succeed in less than eight minutes on the large set. So: back to thinking.

I implemented a solution another way: instead of computing the score of each dictionary word in turn, I processed all words at once, keeping only the best one at the end. I programmed that new algorithm in Python again, because I really love its "runnable pseudo-code" feeling: no need to write types, literal syntax for everyday structures (tuples, arrays, dictionaries). I ended up with a valid solution that managed to solve the large set in about 2'30". Not too bad for Python.

Going multicore

The Python solution was now adequate, but the idea to take advantage of the mutiple cores of my machines was still in my head. So I decided to crawl the web and the excellent reference documentation for Python. My first reflex was to look at the thread and threading modules. But, as you probably know, the standard CPython implementation relies on a Global Interpreter Lock (the GIL) that practically makes multithreading useless, except for handling concurrent blocking I/O. Fortunately, the documentation pointed me to the multiprocessing module. The basic principle is that instead of creating thread inside the Python VM, that module allows to create OS-level processes, each one running a distinct Python VM, but with communication channels established between them, so that you don't have to leave the comfy Python environment. And of course, it's as platform-independent as possible.

For my needs, I looked at the Process class, but rapidly learned about the Pool class. The API is perfect, almost magic. You create a Pool object (Python guesses the number of CPU for you). You assign it all the functions you want to run asynchronously. You close it, which means you have no more work for it to do. You join it (you wait for all workers to process the functions). And you only have to get the results and sort them. That's all. In code, it looks like that:

 def solve_with_pool(cases):  
   from multiprocessing import Pool  
   
   pool = Pool()  
   async_results = [pool.apply_async(solve_case, args=(case,)) 
                    for case in cases]  
   pool.close()  
   pool.join()  
     
   results = [ar.get() for ar in async_results]  
   results.sort()  
   return results  

It's difficult to get easier than that. solve_case() is the function that, as the name implies, solves a case and returns the result as a tuple whose first element is the case index. That's why the simple call to results.sort() works: it sorts all results by index. In fact, it's probably useless, because the results list is built in the right order... Anyway, with that in place, the program runs now in 50", which is three times faster than the single-threaded version, on my quad-core machine. The speedup factor is not four, because the division of work I have made is coarse-grained and there are only ten cases in the large set. So one of the processes probably got to work on two large cases and ended some time after the others had finished. Still, not too bad for ten lines of code.

And Java?

If you read the title of this post thoroughly, you've seen Java in there. So. I remembered Java 1.6 standard library included a part of Doug Lea fork-join framework that was previously distributed as a separate package. Except it does not. It's planned to be included in the upcoming JDK 1.7. So in the meantime, I'll have to resort to simpler constructs. But at least, I think java.util.concurrent will give me a higher level interface than direct use of Thread and synchronization.

The principles are the same as the Python version, but of course, Java is a bit more verbose, mainly because of explicit typing and also because it lacks first-class functions. So I have to create a FutureTask, built from an anonymous class derived from Callable. Otherwise, the API is at the same level: a pool of executors, each one is executed, the pool is shut down, we wait for all computations to end (for one day, because we are forced to provide a timeout value), and then we collect all results. The managed exceptions also add a bit to the verbosity: both awaitTermination() and get() can raise some exceptions that are of no interest in our case.


  public static Result[] solve_with_pool(Case[] cases) {
    ExecutorService pool = Executors.newFixedThreadPool(4);

    FutureTask<Result>[] fresults = new FutureTask[cases.length];
    for (int i = 0; i < cases.length; ++i) {
      final Case c = cases[i];
      fresults[i] = new FutureTask(new Callable<Result>() {
        public Result call() {
          return solve_case(c);
        }
        });
      pool.execute(fresults[i]);
    }
    pool.shutdown();
    try {
      pool.awaitTermination(1, TimeUnit.DAYS);
      Result[] results = new Result[cases.length];
      for (int i = 0; i < cases.length; ++i) {
        results[i] = fresults[i].get();
      }
      return results;
      
    } catch (InterruptedException e) {
    } catch (ExecutionException e) {
    }
    return null;
  }

The net gain from going multithread in my case is less impressive than in Python. The single-threaded program solved the large input set in 25" while the multi-threaded one took 18". I have not investigated why.

Conclusion

Well, turning a single-threaded problem solving program into a multi-process or multi-threaded version is not too difficult. The available APIs sit at the right level of abstraction, both in Python and Java. The runtime gains are significant. Of note, the Java program, when running full steam on all four cores tended to make the whole PC sluggish, with the mouse pointer moving less smoothly than it should.

Now I know. And hopefully you do too.

No comments:

Post a Comment