Skip to content

Sparse support for CUDA, bug fixes, performance improvements

Compare
Choose a tag to compare
@soumith soumith released this 02 May 22:26

API Changes

  • torch.range is deprecated in favor of torch.arange which is consistent with numpy and python range.
  • On sparse Tensors, contiguous is renamed to coalesce and coalesce is now made out-of-place.
    (a reminder that Sparse API is still experimental and evolving, so we dont provide backward-compability).

New Features

New layers and functions

  • torch.topk is now supported for all CUDA types, not just torch.cuda.FloatTensor.
  • Added a three-way ranking loss: nn.TripletMarginLoss
  • Added per-instance normalization layers: nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d
    Each channel is treated as an instance to normalize, and mean-subtraction and std-division is done. This is useful when dealing with larger images and smaller mini-batches where BatchNorm like effects are desired.
  • nn.ZeroPad2d and nn.ConstantPad2d are added.
  • nn.Bilinear is added, which computes Y = X1 * W * X2 + b

Negative dimension support for all functions

Every single function that took a dimension argument will also allow taking negative dimensions.

A negative dimension will index the tensor from the last dimension.

For example:

x = torch.randn(10, 20, 30)
y = torch.mean(x, dim = -1)

Here, since x has 3 dimensions, and dim = -1, the last dimension, i.e. dim=3 is picked for taking a mean.

The functions with dimension arguments are:

narrow, transpose, size, cat, chunk, gather, index_select, split, squeeze,
stack, unbind, unsqueeze, cumprod, cumsum, mean, median, mode, norm, prod, std,
sum, var, kthvalue, max, min, sort, topk, renorm,
index_add, index_copy, index_fill, scatter, select, unfold

CUDA support for Sparse Tensors, faster CPU sparse

Now a part of the torch.sparse API is also supported for torch.cuda.sparse.*Tensor.

Functions that are supported on CUDA are:

sparse_mask, to_dense, coalesce, transpose, spaddmm
spcadd, mul, div, cadd, csub, cmul

nn.Embedding now supports sparse even on CUDA (with the sparse=True flag) leveraging these sparse functions.

A new hybrid matrix-multiply hspmm operation that multiplies a sparse matrix with a dense matrix and returns a matrix in the form of a hybrid tensor (i.e. 1 sparse dimension, 1 dense dimension).

Several of the CPU sparse functions have more efficient implementations.

In a quickly hacked up Embedding classifier training script by @martinraison we see CUDA sparse performing as well as CUDA dense:
https://gist.github.com/martinraison/1e7c18c6f6eda87f1cb4995b0e6a22a5

Table times of seconds / batch

_ CPU CUDA
Dense 10 0.86
Sparse 0.15 0.13

named_parameters to filter out specific parameter types

Let's say that you want to add weight decay to all parameters of your model except for the biases. How do you get only the biases of your model?
We introduce nn.Module.named_parameters for this.
It joins named_children and named_modules in helping you filter specific attributes of models.

Example of filtering out biases of a model and give them weight_decay of 0:

import torch
import torch.nn as nn
import torch.optim as optim
m = nn.Sequential(
      nn.Linear(10, 20),
      nn.ReLU(),
      nn.Linear(20, 20),
      nn.ReLU(),
    )
weights, biases = [], []
for name, p in m.named_parameters():
   if 'bias' in name:
       biases += [p]
   else:
       weights += [p]

optim.SGD([
  {'params': weights},
  {'params': biases, weight_decay=0}
], lr=1e-2, momentum=0.9, weight_decay=1e-5)

Performance Improvements

  • cumsum and cumprod have been significantly made faster on the GPU via using some thrust primitives where appropriate.
  • LSTMCell and GRUCell are now significantly faster on the GPU via a fused kernel
  • The default Algorithm for CuDNN has been changed to PRECOMP_GEMM which is a
    much faster algorithm that takes a tiny bit of workspace. Previously, it used to
    be IMPLICIT_GEMM which took zero workspace, but was significantly slower.
  • 5% to 10% improvement in data loader by collating batches directly into shared memory.
  • SVD is now computed on the GPU via divide-and-conquer (sgesdd) which gives a 2x to 5x speedup.
  • The commonly used function expand has been moved to C, to have better performance in smaller models.

Bug Fixes

  • Added contiguous checks on weight and bias for a large range of THNN functions
  • make the range of random_ correct when both lower and upper bound are specified
  • parallel_apply now can take arguments that are unhashable
  • Reshape grad correctly in the Dot function (inputs don't have to be 1D vectors...)
  • Added Variable.type_as
  • Unify argument names of norm and renorm to have p=norm_type, dim=dim
  • btrisolve works on CPU doubles
  • ipython autocomplete for torch.nn.Module fixed via implementing __dir__
  • device_ids can now be None again in F.data_parallel and will use all available GPUs
  • workaround cudnn bugs in BatchNorm (<5.1.10) and Dilation (6.0.20)
  • Padding bugfix in Conv1d CPU
  • remainder and cremainder are fixed for integer types
  • fix memory leak in btrisolve and getri
  • If nn.Module's source cant be retrieved because of any exception,
    handle serialization to be non-fatal
  • collate_fn now retains the type of the numpy array
  • is_tensor and is_storage are now fixed for old-style Python classes
  • torch.cat now supports keyword arguments
  • CUDA collectives supported coalescing, but the inputs were all assumed
    to be of the same Tensor type. This is fixed.
  • Fix a deadlock bug in autograd because of an underlying glibc bug in specific
    linux distros (ArchLinux in particular)
  • abs is now fixed for char and short cuda types
  • fix torch.diag autograd when giving a dimension argument
  • fix grouped convolution on CPU when bias=False
  • expose dilated convolutions for ConvTranspose*d
  • Fix a bug in HingeEmbeddingLoss where margin can now be specified via kwargs

Improved error messages

  • Fix errors and messages when no CUDA devices are available.