Skip to content

Commit

Permalink
Merge pull request #90 from recoules/fast-small-extract
Browse files Browse the repository at this point in the history
Add fast path when extraction leads to intnat
  • Loading branch information
antoinemine authored Jan 2, 2023
2 parents 655a0f5 + c65e82e commit 7d0ee55
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 30 deletions.
68 changes: 40 additions & 28 deletions caml_z.c
Original file line number Diff line number Diff line change
Expand Up @@ -902,43 +902,55 @@ CAMLprim value ml_z_format(value f, value v)
CAMLreturn(r);
}

#ifdef ARCH_SIXTYFOUR
#define BITS_PER_WORD 64
#else
#define BITS_PER_WORD 32
#endif
/* Fast path since len < BITS_PER_WORD */
CAMLprim value ml_z_extract_small(value arg, value off, value len)
{
Z_DECL(arg);
uintnat o, l; /* caml code ensures off and len are non signed */
intnat x;
mp_size_t c1, c2, csz, i;
mp_limb_t cr;
Z_ARG(arg);
o = (uintnat)Long_val(off);
l = (uintnat)Long_val(len);
c1 = o / Z_LIMB_BITS;
c2 = o % Z_LIMB_BITS;
csz = size_arg - c1;
if (csz > 0) {
if (c2) {
x = ptr_arg[c1] >> c2;
if ((c2 + l > (intnat)Z_LIMB_BITS) && (csz > 1))
x |= (ptr_arg[c1 + 1] << (Z_LIMB_BITS - c2));
}
else x = ptr_arg[c1];
}
else x = 0;
if (sign_arg) {
x = ~x;
if (csz > 0) {
/* carry (cr=0 if all shifted-out bits are 0) */
cr = ptr_arg[c1] & (((intnat)1 << c2) - 1);
for (i = 0; !cr && i < c1; i++)
cr = ptr_arg[i];
if (!cr) x ++;
}
}
x &= ((intnat)1 << l) - 1;
return Val_long(x);
}

