Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 第四期】No.1:为 Paddle 新增 finfo API update #406

Merged
merged 4 commits into from
Mar 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 102 additions & 73 deletions rfcs/APIs/20220330_api_design_for_finfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

| API名称 | paddle.finfo |
| ------------------------------------------------------------ | -------------------------------- |
| 提交作者<input type="checkbox" class="rowselector hidden"> | 林旭(isLinXu) |
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2022-04-12 |
| 版本号 | V2.0 |
| 提交作者<input type="checkbox" class="rowselector hidden"> | lisamhy,林旭(isLinXu) |
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2022-04-12 |
| 版本号 | V2.0 |
| 依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | develop |
| 文件名 | 20220330_api-design_for_finfo.md |
| 文件名 | 20220330_api-design_for_finfo.md |

# 一、概述

Expand Down Expand Up @@ -484,91 +484,120 @@ API设计为`paddle.finfo(dtype)`,根据选择计算方法(比如eps、max、m

通过设计实现与API对应的Class,并通过pybind将相应的成员函数绑定到python,从而实现该API。

- `.h`头文件定义声明
- `pybind.cc` finfo class 实现

```cpp
namespace paddle {
namespace pybind {
void BindFinfoVarDsec(pybind11::module *m);
void BindIinfoVarDsec(pybind11::module *m);
}
}
struct finfo {
int64_t bits;
double eps;
double min; // lowest()
double max;
double tiny;
double smallest_normal; // min()
double resolution;
std::string dtype;

explicit finfo(const framework::proto::VarType::Type &type) {
switch (type) {
case framework::proto::VarType::FP16:
eps = std::numeric_limits<paddle::platform::float16>::epsilon();
min = std::numeric_limits<paddle::platform::float16>::lowest();
max = std::numeric_limits<paddle::platform::float16>::max();
smallest_normal = std::numeric_limits<paddle::platform::float16>::min();
tiny = smallest_normal;
resolution = std::pow(
10, -std::numeric_limits<paddle::platform::float16>::digits10);
bits = 16;
dtype = "float16";
break;
case framework::proto::VarType::FP32:
case framework::proto::VarType::COMPLEX64:
eps = std::numeric_limits<float>::epsilon();
min = std::numeric_limits<float>::lowest();
max = std::numeric_limits<float>::max();
smallest_normal = std::numeric_limits<float>::min();
tiny = smallest_normal;
resolution = std::pow(10, -std::numeric_limits<float>::digits10);
bits = 32;
dtype = "float32";
break;
case framework::proto::VarType::FP64:
case framework::proto::VarType::COMPLEX128:
eps = std::numeric_limits<double>::epsilon();
min = std::numeric_limits<double>::lowest();
max = std::numeric_limits<double>::max();
smallest_normal = std::numeric_limits<double>::min();
tiny = smallest_normal;
resolution = std::pow(10, -std::numeric_limits<double>::digits10);
bits = 64;
dtype = "float64";
break;
case framework::proto::VarType::BF16:
eps = std::numeric_limits<paddle::platform::bfloat16>::epsilon();
min = std::numeric_limits<paddle::platform::bfloat16>::lowest();
max = std::numeric_limits<paddle::platform::bfloat16>::max();
smallest_normal =
std::numeric_limits<paddle::platform::bfloat16>::min();
tiny = smallest_normal;
resolution = std::pow(
10, -std::numeric_limits<paddle::platform::bfloat16>::digits10);
bits = 16;
dtype = "bfloat16";
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"the argument of paddle.finfo can only be paddle.float32, "
"paddle.float64, paddle.float16, paddle.bfloat16"
"paddle.complex64, or paddle.complex128"));
break;
}
}
};
```

- `.cc`绑定实现设计
- `pybind.cc` finfo 绑定实现

```cpp

void BindFInfoVarDsec(pybind11::module *m){
pybind11::class_<pd::VarDesc> finfo_var_desc(*m, "VarDesc", "");
finfo_var_desc.def(pybind11::init<const std::string &>())
.def("bits", &pd::Tinfo::Bits)
.def("eps", &pd::Tinfo::Eps)
.def("min", &pd::Tinfo::Min)
.def("max", &pd::Tinfo::Max)
.def("tiny", &pd::Tinfo::Tiny)
.def("resolution", &pd::Tinfo::Resolution)
}
py::class_<finfo>(m, "finfo")
.def(py::init<const framework::proto::VarType::Type &>())
.def_readonly("min", &finfo::min)
.def_readonly("max", &finfo::max)
.def_readonly("bits", &finfo::bits)
.def_readonly("eps", &finfo::eps)
.def_readonly("resolution", &finfo::resolution)
.def_readonly("smallest_normal", &finfo::smallest_normal)
.def_readonly("tiny", &finfo::tiny)
.def_readonly("dtype", &finfo::dtype)
.def("__repr__", [](const finfo &a) {
std::ostringstream oss;
oss << "paddle.finfo(min=" << a.min;
oss << ", max=" << a.max;
oss << ", eps=" << a.eps;
oss << ", resolution=" << a.resolution;
oss << ", smallest_normal=" << a.smallest_normal;
oss << ", tiny=" << a.tiny;
oss << ", bits=" << a.bits;
oss << ", dtype=" << a.dtype << ")";
return oss.str();
});
```

- `dtype.py` python 暴露 finfo API

```python
from ..fluid.core import finfo as core_finfo

def finfo(dtype):
return core_finfo(dtype)
```

实现思路:

- 从调研Torch的实现方案来看,它并没有使用OP或者重写Kernel来进行实现,而是通过设计实现一个Class来进行返回API结果。

- 因此要实现该API,需要如上抽象出一个符合要求的Class,同时并声明定义类下的成员函数来分别实现功能

- 通过类的成员函数分别来实现eps、min、max等函数,通过Pybind11来进行接口与参数的绑定



## API实现方案

在paddle/fluid/framework/Info.h与Info.cc下新增实现函数
定义class为`Tinfo`(借鉴Torch的结构设计,将finfo与iinfo合并为一个类进行实现)

```c
class Tinfo {
public:
int Bits(const at::ScalarType& type)
float Eps(const at::ScalarType& type)
float Min(const at::ScalarType& type)
float Max(const at::ScalarType& type)
float Tiny(const at::ScalarType& type)
float Resolution(const at::ScalarType& type)
}
```

`.cc`实现

```cpp
int Tinfo::Bits(const at::ScalarType& type){
int bits = elementSize(self->type) * 8;
return THPUtils_packInt64(bits);
}

float Tinfo::Eps(const at::ScalarType& type){
return std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon());
}

float Tinfo::Min(const at::ScalarType& type){
return std::numeric_limits<at::scalar_value_type<scalar_t>::type>::lowest());
}

float Tinfo::Max(const at::ScalarType& type){
return std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
}

float Tinfo::Tiny(const at::ScalarType& type){
return std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
}

float Tinfo::Resolution(const at::ScalarType& type){
return std::numeric_limits<at::scalar_value_type<scalar_t>::type>::resolution());

}
```
- 通过类的成员函数分别来实现 eps、min、max、bits、resolution、tiny、smallest_normal、dtype 等函数,通过Pybind11来进行接口与参数的绑定



Expand Down