There are many interesting threads in this post, one of which is using “non standard interpretations” of programs, and enabling the compiler to augment the human-written code with the extra pieces necessary to get gradients, propagate uncertainties, etc. I wonder whether there’s a more unified discussion of the potential of these methods. I suspect that a lot of “solvers” (each typically with their own DSL for specifying the problem) might be nicely formulated in such a framework. (Particularly in the case of auto diff, I found recent work/talks by Conal Elliot and Tom Minka quite enlightening.)
Tangentially, thinking about Julia, while one initially gets awed by the speed, and then the multiple dispatch, I wonder whether it’s deepest superpower (that we’re still discovering) might be the expressiveness to augment the compiler to do interesting things with a piece of code. Generic programming then acts as a lever to use these improvements for a variety of use cases, and the speed is merely the icing on the cake!
The whole ecosystem is in ebullition, and I'm very eager to see if it will be able to transform in the comping years into a solid foundation able to rival the layers of warts stacked on top of Python.
Just a nit: Zygote is not technically a language level library. In fact, none of the above libraries really are language level — they are all just Julia libraries, some of which use macros. With the exception of metaprogramming with macros, none of these libraries customize inference/opt to do something funky.
Zygote (for example) is based on a language level feature — generated functions which allow multistage programming in Julia — but beyond using this feature it does not configure or otherwise modify the normal compiler pipeline.
With Julia, it is sometimes tough to tell — but Diffractor (for example) I would consider a language level library, as it modifies the traditional compiler pipeline to perform inference and optimization in a specific way.
I guess it depends on what you call “language-level”; I mentioned these libraries because they all make use of macros and/or AST manipulation, which I would put somewhere in a grey area between language and compiler.
Autodiff isn't even the most impressive application of non standard interpretation. The most impressive one is sparsity detection. If you have some arbitrary code that returns a matrix you can infer which elements of the matrix are guaranteed to be zero independent of input using non standard interpretation by taking all the branches of the program.
This discovered sparsity can be exploited by a lot the generic algorithms in Julia that also are able to efficiently work on sparse matrices. There is also a way of combining this with AD to get automatic sparse hessian of optimization which is huge.
Just a comment: you’re right on the money here. This is the dream that a few people in the Julia community are working towards.
The framework of abstract interpretation, when combined with multiple dispatch as a language design feature, is absolutely insane.
I think programming language enthusiasts might meditate on these points —- and get quite excited with the direction that the Julia compiler implementation is heading.
I remember spending a summer using Template Model Builder (TMB), which is a useful R/C++ automatic differentiation (AD) framework, for working with accelerated failure time models. For these models, the survival to time T given covariates X is defined by S(t|X) = P(T>t|X) = S_0(t exp(-beta^T X)) for baseline survival S_0(t). I wanted to use splines for the baseline survival and then use AD for gradients and random effects. Unfortunately, after implementing the splines in template C++, I found a web page entitled "Things you should NOT do in TMB" (https://github.com/kaskr/adcomp/wiki/Things-you-should-NOT-d...) - which included using if statements that are based on coefficients. In this case, the splines for S_0 depend on beta, which is this specific excluded case:(. An older framework (ADMB) did not have this constraint, but dissemination of code was more difficult. Finally, PyTorch did not have an implementation of B-splines or an implementation for Laplace's approximation. Returning to my opening comment, there is no free lunch.
There is definitely no free lunch, it's good to really delineate the engineering trade-offs you're making! A lot of this work actually comes from the fact that some people I work with were building tools that could efficiently handle dynamic control flow without requiring tracing (see the description of Zygote.jl https://arxiv.org/abs/1810.07951). I had to bring up the question: why? It's much harder to build, needs more machinery, and in some cases can make less assumptions/less fusions (a general form of vmap is much harder for example if you cannot trace, see KernelAbstractions.jl for details). This line of inquiry led an example of why you might want to support such dynamic behaviors, so I'll leave it up to someone else to declare whether the maintenance or complexity cost is worth it to them. I wouldn't say that this means Jax or Tensorflow are doomed (far from it: simple ML architectures are quasi-static, so it's building for the correct audience), but it's good to know what exactly you're leaving out when you make a simplifying assumption.
Were you optimizing over the knots as well? Otherwise I can't see why this would be disallowed using either forward or reverse-mode AD. An infinitesimal perturbation of beta will not cause t * exp(-beta^T x) to cross a knot, so the whole thing is smooth. (And, with B-splines the derivatives are continuous from piece to piece anyways.) But in general I agree--a good spline implementation I something I miss the most when moving from scipy.interpolate to jax.scipy. Given that the SciPy implementation is mostly F77 code written before I was born, I do not see this situation resolving itself anytime soon.
A short answer: this requires a basis matrix for the splines rather than interpolation.
A longer answer: the splines require a basis matrix B(t), do that g(S_0(t))= B(t) gamma for a vector of parameters gamma and some transformation g of survival. A classical choice would be to use M-splines and I-splines with g(S)=-log(S) and a penalised likelihood, with the constraint that the gammas should be increasing. In R, this would use the splines2 package, while in Julia, one could use the splines2.jl package (disclaimer: which I maintain). The computational challenge is that the basis matrices need to be re-evaluated for changes in the coefficients for the covariates (that is, the betas).
It's not about smoothness, it's about how to JIT the gradient function. ML libs don't generally do interpolation, partly because it's tricky to vectorize (you have to search for which segment to use for each element) and partly because most ML practioners don't need it. What I've done in my code is use all the vertices for all the elements, but with weights that are mostly zero. It's pretty fast on GPU because I don't use that many vertices.
"Interpolation" here I think effectively boils down to np.searchsorted() or its equivalent, which is implemented in all the major ML libs (jittable and backpropable).
"I wanted to use splines for the baseline survival" - isn't this a modelling step that could have been revisited? It seems a somewhat arbitrary choice (there are other ways to ~interpolate that are much more framework friendly) - and it seems that it forced you down a bit of a rabbit-hole.
It seems like you can't solve this kind of thing with a new jax primitive for the algorithm, but what prevents new function transformations from doing what the mentioned julia libraries do? It seems like between new function transformations and new primitives, you out to be able to do just about anything. Is XLA the issue, and you could run but not jit the result?
XLA is the limiting factor in a lot of these cases, though maybe saying limiting factor is wrong because it's more of a "trade-off factor". XLA wants to know the static size of a lot of arguments so it can build a mathematical description of the compute graph and fuse linear algebra commands freely. What the Julia libraries like Zygote do is say "there is no good mathematical description of this, so I will generate source code instead" (and some programs like Tapenade are similar). For example, while loops are translated into for loops where a stack of the Boolean choices are stored so they can be ran in reverse during the backpass. The Julia libraries can sometimes have more trouble automatically fusing linear algebra commands though, since then they need to say "my IR lets non-static things occur, so therefore I need to prove it's static before doing transformation X". It's much easier to know you can do such transformations if anything written in the IR obeys the rules required for the transform! So it's a trade-off. In the search for allowing differentiation of any program in the language, the Julia AD tools have gone for extreme flexibility (and can rely on the fact that Julia has a compiler that can JIT compile any generated valid Julia code) and I find it really interesting to try and elucidate what you actually gain from that.
> and can rely on the fact that Julia has a compiler that can JIT compile any generated valid Julia code
This seems to be the key bit. It’s a great data point around the meme of “with a sufficiently advanced compiler…” In this case we have sufficiently advanced compilers to make very different JIT trade offs. XLA is differently powerful compared to Julia. Very cool, thanks for the insight.
If the next machine learning killer-app model requires autodiff'ed dynamic control flow, do you think Google/Facebook will build that capability into XLA/TorchScript? Seems like if NLP SOTA requires dynamic control flow, Google will build it? Maybe they let you declare some subgraph as "dynamic" to avoid static optimizations? But maybe the static graph assumption is so deeply embedded into the XLA architecture, they'd be better off just adopting Julia? (I honestly don't know the answer, asking your opinion!)
"Maybe they let you declare some subgraph as 'dynamic' to avoid static optimizations?" What you just described is Tensorflow Eager and why it has some performance issues (but more flexibility!). XLA makes some pretty strong assumptions and I don't think that should change. Tensorflow's ability to automatically generate good automatically parallelized production code stems from the restrictions it has imposed. So I wouldn't even try for a "one true AD to rule them all" since making things more flexible will reduce the amount of compiler optimizations that can be automatically performed.
To get the more flexible form, you really would want to do it in a way that uses a full programming language's IR as its target. I think trying to use a fully dynamic programming language IR directly (Python, R, etc.) directly would be pretty insane because it would be hard to enforce rules and get performance. So some language that has a front end over an optimizing compiler (LLVM) would probably make the most sense. Zygote and Diffractor uses Julia's IR, but there are other ways to do this as well. Enzyme (https://github.com/wsmoses/Enzyme.jl) uses the LLVM IR directly for doing source-to-source translations. Using some dialect of LLVM (provided by MLIR) might be an interesting place to write a more ML-focused flexible AD system. Swift for Tensorflow used the Swift IR. This mindset starts to show why those tools were chosen.
Makes sense. I don't use TF Eager, but I do use Jax, and Jax lets you arbitrarily compose JITed and non-JITed code, which made me think that might be a viable pattern. I guess I wondered if there might be something like "nonstatic_jit(foo)" that would do "julia style" compiling on function foo, in addition to "jit(foo)" that compiles foo to optimized XLA ops. Probably impractical. Thanks.
static_argnums is really just a way to give a bit more assumptions to attempt to build a quasi-static code even if it's using dynamic constructs. In this example that will force it to trace one only one of the two branches (depending on whichever static_argnums sends it down). That is going to generate incorrect code for input values which should've traced the other branch (so the real solution of `lax.cond` is to always trace and always compute both branches, as mentioned in the post). If the computation is actually not quasi-static, there's no good choice for a static argnum. See the factorial example.
Tangentialy related: Faster training of Neural ODEs is super exciting! There are a lot of promising applications (although personally I believe that the intuition of "magically choosing the number of layers" is misguided, but I'm not am expert and might be wrong) but right now it takes forever to train even on toy problems, but I'm sure that enough work in this direction will eventually lead to more practical methods.
The example that fails in Jax would work fine in PyTorch. If you're working on purely training the model, TorchScript doesn't give many benefits, if any.
I mean, that's kind of the point. These are things you can make work with these frameworks, you just have to opt out of all optimizers and JIT compilers to make it work (including TorchScript). Yes, there are some examples you can get away with that because the matmuls are big enough. No, that's not every model you could ever use, or even close to that. In fact, in the context of ODEs we found for example torchdiffeq is ~20x-5000x slower than it should be on simulating many scientific models, in large part because TorchScript optimizations are not able to be enabled on that code (https://gist.github.com/ChrisRackauckas/cc6ac746e2dfd285c28e... , and it shows that enabling TorchScript does nothing, and you can trace it back to the dynamic constructs). This is thus a proof of concept of some ML algorithms that are then not handled by the optimizers and require the fallbacks to the slower setups.
Tangentially, thinking about Julia, while one initially gets awed by the speed, and then the multiple dispatch, I wonder whether it’s deepest superpower (that we’re still discovering) might be the expressiveness to augment the compiler to do interesting things with a piece of code. Generic programming then acts as a lever to use these improvements for a variety of use cases, and the speed is merely the icing on the cake!