Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of FloatingPointValue constructors for ease of use and extension #110

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion lib/src/arithmetic/floating_point/floating_point.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
export 'floating_point_adder_round.dart';
export 'floating_point_adder_simple.dart';
export 'floating_point_logic.dart';
export 'floating_point_value.dart';
export 'floating_point_values/floating_point_values.dart';
28 changes: 22 additions & 6 deletions lib/src/arithmetic/floating_point/floating_point_logic.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
//

import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/src/arithmetic/floating_point/floating_point_value.dart';
import 'package:rohd_hcl/src/exceptions.dart';
import 'package:rohd_hcl/rohd_hcl.dart';

/// Flexible floating point logic representation
class FloatingPoint extends LogicStructure {
Expand Down Expand Up @@ -84,19 +83,36 @@ class FloatingPoint64 extends FloatingPoint {
mantissaWidth: FloatingPoint64Value.mantissaWidth);
}

/// Eight-bit floating point representation for deep learning
class FloatingPoint8 extends FloatingPoint {
/// Eight-bit floating point representation for deep learning: E4M3
class FloatingPoint8E4M3 extends FloatingPoint {
/// Calculate mantissa width and sanitize
static int _calculateMantissaWidth(int exponentWidth) {
final mantissaWidth = 7 - exponentWidth;
if (!FloatingPoint8Value.isLegal(exponentWidth, mantissaWidth)) {
if (!FloatingPoint8E4M3Value.isLegal(exponentWidth, mantissaWidth)) {
throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2');
} else {
return mantissaWidth;
}
}

/// Construct an 8-bit floating point number
FloatingPoint8({required super.exponentWidth})
FloatingPoint8E4M3({required super.exponentWidth})
: super(mantissaWidth: _calculateMantissaWidth(exponentWidth));
}

/// Eight-bit floating point representation for deep learning: E5M2
class FloatingPoint8E5M2 extends FloatingPoint {
/// Calculate mantissa width and sanitize
static int _calculateMantissaWidth(int exponentWidth) {
final mantissaWidth = 7 - exponentWidth;
if (!FloatingPoint8E5M2Value.isLegal(exponentWidth, mantissaWidth)) {
throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2');
} else {
return mantissaWidth;
}
}

/// Construct an 8-bit floating point number
FloatingPoint8E5M2({required super.exponentWidth})
: super(mantissaWidth: _calculateMantissaWidth(exponentWidth));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: BSD-3-Clause
//
// floating_point_32_value.dart
// Implementation of 32-bit Floating-Point value representations.
//
// 2024 October 15
// Authors:
// Max Korbel <[email protected]>
// Desmond A Kirkpatrick <[email protected]>

import 'dart:typed_data';
import 'package:meta/meta.dart';
import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/rohd_hcl.dart';

/// A representation of a single-precision floating-point value.
class FloatingPoint32Value extends FloatingPointValue {
/// The exponent width
static const int exponentWidth = 8;

/// The mantissa width
static const int mantissaWidth = 23;

@override
@protected
int get constrainedExponentWidth => exponentWidth;

@override
@protected
int get constrainedMantissaWidth => mantissaWidth;

/// Constructor for a single precision floating point value
FloatingPoint32Value(
{required super.sign, required super.exponent, required super.mantissa});

/// Return the [FloatingPoint32Value] representing the constant specified
factory FloatingPoint32Value.getFloatingPointConstant(
FloatingPointConstants constantFloatingPoint) =>
FloatingPoint32Value.fromLogic(
FloatingPointValue.getFloatingPointConstant(
constantFloatingPoint, exponentWidth, mantissaWidth)
.value);

/// [FloatingPoint32Value] constructor from string representation of
/// individual bitfields
FloatingPoint32Value.ofBinaryStrings(
super.sign, super.exponent, super.mantissa)
: super.ofBinaryStrings();

/// [FloatingPoint32Value] constructor from spaced string representation of
/// individual bitfields
FloatingPoint32Value.ofSpacedBinaryString(super.fp)
: super.ofSpacedBinaryString();

/// [FloatingPoint32Value] constructor from a single string representing
/// space-separated bitfields
FloatingPoint32Value.ofString(String fp, {super.radix})
: super.ofString(fp, exponentWidth, mantissaWidth);

/// [FloatingPoint32Value] constructor from a set of [BigInt]s of the binary
/// representation
FloatingPoint32Value.ofBigInts(super.exponent, super.mantissa, {super.sign})
: super.ofBigInts();

/// [FloatingPoint32Value] constructor from a set of [int]s of the binary
/// representation
FloatingPoint32Value.ofInts(super.exponent, super.mantissa, {super.sign})
: super.ofInts();

/// Numeric conversion of a [FloatingPoint32Value] from a host double
factory FloatingPoint32Value.fromDouble(double inDouble) {
final byteData = ByteData(4)..setFloat32(0, inDouble);
final accum = byteData.buffer
.asUint8List()
.map((b) => LogicValue.ofInt(b, 32))
.reduce((accum, v) => (accum << 8) | v);

return FloatingPoint32Value(
sign: accum[-1],
exponent: accum.slice(exponentWidth + mantissaWidth - 1, mantissaWidth),
mantissa: accum.slice(mantissaWidth - 1, 0));
}

/// Construct a [FloatingPoint32Value] from a Logic word
factory FloatingPoint32Value.fromLogic(LogicValue val) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these don't need to be factorys, and also should also be updatd to fromLogicValue

FloatingPoint32Value(
sign: val[-1],
exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth),
mantissa: val.slice(mantissaWidth - 1, 0));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: BSD-3-Clause
//
// floating_point_64_value.dart
// Implementation of 64-bit Floating-Point value representations.
//
// 2024 October 15
// Authors:
// Max Korbel <[email protected]>
// Desmond A Kirkpatrick <[email protected]>

import 'dart:typed_data';

import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/rohd_hcl.dart';

/// A representation of a double-precision floating-point value.
class FloatingPoint64Value extends FloatingPointValue {
/// The exponent width
static const int _exponentWidth = 11;

/// The mantissa width
static const int _mantissaWidth = 52;

/// return the exponent width
static int get exponentWidth => _exponentWidth;

/// return the mantissa width
static int get mantissaWidth => _mantissaWidth;

/// Constructor for a double precision floating point value
FloatingPoint64Value(
{required super.sign, required super.mantissa, required super.exponent});

/// Return the [FloatingPoint64Value] representing the constant specified
factory FloatingPoint64Value.getFloatingPointConstant(
FloatingPointConstants constantFloatingPoint) =>
FloatingPoint64Value.fromLogic(
FloatingPointValue.getFloatingPointConstant(
constantFloatingPoint, _exponentWidth, _mantissaWidth)
.value);

/// [FloatingPoint64Value] constructor from string representation of
/// individual bitfields
factory FloatingPoint64Value.ofBinaryStrings(
String sign, String exponent, String mantissa) =>
FloatingPoint64Value(
sign: LogicValue.of(sign),
exponent: LogicValue.of(exponent),
mantissa: LogicValue.of(mantissa));

/// [FloatingPoint64Value] constructor from spaced string representation of
/// individual bitfields
FloatingPoint64Value.ofSpacedBinaryString(super.fp)
: super.ofSpacedBinaryString();

/// [FloatingPoint64Value] constructor from a single string representing
/// space-separated bitfields
FloatingPoint64Value.ofString(String fp, {super.radix})
: super.ofString(fp, exponentWidth, mantissaWidth);

/// [FloatingPoint64Value] constructor from a set of [BigInt]s of the binary
/// representation
FloatingPoint64Value.ofBigInts(super.exponent, super.mantissa, {super.sign})
: super.ofBigInts();

/// [FloatingPoint64Value] constructor from a set of [int]s of the binary
/// representation
FloatingPoint64Value.ofInts(super.exponent, super.mantissa, {super.sign})
: super.ofInts();

/// Numeric conversion of a [FloatingPoint64Value] from a host double
factory FloatingPoint64Value.fromDouble(double inDouble) {
final byteData = ByteData(8)..setFloat64(0, inDouble);
final accum = byteData.buffer
.asUint8List()
.map((b) => LogicValue.ofInt(b, 64))
.reduce((accum, v) => (accum << 8) | v);

return FloatingPoint64Value(
sign: accum[-1],
exponent:
accum.slice(_exponentWidth + _mantissaWidth - 1, _mantissaWidth),
mantissa: accum.slice(_mantissaWidth - 1, 0));
}

/// Construct a [FloatingPoint32Value] from a Logic word
factory FloatingPoint64Value.fromLogic(LogicValue val) =>
FloatingPoint64Value(
sign: val[-1],
exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth),
mantissa: val.slice(mantissaWidth - 1, 0));
}
Loading
Loading