-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
Description
with_layout_constraint accepts a Layout with tiling, but silently drops it—only major_to_minor gets passed to XLA. There's a TODO in the code acknowledging this isn't implemented yet.
This limits how much we can extend the optimization between jax (with global flags) and kernels.
Minimal Reproduction
import jax
import jax.numpy as jnp
from jax.experimental.layout import Layout
from jax._src.pjit import with_layout_constraint
# Create a Layout with explicit tiling
layout = Layout(
major_to_minor=(0, 1, 2),
tiling=((8, 128), (2, 1)), # <-- This is silently ignored!
)
print(f"Layout.tiling: {layout.tiling}") # Shows ((8, 128), (2, 1))
@jax.jit
def apply_constraint(x):
return with_layout_constraint(x, layout)
x = jnp.ones((4, 16, 128), dtype=jnp.bfloat16)
lowered = apply_constraint.lower(x)
hlo_text = lowered.as_text()
# Check if tiling appears in HLO
print(f'"LayoutConstraint" in HLO: {"LayoutConstraint" in hlo_text}') # True
print(f'"T(8,128)" in HLO: {"T(8,128)" in hlo_text}') # False - tiling dropped!Output:
module @jit_apply_constraint attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x16x128xbf16>) -> (tensor<4x16x128xbf16> {jax.result_info = "result"}) {
%0 = stablehlo.custom_call @LayoutConstraint(%arg0) {backend_config = "", operand_layouts = [dense<[0, 1, 2]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<4x16x128xbf16>) -> tensor<4x16x128xbf16>
return %0 : tensor<4x16x128xbf16>
}
}
Layout.tiling: ((8, 128), (2, 1))
"LayoutConstraint" in HLO: True
"T(8,128)" in HLO: False
CONFIRMED: Tiling was specified but NOT passed to HLO!
The generated HLO shows only dimension ordering, no tiling:
result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]
Root Cause
The issue is in jax/_src/interpreters/mlir.py in the wrap_with_layout_op function (line ~3001):
op = custom_call('LayoutConstraint', result_types=[result_type], operands=[x],
# ...
# TODO(yashkatariya): Figure out how to pass tiling to the
# custom call.
result_layouts=[layout.major_to_minor[::-1]]) # <-- Only this is passed!There's already a TODO comment acknowledging this limitation. The layout.tiling attribute exists and is populated, but it's never used.
System info (python version, jaxlib version, accelerator, etc.)
- JAX version: 0.8.1.dev20251105 (issue exists in current main)
- Platform: Reproducible on CPU; use case is TPU (v5e, v6e)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working