Skip to content

Commit

Permalink
fix testing
Browse files Browse the repository at this point in the history
  • Loading branch information
larme committed Feb 14, 2023
1 parent 4495625 commit b565287
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/integration/frameworks/models/onnx.py
Expand Up @@ -289,7 +289,7 @@ def _check(
def make_bert_onnx_model(tmpdir) -> tuple[onnx.ModelProto, t.Any]:
model_id = TINY_BERT_MODEL_ID
bert_model = AutoModelForSequenceClassification.from_pretrained(model_id)
tokenizer = AutoTokenizer(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
sample_text = "This is a sample"
sample_input = tokenizer(sample_text, return_tensors="pt")
model_path = os.path.join(tmpdir, "bert-tiny.onnx")
Expand Down Expand Up @@ -318,8 +318,9 @@ def make_bert_onnx_model(tmpdir) -> tuple[onnx.ModelProto, t.Any]:
return (onnx_model, expected_data)


onnx_bert_raw_model, _expected_data = make_bert_onnx_model()
bert_input, bert_expected_output = _expected_data
with tempfile.TemporaryDirectory() as tmpdir:
onnx_bert_raw_model, _expected_data = make_bert_onnx_model(tmpdir)
bert_input, bert_expected_output = _expected_data


def method_caller_kwargs(
Expand All @@ -342,6 +343,7 @@ def to_numpy(item):
input_names = {k: list(v) for k, v in kwargs}
output_names = [o.name for o in ort_sess.get_outputs()]
out = getattr(ort_sess, method)(output_names, input_names)[0]
print("hahahah lkasjdfklasfdsaf")
return out


Expand Down

0 comments on commit b565287

Please sign in to comment.