-
Notifications
You must be signed in to change notification settings - Fork 31
Open
Description
Hi,
Thank you for the great library! I’m observing some unexpected performance with GEMM on Hopper GPUs when using small M dimensions. I followed the example in example14_autotune.py.
Compared to the PyTorch implementation, the performance is significantly lower — around 30% of the expected TFLOPS and memory bandwidth utilization.
Not sure I am correctly using the API — I would greatly appreciate any suggestions.
Environment:
• GPU: H200
• CUDA: 12.8
• PyTorch: 2.6.0
• nvmath-python: 0.3.0
Benchmark code:
import argparse
import torch
import nvmath
from triton.testing import do_bench
def profile(m, n, k, dtype):
device = torch.device("cuda")
assert isinstance(device, torch.device)
X = torch.randn(m, k, device=device, dtype=dtype)
Y = torch.randn(n, k, device=device, dtype=dtype)
_torch_gemm = lambda: torch.matmul(X, Y)
mm = nvmath.linalg.advanced.Matmul(X, Y)
mm.plan(preferences={"limit":1000})
mm.autotune(iterations=1000)
# print(mm.algorithms[0].capabilities)
_nvmath_gemm = lambda: mm.execute()
t_torch = do_bench(_torch_gemm)
t_nvmath = do_bench(_nvmath_gemm)
return t_torch, t_nvmath
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GEMM profile")
parser.add_argument("--m", type=int, default=4096)
parser.add_argument("--n", type=int, default=4096)
parser.add_argument("--k", type=int, default=4096)
args = parser.parse_args()
print("Provider,Operation,dtype,m,n,k,Runtime,GB/s,GFLOPs")
for dtype in [torch.float16, torch.bfloat16]:
t_torch, t_nvmath = profile(args.m, args.n, args.k, dtype)
m = args.m
n = args.n
k = args.k
torch_mem_bd = 2 * (m * n + n * k + m * k) * 1e3 / t_torch / 1e9
torch_gflops = 2 * m * n * k * 1e3 / t_torch / 1e9
nv_mem_bd = 2 * (m * n + n * k + m * k) * 1e3 / t_nvmath / 1e9
nv_gflops = 2 * m * n * k * 1e3 / t_nvmath / 1e9
print(f"TORCH,0,{dtype},{args.m},{args.n},{args.k},{t_torch},{torch_mem_bd},{torch_gflops}")
print(f"NVMATH,0,{dtype},{args.m},{args.n},{args.k},{t_nvmath},{nv_mem_bd},{nv_gflops}")Metadata
Metadata
Assignees
Labels
No labels