Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

当模型predict参数为ndarray的时候会将参数修改成list #100

Open
danerlt opened this issue Jul 27, 2023 · 1 comment
Open

当模型predict参数为ndarray的时候会将参数修改成list #100

danerlt opened this issue Jul 27, 2023 · 1 comment

Comments

@danerlt
Copy link

danerlt commented Jul 27, 2023

我有一个推理函数如下:

@log_execution_time
def batch_predict(src_data: pd.DataFrame) -> list:
    torque_angle_trace: ndarray = preprocess(src_data)
    start = time.time()
    predict_res = model.predict(torque_angle_trace)
    ene = time.time()
    logger.info(f"单个预测耗时:{ene - start}")
    res = predict_res.tolist()
    return res

其中model.predict接收的参数是一个ndarry
下面是使用ThreadedStreamer之后的代码:

stream = ThreadedStreamer(model.predict, batch_size=10, max_latency=0.1)


def batch_predict_stream(src_data: pd.DataFrame) -> list:
    start = time.time()
    torque_angle_trace: ndarray = preprocess(src_data)
    predict_res = stream.predict([torque_angle_trace])
    ene = time.time()
    logger.info(f"stream预测耗时:{ene - start}")
    return predict_res

在调用stream.predict的时候我已经将数据处理成ndarray传进去了。
然后在运行的时候提示TypeError: X is not of a supported input data type.X must be in a supported mtype format for Panel, found <class 'list'>Use datatypes.check_is_mtype to check conformance with specifications.

我查看源码之后发现问题在下图将队列中取到的数据放到了一个list中然后传递给predict函数

image

请问是我使用ThreadedStreamer方法不对,还是predict函数不支持ndarray的参数。

@danerlt
Copy link
Author

danerlt commented Jul 27, 2023

针对上面的问题我将183行处的循环改成了如下所示,就正常运行了。

        model_inputs = []
        is_ndarray = False
        for i in batch:
            model_input = i[3]
            if isinstance(model_input, np.ndarray):
                is_ndarray = True
            model_inputs.append(model_input)
        if is_ndarray:
            model_inputs = np.vstack(model_inputs)
        model_outputs = self.model_predict(model_inputs)

        if is_ndarray:
            model_outputs = model_outputs.tolist()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant