Multi-armed bandits & Clojure

An (interactive) exploration of a classic reinforcement learning problem through the lens of Clojure.

Introduction

Recently, I've been learning the programming language Clojure. Clojure is a functional Lisp-dialect, hosted on the JVM. Clojure's creator, Rich Hickey, has given some really interesting conference talks and like many others, I went down the rabbit hole. His talks resonate with me because he has a way of explaining things about building software I have always felt but never been able to put my finger on. Here's a example of what I mean:

So yeah, I have drunk the Clojure Kool-aid, and it tastes pretty good!

In this article, I'm going to walk you solving an example problem using Clojure. Hopefully, it can be a stepping stone after you get through the basic "hello world" type examples. I'm going to assume very basic familiarity with the language so if you've never seen Clojure before at all, you may want to check out the "Learn Clojure" section on clojure.org before reading.

All the code samples in here are interactive (enabled by clj-browser-eval and Scittle). Please experiment with them and have some fun.

Let's get started.

Multi-armed Bandit Problem

The problem we will look at is the multi-armed bandit problem, a classic problem in reinforcement learning. The problem is this:

Imagine you are in a casino and there are 3 slot machines in front of you. Each pull of a slot machine is free, and as a result of each pull, you can either win or lose some money. A result of -1 means you lost 1 dollar and a result of 0.5 means you won 50 cents.

Now I (a veteran of this particular casino) tell you that these 3 machines are built differently and some are better than others. Inconveniently, I don't remember which one is the best.

For example, machine 0 may give you a result of 0.5 on average while machine 1 may give you an average result of -1.5, and machine 2 an average result of 0.75.

Machine Average Reward
00.5
1-1.5
20.75

You have 100 pulls and your goal is to end up with as much money as possible. What's your strategy?

Ideally, in the example above, you would like to choose machine 2 for all 100 pulls because it gives the highest reward on average, but the tricky part is that you don't know this information ahead of time. You can only learn which machine is best by trying each one out. And remember, these machines are random, so 1 observation of pulling each one won't necessarily tell you which machine is best. For example, your first few results might look like this:

Pull Number Machine Reward
000.55
11-1.4
220.5

If, based on these results, you naively concluded that machine 0 is best and decided to choose only machine 0 for the remaining 97 pulls, you would miss out on the higher average rewards from machine 2. On the other hand, if you pull machine 0 33 times, machine 1 33 times, and machine 2 33 times, you will now have a really strong inkling for which machine is the best (machine 2), but now you only have 1 pull left to actually use this knowledge!

Exploration vs Exploitation

The fundamental challenge illustrated by this problem is the balancing of exploration (discovering new knowledge) and exploitation (putting your knowledge to use). A good strategy for this problem must balance these in order to earn an optimal reward.

