Skip to content

Variable-length RNNs, Better Indexing, Sparse Tensors, Faster CPU, Many Bug Fixes

Compare
Choose a tag to compare
@soumith soumith released this 15 Mar 05:58

New Features

Indexing and Broadcasting Improvements

  • Add broadcasting semantics to expand / expand_as.
    • Previously, expand had no ability to add new dimensions, and unsqueeze
      had to be used to first create singleton dimensions before expansion.
    • Now, singleton dimensions are automatically prepended to the shape of
      the tensor if a matching dimension is found.
      Here's an example:
      x = torch.rand(5)
      y = torch.rand(4, 8, 5)
      z = x.expand_as(y) # z is of shape (4, 8, 5)
      
      x = torch.rand(1, 8, 1)
      z.expand_as(y) # z is of shape (4, 8, 5)
  • Unsqueeze dimensions using None indexing
    a = torch.randn(10)
    b = a.unsqueeze(0)
    b = a[None, :]     # Equivalent operations
  • Indexing with steps is supported (only positive steps)
    In [1]: a = torch.randn(10)
    In [2]: a
    Out[2]:
    
       0.1338
       1.0789
       1.2302
      -1.3343
      -0.4676
       1.3511
      -0.4374
      -1.0611
      -0.1528
      -1.3994
      [torch.FloatTensor of size 10]
    
    In [3]: a[0:10:3]
    Out[3]:
    
       0.1338
      -1.3343
      -0.4374
      -1.3994
      [torch.FloatTensor of size 4]

Variable-length mini-batches in Recurrent Networks

nn.RNN, nn.LSTM, nn.GRU now support mini-batches where sequences are of variable
lengths.
You can pass an input of type PackedSequence
into these layers.
A PackedSequence holds data and a list of sequence sizes of a packed sequence batch.
For example, a PackedSequence will hold an input mini-batch of such sequences:

a b c d e
a b c d e f g h
a b
a b c d

Here, each input row is of variable length.

You can construct a PackedSequence using the provided function
pack_padded_sequence

pack_padded_sequence takes a Variable containing padded sequences, i.e. a Tensor
of T x B x *, where B is the size of the mini-batch, and each input is either of
length T or is padded to length T. It also takes a list of lengths of each input.
From these, it constructs a PackedSequence

For example, it will take [8, 5, 4, 2] and and an input 8 x 4 x 128
that corresponds to:

a b c d e f g h
a b c d e 0 0 0
a b c d 0 0 0 0
a b 0 0 0 0 0 0

The output of the RNN layers will also be a PackedSequence, which can then be inverted
back to a padded Tensor using the inverse function:
pad_packed_sequence

Sparse Tensors (CPU)

Original goals:

  • ability to propagate sparse updates in a network (e.g. for updating an embedding matrix)
  • ability to efficiently compute "bag-of-words" sentence embeddings (e.g. weighted average of word embeddings)

Implemented features:

  • enable backpropagation of sparse gradients without conversion to dense tensors. In most cases a runtime exception is thrown when mixing different gradient types for the same variable
  • add some methods for THSTensor: zero, elementwise add and mul, scalar mul and div
  • make addcmul method of THTensor compatible with sparse operands
  • make spmm method accessible from Python as dsmm
  • sparse_mask method on THTensor. This produces a sparse tensor from a dense tensor,
    by using a sparse tensor as a mask. A value is only present in the output sparse
    tensor if it also exists in the mask.
  • update optim.Adagrad to use sparse updates when possible.
  • leave Variable's gradient to None by default.
    This is because there is no canonical zero gradient anymore (it could be dense or
    sparse, and if it is sparse we don't know how many dimensions are sparse)
  • N-dimensional values for sparse tensors:
    • Basically for things like applying sparse updates to embedding matrices, only the
      first dimension (the one that corresponds to the word index) is sparse. The other
      dimension is always dense (only whole embedding vectors are updated). An elegant
      solution is to make the values tensor N-dimensional instead of 1-dimensional.
      For an embedding matrix, the sparse gradient will have a values tensor of
      size nnz * embedding_size instead of just nnz.

Common weight initialization methods for neural networks

By default, all Linear and Conv layers in PyTorch are initialized according to
a scheme proposed by LeCun'98.

However, there are several other commonly used initialization methods.
We now support many other methods via torch.nn.init.
Supported methods include:
uniform, normal, constant, xavier_uniform, xavier_normal, kaiming_uniform,
kaiming_normal, orthogonal, sparse

Here's an example of using these initialization methods:

import math
from torch import nn

class Net(nn.Module):
  def __init__(self):
     super(Net, self).__init__()
     self.conv1 = nn.Conv2d(5, 10, (3, 3))
     nn.init.xavier_uniform(self.conv1.weight, gain=math.sqrt(2.0))
     nn.init.constant(self.conv1.bias, 0.1)

network = Net()

Other features

  • Added a gradient checker utility torch.autograd.gradcheck that can
    be used to check your implementations. Here's a small example:
    from torch.autograd import Variable, gradcheck
    inputs = Variable(torch.randn(4, 4), requires_grad=True)
    gradcheck(lambda x: 2*x.diag(), (inputs,), eps=1e-3)
  • Add a clip_grad_norm utility to easily clip gradients via constraints on their norms.
  • Document nn.ModuleList and nn.ParameterList that are immensely useful when
    storing a list of modules in a Container
  • Optimizers have backward-compatiblity for old checkpoints.
    __set_state__ and __get_state__ introduced into optimizers.
  • Add Nesterov momentum to optim.SGD via nesterov=True kwarg
  • DataParallel supports multiple inputs and keyword args (which are also scattered)
    m = nn.DataParallel(model)
    # Now valid
    m(x, y, option=z)
    
    See the documentation for exact behavior.
  • DataLoader's default_collate now also supports numpy arrays
  • Added F.pad that supports Constant, Reflection and Replication padding in a single
    interface: http://pytorch.org/docs/nn.html#pad
  • train() now optionally supports a boolean argument. For example model.train(False)
    will set it to eval mode and model.train(True) sets it to train mode.
  • Added a DataLoader sampler: SubsetRandomSamplerthat takes a list of indices
    in it's constructor and randomly samples from these indices. Useful when you
    want to sample only a particular subset of your dataset.
  • Transpose supports negative dimensions. For example:
    a = torch.randn(2, 3)
    b = a.transpose(0, 1)   # both are equivalent
    b = a.transpose(-2, -1) # both are equivalent

Performance Improvements

  • CPU Tensor backend gets faster
    • Explicit AVX, AVX2 and improved SSE intrinsics to speedup copy, fill, add, mul, div
    • Much improved speed for all apply and reduce operations to have better cache hits
    • Added OpenMP in TH_TENSOR_APPLY* operations
    • Overall, 2x to 10x+ faster on a lot of operations, closer to Numpy speeds
    • Runtime dispatch of intrinsics based on CPU features (easy to ship binaries)
  • Serialization Improvements
    • Fixed bugs on serialization for Tensors > 2GB
    • 5x to 10x faster serialization (no longer Tarring Tensors)

Bug Fixes

  • Multi-GPU CuDNN RNN now has separate dropout descriptors per GPU
  • NLLLoss2d has proper shape checks on GPU and stable sizeAverage formulation
  • LogSoftmax2d has a more stable formula
  • Fix prodall (prod without dim arguments) to not average
  • Return correct number of gradients from cuDNN RNN
  • NLLLoss2d has support for weights
  • Fix Unpooling bug for MaxPool1d
  • Fix Indexing when using only an ellipsis
x = torch.randn(2,2,2,2)
x[...] # used to fail, fixed now.
  • expose stateless methods (torch.*`` methods) for torch.cuda.HalfTensor`
  • Prevent creation of reference cycles (and hence improve memory usage) when
    leaf variables were using in-place operations.
  • Fix gradient computation for the indexing operation in the case of sending in
    LongTensor.
  • Fix a reshaping bug in the grad_input of basic operations such as +, -, *, / etc.
    This used to fail, but is fixed now:
    x = Variable(torch.randn(4, 6), requires_grad=True)
    b = Variable(torch.rand(12, 1) + 1e-2, requires_grad=True)
    (x + b.mm(Variable(torch.rand(1, 2) + 1e-2))).sum().backward()
  • Revert partial indexing with LongTensor to return to numpy-compatibility
  • References to some Tensors in BatchNorm and Conv are now freed to improve
    memory usage in certain situations. ResNet-152 finetuning with batch_size 16
    used to consume the same amount of memory as batch 256 after this fix.
  • Fix a bug where requires_grad was being propagated forward differently in
    CPU mode and CUDA mode.
  • Fix bugs in torch.multinomial on CUDA, where in rare cases, the sampling
    lead to nonsensical values
  • Allow backprop through CuDNN RNN in eval() mode.
  • Support np.int16 in conversion to ShortTensor
  • Enable multithreading in MKL (was disabled previously due to a cmake bug).

Improved error messages

  • Print a readable error message when arguments are on different GPUs
  • Add better error message for conversion of CUDA tensors to numpy
  • Add checks for reward type and size in StochasticFunction