CAMLprim value ml_z_extract(value arg, value off, value len)
{
intnat o, l, x;
uintnat o, l; /* caml code ensures off and len are non signed */
intnat x;
mp_size_t sz, c1, c2, csz, i;
mp_limb_t cr;
value r;
Z_DECL(arg);
Z_MARK_OP;
MAYBE_UNUSED x;
o = Long_val(off);
l = Long_val(len);
if (o < 0) caml_invalid_argument("Z.extract: negative bit offset");
if (l <= 0) caml_invalid_argument("Z.extract: nonpositive bit length");
#if Z_USE_NATINT
/* Fast path */
if (Is_long(arg)) {
x = Long_val(arg);
/* Shift away low "o" bits. If "o" too big, just replicate sign bit. */
if (o >= BITS_PER_WORD) o = BITS_PER_WORD - 1;
x = x >> o;
/* Extract "l" low bits, if "l" is small enough */
if (l < BITS_PER_WORD - 1) {
x = x & (((intnat)1 << l) - 1);
return Val_long(x);
} else {
/* If x >= 0, the extraction of "l" low bits keeps x unchanged. */
if (x >= 0) return Val_long(x);
/* If x < 0, fall through slow path */
}
}
#endif
o = (uintnat)Long_val(off);
l = (uintnat)Long_val(len);
Z_MARK_SLOW;
{
CAMLparam1(arg);
Expand Down
4 changes: 4 additions & 0 deletions tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ bench:: timings.exe
bench:: pi.exe
@echo "Benchmarking pi"; time ./pi.exe 10000 > /dev/null

test:: tst_extract.exe
@echo "Testing extract..."
@if ./tst_extract.exe; then echo "tst_extract: passed"; else echo "tst_extract: FAILED"; exit 2; fi

tofloat.exe: tofloat.ml setround.o ../zarith.cmxa
ocamlopt -I .. -ccopt "-L.." zarith.cmxa -o tofloat.exe \
setround.o tofloat.ml
Expand Down
35 changes: 35 additions & 0 deletions tests/tst_extract.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module I = Z

let pr ch x =
output_string ch (I.to_string x);
flush ch

let chk_extract x o l =
let expected =
I.logand (I.shift_right x o) (I.pred (I.shift_left (I.of_int 1) l))
and actual =
I.extract x o l in
if actual <> expected then (Printf.printf "extract %a %d %d = %a found %a\n" pr x o l pr expected pr actual; failwith "test failed")

let doit () =
let max = 128 in
for l = 1 to max do
if l mod 16 == 0 then Printf.printf "%i/%i\n%!" l max;
for o = 0 to 256 do
for n = 0 to 256 do
let x = I.shift_left I.one n in
chk_extract x o l;
chk_extract (I.mul x x) o l;
chk_extract (I.mul x (I.mul x x)) o l;
chk_extract (I.succ x) o l;
chk_extract (I.pred x) o l;
chk_extract (I.neg (I.mul x x)) o l;
chk_extract (I.neg (I.mul x (I.mul x x))) o l;
chk_extract (I.neg x) o l;
chk_extract (I.neg (I.succ x)) o l;
chk_extract (I.neg (I.pred x)) o l;
done
done
done

let _ = doit ()
31 changes: 30 additions & 1 deletion z.ml
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ external fits_int: t -> bool = "ml_z_fits_int" [@@noalloc]
external fits_int32: t -> bool = "ml_z_fits_int32" [@@noalloc]
external fits_int64: t -> bool = "ml_z_fits_int64" [@@noalloc]
external fits_nativeint: t -> bool = "ml_z_fits_nativeint" [@@noalloc]
external extract: t -> int -> int -> t = "ml_z_extract"
external powm: t -> t -> t -> t = "ml_z_powm"
external pow: t -> int -> t = "ml_z_pow"
external powm_sec: t -> t -> t -> t = "ml_z_powm_sec"
Expand Down Expand Up @@ -323,6 +322,36 @@ let testbit x n =
let is_odd x = testbit_internal x 0
let is_even x = not (testbit_internal x 0)

external c_extract_small: t -> int -> int -> t
= "ml_z_extract_small" [@@noalloc]
external c_extract: t -> int -> int -> t = "ml_z_extract"

let extract_internal x o l =
if is_small_int x then
(* Fast path *)
let o = if o >= Sys.int_size then Sys.int_size - 1 else o in
(* Shift away low "o" bits. If "o" too big, just replicate sign bit. *)
let z = unsafe_to_int x asr o in
if l < Sys.int_size then
(* Extract "l" low bits, if "l" is small enough *)
of_int (z land ((1 lsl l) - 1))
else if z >= 0 then
(* If x >= 0, the extraction of "l" low bits keeps x unchanged. *)
of_int z
else
(* If x < 0, fall through slow path *)
c_extract x o l
else if l < Sys.int_size then
(* Alternative fast path since no allocation is required *)
c_extract_small x o l
else
c_extract x o l

let extract x o l =
if o < 0 then invalid_arg "Z.extract: negative bit offset";
if l < 1 then invalid_arg "Z.extract: nonpositive bit length";
extract_internal x o l

let signed_extract x o l =
if o < 0 then invalid_arg "Z.signed_extract: negative bit offset";
if l < 1 then invalid_arg "Z.signed_extract: nonpositive bit length";
Expand Down
2 changes: 1 addition & 1 deletion z.mli
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ val log2up: t -> int
external size: t -> int = "ml_z_size" [@@noalloc]
(** Returns the number of machine words used to represent the number. *)

external extract: t -> int -> int -> t = "ml_z_extract"
val extract: t -> int -> int -> t
(** [extract a off len] returns a nonnegative number corresponding to bits
[off] to [off]+[len]-1 of [a].
Negative [a] are considered in infinite-length 2's complement
Expand Down

0 comments on commit 7d0ee55

Please sign in to comment.