Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Running Inference from ONNX model #2133

Open
2 of 19 tasks
dfustes opened this issue Nov 13, 2023 · 1 comment
Open
2 of 19 tasks

[BUG] Running Inference from ONNX model #2133

dfustes opened this issue Nov 13, 2023 · 1 comment

Comments

@dfustes
Copy link

dfustes commented Nov 13, 2023

SynapseML version

1.0.1

System information

  • Language version (python 3.8):
  • Spark Version (3.2.1):
  • Spark Platform (Python Standalone):

Describe the problem

Hi, I'm trying to run a given ONNX model, exported from HuggingFace. I can execute the model using the ONNXRuntime in python without issues. However, when I load it using the SynapseML ONNX inference server, the code fails as it's not able to grab the model output

Code to reproduce issue

import pyspark
import synapse
from pyspark.sql.types import StringType
from synapse.ml.onnx import ONNXModel

spark = pyspark.sql.SparkSession.builder.appName("MyApp") \
    .config("spark.jars.packages",
            "com.microsoft.azure:synapseml_2.12:1.0.1,com.microsoft.onnxruntime:onnxruntime:1.8.1") \
    .config("spark.driver.memory", "10g") \
    .config("spark.executor.memory", "10g") \
    .getOrCreate()


def process_rows(rows):
    from transformers import MPNetTokenizerFast
    tokenizer_kwargs = dict(
        max_length=512,
        padding="max_length",
        return_attention_mask=True,
        return_token_type_ids=True
    )
    tokenizer = MPNetTokenizerFast.from_pretrained(
        "./tokenizer")

    for row in rows:
        yield tokenizer(row["text"], **tokenizer_kwargs)


data = spark.createDataFrame(["This is an example sentence", "Each sentence is converted"],
                             StringType()).withColumnRenamed("value", "text")

features = data.rdd.mapPartitions(lambda rows: process_rows(rows)).toDF().select("data.input_ids",
                                                                                 "data.attention_mask",
                                                                                 "data.token_type_ids")

print(synapse.ml.core.__spark_package_version__)

model_path = "model.onnx"
onnx_ml = (ONNXModel()
           .setModelLocation(model_path)
           .setFeedDict(
    {"input_ids": "input_ids", "attention_mask": "attention_mask", "token_type_ids": "token_type_ids"})
           .setFetchDict({"logits": "logits"}).setMiniBatchSize(5000))

onnx_ml.transform(features).write.mode("overwrite").json("news-relevance-predictions")

Other info / logs

  File "news_relevance_synapseml.py", line 45, in <module>
    onnx_ml.transform(features).write.mode("overwrite").json("news-relevance-predictions")
  File "/Users/diegofustes/opt/anaconda3/envs/dpe-tagging/lib/python3.8/site-packages/pyspark/sql/readwriter.py", line 846, in json
    self._jwrite.json(path)
  File "/Users/diegofustes/opt/anaconda3/envs/dpe-tagging/lib/python3.8/site-packages/py4j/java_gateway.py", line 1321, in __call__
    return_value = get_return_value(
  File "/Users/diegofustes/opt/anaconda3/envs/dpe-tagging/lib/python3.8/site-packages/pyspark/sql/utils.py", line 111, in deco
    return f(*a, **kw)
  File "/Users/diegofustes/opt/anaconda3/envs/dpe-tagging/lib/python3.8/site-packages/py4j/protocol.py", line 326, in get_return_value
    raise Py4JJavaError(
py4j.protocol.Py4JJavaError: An error occurred while calling o120.json.
: org.apache.spark.SparkException: Job aborted.
        at org.apache.spark.sql.errors.QueryExecutionErrors$.jobAbortedError(QueryExecutionErrors.scala:496)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:251)
        at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:186)
        at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:113)
        at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:111)
        at org.apache.spark.sql.execution.command.DataWritingCommandExec.executeCollect(commands.scala:125)
        at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.$anonfun$applyOrElse$1(QueryExecution.scala:110)
        at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
        at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
        at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
        at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
        at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
        at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:110)
        at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:106)
        at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:481)
        at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:82)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:481)
        at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:30)
        at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:267)
        at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:263)
        at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:30)
        at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:30)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:457)
        at org.apache.spark.sql.execution.QueryExecution.eagerlyExecuteCommands(QueryExecution.scala:106)
        at org.apache.spark.sql.execution.QueryExecution.commandExecuted$lzycompute(QueryExecution.scala:93)
        at org.apache.spark.sql.execution.QueryExecution.commandExecuted(QueryExecution.scala:91)
        at org.apache.spark.sql.execution.QueryExecution.assertCommandExecuted(QueryExecution.scala:128)
        at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:848)
        at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:382)
        at org.apache.spark.sql.DataFrameWriter.saveInternal(DataFrameWriter.scala:355)
        at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:239)
        at org.apache.spark.sql.DataFrameWriter.json(DataFrameWriter.scala:763)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
        at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
        at py4j.Gateway.invoke(Gateway.java:282)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
        at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
        at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 7 in stage 3.0 failed 1 times, most recent failure: Lost task 7.0 in stage 3.0 (TID 13) (192.168.1.46 executor driver): java.lang.RuntimeException: [F is not a valid external type for schema of float
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.serializefromobject_doConsume_0$(Unknown Source)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
        at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
        at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:286)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$write$16(FileFormatWriter.scala:229)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
        at org.apache.spark.scheduler.Task.run(Task.scala:131)
        at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
        at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
        at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2454)
        at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2403)
        at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2402)
        at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
        at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
        at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
        at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2402)
        at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1160)
        at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1160)
        at scala.Option.foreach(Option.scala:407)
        at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1160)
        at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2642)
        at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2584)
        at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2573)
        at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
        at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:938)
        at org.apache.spark.SparkContext.runJob(SparkContext.scala:2214)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:218)
        ... 42 more
Caused by: java.lang.RuntimeException: [F is not a valid external type for schema of float
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.serializefromobject_doConsume_0$(Unknown Source)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
        at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
        at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:286)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$write$16(FileFormatWriter.scala:229)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
        at org.apache.spark.scheduler.Task.run(Task.scala:131)
        at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
        ... 1 more

What component(s) does this bug affect?

  • area/cognitive: Cognitive project
  • area/core: Core project
  • area/deep-learning: DeepLearning project
  • area/lightgbm: Lightgbm project
  • area/opencv: Opencv project
  • area/vw: VW project
  • area/website: Website
  • area/build: Project build system
  • area/notebooks: Samples under notebooks folder
  • area/docker: Docker usage
  • area/models: models related issue

What language(s) does this bug affect?

  • language/scala: Scala source code
  • language/python: Pyspark APIs
  • language/r: R APIs
  • language/csharp: .NET APIs
  • language/new: Proposals for new client languages

What integration(s) does this bug affect?

  • integrations/synapse: Azure Synapse integrations
  • integrations/azureml: Azure ML integrations
  • integrations/databricks: Databricks integrations
@dfustes dfustes added the bug label Nov 13, 2023
Copy link

Hey @dfustes 👋!
Thank you so much for reporting the issue/feature request 🚨.
Someone from SynapseML Team will be looking to triage this issue soon.
We appreciate your patience.

@memoryz memoryz self-assigned this Nov 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants