WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 63388b6

Browse files
jenriverchapman20j
authored andcommitted
Fix TPU compatibility for model tests
This commit resolves `BufferError` and assertion failures when running tests on TPU for qwen3, densenet121, vit Key changes: - Fix DLPack BufferError: use np during JAX-to-PyTorch conversion (`torch.tensor(np.array(jax_array))`) to force data transfer from TPU to CPU host memory, bypassing DLPack transfer unsupported between TPU and PyTorch CPU. - Enforce Float32 Precision: Explicitly cast JAX model weights to `float32` and set `jax_default_matmul_precision` to `float32`. Prevent implicit downcasting to `bfloat16` on TPU. These changes ensure consistent test behavior across CPU and TPU environments.
1 parent a11f3f4 commit 63388b6

File tree

3 files changed

+164
-48
lines changed

3 files changed

+164
-48
lines changed

bonsai/models/densenet121/tests/test_outputs_densenet121.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import tensorflow as tf
2020
from absl.testing import absltest
21+
from flax import nnx
2122
from huggingface_hub import snapshot_download
2223

2324
from bonsai.models.densenet121 import modeling, params
@@ -26,10 +27,16 @@
2627
class TestModuleForwardPasses(absltest.TestCase):
2728
def setUp(self):
2829
super().setUp()
30+
jax.config.update("jax_default_matmul_precision", "float32")
2931
try:
3032
self.ref_model = keras_hub.models.ImageClassifier.from_preset("densenet_121_imagenet")
3133
model_ckpt_path = snapshot_download("keras/densenet_121_imagenet")
32-
self.nnx_model = params.create_model_from_h5(model_ckpt_path, modeling.ModelConfig.densenet_121())
34+
graph_def, state = nnx.split(
35+
params.create_model_from_h5(model_ckpt_path, modeling.ModelConfig.densenet_121())
36+
)
37+
state = jax.tree.map(lambda x: x.astype(jnp.float32) if isinstance(x, jax.Array) else x, state)
38+
self.nnx_model = nnx.merge(graph_def, state)
39+
3340
except Exception as e:
3441
self.skipTest(
3542
"Skipping test because tensorflow-text requires 3.12 or below: %s"

bonsai/models/qwen3/tests/test_outputs_qwen3.py

Lines changed: 139 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import jax
22
import jax.numpy as jnp
3+
import numpy as np
34
import torch
45
from absl.testing import absltest
6+
from flax import nnx
57
from huggingface_hub import snapshot_download
68
from jax.sharding import AxisType
79
from jax.sharding import PartitionSpec as P
@@ -18,6 +20,7 @@
1820
class TestModuleForwardPasses(absltest.TestCase):
1921
def setUp(self):
2022
super().setUp()
23+
jax.config.update("jax_default_matmul_precision", "float32")
2124
model_name: str = "Qwen/Qwen3-0.6B"
2225
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
2326

@@ -27,7 +30,13 @@ def setUp(self):
2730
model_ckpt_path = snapshot_download("Qwen/Qwen3-0.6B")
2831
self.mesh = jax.make_mesh(((1, 1)), ("fsdp", "tp"), axis_types=(AxisType.Explicit, AxisType.Explicit))
2932
jax.set_mesh(self.mesh)
30-
self.nnx_model = params.create_model_from_safe_tensors(model_ckpt_path, self.bonsai_config, self.mesh)
33+
34+
# Cast JAX model to float32 for precision matching with PyTorch CPU
35+
graph_def, state = nnx.split(
36+
params.create_model_from_safe_tensors(model_ckpt_path, self.bonsai_config, self.mesh)
37+
)
38+
state = jax.tree.map(lambda x: x.astype(jnp.float32) if isinstance(x, jax.Array) else x, state)
39+
self.nnx_model = nnx.merge(graph_def, state)
3140

3241
self.batch_size = 32
3342
self.num_input_tokens = 5
@@ -39,7 +48,11 @@ def _check_batched_logits(self, left_pads: int, torch_logits: torch.Tensor, nnx_
3948
max_len = torch_logits.shape[-2]
4049
for lp, tl, nl in zip(left_pads, torch_logits, nnx_logits):
4150
torch.testing.assert_close(
42-
torch.tensor(nl)[lp:max_len, :], tl[lp:, :], rtol=self.relaxed_tol, atol=self.relaxed_tol
51+
torch.tensor(np.array(nl, dtype=np.float32))[lp:max_len, :],
52+
tl[lp:, :],
53+
rtol=self.relaxed_tol,
54+
atol=self.relaxed_tol,
55+
check_dtype=False,
4356
)
4457

4558
def _setup_torch_attn(self, input_embeddings: torch.Tensor, attention_mask: None = None):
@@ -122,147 +135,231 @@ def test_embedder(self):
122135
jx = jnp.array(tx.cpu().detach().numpy())
123136

124137
jy, ty = nm.embedding.value.at[(jx,)].get(), tm(tx)
125-
torch.testing.assert_close(torch.tensor(jy), ty)
138+
torch.testing.assert_close(
139+
torch.tensor(np.array(jy, dtype=np.float32)),
140+
ty,
141+
rtol=self.relaxed_tol,
142+
atol=self.relaxed_tol,
143+
check_dtype=False,
144+
)
126145

127146
def test_decoder_layer(self):
128147
nm = self.nnx_model.layers[0]
129148
tm = self.torch_model.model.layers[0].to(torch.float32)
130149

131150
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.emb_dim)
132151
jx = jax.random.normal(jax.random.key(0), shape=shape)
133-
tx = torch.tensor(jx)
152+
tx = torch.tensor(np.array(jx, dtype=np.float32))
134153
nnx_cache = self._init_nnx_cache(self.batch_size)
135154
torch_inputs = self._setup_torch_attn(tx)
136155

137156
jy, ty = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens))), tm(**torch_inputs)
138-
torch.testing.assert_close(torch.tensor(jy), ty)
157+
torch.testing.assert_close(
158+
torch.tensor(np.array(jy, dtype=np.float32)),
159+
ty,
160+
rtol=self.relaxed_tol,
161+
atol=self.relaxed_tol,
162+
check_dtype=False,
163+
)
139164

