A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/pytorch/pytorch/issues/133974 below:

Crash When Using torch.compile with Math scaled_dot_product_attention in AMP Mode · Issue #133974 · pytorch/pytorch · GitHub

🐛 Describe the bug

Summary
When attempting to compile a math sdp operation using torch.compile in AMP mode, the script encounters a crash. This issue does not occur with a single Flash Attention operation.

Reproducible Script
Below is a minimal script that reproduces the crash:

from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config

import os
torch._inductor.config.fallback_random = True


# torch._dynamo.config.base_dir = os.environ["TORCHINDUCTOR_CACHE_DIR"]

# disable cuda sdp, before we implement it with xpu backend
torch.backends.cuda.enable_cudnn_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)



from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()



    def forward(self, L_L_self_modules_self_attn_modules_out_proj_parameters_weight_ : torch.Tensor, L_L_self_modules_self_attn_modules_out_proj_parameters_bias_ : torch.Tensor, key_states, value_states, query_states_1):
        l_l_self_modules_self_attn_modules_out_proj_parameters_weight_ = L_L_self_modules_self_attn_modules_out_proj_parameters_weight_
        l_l_self_modules_self_attn_modules_out_proj_parameters_bias_ = L_L_self_modules_self_attn_modules_out_proj_parameters_bias_
        attn_output = torch._C._nn.scaled_dot_product_attention(query_states_1, key_states, value_states, attn_mask = None, dropout_p = 0.0, is_causal = True);  query_states_1 = key_states = value_states = None
        attn_output_1 = attn_output.transpose(1, 2);  attn_output = None
        attn_output_2 = attn_output_1.reshape(1, 1024, 1024);  attn_output_1 = None
        attn_output_3 = torch._C._nn.linear(attn_output_2, l_l_self_modules_self_attn_modules_out_proj_parameters_weight_, l_l_self_modules_self_attn_modules_out_proj_parameters_bias_);  attn_output_2 = l_l_self_modules_self_attn_modules_out_proj_parameters_weight_ = l_l_self_modules_self_attn_modules_out_proj_parameters_bias_ = None
        return (attn_output_3,)


mod = Repro()

def load_args(reader):
    buf0 = reader.storage('18f6e31c9f38831ae36b922d77c4c0546a0b5bb5', 4194304, device=device(type='cuda', index=0))
    reader.tensor(buf0, (1024, 1024), requires_grad=True, is_leaf=True)  # L_L_self_modules_self_attn_modules_out_proj_parameters_weight_
    buf1 = reader.storage('1ceaf73df40e531df3bfb26b4fb7cd95fb7bff1d', 4096, device=device(type='cuda', index=0))
    reader.tensor(buf1, (1024,), requires_grad=True, is_leaf=True)  # L_L_self_modules_self_attn_modules_out_proj_parameters_bias_
    buf2 = reader.storage('9f89f5ef3ecccc9aa2d2d9ec5d6882b315806934', 2097152, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf2, (1, 16, 1024, 64), dtype=torch.float16, requires_grad=True)  # key_states
    buf3 = reader.storage('74883a0ed186fbeb2dbe8b5c549db706a634fdb0', 2097152, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf3, (1, 16, 1024, 64), dtype=torch.float16, requires_grad=True)  # value_states
    buf4 = reader.storage('a22f237f6ff537f598c9f0b2728e44c66b52a875', 2097152, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf4, (1, 16, 1024, 64), dtype=torch.float16, requires_grad=True)  # query_states_1
load_args._version = 0

if __name__ == '__main__':
    from torch._dynamo.repro.after_dynamo import run_repro
    run_repro(mod, load_args, accuracy=False, command='run',
        save_dir='/home/yunfei/code/piece/checkpoints', autocast=True, backend='inductor')
Versions

Collecting environment information...
PyTorch version: 2.5.0a0+gitf1c439c
Is debug build: False
CUDA used to build PyTorch: 12.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.30.0
Libc version: glibc-2.31

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-190-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.2.91
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-PCIE-40GB
Nvidia driver version: 535.183.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.2.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 64
On-line CPU(s) list: 0-63
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 1
NUMA node(s): 8
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7713 64-Core Processor
Stepping: 1
Frequency boost: enabled
CPU MHz: 1448.894
CPU max MHz: 2000.0000
CPU min MHz: 1500.0000
BogoMIPS: 3999.52
Virtualization: AMD-V
L1d cache: 2 MiB
L1i cache: 2 MiB
L2 cache: 32 MiB
L3 cache: 256 MiB
NUMA node0 CPU(s): 0-7
NUMA node1 CPU(s): 8-15
NUMA node2 CPU(s): 16-23
NUMA node3 CPU(s): 24-31
NUMA node4 CPU(s): 32-39
NUMA node5 CPU(s): 40-47
NUMA node6 CPU(s): 48-55
NUMA node7 CPU(s): 56-63
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] bert_pytorch==0.0.1a4
[pip3] functorch==1.14.0a0+b71aa0b
[pip3] mypy==1.10.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] onnx==1.16.1
[pip3] optree==0.12.1
[pip3] pytorch-labs-segment-anything-fast==0.2
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.5.0a0+gitf1c439c
[pip3] torch_geometric==2.4.0
[pip3] torchao==0.3.1
[pip3] torchaudio==2.3.1+3edcf69
[pip3] torchmultimodal==0.1.0b0
[pip3] torchvision==0.19.0a0+d23a6e1
[conda] bert-pytorch 0.0.1a4 dev_0
[conda] functorch 1.14.0a0+b71aa0b pypi_0 pypi
[conda] numpy 1.24.3 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch-labs-segment-anything-fast 0.2 pypi_0 pypi
[conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi
[conda] torch 2.5.0a0+gitf1c439c dev_0
[conda] torch-geometric 2.4.0 pypi_0 pypi
[conda] torchao 0.3.1 pypi_0 pypi
[conda] torchaudio 2.3.1+3edcf69 pypi_0 pypi
[conda] torchmultimodal 0.1.0b0 pypi_0 pypi
[conda] torchvision 0.19.0a0+d23a6e1 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki @mcarilli @ptrblck @leslie-fang-intel @jgong5 @chauhang @penguinwu @voznesenskym @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4