Skip to content

Spectral Norm, Adaptive Softmax, faster CPU ops, anomaly detection (NaNs, etc.), Lots of bug fixes, Python 3.7 and CUDA 9.2 support

Compare
Choose a tag to compare
@soumith soumith released this 26 Jul 19:09

Table of Contents

  • Breaking Changes
  • New Features
    • Neural Networks
      • Adaptive Softmax, Spectral Norm, etc.
    • Operators
      • torch.bincount, torch.as_tensor, ...
    • torch.distributions
      • Half Cauchy, Gamma Sampling, ...
    • Other
      • Automatic anomaly detection (detecting NaNs, etc.)
  • Performance
    • Faster CPU ops in a wide variety of cases
  • Other improvements
  • Bug Fixes
  • Documentation Improvements

Breaking Changes

  • torch.stft has changed its signature to be consistent with librosa #9497
    • Before: stft(signal, frame_length, hop, fft_size=None, normalized=False, onesided=True, window=None, pad_end=0)
    • After: stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True)
    • torch.stft is also now using FFT internally and is much faster.
  • torch.slice is removed in favor of the tensor slicing notation #7924
  • torch.arange now does dtype inference: any floating-point argument is inferred to be the default dtype; all integer arguments are inferred to be int64. #7016
  • torch.nn.functional.embedding_bag's old signature embedding_bag(weight, input, ...) is deprecated, embedding_bag(input, weight, ...) (consistent with torch.nn.functional.embedding) should be used instead
  • torch.nn.functional.sigmoid and torch.nn.functional.tanh are deprecated in favor of torch.sigmoid and torch.tanh #8748
  • Broadcast behavior changed in an (very rare) edge case: [1] x [0] now broadcasts to [0] (used to be [1]) #9209

New Features

Neural Networks

  • Adaptive Softmax nn.AdaptiveLogSoftmaxWithLoss #5287

    >>> in_features = 1000
    >>> n_classes = 200
    >>> adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs=[20, 100, 150])
    >>> adaptive_softmax
    AdaptiveLogSoftmaxWithLoss(
      (head): Linear(in_features=1000, out_features=23, bias=False)
      (tail): ModuleList(
        (0): Sequential(
          (0): Linear(in_features=1000, out_features=250, bias=False)
          (1): Linear(in_features=250, out_features=80, bias=False)
        )
        (1): Sequential(
          (0): Linear(in_features=1000, out_features=62, bias=False)
          (1): Linear(in_features=62, out_features=50, bias=False)
        )
        (2): Sequential(
          (0): Linear(in_features=1000, out_features=15, bias=False)
          (1): Linear(in_features=15, out_features=50, bias=False)
        )
      )
    )
    >>> batch = 15
    >>> input = torch.randn(batch, in_features)
    >>> target = torch.randint(n_classes, (batch,), dtype=torch.long)
    >>> # get the log probabilities of target given input, and mean negative log probability loss
    >>> adaptive_softmax(input, target) 
    ASMoutput(output=tensor([-6.8270, -7.9465, -7.3479, -6.8511, -7.5613, -7.1154, -2.9478, -6.9885,
            -7.7484, -7.9102, -7.1660, -8.2843, -7.7903, -8.4459, -7.2371],
           grad_fn=<ThAddBackward>), loss=tensor(7.2112, grad_fn=<MeanBackward1>))
    >>> # get the log probabilities of all targets given input as a (batch x n_classes) tensor
    >>> adaptive_softmax.log_prob(input)  
    tensor([[-2.6533, -3.3957, -2.7069,  ..., -6.4749, -5.8867, -6.0611],
            [-3.4209, -3.2695, -2.9728,  ..., -7.6664, -7.5946, -7.9606],
            [-3.6789, -3.6317, -3.2098,  ..., -7.3722, -6.9006, -7.4314],
            ...,
            [-3.3150, -4.0957, -3.4335,  ..., -7.9572, -8.4603, -8.2080],
            [-3.8726, -3.7905, -4.3262,  ..., -8.0031, -7.8754, -8.7971],
            [-3.6082, -3.1969, -3.2719,  ..., -6.9769, -6.3158, -7.0805]],
           grad_fn=<CopySlices>)
    >>> # predit: get the class that maximize log probaility for each input
    >>> adaptive_softmax.predict(input)  
    tensor([ 8,  6,  6, 16, 14, 16, 16,  9,  4,  7,  5,  7,  8, 14,  3])
  • Add spectral normalization nn.utils.spectral_norm #6929

    >>> # Usage is similar to weight_norm
    >>> convT = nn.ConvTranspose2d(3, 64, kernel_size=3, pad=1)
    >>> # Can specify number of power iterations applied each time, or use default (1)
    >>> convT = nn.utils.spectral_norm(convT, n_power_iterations=2)
    >>>
    >>> # apply to every conv and conv transpose module in a model
    >>> def add_sn(m):
            for name, c in m.named_children():
                 m.add_module(name, add_sn(c))    
             if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                 return nn.utils.spectral_norm(m)
             else:
                 return m
    
    >>> my_model = add_sn(my_model)
  • nn.ModuleDict and nn.ParameterDict containers #8463

  • Add nn.init.zeros_ and nn.init.ones_ #7488

  • Add sparse gradient option to pretrained embedding #7492

  • Add max pooling support to nn.EmbeddingBag #5725

  • Depthwise convolution support for MKLDNN #8782

  • Add nn.FeatureAlphaDropout (featurewise Alpha Dropout layer) #9073

