Skip to content

Commit

Permalink
[test] Add capability for testing the consistency of the cabac state
Browse files Browse the repository at this point in the history
  • Loading branch information
Jovasa committed Jun 28, 2022
1 parent 02931e8 commit 835b7fa
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 25 deletions.
17 changes: 15 additions & 2 deletions src/cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,11 @@ int uvg_config_init(uvg_config *cfg)
cfg->amvr = 0;

cfg->cclm = 0;



cfg->combine_intra_cus = 1;
cfg->force_inter = 0;

cfg->cabac_debug_file_name = NULL;
return 1;
}

Expand Down Expand Up @@ -1459,6 +1460,13 @@ int uvg_config_parse(uvg_config *cfg, const char *name, const char *value)
else if OPT("force-inter") {
cfg->force_inter = atobool(value);
}
else if OPT("cabac-debug-file") {
cfg->cabac_debug_file_name = strdup(value);
if(cfg->cabac_debug_file_name == NULL) {
fprintf(stderr, "Failed to allocate memory for cabac debug file name.\n");
return 0;
}
}
else {
return 0;
}
Expand Down Expand Up @@ -1812,6 +1820,11 @@ int uvg_config_validate(const uvg_config *const cfg)
}
}

if(cfg->owf != 0 && cfg->cabac_debug_file_name) {
fprintf(stderr, "OWF and cabac debugging are not supported at the same time.\n");
error = 1;
}

return !error;
}

Expand Down
3 changes: 3 additions & 0 deletions src/cli.c
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ static const struct option long_options[] = {
{ "no-combine-intra-cus", no_argument, NULL, 0 },
{ "force-inter", no_argument, NULL, 0 },
{ "no-force-inter", no_argument, NULL, 0 },
{ "cabac-debug-file", required_argument, NULL, 0 },
{0, 0, 0, 0}
};

Expand Down Expand Up @@ -460,6 +461,8 @@ void print_help(void)
" bits, lambda, distortion, and qp for each ctu.\n"
" These are meant for debugging and are not\n"
" written unless the prefix is defined.\n"
" --cabac-debug-file : A debug file for cabac context.\n"
" Ignore this, it is only for tests.\n"
"\n"
/* Word wrap to this width to stay under 80 characters (including ") *************/
"Video structure:\n"
Expand Down
4 changes: 4 additions & 0 deletions src/encode_coding_tree.c
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,10 @@ void uvg_encode_coding_tree(encoder_state_t * const state,
assert(0);
exit(1);
}
if (state->encoder_control->cabac_debug_file) {
fprintf(state->encoder_control->cabac_debug_file, "E %4d %4d %d", x, y, depth);
fwrite(&cabac->ctx, 1, sizeof(cabac->ctx), state->encoder_control->cabac_debug_file);
}

end:

Expand Down
14 changes: 14 additions & 0 deletions src/encoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,14 @@ encoder_control_t* uvg_encoder_control_init(const uvg_config *const cfg)
}
}

if(cfg->cabac_debug_file_name) {
encoder->cabac_debug_file = fopen(cfg->cabac_debug_file_name, "wb");
if (!encoder->cabac_debug_file) {
fprintf(stderr, "Could not open cabac debug file.\n");
goto init_failed;
}
}

if (cfg->fast_coeff_table_fn) {
FILE *fast_coeff_table_f = fopen(cfg->fast_coeff_table_fn, "rb");
if (fast_coeff_table_f == NULL) {
Expand Down Expand Up @@ -677,6 +685,8 @@ void uvg_encoder_control_free(encoder_control_t *const encoder)

FREE_POINTER(encoder->cfg.roi.file_path);

FREE_POINTER(encoder->cfg.cabac_debug_file_name);

uvg_scalinglist_destroy(&encoder->scaling_list);

uvg_threadqueue_free(encoder->threadqueue);
Expand All @@ -691,6 +701,10 @@ void uvg_encoder_control_free(encoder_control_t *const encoder)
fclose(encoder->roi_file);
}

if(encoder->cabac_debug_file) {
fclose(encoder->cabac_debug_file);
}

free(encoder);
}

Expand Down
2 changes: 2 additions & 0 deletions src/encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ typedef struct encoder_control_t

int8_t* qp_map[3];

FILE* cabac_debug_file;

} encoder_control_t;

encoder_control_t* uvg_encoder_control_init(const uvg_config *cfg);
Expand Down
6 changes: 0 additions & 6 deletions src/intra.c
Original file line number Diff line number Diff line change
Expand Up @@ -1539,12 +1539,6 @@ void uvg_intra_recon_cu(
y &= ~7;
}

