AAI_2025_Capstone_Chronicles_Combined

Figure 3.2: Example Predictions from a Model Trained on the Two Spiral Problem (Whitney, 2021). Reprinted from "Parallelizing neural networks on one GPU with JAX" by W. Whitney, 2021, Will Whitney's Blog (https://willWhitney.com/parallel-training-jax.html). ©2021 by Will Whitney.

The data generated for this problem does not require any cleaning, transforming, or feature engineering as it is a test problem where we synthetically generate the data. Therefore, the data is generated within the script, and flows directly into the models without any modification. A key thing to note, when working with the parallelized training, we must take the training data, and expand it into a new dimension and repeat the data for the number of networks in the swarm. In this application, each network will receive the same inputs which will result in the networks all giving very similar predictions and will be purely differentiated only by the starting seed of their parameters. In more complex problems, one may want to train each of the networks on different portions of the data to perhaps achieve a mixture of experts model at the end rather than 1000 networks that return the same prediction. As an application of the library it made sense to apply it to a simple machine learning standard problem. A classic problem in this field is the MNIST handwritten digit dataset, which has been solved many times over the past two decades by various different approaches. This provides a good testing ground to ensure that the library is indeed functioning appropriately and can provide more insight on runtime comparisons for a real-life example. Some examples pulled from this dataset is provided in Figure 3.3:

132

Made with FlippingBook - Share PDF online