Operators

Distributions

  • Half Cauchy and Half Normal #8411
  • Gamma sampling for CUDA tensors #6855
  • Allow vectorized counts in Binomial Distribution #6720

Misc

Performance

  • Accelerate bernoulli number generation on CPU #7171
  • Enable cuFFT plan caching (80% speed-up in certain cases) #8344
  • Fix unnecessary copying in bernoulli_ #8682
  • Fix unnecessary copying in broadcast #8222
  • Speed-up multidim sum (2x~6x speed-up in certain cases) #8992
  • Vectorize CPU sigmoid (>3x speed-up in most cases) #8612
  • Optimize CPU nn.LeakyReLU and nn.PReLU (2x speed-up) #9206
  • Vectorize softmax and logsoftmax (4.5x speed-up on single core and 1.8x on 10 threads) #7375
  • Speed up nn.init.sparse (10-20x speed-up) #6899

Improvements

Tensor printing

  • Tensor printing now includes requires_grad and grad_fn information #8211
  • Improve number formatting in tensor print #7632
  • Fix scale when printing some tensors #7189
  • Speed up printing of large tensors #6876

Neural Networks

  • NaN is now propagated through many activation functions #8033
  • Add non_blocking option to nn.Module.to #7312
  • Loss modules now allow target to require gradient #8460
  • Add pos_weight argument to nn.BCEWithLogitsLoss #6856
  • Support grad_clip for parameters on different devices #9302
  • Removes the requirement that input sequences to pad_sequence have to be sorted #7928
  • stride argument for max_unpool1d, max_unpool2d, max_unpool3d now defaults to kernel_size #7388
  • Allowing calling grad mode context managers (e.g., torch.no_grad, torch.enable_grad) as decorators #7737
  • torch.optim.lr_scheduler._LRSchedulers __getstate__ include optimizer info #7757
  • Add support for accepting Tensor as input in clip_grad_* functions #7769
  • Return NaN in max_pool/adaptive_max_pool for NaN inputs #7670
  • nn.EmbeddingBag can now handle empty bags in all modes #7389
  • torch.optim.lr_scheduler.ReduceLROnPlateau is now serializable #7201
  • Allow only tensors of floating point dtype to require gradients #7034 and #7185
  • Allow resetting of BatchNorm running stats and cumulative moving average #5766
  • Set the gradient of LP-Pooling to zero if the sum of all input elements to the power of p is zero #6766

Operators

Distributions

  • Always enable grad when calculating lazy_property #7708

Sparse Tensor

  • Add log1p for sparse tensor #8969
  • Better support for adding zero-filled sparse tensors #7479

Data Parallel

  • Allow modules that return scalars in nn.DataParallel #7973
  • Allow nn.parallel.parallel_apply to take in a list/tuple of tensors #8047

Misc

  • torch.Size can now accept PyTorch scalars #5676
  • Move torch.utils.data.dataset.random_split to torch.utils.data.random_split, and torch.utils.data.dataset.Subset to torch.utils.data.Subset #7816
  • Add serialization for torch.device #7713
  • Allow copy.deepcopy of torch.(int/float/...)* dtype objects #7699
  • torch.load can now take a torch.device as map location #7339

Bug Fixes

  • Fix nn.BCELoss sometimes returning negative results #8147
  • Fix tensor._indices on scalar sparse tensor giving wrong result #8197
  • Fix backward of tensor.as_strided not working properly when input has overlapping memory #8721
  • Fix x.pow(0) gradient when x contains 0 #8945
  • Fix CUDA torch.svd and torch.eig returning wrong results in certain cases #9082
  • Fix nn.MSELoss having low precision #9287
  • Fix segmentation fault when calling torch.Tensor.grad_fn #9292
  • Fix torch.topk returning wrong results when input isn't contiguous #9441
  • Fix segfault in convolution on CPU with large inputs / dilation #9274
  • Fix avg_pool2/3d count_include_pad having default value False (should be True) #8645
  • Fix nn.EmbeddingBag's max_norm option #7959
  • Fix returning scalar input in Python autograd function #7934
  • Fix THCUNN SpatialDepthwiseConvolution assuming contiguity #7952
  • Fix bug in seeding random module in DataLoader #7886
  • Don't modify variables in-place for torch.einsum #7765
  • Make return uniform in lbfgs step #7586
  • The return value of uniform.cdf() is now clamped to [0..1] #7538
  • Fix advanced indexing with negative indices #7345
  • CUDAGenerator will not initialize on the current device anymore, which will avoid unnecessary memory allocation on GPU:0 #7392
  • Fix tensor.type(dtype) not preserving device #7474
  • Batch sampler should return the same results when used alone or in dataloader with num_workers > 0 #7265
  • Fix broadcasting error in LogNormal, TransformedDistribution #7269
  • Fix torch.max and torch.min on CUDA in presence of NaN #7052
  • Fix torch.tensor device-type calculation when used with CUDA #6995
  • Fixed a missing '=' in nn.LPPoolNd repr function #9629

Documentation