The intended usage, and some design challenges, for this feature are described in README.md. In order to enable this change, I had to first change TensorFlowTracer to carry explicitly the JAX abstract value for the undelying TF value (previously, we would just compute the abstract value from the TF value. This is not possible when the TF value has partially known shape. In that case we want the JAX abstract value, using masking.ShapeSpec. As a beneficial side-effect of adding explicit abstract values to TensorFlowTracer we can clean up all the hacky handling of core.unit (we would store core.unit as a TF value, hence the need for TfValOrUnit, and we would swap it with tf.nan when going to TF). Now we can just store tf.name as the TF value and core.abstract_unit for the abstract value. I added a few extra assertions that the values and abstractions in a TensorFlowTracer are in agreement. The key smarts in this change are just reused from jax.interpreters.masking. All we really added is carefully carrying that information through. A tricky part was carrying the abstract shapes for tf.custom_gradient.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4