Skip to content

PyTorch 2.2: FlashAttention-v2, AOTInductor

Compare
Choose a tag to compare
@jcaip jcaip released this 30 Jan 17:58
· 5830 commits to main since this release
8ac9b20

PyTorch 2.2 Release Notes

  • Highlights
  • Backwards Incompatible Changes
  • Deprecations
  • New Features
  • Improvements
  • Bug fixes
  • Performance
  • Documentation

Highlights

We are excited to announce the release of PyTorch® 2.2! PyTorch 2.2 offers ~2x performance improvements to scaled_dot_product_attention via FlashAttention-v2 integration, as well as AOTInductor, a new ahead-of-time compilation and deployment tool built for non-python server-side deployments.

This release also includes improved torch.compile support for Optimizers, a number of new inductor optimizations, and a new logging mechanism called TORCH_LOGS.

Please note that we are deprecating macOS x86 support, and PyTorch 2.2.x will be the last version that supports macOS x64.

Along with 2.2, we are also releasing a series of updates to the PyTorch domain libraries. More details can be found in the library updates blog.

This release is composed of 3,628 commits and 521 contributors since PyTorch 2.1. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.2. More information about how to get started with the PyTorch 2-series can be found at our Getting Started page.

Summary:

  • scaled_dot_product_attention (SDPA) now supports FlashAttention-2, yielding around 2x speedups compared to previous versions.
  • PyTorch 2.2 introduces a new ahead-of-time extension of TorchInductor called AOTInductor, designed to compile and deploy PyTorch programs for non-python server-side.
  • torch.distributed supports a new abstraction for initializing and representing ProcessGroups called device_mesh.
  • PyTorch 2.2 ships a standardized, configurable logging mechanism called TORCH_LOGS.
  • A number of torch.compile improvements are included in PyTorch 2.2, including improved support for compiling Optimizers and improved TorchInductor fusion and layout optimizations.
  • Please note that we are deprecating macOS x86 support, and PyTorch 2.2.x will be the last version that supports macOS x64.
  • torch.ao.quantization now offers a prototype torch.export based flow
Stable Beta Prototype Performance Improvements
FlashAttentionV2 backend for scaled dot product attention PT 2 Quantization Inductor optimizations
AOTInductor Scaled dot product attention support for jagged layout NestedTensors aarch64-linux optimizations (AWS Graviton)
TORCH_LOGS
torch.distributed.device_mesh
torch.compile + Optimizers

*To see a full list of public 2.2 - 1.12 feature submissions click here.

Tracked Regressions

