WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

jax.experimental.sparse.CSR accepts wrong shapes #33514

@johannahaffner

Description

@johannahaffner

Description

The following code reproduces the error:

import jax.numpy as jnp
from jax.experimental.sparse import CSR


matrix = jnp.eye(2)
csr = CSR.fromdense(matrix)

args = (csr.data, csr.indices, csr.indptr)
new_csr = CSR(args, shape=(3, 3))  # Wrong shape, does not raise

print(csr, new_csr)
print(CSR.todense(csr), CSR.todense(new_csr))  # Raises error

I think it would be useful to the end user if the error raised in todense would also be raised when creating a new sparse operator.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.8.1
jaxlib: 0.8.1
numpy: 2.3.3
python: 3.13.3 (main, Apr 8 2025, 13:54:08) [Clang 16.0.0 (clang-1600.0.26.6)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node=REDACTED, release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:26 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8112', machine='arm64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions