WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Low GEMM Performance on Hopper GPU with Small M Shapes #21

@happierpig

Description

@happierpig

Hi,

Thank you for the great library! I’m observing some unexpected performance with GEMM on Hopper GPUs when using small M dimensions. I followed the example in example14_autotune.py.

Compared to the PyTorch implementation, the performance is significantly lower — around 30% of the expected TFLOPS and memory bandwidth utilization.

Image

Not sure I am correctly using the API — I would greatly appreciate any suggestions.

Environment:
• GPU: H200
• CUDA: 12.8
• PyTorch: 2.6.0
• nvmath-python: 0.3.0

Benchmark code:

import argparse

import torch
import nvmath
from triton.testing import do_bench


def profile(m, n, k, dtype):
    device = torch.device("cuda")
    assert isinstance(device, torch.device)

    X = torch.randn(m, k, device=device, dtype=dtype)
    Y = torch.randn(n, k, device=device, dtype=dtype)
    
    _torch_gemm = lambda: torch.matmul(X, Y)
    
    mm = nvmath.linalg.advanced.Matmul(X, Y)
    
    mm.plan(preferences={"limit":1000})
    mm.autotune(iterations=1000)
    
    # print(mm.algorithms[0].capabilities)
    _nvmath_gemm = lambda: mm.execute()
    
    t_torch = do_bench(_torch_gemm)
    t_nvmath = do_bench(_nvmath_gemm)
    
    return t_torch, t_nvmath


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GEMM profile")
    parser.add_argument("--m", type=int, default=4096)
    parser.add_argument("--n", type=int, default=4096)
    parser.add_argument("--k", type=int, default=4096)
    args = parser.parse_args()

    print("Provider,Operation,dtype,m,n,k,Runtime,GB/s,GFLOPs")
    
    for dtype in [torch.float16, torch.bfloat16]:
        t_torch, t_nvmath = profile(args.m, args.n, args.k, dtype)

        m = args.m
        n = args.n
        k = args.k

        torch_mem_bd = 2 * (m * n + n * k + m * k) * 1e3 / t_torch / 1e9
        torch_gflops = 2 * m * n * k * 1e3 / t_torch / 1e9
        nv_mem_bd = 2 * (m * n + n * k + m * k) * 1e3 / t_nvmath / 1e9
        nv_gflops = 2 * m * n * k * 1e3 / t_nvmath / 1e9
        
        print(f"TORCH,0,{dtype},{args.m},{args.n},{args.k},{t_torch},{torch_mem_bd},{torch_gflops}")
        print(f"NVMATH,0,{dtype},{args.m},{args.n},{args.k},{t_nvmath},{nv_mem_bd},{nv_gflops}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions