Hi there. I've been following along this tutorial from the site (https://docs.jaxstack.ai/en/latest/digits_diffusion_model.html). When getting to the model initialization step, I receive a segmentation fault.
The hardware I'm running on is a AMD Ryzen Threadripper PRO 7945WX s (24) @ 5.380GHz, with 32GB RAM, and a NVIDIA RTX 5000 Ada Generation GPU. For this model, I'm currently trying to run this on GPU.
Many thanks in advance