Performance reduction when using NVLSTree algorithm in NCCL 2.19.3 (#117748)

We have noticed a performance regression introduced to all-reduce in NCCL 2.19.3. Please use version 2.19.1 instead.

Poor numeric stability of loss when training with FSDP + DTensor (#117471)

We observe the loss will flatline randomly while training with FSDP + DTensor in some instances.

Backwards Incompatible Changes

Building PyTorch from source now requires GCC 9.4 or newer (#112858)

GCC 9.4 is the oldest version fully compatible with C++17, which the PyTorch codebase has migrated to from C++14.

Updated flash attention kernel in scaled_dot_product_attention to use Flash Attention v2 (#105602)

Previously, the v1 Flash Attention kernel had a Windows implementation. So if a user on Windows had explicitly forced the flash attention kernel to be run by using sdp_kernel context manager with only flash attention enabled, it would work. In 2.2, if the sdp_kernel context manager must be used, use the memory efficient or math kernel if on Windows.

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
  torch.nn.functional.scaled_dot_product_attention(q,k,v)
# Don't force flash attention to be used if using sdp_kernel on Windows
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
  torch.nn.functional.scaled_dot_product_attention(q,k,v)

Rewrote DTensor (Tensor Parallel) APIs to improve UX (#114732)

In PyTorch 2.1 or before, users can use ParallelStyles like PairwiseParallel and specify input/output layout with functions like make_input_replicate_1d or make_output_replicate_1d. And we have default values for _prepare_input and _prepare_output. The UX of Tensor Parallel was like:

from torch.distributed.tensor.parallel.style import (
    ColwiseParallel,
    make_input_replicate_1d,
    make_input_reshard_replicate,
    make_input_shard_1d,
    make_input_shard_1d_last_dim,
    make_sharded_output_tensor,
    make_output_replicate_1d,
    make_output_reshard_tensor,
    make_output_shard_1d,
    make_output_tensor,
    PairwiseParallel,
    parallelize_module,
)
from torch.distributed.tensor import DeviceMesh

module = DummyModule()
device_mesh = DeviceMesh("cuda", list(range(self.world_size)))
parallelize_module(module, device_mesh, PairwiseParallel(_prepare_input=make_input_replicate_1d))
...

Starting from PyTorch 2.2, we simplified parallel styles to only contain ColwiseParallel and RowwiseParallel because other ParallelStyle can consist of these two. We also deleted the input/output functions, and started using input_layouts and output_layouts as kwargs instead to specify the sharding layout of both input/output tensors. Finally, added PrepareModuleInput/PrepareModuleOutput style, and no default arguments for layouts in these two styles and users need to specify them to think about the sharding layouts.

from torch.distributed.tensor.parallel.style import (
    ColwiseParallel,
    PrepareModuleInput,
    RowwiseParallel,
    parallelize_module,
)
from torch.distributed._tensor import init_device_mesh

module = SimpleMLPModule()
device_mesh = init_device_mesh("cuda", (self.world_size,)))
parallelize_module(
   module,
   device_mesh,
   {
      "fqn": PrepareModuleInput(
                input_layouts=Shard(0),
                desired_input_layouts=Replicate()
             ),
      "fqn.net1": ColwiseParallel(),
      "fqn.net2": RowwiseParallel(output_layouts=Shard(0)),
   }
)
...

UntypedStorage.resize_ now uses the original device instead of the current device context (#113386)

Before this PR, UntypedStorage.resize_ would move data to the current CUDA device index (given by torch.cuda.current_device()).
Now, UntypedStorage.resize_() keeps the data on the same device index that it was on before, regardless of the current device index.

2.1 2.2
>>> import torch
>>> with torch.cuda.device('cuda:0'):
...:     a = torch.zeros(0, device='cuda:1')
...:     print(a.device)
...:     a = a.untyped_storage().resize_(0)
...:     print(a.device)
cuda:1
cuda:0
>>> import torch
>>> with torch.cuda.device('cuda:0'):
...:     a = torch.zeros(0, device='cuda:1')
...:     print(a.device)
...:     a = a.untyped_storage().resize_(0)
...:     print(a.device)
cuda:1
cuda:1

Wrapping a function with set_grad_enabled will consume its global mutation (#113359)

This bc-breaking change fixes some unexpected behavior when set_grad_enabled is used as a decorator.

2.1 2.2
>>> import torch
>>> @torch.set_grad_enabled(False)  # unexpectedly, this mutates the grad mode!
    def inner_func(x):
        return x.sin()

>>> torch.is_grad_enabled()
True
>>> import torch
>>> @torch.set_grad_enabled(False)  # unexpectedly, this mutates the grad mode!
    def inner_func(x):
        return x.sin()

>>> torch.is_grad_enabled()
False

Deprecated verbose parameter in LRscheduler constructors (#111302)

As part of our decision to move towards a consolidated logging system, we are deprecating the verbose flag in LRScheduler.

If you would like to print the learning rate during execution, please use get_last_lr()

2.1 2.2
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True)
for epoch in range(10):
    train(...)
    val_loss = validate(...)
    # Note that step should be called after validate()
    scheduler.step(val_loss)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):
    train(...)
    val_loss = validate(...)
    # Note that step should be called after validate()
    scheduler.step(val_loss)
	print(f"Epoch {epoch} has concluded with lr of {scheduler.get_last_lr()}")

Removed deprecated c10d multi-gpu-per-thread APIs (#114156)

In PyTorch 2.1 or before, users can use our multi-gpu c10d collective APIs such as all_reduce_multigpu:

2.1 2.2
import torch.distributed as dist


dist.broadcast_multigpu
dist.all_reduce_multigpu
dist.reduce_multigpu
dist.all_gather_multigpu
dist.reduce_scatter_multigpu
...
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):
    train(...)
    val_loss = validate(...)
    # Note that step should be called after validate()
    scheduler.step(val_loss)
	print(f"Epoch {epoch} has concluded with lr of {scheduler.get_last_lr()}")

In PyTorch 2.2, these APIs are removed because PyTorch Distributed's preferred programming model is one device per thread, as exemplified by the APIs in its document. The multi-GPU functions (which stand for multiple GPUs per CPU thread) have been deprecated since PyTorch 1.13.

Rename torch.onnx.ExportOutput* to ONNXProgram* (#112263)

The torch.onnx.dynamo_export’s output was renamed from torch.onnx.ExportOutput to torch.onnx.ONNXProgram to better align with torch.export.export API terminology which returns a torch.export.ExportedProgram. With this change, any ambiguity that could arise with either API is eliminated.

2.1 2.2
export_output: torch.onnx.ExportOutput = torch.onnx.dynamo(...)
onnx_program: torch.onnx.ONNXProgram = torch.onnx.dynamo(...)

Fix functional::smooth_l1_loss signatures to not override beta (#109798)

Previously, there were two possible options to pass in beta to smooth_l1_loss, either as a SmoothL1LossFuncOption parameter or a function parameter.

Before, the beta specified as a function parameter would override the other beta if it was set, which was unexpected behavior. Now, we throw an error when beta is passed in both cases.

Deprecations

Autograd API

Deprecate not passing use_reentrant kwarg to torch.utils.checkpoint.checkpoint_sequential explicitly (#114158)

The use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
Note that not passing use_reentrant kwarg to torch.utils.checkpoint.checkpoint has been previously deprecated in a previous release.

2.1 2.2
a = torch.randn(3, requires_grad=True)
modules_list = [
    torch.nn.Linear(3, 3),
    torch.nn.Linear(3, 3),
    torch.nn.Linear(3, 3)
]

# This would produce a warning in 2.2
checkpoint_sequential(modules_list, 3, a)
# Recommended
checkpoint_sequential(modules_list, 3, a, use_reentrant=False)

# To preserve existing behavior
checkpoint_sequential(modules_list, 3, a, use_reentrant=True)

Deprecate "fallthrough" as autograd fallback default (#113166)

Custom operators that do not have a kernel registered to the Autograd keys (e.g. AutogradCPU and AutogradCUDA) will now produce a warning when used with autograd.
If your custom operator previously returned floating-point or complex Tensors that do not require grad, they will now require grad as long as grad mode is enabled and the inputs require grad.
For users who would like the old behavior, register torch::CppFunction::makeFallthrough() to your Autograd key, as shown here.

The below example uses the torch library API, but if you are writing an operator in a cpp extension, please read this doc for more information.

import torch
import numpy as np

# Define the operator
torch.library.define("mylibrary::sin", "(Tensor x) -> Tensor")

# Add implementations for the cpu device
@torch.library.impl("mylibrary::sin", "cpu")
def f(x):
    return torch.from_numpy(np.sin(x.detach().numpy()))
x = torch.randn(3, requires_grad=True)
y = torch.ops.mylibrary.sin(x)
y.sum().backward()
2.1 2.2
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
UserWarning: mylibrary::sin: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.

Linalg

Deprecate torch.cross default behavior (#108760)

Calling torch.cross without specifying the dim arg is now deprecated. This behavior will be changed to match that of torch.linalg.cross in a future release.

Jit

NVFuser functionality has been removed from TorchScript (#110124, #111447, #110881)

Neural Network Compiler (NNC) has replaced NVFuser as the default GPU fuser for TorchScript in PyTorch 2.1, which also added a deprecation warning for NVFuser. The TorchScript functionality for NVFuser has now been fully removed and is no longer supported.

Optimizer

SparseAdam constructor will no longer accept raw Tensor type for params (#114425)

SparseAdam is now consistent with the rest of our optimizers and only accepts containers instead of individual Tensors/Parameters/param groups.

2.1 2.2
import torch
param = torch.rand(16, 32)
optimizer = torch.optim.SparseAdam(param)
optimizer = torch.optim.SparseAdam([param])

New Features

torch.compile

Dynamo

  • Fully enabled compiled optimizers (#115906)
  • Cudagraphs support for compiled optimizers (#107504)
  • Experimental support for TorchDynamo tracing with DTensor (#108329)
  • Experimental support for torch.compile, activation checkpointing and FSDP (#103953)
  • Dynamo variable trackers are mutable (#113725) - Improves Dynamo compilation time
  • Reduce cache size limit to 8 - Quickly fallback to eager for non-compile friendly functions (#108526)

Inductor

torch.export

  • Introduce dynamic_shapes API in favor of constraints (#108448, #112298, #110101, #110638, #110276)
  • Add torch.export.register_dataclass API (#109152)
  • Expose torch.ops.higher_order.map (#111404)
  • Change export to return full ATen IR (not Core ATen) and add a run_decomposition() function to allow users to pass in a decomposition table (or by default it will decompose to the core ATen decomposition table) (#111030, 8be2611, #110236, #114714)

Build

  • Add Hopper (CUDA arch 9.0a) support (#110587)

Python API

  • Add torch.unravel_index (#110580)
  • Add multi-dim reductions for `torch.{any,all} (#110310)
  • Add file name and size to the serialization metadata logging (#113077)
  • Add torch.distributions.InverseGamma distribution and fix sign bug in torch.distributions.PowerTransform (#104501)
  • Add torch.utils.deterministic.fill_uninitialized_memory flag (#111377)

Profiler

  • Show shapes for lists of tensors in chrome traces (#109751)
  • Add src/dst information to NCCL send/recv (#111811)
  • Populate in/out split size information for NCCL all_to_all from CPU to CUDA kernel (#112308)

Quantization

  • Add CUTLASS-based support for mixed dtypes matrix multiplication (#110981)

Sparse API

  • Add torch.compile support and padding for semi-structured sparsity (#111049, #110583)
  • Add CSR tensor with non-contiguous values support to CuSparseSpMatCsrDescriptor (#111742)
  • Add scatter_mm and bsr_scatter_mm operations. (#110396, #111796)
  • Add is_sparse as a property of MaskedTensor (#110725)

NestedTensor API

  • Add unary out-of-place sin / cos support (#107891)
  • Add binary out-of-place ge.Scalar / eq.Scalar support (#107892)
  • Add binary op support for (B, C, *, *) NT with (C, 1, 1) dense (#107890)
  • Add support for cat with dim=0 (#108361)
  • Add support for matmul of (B, *, C, D) NT with dense (D, E) (#108370)
  • Add support for narrow() on dim=0 (#108362)
  • Add support for cat with dim > 0 when representable as jagged (#108428)
  • Add public API for constructing NT with jagged layout from tensor list (#111078)

Misc

  • Python 3.10 Union operator | support for JIT (#109293)
  • Allow specifiying inputs as GradientEdge in autograd APIs (#110867, dev-discuss)
  • Use CapturedTraceback symbolizer for C++ exceptions from Python library (#113207)
  • Add sparse tensor support to dataloader (#112842)
  • Add 0dim Tensor overload for _foreach_div (#113688)
  • Add global_step parameter to SummaryWriter.add_hparams (#109572)

Fx

  • Add a matcher that supports name to node mapping (#110743)
  • Add splitting by tags feature (#109332)
  • Allow tracing calls with Python Enum values. (#109507)
  • Add function to port FX minified graph to HLO via StableHLO (#109084)

ONNX

  • Add symbolic shape support for torch.onnx.dynamo_export(#112179)
  • Add optional torch.export.ExportGraphSignature to ONNXProgram (#113477)
  • Add ONNXProgram.__call__ API to run model with ONNX Runtime (#113495)
  • Add decomposition support for dynamo_export + ExportedProgram (#112444)
  • Add user input mutation support for dynamo_export + ExportedProgram (#114596)
  • Add mutated buffer support for dynamo_export + ExportedProgram (#112272)
  • Add FakeTensor support for dynamo_export + ExportedProgram (#114407)

CPU

  • Add support for torch.cpu.set_device() and torch.cpu.current_device() (#110716, #110987)

MPS

  • Pixel shuffle unshuffle support (#99306)
  • Add lgamma, digamma, and polygamma implementations (#106292)
  • Add support for aten::nextafter (#109685)
  • Adding weight_norm_interface support for mps (#108008)
  • Add searchsorted op (#112829)
  • Add bucketize op (#112830)

Vulkan

  • Add Vulkan support for several ATen operators:
  • Partial implementation of 1D convolution (only supports stride=1, padding=0, dilation=1 for now) (#112880)
    • Add support for 0-size tensors (i.e. a tensor with a size of 0, for example with sizes {2, 1, 0}) (#111512)
    • Add support for 0-dim tensors (i.e. a tensor with sizes of {}) (#111680)

Improvements

torch.compile

Dynamo

  • Dispatch numpy.take_along_axis to torch.take_along_dim (#108880)
  • Force specialization on INT_LIST (#111216)
  • Add custom treespec fqn field (#112428)
  • Better error handling for cond (#108817)
  • Support 'BaseOutput' and subclasses from 'diffusers' in dynamo (#111978)
  • Add infinite generators itertools.{count, repeat, cycle} (#110967)
  • Add support for dict.fromkeys() / OrderedDict.fromkeys() / defaultdict.fromkeys() (#115010)
  • Add support for dict.update(seq2) / OrderedDict.update(seq2) / defaultdict.update(seq2) (#115011)
  • Add support for dict.copy() / OrderedDict.copy() / defaultdict.copy() (#115012)
  • Force synced KJT to trace unbacked SymInt (#108960)

Inductor

  • max-autotune improvements
  • Make codegen stateless (#107320, #107617)
  • Add or improve lowering rules for prims.div, reflection_pad2d, full, sub, _local_scalar_dense, index, reflection_pad2d, index_put, unfold (#102809, #110988, #108166, #108518, #109893, #111015, #111212, #113204, #113259)
  • Add or improve decomposition rules for grid_sampler_2d, full, torch.ops.quantized.embedding_bag_byte_unpack, amax/amin, native_dropout, bmm, mm, complex dtype addition, upsample_nearest_exactNd (#104710, #108443, #109398, #110311, #115040, #109836, #110740, #113749)
  • Add reinplacing pass for scatters + incremental fake tensor updating (#106192)
  • Add meta-registration for _sparse_semi_structured_linear, _cslt_sparse_mm (#114477, #114685 )
  • Provide fallback values for unbacked symint (#109893, #110520)
  • Decompose addmm on cpu for a few special cases (e.g. dot product or small matrix vector multiplication) (#110010, #110456)
  • Don't tune beyond 32 warps (which is a CUDA limit) for the coordinate descent tuner (#108997)
  • Avoid special characters in cache_dir path (#110945)
  • Add DeviceInterface abstraction to make inductor code more device agnostic (#109486 )
  • Allow matmul to have flexiable layout when we are not autotuning (#110726)
  • Allow backend compiler skipping a frame for transient errors (#111153)
  • Handle item() on boolean tensor (#114157)
  • Replace rand[n].generator with inductor prim if generator=None (#115051)
  • Support channel last for XPU convolution in inductor layout optimization path (#111018)
  • Early work to improve the static memory planning algorithm (#111402)
  • Added support for symbolic shapes in FX graph cache (#111421)
  • Added config to specify the shape attribute for the generated svg graphs (#114811)
  • Quantization
    • Enable quantization dynamic batch size support (#108550)
    • Enable QConv2d Unary & Binary int8-mixed-bf16 Lowering (#112550, #112551)
    • Enable QLinear int8-mixed-bf16 Lowering (#112486)
    • Enable the lowering of quantized reshape (#114443)
    • Enable the Inductor Lowering of QConv2d post op hardtanh (#114580)
  • Improve the CPU backend
  • Improve the Fx pattern matching passes
    • Generalize pointless_cumsum_replacement pattern (#108373)
    • Improve mem efficiency of constant folding (#108421)
    • Make sure unfuse_addmm and addmm patterns don't overlap (#110235)
    • Improve reinplace_scatters pass (#112801)
    • Make pattern-matcher failure diagnostics lazy and add an error message if format string is too long (#112923)
  • Foreach kernel compilation time improvement
    • Skip searching getitem in group batch fusion pass reduces optimizer compilation time by 60s (#112088)
    • Re-inplace foreach when safe and allow aliasing during lowering (#112440)
  • AOTInductor
    • ABI-compatible mode support
      • Add a C shim layer for libtorch (#109391, #109834)
      • Support _scaled_dot_product_flash_attention fallback (#110085)
      • Add AOTI ABI shim function for repeat_interleave.Tensor (#110745)
      • Add size, stride, storage_offset to RAIIAtenTensorHandle (#110764)
      • Add AOTI ABI shim function for torch.nonzero (#110766)
      • Enable floor_div indexing to work under ABI-compat mode (#113276)
      • Add ABI shim function for torch.scatter (#114027)
      • Support ReinterpretView in ABI mode (#114169)
      • Support at::convolution for AOTInductor (#114961)
    • ProxyExecutor for custom ops support
      • ProxyExecutor skips serializing missing args with default value (#111425)
      • Support List[Tensor] return type (#110182)
      • Proxy Executor for Extern Fallback kernels (#108350)
      • Switch ProxyExecutor to use AtenTensorHandle (#109748)
      • ProxyExecutor supports custom op with tuple output (#110140)
      • ProxyExecutor supports Tuple of Tensor and List[Tensor] in returns (#110187)
      • ProxyExecutor support ReinterpretView inputs (#110451)
      • ProxyExecutor support Dynamic Shape (#110526)
      • Allow using ProxyExecutor for ATen fallbacks (#112976)
      • Use ProxyExecutor for aten op if c-shim is missing (#113918)
    • CPU performance improvement
      • Generate reused thread_locals when tensors probably have static shape (#110892)
      • Cache dtypes and device types at DSO load (#111820)
      • Emit CACHED_TORCH_TYPE only as needed (#113997)
    • UX improvement and refactoring
      • Use array of constants (#111815)
      • Write weight files only if they do not exist yet (#111379)
      • Enforce no_grad for 'run' entry points (#111613)
      • Improve validation for C++ wrapper codegen (#111102)
      • Avoid generating redundant kernel loading code (#110510)
      • Group AOTInductor configs under aot_inductor class (#108369)
      • Include constants in the generated .so file (#108473)
      • Do not hardcode directory with .cubin files (#109151)
      • Add is_cpu for AOTInductorModelContainer (#109287)
      • Pass TorchIR to AOTInductor (#110020)
      • A lightweight model runner (#110158)
      • Remove CUDA dependency for cpp backend (#110409)
      • Delay the fallback kernel naming decision to the codegen time (#113660)
      • Move constant loading logic from Container to Model (#112197)
      • Allow specifying a .so name in the aot_inductor.output_path config (#112651)
      • Improve the two-pass wrapper codegen (#114067)

torch.export

  • Address constant tensors in ExportedPrograms (#113689, #108592)
  • Remove replaced symbols from range_constraints (#110644)
  • Copy graph module before calling PassManager (#108321)
  • Made aot_export_module uses dynamo's fake_mode (#114009, #114381)
  • Core ATen Opset
    • Registered additional ATen operators as core (#110882)
    • De-registered full_like and empty_like as core (#110924)
    • Added div.Tensor_mode, div.Scalar_mode, and copy as core operators (#109812)

Composability

  • FakeTensors and meta tensors are used to perform shape propagating when tracing out a graph in torch.compile. There were a number of op coverage improvements this release:
  • We have python “reference” decompositions for many aten operators. These are used during the tracing step of torch.compile. In a few ways: sometimes they are used to directly decompose operators in the captured graph. Other times, they are used as an alternative to a shape-propagation rule for an operator. There were several improvements to operator coverage in this release
  • We also have an opset known as “Core ATen IR” as defined here. Several ops were either added to core ATen, or had decompositions for them added, that decompose into other core ATen operators:
  • decompositions:

Python API

  • Add a UserWarning when using torch.{std,var,std_mean,std_var} with dof<=0 (#109824)
  • Add torch.half support for torch.multinomial on CPU (#104178)
  • Add support for serializing torch.float8_* dtypes (#114662)
  • Add different out dtypes support to torch.addc{mul,div} (#112682)

torch.nn API

  • Add __all__ for torch.nn.utils (#111026)
  • Add Half support for AdaptiveAvgPool2d and AdaptiveMaxPool2d on CPU (#102079)
  • Add Half support for GroupNorm on CPU (#100234)
  • Add Half support for torch.nn.functional.{softmax/log_softmax} on CPU (#103315)
  • Add BFloat16 support to torch.nn.functional.grid_sample (#112331)
  • Add BFloat16 support for nn.utils.parametrizations.weight_norm (#114785)

Linalg API

  • Add fp16 support for gemm on CPU (#99498)
  • Add quantized int4 support for mm (#111403)

Optimizer API

  • Allow torch.float64 scalars for forloop + foreach implementations (#115841, #111008)
  • Add NAdam support for complex dtypes, with has_complex shortcut (#110634)
  • Set default learning rate (lr) value of SGD to 1e-3 (#114467)
  • Add capturable ASGD impl (#107857)

torch.func

  • Add vmap support for various in-place operations (#110692, #113513)
  • Add vmap support for torch.unsafe_chunk (#110862)
  • Add vmap support for Tensor.index_add_ (#112276)
  • Add vmap support for torch.linspace and torch.logspace (#105451)
  • Add vmap support for torch.linalg.eigh (#110640)
  • Add dynamic shapes support for vmap over torch.squeeze and alias (#107577)
  • Add dynamic shapes support for vmap over torch.is_same_size and torch.split_with_sizes (#111491)

Misc

  • Add torch.utils.checkpoint.set_checkpoint_debug_enabled (#110728)
  • StackDataset batched sampling (#110694)
  • Add option to flop counter formula registration to get raw values (#110591)

Quantization

  • Bits Types
    • Enable cat for bits types (e.g.torch.bits8) in cuda (#115044)
    • Enable copy/clone/reshape/contiguous operations for bits types (#113508)
  • PyTorch 2 Export Quantization:
    • Add reference representation for dynamic quantized linear (#108073)
    • Use input_qspec_map for weight quantization of linear (#107105)
    • Make annotation util functions return annotated nodes (#107106)
    • Add dequantize operator duplication pass (#107900)
    • Add metadata porting for nodes added by quantization (#107107)
    • Move to BFS instead of DFS to check for connectedness (#108572)
    • Support int16 quantization (#108453)
    • Support cat (#108382), conv1d (#109830) and mul (#110428) in XNNPACKQuantizer
    • Enable constant folding for quantize ops (#109343)
    • Add util function to convert scalars to attrs (#110427)
    • Support cudnn_batch_norm (#109908) and miopen_batch_norm (#110653) in QAT fusion
    • Preserve source_fn_stack after QAT fusion (#110899, #111515)
    • Cleanup observer insertion logic (#111828) (#112453)
    • Fix QAT conv-bn bias using DerivedQSpec (#112159)
    • Refactor QAT q-dq patterns (#112279)
    • Add "quantization_tag" as metadata to fx.Proxy (#108764)
    • Inductor cpp wrapper: support QLinear (#112378)
    • Enable QAT Quantization flow in X86InductorQuantizer (#111280)
    • Add ConvBNAdd(ReLU) Annotation (#111281), adaptive_avg_pool2d and flatten (#114442), Hardtanh and ReLU6 for conv2d (#114579) in X86InductorQuantizer
    • Enable oneDNN QConv (#112010) and QLinear (#112126) FP32/BF16 output
    • Enable oneDNN QConv2d with hardtanh post op (#114578)
    • Remove the output Annotation of Conv/Linear in x86InductorQuantizer (#112140)
    • Enable quantize_per_tensor/quantize_per_channel to accept bfloat16 input (#112225)
    • Support quantized conv bias in QAT fusion (#112528)
    • Fix custom dtype per channel weight in QAT (#112612)
    • Support allow_implicit_sharing flag (#112929)
    • Add transform_for_annotation method in Quantizer (#113115)
    • Remove add/relu from conv-bn QAT pattern (#113006)
    • Rewrite QAT annotations using SubgraphMatcherWithNameNodeMap (#113709)
    • Support conv1d-bn QAT fusion (#113714)
    • XNNPACKQuantizer skip quantization for input and output to workaround histogram observer problem (#113405)
    • Add support for QAT dynamic quantization for linear in XNNPACKQuantizer (#113288)
  • Support Subclasses of FloatFunctional in torch.ao.quantization.prepare (#109646)
  • Enable pickling model prepared with QAT QConfig (#109288)
  • Suppress empty translation unit warning in QNNPACK (#111475)
  • new_qtensor support privateuseone allocator. (#111464)
  • Add support for float8_e4m3fnuz and float8_e5m2fnuz (#107586)
  • Overload vec::dequantize to eliminate rounding error for quantized sigmoid (#114098)

NestedTensor API

  • Pickle support for NT (#110219)
  • Multiprocessing support for NT (#110292)
  • pin_memory support for NT (#110404)
  • Implement split_with_sizes backward for NT (#110647)
  • Multiprocessing support for NT (#110292)
  • Support for as_nested_tensor() with jagged layout + fixed nested_tensor() semantics (#112304)
  • Do not generate zero-numel NT by default in helper and improve to_padded_tensor msg (#113162)
  • Backward support for broadcasting binary ops (#112519)
  • Implement narrow from a regular tensor to jagged tensor (#112770)

Distributed

  • c10d
    • Make TCPStore more robust to one-time interruptions. (#108425)
    • Add functional collective all_to_all_single and support it in Inductor (#110195)
    • Set ProcessGroupNCCL default timeout to 10 min (#110947)
    • Add an explicit _shutdown method to ProcessGroupNCCL (#111392)
    • Enable coalescing manager in DETAIL debug mode (#111878)
    • Avoid recording stream for all-gather, reduce-scatter, broadcast and scatter (#111431, #112896)
    • Relax tensor contiguity requirement for P2P ops (#114982)
    • Add .boxed() to c10d::ProcessGroup and c10d::Work's pybind (#111997)
    • Make init_process_group timeout kwarg override pg_options (#112611, #113094)
    • Use allocator trace callbacks for ProcessGroupNCCL register (#112850)
    • Make FakeProcessGroup traceable (#113314)
    • Add Bfloat16 scalar support to gloo backend (#113557)
    • Opportunistically use ncclCommSplit when creating new NCCL groups (#114385)
    • Add API _set_group_name and group_name to track pg names in C++. (#108813)
  • Distributed Checkpointing (DCP):
    • Stateful Checkpointing for Distributed (#113867)
  • DistributedDataParallel (DDP):
    • Add an API to DDP for dynamically updating the underlying process group. (#113580, #114194)
  • DTensor
    • Supported convolution ops (#113123)
    • Add DTensor constructor: randn (#108285)
    • Add grad placements kwarg to to_local API (#110629)
    • Support aten.where and enabled implicit scalar promotion (#110584)
    • Support lt/gt op (#110585)
    • Enable DTensor TP in the inference mode (#110751)
    • Refactor Parallel Style and TP API to improve UX (#111160, #111166, #111176, #111346, #111353, #111625, #111521)
    • Enable embedding sharding in TP API (#111177)
    • Enable adagrad foreach support (#114151)
    • Enable RMSprop optimizer foreach support (#114152)
    • Introduce full_tensor API to DTensor (#112224, #113322)
    • Enable foreach operators for adam optimizer (#112108)
    • Don’t use make_fx for strategy propagation (#108262)
    • Add assert of shard dim to be less than tensor ndim (#112404)
    • Add rand_like, randn_like, randint_like ops to shard propagation (#112576)
    • Enable min, max and prod sharding propagation rules (#112403)
    • Add support for layer norm op in DTensor (#113105, #113244)
    • Add foreach_zero_ support (#113897)
    • Make _Partial, Replicate frozen dataclasses (#113919)
    • Make replicate -> partial DTensor do division instead (#110898)
    • Use new placements for neg dim in redistribute (#113924)
    • Use new placements for neg dim in from_local (#114134)
    • Use new placements for neg dim in distribute_tensor (#113930)
    • Ensure grad_placements was tuple (#113925)
    • Support Xla backend in distribute_tensor API (#110275)
  • FullyShardedDataParallel (FSDP):
    • Not materialized ignored modules for FSDP (#108032)
    • Not moved ignored params / buffers to device (#108033)
    • Make checkpoint_wrapper default to NO_REENTRANT (#108435)
    • Make ModuleWrapPolicy callable (#109117)
    • Enable cpu_offload config for optimizer state_dict (#108434)
    • Enable FSDP on CPU when GPU is still available (#112145, #112144)
    • Add cpu_only and ranks_only support for _gather_state_dict (#112836)
    • Implement cpu_offload and full_state_dict for get_state_dict (#112837)
  • TorchElastic:
    • Ensure grandchild processes are restarted correctly (#113231)
    • Avoid terminating parent process if exit code from child isn't valid (#111961)

CPU

  • Use cpuinfo to determine c10::ThreadPool thread number (#107010)
  • Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU (#112132)
  • Add Half support for poisson and use float for Half cumulative distribution on CPU (#112124)
  • Remove memory efficient attention checks (#112375)
  • Enable THP for buffer sizes >=2MB (5a4f136)

CUDA

  • Add lazy initialization for p2p access function (#108589)
  • Add support of CudaHostRegister (#108488)
  • Create per thread task pool for mapping memory space (#111545)
  • Add AMP support to linalg.vecdot. (#108165)
  • bfloat16 support in erfinv (#111257)
  • Add Bfloat16 support to CrossKernel.cu (#108941)
  • Add bf16 support to replicate padding (#112099)
  • Preserve operations order between vectorized and non-vectorized in ln grad input (#111488)

Fx

  • Add mechanism for make_fx to not error on data-dependent-ops (#114129)
  • Preserve non-forward method during torch package serialization (#114702)
  • Add Graph input option for replace_pattern (#112409)
  • Allow preserving non-forward methods during deepcopy (#114849)
  • Replace node.meta source_fn with source_fn_stack (#108595)
  • Fix tree spec matching behavior (#109679)
  • Assert that output must be the last node of the FX graph (#114973)
  • Misc improvements to visualization + utility (#114984)
  • Add stylistic improvements for fx.split_module (#113373)

Jit

  • Skip builtins while enumerating class methods (#91805)
  • Support lovelace for NVRTC (#87611)
  • Add expanded symbolic shape support (movedim) (#91696)

MPS

ONNX

  • torch->onnx export support: quantized::linear_relu (#109755)
  • Add Half for aten2, logaddexp, logaddexp2, hypot, and nextafter on CPU (#112138)
  • Support None in fx.args as torchlib inputs (#108708)
  • Support attn_mask fp16 type (#110306)
  • A better way to safe guard 2GB model serialization (#111984)
  • Fix scalar type promotion between fp16 tensor and fp32 scalar (#113404)
  • Add 'aten::rsub' type promotion (#113697)
  • Relax unsupported node analysis on complex dtype (#113785)

ROCm

  • enable hipSparse const descriptors for version >= 2.4.0 (#110317)

Vulkan

  • Improve binary operators to be able to handle the other argument being a 0-dim tensor (#109035)
  • Improve binary operators to automatically convert the other argument to float in order to handle mismatched input dtype (#114145)
  • Improve aten::addmm and aten::linear to be able to broadcast the bias argument (#108199)
  • Improve aten::layernorm to be able to handle 2D inputs (#110796)
  • Improve aten::slice to be able to return 0-size output (#112879)

Bug fixes

Autograd API

  • Fix in-place custom autograd Functions to not fail when grad returned from backward is undefined (#108353)
  • Update custom Function preserve torch function when inputs returned as-is (#109825)
  • Do not error when printing view created in no-grad modified in-place in no-grad (#113716)
  • Fix torch.autograd.gradcheck when fast_mode=True and default device is set (#114560)
  • Fix torch.prod double backward when input tensor contains more than one zero (#113969)

Cpp API

  • Check results dtype in index_out (#108167)
  • Add the appropriate check on div_value to the cpp frontend (#114671)
  • Add input check at the beginning for C++ API interpolate (#108506)
  • Fix the coredump described by #106702 (#108002)
  • Fix torch.nn.GRUCell segfaulting (#108340)
  • Add checks to num_layers for RNN, LSTM, GRU (#108853)
  • torch::nn::AdaptiveLogSoftmaxWithLoss: check length of cutoffs (#106777)

Foreach API

  • Fix 0-size handling for real (#109402)

Linalg API

  • Fallback to GEMM if mkldnn_matmul fails on aarch64 (#115936)
  • Preserve input's NaN values to prevent undefined behavior for matrix_exp function (#111539)

NestedTensor API

  • Fix torch.load(..., weights_only=True) for NT (#112516)

Optimizer API

  • ReduceLROnPlateau now subclasses LRScheduler (#113659)
  • Fix adagrad sparse handling due to incorrect early exit (#110454)
  • Solving pickle error when saving CyclicLR state_dict (#110931)

Python API

  • Fix type checking of lazy submodule import (#109683)
  • Fix unhandled exceptions in torch.{finfo,iinfo} calls (#109743)
  • Fix torch.{size|stride}(dim=None)` (#111991)
  • Fix legacy typed storage warning line pointer (#113601)
  • Fix cpu detection error handling (#113771)

Sparse API

  • Fix semi-structured sparse shape mismatch bug (#110420)

torch.compile

Dynamo

  • Add torch.distributed get_rank and get_world_size to constant_fold_functions (#109029)
  • Implement traceable torch.tensor when you have SymInt/SymFloat inputs (#109515)
  • Avoid throwing exception in ClosingTHPObjectPtr (#109758)
  • Fix inductor CI (by updating graph break count) (#110160)
  • Error if you try to run Dynamo compiled function under torch.jit.trace (#111321)
  • Adjust _list_with_default to also work with SymInt input (#113073)
  • Avoid eager imports of classes with custom VariableTrackers (#112319)
  • Uniformly use SourcelessBuilder to handle user defined types (#113390)
  • Register SymInt-aware meta function for mm out, symintify resize (#113202)
  • use sourceless builder for builtin getattr (#113340)
  • use sourceless builder for builtin getattr (#113340)
  • Don't toggle torch logger to NOTSET if it is not set; always use pre-existing (#113842)
  • Fix dict.get with no default (#115048)
  • Improve support for list subclasses (#115052)

Inductor

  • Properly handle unbacked symint in various scenarios (#109603, #109609, #111803)
  • Using floating point 0 rather than integer 0 as default value for tl.load (#113047)
  • Avoid passing a None 'generator' argument to aten.rand which does not accept a generator argument (#112240)
  • Avoid recursion error because of accumulating too much computation in a pointwise IRNode (#112320)
  • Fix 0-sized views of tensors in cudagraphs (#109055)
  • Explicitly use the result's dtype for 'other' values in a masked load to avoid unexpected type promotion (#109325)
  • Work around triton issue that loading from int1 pointer returns int8 (#110388)
  • Avoid max-autotune benchmarking messing up the random number generator (RNG) state (#111381)
  • Fix an out of shared memory issue by avoiding a single invalid triton config causes fatal problem (#112916)
  • Make TORCH_COMPILE_DEBUG=1 work again (#112917)
  • Fix inductor <> ddp_optimizer issue (#108081)
  • Fix visualize_overlap for Inductor comm reordering (#113066)
  • Fix cat decomp that the first tensor was returned if it is empty and there is only one non-empty tensor (#113514)
  • Correctly codegen math.inf in Inductor (#114159)
  • Do not promote int to float for torch.mm (#115043)
  • Dont pad broadcasting bias dimension in pad mm (#115098)
  • Bug fix for the CPU backend
    • Fix argmax with >1 reduction dims (#113168)
    • Fix add/sub with uint8 dtype to avoid unexpected type promotion (#113253)
    • Fix non-contiguous reduction store (#113261)
  • Bug fix for the Fx pattern matching passes
    • Fix a bug in the merge getitem cat pattern (#113822)
    • Fix shape mismatch in SDPA pattern matcher (#115038)
  • AOTInductor
    • make AOTInductor work with pip installed torch (#108319)
    • Fixing a redefining symbol bug (#110041)
    • Make freezing work with AOTInductor (#110055)
    • Make a free function in AOTInductor header file inline to avoid redefining symbol error (#110445)
    • Fix a weight loading issue when the weight size can be 0 (#114280)
    • Handle empty input args (#114682)

torch.export

  • Fix how pass base uses constant fake tensors (#111140)

torch.func API

  • Fix vmap support for `torch.real, torch.imag (#110508)
  • Fix vmap support for torch.isfinite, torch.isreal, and torch.log_sigmoid (#110896)
  • Fix vmap support for torch.movedim, torch.tensor_split, Tensor.to, to.* (#110999)
  • Fix vmap support for torch.flatten, torch.linalg.*, torch.linear, torch.log_softmax, torch.logdet, torch.special.* (#110985)

torch.nn API

  • Fix precision issues for torch.nn.LayerNorm on CPU (#108089)
  • Madforward outputs of type collections.namedtuple are preserved instead of being changed to tuple when there are backward hooks on nn.Module (#112433)
  • Fixug in mem_eff kernel with attention mask and MQA (bc244ee)
  • Fix allowed dtypes for CUDA devices less than SM80 for memory_efficient_attention (#116026)
  • Enfced that both input tensors to nn.CosineEmbeddingLoss have the same size (#112782)
  • Fix type hints for nn.Module.to (#108767)
  • Fix torch.nn.utils.rnn.pad_sequence type hint to allow sequences to be an iterable (#108765)
  • Fix num_batches_tracked of nn.BatchNorm{*}D in load_state_dict to not be reset to 0 if the state_dict does not contain num_batches_tracked (#110850)
  • Fix convert_sync_batchnorm to return SyncBatchNorm layer with same training flag as BatchNorm layer being converted (#111998)
  • Fix 64-bit indexing support for cross-entropy CUDA kernels (#112096)

Build

  • Fix finding Intel MKL, LAPACK, cuDNN and cuSPARSELt on Windows (#108040)
  • Fix ppc64le clang compilation errors (#106446)
  • Compile FBGEMM with ASAN (#111266)

Composability

  • FakeTensors and meta tensors are used to perform shape propagating when tracing out a graph in torch.compile. There were a number of op coverage improvements this release:
  • There were several bugfixes to our python decompositions and reference implementations of our aten operators this release:
    • Operator specific
    • General bugfixes
      • fix infinite loop with primtorch and .to(meta) (#109632)
      • Fix registering jit decompositions for jvp for out wrapped decomps (#109367)
      • Fix python decomps for OpOverloadPackets and add tests (#107707)
  • fix issue with lift_fresh_copy when using export + compile (#108243)
  • Removed spurious warnings from calling torch.overrides.get_overridable_functions (#109890)

CPU

  • Fix NULL dereference in binary CPU ops (e57f089)
  • Fix cpuinfo related crash on ppc64 (#110708)

CUDA

  • Release GIL in torch.cuda ops wherever possible. (#109159)
  • Skipped CUDA Flags if C++ Extension Name includes "arch" Substring (#111211)
  • Don't set CUDA_HOME when not compiled with CUDA support (#106310)

Distributed

  • C10d
    • Fix gloo cuda sparse_allreduce dispatch (#111485)
    • Add timeout for master store if clients do not join (#111805)
    • Add cuda to MPI backend capabilities (#109614)
    • Fix send()/recv() to make them respect the timeout same as non-p2p collectives (#109611)
    • Change default NCCL_ASYNC_ERROR_HANDLING to 3:SkipCleanUp to avoid calling ncclCommAbort which in some cases hangs (#110723)
    • Distributed Checkpointing (DCP):
    • Fix torch.cpu has no attribute current_device in checkpoint/optimizer.py (#110299)
  • DTensor:
    • Fix DTensor.from_local() returns DTensor with wrong size for uneven sharded tensor (#110781)
    • Make DTensor handle negative dim correctly and fixed TP regression (#111750)
    • Fix pointwise op strategy linearity (#112107)
    • Fix empty shape init for DTensor constructors (#115091)
  • FullyShardedDataParallel:
    • Fix non-Node 0 unable receive parameters from Node 0 for HSDP (#108331)
    • Add device to _shard_utils.py to explicitly use the correct device from fsdp_state (#109631)
    • Propagate requires_grad attribute to unsharded params (#109892)
    • Fix logics for fsdp exec order pre fwd record (#110138)
    • Move local optimizer state to FSDP compute_device (#110929)
    • Fix the FSDP to not reshard parameters twice (#110948)
    • Fix FSDP to reset prefetch flag upon reshard (#111354)
    • Fix FSDP when SHARD_GRAD_OP and forward_prefetch is turned on (#110139)
    • Fix FSDP summon_full_params(..., with_grads=True) when grad precision is not fp32 (#112746)
    • Fix fsdp state_dict to use run_check=False (#114995)
    • Fix pylance issues for torch.distributed.fsdp (#109922)
  • RPC:
    • Fix assertion on vector length during message parsing (#108414)

Fx

  • Fix functorch.compile.minifier error of “'Node' object is not iterable” (#103011)
    • Skip mode issue in minimizer (#109399)
    • Skip the Tensor node in __annotations__ (#109853)
    • Fixed dict size change during iteration error (#111267)
    • Made sure fx code is valid in python (#113345)
    • Updated symbolic_trace’s nn_module_stack format (#114422)
    • Fixed missing meta for proxy.node (#114659)
  • Correctly restore pybind11 error_already_set (#93238)
  • Remove proxy tensor's check for data dependent output (#93265)
  • Make ShapeEnv deepcopy-able (#93403)
  • Fix SubgraphMatcher for case of no anchor found (#86421)
  • Fix for partitioner with symbolic shapes (#86425)
  • Fix getitem in partitioner and make metadata storage more consistent (#87012)
  • Fix magic method try reverse protocol (#88030)
  • Fix FakeTensorProp on Module with Parameters or Buffers (#88700)
  • Fix PassManager to not use a class variable mutable list (#89108)
  • Prevent tracing when we track_tensor_tree (#89139)
  • Make all make_fx invocations isolated (opaque to higher make_fx invocations) by default (#93290)
  • Fix matching args in PatternMatcher (#94375)
  • Allow FakeTensorProp to run on graphs traced with some None inputs (#94569)
  • Copy codegen in legalize_graph (#90023)
  • Fix proxy unwrapping for cond() (#91907)

Jit

  • Fix optimize_for_inference to support modules that don't have a forward method (#110013)
  • Fix errors found by fuzzing and sanitizers (#108417, #108413, #110303, #110441)
  • Fix deprecated python usage for python 3.12 in TorchScript (#113981)
  • Support newer versions of LLVM (#110200, #113455)

Lazy

  • Fix error when inferring shape in AdaptiveAvgPool3d (#109822)

Mps

  • Fix and refactor unary/binary ops with non-zero offset or non-contiguous output (#97085)
  • Fix memory leak in copy_from_mps_ (#114197)
  • Fix crash if nonzero is called concurrently (#108996)
  • Fix nll_loss with default ignore_index (#109574)
  • Fix sort with empty tensor. (#109584)

ONNX

  • Fix module attribute retrieval in ONNX export (#109759)
  • Add dynamic input support for MaxPool (#113318)
  • Fix op-level debug for complex dtype (#114885)
  • Fix indexing for meshgrid op (#109350)
  • Fix torch.diagonal for torch.onnx.export when dim1<0 or dim2<0 (#111130)
  • Fix scope name when parent scope is empty for torch.onnx.export (#112654)
  • Cast ‘scale’ back to float16 after _attention_scale. (#112554)
  • Fix aten::layer_norm for ONNX opset 17 (#114058)
  • Add support for negative dim in _index_fill_reshape_helper (#114050)
  • Disable opmath type promotion (#113780)

Profiler

  • Fix missing dependency in torch.utils.tensorboard (#115598) (#115598)
  • Fix torch.utils.benchmark API while use privateuse1. (#108548)
  • Ignore some properties when symbolic size/strides exist (#112458)
  • Fix description to use nelems rather than size (#114735)
  • Use PyCFunction_Check to check both PyCMethod_Type and PyC… (#110002)
  • Disable CUPTI Teardown when using CUDA Graphs (#112507)
  • Fix the Chrome trace loading issue with all_to_all input split length > 30 (#113392)

Quantization

  • Make mutation test work with quantized tensors (#108935)
  • Support transitive sharing for SharedQuantizationSpec (#111172)

Releng

Visualization

  • Fix TensorBoard summary writer encoding for torch.bfloat16 tensors (#108351)

Vulkan

  • Fix for a bug in aten::sum.dim_IntList where providing negative dims in the opt_dim argument and setting keepdim=false results in wrong inputs (#111586)

Performance

Autograd API

  • Avoid saving input for torch.mean backward (#109935)

Cpp API

  • Add ScalarTensor or 0dim overload for _foreach_add (#111079)
  • Vectorize torch.exp2 on CPU and add complex support (#92115)
  • Add various performance fixes to c++ STL usage (#94034)

Linalg API

  • Speedup torch.matrix_exp performance (#110848)
  • Improve speedup of cholesky_solve_backward using output_mask (#112981)

NestedTensor API

  • Reduce overhead in split and chunk for NestedTensor (#108213)

Optimizer API

  • Use for loop with shortcut in Optimizers to speedup inductor against list comprehensions (e.g. complex conversion) (#110613, #112722)
  • Speed up dynamo tracing of optimizer by shortcutting is_sparse iteration in foreach SGD (#110648)

Sparse API

  • Add NVIDIA A100 optimized meta parameters to bsr_dense_mm (#111760)
  • Improve triton bsr_dense_mm performance on column-major ordered inputs with float32 dtype (#108512)
  • Add bsr_dense_addmm triton kernel (#114595)
  • Use more performant bsr_scatter_mm within bsr_dense_mm when blocksize is 16. (#111489)

torch.compile API

Inductor

  • Support convolution layout optimization for models with SDPA (#112045)
  • Scaling down XBLOCK/RBLOCK to increase occupancy for kernels exposing large parallelism and having register pressue (#109275, #109839, #113039, #114284)
  • Horizontal fusion for concat (#111437)
  • Avoid an extra memory copy for views on ExternKernelAlloc (#108635)
  • More memory and performance efficient implementation for Conv+BatchNorm block in eval mode (#109398, #109722)
  • Pointwise fuse cat with pointwise inputs or outputs and <= 4 inputs (#111233)
  • Add a way to force fusion of int_mm with mul (#111413)
  • Add a heuristic to multi-layer reduction to increase the chance that the first splitted reduction can have compatible shape to fuse with a preceding reduction (#111781)
  • benchmark fusion: either use this to skip slow fusions or analyze patterns from slow fusions (#112450)
  • optimize sympy expression where div denominator is -1 (#112878)
  • Use different conv layout optimization heuristics for inference (#114600)
  • Add or improve Fx pattern matching passes
    • pre grad batch relu fusion (#111146)
    • new split cat pattern detection (#110923)
    • Add split-stack-tahn-unbind pattern detection (#111854)
    • Remove split nodes with split section size one (#112922)
    • Normalize nodes created by users (#113179)
    • post_grad batched linear fusion (#112504)
    • More SDPA patterns (#109156, #110001)
    • Horizontally fusing two matmuls in freezing phase (#111232)
    • A bunch of pattern matcher + indexing fixes (#112476)
  • CPU Backend
    • Fallback scatter_add to eager on CPU to avoid bad perf (#108220)
    • Make OneDNN matmul inputs contiguous to avoid degraded performance (#108560)

torch.func API

CPU

  • S390x complex division (#108516)
  • Add Half support for CPU autocast on eager mode (#112484)
  • Add scalar conversion using avx instructions for half (#102140)

CUDA

  • Release the allocator lock on the slow path (#108367)
  • Faster gc_count update for CUDACachingAllocator (#108071)
  • baddmm should fall back to addmm for batch=1 (#114992, #114992)
  • Speed-up casts to FP8 (#110251)
  • int4 mm kernel enhancement (#111460)
  • vectorized implementation for layer_norm_grad_input_kernel (#111021)

Distributed

  • c10d:
    • Push TCPStore scalability further by staggering client connection and increasing the backlog to 16k. (#109217)
    • Make the minimum wait time in _store_based_barrier to be adaptative based on the number of ranks. (#109218)
  • DTensor:
    • Fix and improve the sharding cache behavior (#109306, #109428)
    • Switch DTensor and Functional Collective to use optree (#110670)
    • Skip move to device when device_type match (#110774)
    • Skip pytree when not necessary (#110132)
    • Introduce cost model for sharding (#109145)
    • Group dispatch unwrapping to a method (#113846)
    • Cache hash for DTensorSpec (#113915)
    • Compute and recompute DTensorSpec hash lazily (#114322, #114379)
    • Reduce to one isinstance call in is_shard (#114140)
  • FullyShardedDataParallel (FSDP):
    • Fuse allgather for optim_state_dict when use_orig_params is True (#108298)
    • Make the new optimizer allgather fusion work with fine-tuning models (#110540)
    • Skip the parameter in optim state dict if the parameter does not belong to the current FSDP instance (#112804)

Fx

  • Use deque instead of list for BFS (#91139)
  • Refactor the dfs cyclic search from recursive to iterative approach (#91042)

Vulkan

  • Improve matrix multiplication shader performance by up to 50% through new packing schemes and algorithmic improvements (#112918, #113627, #113883, #113943)

Documentation

Autograd API

  • Improve docstring issues in various places (#113266)

Dataloader API

Linalg API

  • Fix typo in example of torch.linalg.solve_triangular (#112361)
  • Remove duplicate sentences in description of torch.linalg.eig (#108230)
  • Fix bug in matrix_power documentation (#108585)

Optimizer API

  • Update documentation for PolynomialLR (#110151)
  • Clarify maximize option in optimizer.py (#112724)
  • Fix docstring errors inside torch/cuda/ and torch/optim/ (#112964)

Python API

  • Fix docstring issues in torch.utils (#113335)
  • Add docstring to Timer.adaptive_autorange (#111612)
  • Fix a typo in torch.cholesky_inverse documentation (#110364)
  • Document torch.from_file and UntypedStorage.from_file properly (#111688)
  • Clarify difference between Tensor.share_memory_ and torch.from_file (#111856)
  • Improve torch.unique docs (#113424)
  • Fix torch.lgamma docs (#108719)
  • Update torch.take_along_dim docs to include dim=None case (#109120)
  • Fix torch.searchsorted docs (#109364)
  • Clarify torch.multinomial usage (#112892)

torch.compile API

Inductor

  • Add a tutorial for AOTInductor (#112457)
  • Add document for cudagraph_mark_step_begin API (#111722)

torch.export API

  • Add torch.cond doc (#108691)
  • Add ir spec (#110394)
  • Update docs to say that export returns full ATen IR (#111161)

torch.func API

  • Fix per-sample-grads notebook (#107988)

torch.nn API

  • Add examples for nn.CosineEmbeddingLoss(#108215)
  • Fix attn_bias in code block in scaled_dot_product_attention documentation (#109086)
  • Add documentation for torch.nn.utils.parametrizations.weight_norm (#113783)
  • Improve type annotation for device parameters when a device ordinal is allowed (#113647)
  • Update scaled_dot_product_attention documentation to point to flash-attn-v2 (#114124)
  • Fix extending torch native API docs (#114863)

Build

  • Fix doc preview page url at CONTRIBUTING.md (#108580)
  • Fix typo in cpp/installing when wheel is used (#111143)

Composability

  • Fix ScalarTensor repr in Extending PyTorch example (#86330)
  • Fix incorrect wrapping of function decorator (#94446)
  • Add all to torch.{autograd, fx, cuda} submodules (#85343)

CUDA

  • Show CUDAExtension example commands as code (#112764)
  • Rewrite docs to describe CUDACachingAllocator semantics (#113282)

Distributed

  • c10d:
    • Fix TCPStore doc for arg wait_for_workers (#111807)
    • Fix the warning messages when avoidRecordStreams_ so that correct name of Environment variable is shown in the warning message (#108759)
    • Fix an incorrect indent in documentation for dist.send (#108273)
    • Fix warnings and descriptions about distributed exceptions in the logging section of PyTorch Distributed document (#110157)
    • Fix batch_isend_irecv example incorrect usage (#110408)
    • Clarify the behavior of apply_optimizer_in_backward in the document (#110903)
    • Correct docstring errors for torch.distributed files (#112735, #113511, #113523, #112693, #113241, #113216)
    • Print NCCL_SUFFIX in NCCL version log at PG init (#112560)
    • Distributed Checkpointing (DCP):
    • Fix the comment no_dist for in load_state_dict (save -> load) (#112217)
    • Improve DDP checkpoint documentation (#106985)
  • DistributedDataParallel (DDP):
    • Fix import in DDP notes (#111833)
    • Add errors when using _dynamo.optimize_ddp=True and _inductor.keep_output_stride=False together inside DDP (#108235)
  • DTensor:
    • Improve TP documentation (#115880, #115974)
    • FullyShardedDataParallel (FSDP):
    • Fix docstring of FSDP.optim_state_dict_to_load to reflect right ctors (#108383)
    • Fix docstring for FSDP.set_state_dict_type to contain missing Args (#103864)
    • Remove "on CPU" in the comment of FSDP initialization doc (#113753)
  • TorchElastic:
    • Fix a typo in rendezvous/registry.py (#111352)
  • RPC:
    • Fix torch.distributed.rpc example incorrect usage (#112367)
  • Activation checkpointing
    • Clean up comments in activation checkpoint (#86622)
  • Distributed (c10d)
  • DistributedDataParallel
  • RPC
    • Fix non-existing parameters in docstrings in benchmarks (#91115)
  • Tensor parallelism and DTensor:
    • Add more clarifications and fix errors in tensor parallelism docs (#94786)
    • Update 2D parallelism API naming and docs (#94771)
  • FullyShardedDataParallel
    • Add docs to explain the running the forward pass of of submodules in FSDP (#86343)
    • Clarify warnings to mention collectives (#87478)
    • Remove HSDP Zero-2 from doc (#90503)
    • Improve the comments for FSDP (#92359)
  • Distributed Checkpoint
    • Enable documentation for Distributed Checkpoint. (#92813)
  • Torch Elastic
    • Fix a minor typo in documentation (#90667)
    • Fix torch.distributed.run init connect timeout by comparing host with the current IP list (#90221)

Mps

  • Resolve docstring errors (#113311)

ONNX

  • Update exporter issue report instructions for quantized models (#113494)
  • Fix sample code in onnx_dynamo.rst (#114770)

Profiler

  • Improve the docstring for export_memory_timeline (#110949)
  • Improve torch/csrc/profiler/README.md - stubs, RecordFunction, Autograd interaction (#108470)

Quantization

  • Add pt2 export quantization to main doc (#110260)
  • Add x86 inductor quantization docs (#112648)
  • Use \odot everywhere instead of mixing \odot and * for the Hadamard product (#111763)
  • Add documentation for prepare_pt2e, prepare_qat_pt2e and convert_pt2e (#110097)
  • Docstyle fix for some quantization code (#112992)
  • Updating docs for embedding_bag support for fx and eager (#107623)
  • fix docstring errors in quantized modules (#112695)

Security

Releng

  • Use secure setup-ssh action from test-infra (#111922)
  • Automate passing Conda PyTorchBot Test Token for release (#111821)
  • Migrate MacOS wheel binary builds to ephemeral M1 runners (#110432)