Friday, June 1, 2018

A very serious and completely accurate FAQ on using Deep Neural Nets for Population Genetic Inference


This post is in conjunction with our paper on bioarxiv. My hope is to add a little more detail around the idea of deep learning in population genetics.  

Where did the idea of using neural networks to do pop gen come from?


It’s mostly stolen from Google. We’re not that clever, but we steal from the best. 

The neural nets we built are very similar to those commonly used on MNIST handwriting data, which is like basic neural nets 101.


Machine learning, deep learning, deep neural nets, and artificial intelligence.  What is all this stuff?  And where does statistics fit in?


It’s confusing.  Here’s my impression of how to structure some of the terms:


ML=Machine Learning; DL=Deep Learning; DNN=Deep Neural Nets

ML is very general, and encompasses many things, even including standard stats methods like regression (depending on who you ask).  DL is a catch-all term that refers to a large class of (mostly) neural network methods without calling out any specific one.  And DNN is a just any neural network that has more than one internal layer (i.e. “deep” layers).  There are many sub-types of DNNs too.  For example we made heavy use of Convolutional DNNs. 

And artificial intelligence, or AI.  That term is fuzzy and really used more for marketing. All this stuff is probably AI, or none of it is, or who knows.  You should sprinkle AI into the conversation when you are trying to impress people. 

Finally, where does stats fit in? That’s also debated. I don’t think of neural nets as a statistical method.  But others may disagree.  They certainly incorporate ideas from statistics, and if you know some stats that will probably make it easier to learn neural nets. 


I’ve got this project I want to do, and it’s a really hard problem.  Also there is no good way get good training data, and even if there were, there’d be no signal in it.  Do you think the neural nets can solve this problem?


Nope. There is no magic.  


I’ve got another problem, and it can be easily solved by regression, but I was wondering if I should use a neural net instead?


Nope. Whenever possible stick to the explainable stuff.  But bonus points if you do it both ways and learn something.

I've got a NP problem (like phylogenetics), can neural nets do that?


Generally speaking neural nets are inappropriate for optimization problems. Sometimes you can recast optimization as classification and use a neural network, but this only works well if there are pretty small number of possible classes (and usually in those cases you can just do an exhaustive search).


Do I have to be a math expert to learn neural nets?


No! Ever run a BLAST search?  Do you really now what’s going on under the hood?  Me either, but we can both use them to get our work done.  Same thing with neural nets.

Also, the math is not so bad, and I don’t think you need to master it to be able to conceptualize what’s going on.

Right now is a magical time to learn this stuff. You can learn about deep learning from the people who pioneered it. Many of them teach online courses or have really great books. Most of their teaching is free. Don’t miss the opportunity to learn from these people. Below are just a few links.

https://goo.gl/ryQ1kD (really great visuals)


What kind of pop gen theory are these deep neural net based on?


None! Neural nets are completely naive to any relevant pop gene theory.  They haven't even heard of Hardy & Weinberg.

Neural nets are just algorithms that learn to do pattern recognition by computational brute force. And they work freaking well! And on a whole bunch of different problems. Kinda neat huh!  Learn one technique and solve many problems.

The good news is that when we point the neural nets at data they aren’t limited by our current state-of-the-art in pop gen theory.  They just learn patterns and invent their own theory (so to speak) for solving the problem at hand.  So they may be doing things that we never imagined! And I said there was good news, so that probably means there is bad news too.  Namely the innards of a neural net are nearly impossible to interpret. So they are really good for detection (i.e. this is a selective sweep and that’s not), but they are really bad for explanation (i.e. I know this is a sweep because of X).


Oh, I see, so you jerks are putting theorists out of business!


Ah, you got us!  

Actually, no.  See above.  Neural nets are great for detection, but terrible for explanation. What they learn unfortunately a big black box to us. 

We imagine that a hybrid approach might be the best way to use deep neural nets in pop gen.  For example, first test a neural net and other methods on simulated data and figure out what works best for your question (and that might be some ensemble of several methods).  If it’s the neural net looks good, you should use it. It seems silly to do otherwise. Then to tell your story use traditional pop gen techniques to try to explain what the neural net found.  Some things might not make sense.  Like the net found a selected site, but no other method detects it.  On the one hand that presents a challenge, but on the other it’s probably worth investigating because it could be something we’ve been missing and could lead to new theory.

