Diary 2019-04-05
So, I cut myself off yesterday, but I've got some ideas for how to make my numerical code, or code very like it, faster.
The basic premise of a variety of optimization libraries and frameworks for Python is that if your inner loop is hot enough, you don't want to have Python function calls inside it. My own code does something more than one call per edge in its graph. As ways to implement stacked matrix multiplication go, I'm not surprised such techniques haven't taken off. We're talking like over a thousand function calls for a single evaluation of the network, easy. Now, we need at least one function call to make things actually happen. Can we reach one function call per error evaluation? Per gradient calculation? Per training step? Per batch of training steps between diagnostic dumps?
I think I can get that far. It probably won't be amazing on my machine—I don't have fancy GPUs to work with—but I expect I'll be able to get modest improvements, at least.
So, what does reducing the number of function calls look like? Basically, you take some Python code with function calls, and transform it in some way to remove the Python function calls.
Possible ways to do this include:
- Try to rewrite everything in RPython. I tried to do this. I now do not recommend RPython for things that aren't interpreters.
- Rewrite parts of it in Cython. I haven't personally used Cython directly myself; I don't know how well it does at optimizing the kind of tiny shrapnel functions my library here has.
- Use an actual machine learning library/framework like Theano or TensorFlow. Theano says it compiles graphs to C, and TensorFlow appears to compile graphs to... something. I probably should use these, but I like feeling like I put every bit of the graph-level logic together myself.
- Compile some functions using LLVM, via Numba. I've experimented some with Numba, and while some of the advanced features aren't all the way there yet (I wanted to wrap a dataclass in a jitclass, and that just... doesn't work), I like how the interface works out. I haven't profiled this, so I'd like to do that before going all in on this idea, but I have reason to be optimistic about LLVM's ability to handle tiny shrapnel functions.
In any case, before I sink effort into any kind of optimization like that, I should first try to do some due diligence about timing and profiling my code. The old code I wrote, I profiled in an ad-hoc manner, which isn't great so far as knowing which profile data even matters.
Let's see about working out what kind of patterns of usage to profile, and how... Tomorrow.