Skip to content

Commit

Permalink
Fix Fusion IR cloning (pytorch#567)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlemo committed Dec 10, 2020
1 parent cd1242b commit c6d8c4a
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 35 deletions.
38 changes: 11 additions & 27 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Expand Up @@ -72,41 +72,25 @@ Fusion::Fusion(const Fusion& other) {
val_set_.insert(ir_cloner.clone(val));
}

for (auto expr : other.expr_set_) {
expr_set_.insert(ir_cloner.clone(expr));
}

for (auto val : other.val_deque_) {
val_deque_.push_back(ir_cloner.clone(val));
}

for (auto old_expr : other.expr_set_) {
auto new_expr = ir_cloner.clone(old_expr);
expr_set_.insert(new_expr);

// ir_cloner doesn't go through registerStmt, so we need to "Register Expr"
// we would similarly need to do to val if there was in that pass that is
// also not covered here.
for (Val* input : new_expr->inputs()) {
auto uses_copy = input->uses();
if (std::find(uses_copy.begin(), uses_copy.end(), new_expr) ==
uses_copy.end()) {
uses_copy.push_back(new_expr);
input->setUses(uses_copy);
}
}
// Fixup potentially cyclic pointers
for (auto val : val_set_) {
val->definition_ = ir_cloner.clone(val->definition_);
val->uses_ = ir_cloner.clone(val->uses_);
}

val_type_name_map_ = other.val_type_name_map_;
expr_name_counter_ = other.expr_name_counter_;

inputs_ = ir_cloner.clone(other.inputs_);
outputs_ = ir_cloner.clone(other.outputs_);

for (auto inp : inputs_) {
inp->setIsFusionInput(true);
}
for (auto out : outputs_) {
out->setIsFusionOutput(true);
}

resetTvUses();
}

Fusion::Fusion(Fusion&& other) noexcept {
Expand Down Expand Up @@ -421,16 +405,16 @@ void Fusion::resetTvUses() {
// remove dead exprs, this could reinsert them. getExprs is also boundeds by
// inputs as registered inputs will return nullptr as their definition.
const auto all_tvs = ir_utils::filterByType<TensorView>(val_set_);
auto used_exprs = ExprSort::getExprs(this);
const auto used_exprs = ExprSort::getExprs(this);

for (auto tv : all_tvs) {
tv->setUses(std::deque<Expr*>());
tv->setUses({});
}

// Same as in register expr
for (auto expr : used_exprs) {
for (Val* input : expr->inputs()) {
std::deque<Expr*> uses_copy = input->uses();
auto uses_copy = input->uses();
if (std::find(uses_copy.begin(), uses_copy.end(), expr) ==
uses_copy.end()) {
uses_copy.push_back(expr);
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Expand Up @@ -53,11 +53,19 @@ Val::Val(ValType _vtype, DataType _dtype, bool register_val)
}
}

// NOTE: we don't clone the definition_ and uses_ here
// since they may introduce cloning cycles. Instead, we copy
// the original pointers and we'll fix them up later part of the
// Fusion copy
//
Val::Val(const Val* src, IrCloner* ir_cloner)
: Statement(src, ir_cloner),
vtype_(src->vtype_),
dtype_(src->dtype_),
definition_(ir_cloner->clone(src->definition())) {}
is_fusion_input_(src->is_fusion_input_),
is_fusion_output_(src->is_fusion_output_),
definition_(src->definition_),
uses_(src->uses_) {}

namespace {

Expand Down
7 changes: 3 additions & 4 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Expand Up @@ -9,7 +9,6 @@
#include <torch/csrc/jit/codegen/cuda/utils.h>

#include <cstdint>
#include <deque>
#include <iostream>
#include <limits>
#include <memory>
Expand Down Expand Up @@ -214,7 +213,7 @@ class TORCH_CUDA_API Val : public Statement {
return definition_;
}

const std::deque<Expr*>& uses() const {
const auto& uses() const {
return uses_;
}

Expand Down Expand Up @@ -272,7 +271,7 @@ class TORCH_CUDA_API Val : public Statement {
is_fusion_output_ = is_fusion_output;
}

void setUses(std::deque<Expr*> uses) {
void setUses(const std::vector<Expr*>& uses) {
uses_ = uses;
}

Expand All @@ -282,7 +281,7 @@ class TORCH_CUDA_API Val : public Statement {
bool is_fusion_output_ = false;

Expr* definition_ = nullptr;
std::deque<Expr*> uses_;
std::vector<Expr*> uses_;
};

//! A Expr represents a "computation." These are functions that takes inputs
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_cloner.h
Expand Up @@ -13,7 +13,11 @@ namespace cuda {

class Fusion;

// Clones nodes from an exiting Fusion
//! Clones nodes from an exiting Fusion
//!
//! \warning IrCloner machinery is a specialized helper for implementing
//! Fusion copy operations and it's not intended for any other uses
//!
class TORCH_CUDA_API IrCloner : private OptInConstDispatch {
friend class Statement;

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/root_domain_map.cpp
Expand Up @@ -71,7 +71,7 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
TORCH_INTERNAL_ASSERT(producer_tv_->domain() == producer);
TORCH_INTERNAL_ASSERT(consumer_tv_->domain() == consumer);

if (consumer_tv_->getOrigin()->isA<TransposeOp>()) {
if (consumer_tv_->definition()->isA<TransposeOp>()) {
return mapTranspose(
producer, consumer, root_dims_to_map, producer_to_consumer);
}
Expand Down Expand Up @@ -126,7 +126,7 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::

std::unordered_map<IterDomain*, IterDomain*> dom_map;

TransposeOp* top = dynamic_cast<TransposeOp*>(consumer_tv_->getOrigin());
TransposeOp* top = dynamic_cast<TransposeOp*>(consumer_tv_->definition());
TORCH_INTERNAL_ASSERT(top != nullptr);

const auto& new2old = top->new2old();
Expand Down

0 comments on commit c6d8c4a

Please sign in to comment.