AAI_2025_Capstone_Chronicles_Combined

computational buffers and maximizing hardware efficiency. This approach allows for the simultaneous training of 100 small neural networks in a time comparable to training a single network in PyTorch on a Titan GPU (Whitney, 2021). However, there are several limitations to Whitney’s implementation. First, the API used in his demonstration is now outdated, such as the reliance on optim from the Flax library, which has since been replaced by Optax. More critically, his approach is not abstracted for general use, meaning that users must manually set up vectorization for each new network or dataset. The goal of this project, through the development of TurbaNet, is to create an abstraction that allows users to take advantage of JAX's vectorization capabilities without needing to manually configure each individual network. The core principle behind this approach is the use of vmap to facilitate parallel training and evaluation of small neural networks. JAX’s vmap function allows a given operation to be applied over multiple inputs in parallel, automatically batching computations to fully utilize available hardware (Google Research, 2024). Normally, when training a single small neural network, GPU memory is often underutilized because the network's parameter space is too small to fill the buffer. However, by using vmap to simultaneously process multiple small networks, we can better occupy the GPU’s memory and computational resources, leading to more efficient execution. Combining this with Jax’s JIT compiler, the performance gains for simultaneous training can be substantial. TurbaNet takes a different approach than ensemble bagging libraries like scikit-learn’s BaggingClassifier or distributed multi-GPU orchestration tools such as Horovod. Instead of coordinating across multiple devices or wrapping full training loops, TurbaNet focuses on training many small, independent models in parallel on a single GPU, optimizing for high throughput and efficiency. Rather than replicating models across devices or relying on complex scheduling and inter-device communication, TurbaNet exploits intra-GPU parallelism to efficiently utilize compute resources with minimal overhead. GPUs are designed to handle large-scale matrix operations efficiently, but when training small networks, a large portion of the available memory and computational throughput is wasted (left side of Figure 2.1). Traditional deep learning workflows may only make use of a fraction of a GPU’s available

128

Made with FlippingBook - Share PDF online