torch.compiler.set_stance
#
Author: William Wen
torch.compiler.set_stance
is a torch.compiler
API that enables you to change the behavior of torch.compile
across different calls to your model without having to reapply torch.compile
to your model.
This recipe provides some examples on how to use torch.compiler.set_stance
.
torch >= 2.6
torch.compile.set_stance
can be used as a decorator, context manager, or raw function to change the behavior of torch.compile
across different calls to your model.
In the example below, the "force_eager"
stance ignores all torch.compile
directives.
Sample decorator usage
@torch.compiler.set_stance("force_eager") def bar(x): # force disable the compiler return foo(x) print(bar(inp)) # not compiled, prints -1
Sample context manager usage
Sample raw function usage
tensor([-1., -1., -1.]) tensor([1., 1., 1.])
torch.compile
stance can only be changed outside of any torch.compile
region. Attempts to do otherwise will result in an error.
@torch.compile def baz(x): # error! with torch.compiler.set_stance("force_eager"): return x + 1 try: baz(inp) except Exception as e: print(e) @torch.compiler.set_stance("force_eager") def inner(x): return x + 1 @torch.compile def outer(x): # error! return inner(x) try: outer(inp) except Exception as e: print(e)
Attempt to trace forbidden callable <function set_stance at 0x7fad33c79510> from user code: File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 85, in baz with torch.compiler.set_stance("force_eager"): Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" Attempt to trace forbidden callable <function inner at 0x7fad49a7d240> from user code: File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 103, in outer return inner(x) Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
"default"
: The default stance, used for normal compilation.
"eager_on_recompile"
: Run code eagerly when a recompile is necessary. If there is cached compiled code valid for the input, it will still be used.
"fail_on_recompile"
: Raise an error when recompiling a function.
See the torch.compiler.set_stance
doc page for more stances and options. More stances/options may also be added in the future.
Some models do not expect any recompilations - for example, you may always have inputs with the same shape. Since recompilations may be expensive, we may wish to error out when we attempt to recompile so we can detect and fix recompilation cases. The "fail_on_recompilation"
stance can be used for this.
@torch.compile def my_big_model(x): return torch.relu(x) # first compilation my_big_model(torch.randn(3)) with torch.compiler.set_stance("fail_on_recompile"): my_big_model(torch.randn(3)) # no recompilation - OK try: my_big_model(torch.randn(4)) # recompilation - error except Exception as e: print(e)
Detected recompile when torch.compile stance is 'fail_on_recompile'
If erroring out is too disruptive, we can use "eager_on_recompile"
instead, which will cause torch.compile
to fall back to eager instead of erroring out. This may be useful if we don’t expect recompilations to happen frequently, but when one is required, we’d rather pay the cost of running eagerly over the cost of recompilation.
tensor([1., 1., 1.]) tensor([1., 1., 1.]) tensor([-1., -1., -1., -1.]) tensor([1., 1., 1.])Measuring performance gains#
torch.compiler.set_stance
can be used to compare eager vs. compiled performance without having to define a separate eager model.
# Returns the result of running `fn()` and the time it took for `fn()` to run, # in seconds. We use CUDA events and synchronization for the most accurate # measurements. def timed(fn): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() result = fn() end.record() torch.cuda.synchronize() return result, start.elapsed_time(end) / 1000 @torch.compile def my_gigantic_model(x, y): x = x @ y x = x @ y x = x @ y return x inps = torch.randn(5, 5), torch.randn(5, 5) with torch.compiler.set_stance("force_eager"): print("eager:", timed(lambda: my_gigantic_model(*inps))[1]) # warmups for _ in range(3): my_gigantic_model(*inps) print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
eager: 0.00019046400487422944 compiled: 7.88479968905449e-05Crashing sooner#
Running an eager iteration first before a compiled iteration using the "force_eager"
stance can help us to catch errors unrelated to torch.compile
before attempting a very long compile.
sin() takes 1 positional argument but 2 were givenConclusion#
In this recipe, we have learned how to use the torch.compiler.set_stance
API to modify the behavior of torch.compile
across different calls to a model without needing to reapply it. The recipe demonstrates using torch.compiler.set_stance
as a decorator, context manager, or raw function to control compilation stances like force_eager
, default
, eager_on_recompile
, and “fail_on_recompile.”
For more information, see: torch.compiler.set_stance API documentation.
Total running time of the script: (0 minutes 10.367 seconds)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4