Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training dev #1976

Open
wants to merge 33 commits into
base: dev_train
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
29aa4cc
增加GetModelInputDataTypesMap接口
Sep 21, 2023
0b852d8
增加Log的反向算子
Sep 22, 2023
3b0f2da
fix bug: broadcast type
Oct 16, 2023
d1c2440
fix bug: broadcast type identifying
Oct 17, 2023
44b0c64
Rename GetModelInputDataTypeMap func
doxutx Oct 19, 2023
3b9495f
GetModelInputDataTypeMap
doxutx Oct 19, 2023
836e331
GetModelInputDataTypeMap
doxutx Oct 19, 2023
0bd23ce
Update tnn_impl_default.cc
doxutx Oct 19, 2023
4c1e33a
Update tnn_impl_default.h
doxutx Oct 19, 2023
f015419
Update tnn_impl_rknpu.cc
doxutx Oct 19, 2023
4d51c1c
Update tnn_impl_rknpu.h
doxutx Oct 19, 2023
920a4d0
Update tnn_impl_coreml.h
doxutx Oct 19, 2023
76d8335
Update tnn_impl_coreml.mm
doxutx Oct 19, 2023
f9501ef
Update tnn.cc
doxutx Oct 19, 2023
5fc5393
Merge branch 'Tencent:dev_train' into dev_train
kotrue Oct 19, 2023
49b5f66
fix: constants fix for onnx2tnn
Nov 24, 2023
e996528
fix: convertor constants
Nov 28, 2023
11d8de1
update
Nov 28, 2023
e1f993d
fix: 修复转换bug
Nov 30, 2023
bf3de2c
fixbug
Dec 1, 2023
449fe0b
update
Dec 5, 2023
10adeb9
feat: 增加ZeroGrad函数,增加GradParam参数
Dec 14, 2023
fad2795
bug fix
Dec 14, 2023
00356c2
fix: weights grad update, considering data format
Dec 15, 2023
09b4b51
feat: binary算子反向传播
Dec 19, 2023
5fd150b
一些优化
Dec 21, 2023
705aee9
测试:增加二值类的grad测试
Dec 22, 2023
0e1478f
fix: 修复grad取值的bug
Dec 22, 2023
2fbfe69
test: 增加二元的后向计算
Dec 25, 2023
9173e3e
update
Dec 25, 2023
330594e
unary grad实现
Dec 26, 2023
6b1502b
InnerProduct梯度
Dec 27, 2023
c62ad9b
fix: 梯度置零
Jan 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tnn/core/blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ struct PUBLIC BlobDesc {
struct PUBLIC BlobHandle {
void *base = NULL;
uint64_t bytes_offset = 0;

template <typename T>
T force_to() {
return reinterpret_cast<T>(base ? ((char *)base + bytes_offset) : nullptr);
}
};

class BlobImpl;
Expand Down
2 changes: 1 addition & 1 deletion include/tnn/core/instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ typedef std::function<void(std::vector<Blob*>& blobs, LayerInfo* info)> BlobStat

class PUBLIC Instance {
public:
Instance(NetworkConfig& net_config, ModelConfig& model_config);
Instance(const NetworkConfig& net_config, const ModelConfig& model_config);

~Instance();

Expand Down
2 changes: 1 addition & 1 deletion include/tnn/core/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ enum StatusCode {
class PUBLIC Status {
public:
~Status();
Status(int code = TNN_OK, std::string message = "OK");
Status(int code = TNN_OK, const std::string &message = "OK");

Status &operator=(int code);

Expand Down
1 change: 1 addition & 0 deletions source/tnn/core/abstract_network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Status AbstractNetwork::SetCpuNumThreads(int num_threads) {
Status AbstractNetwork::TrainStep() {
return Status(TNNERR_TRAIN_ERROR, "Subclass of AbstractNetwork doesn't implement TrainStep func");
}

Status AbstractNetwork::GetTrainingFeedback(TrainingFeedback& feed_back) {
return Status(TNNERR_TRAIN_ERROR, "Subclass of AbstractNetwork doesn't implement GetTrainingFeedback func");
}
Expand Down
9 changes: 2 additions & 7 deletions source/tnn/core/blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,14 @@ namespace TNN_NS {

std::string BlobDesc::description(bool all_message) {
std::ostringstream os;
//name
os << "name: " <<name;

//data type
os << " data type: " << data_type;

//shape
os << " data_type: " << data_type;
os << " data_format: " << data_format;
os << " shape: [ " ;
for (auto iter : dims) {
os << iter << " " ;
}
os << "]";

return os.str();
}

Expand Down
9 changes: 8 additions & 1 deletion source/tnn/core/blob_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,15 @@ Status BlobManager::AllocateBlobMemory(int flag) {

BlobMemorySizeInfo info = device_->Calculate(current_blob->GetBlobDesc());
// find an available BlobMemory
bool use_new_memory = false;
#if TNN_TRAIN
// for train mode, ZeroGrad will reset all grads as 0 at the TrainStep() beginning, use_count does not work
if (config_.train_config.run_mode == TRAIN_MODE_TRAIN) {
use_new_memory = true;
}
#endif
BlobMemory *blob_memory =
blob_memory_pool_map_[info.dims.size()]->BorrowBlobMemory(use_count, info, false);
blob_memory_pool_map_[info.dims.size()]->BorrowBlobMemory(use_count, info, use_new_memory);
blob_memory_mapping_.insert(std::make_pair(current_blob, blob_memory));
}
}
Expand Down
29 changes: 27 additions & 2 deletions source/tnn/core/default_network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ Status DefaultNetwork::InitLayers(NetStructure *net_structure, NetResource *net_
}

// init layer
std::map<std::string, BaseLayer*> layer_map;
for (auto layer_info : net_structure->layers) {
if (runtime_model_ == RUNTIME_MODE_NORMAL && const_layers.find(layer_info->name) != const_layers.end()) {
continue;
Expand Down Expand Up @@ -345,6 +346,20 @@ Status DefaultNetwork::InitLayers(NetStructure *net_structure, NetResource *net_
cur_layer->SetRuntimeMode(runtime_model_);
cur_layer->SetConstantResource(&net_resource->constant_map);
cur_layer->SetConstantResourceFlag(&net_resource->constant_blob_flags);

#if TNN_TRAIN
if (layer_info->type == LayerType::LAYER_GRADIENT) {
auto &forward_layer_name = net_structure->back2forward[layer_info->name];
auto it_forward_layer = layer_map.find(forward_layer_name);
if (it_forward_layer == layer_map.end() || layer_info->param == nullptr) {
return Status(TNNERR_TRAIN_ERROR, "backward layer[" + layer_info->name + "] miss its forward layer");
}

GradientParam *grad_param = dynamic_cast<GradientParam*>(layer_info->param.get());
CHECK_PARAM_NULL(grad_param);
grad_param->forward_layer = it_forward_layer->second;
}
#endif
ret = cur_layer->Init(context_, layer_info->param.get(), layer_resource, inputs, outputs, device_);
if (ret != TNN_OK) {
LOGE("Error Init layer %s (err: %d or 0x%X)\n", cur_layer->GetLayerName().c_str(), (int)ret, (int)ret);
Expand All @@ -355,6 +370,7 @@ Status DefaultNetwork::InitLayers(NetStructure *net_structure, NetResource *net_
cur_layer->SetRuntimeBlobMemoryPool(runtime_blob_pool_);

layers_.push_back(cur_layer);
layer_map[cur_layer->GetLayerName()] = cur_layer;
}
forward_layer_count_ = layers_.size();
return ret;
Expand Down Expand Up @@ -501,9 +517,11 @@ Status DefaultNetwork::PrepareDoReshape(const InputShapesMap& inputs, bool& shap
for (auto iter : inputs) {
Blob *blob = blob_manager_->GetBlob(iter.first);
if (blob == nullptr) {
#if TNN_TRAIN
if (config_.train_config.run_mode == TRAIN_MODE_TRAIN) {
continue; // inputs contains groud turth label, so continues
}
#endif
LOGE("DefaultNetwork reshape blob is empty, maybe the blob name is wrong\n");
return Status(TNNERR_PARAM_ERR, "DefaultNetwork reshape blob is empty, maybe the blob name is wrong");
}
Expand Down Expand Up @@ -636,9 +654,16 @@ Status DefaultNetwork::Forward() {
}
#endif // DUMP_INPUT_BLOB

LOGD("layer name: %s\n", layer->GetLayerName().c_str());
for (auto blob : inputs) {
LOGD("input - %s\n", blob->GetBlobDesc().description().c_str());
}

status = layer->Forward();
LOGD("layer name: %s, forward result: %d \n", layer->GetLayerName().c_str(), (int)status);
LOGD("Output Shape: [%s]\n", layer->GetOutputBlobs()[0]->GetBlobDesc().description().c_str());

for (auto blob : outputs) {
LOGD("output - %s\n", blob->GetBlobDesc().description().c_str());
}
if (status != TNN_OK) {
LOGE("Forward error %s, exit\n", status.description().c_str());
return status;
Expand Down
2 changes: 1 addition & 1 deletion source/tnn/core/instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace TNN_NS {
* It wraps the network object to keep consistency of the header.
*/

Instance::Instance(NetworkConfig& net_config, ModelConfig& model_config) {
Instance::Instance(const NetworkConfig& net_config, const ModelConfig& model_config) {
net_config_ = net_config;
model_config_ = model_config; // note that, the params in model_config is empty, don't use it
}
Expand Down
2 changes: 1 addition & 1 deletion source/tnn/core/status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Status::~Status() {
}

//constructor with code and message
Status::Status(int code, std::string message) {
Status::Status(int code, const std::string& message) {
code_ = code;
message_ = (message != "OK" && message.length() > 0) ? message : StatusGetDefaultMessage(code);
}
Expand Down
1 change: 1 addition & 0 deletions source/tnn/device/arm/acc/TNNVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ struct TNNVector {
}
return dst;
}

static TNNVector<T, len> add(const TNNVector<T, len>& v1, const TNNVector<T, len>& v2) {
TNNVector<T, len> dst;
for (int i = 0; i < len; ++i) {
Expand Down
48 changes: 1 addition & 47 deletions source/tnn/device/arm/acc/arm_binary_layer_acc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,53 +141,7 @@ Status ArmBinaryLayerAcc::Init(Context *context, LayerParam *param, LayerResourc
RETURN_ON_NEQ(allocateBufferParamHalf(inputs, outputs), TNN_OK);
}
#endif

auto layer_param = dynamic_cast<MultidirBroadcastLayerParam *>(param_);
CHECK_PARAM_NULL(layer_param);
auto layer_res = dynamic_cast<EltwiseLayerResource *>(resource_);

// prepare input shapes
input_shapes_.clear();
input_shapes_.reserve(4);
auto output = outputs[0];
auto output_dims = output->GetBlobDesc().dims;

if (broadcast_.GetBytesSize() > 0) {
DimsVector input_shape0 = inputs[0]->GetBlobDesc().dims;
if (layer_param->weight_input_index == 0) {
// bias as another input
input_shapes_.push_back(layer_res->element_shape);
input_shapes_.push_back(input_shape0);
} else {
input_shapes_.push_back(input_shape0);
input_shapes_.push_back(layer_res->element_shape);
}
} else {
if (inputs.size() == 1) {
input_shapes_.push_back(inputs[0]->GetBlobDesc().dims);
input_shapes_.push_back(inputs[0]->GetBlobDesc().dims);
} else {
for (size_t inid = 0; inid < inputs.size(); inid++) {
input_shapes_.push_back(inputs[inid]->GetBlobDesc().dims);
}
}
}

btype_ = BroadcastTypeUnknown;
// check broadcast type is general or other optimized ncxhwx types
// if type is general, go to nchw general impl
DimsVector input_pad_shape;
input_pad_shape.resize(output_dims.size());
for (int i = 0; i < input_shapes_.size(); i++) {
int pad_size = output_dims.size() - input_shapes_[i].size();
PadShape(pad_size, output_dims.size(), input_pad_shape, input_shapes_[i]);
BroadCastTypeFilter(output_dims, input_pad_shape, btype_);
if (btype_ == BroadcastTypeGeneral) {
break;
}
}

return TNN_OK;
return Reshape(inputs, outputs);
}

// if reshape, reset input_shapes and broadcast type
Expand Down
4 changes: 4 additions & 0 deletions source/tnn/device/arm/acc/arm_binary_layer_acc.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class ArmBinaryLayerAcc : public ArmLayerAcc {

virtual Status RefreshBuffers(const std::vector<Blob *> &inputs, const std::vector<Blob *> &outputs) override;

RawBuffer &GetResource() { return broadcast_;}
std::vector<void *> &GetInputPtrs() { return input_ptrs_;}
std::vector<DimsVector> &GetInputShapes() { return input_shapes_;}

protected:
virtual bool DataTypeSupported(DataType data_type) override;
virtual Status ConfigBuffer2ArmBlobDesc(BlobDesc &desc) override;
Expand Down
1 change: 1 addition & 0 deletions source/tnn/device/arm/acc/arm_layer_acc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ Status ArmLayerAcc::ReloadConstantBlobs(const std::vector<Blob *> &inputs, bool
RETURN_ON_NEQ(status, TNN_OK);

blob->SetFlag(DATA_FLAG_CHANGE_NEVER);
blob->GetBlobDesc().name = name;
const_blob_map[name] = blob;
iter->SetHandle(blob->GetHandle());
iter->GetBlobDesc() = blob->GetBlobDesc();
Expand Down
2 changes: 1 addition & 1 deletion source/tnn/device/arm/acc/compute/binary_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace TNN_NS {

void PadShape(const int pad_size, const int dim_size, DimsVector &pad_shape, DimsVector in_shape) {
void PadShape(const int pad_size, const int dim_size, DimsVector &pad_shape, const DimsVector &in_shape) {
int j = 0;
for (; j < pad_size; j++) {
pad_shape[j] = 1;
Expand Down
2 changes: 1 addition & 1 deletion source/tnn/device/arm/acc/compute/binary_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dtype binary_op(const dtype &a, const dtype &b, float alpha = 0, float beta = 0)
return a;
}

void PadShape(const int pad_size, const int dim_size, DimsVector &pad_shape, DimsVector in_shape);
void PadShape(const int pad_size, const int dim_size, DimsVector &pad_shape, const DimsVector& in_shape);

void BroadCastTypeFilter(const DimsVector &dims_output, const DimsVector &dims_input, BroadcastType &type);

Expand Down
13 changes: 6 additions & 7 deletions source/tnn/device/arm/acc/gradient/arm_add_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ namespace TNN_NS {
// z = x + y
// dz/dx = 1
// dz/dy = 1
typedef struct arm_add_grad_function: arm_binary_grad_function {
virtual std::pair<float, float> operator()(const float &i_0, const float &i_1, const float &o, const float &og) {
class ArmAddGradOp : public ArmBinaryGradOp {
private:
virtual std::pair<float, float> cal_grad(const float &i_0, const float &i_1, const float &o, const float &og) override {
return {og, og};
}
virtual std::pair<Float4, Float4> operator()(const Float4 &i_0, const Float4 &i_1, const Float4 &o,
const Float4 &og) {
virtual std::pair<Float4, Float4> cal_grad(const Float4 &i_0, const Float4 &i_1, const Float4 &o,
const Float4 &og) override {
return {og, og};
}
} ARM_ADD_GRAD_FUNC;

DEFINE_ARM_BINARY_GRAD_OP(Add, ARM_ADD_GRAD_FUNC)
};

REGISTER_ARM_GRAD_OP(Add, LAYER_ADD)
REGISTER_ARM_GRAD_LAYOUT(LAYER_ADD, DATA_FORMAT_NC4HW4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ namespace TNN_NS {
// y = -x1*log(x0) - (1-x1)*log(1-x0)
// dy/dx0 = (1-x1)/(1-x0) -x1/x0
// dy/dx1 = log(1-x0) - log(x0)
typedef struct arm_bce_grad_function: arm_binary_grad_function {
virtual std::pair<float, float> operator()(const float &i_0, const float &i_1, const float &o, const float &og) {
class ArmBinaryCrossEntropyGradOp : public ArmBinaryGradOp {
private:
virtual std::pair<float, float> cal_grad(const float &i_0, const float &i_1, const float &o, const float &og) override {
return {((1.0 - i_1) / (1.0 - i_0) - i_1 / i_0) * og, (std::log(1.0 - i_0) - std::log(i_0)) * og};
}
virtual std::pair<Float4, Float4> operator()(const Float4 &i_0, const Float4 &i_1, const Float4 &o,
const Float4 &og) {
virtual std::pair<Float4, Float4> cal_grad(const Float4 &i_0, const Float4 &i_1, const Float4 &o,
const Float4 &og) override {
Float4 g0 = Float4::div(Float4(1.0) - i_1, Float4(1.0) - i_0) - Float4::div(i_1, i_0);
Float4 g1 = Float4::log(Float4(1.0) - i_0) - Float4::log(i_0);
return {g0 * og, g1 * og};
}
} ARM_BCE_GRAD_FUNC;
};

DEFINE_ARM_BINARY_GRAD_OP(BinaryCrossEntropy, ARM_BCE_GRAD_FUNC)

REGISTER_ARM_GRAD_OP(BinaryCrossEntropy, LAYER_BINARY_CROSSENTROPY)
REGISTER_ARM_GRAD_LAYOUT(LAYER_BINARY_CROSSENTROPY, DATA_FORMAT_NC4HW4)
Expand Down