diff --git a/lib/src/arithmetic/floating_point/floating_point.dart b/lib/src/arithmetic/floating_point/floating_point.dart index de0b3b40..d7cce900 100644 --- a/lib/src/arithmetic/floating_point/floating_point.dart +++ b/lib/src/arithmetic/floating_point/floating_point.dart @@ -4,4 +4,4 @@ export 'floating_point_adder_round.dart'; export 'floating_point_adder_simple.dart'; export 'floating_point_logic.dart'; -export 'floating_point_value.dart'; +export 'floating_point_values/floating_point_values.dart'; diff --git a/lib/src/arithmetic/floating_point/floating_point_logic.dart b/lib/src/arithmetic/floating_point/floating_point_logic.dart index fcdca310..3a2d7a9a 100644 --- a/lib/src/arithmetic/floating_point/floating_point_logic.dart +++ b/lib/src/arithmetic/floating_point/floating_point_logic.dart @@ -11,8 +11,7 @@ // import 'package:rohd/rohd.dart'; -import 'package:rohd_hcl/src/arithmetic/floating_point/floating_point_value.dart'; -import 'package:rohd_hcl/src/exceptions.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; /// Flexible floating point logic representation class FloatingPoint extends LogicStructure { @@ -84,12 +83,12 @@ class FloatingPoint64 extends FloatingPoint { mantissaWidth: FloatingPoint64Value.mantissaWidth); } -/// Eight-bit floating point representation for deep learning -class FloatingPoint8 extends FloatingPoint { +/// Eight-bit floating point representation for deep learning: E4M3 +class FloatingPoint8E4M3 extends FloatingPoint { /// Calculate mantissa width and sanitize static int _calculateMantissaWidth(int exponentWidth) { final mantissaWidth = 7 - exponentWidth; - if (!FloatingPoint8Value.isLegal(exponentWidth, mantissaWidth)) { + if (!FloatingPoint8E4M3Value.isLegal(exponentWidth, mantissaWidth)) { throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); } else { return mantissaWidth; @@ -97,6 +96,23 @@ class FloatingPoint8 extends FloatingPoint { } /// Construct an 8-bit floating point number - FloatingPoint8({required super.exponentWidth}) + FloatingPoint8E4M3({required super.exponentWidth}) + : super(mantissaWidth: _calculateMantissaWidth(exponentWidth)); +} + +/// Eight-bit floating point representation for deep learning: E5M2 +class FloatingPoint8E5M2 extends FloatingPoint { + /// Calculate mantissa width and sanitize + static int _calculateMantissaWidth(int exponentWidth) { + final mantissaWidth = 7 - exponentWidth; + if (!FloatingPoint8E5M2Value.isLegal(exponentWidth, mantissaWidth)) { + throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); + } else { + return mantissaWidth; + } + } + + /// Construct an 8-bit floating point number + FloatingPoint8E5M2({required super.exponentWidth}) : super(mantissaWidth: _calculateMantissaWidth(exponentWidth)); } diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_32_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_32_value.dart new file mode 100644 index 00000000..2ea84c59 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_32_value.dart @@ -0,0 +1,91 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_32_value.dart +// Implementation of 32-bit Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'dart:typed_data'; +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of a single-precision floating-point value. +class FloatingPoint32Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 8; + + /// The mantissa width + static const int mantissaWidth = 23; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// Constructor for a single precision floating point value + FloatingPoint32Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// Return the [FloatingPoint32Value] representing the constant specified + factory FloatingPoint32Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPoint32Value.fromLogic( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + .value); + + /// [FloatingPoint32Value] constructor from string representation of + /// individual bitfields + FloatingPoint32Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPoint32Value] constructor from spaced string representation of + /// individual bitfields + FloatingPoint32Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPoint32Value] constructor from a single string representing + /// space-separated bitfields + FloatingPoint32Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPoint32Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPoint32Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPoint32Value] constructor from a set of [int]s of the binary + /// representation + FloatingPoint32Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPoint32Value] from a host double + factory FloatingPoint32Value.fromDouble(double inDouble) { + final byteData = ByteData(4)..setFloat32(0, inDouble); + final accum = byteData.buffer + .asUint8List() + .map((b) => LogicValue.ofInt(b, 32)) + .reduce((accum, v) => (accum << 8) | v); + + return FloatingPoint32Value( + sign: accum[-1], + exponent: accum.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: accum.slice(mantissaWidth - 1, 0)); + } + + /// Construct a [FloatingPoint32Value] from a Logic word + factory FloatingPoint32Value.fromLogic(LogicValue val) => + FloatingPoint32Value( + sign: val[-1], + exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: val.slice(mantissaWidth - 1, 0)); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_64_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_64_value.dart new file mode 100644 index 00000000..40f18a91 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_64_value.dart @@ -0,0 +1,93 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_64_value.dart +// Implementation of 64-bit Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'dart:typed_data'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of a double-precision floating-point value. +class FloatingPoint64Value extends FloatingPointValue { + /// The exponent width + static const int _exponentWidth = 11; + + /// The mantissa width + static const int _mantissaWidth = 52; + + /// return the exponent width + static int get exponentWidth => _exponentWidth; + + /// return the mantissa width + static int get mantissaWidth => _mantissaWidth; + + /// Constructor for a double precision floating point value + FloatingPoint64Value( + {required super.sign, required super.mantissa, required super.exponent}); + + /// Return the [FloatingPoint64Value] representing the constant specified + factory FloatingPoint64Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPoint64Value.fromLogic( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, _exponentWidth, _mantissaWidth) + .value); + + /// [FloatingPoint64Value] constructor from string representation of + /// individual bitfields + factory FloatingPoint64Value.ofBinaryStrings( + String sign, String exponent, String mantissa) => + FloatingPoint64Value( + sign: LogicValue.of(sign), + exponent: LogicValue.of(exponent), + mantissa: LogicValue.of(mantissa)); + + /// [FloatingPoint64Value] constructor from spaced string representation of + /// individual bitfields + FloatingPoint64Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPoint64Value] constructor from a single string representing + /// space-separated bitfields + FloatingPoint64Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPoint64Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPoint64Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPoint64Value] constructor from a set of [int]s of the binary + /// representation + FloatingPoint64Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPoint64Value] from a host double + factory FloatingPoint64Value.fromDouble(double inDouble) { + final byteData = ByteData(8)..setFloat64(0, inDouble); + final accum = byteData.buffer + .asUint8List() + .map((b) => LogicValue.ofInt(b, 64)) + .reduce((accum, v) => (accum << 8) | v); + + return FloatingPoint64Value( + sign: accum[-1], + exponent: + accum.slice(_exponentWidth + _mantissaWidth - 1, _mantissaWidth), + mantissa: accum.slice(_mantissaWidth - 1, 0)); + } + + /// Construct a [FloatingPoint32Value] from a Logic word + factory FloatingPoint64Value.fromLogic(LogicValue val) => + FloatingPoint64Value( + sign: val[-1], + exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: val.slice(mantissaWidth - 1, 0)); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_8_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_8_value.dart new file mode 100644 index 00000000..2a8b598e --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_8_value.dart @@ -0,0 +1,179 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_8_value.dart +// Implementation of 8-bit Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'dart:math'; +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// The E4M3 representation of a 8-bit floating point value as defined in +/// [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433). +class FloatingPoint8E4M3Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 4; + + /// The mantissa width + static const int mantissaWidth = 3; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// The maximum value representable by the E4M3 format + static double get maxValue => 448.toDouble(); + + /// The minimum value representable by the E4M3 format + static double get minValue => pow(2, -9).toDouble(); + + /// Return if the exponent and mantissa widths match E4M3 + static bool isLegal(int exponentWidth, int mantissaWidth) => + (exponentWidth == 4) & (mantissaWidth == 3); + + /// Constructor for a double precision floating point value + FloatingPoint8E4M3Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// [FloatingPoint8E4M3Value] constructor from string representation of + /// individual bitfields + FloatingPoint8E4M3Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPoint8E4M3Value] constructor from spaced string representation of + /// individual bitfields + FloatingPoint8E4M3Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPoint8E4M3Value] constructor from a single string representing + /// space-separated bitfields + FloatingPoint8E4M3Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPoint8E4M3Value] constructor from a set of [BigInt]s of the + /// binary representation + FloatingPoint8E4M3Value.ofBigInts(super.exponent, super.mantissa, + {super.sign}) + : super.ofBigInts(); + + /// [FloatingPoint8E4M3Value] constructor from a set of [int]s of the binary + /// representation + FloatingPoint8E4M3Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPoint8E4M3Value] from a host double + factory FloatingPoint8E4M3Value.fromDouble(double inDouble) { + if ((inDouble.abs() > maxValue) | + ((inDouble != 0) & (inDouble.abs() < minValue))) { + throw RohdHclException('Number exceeds E4M3 range'); + } + final fpv = FloatingPointValue.fromDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + return FloatingPoint8E4M3Value( + sign: fpv.sign, exponent: fpv.exponent, mantissa: fpv.mantissa); + } + + /// Construct a [FloatingPoint8E4M3Value] from a Logic word + factory FloatingPoint8E4M3Value.fromLogic(LogicValue val) { + if (val.width != 8) { + throw RohdHclException('Width must be 8'); + } + return FloatingPoint8E4M3Value( + sign: val[-1] == LogicValue.one ? LogicValue.one : LogicValue.zero, + exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: val.slice(mantissaWidth - 1, 0)); + } +} + +/// The E5M2 representation of a 8-bit floating point value as defined in +/// [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433). +class FloatingPoint8E5M2Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 5; + + /// The mantissa width + static const int mantissaWidth = 2; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// The maximum value representable by the E5M2 format + static double get maxValue => 57344.toDouble(); + + /// The minimum value representable by the E5M2 format + static double get minValue => pow(2, -16).toDouble(); + + /// Return if the exponent and mantissa widths match E5M2 + static bool isLegal(int exponentWidth, int mantissaWidth) => + (exponentWidth == 5) & (mantissaWidth == 2); + + /// Constructor for a double precision floating point value + FloatingPoint8E5M2Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// [FloatingPoint8E5M2Value] constructor from string representation of + /// individual bitfields + FloatingPoint8E5M2Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPoint8E5M2Value] constructor from spaced string representation of + /// individual bitfields + FloatingPoint8E5M2Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPoint8E5M2Value] constructor from a single string representing + /// space-separated bitfields + FloatingPoint8E5M2Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPoint8E5M2Value] constructor from a set of [BigInt]s of the + /// binary representation + FloatingPoint8E5M2Value.ofBigInts(super.exponent, super.mantissa, + {super.sign}) + : super.ofBigInts(); + + /// [FloatingPoint8E5M2Value] constructor from a set of [int]s of the binary + /// representation + FloatingPoint8E5M2Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPoint8E5M2Value] from a host double + factory FloatingPoint8E5M2Value.fromDouble(double inDouble) { + if ((inDouble.abs() > maxValue) | + ((inDouble != 0) & (inDouble.abs() < minValue))) { + throw RohdHclException('Number exceeds E5M2 range'); + } + final fpv = FloatingPointValue.fromDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + return FloatingPoint8E5M2Value( + sign: fpv.sign, exponent: fpv.exponent, mantissa: fpv.mantissa); + } + + /// Construct a [FloatingPoint8E5M2Value] from a Logic word + factory FloatingPoint8E5M2Value.fromLogic(LogicValue val) { + if (val.width != 8) { + throw RohdHclException('Width must be 8'); + } + return FloatingPoint8E5M2Value( + sign: val[-1], + exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: val.slice(mantissaWidth - 1, 0)); + } +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_bf16_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_bf16_value.dart new file mode 100644 index 00000000..8c47219a --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_bf16_value.dart @@ -0,0 +1,84 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_bf16_value.dart +// Implementation of BF16 Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of a BF16 floating-point value. +class FloatingPointBF16Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 8; + + /// The mantissa width + static const int mantissaWidth = 7; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// Constructor for a single precision floating point value + FloatingPointBF16Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// Return the [FloatingPointBF16Value] representing the constant specified + factory FloatingPointBF16Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPointBF16Value.fromLogic( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + .value); + + /// [FloatingPointBF16Value] constructor from string representation of + /// individual bitfields + FloatingPointBF16Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPointBF16Value] constructor from spaced string representation of + /// individual bitfields + FloatingPointBF16Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPointBF16Value] constructor from a single string representing + /// space-separated bitfields + FloatingPointBF16Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPointBF16Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPointBF16Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPointBF16Value] constructor from a set of [int]s of the binary + /// representation + FloatingPointBF16Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPointBF16Value] from a host double + factory FloatingPointBF16Value.fromDouble(double inDouble) { + final fpv = FloatingPointValue.fromDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + return FloatingPointBF16Value.fromLogic(fpv.value); + } + + /// Construct a [FloatingPointBF16Value] from a Logic word + factory FloatingPointBF16Value.fromLogic(LogicValue val) => + FloatingPointBF16Value( + sign: val[-1], + exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: val.slice(mantissaWidth - 1, 0)); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_fp16_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_fp16_value.dart new file mode 100644 index 00000000..b9fa2e10 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_fp16_value.dart @@ -0,0 +1,84 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_fp16_value.dart +// Implementation of FP16 Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of an FP16 floating-point value. +class FloatingPointFP16Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 5; + + /// The mantissa width + static const int mantissaWidth = 10; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// Constructor for a single precision floating point value + FloatingPointFP16Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// Return the [FloatingPointFP16Value] representing the constant specified + factory FloatingPointFP16Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPointFP16Value.fromLogic( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + .value); + + /// [FloatingPointFP16Value] constructor from string representation of + /// individual bitfields + FloatingPointFP16Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPointFP16Value] constructor from spaced string representation of + /// individual bitfields + FloatingPointFP16Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPointFP16Value] constructor from a single string representing + /// space-separated bitfields + FloatingPointFP16Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPointFP16Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPointFP16Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPointFP16Value] constructor from a set of [int]s of the binary + /// representation + FloatingPointFP16Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPointFP16Value] from a host double + factory FloatingPointFP16Value.fromDouble(double inDouble) { + final fpv = FloatingPointValue.fromDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + return FloatingPointFP16Value.fromLogic(fpv.value); + } + + /// Construct a [FloatingPointFP16Value] from a Logic word + factory FloatingPointFP16Value.fromLogic(LogicValue val) => + FloatingPointFP16Value( + sign: val[-1], + exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: val.slice(mantissaWidth - 1, 0)); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_tf32_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_tf32_value.dart new file mode 100644 index 00000000..7438a75f --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_tf32_value.dart @@ -0,0 +1,83 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_tf32_value.dart +// Implementation of TF32 Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of a TF32 floating-point value. +class FloatingPointTF32Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 8; + + /// The mantissa width + static const int mantissaWidth = 10; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// Constructor for a single precision floating point value + FloatingPointTF32Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// Return the [FloatingPointTF32Value] representing the constant specified + factory FloatingPointTF32Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPointTF32Value.fromLogic( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + .value); + + /// [FloatingPointTF32Value] constructor from string representation of + /// individual bitfields + FloatingPointTF32Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPointTF32Value] constructor from spaced string representation of + /// individual bitfields + FloatingPointTF32Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPointTF32Value] constructor from a single string representing + /// space-separated bitfields + FloatingPointTF32Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPointTF32Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPointTF32Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPointTF32Value] constructor from a set of [int]s of the binary + /// representation + FloatingPointTF32Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPointTF32Value] from a host double + factory FloatingPointTF32Value.fromDouble(double inDouble) { + final fpv = FloatingPointValue.fromDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + return FloatingPointTF32Value.fromLogic(fpv.value); + } + + /// Construct a [FloatingPointTF32Value] from a Logic word + factory FloatingPointTF32Value.fromLogic(LogicValue val) => + FloatingPointTF32Value( + sign: val[-1], + exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: val.slice(mantissaWidth - 1, 0)); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_value.dart similarity index 55% rename from lib/src/arithmetic/floating_point/floating_point_value.dart rename to lib/src/arithmetic/floating_point/floating_point_values/floating_point_value.dart index a3ca64d6..fd968b9f 100644 --- a/lib/src/arithmetic/floating_point/floating_point_value.dart +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_value.dart @@ -10,7 +10,6 @@ // Desmond A Kirkpatrick { final LogicValue value; /// The sign of the value: 1 means a negative value - final LogicValue sign; + late final LogicValue sign; /// The exponent of the floating point: this is biased about a midpoint for /// positive and negative exponents @@ -119,135 +118,160 @@ class FloatingPointValue implements Comparable { final int _maxExp; final int _minExp; - /// Factory (static) constructor of a [FloatingPointValue] from - /// sign, mantissa and exponent - factory FloatingPointValue( + /// A Map from the (exponentWidth, mantissaWidth) pair to the + /// FloatingPointValue subtype + static Map< + ({int exponentWidth, int mantissaWidth}), + FloatingPointValue Function( + {required LogicValue sign, + required LogicValue exponent, + required LogicValue mantissa})> factoryConstructorMap = { + ( + exponentWidth: FloatingPoint32Value.exponentWidth, + mantissaWidth: FloatingPoint32Value.mantissaWidth + ): FloatingPoint32Value.new, + ( + exponentWidth: FloatingPoint64Value.exponentWidth, + mantissaWidth: FloatingPoint64Value.mantissaWidth + ): FloatingPoint64Value.new, + (exponentWidth: 4, mantissaWidth: 3): FloatingPoint8E4M3Value.new, + (exponentWidth: 5, mantissaWidth: 2): FloatingPoint8E5M2Value.new, + (exponentWidth: 5, mantissaWidth: 10): FloatingPointFP16Value.new, + (exponentWidth: 8, mantissaWidth: 7): FloatingPointBF16Value.new, + (exponentWidth: 8, mantissaWidth: 10): FloatingPointTF32Value.new, + }; + + /// Constructor for a [FloatingPointValue] with a sign, exponent, and + /// mantissa. + @protected + FloatingPointValue( + {required this.sign, required this.exponent, required this.mantissa}) + : value = [sign, exponent, mantissa].swizzle(), + _bias = computeBias(exponent.width), + _minExp = computeMinExponent(exponent.width), + _maxExp = computeMaxExponent(exponent.width) { + if (sign.width != 1) { + throw RohdHclException('FloatingPointValue: sign width must be 1'); + } + if (constrainedMantissaWidth != null && + mantissa.width != constrainedMantissaWidth) { + throw RohdHclException('FloatingPointValue: mantissa width must be ' + '$constrainedMantissaWidth'); + } + if (constrainedExponentWidth != null && + exponent.width != constrainedExponentWidth) { + throw RohdHclException('FloatingPointValue: exponent width must be ' + '$constrainedExponentWidth'); + } + } + + /// Constructs a [FloatingPointValue] with a sign, exponent, and mantissa + /// using one of the builders provided from [factoryConstructorMap] if + /// available, otherwise using the default constructor. + factory FloatingPointValue.mapped( {required LogicValue sign, required LogicValue exponent, required LogicValue mantissa}) { - if (exponent.width == FloatingPoint32Value.exponentWidth && - mantissa.width == FloatingPoint32Value.mantissaWidth) { - return FloatingPoint32Value( - sign: sign, mantissa: mantissa, exponent: exponent); - } else if (exponent.width == FloatingPoint64Value._exponentWidth && - mantissa.width == FloatingPoint64Value._mantissaWidth) { - return FloatingPoint64Value( - sign: sign, mantissa: mantissa, exponent: exponent); - } else { - return FloatingPointValue.withConstraints( - sign: sign, mantissa: mantissa, exponent: exponent); + final key = (exponentWidth: exponent.width, mantissaWidth: mantissa.width); + + if (!factoryConstructorMap.containsKey(key)) { + return FloatingPointValue( + sign: sign, exponent: exponent, mantissa: mantissa); } + + return factoryConstructorMap[key]!( + sign: sign, exponent: exponent, mantissa: mantissa); } + /// Converts this [FloatingPointValue] to a [FloatingPointValue] with the same + /// sign, exponent, and mantissa using the constructor provided in + /// [factoryConstructorMap] if available, otherwise using the default + /// constructor. + FloatingPointValue toMappedType() => FloatingPointValue.mapped( + sign: sign, exponent: exponent, mantissa: mantissa); + + /// [constrainedMantissaWidth] is the hard-coded mantissa width of the + /// sub-class of this floating-point value + @protected + int? get constrainedMantissaWidth => null; + + /// [constrainedExponentWidth] is the hard-coded exponent width of the + /// sub-class of this floating-point value + @protected + int? get constrainedExponentWidth => null; + /// [FloatingPointValue] constructor from a binary string representation of /// individual bitfields - factory FloatingPointValue.ofBinaryStrings( - String sign, String exponent, String mantissa) { - if (sign.length != 1) { - throw RohdHclException('Sign string must be of length 1'); - } - - return FloatingPointValue( - sign: LogicValue.of(sign), - exponent: LogicValue.of(exponent), - mantissa: LogicValue.of(mantissa)); - } + FloatingPointValue.ofBinaryStrings( + String sign, String exponent, String mantissa) + : this( + sign: LogicValue.of(sign), + exponent: LogicValue.of(exponent), + mantissa: LogicValue.of(mantissa)); /// [FloatingPointValue] constructor from a single binary string representing /// space-separated bitfields - factory FloatingPointValue.ofSeparatedBinaryStrings(String fp) { - final s = fp.split(' '); - if (s.length != 3) { - throw RohdHclException('FloatingPointValue requires three strings ' - 'to initialize'); - } - return FloatingPointValue.ofBinaryStrings(s[0], s[1], s[2]); - } + FloatingPointValue.ofSpacedBinaryString(String fp) + : this.ofBinaryStrings( + fp.split(' ')[0], fp.split(' ')[1], fp.split(' ')[2]); /// [FloatingPointValue] constructor from a radix-encoded string /// representation and the size of the exponent and mantissa - factory FloatingPointValue.ofString( - String fp, int exponentWidth, int mantissaWidth, - {int radix = 2}) { + FloatingPointValue.ofString(String fp, int exponentWidth, int mantissaWidth, + {int radix = 2}) + : this.ofBinaryStrings( + _extractBinaryStrings(fp, exponentWidth, mantissaWidth, radix).sign, + _extractBinaryStrings(fp, exponentWidth, mantissaWidth, radix) + .exponent, + _extractBinaryStrings(fp, exponentWidth, mantissaWidth, radix) + .mantissa); + + /// Helper function for extracting binary strings from a longer + /// binary string and the known exponent and mantissa widths. + static ({String sign, String exponent, String mantissa}) + _extractBinaryStrings( + String fp, int exponentWidth, int mantissaWidth, int radix) { final binaryFp = LogicValue.ofBigInt( BigInt.parse(fp, radix: radix), exponentWidth + mantissaWidth + 1) .bitString; - final (sign, exponent, mantissa) = ( - binaryFp.substring(0, 1), - binaryFp.substring(1, 1 + exponentWidth), - binaryFp.substring(1 + exponentWidth, 1 + exponentWidth + mantissaWidth) + return ( + sign: binaryFp.substring(0, 1), + exponent: binaryFp.substring(1, 1 + exponentWidth), + mantissa: binaryFp.substring( + 1 + exponentWidth, 1 + exponentWidth + mantissaWidth) ); - return FloatingPointValue.ofBinaryStrings(sign, exponent, mantissa); } + // TODO(desmonddak): toRadixString() would be useful, not limited to binary + /// [FloatingPointValue] constructor from a set of [BigInt]s of the binary /// representation and the size of the exponent and mantissa - factory FloatingPointValue.ofBigInts(BigInt exponent, BigInt mantissa, - {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) { - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - - return FloatingPointValue( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } + FloatingPointValue.ofBigInts(BigInt exponent, BigInt mantissa, + {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) + : this( + sign: LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + exponent: LogicValue.ofBigInt(exponent, exponentWidth), + mantissa: LogicValue.ofBigInt(mantissa, mantissaWidth)); /// [FloatingPointValue] constructor from a set of [int]s of the binary /// representation and the size of the exponent and mantissa - factory FloatingPointValue.ofInts(int exponent, int mantissa, - {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) { - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth), - LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth) - ); - - return FloatingPointValue( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } - - /// Constructor enabling subclasses. - FloatingPointValue.withConstraints( - {required this.sign, - required this.exponent, - required this.mantissa, - int? mantissaWidth, - int? exponentWidth}) - : value = [sign, exponent, mantissa].swizzle(), - _bias = computeBias(exponent.width), - _minExp = computeMinExponent(exponent.width), - _maxExp = computeMaxExponent(exponent.width) { - if (sign.width != 1) { - throw RohdHclException('FloatingPointValue: sign width must be 1'); - } - if (mantissaWidth != null && mantissa.width != mantissaWidth) { - throw RohdHclException( - 'FloatingPointValue: mantissa width must be $mantissaWidth'); - } - if (exponentWidth != null && exponent.width != exponentWidth) { - throw RohdHclException( - 'FloatingPointValue: exponent width must be $exponentWidth'); - } - } - - /// Construct a [FloatingPointValue] from a Logic word - factory FloatingPointValue.fromLogic( - int exponentWidth, int mantissaWidth, LogicValue val) { - final sign = (val[-1] == LogicValue.one); - final exponent = - val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt(); - final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt(); - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - return FloatingPointValue( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } + FloatingPointValue.ofInts(int exponent, int mantissa, + {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) + : this( + sign: LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + exponent: LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth), + mantissa: + LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth)); + + /// Construct a [FloatingPointValue] from a [LogicValue] + FloatingPointValue.fromLogicValue( + int exponentWidth, int mantissaWidth, LogicValue val) + : this( + sign: val[-1], + exponent: + val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: val.slice(mantissaWidth - 1, 0)); /// Return the [FloatingPointValue] representing the constant specified factory FloatingPointValue.getFloatingPointConstant( @@ -526,9 +550,9 @@ class FloatingPointValue implements Comparable { (mantissa == other.mantissa); } + // TODO(desmonddak): figure out the difference with Infinity /// Return true if the represented floating point number is considered /// NaN or 'Not a Number' due to overflow - // TODO(desmonddak): figure out the difference with Infinity bool isNaN() { if ((exponent.width == 4) & (mantissa.width == 3)) { // FP8 E4M3 does not support infinities @@ -625,297 +649,3 @@ class FloatingPointValue implements Comparable { FloatingPointValue abs() => FloatingPointValue( sign: LogicValue.zero, exponent: exponent, mantissa: mantissa); } - -/// A representation of a single precision floating point value -class FloatingPoint32Value extends FloatingPointValue { - /// The exponent width - static const int exponentWidth = 8; - - /// The mantissa width - static const int mantissaWidth = 23; - - /// Constructor for a single precision floating point value - FloatingPoint32Value( - {required super.sign, required super.exponent, required super.mantissa}) - : super.withConstraints( - mantissaWidth: mantissaWidth, exponentWidth: exponentWidth); - - /// Return the [FloatingPoint32Value] representing the constant specified - factory FloatingPoint32Value.getFloatingPointConstant( - FloatingPointConstants constantFloatingPoint) => - FloatingPointValue.getFloatingPointConstant( - constantFloatingPoint, exponentWidth, mantissaWidth) - as FloatingPoint32Value; - - /// [FloatingPoint32Value] constructor from string representation of - /// individual bitfields - factory FloatingPoint32Value.ofStrings( - String sign, String exponent, String mantissa) => - FloatingPoint32Value( - sign: LogicValue.of(sign), - exponent: LogicValue.of(exponent), - mantissa: LogicValue.of(mantissa)); - - /// [FloatingPoint32Value] constructor from a single string representing - /// space-separated bitfields - factory FloatingPoint32Value.ofString(String fp) { - final s = fp.split(' '); - assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}'); - return FloatingPoint32Value.ofStrings(s[0], s[1], s[2]); - } - - /// [FloatingPoint32Value] constructor from a set of [BigInt]s of the binary - /// representation - factory FloatingPoint32Value.ofBigInts(BigInt exponent, BigInt mantissa, - {bool sign = false}) { - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - - return FloatingPoint32Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } - - /// [FloatingPoint32Value] constructor from a set of [int]s of the binary - /// representation - factory FloatingPoint32Value.ofInts(int exponent, int mantissa, - {bool sign = false}) { - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth), - LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth) - ); - - return FloatingPoint32Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } - - /// Numeric conversion of a [FloatingPoint32Value] from a host double - factory FloatingPoint32Value.fromDouble(double inDouble) { - final byteData = ByteData(4) - ..setFloat32(0, inDouble) - ..buffer.asUint8List(); - final bytes = byteData.buffer.asUint8List(); - final lv = bytes.map((b) => LogicValue.ofInt(b, 32)); - - final accum = lv.reduce((accum, v) => (accum << 8) | v); - - final sign = accum[-1]; - final exponent = - accum.slice(exponentWidth + mantissaWidth - 1, mantissaWidth); - final mantissa = accum.slice(mantissaWidth - 1, 0); - - return FloatingPoint32Value( - sign: sign, exponent: exponent, mantissa: mantissa); - } - - /// Construct a [FloatingPoint32Value] from a Logic word - factory FloatingPoint32Value.fromLogic(LogicValue val) { - final sign = (val[-1] == LogicValue.one); - final exponent = - val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth); - final mantissa = val.slice(mantissaWidth - 1, 0); - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - exponent, - mantissa - ); - return FloatingPoint32Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } -} - -/// A representation of a double precision floating point value -class FloatingPoint64Value extends FloatingPointValue { - static const int _exponentWidth = 11; - static const int _mantissaWidth = 52; - - /// return the exponent width - static int get exponentWidth => _exponentWidth; - - /// return the mantissa width - static int get mantissaWidth => _mantissaWidth; - - /// Constructor for a double precision floating point value - FloatingPoint64Value( - {required super.sign, required super.mantissa, required super.exponent}) - : super.withConstraints( - exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - - /// Return the [FloatingPoint64Value] representing the constant specified - factory FloatingPoint64Value.getFloatingPointConstant( - FloatingPointConstants constantFloatingPoint) => - FloatingPointValue.getFloatingPointConstant( - constantFloatingPoint, _exponentWidth, _mantissaWidth) - as FloatingPoint64Value; - - /// [FloatingPoint64Value] constructor from string representation of - /// individual bitfields - factory FloatingPoint64Value.ofStrings( - String sign, String exponent, String mantissa) => - FloatingPoint64Value( - sign: LogicValue.of(sign), - exponent: LogicValue.of(exponent), - mantissa: LogicValue.of(mantissa)); - - /// [FloatingPoint64Value] constructor from a single string representing - /// space-separated bitfields - factory FloatingPoint64Value.ofString(String fp) { - final s = fp.split(' '); - assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}'); - return FloatingPoint64Value.ofStrings(s[0], s[1], s[2]); - } - - /// [FloatingPoint64Value] constructor from a set of [BigInt]s of the binary - /// representation - factory FloatingPoint64Value.ofBigInts(BigInt exponent, BigInt mantissa, - {bool sign = false}) => - FloatingPointValue.ofBigInts(exponent, mantissa, - sign: sign, - exponentWidth: exponentWidth, - mantissaWidth: mantissaWidth) as FloatingPoint64Value; - - /// [FloatingPoint64Value] constructor from a set of [int]s of the binary - /// representation - factory FloatingPoint64Value.ofInts(int exponent, int mantissa, - {bool sign = false}) => - FloatingPointValue.ofInts(exponent, mantissa, - sign: sign, - exponentWidth: exponentWidth, - mantissaWidth: mantissaWidth) as FloatingPoint64Value; - - /// Numeric conversion of a [FloatingPoint64Value] from a host double - factory FloatingPoint64Value.fromDouble(double inDouble) { - final byteData = ByteData(8) - ..setFloat64(0, inDouble) - ..buffer.asUint8List(); - final bytes = byteData.buffer.asUint8List(); - final lv = bytes.map((b) => LogicValue.ofInt(b, 64)); - - final accum = lv.reduce((accum, v) => (accum << 8) | v); - - final sign = accum[-1]; - final exponent = - accum.slice(_exponentWidth + _mantissaWidth - 1, _mantissaWidth); - final mantissa = accum.slice(_mantissaWidth - 1, 0); - - return FloatingPoint64Value( - sign: sign, mantissa: mantissa, exponent: exponent); - } - - /// Construct a [FloatingPoint32Value] from a Logic word - factory FloatingPoint64Value.fromLogic(LogicValue val) { - final sign = (val[-1] == LogicValue.one); - final exponent = - val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt(); - final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt(); - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - return FloatingPoint64Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } -} - -/// A representation of a 8-bit floating point value as defined in -/// [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433). -class FloatingPoint8Value extends FloatingPointValue { - /// The exponent width - late final int exponentWidth; - - /// The mantissa width - late final int mantissaWidth; - - static double get _e4m3max => 448.toDouble(); - static double get _e5m2max => 57344.toDouble(); - static double get _e4m3min => pow(2, -9).toDouble(); - static double get _e5m2min => pow(2, -16).toDouble(); - - /// Return if the exponent and mantissa widths match E4M3 or E5M2 - static bool isLegal(int exponentWidth, int mantissaWidth) { - if (((exponentWidth == 4) & (mantissaWidth == 3)) | - ((exponentWidth == 5) & (mantissaWidth == 2))) { - return true; - } else { - return false; - } - } - - /// Constructor for a double precision floating point value - FloatingPoint8Value( - {required super.sign, required super.mantissa, required super.exponent}) - : super.withConstraints() { - exponentWidth = exponent.width; - mantissaWidth = mantissa.width; - if (!isLegal(exponentWidth, mantissaWidth)) { - throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); - } - } - - /// [FloatingPoint8Value] constructor from string representation of - /// individual bitfields - factory FloatingPoint8Value.ofStrings( - String sign, String exponent, String mantissa) => - FloatingPoint8Value( - sign: LogicValue.of(sign), - exponent: LogicValue.of(exponent), - mantissa: LogicValue.of(mantissa)); - - /// [FloatingPoint8Value] constructor from a single string representing - /// space-separated bitfields - factory FloatingPoint8Value.ofString(String fp) { - final s = fp.split(' '); - assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}'); - return FloatingPoint8Value.ofStrings(s[0], s[1], s[2]); - } - - /// Construct a [FloatingPoint8Value] from a Logic word - factory FloatingPoint8Value.fromLogic(LogicValue val, int exponentWidth) { - if (val.width != 8) { - throw RohdHclException('Width must be 8'); - } - - final mantissaWidth = 7 - exponentWidth; - if (!isLegal(exponentWidth, mantissaWidth)) { - throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); - } - - final sign = (val[-1] == LogicValue.one); - final exponent = - val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt(); - final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt(); - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - return FloatingPoint8Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } - - /// Numeric conversion of a [FloatingPoint8Value] from a host double - factory FloatingPoint8Value.fromDouble(double inDouble, - {required int exponentWidth}) { - final mantissaWidth = 7 - exponentWidth; - if (!isLegal(exponentWidth, mantissaWidth)) { - throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); - } - if (exponentWidth == 4) { - if ((inDouble > _e4m3max) | (inDouble < _e4m3min)) { - throw RohdHclException('Number exceeds E4M3 range'); - } - } else if (exponentWidth == 5) { - if ((inDouble > _e5m2max) | (inDouble < _e5m2min)) { - throw RohdHclException('Number exceeds E5M2 range'); - } - } - final fpv = FloatingPointValue.fromDouble(inDouble, - exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - return FloatingPoint8Value( - sign: fpv.sign, exponent: fpv.exponent, mantissa: fpv.mantissa); - } -} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_values.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_values.dart new file mode 100644 index 00000000..102b10cb --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_values.dart @@ -0,0 +1,10 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +export 'floating_point_32_value.dart'; +export 'floating_point_64_value.dart'; +export 'floating_point_8_value.dart'; +export 'floating_point_bf16_value.dart'; +export 'floating_point_fp16_value.dart'; +export 'floating_point_tf32_value.dart'; +export 'floating_point_value.dart'; diff --git a/test/arithmetic/floating_point/floating_point_value_test.dart b/test/arithmetic/floating_point/floating_point_value_test.dart index 252d3326..d9574cab 100644 --- a/test/arithmetic/floating_point/floating_point_value_test.dart +++ b/test/arithmetic/floating_point/floating_point_value_test.dart @@ -118,7 +118,7 @@ void main() { }); test('FloatingPointValue string conversion', () { const str = '0 10000001 01000100000000000000000'; // 5.0625 - final fp = FloatingPoint32Value.ofString(str); + final fp = FloatingPoint32Value.ofSpacedBinaryString(str); expect(fp.toString(), str); expect(fp.toDouble(), 5.0625); }); @@ -163,14 +163,13 @@ void main() { exponentWidth: 4, mantissaWidth: 3); expect(val, fp.toDouble()); expect(str, fp.toString()); - final fp8 = FloatingPointValue.fromDouble(val, - exponentWidth: 4, mantissaWidth: 3); + final fp8 = FloatingPoint8E4M3Value.fromDouble(val); expect(val, fp8.toDouble()); expect(str, fp8.toString()); } }); - test('FP8: E5M2', () { + test('FPV8: E5M2', () { final corners = [ ['0 00000 00', 0.toDouble()], ['0 11110 11', 57344.toDouble()], @@ -185,8 +184,7 @@ void main() { exponentWidth: 5, mantissaWidth: 2); expect(val, fp.toDouble()); expect(str, fp.toString()); - final fp8 = FloatingPointValue.fromDouble(val, - exponentWidth: 5, mantissaWidth: 2); + final fp8 = FloatingPoint8E5M2Value.fromDouble(val); expect(val, fp8.toDouble()); expect(str, fp8.toString()); } @@ -199,16 +197,16 @@ void main() { final fp2 = FloatingPoint64() ..put(FloatingPoint64Value.fromDouble(1.5).value); expect(fp2.floatingPointValue.toDouble(), 1.5); - final fp8e4m3 = FloatingPoint8(exponentWidth: 4) - ..put(FloatingPoint8Value.fromDouble(1.5, exponentWidth: 4).value); + final fp8e4m3 = FloatingPoint8E4M3(exponentWidth: 4) + ..put(FloatingPoint8E4M3Value.fromDouble(1.5).value); expect(fp8e4m3.floatingPointValue.toDouble(), 1.5); - final fp8e5m2 = FloatingPoint8(exponentWidth: 5) - ..put(FloatingPoint8Value.fromDouble(1.5, exponentWidth: 5).value); + final fp8e5m2 = FloatingPoint8E5M2(exponentWidth: 5) + ..put(FloatingPoint8E5M2Value.fromDouble(1.5).value); expect(fp8e5m2.floatingPointValue.toDouble(), 1.5); }); test('FPV: round nearest even Guard and Sticky', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '0000100000000000000000000000000000000000000000000001'); final fpRound = FloatingPointValue.ofBinaryStrings('0', '1000', '0001'); @@ -218,7 +216,7 @@ void main() { expect(fpConvert, equals(fpRound)); }); test('FPV: round nearest even Guard and Round', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '0000110000000000000000000000000000000000000000000000'); final fpRound = FloatingPointValue.ofBinaryStrings('0', '1000', '0001'); @@ -229,7 +227,7 @@ void main() { expect(fpConvert, equals(fpRound)); }); test('FPV: rounding nearest even increment', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '0001100000000000000000000000000000000000000000000000'); final fpRound = FloatingPointValue.ofBinaryStrings('0', '1000', '0010'); @@ -239,7 +237,7 @@ void main() { expect(fpConvert, equals(fpRound)); }); test('FPV: rounding nearest even increment carry into exponent', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '1111100000000000000000000000000000000000000000000000'); final fpRound = FloatingPointValue.ofBinaryStrings('0', '1001', '0000'); @@ -249,7 +247,7 @@ void main() { expect(fpConvert, equals(fpRound)); }); test('FPV: rounding nearest even truncate', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '0010100000000000000000000000000000000000000000000000'); final fpTrunc = FloatingPointValue.ofBinaryStrings('0', '1000', '0010');