140165
def test_all_decoder_layers(self):
141166
nnx_cache = self._init_nnx_cache(self.batch_size)
142167
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.emb_dim)
143168

144169
for nm, tm, nc in zip(self.nnx_model.layers, self.torch_model.model.layers, nnx_cache):
145170
jx = jax.random.normal(jax.random.key(0), shape=shape)
146-
tx = torch.tensor(jx)
171+
tx = torch.tensor(np.array(jx, dtype=np.float32))
147172

148173
jy = nm(jx, nc, jnp.ones((self.batch_size, self.num_input_tokens)))
149174
torch_inputs = self._setup_torch_attn(tx)
150175
ty = tm.to(torch.float32)(**torch_inputs)
151-
torch.testing.assert_close(torch.tensor(jy), ty, atol=self.relaxed_tol, rtol=self.relaxed_tol)
176+
torch.testing.assert_close(
177+
torch.tensor(np.array(jy, dtype=np.float32)),
178+
ty,
179+
atol=self.relaxed_tol,
180+
rtol=self.relaxed_tol,
181+
check_dtype=False,
182+
)
152183

153184
def test_rms_norm(self):
154185
nm = self.nnx_model.layers[0].input_layernorm
155186
tm = self.torch_model.model.layers[0].input_layernorm
156187

157188
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.emb_dim)
158-
jx = jax.random.normal(jax.random.key(0), shape=shape, dtype=jnp.bfloat16)
159-
tx = torch.tensor(jx)
189+
jx = jax.random.normal(jax.random.key(0), shape=shape)
190+
tx = torch.tensor(np.array(jx, dtype=np.float32))
160191

161192
jy, ty = nm(jx), tm(tx)
162-
torch.testing.assert_close(torch.tensor(jy), ty)
193+
torch.testing.assert_close(
194+
torch.tensor(np.array(jy, dtype=np.float32)),
195+
ty,
196+
rtol=self.relaxed_tol,
197+
atol=self.relaxed_tol,
198+
check_dtype=False,
199+
)
163200

164201
def test_self_attn(self):
165202
nm = self.nnx_model.layers[0].attn
166203
tm = self.torch_model.model.layers[0].self_attn.to(torch.float32)
167204

168205
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.emb_dim)
169206
jx = jax.random.normal(jax.random.key(0), shape=shape)
170-
tx = torch.tensor(jx)
207+
tx = torch.tensor(np.array(jx, dtype=np.float32))
171208
torch_inputs = self._setup_torch_attn(tx)
172209
nnx_cache = self._init_nnx_cache(self.batch_size)
173210

174211
jy = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens), dtype=jnp.float32))
175212
ty = tm(**torch_inputs)[0]
176-
torch.testing.assert_close(torch.tensor(jy), ty)
213+
torch.testing.assert_close(
214+
torch.tensor(np.array(jy, dtype=np.float32)),
215+
ty,
216+
rtol=self.relaxed_tol,
217+
atol=self.relaxed_tol,
218+
check_dtype=False,
219+
)
177220

178221
def test_q_norm(self):
179222
nm = self.nnx_model.layers[0].attn.q_norm
180223
tm = self.torch_model.model.layers[0].self_attn.q_norm
181224

182225
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.num_heads, self.bonsai_config.head_dim)
183-
jx = jax.random.normal(jax.random.key(0), shape=shape, dtype=jnp.bfloat16)
184-
tx = torch.tensor(jx)
226+
jx = jax.random.normal(jax.random.key(0), shape=shape)
227+
tx = torch.tensor(np.array(jx, dtype=np.float32))
185228

186229
jy, ty = nm(jx), tm(tx)
187-
torch.testing.assert_close(torch.tensor(jy), ty)
230+
torch.testing.assert_close(
231+
torch.tensor(np.array(jy, dtype=np.float32)),
232+
ty,
233+
rtol=self.relaxed_tol,
234+
atol=self.relaxed_tol,
235+
check_dtype=False,
236+
)
188237

189238
def test_k_norm(self):
190239
nm = self.nnx_model.layers[0].attn.q_norm
191240
tm = self.torch_model.model.layers[0].self_attn.q_norm
192241

193242
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.num_kv_heads, self.bonsai_config.head_dim)
194-
jx = jax.random.normal(jax.random.key(0), shape=shape, dtype=jnp.bfloat16)
195-
tx = torch.tensor(jx)
243+
jx = jax.random.normal(jax.random.key(0), shape=shape)
244+
tx = torch.tensor(np.array(jx, dtype=np.float32))
196245

197246
jy, ty = nm(jx), tm(tx)
198-
torch.testing.assert_close(torch.tensor(jy), ty)
247+
torch.testing.assert_close(
248+
torch.tensor(np.array(jy, dtype=np.float32)),
249+
ty,
250+
rtol=self.relaxed_tol,
251+
atol=self.relaxed_tol,
252+
check_dtype=False,
253+
)
199254

200255
def test_q_proj(self):
201256
nm = self.nnx_model.layers[0].attn.q_proj
202-
tm = self.torch_model.model.layers[0].self_attn.q_proj
257+
tm = self.torch_model.model.layers[0].self_attn.q_proj.to(torch.float32)
203258

204259
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.emb_dim)
205-
jx = jax.random.normal(jax.random.key(0), shape=shape, dtype=jnp.bfloat16)
206-
tx = torch.tensor(jx)
260+
jx = jax.random.normal(jax.random.key(0), shape=shape)
261+
tx = torch.tensor(np.array(jx, dtype=np.float32))
207262

208263
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.num_heads, self.bonsai_config.head_dim)
209264
jy, ty = nm(jx), tm(tx).reshape(shape)
210-
torch.testing.assert_close(torch.tensor(jy), ty)
265+
torch.testing.assert_close(
266+
torch.tensor(np.array(jy, dtype=np.float32)),
267+
ty,
268+
rtol=self.relaxed_tol,
269+
atol=self.relaxed_tol,
270+
check_dtype=False,
271+
)
211272

212273
def test_k_proj(self):
213274
nm = self.nnx_model.layers[0].attn.k_proj
214-
tm = self.torch_model.model.layers[0].self_attn.k_proj
275+
tm = self.torch_model.model.layers[0].self_attn.k_proj.to(torch.float32)
215276

216277
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.emb_dim)
217-
jx = jax.random.normal(jax.random.key(0), shape=shape, dtype=jnp.bfloat16)
218-
tx = torch.tensor(jx)
278+
jx = jax.random.normal(jax.random.key(0), shape=shape)
279+
tx = torch.tensor(np.array(jx, dtype=np.float32))
219280

220281
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.num_kv_heads, self.bonsai_config.head_dim)
221282
jy, ty = nm(jx), tm(tx).reshape(shape)
222-
torch.testing.assert_close(torch.tensor(jy), ty)
283+
torch.testing.assert_close(
284+
torch.tensor(np.array(jy, dtype=np.float32)),
285+
ty,
286+
rtol=self.relaxed_tol,
287+
atol=self.relaxed_tol,
288+
check_dtype=False,
289+
)
223290

224291
def test_o_proj(self):
225292
nm = self.nnx_model.layers[0].attn.o_proj
226-
tm = self.torch_model.model.layers[0].self_attn.o_proj
293+
tm = self.torch_model.model.layers[0].self_attn.o_proj.to(torch.float32)
227294

