Skip to content

Version 1.0.0

Latest
Compare
Choose a tag to compare
@levi131 levi131 released this 11 Mar 03:18
· 4 commits to master since this release
2491c5b

Version 1.0.0

OneFlow v1.0.0 release note

OneFlow v1.0.0 came out, welcome to install the new version for a better experience.

  • Highlights
  • New Features
  • Improvements
  • Changes and Fixes
  • Performance

Highlights

This version update includes 447 commits and the following highlights:

  • Released a new interface compile_from_torch. This interface, while sharing the parameter memory, converts a PyTorch Module instance into a OneFlow Module instance. It supports direct Eager execution or conversion into a static graph nn.Graph, further accelerating the process using MLIR compilation. This interface is rapidly evolving and currently supports dynamic shape compilation, validated across typical models such as ResNet50, Faster RCNN, and Stable Diffusion.

  • Made a series of optimizations and refactoring to Eager execution runtime, including unification of system memory pools, integration with CUDA native interfaces, optimization of instruction scheduling mechanisms, introduction of an instruction fusion mechanism, optimization of Autograd graph construction speed, optimization of Op inference process, and decoupling of Instruction and Stream, etc.

  • The static graph distributed physical execution plan supports separate compilation functionality, allowing each process to independently compile its required execution plan, eliminating linear growth of compilation time with GPU scale.

  • Addition of a series of functional automatic differentiation related interface supports, including jvp, vjp, hvp, vhp, jacobian, and hessian.

  • Addition of the Insight module, supporting visualization of kernel invocation, execution time, speed, and other related information within the embedded point intervals.

  • Updates to LiBai (the open-source toolbox for large-scale model training), with native support for fine-tuning and distributed inference of Llama2 and ChatGLM2, supporting full finetune, adapter finetune, lora finetune. lm-eval-harness can be used for language model evaluation and validation.

  • Upgrade of OneFlow Serving functionality, adding support for OneFlow Python backend and OneFlow Lite backend, in addition to the existing support for OneFlow Cpp backend.

New Features

1. compile_from_torch

