-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multinomial operator #1000
Add multinomial operator #1000
Conversation
felixhjh
commented
Dec 23, 2022
- Add multinomial operator
tests/test_auto_scan_multinomial.py
Outdated
else: | ||
num_samples = 1 | ||
|
||
dtype = draw(st.sampled_from(["float32"])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输入的dtype可以为float64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我参考了下,Paddle官方对multinomial api的测试用例,其中并没有输入为float64的测试用例。我在想是文档的问题吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已支持float64的输入
可以merge一下develop,CI就可以通过了 |
tests/test_auto_scan_multinomial.py
Outdated
"num_samples": num_samples, | ||
"replacement": replacement, | ||
"input_spec_shape": [], | ||
"out_dtype": "int64", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是干嘛用的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是因为,Paddle multinomial的输出是int64,onnx算子的输出可以是int32或者int64(默认是int32),为了和Paddle对齐,就把onnx的输出都指定为int64,这里检验下输出的type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
但是out_dtype应该没在这个文件中用到吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
delete out_dtype