228295
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.num_heads, self.bonsai_config.head_dim)
229-
jx = jax.random.normal(jax.random.key(0), shape=shape, dtype=jnp.bfloat16)
230-
tx = torch.tensor(jx).reshape(self.batch_size, self.num_input_tokens, -1)
296+
jx = jax.random.normal(jax.random.key(0), shape=shape)
297+
tx = torch.tensor(np.array(jx, dtype=np.float32)).reshape(self.batch_size, self.num_input_tokens, -1)
231298

232299
jy, ty = nm(jx), tm(tx)
233-
torch.testing.assert_close(torch.tensor(jy), ty)
300+
torch.testing.assert_close(
301+
torch.tensor(np.array(jy, dtype=np.float32)),
302+
ty,
303+
rtol=self.relaxed_tol,
304+
atol=self.relaxed_tol,
305+
check_dtype=False,
306+
)
234307

235308
def test_mlp(self):
236309
nm = self.nnx_model.layers[0].mlp
237310
tm = self.torch_model.model.layers[0].mlp.to(torch.float32)
238311

239312
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.emb_dim)
240313
jx = jax.random.normal(jax.random.key(0), shape=shape)
241-
tx = torch.tensor(jx)
314+
tx = torch.tensor(np.array(jx, dtype=np.float32))
242315

243316
jy, ty = nm(jx), tm(tx)
244-
torch.testing.assert_close(torch.tensor(jy), ty, rtol=self.relaxed_tol, atol=self.relaxed_tol)
317+
torch.testing.assert_close(
318+
torch.tensor(np.array(jy, dtype=np.float32)),
319+
ty,
320+
rtol=self.relaxed_tol,
321+
atol=self.relaxed_tol,
322+
check_dtype=False,
323+
)
245324

246325
def test_lm_head(self):
247326
nm = self.nnx_model.lm_head
248327
tm = self.torch_model.lm_head.to(torch.float32)
249328

250329
shape = (self.batch_size, self.num_input_tokens, self.bonsai_config.emb_dim)
251330
jx = jax.random.normal(jax.random.key(0), shape=shape)
252-
tx = torch.tensor(jx)
331+
tx = torch.tensor(np.array(jx, dtype=np.float32))
253332

254333
jy, ty = nm(jx), tm(tx)
255-
torch.testing.assert_close(torch.tensor(jy), ty)
334+
torch.testing.assert_close(
335+
torch.tensor(np.array(jy, dtype=np.float32)),
336+
ty,
337+
rtol=self.relaxed_tol,
338+
atol=self.relaxed_tol,
339+
check_dtype=False,
340+
)
256341

257342
def test_sin_cos(self):
258343
batch_size, seq_len, dim = 2, 10, 128
259344
hidden_states = torch.ones((batch_size, seq_len, dim))
260345
jp = jnp.stack([jnp.arange(seq_len), jnp.arange(seq_len)])
261346
js, jc = modeling._generate_pos_embeddings(jp, dim)
262-
tc, ts = self.torch_model.model.rotary_emb(hidden_states, torch.tensor(jp))
347+
tc, ts = self.torch_model.model.rotary_emb(hidden_states, torch.tensor(np.array(jp, dtype=np.float32)))
263348
tc, ts = tc[:, :, : dim // 2], ts[:, :, : dim // 2]
264-
torch.testing.assert_close(torch.tensor(js), ts)
265-
torch.testing.assert_close(torch.tensor(jc), tc)
349+
torch.testing.assert_close(
350+
torch.tensor(np.array(js, dtype=np.float32)),
351+
ts,
352+
rtol=self.relaxed_tol,
353+
atol=self.relaxed_tol,
354+
check_dtype=False,
355+
)
356+
torch.testing.assert_close(
357+
torch.tensor(np.array(jc, dtype=np.float32)),
358+
tc,
359+
rtol=self.relaxed_tol,
360+
atol=self.relaxed_tol,
361+
check_dtype=False,
362+
)
266363

267364
def test_full(self):
268365
query = ["Why is the sky blue instead of any other color like purple?"]
@@ -275,7 +372,11 @@ def test_full(self):
275372
torch_inputs = self._process_hf_tokens(query)
276373
torch_logits = self.torch_model(**torch_inputs).logits
277374
torch.testing.assert_close(
278-
torch.tensor(nnx_logits)[:, :token_len, :], torch_logits, rtol=self.relaxed_tol, atol=self.relaxed_tol
375+
torch.tensor(np.array(nnx_logits, dtype=np.float32))[:, :token_len, :],
376+
torch_logits,
377+
rtol=self.relaxed_tol,
378+
atol=self.relaxed_tol,
379+
check_dtype=False,
279380
)
280381

281382
def test_full_batched(self):

0 commit comments

Comments
 (0)