Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce time complexity of node replacment in PyTorch frontend #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
110 changes: 65 additions & 45 deletions python/tvm/relay/frontend/pytorch.py
Expand Up @@ -293,7 +293,7 @@ def make_elemwise(self, name):
def elemwise(inputs, input_types):
if name == "divide":
# https://pytorch.org/docs/stable/generated/torch.div.html#torch.div
# None - default behavior. Performs no rounding and, if both input and
# None - default behavior. Performs no rounding and, if both input and
# other are integer types, promotes the inputs to the default scalar type.
if all(["int" in input_type for input_type in input_types[:2]]):
input_types[:2] = ["float32"] * 2
Expand Down Expand Up @@ -744,7 +744,7 @@ def tensordot(self, input, input_types):
y = input[1]
xshape = self.infer_shape(x)
yshape = self.infer_shape(y)

# handle all types of inputs for `dims`
if isinstance(dims, int):
pairs = []
Expand Down Expand Up @@ -773,7 +773,7 @@ def tensordot(self, input, input_types):
for j in range(len(dims[0])):
if dims[0][j] < 0:
dims[0][j] += len(xshape)

dims[1] = list(dims[1])
for j in range(len(dims[1])):
if dims[1][j] < 0:
Expand All @@ -793,15 +793,15 @@ def tensordot(self, input, input_types):
dim_to_char = OrderedDict()
dim_to_char[0] = OrderedDict()
dim_to_char[1] = OrderedDict()


x_str = ""
for i, j in enumerate(xshape):
if i not in dim_to_char[0]:
dim_to_char[0][i] = alphabet[l]
l += 1
x_str = x_str + dim_to_char[0][i]


y_str = ""
for i, j in enumerate(yshape):
Expand Down Expand Up @@ -990,7 +990,7 @@ def fill(self, inputs, input_type):
def full(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]

# Convert to scaler if provided values is TVM call expression and is not dependent on any inputs (i.e. constant)
fill_value = _infer_value(fill_value, {}).numpy().item() if type(fill_value) == _expr.Call and len(_analysis.free_vars(fill_value)) == 0 else fill_value
if isinstance(fill_value, tvm.relay.expr.TupleGetItem):
Expand Down Expand Up @@ -1389,7 +1389,7 @@ def conv2d(self, inputs, input_types):
# Add no output padding and move groups from inputs[6] to inputs[8]
inputs.append([0, 0])
inputs.append(inputs[6])
inputs[6] = 0
inputs[6] = 0
return self.convolution(inputs, input_types)

def convolution(self, inputs, input_types):
Expand Down Expand Up @@ -1812,7 +1812,7 @@ def view(self, inputs, input_types):
if isinstance(data, _expr.Constant):
old_shape = data.data.shape
num_new_dims = len(new_shape) - len(old_shape)

if num_new_dims > 1:
data = _op.transform.expand_dims(data, -1, num_new_dims)
return _op.transform.reshape(data, new_shape)
Expand Down Expand Up @@ -2543,7 +2543,7 @@ def broadcast_tensors(self, inputs, input_types):

def broadcast_to(self, inputs, input_types):
tensor_list = inputs[1]

if type(tensor_list) is list:
res_shape = tensor_list
else:
Expand Down Expand Up @@ -2689,7 +2689,7 @@ def embedding_bag(self, inputs, input_types):

take = []
take.append(_op.embedding(weight, indices.astype("int32"), axis=0))

if mode == "sum":
out = _op.sum(take[0], axis=0, keepdims=True)
elif mode == "mean":
Expand Down Expand Up @@ -2721,7 +2721,7 @@ def index(self, inputs, input_types):
if indices[0] == None:
# Remove first None argument (represents ':')
indices.pop(0)

assert len(_infer_shape(data)) == 2 and len(indices) == 1, "Currently supportes only 2D tensors with single mask"

indices = indices[0]
Expand Down Expand Up @@ -2796,7 +2796,7 @@ def index(self, inputs, input_types):
# Extract indices from boolean mask
indices = _op.transform.argwhere(indices)
# Doing this reshape to remove dynamic shapes caused by argwhere op (e.g. '?' shapes). This
# reshape will ensure that the output of argwhere (and following ops) is "predictable" in a
# reshape will ensure that the output of argwhere (and following ops) is "predictable" in a
# manner to suport further TVM compilation. However, this is only valid if this op is fallback
# on CPU. Otherwise, this reshape will cause incorrect results.
# indices = _op.reshape(indices, newshape=_infer_shape(data)[0])
Expand Down Expand Up @@ -3040,19 +3040,19 @@ def index_put(self, inputs, input_types):
mode = "add"
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)