A second thing you should know.  Big advances are being made in explaining what neural nets are learning. And who is going to take the knowledge out of the neural net and translate it to the rest of us?  Theoreticians.


You keep calling it a "black box", what the hell are you talking about?


Good question. Black box is kind of weasely. Also black box makes it sound like we can't look inside. That's not the case, we can look in, but I hope to convince that what we see is just really hard to interpret. Let's build a neural net and crack it open see what's going on.

I made a neural net that predicts Tajima's D. The nice thing is we know the theory and math behind Tajima's D, so if we train a net to predict it and look at the inside, maybe we can find the components of theory and math in there somewhere??  That's the dream anyway.

Below is the structure of the net. I highlighted the neuron layers in red and for each the shapes of their output matrices are given.

       InputLayer        (40, 20)   
           Conv1D        (39, 128)  
           Conv1D        (38, 128)  
 AveragePooling1D (19, 128)  
          Dropout         (19, 128)  
          Flatten           (2432)     
           Dense           (128)      
          Dropout        (128)      

      OutputLayer     (1)        

And you feed this thing images (20 chromosomes and 40 segregating sites) like the one below:

That's the "input layer" in the structure above.  Then it goes through two 1D convolutional layers (among other things noted above), and a fully connected dense layer (each with 128 neurons), and finally a single neuron output layer which just predicts Tajima's D as a continuous value.  Below is the quality of this network on unseen test data after training.


It's pretty darn good.  So it really gets how to predict Tajima's D.  So now lets look how it does this using the input sequence alignment given above as an example.  Here's the output of the first convolutional layer:

White pixels indicates a strong signal and black means no signal.  Now here's the next convolutional layer:

