WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Support Layout.tiling in with_layout_constraint #33543

@yuwei-qin

Description

@yuwei-qin

Description

Description

with_layout_constraint accepts a Layout with tiling, but silently drops it—only major_to_minor gets passed to XLA. There's a TODO in the code acknowledging this isn't implemented yet.

This limits how much we can extend the optimization between jax (with global flags) and kernels.

Minimal Reproduction

import jax
import jax.numpy as jnp
from jax.experimental.layout import Layout
from jax._src.pjit import with_layout_constraint

# Create a Layout with explicit tiling
layout = Layout(
    major_to_minor=(0, 1, 2),
    tiling=((8, 128), (2, 1)),  # <-- This is silently ignored!
)

print(f"Layout.tiling: {layout.tiling}")  # Shows ((8, 128), (2, 1))

@jax.jit
def apply_constraint(x):
    return with_layout_constraint(x, layout)

x = jnp.ones((4, 16, 128), dtype=jnp.bfloat16)
lowered = apply_constraint.lower(x)
hlo_text = lowered.as_text()

# Check if tiling appears in HLO
print(f'"LayoutConstraint" in HLO: {"LayoutConstraint" in hlo_text}')  # True
print(f'"T(8,128)" in HLO: {"T(8,128)" in hlo_text}')  # False - tiling dropped!

Output:

module @jit_apply_constraint attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x16x128xbf16>) -> (tensor<4x16x128xbf16> {jax.result_info = "result"}) {
%0 = stablehlo.custom_call @LayoutConstraint(%arg0) {backend_config = "", operand_layouts = [dense<[0, 1, 2]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<4x16x128xbf16>) -> tensor<4x16x128xbf16>
return %0 : tensor<4x16x128xbf16>
}
}

Layout.tiling: ((8, 128), (2, 1))
"LayoutConstraint" in HLO: True
"T(8,128)" in HLO: False
CONFIRMED: Tiling was specified but NOT passed to HLO!

The generated HLO shows only dimension ordering, no tiling:
result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]

Root Cause

The issue is in jax/_src/interpreters/mlir.py in the wrap_with_layout_op function (line ~3001):

op = custom_call('LayoutConstraint', result_types=[result_type], operands=[x],
                 # ...
                 # TODO(yashkatariya): Figure out how to pass tiling to the
                 # custom call.
                 result_layouts=[layout.major_to_minor[::-1]])  # <-- Only this is passed!

There's already a TODO comment acknowledging this limitation. The layout.tiling attribute exists and is populated, but it's never used.


System info (python version, jaxlib version, accelerator, etc.)

  • JAX version: 0.8.1.dev20251105 (issue exists in current main)
  • Platform: Reproducible on CPU; use case is TPU (v5e, v6e)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions