I have been looking for a way to efficiently ship data efficiently from Java data pipelines to machine learning (ML) frameworks in Python (specifically Pytorch).
The goal is to pass a large amount of data (per training batch) to Pytorch running in Python efficiently (so then I can train and run experiments in Python).
@saudet has kindly provided guidance in bytedec/javacpp-presets#1107 and I have an initial prototype working (see below) but I thought this may be a more appropriate place for further discussion. It is very fast, but I know JavaCPP may not have been designed for this, so I would like to hear some advice on where things could break / potential flaws from experts in this area. Thanks in advance!
Current solution with JavaCPPWhere I see JavaCPP comes in is to allow us to create a large data array in C from Java (via JavaCPP's FloatPointer), then wrap that same array in our training batch in Pytorch, so we don't waste time copying data multiple times.
Specifically, using JavaCPP's Pointer classes (e.g. FloatPointer) on top of a Java data iterator like:
public class MyJavaDataIter {
public FloatPointer nextBatch(...) {
// ...
}
}
then when we want to train in Python, I use Python's ctypes to wrap the returned tensor.address()
like:
import jnius
from jnius import autoclass
dataiter = autoclass(...MyJavaDataIter)
tensor = dataiter.nextBatch() # tensor is a JavaCPP FloatPointer object here
p = ctypes.cast(tensor.address(), ctypes.POINTER(ctypes.c_float)) # tensor.address() is the JavaCPP method
arr = np.ctypeslib.as_array(p, [100000, 443])
then arr
can be wrapped around Pytorch without new memory creation or Python ever seeing the data, and makes this method very fast. After the batch is done, I can simply call tensor.close()
perhaps in a try-finally loop. Here I use PyJNIus to expose the Java function but @saudet has suggested simpler/better options in bytedec/javacpp-presets#1107.
I want to emphasize though that the value of using JavaCPP is not to expose functions but to pass data efficiently. IMO this is a bit different from the related discussion bytedeco/javacpp#17 as I am looking to build a Java program with JavaCPP, but my usecase is to pass data efficiently to Pytorch in Python to consume without converting data to Python. For some back-of-envelope calculation, a typical batch we want may have size 1e4 samples, each having around 1e3 features, which means the FloatPointer is carrying 1e7 4 byte floats. I have tried other options but this is by far the fastest.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4