11import jax
22import jax .numpy as jnp
3+ import numpy as np
34import torch
45from absl .testing import absltest
6+ from flax import nnx
57from huggingface_hub import snapshot_download
68from jax .sharding import AxisType
79from jax .sharding import PartitionSpec as P
1820class TestModuleForwardPasses (absltest .TestCase ):
1921 def setUp (self ):
2022 super ().setUp ()
23+ jax .config .update ("jax_default_matmul_precision" , "float32" )
2124 model_name : str = "Qwen/Qwen3-0.6B"
2225 self .tokenizer = AutoTokenizer .from_pretrained (model_name )
2326
@@ -27,7 +30,13 @@ def setUp(self):
2730 model_ckpt_path = snapshot_download ("Qwen/Qwen3-0.6B" )
2831 self .mesh = jax .make_mesh (((1 , 1 )), ("fsdp" , "tp" ), axis_types = (AxisType .Explicit , AxisType .Explicit ))
2932 jax .set_mesh (self .mesh )
30- self .nnx_model = params .create_model_from_safe_tensors (model_ckpt_path , self .bonsai_config , self .mesh )
33+
34+ # Cast JAX model to float32 for precision matching with PyTorch CPU
35+ graph_def , state = nnx .split (
36+ params .create_model_from_safe_tensors (model_ckpt_path , self .bonsai_config , self .mesh )
37+ )
38+ state = jax .tree .map (lambda x : x .astype (jnp .float32 ) if isinstance (x , jax .Array ) else x , state )
39+ self .nnx_model = nnx .merge (graph_def , state )
3140
3241 self .batch_size = 32
3342 self .num_input_tokens = 5
@@ -39,7 +48,11 @@ def _check_batched_logits(self, left_pads: int, torch_logits: torch.Tensor, nnx_
3948 max_len = torch_logits .shape [- 2 ]
4049 for lp , tl , nl in zip (left_pads , torch_logits , nnx_logits ):
4150 torch .testing .assert_close (
42- torch .tensor (nl )[lp :max_len , :], tl [lp :, :], rtol = self .relaxed_tol , atol = self .relaxed_tol
51+ torch .tensor (np .array (nl , dtype = np .float32 ))[lp :max_len , :],
52+ tl [lp :, :],
53+ rtol = self .relaxed_tol ,
54+ atol = self .relaxed_tol ,
55+ check_dtype = False ,
4356 )
4457
4558 def _setup_torch_attn (self , input_embeddings : torch .Tensor , attention_mask : None = None ):
@@ -122,147 +135,231 @@ def test_embedder(self):
122135 jx = jnp .array (tx .cpu ().detach ().numpy ())
123136
124137 jy , ty = nm .embedding .value .at [(jx ,)].get (), tm (tx )
125- torch .testing .assert_close (torch .tensor (jy ), ty )
138+ torch .testing .assert_close (
139+ torch .tensor (np .array (jy , dtype = np .float32 )),
140+ ty ,
141+ rtol = self .relaxed_tol ,
142+ atol = self .relaxed_tol ,
143+ check_dtype = False ,
144+ )
126145
127146 def test_decoder_layer (self ):
128147 nm = self .nnx_model .layers [0 ]
129148 tm = self .torch_model .model .layers [0 ].to (torch .float32 )
130149
131150 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .emb_dim )
132151 jx = jax .random .normal (jax .random .key (0 ), shape = shape )
133- tx = torch .tensor (jx )
152+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
134153 nnx_cache = self ._init_nnx_cache (self .batch_size )
135154 torch_inputs = self ._setup_torch_attn (tx )
136155
137156 jy , ty = nm (jx , nnx_cache [0 ], jnp .ones ((self .batch_size , self .num_input_tokens ))), tm (** torch_inputs )
138- torch .testing .assert_close (torch .tensor (jy ), ty )
157+ torch .testing .assert_close (
158+ torch .tensor (np .array (jy , dtype = np .float32 )),
159+ ty ,
160+ rtol = self .relaxed_tol ,
161+ atol = self .relaxed_tol ,
162+ check_dtype = False ,
163+ )
139164
140165 def test_all_decoder_layers (self ):
141166 nnx_cache = self ._init_nnx_cache (self .batch_size )
142167 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .emb_dim )
143168
144169 for nm , tm , nc in zip (self .nnx_model .layers , self .torch_model .model .layers , nnx_cache ):
145170 jx = jax .random .normal (jax .random .key (0 ), shape = shape )
146- tx = torch .tensor (jx )
171+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
147172
148173 jy = nm (jx , nc , jnp .ones ((self .batch_size , self .num_input_tokens )))
149174 torch_inputs = self ._setup_torch_attn (tx )
150175 ty = tm .to (torch .float32 )(** torch_inputs )
151- torch .testing .assert_close (torch .tensor (jy ), ty , atol = self .relaxed_tol , rtol = self .relaxed_tol )
176+ torch .testing .assert_close (
177+ torch .tensor (np .array (jy , dtype = np .float32 )),
178+ ty ,
179+ atol = self .relaxed_tol ,
180+ rtol = self .relaxed_tol ,
181+ check_dtype = False ,
182+ )
152183
153184 def test_rms_norm (self ):
154185 nm = self .nnx_model .layers [0 ].input_layernorm
155186 tm = self .torch_model .model .layers [0 ].input_layernorm
156187
157188 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .emb_dim )
158- jx = jax .random .normal (jax .random .key (0 ), shape = shape , dtype = jnp . bfloat16 )
159- tx = torch .tensor (jx )
189+ jx = jax .random .normal (jax .random .key (0 ), shape = shape )
190+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
160191
161192 jy , ty = nm (jx ), tm (tx )
162- torch .testing .assert_close (torch .tensor (jy ), ty )
193+ torch .testing .assert_close (
194+ torch .tensor (np .array (jy , dtype = np .float32 )),
195+ ty ,
196+ rtol = self .relaxed_tol ,
197+ atol = self .relaxed_tol ,
198+ check_dtype = False ,
199+ )
163200
164201 def test_self_attn (self ):
165202 nm = self .nnx_model .layers [0 ].attn
166203 tm = self .torch_model .model .layers [0 ].self_attn .to (torch .float32 )
167204
168205 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .emb_dim )
169206 jx = jax .random .normal (jax .random .key (0 ), shape = shape )
170- tx = torch .tensor (jx )
207+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
171208 torch_inputs = self ._setup_torch_attn (tx )
172209 nnx_cache = self ._init_nnx_cache (self .batch_size )
173210
174211 jy = nm (jx , nnx_cache [0 ], jnp .ones ((self .batch_size , self .num_input_tokens ), dtype = jnp .float32 ))
175212 ty = tm (** torch_inputs )[0 ]
176- torch .testing .assert_close (torch .tensor (jy ), ty )
213+ torch .testing .assert_close (
214+ torch .tensor (np .array (jy , dtype = np .float32 )),
215+ ty ,
216+ rtol = self .relaxed_tol ,
217+ atol = self .relaxed_tol ,
218+ check_dtype = False ,
219+ )
177220
178221 def test_q_norm (self ):
179222 nm = self .nnx_model .layers [0 ].attn .q_norm
180223 tm = self .torch_model .model .layers [0 ].self_attn .q_norm
181224
182225 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .num_heads , self .bonsai_config .head_dim )
183- jx = jax .random .normal (jax .random .key (0 ), shape = shape , dtype = jnp . bfloat16 )
184- tx = torch .tensor (jx )
226+ jx = jax .random .normal (jax .random .key (0 ), shape = shape )
227+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
185228
186229 jy , ty = nm (jx ), tm (tx )
187- torch .testing .assert_close (torch .tensor (jy ), ty )
230+ torch .testing .assert_close (
231+ torch .tensor (np .array (jy , dtype = np .float32 )),
232+ ty ,
233+ rtol = self .relaxed_tol ,
234+ atol = self .relaxed_tol ,
235+ check_dtype = False ,
236+ )
188237
189238 def test_k_norm (self ):
190239 nm = self .nnx_model .layers [0 ].attn .q_norm
191240 tm = self .torch_model .model .layers [0 ].self_attn .q_norm
192241
193242 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .num_kv_heads , self .bonsai_config .head_dim )
194- jx = jax .random .normal (jax .random .key (0 ), shape = shape , dtype = jnp . bfloat16 )
195- tx = torch .tensor (jx )
243+ jx = jax .random .normal (jax .random .key (0 ), shape = shape )
244+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
196245
197246 jy , ty = nm (jx ), tm (tx )
198- torch .testing .assert_close (torch .tensor (jy ), ty )
247+ torch .testing .assert_close (
248+ torch .tensor (np .array (jy , dtype = np .float32 )),
249+ ty ,
250+ rtol = self .relaxed_tol ,
251+ atol = self .relaxed_tol ,
252+ check_dtype = False ,
253+ )
199254
200255 def test_q_proj (self ):
201256 nm = self .nnx_model .layers [0 ].attn .q_proj
202- tm = self .torch_model .model .layers [0 ].self_attn .q_proj
257+ tm = self .torch_model .model .layers [0 ].self_attn .q_proj . to ( torch . float32 )
203258
204259 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .emb_dim )
205- jx = jax .random .normal (jax .random .key (0 ), shape = shape , dtype = jnp . bfloat16 )
206- tx = torch .tensor (jx )
260+ jx = jax .random .normal (jax .random .key (0 ), shape = shape )
261+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
207262
208263 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .num_heads , self .bonsai_config .head_dim )
209264 jy , ty = nm (jx ), tm (tx ).reshape (shape )
210- torch .testing .assert_close (torch .tensor (jy ), ty )
265+ torch .testing .assert_close (
266+ torch .tensor (np .array (jy , dtype = np .float32 )),
267+ ty ,
268+ rtol = self .relaxed_tol ,
269+ atol = self .relaxed_tol ,
270+ check_dtype = False ,
271+ )
211272
212273 def test_k_proj (self ):
213274 nm = self .nnx_model .layers [0 ].attn .k_proj
214- tm = self .torch_model .model .layers [0 ].self_attn .k_proj
275+ tm = self .torch_model .model .layers [0 ].self_attn .k_proj . to ( torch . float32 )
215276
216277 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .emb_dim )
217- jx = jax .random .normal (jax .random .key (0 ), shape = shape , dtype = jnp . bfloat16 )
218- tx = torch .tensor (jx )
278+ jx = jax .random .normal (jax .random .key (0 ), shape = shape )
279+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
219280
220281 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .num_kv_heads , self .bonsai_config .head_dim )
221282 jy , ty = nm (jx ), tm (tx ).reshape (shape )
222- torch .testing .assert_close (torch .tensor (jy ), ty )
283+ torch .testing .assert_close (
284+ torch .tensor (np .array (jy , dtype = np .float32 )),
285+ ty ,
286+ rtol = self .relaxed_tol ,
287+ atol = self .relaxed_tol ,
288+ check_dtype = False ,
289+ )
223290
224291 def test_o_proj (self ):
225292 nm = self .nnx_model .layers [0 ].attn .o_proj
226- tm = self .torch_model .model .layers [0 ].self_attn .o_proj
293+ tm = self .torch_model .model .layers [0 ].self_attn .o_proj . to ( torch . float32 )
227294
228295 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .num_heads , self .bonsai_config .head_dim )
229- jx = jax .random .normal (jax .random .key (0 ), shape = shape , dtype = jnp . bfloat16 )
230- tx = torch .tensor (jx ).reshape (self .batch_size , self .num_input_tokens , - 1 )
296+ jx = jax .random .normal (jax .random .key (0 ), shape = shape )
297+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) ).reshape (self .batch_size , self .num_input_tokens , - 1 )
231298
232299 jy , ty = nm (jx ), tm (tx )
233- torch .testing .assert_close (torch .tensor (jy ), ty )
300+ torch .testing .assert_close (
301+ torch .tensor (np .array (jy , dtype = np .float32 )),
302+ ty ,
303+ rtol = self .relaxed_tol ,
304+ atol = self .relaxed_tol ,
305+ check_dtype = False ,
306+ )
234307
235308 def test_mlp (self ):
236309 nm = self .nnx_model .layers [0 ].mlp
237310 tm = self .torch_model .model .layers [0 ].mlp .to (torch .float32 )
238311
239312 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .emb_dim )
240313 jx = jax .random .normal (jax .random .key (0 ), shape = shape )
241- tx = torch .tensor (jx )
314+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
242315
243316 jy , ty = nm (jx ), tm (tx )
244- torch .testing .assert_close (torch .tensor (jy ), ty , rtol = self .relaxed_tol , atol = self .relaxed_tol )
317+ torch .testing .assert_close (
318+ torch .tensor (np .array (jy , dtype = np .float32 )),
319+ ty ,
320+ rtol = self .relaxed_tol ,
321+ atol = self .relaxed_tol ,
322+ check_dtype = False ,
323+ )
245324
246325 def test_lm_head (self ):
247326 nm = self .nnx_model .lm_head
248327 tm = self .torch_model .lm_head .to (torch .float32 )
249328
250329 shape = (self .batch_size , self .num_input_tokens , self .bonsai_config .emb_dim )
251330 jx = jax .random .normal (jax .random .key (0 ), shape = shape )
252- tx = torch .tensor (jx )
331+ tx = torch .tensor (np . array ( jx , dtype = np . float32 ) )
253332
254333 jy , ty = nm (jx ), tm (tx )
255- torch .testing .assert_close (torch .tensor (jy ), ty )
334+ torch .testing .assert_close (
335+ torch .tensor (np .array (jy , dtype = np .float32 )),
336+ ty ,
337+ rtol = self .relaxed_tol ,
338+ atol = self .relaxed_tol ,
339+ check_dtype = False ,
340+ )
256341
257342 def test_sin_cos (self ):
258343 batch_size , seq_len , dim = 2 , 10 , 128
259344 hidden_states = torch .ones ((batch_size , seq_len , dim ))
260345 jp = jnp .stack ([jnp .arange (seq_len ), jnp .arange (seq_len )])
261346 js , jc = modeling ._generate_pos_embeddings (jp , dim )
262- tc , ts = self .torch_model .model .rotary_emb (hidden_states , torch .tensor (jp ))
347+ tc , ts = self .torch_model .model .rotary_emb (hidden_states , torch .tensor (np . array ( jp , dtype = np . float32 ) ))
263348 tc , ts = tc [:, :, : dim // 2 ], ts [:, :, : dim // 2 ]
264- torch .testing .assert_close (torch .tensor (js ), ts )
265- torch .testing .assert_close (torch .tensor (jc ), tc )
349+ torch .testing .assert_close (
350+ torch .tensor (np .array (js , dtype = np .float32 )),
351+ ts ,
352+ rtol = self .relaxed_tol ,
353+ atol = self .relaxed_tol ,
354+ check_dtype = False ,
355+ )
356+ torch .testing .assert_close (
357+ torch .tensor (np .array (jc , dtype = np .float32 )),
358+ tc ,
359+ rtol = self .relaxed_tol ,
360+ atol = self .relaxed_tol ,
361+ check_dtype = False ,
362+ )
266363
267364 def test_full (self ):
268365 query = ["Why is the sky blue instead of any other color like purple?" ]
@@ -275,7 +372,11 @@ def test_full(self):
275372 torch_inputs = self ._process_hf_tokens (query )
276373 torch_logits = self .torch_model (** torch_inputs ).logits
277374 torch .testing .assert_close (
278- torch .tensor (nnx_logits )[:, :token_len , :], torch_logits , rtol = self .relaxed_tol , atol = self .relaxed_tol
375+ torch .tensor (np .array (nnx_logits , dtype = np .float32 ))[:, :token_len , :],
376+ torch_logits ,
377+ rtol = self .relaxed_tol ,
378+ atol = self .relaxed_tol ,
379+ check_dtype = False ,
279380 )
280381
281382 def test_full_batched (self ):
0 commit comments