For example, you could decide that half the time, you will explore (choose a machine randomly) and half the time, you will exploit (choose the machine that you estimate to be the best based on what you've seen so far). Or maybe you should explore 25% of the time and exploit 75% of the time. You could also use a different strategy entirely.

Let's build a simulation and test one of these strategies out.

Exploring the problem with Clojure

What does a bandit look like?

In our program, we're going to talk about a bandit with N arms: an n-armed bandit. Using this language, the example above would be a 3-armed bandit. For our program, let's consider a 10-armed bandit. We might create a bandit like so:


      

(Click the "Evaluate" button above to see what this code does.)

So here we have 3 functions. The first 2 are small utility functions that just help us achieve what we are really trying to implement, the n-armed-bandit function, so I'll ignore those for now. The n-armed-bandit function is what actually creates what we will consider our "bandit". Our bandit can be entirely described by a list of values, each one representing the "average reward" for an arm, so that is what our function returns. Our function takes in the number of arms it needs to generate "average reward" values for. Then, it takes that many samples from the infinite list that results from repeatedly applying our sample-normal function. We convert the resulting list to a vector with vec. vectors are like arrays in other programming languages. These are useful when we want to quickly access an element based on it's index (which we will need to do very shortly).

The main operation we are concerned with when it comes to bandits is "pull arm". This should simulate pulling a single arm of the bandit and give us back a reward. We can implement this like so:


      

Here, we retrieve the average reward value for the specified arm and return a sample from a normal distribution using that arm's average reward value as the mean. This is where it comes in handy that we can fetch the "average reward" value for an arm of the bandit using its index.

Press "Evaluate" a few times to pull arm 0 and see what you get!

Now, if we wrote this correctly, we should be able to see after pulling a certain arm many times, that the rewards average out to the arm's "average reward" value. Let's verify that is true:


      

You should be able to see that the actual and observed values for arm 0 are pretty close to one another.

Our decision making "agent"

The decision maker in a reinforcement learning problem is usually called an "agent". In this section, we'll start thinking about how our agents will make decisions.

As previously discussed, one possible strategy is to choose randomly some percentage (which we'll call epsilon ε) of the time. The rest of the time, we could act "greedily", or choose the arm that we currently think is best. This is called an epsilon-greedy method.

To perform the "greedy" action, the agent needs to have some estimate of the value of each arm so it can pick the one whose esimate is the highest. The obvious way to accomplish this is to just remember every single observation for each arm, then we can calculate the estimated value for each arm by averaging all of the observations. This would involve repeatedly calculating the mean over all observations for each arm every time we wanted to choose the "best" arm. A more efficient way is to calculate the mean of the observations incrementally. To do this, the agent only needs to know how many times it has pulled an arm and what the current average observation is for that arm.

With these considerations in mind, we may end up with the following representation for our epsilon greedy agent:


      

The two main operations relating to an agent are:

  1. Choose which arm to pull next
  2. Update the agent's value estimates with the observed result
The first is straightforward:

      
Let's break down what's happening here. argmax is a function we define that returns the index of the largest element in the input vector. In choose-arm, we first generate a random number between 0 and 1. If that number is below the agent's epsilon, we choose an arm randomly; otherwise, we choose the arm that currently has the highest estimated value. By doing this, we are performing the exploratory action epsilon percentage of the time.

In the example in the code above, you can see that epsilon is 0.5, which means half the time, we choose randomly and half the time, we choose the arm whose estimate is currently the highest (which is 0 in this example).

Side note: I have really grown to appreciate the value of dynamic typing when writing little sample scripts like this and even when writing tests in my day job. As you can see above, when I call choose-arm {:epsilon 0.5 :value-estimates [0.75 ...]} I don't bother passing in the :pulls-per-arm [0 0 ...] field that would typically go along with an agent because this function doesn't care about that so why pass it in? I find static typing makes things like this a lot more tedious.

Now, for updating the agent's value estimates. Clojure discourages state mutation and prefers immutable data, but our agent needs to be able to make observations about the bandits and learn from them. To accomplish this without state mutation, our "update" function will actually just generate a whole new agent, like so:


      

As I discussed above, the agent tracks the average of the rewards seen for each arm in an incremental way.

Once we calculate the average, we use update-in to update the value estimate for the arm we pulled and update-in again to increment the count value for the arm we pulled. And remember, in Clojure, the core datastructures are immutable so update-in actually returns a new object without affecting the old object.

Side note: For those of you who are new to Clojure, -> is the thread first macro which essentially provides a way to take some value and pass it through a pipeline of functions.

A note on immutability: If you evaluate the code above, you can see that after adding a few observations, we have access to the new updated agent, but our old agent is still perfectly intact. Many proponents of Clojure and functional programming argue that this immutability makes programs a lot simpler. In the paper "Out of the Tar Pit," the authors attempt to identify the primary sources of complexity in software and they identify "state" as one of them. "State" is the reason why "turning it off and on again" is still the most likely resolution to your IT problem in 2022 (because you are resetting the system's state). And as you can see from the example above, Clojure encourages us and enables us to avoid state in many circumstances.

Tying it all together

Now that we have our bandit and we have our agent, we can write the primary function that ties them both together:


      

Basically, we ask our agent which arm it wants to choose, we pull that arm on the bandit, then we "update" the agent with the new observation. Let's see that in action:


      

Here we can see how our epsilon greedy agent actually performs against the bandit that I initially described (I'm using 1000 pulls though instead of 100 so the results have less variance). As we said before, the optimal strategy would be to pick arm 2 every time. If we somehow knew to do that, we would expect an average reward per pull of 0.75 so we can consider that a benchmark to compare against.

If you evaluate the code above a few times, I expect you'll see that epsilon=0.5 is not ideal - a bit too much exploring. The agent needs to be more greedy. Try it out for yourself and see if you can find an epsilon value that works better!

Keep in mind, this is only one of many possible strategies for this problem. I would keep going but this post is getting too long.

Conclusion

I hope the example problem here was a nice showcase of how to use Clojure to solve a "real" problem - "real" in the sense that it's at least one step above the sort of trivial "generate the fibonacci numbers" examples.

Clojure is a really fun language and if nothing else, it forces you to think in a different way. It's kinda like lifting weights, but for your brain.

To be a little more specific though, I think Clojure is valuable because it makes it easy and practical to program real applications with immutable data. I think the concept of immutability is starting to creep in everywhere in our industry, and people are starting to understand its value. Some examples are:

I think programming primarily with immutable data in our application code is still pretty niche, but it's only a matter of time before more people start to understand it's value there as well.

So yeah, Clojure is pretty cool.

And reinforcement learning is cool.

And if you made it this far, thanks for reading!

References

Leave a comment

Comments