GSoC

Google Summer of Code 2022: Final Report

Getting accepted as a GSoC contributor and working with Dr. Kevin Murphy was a dream!

In this post, I will share my GSoC’22 journey with the TensorFlow organization on the project “JAX ML Textbook” mentored by Dr. Kevin Murphy. The project contains the tasks related to the upcoming textbook Probabilistic Machine Learning: Advanced Topics (book2) and the existing textbook Probabilistic Machine Learning: An Introduction (book1) by Dr. Murphy. I contributed in mainly two repos - 1) probml/pyprobml: A public repo where the code is published. 2) probml/bookv2: A private repo for the changes in the latex source code of the textbook.

Highlights

Main contributions

I had broadly 3 main tasks in my GSoC project

No.TaskTime allocated (in %)
1JAX implementation of algorithms30%
2Improving figures quality by matplotlib tricks40%
3Refactoring the codebase30%

1. JAX implementation of algorithms:

In this task, I have converted existing python code to the JAX framework or created a new educational demo to create figures in the textbook. I have mainly worked on examples of Markov Chain Monte Carlo (MCMC) algorithms, however, I have created a few notebooks for other topics also. Major work is as following:

1) Automatic Differentiation Variational Inference (ADVI) from scratch in JAX [Notebook]: This was my first demo which I have created for the textbook, in which I have implemented ADVI algorithm from scratch on a beta-binomial model by referring the original ADVI paper. Where I started to explore JAX tricks such as using lax.scan() instead of for loop to speed up program (by 50x in this demo!).

Figure Description

2) Posterior of a beta-Bernoulli model using MCMC [PR] [Notebook]: Approximated posterior of beta-binomial model using blackjax sampling library and No U-turn Sampling algorithm (NUTS). In this demo first time I introduced with blackjax library which I used a lot in upcoming demos.

Figure Description

3) Laplace approximation for beta-binomial model [PR] [Notebook]: Approximated the posterior of beta-bernoulli model for coin toss problem using Laplace approximation method from scratch.

Figure Description

4) Markov chain convergence on uniform distribution [Notebook] [PR]: Recreated the figure which illustartes markov chain convergence.

Figure Description

5) MCMC’s diagnostic: R-hat [PR] [Notebook]: R-hat (potential scale reduction) is MCMC diagnostic to quantify convergence of MCMC samples. I have reproduced the figure given in this paper to illustrate the difference between split R-hat and non split R-hat.

Figure Description

6) MCMC’s diagnostic: Trace plots and rank plots [PR] [Notebook]: Again trace plots and rank plots are MCMC diagnostics which is used to judge convergence. I have reproduced the numpyro demo into JAX using blackjax library. Here I have used the arviz library (famous plotting library for PML algorithms) with blackjax first time.

Figure
Diffuse prior Description Description
Sensible prior Description Description

7) Non centered parameterization in hierarchical bayesian model [PR] [Notebook]: Recreated example of pymc in JAX which shows the problem of using bayesian hierarchical model without using reparameterization trick.

Figure Description

8) Bayesian neural networks (BNN) using SGD & SGLD [PR] [Notebook]: This demo compare prediction’s uncertainty between bayesian algorithm (SGLD) and non-bayesian algorithm (SGD). In this demo first I explored the flax library to create multi-layer perceptron (MLP) neural network model. Then I used SGLD algorithm using the blackjax library for sampling the weights.

Figure Description

9) Change of variable in Hamiltonian Monte Carlo (HMC) [PR] [Notebook]: In this demo I have illustrated the need for a change of variable while defining the log joint density function of bayesian hierarchical models. This is done implicitly in high-level inference libraries such as pymc, pyro, numpyro, etc but we need to do it manually while using a low level library such as blackjax. I have also added this demo in blackjax library’s examples [PR].

Figure Description

2. Improving figures quality by matplotlib tricks:

In the first draft of the book, most of the figures were made according matplotlib (A plotting library) defaults which were not fitting with textbook settings, it had some issues such as fonts in figures were small compared to caption; font style did not match with text; some labels were missing such as x-label, y-label, legend, etc. I synchronized with Zeel (GSoC contributor) and improved almost all needed figures by the latexification of the figures. One example of improvement as shown in the following.

Before PRAfter PR
imageimage

I can classify this task into 3 sub-taks:

No.Sub-taskDescription/PRs
1.Latexification of figuresExamples PR: #723 #726 #891 #1101 #1038
2.Editing latex source code to add updated figureAll the edits done in private bookv2/ repo
3.Reviewing code of Non-GSoC contributors who were helping us to complete this taskI reviewed 50+ PRs in pyprobml-review/pyprobml-review repo.

3. Refactoring pyprobml repo for better management:

There are 425+ notebooks in the pyprobml repo which contains code that uses almost all ML libraries of Python. To manage this large codebase we need mechanisms to manage it well. In my GSoC, I synchronized with Zeel in contributing refactoring tasks including but not limited to organizing structure of repo, converting .py to .ipynb notebooks, redirection of figures’ code from the textbook, Creating workflows on PR, Creating dashboards of notebooks, generating well-organized readme files, dead url checking in textbook, converting pdfs of figures to cmyk format. resolving comments on the book by MIT press, etc. I have enlisted some parts of this tasks in the following table:

TaskPR
Resolving notebook errors raised due to library update or other reasons#765 #774 #935 #936 #960
Detected dead URLs in book and created a dummy notebook for figures which has more than one notebook. This dummy notebook contains links to original notebooks.#781
Added auto-generated reademe.md files for book1#815 #816
Moved tutorials notebooks to corresponding chapter programmatically#841
Moved 49 duplicated notebooks having the same names from notebooks/misc to deprecated/#858
Added mapping of figure name to figure height, which overcomes the problem of seeing in latex source code of book while setting figure height in code.#869 #911
Added notebooks.md which is an index of all notebooks, reader will be redirected to the notebook’s entry in this table after clicking on the hyperlink given in figure’s caption.#932
Renamed notebook names from camel case to snake case#942
Removed suffix like _pymc, _blackjax, etc from book2 notebooks#940
Updated book1 redirection links in the firestore database to introduce new redirection to notebooks.md#1100

Conclusion

GSoC’22 was high learning experience for me, In this period I improved a lot in both technical and non-technical skills. Special thanks to mentor Dr. Murphy and Prof. Nipun Batra for their consistent support and this amazing opportunity. I would also like to thank Zeel for always helping me whenever I get stuck.