And finally below is the fully connected "dense" layer:  (This thing is flattened out, so it's really just a vector. But I gave it a little height so you could see it.)


And that thing feeds to one linear neuron and it sums it up and outputs 1.53, which is really close to the true value of Tajima's D (1.56) for the sequence alignment above.

Like I said, we know what goes into Tajima's D.  What you'd like to find in these images is something corresponding to Watterson's Theta and Pi, and also something about their variances.  Or maybe some bits that appear to count the frequency spectrum or some kind of intermediate stuff that you might use to get to Watterson's Theta and Pi.  But it's just not there.  At least not in any way we know how to extract (and perhaps it is really just not there).  This is what the "black box" term refers to. What's in there is just a bunch of signals (off, on a little, on a lot, etc), and they give rise to other signals, and finally the signals are summed up and the sum (almost magically) comes out really close to the correct value.

What's worse, there are a vast array of internal configurations of this network that can achieve high accuracy.  So if we retrained this thing from a new random starting position, we would likely see a completely new pattern of signals flowing across the network for this same input.  And yet, the thing would probably again guess a value very near to 1.56.  It's a little like our brains.  If I ask you to imagine ice cream, you can conjure up some mental image.  But if I compare how your brain makes that image to how my brain does it, we surely have different patterns of neurons firing.  So just like our brains, the neural net is not learning some canonical way to get Tajima's D out of a sequence alignment.  Instead it's just finding a local optimum, and there are many many equivalent local optima out there.  So there is nothing really special about the patterns above, other than the fact that they represent one of the many ways of getting a decent prediction of Tajima's D out of a neural net.

Now imagine we instead built a neural network that could prediction something we don't know how to compute (like the signature of a soft sweep).  We might get a network that does really well, and it'd be a dream back-out what the network learned into new theory.  In practice, what this would entail is making sense and structure out of stuff like the patterns above. And perhaps worse, I'm not sure there is any certainty that their would be new theory in there. Like maybe you could toil at this for years, only to later learn that there was never anything there to find. I hate ending on a downer like that, but anyway, you hopefully understand why people call these things a black box.


Ok, I read the stuff above, and I still want to develop a neural net for my research.  What does the process look like?


I’m so glad you are interested!  Let me help you get started.  This is a bit of a long answer. It’s a multistep thing.

First, maybe you have a bunch of genomes from your favorite species where the exact genealogical history of every SNP is known.  And you know every evolutionary force that has ever acted on each SNP too.  Then all you have to do is… 

What’s that? You don’t have that kind of data? Or even a kind of halfway approximate version of it. OK, well, first let me welcome you to the club. In this club we do coalescent simulations to make up fake data sets that are sorta like those of the real critters we are studying.  Then we feed these fake data sets into our neural net to train it and hope that we’re close enough to reality that when we point the thing at the real data it works. 

So step one, you have to simulate training data for your net.  And like I said, what that usually means is building a customized coalescent simulation.  So you have to capture all the important complexity of the evolutionary history of your sample of genomes, and then simulate conditional on that. You also have to simulate the thing you want to detect, like a set of sims with the selective sweep and a set without.  Probably 90% of all your mistakes are going to happen during this stage of the process.  But since everything is super dependent on the exact nature of your sample and the critters you study, all your mistakes will be new and different from all of ours, so I can’t give you a lot of tips other than to talk to coalescent simulation experts. So that’s my advice. Make mistakes, talk to people, fix those mistakes and repeat until all the remaining mistakes are so deep that the average reviewer probably will never notice.   

I’ll assume you accomplished step one, and you have a lot of coalescent sims in hand and they are all really really perfect for your question. And hopefully you still have some of your motivation intact too.

Next you have to get your coalescent sims into python as matrices (oh yeah, if you don’t already know python, go learn that too).  We have some code that you can use as an example.  We used ‘ms’ formatted files to store our coalescent sims.  They’re unpleasant to parse. I’d recommend using msprime, which gives you the output of the sim in a nice python format off the bat.

Now we are nearly ready to train our network.  But first we have to build our compute environment. I highly recommend using Keras and Tensorflow. They are two great python packages for training neural nets. Tensorflow is the math guts of neural nets, and Keras is a package that sits on top and hides a lot of the math from you. Tensorflow is a bear to install.  I recommend using a cloud service like FloydHub where it’s already pre-installed (I’m not affiliated with FloydHub). Or bribe someone to install it for you, but probably just use a cloud service. They’re so cheap and your time is valuable.

Neural networks are computationally intensive, so make sure you use a machine with a GPU, because it will make things way way faster.  What’s a GPU you say?  The truth is I don’t really know.  The Wikipedia page says things like “stream processing” and “compute kernels”, which are all words that mean something I guess.  Here’s what I do know. GPUs do matrix math really fast, and training a neural net involves tons of matrix math.  So they make the whole process way faster.

OK, so you got Keras and Tensorflow up and running on a GPU either by bribing someone or using a cloud service. Now you are ready to train!!

Training is my favorite part, because it involves baby-sitting a computer while listening to music and drinking beer.  So first get some tunes going.  I recommend something electronic because that’s the mother-tongue of neural nets.  The goal is to comfort your net with songs in its language so it can learn faster.  The beer is for you, so that's your call. 

OK, first step in training (after taking a good pull of your beer), is to split the coalescent sims into three groups, the training group, the validation group, and the testing group.  We’re going to train the neural net with the training group, make parameter changes with the validation group,  and then test its accuracy with the test group (pretty creative names huh!).  So the training group should be big.  We used like 100K-200K individual coalescent sims usually, but even more would probably have been better.  And the validation and testing data can be smaller, like 5K-20K each.  But the trick with the validation and testing data is that it needs to be big enough to have all the important features of the training data (and vice-versa).  So let say we are trying to detect selective sweeps.  It’d be a huge mistake to put all the hard sweeps in the training data and all the soft sweeps in the validation and testing data.  Instead, we need to randomly mix them so both sets have similar distributions of hard and soft sweeps (and even similar intensities of hard and soft sweepy-ness).

Once we’ve split the data, we pick a neural net architecture and we start training.  Architectures that work pretty well are here, and if you read the code you’ll see you just specify the architecture layer by layer in with Keras.  Network architecture is a dense topic that I don’t completely understand, but experts seem to agree that nobody else understands it either.

Each pass over our training samples is called an epoch. Generally speaking one epoch takes about a quarter of a beer. So you fire off one epoch, and sit back, enjoy the beer, and groove to tunes.  Keras provides a little progress bar and 2 numbers called loss and val_loss, and it kinda looks like this at the end of the epoch.

Train on 143064 samples, validate on 7167 samples
Epoch 1/1 143064/143064 [==============================] -325s -loss: 0.3728 -val_loss: 0.3117

loss is the current state of the loss function on the training data, and val_loss is the current state of the loss function on the validation data.  The loss function is just a measure of how well the net is doing at getting the right answer.  We want both loss and val_loss to be small, but really we want val_loss to be small, because that’s the true test.  And as the model trains we want to see both loss and val_loss reduce with every epoch.

There’s something funny going on above.  The value at loss is actually greater than the one at val_loss.  What that means is that the net does better on data it has never seen than on data it was trained on.  WTF is up with that?  Is it psychic or something??

Relax, it’s cool. In fact, it’s really good. What’s going on is that we are using a technique called dropout.  By using dropout we handicap the net during training, so it thinks it’s dumber than it really is and tries extra hard to learn.  Then when it applies itself to the validation data, we take the handicap away and it does better than it ever thought it could.  This is kind of a shitty thing to do to it of course. Our little net is definitely going to have trust issues down the road, but it’s just a bunch of bits, so tough luck kid. That said, when the computers take over we are all going to pay dearly for this dropout stuff.

Training only takes about ½ a brain and a couple beers.  Maybe less of the former and more of the latter. That’s why I like training. There is one thing you should do at some point.  Figure out the math that makes up your loss function (they’re usually the mean square error or cross-entropy, which are pretty easy equations), and figure out some way to convert this into units you can contextualize (like RMSE or accuracy (and Keras will do accuracy for you)).  That way while this thing runs you can see how it’s doing in units that make sense and you can feel good about life as it gets smarter and you get drunker. 

But don’t get too drunk.  There’s one other thing you have to do.  You have to watch out for overfitting. This is when the net starts predicting your training data super well, but at the same time gets worse at predicting your validation data. Overfitting is just a fancy way of saying that something trained really hard to memorize a specific task, but didn’t learn to solve the problems in a general way. 

Neural nets are really smart, and they love to overfit, however overfitting is pretty easy to diagnose.  Any time loss is smaller than val_loss you are seeing overfitting.  That’s why dropout is good.  It keeps loss bigger than val_loss.  Usually a little over fitting is ok, and sometimes it corrects itself after a few epochs.  But when it gets out of hand it’s pretty ugly, and often it spirals until your model is so overfit that it’s pretty much useless at predicting the validation data even though it’s crushing it on the training data.  If it overfits really fast and well before it achieves any kind of accuracy you’d call acceptable, then you might have to change the architecture, or make more training data, or use even more dropout (you heartless jerk).  Here’s a cool blog post that talks you through your options.  I recommend carefully going through it step-by-step, and for now skipping all the steps you don’t understand.

On the other hand, maybe your net just keeps trucking along getting better and better.  Lucky you!  But at some point you have to stop drinking and call it good enough.  Then you evaluate your net one last time, this time on test data that were never used in training. This is important, because you tuned the model using the validation data, but now you need an independent assessment, and that's what the untouched test data is for. Then you save the trained net and the accuracy on the test data and get ready to run it on your real data and do real biology!

Have you ever learned about some flashy new computational approach for solving your exact problem? Holy cow, exactly what you needed! Amazing, right! Then you download it and realize that the inconsiderate bastards who wrote this crappy software went out of their way to make it hard to use. You spend weeks reformatting your data and just seething with contempt for these jerks. Important stuff is undocumented, and it even crashes on their stupid example, and whitespace is f-ing important but of course that’s not discussed and… uhg. Who do they think they are?!  How could they be such assholes, wasting countless hours of everyone’s time, just because they couldn’t be bothered to use some standard input formats, or write halfway decent documentation, or provide code that compiles.  Oh do you hate them, just a burning hatred!

OK, first you’ve got some issues, but don’t worry so do I, so no judgment here. But second - and this is the big payoff - guess what, when you make your own neural net to solve your own problem you get to be your own inconsiderate bastard who makes it a huge pain in the ass to use. How cool is that!

Yeah, so I should have mentioned this earlier.  Before you start, you need to think about your real data a lot.  Consider how you want to run it through the net. Then you need to tailor your coalescent sims to match whatever paradigm you choose.  If you have 20 individuals and you want to run 20 kb chunks of the genome through, then you need to simulate 20 individuals and tune the sim parameters to approximate diversity in those 20 kb chunks.  Also, are your data phased or unphased, and how much sequence error do you think you have?  And do the A,C,G, and T bases matter, or can you code all your biallelic sites arbitrarily as 0 & 1?  You gotta work through all this stuff, and tune the sims accordingly.  It all matters, because the neural net will learn from what you give it, not what you meant to give it. Also, you are probably going to tweak and rerun the simulations several times before you get it right.  Hence my comment above about 90% of your mistakes, it’s this stage where they all come to light.  

Sounds bad, I know, but just use the neural net that sits on you shoulders and try to imagine what the artificial one in your computer is seeing in the data, and then adjust the sims accordingly. Plus when you are your own inconsiderate jerk who made the thing a pain to use, at least the feedback loop usually is pretty short.  

You’re almost done!  You made your training data, and trained your net, then you tested it against unseen data, and finally you ran real data through.  And guess what, it found stuff! Time to write a paper! So now you have to explain what it found. One approach is to fall back to non-neural net methods.  Like apply some pop gen summary stats just as you always have.  It’s cool, go ahead, I won’t judge.  Then maybe compare the patterns of the net versus the traditional methods and figure out how to explain it to someone using traditional summary stats.  Or maybe don’t.  Instead you could just point to the quality of your neural net on the test data and rest your case.  If a reviewer asks tell them “Stats! We don’t need no stinkin’ stats!”.  I’m cool with that too.  Whatever works for the story you want to tell.  It might even be helpful to read some cool machine learning papers in biology and see how other people reported discoveries made with a black box network.

Here’s my final pitch: we’ve been doing pattern recognition in pop gen for years (though rarely calling it that). I (and other like-minded folks) think it’s time we adopt more powerful pattern recognition approaches.  I’m excited about neural nets and their potential in this regard, but they are different and will require a little retooling on your part. So, if you decide to give it a whirl and get stuck, reach out to me.  I’m happy to chat and offer any tips I can.  The best ways to reach me are on Twitter @flagelbagel or my email, which is a gmail account with the prefix flag0010.

OK, I’m getting it, and I see this whole thing as being terribly reliant on the quality of my coalescent sims.  Seems like a house of cards.  Right?


Yes, how well your coalescent sims match your real data really matters. I can’t stress that enough.  You can’t train a neural net on sims that are appropriate for elephant demography and expect it to do well on fruit flies.  In fact, you probably can’t even train on the demography of one species of fruit fly and hope it will work well on another. But hey, it’s not hard to get heaps of genomic data, so just estimate the demography and simulate accordingly.  This is kinda like specifying your prior, and in my opinion that's a good thing. 

You also should know that neural nets do well at interpolation, but they struggle with extrapolation.  So for example, if you train one to look for sweeps when the effective population size ranges from 10K-50K, they will work well on real data that falls in that same range, even for points in that range you never actually simulated (interpolation). But they probably won’t work well for real data that lies outside that range (extrapolation).  So you have to be cognizant of the extremes in your data and train all the way to those extremes.  And bear in mind, the coalescent is quite stochastic.  So things like the effective population size are going to have a big range across different loci in the genome and you need your sims to be similarly extreme.    

Finally, traditional summary stats based pop gen also tends to rely heavily on coalescent simulations too. A common practice is to generate oodles of summary stat values from the real data and then use coalescent sims to determine the cutoff values for significance. So if you do that kind of pop gen already, then either way you slice it, it’s turtles and coalescent sims all the way down.


There’s so much hype around neural nets. What happens in 5 years when they become the next “big data” or “VR” or “wearables”?


Yes, there is a LOT of hype around deep learning in tech right now, and you are absolutely right to be a skeptic when you see this kind of hype. “Deep Learning for X” startups are launching daily, and many are selling pixie dust. They and their financiers will be screwed soon enough, but you should know that the core tech behind neural nets is real and it’s not going away.  This is absolutely happening.  If you see yourself being in a quantitative field in the future, it's probably time to add neural nets to your toolbox.

1 comment:

  1. Great post! You might know this already but when you ask 'what is a GPU?' -- most computers have one in their graphics card -- which tensorflow can use.

    ReplyDelete