The compile_from_torch interface, while sharing the parameter memory, converts a PyTorch Module instance into a OneFlow Module instance. It supports direct Eager execution or conversion into a static graph nn.Graph, further accelerating the process using MLIR compilation. (#10404, #10408, #9984, #9754)

Interface Signature and Parameter Introduction:

compile_from_torch(torch_module: torch.nn.Module, \*, use_graph=True, options={})
* torch_module: The Torch Module instance to be converted.
* use_graph: Indicates whether to transform into a static graph nn.Graph and utilize MLIR compilation acceleration. The default is True.
* options:
  * size: When using static graph nn.Graph, the hash value of the graph corresponding to the input shape will be calculated and cached. Size indicates the maximum capacity of the static graph cache. When exceeding the maximum capacity, the graph will be cleared based on the LRU strategy. The default value is 9.
  * dynamic: For the first input with a dynamic shape, the graph will be fully compiled. For subsequent inputs with different shapes, if dynamic is True, shared graph will be used for compilation acceleration; if dynamic is False, the compilation will be performed each time. The default is True.
  * debug: Debug mode and log level settings. -1 disables debug mode, 0 outputs warnings and static graph construction information, 1 additionally outputs graph construction information for each sub-module, 2 additionally outputs progress for each operator, 3 provides more detailed operator information. The default value is -1.

Example of Usage:

import torch
from torchvision import models
import oneflow
from oneflow.framework.infer_compiler import compile_from_torch
DEVICE = torch.device("cuda")
WEIGHT = models.ResNet50_Weights.DEFAULT
model = models.resnet50(weights=WEIGHT).to(DEVICE)
compile_model = compile_from_torch(model, options={"dynamic": True})

2. Separated Compilation

The static graph distributed physical execution plan supports separate compilation , allowing each process to independently compile its required execution plan, thereby preventing linear growth of compilation time with GPU scale. The separate compilation feature supports 3D hybrid parallel (data parallelism + model parallelism + pipeline parallelism) scenarios and can be used together with LiBai (the open-source large-scale model training toolbox). To enable this feature, use the command: export ONEFLOW_ENABLE_LAZY_SEPARATE_COMPILE=1. (#9920, #10140, #10141, #10124, #10102)

Below are the test results on a 128-card A100-PCIE-40GB device with LiBai on the GPT2 model:

Parallelism Separated Compilation Enabled Execution Plan Compilation Time
Data Parallelism (DP128 MP1 PP1) No Over 20 minutes
Data Parallelism (DP128 MP1 PP1) Yes 108.21 s
3D Parallelism (DP4 MP4 PP8) No 445.16 s
3D Parallelism (DP4 MP4 PP8) Yes 82.88 s

3. Functional Automatic Differentiation Interfaces

A series of functional automatic differentiation-related interfaces have been introduced, including jvp, vjp, hvp, vhp, jacobian, and hessian. (#10412, #10428)

Example of Usage:

import oneflow as flow

# jacobian example
def exp_reducer(x):
    return x.exp().sum(dim=1)

input = flow.rand(2, 2)
jac_rslt = flow.autograd.functional.jacobian(exp_reducer, input)

# vhp example
def pow_reducer(x):
    return x.pow(3).sum()

input = flow.rand(2, 2)
v = flow.ones(2, 2)
vhp_rslt = flow.autograd.functional.vhp(pow_reducer, input, v)

4. Insight Module

Introduced a new Insight module, enabling the visualization of kernel invocation, execution time, speed, and other related information within the embedded point intervals. (#10370)

Usage:

  • Step 1: Set embedded point intervals in the code using the OneFlow Profiler module.
  • Step 2: Run the code and use NVIDIA Nsight Systems to generate a .sqlite file.
  • Step 3: Use the OneFlow Insight module to generate a .json file.
  • Step 4: Open the .json file in the browser at chrome://tracing/ or edge://tracing/ to obtain the visualization interface.

For more detailed information, please refer to: https://github.com/Oneflow-Inc/oneflow/tree/master/python/oneflow/utils/insight#usage

5. LiBai Version Update

  • LiBai (the open-source toolbox for large-scale model training) has been upgraded to version v0.3.0. It now natively supports finetuning and distributed inference of large language models Llama2 and ChatGLM2. It supports full full finetune, adapter finetune, lora finetune. lm-eval-harness can be used for language model evaluation and validation.

  • The distributed training and inference support for ChatGLM and Llama2 are as follows:

Models 2D (tp+pp) Inference 3D Parallel Training
ChatGLM
Llama2

Example of Usage:

# full finetune
bash tools/train.sh projects/Llama/train_net.py projects/Llama/configs/llama_sft.py 8
# adapter finetune
bash tools/train.sh projects/Llama/adapter/train_net.py projects/Llama/adapter/adapter_sft.py 8
# inference
bash tools/infer.sh projects/Llama/pipeline.py 8
# eval
python projects/Llama/utils/eval_adapter.py

6. Other New Features

  • Added FFT-related operators. (#10027)

  • Added zeta operator. (#10189)

  • Added tril_ operator. (#9996)

  • Added clone operator. (#9800)

  • Added frac and frac_ operator. (#9979)

  • Added exp2 operator. (#9958)

  • Added rrelu operator. (#9736)

  • Added lgamma backward operator. (#10177)

  • Added digamma operator. (#10066)

  • Added trigamma operator. (#10117)

  • Added bitwise_not operator. (#9859)

  • Added squared_relu operator. (#10316)

  • Added skip_rms_norm operator. (#10036)

  • Added multi_tensor_amp_grad_scaler related operators. (#10071)

  • Added bitwise_and, bitwise_or, bitwise_xor operators. (#9842)

  • Added fused_attention_concat_past_key_value operator. (#9963)

  • Added fused_multi_head_attention_inference_v2 operator. (#9933)

  • Added fused_codegeex_qkv_reshape operator. (#9927)

  • Added fused_apply_rotary_emb operator. (#9914)

  • Added skip_layer_norm operator. (#9906)

  • Added groupwise_dequantize, fused_linear_with_groupwise_quantized_weight operators. (#9900)

  • Added fused_scale_mask_bias_softmax, fused_scale_mask_bias_softmax_grad operators. (#9867)

  • Added depend operator for describing dependency relationships in the computation graph. (#9807)

  • Added operators for handling complex data types: real, imag, conj, conj_physical. (#10034, #10281)

  • Added CPU support for the nms operator. (#10225)

  • Added support for the cast operator to convert bool to int16 data type. (#10211)

  • Added support for the arange operator for the fp16 data type. (#10019)

  • Added support for the adaptive_avg_pool operator for the fp16 data type. (#10004)

  • Added support for the nonzero operator for the fp16 data type. (#9826)

  • Added support for the exponential operator for the half data type. (#10005)

  • Added support for the arg_sort and top_k operators for the half data type. (#10000)

  • Added support for some basic operators like add, sub, mul, mm, sqrt, div for complex data types. (#10269, #10136, #10284, #10049)

  • Added support for basic binary operators for discontinuous memory input tensors. (#9986)

  • Added a virtual jit interface to support mocking of torch for user code that imports but does not actually use the interface. (#10395)

  • Added the mem_get_info interface to return overall and free memory information for a specified CUDA device. (#10398)

  • Added the tensor.new interface. (#9881)

  • Added the tensor.is_cpu interface. (#10172)

  • Added the tensor.is_view interface. (#10101)

  • Added the tensor.data_ptr interface. (#10111, #10139)

  • Added the tensor.baddbmm interface. (#9918)

  • Added interfaces like special.erf, special.erfc, etc. (#9982)

  • Added the layout and frombuffer interfaces. (#10171)

  • Added prune-related interfaces. (#9730)

  • Added the utils.model_zoo interface. (#10183)

  • Added the get_rng_state and get_rng_state_all interfaces. (#9760)

  • Added the set_rng_state and set_rng_state_all interfaces. (#10250)

  • Added support for the float16 data type. (#9697)

  • Added support for the char and short data types. (#10086)

  • Added support for the complex64 and complex128 data types. (#9987)

  • Integrated Transform Dialect into the MLIR codegen process. (#10224, #10227)

  • Added code generation support for the matmul operator. 。(#10283)

  • Added code generation support for the softmax operator. (#10263, #10272)

  • Added code generation support for the transform.oneflow.apply_patterns operator. (#10255)

  • Added support for byte attributes in the MLIR codegen process. (#10276)

  • Added extra_libs functionality to the mock_torch module, enabling flowvision to mimic torchvision's functionality. (#10223)

  • Added lazy parameter to the mock_torch module, allowing non-existent interfaces to return a fake object without immediate errors. (#9876)

  • Added skip_init functionality and introduced meta device. (#10008)

  • Introduced the HostMemoryInput mechanism, allowing an operator's specific input to be defined as HostMemoryInput type for accessing data within the kernel's host function body. (#9928)

  • Added fusion mechanism for nccl logical operations to reduce excessive synchronization overhead in scenarios like ZERO, where too many fragmented nccl calls lead to significant training speed reduction. (#9879)

  • Introduced a mechanism for re-computation of tensor operations. (#9861)

  • Added support for backward_hook, register_full_backward_hook, and register_state_dict_pre_hook. (#9837, #9710)

  • Added content related to the stochastic weight averaging algorithm to the optimizers module. (#9781)

  • Added graph-level flattening algorithm. (#9718, #9748)

  • Added DelayVariableOpExecutionPass optimization pass for the computation graph. (#9745)

  • Added MulCastPattern operator fusion rule. (#9715)

  • Added the environment variable ONEFLOW_ENABLE_GLOBAL_INPUTS_WITH_INCONSISTENT_PLACEMENT to control whether to automatically place global tensors used by operators through the to_global operation on the largest rank. (#10073)

  • Added the environment variable ONEFLOW_EAGER_NCCL_USE_COMPUTE_STREAM to control whether nccl and regular computations in eager mode are on the same stream. The default value is false. (#10230)

  • Added the environment variable VLOG_REMAT to handle dynamic graph recomputation logs and interface with ComputeComplexityFn to estimate op computation time. (#10212)

  • Added the environment variable ENABLE_ACTOR_DEBUG_LOG to print detailed logs of actor message sending, receiving, and execution on the current rank. (#10081)

  • Added the environment variable ONEFLOW_RUN_GRAPH_BY_VM to control whether to use VM to run static graph nn.Graph. (#9884)

  • Added the environment variable ONEFLOW_DISABLE_MOCK_TORCH to control whether to disable the mock_torch functionality. (#9805)

  • Added the environment variable ONEFLOW_VM_MULTI_THREAD to control the number of threads used in the VM. (#9698)

  • Added support for the second-order optimizer lbfgs. (#10265)

Improvements

1. Eager Runtime Optimization and Refactoring

A series of optimizations and refactoring has been implemented for the Eager runtime, including:

  • Unified system memory pool to manage memory resources across all allocators on the same device. (#8591)

  • Integration with CUDA native interfaces to accelerate kernel launches.(#8571)

  • Optimization of instruction scheduling mechanism to reduce system overhead.(#8796)

  • Introduction of instruction fusion mechanism to accelerate instruction dispatch. (#7399)

  • Speed improvement in Autograd graph construction. (#8606)

  • Optimization of op deduction process to accelerate kernel execution. (#8672, #8619, #8662)

  • Consolidation of redundant concepts within the eager runtime, decoupling Instruction and Stream. (#8583, #8590, #7607)

Users can configure the Eager runtime using various environment variables:

Environment Variable Meaning Default Value
ONEFLOW_VM_COMPUTE_ON_WORKER_THREAD Whether to perform computation on worker threads true
ONEFLOW_VM_MULTI_THREAD Whether to use multi-threaded collaboration for Eager computation true
ONEFLOW_VM_ENABLE_STREAM_WAIT Whether to use stream_wait mechanism for dependencies between multiple streams true
ONEFLOW_VM_ENABLE_SCHEDULE_YIELD Whether to use yield mechanism to reduce scheduler thread's busy wait true
ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE Whether to cache operator output metadata during computation true
ONEFLOW_VM_WORKER_THREAD_LIMIT Number of worker threads 16
ONEFLOW_VM_PENDING_HANDLE_WINDOW_SIZE Maximum size for fusing vm instructions 10
ONEFLOW_VM_BLOCKING_DEBUG_INSTRUCTIONS_DISPLAY_LIMIT Number of unprocessed instructions to be printed when vm execution times out 1000

2. Upgrade of OneFlow Serving Features

OneFlow Serving features have been upgraded to support additional backends, including OneFlow Python backend and OneFlow Lite backend, in addition to the existing support for the OneFlow Cpp backend.

  • The OneFlow Cpp backend enables deployment in a Python-independent environment to achieve the highest performance.
  • The OneFlow Lite backend enables deployment on edge devices.
  • The OneFlow Python backend facilitates the deployment of complex models with minimal migration cost.

For usage instructions, refer to: https://github.com/Oneflow-Inc/serving/blob/main/README.md

3. Other Functionality Improvements

  • Optimized certain code implementations to accommodate CUDA 12.x. (#10367)

  • Optimized the glu operator implementation to support bias-less inputs.(#9874)

  • Optimized pooling operator implementation to support the channels_last parameter. (#10242)

  • Optimized the flip operator implementation to address memory access inefficiencies when dim = -1. (#10310)

  • Optimized the bincount operator implementation for accelerated performance. (#10308)

  • Optimized the index_add operator implementation by dispatching varied logic based on index length to enhance performance for smaller indices.(#9751)

  • Optimized the topk operator implementation to boost performance when batch size equals 1. (#10009)

  • Optimized implementations of operators such as conv and arange to facilitate CUDA graph usage. (#9761)

  • Optimized the upsample operator implementation to include input/output size validation.(#9737)

  • Optimized the grouped_matmul_bias operator implementation by introducing tensor parallelism sbp derivation rules. (#9934)

  • Optimized the reshape operator implementation with added nd sbp derivation rules. (#9858)

  • Optimized error messages and completed test cases for mask_fill and in_top_k operators. (#10062)

  • Optimized the higher-order differentiation rules for the tanh operator to optimize performance under third-order differentiation. (#10188, #10237)

  • Optimized conv interface implementation to support device and dtype parameters. (#10228)

  • Optimized conv interface implementation to automatically expand input dimensions.(#9721)

  • Optimized sum interface implementation to accommodate dtype parameters.(#10204)

  • Optimized softmax interface implementation to support dtype parameters. (#10069)

  • Optimized maxpool interface implementation to support 3D input tensors. (#10110)

  • Optimized ctc_loss interface implementation parameters with PyTorch interface. (#9887)

  • Optimized copy interface implementation to support scenarios where input and output have different devices and dtypes. (#9888)

  • Optimized grad interface implementation to support the allow_unused parameter.(#10251)

  • Optimized load interface implementation to provide more user-friendly error messages.(#10138)

  • Optimized fused_matmul_bias operator and interface implementation to support alpha and beta parameters. (#10015)

  • Optimized normal operator and interface implementation to align behavior with PyTorch. (#10185)

  • Optimized fused attention operator and interface implementation to allow None for pasti_key and past_value. (#9977)

  • Optimized fused_attention operator and interface implementation to add support for variable sequence lengths. (#9991)

  • Optimized fused_multi_head_attention_inference operator and interface implementation to include attn_bias parameter. (#9853)

  • Optimized bn-related functor implementation. Merging bn_add_relu and bn_relu operations to expedite inference. (#10239)

  • Optimized MLIR CodeGen-based processes and upgraded LLVM version to 16.0.0. (#9985)

  • Optimized MLIR codegen-based processes by adding AppendOneFlowStream, MgpuToOneFlowStream, and CastOneFlowInputToSignlessPass passes. (#10149, #10151, #10099)

  • Optimized MLIR codegen-based processes by linking LibDevice to support NVVM IR conversion to cubin. (#10200)

  • Optimized MLIR codegen-based processes by utilizing tmpbuffer as MemPool in MLIR. (#10159)

  • Optimized MLIR codegen-based processes by enabling bufferizable operator dispatch. (#9787)

  • Optimized MLIR codegen-based processes to expedite ofmempool and related processes. (#10152, #10168, #10184, #10239)

  • Optimized stacktrace call stack information.(#9912, #9937, #10260, #10161)

  • Optimized random number generator implementation by adding caching to avoid regeneration with each call. (#10387)

  • Optimized graph load functionality to support loading the graph onto a new device.(#10335)

  • Optimized dummy array initialization implementation using fold expressions. (#10271)

  • Optimized MemoryFormat class organization, exposed to Python layer via cpython to support changing tensor's MemoryFormat using Tensor.to interface. (#10181)

  • Optimized implementations of steam, device, and vm to support more device types. (#10166)

  • Optimized error messages for MapAt, adding printing of key values.(#10090)

  • Optimized OOM error messages to differentiate CUDA and CPU devices and display size. (#9938)

  • Optimized error messages for CHECK_XX_OR_RETURN macros. (#9921)

  • Optimized error messages for graph-related issues. (#9821)

  • Optimized error messages for convolution operator-related issues. (#9707)

  • Optimized model initialization to minimize additional overhead. (#10088)

  • Optimized thread manager implementation to accommodate three usage scenarios: unrestricted threads, master as a thread, and n threads. (#10060)

  • Optimized numpy array release mechanism to release in the main thread to reduce time-consuming GIL requests. (#10050)

  • Optimized graph save runtime_state_dict implementation to enhance performance and address related issues. (#10016)

  • Optimized parsing of different calling methods for interfaces like Tensor.foo(*args) using a unified PyParseArgs function. (#9983)

  • Optimized the implementation of the ArgsTree class to support arbitrary output types and conducted file location migration. (#9846)

  • Optimized memory allocation mechanism to achieve ordered allocation based on streams. (#9818)

Changes and Fixes

1. Functional Changes

  • Removed deallocate context. (#10143)

  • Removed debug compilation mode in graph compilation. (#10145)

  • Removed unused logic for MemChain merge. (#10097)

  • Removed default settings for some unused distributed environment variables. (#9803)

  • Refactored collective boxing implementation under lazy mode. (#10098)

  • Refactored registration of EagerCclS2S.(#10100)

  • Refactored implementation of collective_boxing_executor_backend. (#10082)

  • Refactored implementation of running global nn.graph using VM. (#10048)

  • Refactored implementation of local to global related interfaces.(#9870)

  • Refactored operator dispatch dialect implementation in MLIR codegen process. (#9693)

  • Refactored implementation of random generator and distribution kernels. (#9691)

  • Refactored implementation of fast_atomic_add operator. (#9680)

  • Refactored error check related macros in glog. (#10176)

  • Refactored implementation of random generator. (#10025)

  • Refactored implementation of some elementwise primitive operations. (#9857)

  • Refactored code related to device descriptions. (#9791)

  • Refactored implementation of ParseDeviceString and ParseDeviceNameConf. (#9833)

  • Refactored implementation of ActorMsg related functionalities, introducing IBVerbsActorMsgWrapper wrapper to reduce the size of ActorMsg. (#9762)

  • Refactored implementation of save and load interfaces, migrating the method of saving graphs to the _save_graph function, adding some _open* helper classes to differentiate between paths and memory, enabling saving weights to BytesIO in save, and supporting file streaming in load. (#10021)

  • Refactored implementation of some tensor-related interfaces, migrating code from Python layer to C++ layer. (#9990, #9964)

  • Upgraded PyBind version used in the project to 2.11.1. (#10391)

2. Bug Fixes

  • Fixed default dynamic linking settings in CMake files to avoid LLVM15 linking errors. (#10373, #10131)

  • Fixed cast-related bugs in MLIR codegen. (#10105)

  • Fixed logic handling for cpg attr in Module._apply function. (#10343)

  • Fixed inheritance issue for DummyModule when attr is mro_entries. (#9976)

  • Fixed size checking issue for _handle_size_arg in full op. (#9975)

  • Fixed residual environment variables after launching mock via command line, causing subsequent API mock parameter errors. (#9970)

  • Fixed inability to exit when two processes encounter exceptions. (#10054)

  • Fixed bug in grouped quantization sbp derivation. (#10132)

  • Fixed kMaxInputCount check issue in GroupedMatmulFunctor. (#10322)

  • Fixed 0-size tensor broadcast issue.(#10186)

  • Fixed issue where double type attr was not updated when using shared_graph. (#10279)

  • Fixed data type error in GetItemInScalarTensor. (#10226)

  • Fixed gradient issue in GroupNorm, calling GroupNormParamGrad only when gamma and beta gradients are required. (#10045)

  • Fixed error when reading tensors with partial ranks in global mode. (#10056)

  • Fixed control boundary issues in checkpointing under PP, affecting task graph construction under separate compilation. (#10057)

  • Fixed bug when using 3D parallelism and enabling activation checkpointing simultaneously. (#10031)

  • Fixed adaptation bug of AutoMixedPrecision pass on non-CUDA devices and bug related to device combinations in LayerNorm Module. (#10026)

  • Fixed default value setting issue for reduce parameter in scatter operator. (#10002)

  • Fixed incomplete disable of some Torch variables in mock.disable, causing lingering references in other globals. (#9989)

  • Fixed destructor issue in vm::TensorStorage. (#9962)

  • Fixed offload issue where small tensors were not released from CUDA memory.(#9974)

  • Fixed occasional segmentation fault in Python stack getter due to thread unsafety.(#9955)

  • Fixed element lookup issue in set under separate compilation scenario. (#9952)

  • Aligned qkv and output_layout in fused_multi_head_attention operator. (#9950)

  • Fixed inconsistency in seed behavior of random series operators between graph and checkpointing. (#9941)

  • Fixed parameter reload failure issue in Eager mode. (#9935)

  • Fixed infinite loop issue in specific cases of mock torch lazy functionality. (#9926)

  • Fixed issue where code in stft_kernel.cu file was not compiled by default. (#9922)

  • Fixed deadlock and memory allocation errors caused by invalid topological order due to incomplete TaskGraph under separate compilation in order_in_graph. (#9909 )

  • Fixed xrt compilation issue where fmt could not be found. (#9894)

  • Fixed imbalance in GPU memory allocation among processes during local to global process where sbp is B. (#9852)

  • Aligned OneFlow and PyTorch behaviors related to the third parameter of CTCLoss. (#9845)

  • Fixed initialization issues related to thread_global_id and rank_group_scope. (#9841)

  • Fixed inplace handling errors in dropout operator implementation. (#9808)

  • Fixed errors in loading non-tensor objects saved by PyTorch in the load function. (#9804)

  • Fixed conflicts between contiguous memory and GPU memory allocation strategies. (#9786)

  • Fixed memory allocation issues in EagerBlobObject::ByteSizeOfBlobBody when considering non-contiguous cases. (#9782)

  • Fixed dtype inference errors in fill_ operator during autocast. (#9776)

  • Fixed sbp derivation rule issues in fused_glu operator. (#10108)

  • Fixed issues related to calling nn.Graph.__map_io. (#10084)

  • Fixed inconsistency between set_grad_mode interface and PyTorch behavior. (#10059)

  • Fixed an issue related to the map_location parameter in the load interface and added support for passing lambda functions. (#10052)

  • Fixed stride inference errors after unsqueeze operation in view mode. (#9775)

  • Fixed problems in conv op with unbatched input and bias, and added support for unbatched input in deconv op. (#9740)

  • Fixed logic errors in trunc_normal_ implementation. (#9711)

  • Fixed default value issue in dim parameter of topk operator. (#9703)

  • Fixed issues where placement of some networks was incorrectly set to CPU during static graph printing. (#9770)

  • Fixed conflict between include paths of trt_flash_attention and native flash attention. (#9750)

  • Fixed segmentation fault caused by is_shutting_down and gil in stack getter. (#9681)

  • Fixed issues related to the separate compilation feature found in distributed unit testing.(#9749)

  • Fixed memory handling issues in flatten algorithm implementation. (#9746)

  • Fixed a deadlock issue in the execution flow. (#9738)

  • Fixed errors in isinstance check for DummyModule. (#10207)

  • Corrected behavior where default size was erroneously overridden when introducing llvm::SmallVector. (#9932)

  • Fixed errors in calculating memory size of non-contiguous memory tensors. (#9819)

  • Fixed issues with calling CHECK_JUST in the TensorStorage destructor function. (#9752)

Performance

1. OneFlow compile_from_torch VS PyTorch compile

Compile and execute the backbone parts of ResNet50 and Faster RCNN models using OneFlow compile_from_torch and PyTorch compile interfaces to test the compilation time with inputs of different shapes. The results are shown in the table below:

Model input shape PyTorch compile OneFlow compile_from_torch dynamic test timing
ResNet50 (1, 3, 512, 512) 21.328 s 3.205 s False initial compilation and execution
ResNet50 (2, 3, 896, 512) 14.167 s 1.523 s False continuous compilation and execution
ResNet50 (2, 3, 512, 896) 13.364 s 1.402 s False continuous compilation and execution
ResNet50 (3, 3, 896, 896) 15.056 s 1.539 s False continuous compilation and execution
ResNet50 (2, 3, 1024, 896) 14.167 s 1.500 s False continuous compilation and execution
ResNet50 (2, 3, 896, 1024) 12.891 s 1.494 s False continuous compilation and execution
ResNet50 (6, 3, 1024, 1024) 14.859 s 1.872 s False continuous compilation and execution
ResNet50 (1, 3, 512, 512) 170.446 s 3.143 s True initial compilation and execution
ResNet50 (2, 3, 896, 512) 185.672 s 0.851 s True continuous compilation and execution
ResNet50 (2, 3, 512, 896) 0.089 s 0.836 s True continuous compilation and execution
ResNet50 (3, 3, 896, 896) 0.084 s 0.980 s True continuous compilation and execution
ResNet50 (2, 3, 1024, 896) 0.077 s 0.942 s True continuous compilation and execution
ResNet50 (2, 3, 896, 1024) 0.080 s 0.931 s True continuous compilation and execution
ResNet50 (6, 3, 1024, 1024) 0.084 s 1.406 s True continuous compilation and execution
Faster RCNN (1, 3, 512, 512) 18.224 s 5.483 s False initial compilation and execution
Faster RCNN (2, 3, 896, 512) 9.200 s 3.011 s False continuous compilation and execution
Faster RCNN (2, 3, 512, 896) 9.331 s 3.025 s False continuous compilation and execution
Faster RCNN (3, 3, 896, 896) 9.301 s 2.854 s False continuous compilation and execution
Faster RCNN (2, 3, 1024, 896) 9.290 s 2.805 s False continuous compilation and execution
Faster RCNN (2, 3, 896, 1024) 9.123 s 2.851 s False continuous compilation and execution
Faster RCNN (6, 3, 1024, 1024) 9.377 s 3.180 s False continuous compilation and execution
Faster RCNN (1, 3, 512, 512) 25.444 s 5.430 s True initial compilation and execution
Faster RCNN (2, 3, 896, 512) 25.381 s 1.899 s True continuous compilation and execution
Faster RCNN (2, 3, 512, 896) 0.116 s 1.886 s True continuous compilation and execution
Faster RCNN (3, 3, 896, 896) 1.982 s 1.793 s True continuous compilation and execution
Faster RCNN (2, 3, 1024, 896) 0.114 s 1.803 s True continuous compilation and execution
Faster RCNN (2, 3, 896, 1024) 0.111 s 1.778 s True continuous compilation and execution
Faster RCNN (6, 3, 1024, 1024) 0.143 s 2.110 s True continuous compilation and execution

Using the OneFlow compile_from_torch and PyTorch compile interfaces, the unet section of the Stable Diffusion model was compiled and executed to test the compilation time and execution time with outputs of different shapes. The results are presented in the table below:

Model Output shape PyTorch compile OneFlow compile_from_torch dynamic test timing
Stable Diffusion (2, 512, 512) 103.701 s 63.670 s False initial compilation and execution
Stable Diffusion (1, 512, 768) 95.137 s 53.864 s False continuous compilation and execution
Stable Diffusion (2, 768, 512) 90.259 s 55.271 s False continuous compilation and execution
Stable Diffusion (1, 768, 768) 90.196 s 51.590 s False continuous compilation and execution
Stable Diffusion (2, 512, 512) 275.660 s 57.117 s True initial compilation and execution
Stable Diffusion (1, 512, 768) 345.774 s 43.752 s True continuous compilation and execution
Stable Diffusion (2, 768, 512) 349.835 s 47.653 s True continuous compilation and execution
Stable Diffusion (1, 768, 768) 7.224 s 45.720 s True continuous compilation and execution
Stable Diffusion (2, 512, 512) 4.088 s 2.831 s False subsequent execution
Stable Diffusion (1, 512, 768) 3.296 s 2.325 s False subsequent execution
Stable Diffusion (2, 768, 512) 5.594 s 5.157 s False subsequent execution
Stable Diffusion (1, 768, 768) 4.713 s 3.557 s False subsequent execution
Stable Diffusion (2, 512, 512) 4.448 s 2.801 s True subsequent execution
Stable Diffusion (1, 512, 768) 3.201 s 2.314 s True subsequent execution
Stable Diffusion (2, 768, 512) 6.093 s 4.166 s True subsequent execution
Stable Diffusion (1, 768, 768) 4.920 s 3.557 s True subsequent execution

Conclusion: The OneFlow compile_from_torch interface generally has shorter compilation times compared to the PyTorch compile interface. Additionally, benefiting from the exceptional operator optimizations in the OneFlow framework, there is superior execution performance on the Stable Diffusion model.

Note: The tests were conducted with GPU 3090, PyTorch v2.1.2 and CUDA 12.2.

2. OneFlow Eager vs PyTorch Eager

Model GPU model number of GPUs macro batch PyTorch performance(iter/s) OneFlow performance(iter/s) speedup ratio
ResNet50 3090 1 1 31.37 38.81 23.72%
ResNet50 3090 1 2 32.06 48.45 51.12%
ResNet50 3090 2 1 31.10 33.46 7.59%
ResNet50 3090 2 2 31.76 34.83 9.67%
ResNet50 A100 1 1 24.60 46.64 89.59%
ResNet50 A100 1 2 25.06 49.88 99.04%
ResNet50 A100 2 1 25.28 39.18 54.98%
ResNet50 A100 2 2 24.09 32.84 36.32%
Bert 3090 1 1 8.93 10.41 16.57%
Bert 3090 1 2 13.11 14.31 9.15%
Bert 3090 2 1 6.94 8.27 19.16%
Bert 3090 2 2 12.19 15.58 27.81%
Bert A100 1 1 10.45 12.72 21.72%
Bert A100 1 2 20.24 21.57 6.57%
Bert A100 2 1 12.63 16.09 27.39%
Bert A100 2 2 24.86 29.84 20.03%

Conclusion: Compared to PyTorch Eager, using OneFlow Eager shows significant performance advantages in small batch scenarios for both ResNet50 and BERT models.

Note: The tests were conducted using PyTorch v2.1.0 and CUDA 12.1.

Version 1.0.0

OneFlow v1.0.0 release note

OneFlow 发布 v1.0.0 版本, 欢迎大家安装使用。

  • 重点内容
  • 新特性
  • 功能改进
  • 改动与修复
  • 性能

重点内容

本次版本更新包含 447 个 commits 和如下重点内容:

  • 发布新接口 compile_from_torch。该接口在共享参数显存的情况下,将 PyTorch 的 Module 实例转化成 OneFlow 的 Module 实例,支持直接 Eager 运行或者转化为静态图 nn.Graph 并进一步使用 MLIR 编译加速。该接口仍在快速演进中,目前支持了动态形状编译并在ResNet50、Faster RCNN、Stable Diffusion三个典型模型上做了验证。

  • 对 Eager 运行时做了一系列优化与重构,包括统一系统内存池、对接 CUDA 原生接口、优化指令调度机制、引入指令融合机制、优化 Autograd 构图速度、优化 Op 推导过程、解耦 Instruction 与 Stream 等。

  • 静态图分布式物理执行计划支持分离编译功能,每个进程独立编译自己所需的执行计划,使得编译时间不再随 GPU 规模线性增长。

  • 新增一系列函数式自动微分相关接口支持,包括 jvp、vjp、hvp、vhp、jacobian、hessian。

  • 新增 Insight 模块,支持可视化地展示埋点区间内 kernel 调用、执行时间、速度等信息。

  • 大规模模型训练开源工具箱 LiBai 版本更新,原生支持大语言模型 Llama2 和 ChatGLM2 的 finetune 和分布式推理,支持 full finetune、adapter finetune、lora finetune,可使用 lm-eval-harness 对语言模型进行评测验证。

  • OneFlow Serving 功能升级,在原有支持 OneFlow Cpp 后端的基础上,新增支持 OneFlow Python 后端和 OneFlow Lite 后端。

新特性

1、compile_from_torch

compile_from_torch 接口在共享参数显存的情况下,将 PyTorch 的 Module 实例转化成 OneFlow 的 Module 实例,支持直接 Eager 运行或者转化为静态图 nn.Graph 并进一步使用 MLIR 编译加速。(#10404, #10408, #9984, #9754)

接口签名及参数介绍:

compile_from_torch(torch_module: torch.nn.Module, \*, use_graph=True, options={})
* torch_module:需要被转换的 Torch Module 实例。
* use_graph:是否转化为静态图 nn.Graph 并使用 MLIR 编译加速,默认为 True。
* options:
  * size: 使用静态图 nn.Graph 后会根据输入的 shape 计算 hash 值缓存相应的 graph ,size 表示静态图缓存的最大容量,超过最大容量会根据 LRU 策略对 graph 进行清理,默认值为 9。
  * dynamic:对于动态 shape 的输入第一次会完整编译 graph,之后的对于不同 shape 的输入当 dynamic 为 True 时会启用共享图进行编译加速,dynamic 为 False 时每次都会重新进行编译,默认为 True。
  * debug:调试模式和日志级别设置,-1 禁用调试模式,0 输出警告和静态图构建信息,1 额外输出每个子模块的构图信息,2 额外输出每个算子的进度,3 输出更详细的算子信息,默认为 -1。

使用示例:

import torch
from torchvision import models

import oneflow
from oneflow.framework.infer_compiler import compile_from_torch

DEVICE = torch.device("cuda")
WEIGHT = models.ResNet50_Weights.DEFAULT
model = models.resnet50(weights=WEIGHT).to(DEVICE)
compile_model = compile_from_torch(model, options={"dynamic": True})

2、分离编译

静态图分布式物理执行计划支持分离编译功能,每个进程独立编译自己所需的执行计划,使得编译时间不再随 GPU 规模线性增长。分离编译功能支持 3D 混合并行(数据并行+模型并行+流水并行)场景,可与大规模模型训练开源工具箱 LiBai 一同使用,打开方式为:export ONEFLOW_ENABLE_LAZY_SEPARATE_COMPILE=1。(#9920, #10140, #10141, #10124, #10102)

以下是在 128 卡 A100-PCIE-40GB 设备上,配合 LiBai 在 GPT2 模型上的测试结果:

并行方式 是否开启分离编译 执行计划编译时间
数据并行 (DP128 MP1 PP1) 超过 20 minutes
数据并行 (DP128 MP1 PP1) 108.21 s
3D 并行 (DP4 MP4 PP8) 445.16 s
3D 并行 (DP4 MP4 PP8) 82.88 s

3、函数式自动微分接口

新增一系列函数式自动微分相关接口支持,包括 jvp、vjp、hvp、vhp、jacobian、hessian。(#10412, #10428)

使用示例:

import oneflow as flow

# jacobian example
def exp_reducer(x):
    return x.exp().sum(dim=1)

input = flow.rand(2, 2)
jac_rslt = flow.autograd.functional.jacobian(exp_reducer, input)

# vhp example
def pow_reducer(x):
    return x.pow(3).sum()

input = flow.rand(2, 2)
v = flow.ones(2, 2)
vhp_rslt = flow.autograd.functional.vhp(pow_reducer, input, v)

4、Insight模块

新增 Insight 模块,支持可视化地展示埋点区间内 kernel 调用、执行时间、速度等信息。(#10370)

使用方法如下:

  • 步骤一:使用 OneFlow Profiler 模块在代码中设置埋点区间。
  • 步骤二:运行代码并使用 NVIDIA Nsight Systems 生成 sqlite 后缀文件。
  • 步骤三:使用 OneFlow Insight 模块生成 json 文件。
  • 步骤四:在网址 chrome://tracing/ 或 edge://tracing/ 中打开 json 文件得到可视化界面。

更详细的介绍可参考:https://github.com/Oneflow-Inc/oneflow/tree/master/python/oneflow/utils/insight#usage

5、LiBai版本更新

  • 大规模模型训练开源工具箱 LiBai 功能升级,发布新版本 v0.3.0,原生支持大语言模型 Llama2 和 ChatGLM2 的 finetune 和分布式推理,支持 full finetune、adapter finetune、lora finetune,可使用 lm-eval-harness 对语言模型进行评测验证。

  • ChatGLM 和 Llama2 的分布式训练和推理支持情况如下:

Models 2D (tp+pp) Inference 3D Parallel Training
ChatGLM
Llama2

使用示例:

# full finetune
bash tools/train.sh projects/Llama/train_net.py projects/Llama/configs/llama_sft.py 8

# adapter finetune
bash tools/train.sh projects/Llama/adapter/train_net.py projects/Llama/adapter/adapter_sft.py 8

# inference
bash tools/infer.sh projects/Llama/pipeline.py 8

# eval
python projects/Llama/utils/eval_adapter.py

6、其他新特性

  • 新增 FFT 相关算子。(#10027)
  • 新增 zeta 算子。(#10189)
  • 新增 tril_ 算子。(#9996)
  • 新增 clone 算子。(#9800)
  • 新增 frac、frac_ 算子。(#9979)
  • 新增 exp2 算子。(#9958)
  • 新增 rrelu 算子。(#9736)
  • 新增 lgamma 反向算子。(#10177)
  • 新增 digamma 算子。(#10066)
  • 新增 trigamma 算子。(#10117)
  • 新增 bitwise_not 算子。(#9859)
  • 新增 squared_relu 算子。(#10316)
  • 新增 skip_rms_norm 算子。(#10036)
  • 新增 multi_tensor_amp_grad_scaler 相关算子。(#10071)
  • 新增 bitwise_and、bitwise_or、bitwise_xor 算子。(#9842)
  • 新增 fused_attention_concat_past_key_value 算子。(#9963)
  • 新增 fused_multi_head_attention_inference_v2 算子。(#9933)
  • 新增 fused_codegeex_qkv_reshape 算子。(#9927)
  • 新增 fused_apply_rotary_emb 算子。(#9914)
  • 新增 skip_layer_norm 算子。(#9906)
  • 新增 groupwise_dequantize、fused_linear_with_groupwise_quantized_weight 算子。(#9900)
  • 新增 fused_scale_mask_bias_softmax、fused_scale_mask_bias_softmax_grad 算子。(#9867)
  • 新增 depend 算子,用于描述计算图中依赖关系。(#9807)
  • 新增 real, imag, conj, conj_physical 复数数据类型相关算子。(#10034, #10281)
  • 新增 nms 算子 cpu 支持。(#10225)
  • 新增 cast 算子对 bool to int16 数据类型转换支持。(#10211)
  • 新增 arange 算子对 fp16 数据类型的支持。(#10019)
  • 新增 adaptive_avg_pool 算子对 fp16 数据类型的支持。(#10004)
  • 新增 nonzero 算子对 fp16 数据类型的支持。(#9826)
  • 新增 exponential 算子对 half 数据类型的支持。(#10005)
  • 新增 arg_sort、top_k 算子对 half 数据类型的支持。(#10000)
  • 新增 add、sub、mul、mm、sqrt、div 等算子对复数数据类型支持。(#10269, #10136, #10284, #10049)
  • 新增基础 binary 算子对不连续内存输入张量的支持。(#9986)
  • 新增虚拟 jit 接口,支持对 import 而未实际使用该接口的用户代码 mock_torch。(#10395)
  • 新增 mem_get_info 接口,用于返回指定 cuda 设备的总体和空闲内存信息。(#10398)
  • 新增 tensor.new 接口。(#9881)
  • 新增 tensor.is_cpu 接口。(#10172)
  • 新增 tensor.is_view 接口。(#10101)
  • 新增 tensor.data_ptr 接口。(#10111, #10139)
  • 新增 tensor.baddbmm 接口。(#9918)
  • 新增 special.erf、special.erfc 等接口。(#9982)
  • 新增 layout 和 frombuffer 接口。(#10171)
  • 新增 prune 相关接口。(#9730)
  • 新增 utils.model_zoo 接口。(#10183)
  • 新增 get_rng_state 和 get_rng_state_all 接口。(#9760)
  • 新增 set_rng_state 和 set_rng_state_all 接口。(#10250)
  • 新增对 float16 数据类型支持。(#9697)
  • 新增对 char 和 short 数据类型支持。(#10086)
  • 新增对 complex64 和 complex128 数据类型支持。(#9987)
  • 新增 Transform Dialect 到 MLIR codegen 流程中。(#10224, #10227)
  • 新增对 matmul 算子的代码生成支持。(#10283)
  • 新增对 softmax 算子的代码生成支持。(#10263, #10272)
  • 新增对 transform.oneflow.apply_patterns 算子的代码生成支持。(#10255)
  • 新增 MLIR codegen 流程中对 byte attr 支持。(#10276)
  • 新增 extra_libs 功能 到 mock_torch 模块,使其可以实现 flowvision 去模拟 torchvision 的功能。(#10223)
  • 新增 lazy 参数到 mock_torch 模块,对不存在的接口会返回一个假对象而不立即报错。(#9876)
  • 新增 skip_init 功能,并引入 meta device。(#10008)
  • 新增 HostMemoryInput机制,将算子某个输入定义为 HostMemoryInput 类型后可以在 kernel 的 host 函数体内访问数据。(#9928)
  • 新增 nccl 逻辑运算的融合机制,可以降低 ZERO 等场景,过多碎 nccl 导致同步开销太大降低训练速度的问题。(#9879)
  • 新增张量运算的重计算机制。(#9861)
  • 新增 backward_hook、register_full_backward_hook、register_state_dict_pre_hook 支持。(#9837, #9710)
  • 新增 stochastic weight averaging 算法相关内容到 optimizers 模块。(#9781)
  • 新增计算图层面的拉直算法。(#9718, #9748)
  • 新增 DelayVariableOpExecutionPass 计算图优化 pass。(#9745)
  • 新增 MulCastPattern 算子融合规则。(#9715)
  • 新增环境变量 ONEFLOW_ENABLE_GLOBAL_INPUTS_WITH_INCONSISTENT_PLACEMENT,控制是否自动将算子用到的 global_tensor 通过 to_global 操作放到最大的 rank 上。(#10073)
  • 新增环境变量 ONEFLOW_EAGER_NCCL_USE_COMPUTE_STREAM 用于控制eager 模式下 nccl 和普通的计算是否在同一个stream上,默认值为false。(#10230)
  • 新增环境变量 VLOG_REMAT 处理动态图重计算的日志并对接 ComputeComplexityFn 估计 op 计算时间。(#10212)
  • 新增环境变量 ENABLE_ACTOR_DEBUG_LOG 用于打印当前 rank 上 actor 收发消息、执行的详细日志。(#10081)
  • 新增环境变量 ONEFLOW_RUN_GRAPH_BY_VM 用于控制是否使用 VM 来运行静态图 nn.Graph。(#9884)
  • 新增环境变量 ONEFLOW_DISABLE_MOCK_TORCH 用于控制是否让 mock_torch 功能失效。(#9805)
  • 新增环境变量 ONEFLOW_VM_MULTI_THREAD 用于控制 vm 中使用的线程数。(#9698)
  • 新增二阶优化器 lbfgs 支持。(#10265)

功能改进

1、Eager 运行时优化与重构

对 Eager 运行时做了一系列优化与重构,主要包括:

  • 统一系统内存池,打通同设备下的所有分配器的内存资源。(#8591)
  • 对接 CUDA 原生接口,加速 kernel launch。(#8571)
  • 优化指令调度机制,降低系统负担。(#8796)
  • 引入指令融合机制,加速指令分发。(#7399)
  • 优化 Autograd 构图部分的速度。(#8606)
  • 优化op推导过程,加速kernel执行。(#8672, #8619, #8662)
  • 合并eager运行时中的冗余概念,解耦Instruction与Stream。(#8583, #8590, #7607)

可以通过一些环境变量设定 Eager 运行时行为:

环境变量 意义 默认值
ONEFLOW_VM_COMPUTE_ON_WORKER_THREAD 是否在 worker 线程上完成计算 true
ONEFLOW_VM_MULTI_THREAD 是否使用多线程协同执行 Eager 运算 true
ONEFLOW_VM_ENABLE_STREAM_WAIT 多 stream 间的依赖是否使用 stream_wait 机制 true
ONEFLOW_VM_ENABLE_SCHEDULE_YIELD 是否使用 yield 机制减少 scheduler 线程 busy wait 程度 true
ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE 计算过程中是否缓存算子输出的元信息 true
ONEFLOW_VM_WORKER_THREAD_LIMIT worker 线程的个数 16
ONEFLOW_VM_PENDING_HANDLE_WINDOW_SIZE vm 融合指令的最大 size 10
ONEFLOW_VM_BLOCKING_DEBUG_INSTRUCTIONS_DISPLAY_LIMIT vm 执行超时时打印未处理指令的数量 1000

2、OneFlow Serving功能升级

OneFlow Serving 功能升级,在原有支持 OneFlow Cpp 后端的基础上,新增支持 OneFlow Python 后端和 OneFlow Lite 后端。

  • 使用 OneFlow Cpp 后端可以在脱离 Python 的环境中部署以达到最高的性能。
  • 使用 OneFLow Lite 后端可以实现在端侧设备上的部署。
  • 使用 OneFlow Python 后端可以以极小的迁移代价完成复杂模型的部署。

使用方法参考:https://github.com/Oneflow-Inc/serving/blob/main/README.md

3、其他功能改进

  • 改进部分代码实现以支持 cuda 12.x 版本。(#10367)
  • 改进 glu 算子实现,支持无bias 输入。(#9874)
  • 改进池化算子实现,支持 channels_last 参数 。(#10242)
  • 改进 flip 算子实现,针对 dim = -1 时候访存无法合并的情况进行优化。(#10310)
  • 改进 bincount 算子实现,实现优化加速。(#10308)
  • 改进 index_add 算子实现,根据 index 的长度派发不同的实现逻辑以改善索引比较小的时候的性能。(#9751)
  • 改进 topk 算子实现,优化 batch_size 是1时的性能。(#10009)
  • 改进 conv、arange 等算子实现,支持启用cuda graph。(#9761)
  • 改进 upsample 算子实现,增加对输入/输出大小检查。(#9737)
  • 改进 grouped_matmul_bias 算子实现,增加张量并行的 sbp 推导规则。(#9934)
  • 改进 reshape 算子实现,增加对 nd sbp 推导规则。(#9858)
  • 改进 mask_fill 和 in_top_k 算子的报错信息并完善测试样例。(#10062)
  • 改进 tanh 算子的高阶微分规则,优化三阶微分下的性能。(#10188, #10237)
  • 改进 conv 接口实现,支持 device 和 dtype 参数。(#10228)
  • 改进 conv 接口实现,支持对输入自动扩展维度。(#9721)
  • 改进 sum 接口实现,支持 dtype 参数。(#10204)
  • 改进 softmax 接口实现,支持 dtype 参数。(#10069)
  • 改进 maxpool 接口实现,支持 3D 输入张量。(#10110)
  • 改进 ctc_loss 接口实现,参数与 PyTorch 接口对齐。(#9887)
  • 改进 copy 接口实现,支持输入和输出的 device 和 dtype 都不同的情况。(#9888)
  • 改进 grad 接口实现,支持 allow_unused 参数。(#10251)
  • 改进 load 接口实现,提供更加用户友好的报错信息。(#10138)
  • 改进 fused_matmul_bias 算子及接口实现,支持 alpha 和 beta 参数。(#10015)
  • 改进 normal 算子及接口实现以和 pytorch 行为对齐。(#10185)
  • 改进 fused attention 算子及接口实现,允许 pasti_key 和 past_value 为 None 的情况。(#9977)
  • 改进 fused_attention 算子及接口实现,增加对可变序列长度的支持。(#9991)
  • 改进 fused_multi_head_attention_inference 算子及接口实现,增加attn_bias 参数。(#9853)
  • 改进 bn 相关 functor 实现,融合bn_add_relu和bn_relu操作加速推理。(#10239)
  • 改进基于 MLIR CodeGen 流程,将 LLVM 版本更新到 16.0.0。(#9985)
  • 改进基于 MLIR codegen 流程,增加 AppendOneFlowStream、MgpuToOneFlowStream、CastOneFlowInputToSignlessPass pass。(#10149, #10151, #10099)
  • 改进基于 MLIR codegen 流程,通过链接 LibDevice 支持 NVVM IR 转化为 cubin。(#10200)
  • 改进基于 MLIR codegen 流程,支持在 MLIR 中使用 tmpbuffer 作为 MemPool。(#10159)
  • 改进基于 MLIR codegen 流程,支持 bufferizable 算子分发。(#9787)
  • 改进基于 MLIR codegen 流程,进行 ofmempool 等相关流程加速。(#10152, #10168, #10184, #10239)
  • 改进 stacktrace 调用栈信息。(#9912, #9937, #10260, #10161)
  • 改进随机数生成器部分实现,增加缓存避免每次调用重新生成。(#10387)
  • 改进 graph load 功能,支持将 graph 加载到新设备上。(#10335)
  • 改进 dummy 数组初始化实现,使用 fold 表达式。(#10271)
  • 改进 MemoryFormat 类组织形式,通过 cpython 暴露到 python 层中,支持使用 Tensor.to 接口更改张量的 MemoryFormat。(#10181)
  • 改进 steam、device、vm 部分实现以支持更多设备类型。(#10166)
  • 改进 MapAt 的报错信息,新增打印 key 的值。(#10090)
  • 改进 OOM 报错信息,支持区分 CUDA 和 CPU 设备且显示 size。(#9938)
  • 改进 CHECK_XX_OR_RETURN 宏报错信息。(#9921)
  • 改进 graph 相关报错信息。(#9821)
  • 改进 卷积算子相关报错信息。(#9707)
  • 改进模型初始化方式,避免额外的开销。(#10088)
  • 改进 thread manager 实现,可以兼容不限制线程、master 作为线程、n个线程的三种使用场景。(#10060)
  • 改进 numpy 数组释放方式,在主线程中释放以减少耗时的 gil 请求。(#10050)
  • 改进 graph save runtime_state_dict 实现,提升性能并修复相关问题。(#10016)
  • 改进形如 Tensor.foo(*args) 接口不同调用方式的解析,使用统一的 PyParseArgs 函数完成。(#9983)
  • 改进 ArgsTree 类实现,支持任意输出类型并进行文件位置迁移。(#9846)
  • 改进内存分配机制,实现按 stream 有序分配。(#9818)

改动与修复

1、功能改动

  • 移除 deallocate context。(#10143)
  • 移除图编译中的调试编译模式。(#10145)
  • 移除不再使用的 MemChain merge 的逻辑。(#10097)
  • 移除一些分布式相关的环境变量的默认设置。(#9803)
  • 重构 lazy 模式下的 collective boxing 实现。(#10098)
  • 重构 EagerCclS2S 的注册。(#10100)
  • 重构 collective_boxing_executor_backend 的实现。(#10082)
  • 重构使用 VM 跑 global nn.graph 的实现。(#10048)
  • 重构 local to global 相关接口实现。(#9870)
  • 重构 MLIR codegen 流程中算子分发 dialect 实现。(#9693)
  • 重构 random generator 和 distribution kernels 实现。(#9691)
  • 重构 fast_atomic_add 算子实现。(#9680)
  • 重构 glog 中的错误检查相关宏定义。(#10176)
  • 重构 random generator 实现。(#10025)
  • 重构部分 elementwise primitive 的实现。(#9857)
  • 重构部分 device 描述相关代码。(#9791)
  • 重构 ParseDeviceString 和 ParseDeviceNameConf 实现。(#9833)
  • 重构 ActorMsg 相关实现,引入 IBVerbsActorMsgWrapper 封装以减少 ActorMsg 的大小。(#9762)
  • 重构 save 和 load 接口实现,迁移保存 Graph 逻辑的方法到 _save_graph 函数,添加部分 _open* 辅助类区分路径和内存, save 支持将权重保存到 BytesIO 中,load 支持文件流。(#10021)
  • 重构部分 tensor 相关接口实现,代码从 python 层迁移到 C++ 层。(#9990, #9964)
  • 升级项目使用的 PyBind 版本至 2.11.1。(#10391)

2、问题修复

  • 修复 cmake 文件中动态链接默认设置以避免 llvm15 链接错误。(#10373, #10131)
  • 修复基于 MLIR codegen 中 cast 相关 bug。(#10105)
  • 修复 Module._apply 函数中对 cpg attr 处理的逻辑问题。(#10343)
  • 修复 DummyModule 在 attr 为 __mro_entries__ 情况下无法被继承的问题。(#9976)
  • 修复 full op 中 _handle_size_arg 对传入 size 判断的问题。(#9975)
  • 修复通过命令行启动 mock 后环境变量残留导致后续 api 方式的 mock 参数错误的问题。(#9970)
  • 修复两个进程异常时无法退出的问题。(#10054)
  • 修复了分组量化 sbp 推导的 bug。(#10132)
  • 修复 GroupedMatmulFunctor 中的 kMaxInputCount 检查问题。(#10322)
  • 修复 0-size tensor broadcast 的问题。(#10186)
  • 修复使用 shared_graph 时 double 类型 attr 没有更新的问题。(#10279)
  • 修复 GetItemInScalarTensor 中的数据类型错误。(#10226)
  • 修复 GroupNorm 梯度问题,仅当 gamma 和 beta 需要梯度时,才调用 GroupNormParamGrad。(#10045)
  • 修复 global mode 在读取 placement 为部分 ranks 的 tensor 时会报错的问题。(#10056)
  • 修复 checkpointing 在 PP 下可能会有跨出 rank 控制的边,从而导致影响分离编译下的 task graph 构建的问题。(#10057)
  • 修复同时使用 3D 并行和打开 activation checkpointing 时的 bug。(#10031)
  • 修复 AutoMixedPrecision pass 在其他非 cuda 设备上的适配 bug 和 LayerNorm Module相关设备组合的 bug。(#10026)
  • 修复 scatter 算子 reduce 参数默认值设置问题。(#10002)
  • 修复 mock.disable 时,有些 Torch 变量依旧内置于其他引用的 globals 里而导致 disable 不彻底的问题。(#9989)
  • 修复 vm::TensorStorage 析构问题。(#9962)
  • 修复 offload,解决小 tensor 释放清理不出 Cuda Memory 的问题。(#9974)
  • 修复线程不安全导致的 Python stack getter 偶发 segmentation fault 的问题。(#9955)
  • 修复分离编译场景下的 set 中元素查找不到的问题。(#9952)
  • 修复 fused_multi_head_attention 算子,对齐 qkv 和 output_layout。(#9950)
  • 修复 random 系列算子在 graph 和 checkpointing 中 seed 表现不一致的问题。(#9941)
  • 修复 Eager 模式下 parameter reload 失败问题。(#9935)
  • 修复 mock torch lazy 功能特定情况下死循环的问题。(#9926)
  • 修复 stft_kernel.cu 文件中的代码默认情况下不会被编译的问题。(#9922)
  • 修复 order_in_graph 在分离编译下,由于 TaskGraph 不是完整的图。(缺少其他 rank 的信息)导致拓扑序失效造成 死锁、内存分配写错的 BUG。(#9909 )
  • 修复 xrt 编译找不到 fmt 的问题。(#9894)
  • 修复 local to global 过程中,当 sbp 为 B 时,各进程显存分配不平衡的问题。(#9852)
  • 修复 CTCLoss 的第三个参数相关 OneFlow 和 PyTorch 行为不对齐的问题。(#9845)
  • 修复 thread_global_id 和 rank_group_scope 初始化相关问题。(#9841)
  • 修复 dropout 算子实现中 inplace 处理相关错误。(#9808)
  • 修复 load 功能在加载 PyTorch 保存的非张量对象时的错误。(#9804)
  • 修复连续内存/显存分配策略之间的冲突问题。(#9786)
  • 修复 EagerBlobObject::ByteSizeOfBlobBody 内存分配时未考虑非连续情况的问题。(#9782)
  • 修复 fill_ 算子在 autocast 时的 dtype infer 错误。(#9776)
  • 修复 fused_glu 算子 sbp 推导规则相关问题。(#10108)
  • 修复调用 nn.Graph.__map_io 的相关问题。(#10084)
  • 修复 set_grad_mode 接口和 PyTorch 行为不一致的问题。(#10059)
  • 修复 load 接口中 map_location 参数相关的一个问题并支持传入 lambda 函数。(#10052)
  • 修复 view 模式下的 unsqueeze 操作后 stride 推断错误。(#9775)
  • 修复 conv op 在 unbatched 输入且有 bias 时的问题,为 deconv op 添加 unbatched 输入支持。(#9740)
  • 修复 trunc_normal_ 实现的逻辑错误。(#9711)
  • 修复 topk 算子 dim 参数默认值的问题。(#9703)
  • 修复打印静态图时部分网络的 placement 为 CPU 的问题。(#9770)
  • 修复 trt_flash_attention 的 include 路径和原生 flash attention 路径冲突问题。(#9750)
  • 修复 is_shutting_down 和 gil 引起的 stack getter 段错误。(#9681)
  • 修复分离编译特性在分布式单测中暴露相关的问题。(#9749)
  • 修复拉直算法实现中内存处理相关问题。(#9746)
  • 修复执行流程中一个死锁问题。(#9738)
  • 修复 DummyModule 在 isinstance 判断时报错的相关问题。(#10207)
  • 修复在引入 llvm::SmallVector 时错误覆盖默认 size 的行为。(#9932)
  • 修复非连续内存张量内存大小计算错误问题。(#9819)
  • 修复在 TensorStorage 析构函数中调用 CHECK_JUST 的问题。(#9752)

性能

1、OneFlow compile_from_torch VS PyTorch compile

对 ResNet50 模型和 Faster RCNN 模型的 backbone 部分使用 OneFlow compile_from_torch 和 PyTorch compile 接口进行编译并执行,测试不同 shape 输入时的编译时间,结果如下表:

模型 输入 shape PyTorch compile OneFlow compile_from_torch dynamic 测试时机
ResNet50 (1, 3, 512, 512) 21.328 s 3.205 s False 首次编译执行
ResNet50 (2, 3, 896, 512) 14.167 s 1.523 s False 连续编译执行
ResNet50 (2, 3, 512, 896) 13.364 s 1.402 s False 连续编译执行
ResNet50 (3, 3, 896, 896) 15.056 s 1.539 s False 连续编译执行
ResNet50 (2, 3, 1024, 896) 14.167 s 1.500 s False 连续编译执行
ResNet50 (2, 3, 896, 1024) 12.891 s 1.494 s False 连续编译执行
ResNet50 (6, 3, 1024, 1024) 14.859 s 1.872 s False 连续编译执行
ResNet50 (1, 3, 512, 512) 170.446 s 3.143 s True 首次编译执行
ResNet50 (2, 3, 896, 512) 185.672 s 0.851 s True 连续编译执行
ResNet50 (2, 3, 512, 896) 0.089 s 0.836 s True 连续编译执行
ResNet50 (3, 3, 896, 896) 0.084 s 0.980 s True 连续编译执行
ResNet50 (2, 3, 1024, 896) 0.077 s 0.942 s True 连续编译执行
ResNet50 (2, 3, 896, 1024) 0.080 s 0.931 s True 连续编译执行
ResNet50 (6, 3, 1024, 1024) 0.084 s 1.406 s True 连续编译执行
Faster RCNN (1, 3, 512, 512) 18.224 s 5.483 s False 首次编译执行
Faster RCNN (2, 3, 896, 512) 9.200 s 3.011 s False 连续编译执行
Faster RCNN (2, 3, 512, 896) 9.331 s 3.025 s False 连续编译执行
Faster RCNN (3, 3, 896, 896) 9.301 s 2.854 s False 连续编译执行
Faster RCNN (2, 3, 1024, 896) 9.290 s 2.805 s False 连续编译执行
Faster RCNN (2, 3, 896, 1024) 9.123 s 2.851 s False 连续编译执行
Faster RCNN (6, 3, 1024, 1024) 9.377 s 3.180 s False 连续编译执行
Faster RCNN (1, 3, 512, 512) 25.444 s 5.430 s True 首次编译执行
Faster RCNN (2, 3, 896, 512) 25.381 s 1.899 s True 连续编译执行
Faster RCNN (2, 3, 512, 896) 0.116 s 1.886 s True 连续编译执行
Faster RCNN (3, 3, 896, 896) 1.982 s 1.793 s True 连续编译执行
Faster RCNN (2, 3, 1024, 896) 0.114 s 1.803 s True 连续编译执行
Faster RCNN (2, 3, 896, 1024) 0.111 s 1.778 s True 连续编译执行
Faster RCNN (6, 3, 1024, 1024) 0.143 s 2.110 s True 连续编译执行

对 Stable Diffusion 模型的 unet 部分使用 OneFlow compile_from_torch 和 PyTorch compile 接口进行编译并执行,测试不同 shape 输出时的编译时间和推理时间,结果如下表:

模型 输出 shape PyTorch compile OneFlow compile_from_torch dynamic 测试时机
Stable Diffusion (2, 512, 512) 103.701 s 63.670 s False 首次编译执行
Stable Diffusion (1, 512, 768) 95.137 s 53.864 s False 连续编译执行
Stable Diffusion (2, 768, 512) 90.259 s 55.271 s False 连续编译执行
Stable Diffusion (1, 768, 768) 90.196 s 51.590 s False 连续编译执行
Stable Diffusion (2, 512, 512) 275.660 s 57.117 s True 首次编译执行
Stable Diffusion (1, 512, 768) 345.774 s 43.752 s True 连续编译执行
Stable Diffusion (2, 768, 512) 349.835 s 47.653 s True 连续编译执行
Stable Diffusion (1, 768, 768) 7.224 s 45.720 s True 连续编译执行
Stable Diffusion (2, 512, 512) 4.088 s 2.831 s False 后续执行
Stable Diffusion (1, 512, 768) 3.296 s 2.325 s False 后续执行
Stable Diffusion (2, 768, 512) 5.594 s 5.157 s False 后续执行
Stable Diffusion (1, 768, 768) 4.713 s 3.557 s False 后续执行
Stable Diffusion (2, 512, 512) 4.448 s 2.801 s True 后续执行
Stable Diffusion (1, 512, 768) 3.201 s 2.314 s True 后续执行
Stable Diffusion (2, 768, 512) 6.093 s 4.166 s True 后续执行
Stable Diffusion (1, 768, 768) 4.920 s 3.557 s True 后续执行

结论:使用 OneFlow compile_from_torch 接口有相对于 PyTorch compile 接口平均更短的编译时间,另外得益于 OneFlow 框架中极致的算子优化,在 Stable Diffusion 模型上有更优的执行性能。

备注:测试使用 GPU 型号为 3090,PyTorch 版本为 v2.1.2,cuda 版本为 12.2。

2、OneFlow Eager vs PyTorch Eager

模型 GPU 型号 卡数 macro batch PyTorch 性能(iter/s) OneFlow 性能(iter/s) 加速比
ResNet50 3090 1 1 31.37 38.81 23.72%
ResNet50 3090 1 2 32.06 48.45 51.12%
ResNet50 3090 2 1 31.10 33.46 7.59%
ResNet50 3090 2 2 31.76 34.83 9.67%
ResNet50 A100 1 1 24.60 46.64 89.59%
ResNet50 A100 1 2 25.06 49.88 99.04%
ResNet50 A100 2 1 25.28 39.18 54.98%
ResNet50 A100 2 2 24.09 32.84 36.32%
Bert 3090 1 1 8.93 10.41 16.57%
Bert 3090 1 2 13.11 14.31 9.15%
Bert 3090 2 1 6.94 8.27 19.16%
Bert 3090 2 2 12.19 15.58 27.81%
Bert A100 1 1 10.45 12.72 21.72%
Bert A100 1 2 20.24 21.57 6.57%
Bert A100 2 1 12.63 16.09 27.39%
Bert A100 2 2 24.86 29.84 20.03%

结论:使用 OneFlow Eager 相对于 PyTorch Eager 在 ResNet50 和 Bert 两个模型小 batch 场景下有明显性能优势。

备注:测试使用PyTorch版本为 v2.1.0,cuda 版本为 12.1。