# Narrow index tensor to match input tensor
shape_diff = len(_infer_shape(index_tensor)) - len(_infer_shape(in_tensor))
if shape_diff > 0:
for shape in range(shape_diff):
index_tensor = _op.squeeze(index_tensor, axis=[0])

# If indexes are in form of boolean mask instead of indices, use where op
# instead of scatter_nd
if _infer_type(index_tensor).checked_type.dtype == "bool":
if isinstance(values, float):
values = _expr.const(values, dtype=_infer_type(in_tensor).checked_type.dtype)

# Make sure that dynamic output will be 1D vector
# index_tensor = _op.reshape(index_tensor, newshape=(-1,))
# Make sure that dynamic output will be 1D vector
Expand All @@ -3065,13 +3065,13 @@ def index_put(self, inputs, input_types):
indices = _op.transform.argwhere(index_tensor)
indices = _op.transpose(indices, (1, 0))
indices = _op.squeeze(indices, _expr.const([0])) if len(_infer_shape(indices)) == 2 and _infer_shape(indices)[0] == 1 else indices

# Make sure that dynamic output will be 1D vector
values = _op.reshape(values, newshape=(-1,))

# Reduce data to 1D vector if possible
in_tensor = _op.reshape(in_tensor, newshape=(-1,))

res = _op.scatter_elements(in_tensor, indices, values, 0, "add")

return res
Expand Down Expand Up @@ -3322,11 +3322,11 @@ def replace_inf(inp, replacement_val=1e4):
value = _op.broadcast_to_like(value, mask)

one_const = _expr.const(1, dtype="float32")

# Original implementation
# return _op.where(mask, value, inputs[0])
# Implementaiton without using where operator in order to avoide numerical instability

# Implementaiton without using where operator in order to avoide numerical instability
# for certain models caused by the future matmul (once where is decomposed)
return _op.add(_op.multiply(inputs[0], _op.subtract(one_const, mask)), _op.multiply(value, mask))

Expand Down Expand Up @@ -4054,7 +4054,7 @@ def all_any_common(self, op, inputs, input_types):
dim = inputs[1]
else:
dim = 0

if len(inputs) > 2:
keepdim = inputs[2]
else:
Expand Down Expand Up @@ -4273,7 +4273,7 @@ def tril(self, inputs, input_types):
y = np.tril(np.ones(x_shape)).astype(_convert_tvm_to_np_dtype(input_types[0]))
y = tvm.nd.array(y)
y = tvm.relay.Constant(y)

return _op.multiply(x, y)


Expand All @@ -4288,7 +4288,7 @@ def triu(self, inputs, input_types):
zeros = np.zeros(x_shape).astype(_convert_tvm_to_np_dtype(input_types[0]))
zeros = tvm.nd.array(zeros)
zeros = tvm.relay.Constant(zeros)

return _op.where(mask, x, zeros)


Expand Down Expand Up @@ -4330,15 +4330,15 @@ def as_strided(self, inputs, input_types):

rc_begin += stride_col
rc_end = rc_begin + (n_out_col * stride_col)

rc_rows = _op.concatenate(rc_rows, axis=0)
rc_rows = _op.expand_dims(rc_rows, axis=0)
time_rows = np.append(time_rows, rc_rows)

time_rows = _op.concatenate(time_rows, axis=0)
time_rows = _op.expand_dims(time_rows, axis=0)
batch_rows = np.append(batch_rows, time_rows)

return _op.concatenate(batch_rows, axis=0)


Expand Down Expand Up @@ -4388,7 +4388,7 @@ def alias(self, inputs, inputs_types):

# Get constant dtype
dtype = _convert_data_type(shape.data.dtype, default_dtype="float32")

# Convert to numpy array
shape = shape.data.numpy()
if len(shape.shape) == 0:
Expand Down Expand Up @@ -4464,7 +4464,7 @@ def scaled_dot_product_attention(self, inputs, input_types):

scale_factor = _expr.const(1 / math.sqrt(query_shape[-1]), dtype=dtype)
scale_factor = _op.broadcast_to(scale_factor, shape=tuple(1 for _ in range(len(query_shape))))

# Early out if not decomposing
return _op.nn.scaled_dot_product_attention(
query,
Expand Down Expand Up @@ -4505,7 +4505,7 @@ def scaled_dot_product_attention(self, inputs, input_types):
batch_size = key_shape[0]
else:
batch_size = query_shape[0]

if len(query_shape) == 4 and len(key_shape) == 4:
query = _op.reshape(query, newshape=[-3, -2])
key = _op.reshape(key, newshape=[-3, -2])
Expand Down Expand Up @@ -5650,11 +5650,11 @@ def get_relay_ty(ishape, itype, pt_type):

input_vars = {}


def get_new_input_infos(input_infos):
new_input_infos = []
for num, inp in enumerate(input_infos):

if not isinstance(inp, tuple):
msg = "Graph input {} is not a tuple".format(num)
raise RuntimeError(msg)
Expand All @@ -5663,7 +5663,7 @@ def get_new_input_infos(input_infos):
"Graph input {} is not valid,"
" expected ('name', shape) or ('name', (shape, dtype))".format(inp)
)