if (mode_luma != -1 && mode_chroma != -1) {
if (search_data->pred_cu.intra.mip_flag) {
assert(mode_luma == mode_chroma && "Chroma mode must be derived from luma mode if block uses MIP.");
}
}

// Reset CBFs because CBFs might have been set
// for depth earlier
if (mode_luma >= 0) {
Expand Down
40 changes: 24 additions & 16 deletions src/search.c
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,8 @@ static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state,
const int chroma_width = MAX(4, LCU_WIDTH >> (depth + 1));
int8_t scan_order = uvg_get_scan_order(pred_cu->type, pred_cu->intra.mode_chroma, depth);
const unsigned index = xy_to_zorder(LCU_WIDTH_C, lcu_px.x, lcu_px.y);

const bool chroma_can_use_tr_skip = state->encoder_control->cfg.trskip_enable && chroma_width <= (1 << state->encoder_control->cfg.trskip_max_size);
if(pred_cu->joint_cb_cr == 0) {
if (!state->encoder_control->cfg.lossless) {
int index = lcu_px.y * LCU_WIDTH_C + lcu_px.x;
Expand All @@ -569,10 +571,10 @@ static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state,
chroma_width);
chroma_ssd = ssd_u + ssd_v;
}
if(can_use_tr_skip && cb_flag_u) {
if(chroma_can_use_tr_skip && cb_flag_u) {
CABAC_FBITS_UPDATE(cabac, &cabac->ctx.transform_skip_model_chroma, tr_cu->tr_skip & 2, tr_tree_bits, "transform_skip_flag");
}
if(can_use_tr_skip && cb_flag_v) {
if(chroma_can_use_tr_skip && cb_flag_v) {
CABAC_FBITS_UPDATE(cabac, &cabac->ctx.transform_skip_model_chroma, tr_cu->tr_skip & 4, tr_tree_bits, "transform_skip_flag");
}
coeff_bits += kvz_get_coeff_cost(state, &lcu->coeff.u[index], NULL, chroma_width, COLOR_U, scan_order, tr_cu->tr_skip & 2);
Expand Down Expand Up @@ -1155,6 +1157,11 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
(state->frame->slicetype != UVG_SLICE_I &&
depth < pu_depth_inter.max);

if(state->encoder_control->cabac_debug_file) {
fprintf(state->encoder_control->cabac_debug_file, "S %4d %4d %d", x, y, depth);
fwrite(&state->search_cabac.ctx, 1, sizeof(state->search_cabac.ctx), state->encoder_control->cabac_debug_file);
}

// Recursively split all the way to max search depth.
if (can_split_cu) {
int half_cu = cu_width / 2;
Expand All @@ -1165,6 +1172,21 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
memcpy(&state->search_cabac, &pre_search_cabac, sizeof(post_seach_cabac));


state->search_cabac.update = 1;

double split_bits = 0;

if (depth < MAX_DEPTH) {
// Add cost of cu_split_flag.
kvz_write_split_flag(state, &state->search_cabac,
x > 0 ? LCU_GET_CU_AT_PX(lcu, SUB_SCU(x) - 1, SUB_SCU(y)) : NULL,
y > 0 ? LCU_GET_CU_AT_PX(lcu, SUB_SCU(x), SUB_SCU(y) - 1) : NULL,
1, depth, cu_width, x, y, &split_bits);
}

state->search_cabac.update = 0;
split_cost += split_bits * state->lambda;

// If skip mode was selected for the block, skip further search.
// Skip mode means there's no coefficients in the block, so splitting
// might not give any better results but takes more time to do.
Expand All @@ -1179,20 +1201,6 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
split_cost = INT_MAX;
}

state->search_cabac.update = 1;

double split_bits = 0;

if (depth < MAX_DEPTH) {
// Add cost of cu_split_flag.
uvg_write_split_flag(state, &state->search_cabac,
x > 0 ? LCU_GET_CU_AT_PX(lcu, SUB_SCU(x) - 1, SUB_SCU(y)) : NULL,
y > 0 ? LCU_GET_CU_AT_PX(lcu, SUB_SCU(x), SUB_SCU(y) - 1) : NULL,
1, depth, cu_width, x, y, &split_bits);
}

state->search_cabac.update = 0;
split_cost += split_bits * state->lambda;

// If no search is not performed for this depth, try just the best mode
// of the top left CU from the next depth. This should ensure that 64x64
Expand Down
1 change: 0 additions & 1 deletion src/search_intra.c
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,6 @@ int8_t uvg_search_intra_chroma_rdo(
uvg_intra_build_reference(log2_width, COLOR_V, &luma_px, &pic_px, lcu, &refs[1], state->encoder_control->cfg.wpp, NULL, 0);

const vector2d_t lcu_px = { SUB_SCU(x_px), SUB_SCU(y_px) };
cu_info_t *const tr_cu = LCU_GET_CU_AT_PX(lcu, lcu_px.x, lcu_px.y);
cabac_data_t temp_cabac;
memcpy(&temp_cabac, &state->search_cabac, sizeof(cabac_data_t));
int8_t width = 1 << log2_width;
Expand Down
1 change: 1 addition & 0 deletions src/uvg266.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ typedef struct uvg_config
uint8_t combine_intra_cus;

uint8_t force_inter;
char* cabac_debug_file_name;
} uvg_config;

/**
Expand Down
88 changes: 88 additions & 0 deletions tests/check_cabac_state_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import sys
from pathlib import Path
from pprint import pprint


def get_ctx_count_and_names(cabac_source_file: Path):
count = 0
names = []
with open(cabac_source_file) as file:
for line in file:
line = line.strip().split(";")[0]
if line.startswith("cabac_ctx_t") and not "*" in line:
temp = 1
name = line.split()[1].split("[")[0]
for k in line.split("[")[1:]:
temp *= int(k[:-1])
count += temp
if temp == 1:
names.append(name)
else:
names.extend([f"{name}[{x}]" for x in range(temp)])
return count, names


def main(state_file: Path, ctx_names: list, ctx_count: int = 332, ctx_size: int = 6):
ctx_store = dict()
e_store = set()
was_zero_last = False
frame_num = -1
with open(state_file, "rb") as file:
try:
while True:
type_, x, y, depth = file.read(13).decode().split()
# Reset stored data at the beginning of the frame
if x == '0' and y == '0' and type_ == "S":
if not was_zero_last:
frame_num += 1
ctx_store = dict()
e_store = set()
was_zero_last = True
else:
was_zero_last = False

ctx = file.read(ctx_count * ctx_size)
if type_ == "S":
# These shouldn't happen but just to make sure everything is working as intended
if ctx_store.get((x, y, depth)):
raise RuntimeError
ctx_store[(x, y, depth)] = ctx
else:
if (x, y, depth) in e_store:
raise RuntimeError
e_store.add((x, y, depth))
if (s_ctx := ctx_store[(x, y, depth)]) != ctx:
actual_problem = False

for i in range(ctx_count):
temp_s = s_ctx[i * ctx_size: (i + 1) * ctx_size]
temp_e = ctx[i * ctx_size: (i + 1) * ctx_size]
if temp_s != temp_e:
if ctx_names[i] in ignore_list:
continue
actual_problem = True
print(f"MISSMATCH in {ctx_names[i]} {frame_num=} {x=} {y=} {depth=}")
print(
f"GOT : {int.from_bytes(temp_s[0:2], 'little')}:"
f"{int.from_bytes(temp_s[2:4], 'little')} "
f"rate={int.from_bytes(temp_s[4:5], 'big')}")
print(
f"EXPECTED: {int.from_bytes(temp_e[0:2], 'little')}:"
f"{int.from_bytes(temp_e[2:4], 'little')} "
f"rate={int.from_bytes(temp_e[4:5], 'big')}")
if actual_problem:
exit(1)
except ValueError:
# EOF
pass


if __name__ == '__main__':
ignore_list = {"sao_type_idx_model", "sao_merge_flag_model"}
if len(sys.argv) < 2:
print("Usage: name of the file storing the cabac states")
exit(1)
path = Path(__file__) / "../.." / "src" / "cabac.h"
print(path.resolve())
counts, names_ = get_ctx_count_and_names(path.resolve())
main(Path(sys.argv[1]), names_, counts)
12 changes: 12 additions & 0 deletions tests/test_cabac_state.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/sh

set -eu

. "${0%/*}/util.sh"

cabacfile="$(mktemp)"

valgrind_test 256x128 10 yuv420p --preset veryslow --rd 3 --mip --jccr --mrl -p 1 --owf 0 --no-wpp --cabac-debug-file="${cabacfile}"
python3 check_cabac_state_consistency.py "${cabacfile}"
rm -rf "${cabacfile}"

0 comments on commit 835b7fa

Please sign in to comment.