Thursday, June 27, 2013

The Optimization of Java's HashMap class

Yesterday I was on Quora.com sifting through some Q&A and I ran across someone describing an optimization that was made in Java's HashMap class, according to the poster around version 1.4.  It was simple, yet it amazed me.  I didn't understand how it could work at first, but with a little digging, I figured it out, and it's simple yet very clever.  First I'll briefly explain some of how a hash map works, for the laymen (I think everyone can understand most of this), then I'll go on to describe the change.


Description of What a Hash Map is

A hash map is a way a computer can store a set of things in memory for quick access.  Picture I had a function that took a word, let's say "bird", and converted it to a number, such as 7.  As long as you had the same input, you'd always get the same output.  So when you wanted to access a bunch of information with the label "bird", you can find it in storage bin number 7.  You only have to look in one bucket, so it's super-fast.

Optimally, your function will have a unique number for every unique word.  Ideally it would, but the function might not be perfect, and if "bird" and "potato" both produce a 7, then if you want to look up either, you might have to check two spots in memory instead of one, which takes longer.  This is called a "collision", and you want a function that avoids them as much as possible.

Now, it's true that if you had a billion words, it's unrealistic that your computer could have a billion separate spots in memory to hold it.  But your function produces unique numbers for nearly all of them, so you ultimately want the hash map to have a place for each number.  What a HashMap will do is take the number of spots in memory it DOES have (let's say 16), and divide the number output of the function by it, and use the remainder instead.  This is the "modulo operation", represented by the percent (%) sign.  That way, you're never trying to put something in a memory location that your hash map can't support.  So if your function said "banana" should go in spot 39, then you'd see 39 % 16 = 7.

Certainly you're going to have plenty of collisions, but there are a couple of key optimizations that can be made.  First, you want the function to spit out numbers that are as evenly distributed as possible, so you don't have a bad scenario where you're searching through most of the words because they all resulted in, for instance, the number 7.  There are formulas provided by others who have thought this through already, so just use those.  Secondly, when the hash map gets too full, it will increase the number of spaces available, and move all the old words to their new locations based on the new number of "spots".

Just so you can talk the talk, the spots in memory a hash map has available are called "buckets".  The function that converts words to numbers is called a "hash function".  The numbers are called "hash codes".  The words are called "keys", and the "bunch of information" attached to a key is called a "value".


Java's Hash Map Optimization


The above modification shows the change, but it's dependent on a couple of other behaviors of the Java hash map.  First, I'll review what's going on.  I mentioned how modulo is used to determine what bucket a specific hash code maps to.  This is replacing that modulo with a "bitwise AND".  I'm not going to review too much about binary here, but it's all 1's and 0's instead of 0-9 like the base-10 (decimal) numbers you're used to.  So if you have 1 & 1, you get 1.  But if either or both is a 0, you get 0.  Picture converting the hash code and the number of buckets to a bunch of 1 and 0 "bits", then doing this AND operation on each bit, from right (least significant/smallest) to left.

If you think about it, you might wonder how this works, because it's not the same thing as modulo.  If you have 5 buckets, you're converting 4 to binary -- 100 (google "4 in binary").  That means whatever your hash code is, only the third bit will matter because the other bits will be ANDed to 0.  indexFor will always output either a 0 or a 4.  That will be a crazy amount of collisions.

First Trick


There will never be 5 buckets.  Java's hash map implementation, when expanding, multiplies the number of buckets by 2.  You'll always have a power of 2 (1, 2, 4, 8, 16, etc).  When you convert a power of 2 to binary, you only have one bit as a 1.  When you subtract 1 from a power of 2 and convert that to binary, that bit is a 0, and all the bits to the right of it are 1.  Try ANDing a bunch of 1's with anything and you get the modulo of what those 1's represent in decimal plus 1.  ANDing bits like this is much faster than doing a modulo, which requires division and subtraction.

Second Trick


There's also a concern which might not be obvious, but if you are relying on only the smallest bits of your hash code, you can easily get an uneven distribution of keys in your buckets unless you have a really good hash function.  What Java's hash map implementation does is to "rehash" the hash code.  Check this out:


This scary thing takes your mostly unique hash code and randomizes it in a way that has a relatively even distribution in the "lower bits".  For curiosity's sake, I'll mention that the >>> is shifting the bits in your hash code to the right... so if you had a 4, or 100 in binary, and you did 4 >>> 2, you'd end up with 001, because it's been right-shifted twice.  The ^ is an "exclusive OR" operation, which is similar to the AND operation, but it outputs a 1 if the two bits are different (one's a 1, and the other is 0).  Essentially this thing is ensuring that the more significant bits in your hash code are affecting the least significant bits that you're ultimately going to use to choose each bucket.


Hope you found this all as righteous as I did!

2 comments:

  1. In "First Trick" it shoul be - "When you convert power of 2 to binary..." instead of 1.

    ReplyDelete