By William Falcon, Founder PyTorch Lightning
Today [Recently] we released 0.8.1 which is a major milestone for PyTorch Lightning. With incredible user adoption and growth, we’re continuing to build tools to easily do AI research.
This major release puts us on track for final API changes for our v1.0.0 coming soon!
PyTorch Lightning is a very light-weight structure for PyTorch — it’s more of a style guide than a framework. But once you structure your code, we give you free GPU, TPU, 16-bit precision support and much more!
Lightning is just structured PyTorch
This release has a major new package inside lightning, a multi-GPU metrics package!
There are two key facts about the metrics package in Lightning.
- It works with plain PyTorch!
- It automatically handles multi-GPUs for you via DDP. That means that even if you calculate the accuracy on one or 20 GPUs, we handle that for you automatically.
The metrics package also includes mappings to sklearn metrics to bridge between numpy, sklearn and PyTorch, as well as a fancy class you can use to implement your own metrics.
class RMSE(TensorMetric): def forward(self, x, y): return torch.sqrt(torch.mean(torch.pow(x-y, 2.0)))
The metrics package has over 18 metrics currently implemented (including functional metrics). Check out our documentation for a full list!
This release also cleaned up really cool debugging tools we’ve had in lightning for a while. The overfit_batches flag can now let you overfit on a small subset of data to sanity check that your model doesn’t have major bugs.
The logic is that if you can’t even overfit 1 batch of data, then there’s no use in training the rest of the model. This can help you figure out if you’ve implemented something correctly, or to make sure your math is correct
If you do this in Lightning, this is what you will get:
Faster multi-GPU training
Another key part of this release is speed-ups we made to distributed training via DDP. The change comes from allowing DDP to work with num_workers>0 in Dataloaders
Today when you use DDP by launching it via .spawn() and you try to use num_workers>0 in Dataloader, your program will likely freeze and not start training (this is also true outside of Lightning).
The solution for most is to set num_workers=0, but that means that your training is going to be reaaaaally slow. To enable num_workers>0 AND DDP, we now launch DDP under the hood without spawn. This removes a lot of other weird restrictions like the need to pickle everything and the need for model weights to not be available once training has finished (because the weights were learned in a subprocess with different memory).
Thus, our implementation of DDP here is much much faster than normal. But of course, we keep both for flexibility:
# very fast :) Trainer(distributed_backend='ddp')# very slow Trainer(distributed_backend='ddp_spawn')
Other cool features of the release
- .test() now automatically loads the best model weights for you!
model = Model() trainer = Trainer() trainer.fit(model)# automatically loads the best weights! trainer.test()
- Install lightning via conda now
conda install pytorch-lightning -c conda-forge
- ModelCheckpoint tracks the path to the best weights
ckpt_callback = ModelCheckpoint(...) trainer = Trainer(model_checkpoint=ckpt_callback) trainer.fit(model)best_weights = ckpt_callback.best_model_path
- Automatically move data to correct device during inference
class LitModel(LightningModule): @auto_move_data def forward(self, x): return xmodel = LitModel() x = torch.rand(2, 3) model = model.cuda(2)# this works! model(x)
- many more speed improvements including single-TPU speed-ups (we already support multi-tpu out of the box as well)
Try Lightning today
If you haven’t yet! give Lightning a chance 🙂
This video explains how to refactor your PyTorch code into Lightning.
Bio: William Falcon is an AI Researcher, and Founder at PyTorch Lightning. He is trying to understand the brain, build AI and use it at scale.
Original. Reposted with permission.