This post is in conjunction with our paper on bioarxiv. My hope is to add a little more detail around the idea of deep learning in population genetics.
Where did the idea of using neural networks to do pop gen come from?
The neural nets we built are very similar to
those commonly used on MNIST handwriting data, which is like basic neural
nets 101.
Machine learning, deep learning, deep neural nets, and artificial intelligence. What is all this stuff? And where does statistics fit in?
It’s confusing. Here’s my impression of how to structure some of the terms:
ML=Machine Learning; DL=Deep Learning; DNN=Deep Neural Nets
ML is very general, and encompasses many things, even
including standard stats methods like regression (depending on who you
ask). DL is a catch-all term that refers
to a large class of (mostly) neural network methods without calling out any
specific one. And DNN is a just any neural
network that has more than one internal layer (i.e. “deep” layers). There are many sub-types of DNNs too. For example we made heavy use of
Convolutional DNNs.
And artificial intelligence, or AI. That term is
fuzzy and really used more for marketing. All this stuff is probably AI, or
none of it is, or who knows. You should sprinkle AI into the conversation when you are trying to impress people.
Finally, where does stats fit in? That’s also debated. I don’t think of neural nets as
a statistical method. But others may disagree. They certainly incorporate ideas from statistics, and if you know some stats that will probably make it easier to learn neural nets.
I’ve got this project I want to do, and it’s a really hard problem. Also there is no good way get good training data, and even if there were, there’d be no signal in it. Do you think the neural nets can solve this problem?
Nope. There is no magic.
I’ve got another problem, and it can be easily solved by regression, but I was wondering if I should use a neural net instead?
Nope. Whenever possible stick to the explainable stuff. But bonus points if you do it both ways and learn something.
I've got a NP problem (like phylogenetics), can neural nets do that?
Generally speaking neural nets are inappropriate for optimization problems. Sometimes you can recast optimization as classification and use a neural network, but this only works well if there are pretty small number of possible classes (and usually in those cases you can just do an exhaustive search).
Do I have to be a math expert to learn neural nets?
No! Ever run a BLAST search? Do you really now what’s going on under the hood? Me either, but we can both use them to get our work done. Same thing with neural nets.
Also, the math is not so bad, and I don’t think you need to
master it to be able to conceptualize what’s going on.
Right now is a magical time to learn this stuff. You can
learn about deep learning from the people who pioneered it. Many of them teach
online courses or have really great books. Most of their teaching is free.
Don’t miss the opportunity to learn from these people. Below are just a few
links.
https://goo.gl/ryQ1kD (really great visuals)
https://www.youtube.com/watch?v=vOppzHpvTiQ (love this guy)
http://neuralnetworksanddeeplearning.com/
http://neuralnetworksanddeeplearning.com/
What kind of pop gen theory are these deep neural net based on?
None! Neural nets are completely naive to any relevant pop gene theory. They haven't even heard of Hardy & Weinberg.
Neural nets are just algorithms that learn to do
pattern recognition by computational brute force. And they work freaking well! And
on a whole bunch of different problems. Kinda neat huh!
Learn one technique and solve many problems.
The good news is that when we point the neural nets at data they
aren’t limited by our current state-of-the-art in pop gen theory. They just learn patterns and invent their own
theory (so to speak) for solving the problem at hand. So they may be doing things that we never
imagined! And I said there was good news, so that probably means there is bad
news too. Namely the innards of a neural
net are nearly impossible to interpret. So they are really good for detection
(i.e. this is a selective sweep and that’s not), but they are really bad for explanation
(i.e. I know this is a sweep because of X).
Oh, I see, so you jerks are putting theorists out of business!
Ah, you got us!
Actually, no. See above. Neural
nets are great for detection, but terrible for explanation. What they learn unfortunately a big black box to us.
We imagine that a hybrid approach might be the best way to use
deep neural nets in pop gen. For
example, first test a neural net and other methods on simulated data and figure
out what works best for your question (and that might be some ensemble of
several methods). If it’s the neural net
looks good, you should use it. It seems silly to do otherwise. Then to tell
your story use traditional pop gen techniques to try to explain what the neural
net found. Some things might not make
sense. Like the net found a selected
site, but no other method detects it. On
the one hand that presents a challenge, but on the other it’s probably worth
investigating because it could be something we’ve been missing and could lead
to new theory.
A second thing you should know. Big advances are being made in explaining what neural nets are learning.
And who is going to take the knowledge out of the neural net and
translate it to the rest of us?
Theoreticians.
You keep calling it a "black box", what the hell are you talking about?
Good question. Black box is kind of weasely. Also black box makes it sound like we can't look inside. That's not the case, we can look in, but I hope to convince that what we see is just really hard to interpret. Let's build a neural net and crack it open see what's going on.
I made a neural net that predicts Tajima's D. The nice thing is we know the theory and math behind Tajima's D, so if we train a net to predict it and look at the inside, maybe we can find the components of theory and math in there somewhere?? That's the dream anyway.
Below is the structure of the net. I highlighted the neuron layers in red and for each the shapes of their output matrices are given.
InputLayer (40, 20)
Conv1D (39, 128)
Conv1D (38, 128)
AveragePooling1D (19, 128)
Dropout (19, 128)
Flatten (2432)
Dense (128)
Dropout (128)
OutputLayer (1)
And you feed this thing images (20 chromosomes and 40 segregating sites) like the one below:
That's the "input layer" in the structure above. Then it goes through two 1D convolutional layers (among other things noted above), and a fully connected dense layer (each with 128 neurons), and finally a single neuron output layer which just predicts Tajima's D as a continuous value. Below is the quality of this network on unseen test data after training.
It's pretty darn good. So it really gets how to predict Tajima's D. So now lets look how it does this using the input sequence alignment given above as an example. Here's the output of the first convolutional layer:
White pixels indicates a strong signal and black means no signal. Now here's the next convolutional layer:
And finally below is the fully connected "dense" layer: (This thing is flattened out, so it's really just a vector. But I gave it a little height so you could see it.)
And that thing feeds to one linear neuron and it sums it up and outputs 1.53, which is really close to the true value of Tajima's D (1.56) for the sequence alignment above.
Like I said, we know what goes into Tajima's D. What you'd like to find in these images is something corresponding to Watterson's Theta and Pi, and also something about their variances. Or maybe some bits that appear to count the frequency spectrum or some kind of intermediate stuff that you might use to get to Watterson's Theta and Pi. But it's just not there. At least not in any way we know how to extract (and perhaps it is really just not there). This is what the "black box" term refers to. What's in there is just a bunch of signals (off, on a little, on a lot, etc), and they give rise to other signals, and finally the signals are summed up and the sum (almost magically) comes out really close to the correct value.
What's worse, there are a vast array of internal configurations of this network that can achieve high accuracy. So if we retrained this thing from a new random starting position, we would likely see a completely new pattern of signals flowing across the network for this same input. And yet, the thing would probably again guess a value very near to 1.56. It's a little like our brains. If I ask you to imagine ice cream, you can conjure up some mental image. But if I compare how your brain makes that image to how my brain does it, we surely have different patterns of neurons firing. So just like our brains, the neural net is not learning some canonical way to get Tajima's D out of a sequence alignment. Instead it's just finding a local optimum, and there are many many equivalent local optima out there. So there is nothing really special about the patterns above, other than the fact that they represent one of the many ways of getting a decent prediction of Tajima's D out of a neural net.
Now imagine we instead built a neural network that could prediction something we don't know how to compute (like the signature of a soft sweep). We might get a network that does really well, and it'd be a dream back-out what the network learned into new theory. In practice, what this would entail is making sense and structure out of stuff like the patterns above. And perhaps worse, I'm not sure there is any certainty that their would be new theory in there. Like maybe you could toil at this for years, only to later learn that there was never anything there to find. I hate ending on a downer like that, but anyway, you hopefully understand why people call these things a black box.
Ok, I read the stuff above, and I still want to develop a neural net for my research. What does the process look like?
I’m so glad you are interested! Let me help you get started. This is a bit of a long answer. It’s a multistep thing.
First, maybe you have a bunch of genomes from your favorite
species where the exact genealogical history of every SNP is known. And you know every evolutionary force that has ever acted on each SNP too. Then all
you have to do is…
What’s that? You don’t have that kind of data? Or even a
kind of halfway approximate version of it. OK, well, first let me welcome you
to the club. In this club we do coalescent simulations to make up fake data sets that are sorta like those of the real critters we are studying. Then we feed these fake data sets into our
neural net to train it and hope that we’re close enough to reality that when we
point the thing at the real data it works.
So step one, you have to simulate training data for your net. And like I said, what that usually means is
building a customized coalescent simulation.
So you have to capture all the important complexity of the evolutionary
history of your sample of genomes, and then simulate conditional on that. You
also have to simulate the thing you want to detect, like a set of sims with the
selective sweep and a set without.
Probably 90% of all your mistakes are going to happen during this stage
of the process. But since everything is
super dependent on the exact nature of your sample and the critters you study,
all your mistakes will be new and different from all of ours, so I can’t give
you a lot of tips other than to talk to coalescent simulation experts. So
that’s my advice. Make mistakes, talk to people, fix those mistakes and repeat
until all the remaining mistakes are so deep that the average reviewer probably
will never notice.
I’ll assume you accomplished step one, and you have a lot of
coalescent sims in hand and they are all really really perfect for your
question. And hopefully you still have some of your motivation intact too.
Next you have to get your coalescent sims into python as
matrices (oh yeah, if you don’t already know python, go learn that too). We
have some code that you can use as an example. We used ‘ms’ formatted files to store our
coalescent sims. They’re unpleasant to
parse. I’d recommend using msprime, which gives you the output of the sim in a
nice python format off the bat.
Now we are nearly ready to train our network. But first we have to build our compute
environment. I highly recommend using
Keras and Tensorflow. They are two great python packages for training neural
nets. Tensorflow is the math guts of
neural nets, and Keras is a package that sits on top and hides a lot of the
math from you. Tensorflow is a bear to install.
I recommend using a cloud service like FloydHub
where it’s already pre-installed (I’m not affiliated with FloydHub). Or
bribe someone to install it for you, but probably just use a cloud service.
They’re so cheap and your time is valuable.
Neural networks are computationally intensive, so make sure
you use a machine with a GPU, because it will make things way way faster. What’s a GPU you say? The truth is I don’t really know. The
Wikipedia page says things like “stream processing” and “compute kernels”,
which are all words that mean something I guess. Here’s what I do know. GPUs do matrix math really fast, and training a neural net involves tons of matrix math. So they make the whole process way faster.
OK, so you got Keras and Tensorflow up and running on a GPU either by bribing someone or using a cloud service. Now you are ready to train!!
OK, so you got Keras and Tensorflow up and running on a GPU either by bribing someone or using a cloud service. Now you are ready to train!!
Training is my favorite part, because it involves
baby-sitting a computer while listening to music and drinking beer. So first get some tunes going. I recommend something electronic because
that’s the mother-tongue of neural nets.
The goal is to comfort your net with songs in its language so it can
learn faster. The beer is for you, so that's your call.
OK, first step in training (after taking a good pull of your
beer), is to split the coalescent sims into three groups, the training group, the validation group, and
the testing group. We’re going to train
the neural net with the training group, make parameter changes with the validation group, and then test its accuracy with the test group
(pretty creative names huh!). So the
training group should be big. We used
like 100K-200K individual coalescent sims usually, but even more would probably
have been better. And the validation and testing data
can be smaller, like 5K-20K each. But the
trick with the validation and testing data is that it needs to be big enough to have all the
important features of the training data (and vice-versa). So let say we are trying to detect selective
sweeps. It’d be a huge mistake to put
all the hard sweeps in the training data and all the soft sweeps in the validation and testing
data. Instead, we need to randomly mix
them so both sets have similar distributions of hard and soft sweeps (and even
similar intensities of hard and soft sweepy-ness).
Once we’ve split the data, we pick a neural net architecture and we start training. Architectures that work pretty well are here, and if you read the code you’ll see you just specify the
architecture layer by layer in with Keras.
Network architecture is a dense topic that I don’t completely
understand, but experts seem to agree that nobody else understands it either.
Each pass over our training samples is called an epoch. Generally speaking one epoch
takes about a quarter of a beer. So you fire off one epoch, and sit back, enjoy
the beer, and groove to tunes. Keras
provides a little progress bar and 2 numbers called loss and val_loss, and it
kinda looks like this at the end of the epoch.
Train on 143064 samples, validate on 7167 samples
Epoch 1/1 143064/143064 [==============================] -325s -loss: 0.3728 -val_loss: 0.3117
loss is the
current state of the loss function on the training data, and val_loss is the current state of the loss
function on the validation data. The loss function is just a measure of how
well the net is doing at getting the right answer. We want both loss and val_loss to be
small, but really we want val_loss to
be small, because that’s the true test. And
as the model trains we want to see both loss
and val_loss reduce with every epoch.
There’s something funny going on above. The value at loss is actually greater than the one at val_loss. What that means is
that the net does better on data it has never seen than on data it was trained
on. WTF is up with that? Is it psychic or something??
Relax, it’s cool. In fact, it’s really good. What’s going on
is that we are using a technique called dropout. By using dropout we handicap the net during
training, so it thinks it’s dumber than it really is and tries extra hard to
learn. Then when it applies itself to
the validation data, we take the handicap away and it does better than it ever thought
it could. This is kind of a shitty thing
to do to it of course. Our little net is definitely going to have trust issues down
the road, but it’s just a bunch of bits, so tough luck kid. That said, when the
computers take over we are all going to pay dearly for this dropout stuff.
Training only takes about ½ a brain and a couple beers. Maybe less of the former and more of the
latter. That’s why I like training. There is one thing you should do at some
point. Figure out the math that makes up
your loss function (they’re usually
the mean square error or cross-entropy, which are pretty easy equations), and
figure out some way to convert this into units you can contextualize (like RMSE
or accuracy (and Keras will do accuracy for you)). That way while this thing runs you can see
how it’s doing in units that make sense and you can feel good about life as it
gets smarter and you get drunker.
But don’t get too drunk.
There’s one other thing you have to do.
You have to watch out for overfitting. This is when the net starts predicting
your training data super well, but at the same time gets worse at predicting
your validation data. Overfitting is just a fancy way of saying that something
trained really hard to memorize a specific task, but didn’t learn to solve the problems in a general way.
Neural nets are really smart, and they love to overfit,
however overfitting is pretty easy to diagnose.
Any time loss is smaller than val_loss you are seeing
overfitting. That’s why dropout is
good. It keeps loss bigger than val_loss. Usually a little over fitting is ok, and
sometimes it corrects itself after a few epochs. But when it gets out of hand it’s pretty
ugly, and often it spirals until your model is so overfit that it’s pretty much
useless at predicting the validation data even though it’s crushing it on the
training data. If it overfits really
fast and well before it achieves any kind of accuracy you’d call acceptable,
then you might have to change the architecture, or make more training data, or use even
more dropout (you heartless jerk). Here’s
a cool blog post that talks you through your options. I recommend carefully going through it
step-by-step, and for now skipping all the steps you don’t understand.
On the other hand, maybe your net just keeps trucking along
getting better and better. Lucky
you! But at some point you have to stop
drinking and call it good enough. Then
you evaluate your net one last time, this time on test data that were never used in training. This is important, because you tuned the model using the validation data, but now you need an independent assessment, and that's what the untouched test data is for. Then you
save the trained net and the accuracy on the test data and get ready to run it on your real data and do real biology!
Have you ever learned about some flashy new computational
approach for solving your exact problem? Holy cow, exactly what you needed! Amazing,
right! Then you download it and realize that the inconsiderate bastards who
wrote this crappy software went out of their way to make it hard to use. You spend weeks reformatting your data and
just seething with contempt for these jerks. Important stuff is undocumented, and it even crashes on their stupid
example, and whitespace is f-ing important but of course that’s
not discussed and… uhg. Who do they think they are?! How could they be such assholes, wasting
countless hours of everyone’s time, just because they couldn’t be bothered to
use some standard input formats, or write halfway decent documentation, or
provide code that compiles. Oh do you
hate them, just a burning hatred!
OK, first you’ve got some issues, but don’t worry so do I, so no judgment here. But second - and this is the big payoff - guess what,
when you make your own neural net to solve your own problem you get to be your
own inconsiderate bastard who makes it a huge pain in the ass to use. How cool
is that!
Yeah, so I should have mentioned this earlier. Before you start, you need to think about
your real data a lot. Consider how you
want to run it through the net. Then you need to tailor your coalescent sims to
match whatever paradigm you choose. If
you have 20 individuals and you want to run 20 kb chunks of the genome through,
then you need to simulate 20 individuals and tune the sim parameters to
approximate diversity in those 20 kb chunks.
Also, are your data phased or unphased, and how much sequence error do
you think you have? And do the A,C,G,
and T bases matter, or can you code all your biallelic sites arbitrarily as 0
& 1? You gotta work through all this
stuff, and tune the sims accordingly. It
all matters, because the neural net will learn from what you give it, not what
you meant to give it. Also, you are probably going to tweak and rerun the
simulations several times before you get it right. Hence my comment above about 90% of your mistakes, it’s
this stage where they all come to light.
Sounds bad, I know, but just use the neural net that sits on
you shoulders and try to imagine what the artificial one in your computer is seeing
in the data, and then adjust the sims accordingly. Plus when you are your own
inconsiderate jerk who made the thing a pain to use, at least the feedback loop
usually is pretty short.
You’re almost done!
You made your training data, and trained your net, then you tested it
against unseen data, and finally you ran real data through. And guess what, it found stuff! Time to write
a paper! So now you have to explain what it found. One approach is to fall back
to non-neural net methods. Like apply
some pop gen summary stats just as you always have. It’s cool, go ahead, I won’t judge. Then maybe compare the patterns of the net
versus the traditional methods and figure out how to explain it to someone
using traditional summary stats. Or
maybe don’t. Instead you could just
point to the quality of your neural net on the test data and rest your
case. If a reviewer asks tell them “Stats! We don’t need no stinkin’ stats!”. I’m cool with that too. Whatever works for the story you want to
tell. It might even be helpful to read
some cool machine learning papers in biology and see how other people reported
discoveries made with a black box network.
Here’s my final pitch: we’ve been doing pattern recognition
in pop gen for years (though rarely calling it that). I (and other like-minded
folks) think it’s time we adopt more powerful pattern recognition
approaches. I’m excited about neural
nets and their potential in this regard, but they are different and will require
a little retooling on your part. So, if you decide to give it a whirl and get
stuck, reach out to me. I’m happy to
chat and offer any tips I can. The best
ways to reach me are on Twitter @flagelbagel or my email, which is a gmail account with the prefix flag0010.
OK, I’m getting it, and I see this whole thing as being terribly reliant on the quality of my coalescent sims. Seems like a house of cards. Right?
Yes, how well your coalescent sims match your real data really matters. I can’t stress that enough. You can’t train a neural net on sims that are appropriate for elephant demography and expect it to do well on fruit flies. In fact, you probably can’t even train on the demography of one species of fruit fly and hope it will work well on another. But hey, it’s not hard to get heaps of genomic data, so just estimate the demography and simulate accordingly. This is kinda like specifying your prior, and in my opinion that's a good thing.
You also should know that neural nets do well at
interpolation, but they struggle with extrapolation. So for example, if you train one to look for
sweeps when the effective population size ranges from 10K-50K, they will work
well on real data that falls in that same range, even for points in that range
you never actually simulated (interpolation). But they probably won’t work well
for real data that lies outside that range (extrapolation). So you have to be cognizant of the extremes in
your data and train all the way to those extremes. And bear in mind, the coalescent is quite stochastic. So things like the
effective population size are going to have a big range across different loci
in the genome and you need your sims to be similarly extreme.
Finally, traditional summary stats based pop gen also tends
to rely heavily on coalescent simulations too. A common practice is to generate
oodles of summary stat values from the real data and then use coalescent sims
to determine the cutoff values for significance. So if you do that kind of pop
gen already, then either way you slice it, it’s turtles and coalescent sims all
the way down.
There’s so much hype around neural nets. What happens in 5 years when they become the next “big data” or “VR” or “wearables”?
Yes, there is a LOT of hype around deep learning in tech right now, and you are absolutely right to be a skeptic when you see this kind of hype. “Deep Learning for X” startups are launching daily, and many are selling pixie dust. They and their financiers will be screwed soon enough, but you should know that the core tech behind neural nets is real and it’s not going away. This is absolutely happening. If you see yourself being in a quantitative field in the future, it's probably time to add neural nets to your toolbox.