raise RuntimeError(msg)
if isinstance(inp[1], (list, tuple)) and isinstance(inp[1][0], (list, tuple)) and isinstance(inp[1][0][0], str):
new_input_infos.append((inp[0], get_new_input_infos(inp[1])))
Expand All @@ -5672,9 +5672,9 @@ def get_new_input_infos(input_infos):
else:
new_input_infos.append(inp)
return new_input_infos

new_input_infos = get_new_input_infos(input_infos)

def get_input_types(input_infos, graph_input_types):
input_types = []
for (name, info), gi_type in zip(input_infos, graph_input_types):
Expand All @@ -5685,11 +5685,11 @@ def get_input_types(input_infos, graph_input_types):
input_types.append((name, get_relay_ty(info[0], info[1], gi_type), info[1])) # info[1] is the framework datatype, which may differ after being converted to relay
return input_types


graph_input_types = [gi.type() for gi in graph_inputs]
input_types = get_input_types(new_input_infos, graph_input_types)

def get_input_vars(input_types, graph_input_names, use_tuple_type=False, tuple_name=""):
def get_input_vars(input_types, graph_input_names, use_tuple_type=False, tuple_name=""):
input_vars = {} if not use_tuple_type else []
for gi_name, gi_type in zip(graph_input_names, input_types):
name, itype = gi_type[0], gi_type[1]
Expand Down Expand Up @@ -5828,7 +5828,7 @@ def convert_params(graph, state_dict, source_map, use_parser_friendly_name=False
elif full_attr in state_dict:
if var_name in vars_by_name:
var = vars_by_name[var_name]
# we need to remap inputs that pointed to the old
# we need to remap inputs that pointed to the old
input_remap[full_attr_node_name] = outputs_by_var_name[var_name]
else:
torch_tensor = state_dict[full_attr]
Expand Down Expand Up @@ -5871,7 +5871,19 @@ def export_c_graph(location, graph):
fname = os.path.join(location, f"tvm_exported_c_graph_{time_stamp}.txt")
with open(f"{fname}", "w") as f:
f.write(str(graph))


def _binary_search(lst, func):
nvukobratTT marked this conversation as resolved.
Show resolved Hide resolved
"""Binary search for the first index that func returns True"""
l, r = 0, len(lst)
while l < r:
m = (l + r) // 2
if func(lst[m]):
r = m
else:
l = m + 1
return l


def outplace_inplace_ops(opnodes):

replace_map = []
Expand All @@ -5880,18 +5892,26 @@ def outplace_inplace_ops(opnodes):
for i, (node_name, op_node) in enumerate(opnodes):
operator = op_node.kind()
# Check if op is in-place (avoid '__not__', etc.)
if operator[-1] == '_' and operator[-2:] != "__":
if operator[-1] == '_' and operator[-2:] != "__":
input_node = op_node.inputsAt(0)
replace_map.append((i, input_node, op_node.outputsAt(0)))

# Replace future uses of node with an in-place op applied to it with the output of the op
for node_idx, orig_node, replacement_node in replace_map:
relevant_ops = opnodes[node_idx+1:]
node_inputs_map = {}
nvukobratTT marked this conversation as resolved.
Show resolved Hide resolved
for idx, (node_name, op_node) in enumerate(opnodes):
for inp in op_node.inputs():
if inp not in node_inputs_map:
node_inputs_map[inp] = []
node_inputs_map[inp].append((idx, op_node))

for _, node in relevant_ops:
if orig_node in node.inputs():
node.replaceInputWith(orig_node, replacement_node)
for node_idx, orig_node, replacement_node in replace_map:
if orig_node not in node_inputs_map:
continue

relevant_ops = node_inputs_map[orig_node]
begin_idx = _binary_search(relevant_ops, lambda x: x[0] > node_idx)
for idx, node in relevant_ops[begin_idx:]:
nvukobratTT marked this conversation as resolved.
Show resolved Hide resolved
node.replaceInputWith(orig_node, replacement_node)

def from_pytorch(
script_module,
Expand All @@ -5901,7 +5921,7 @@ def from_pytorch(
use_parser_friendly_name=False,
keep_quantized_weight=False,
export_renamed_c_graph_path=None,
do_convert_params=True,
do_convert_params=True,
):
"""Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
Expand Down