Skip to content

Commit

Permalink
Use schema names for generating svg hierarchy. (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
keyurva authored May 28, 2024
1 parent 0daedbf commit d98bfde
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 45 deletions.
5 changes: 3 additions & 2 deletions simple/stats/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ def _generate_svg_hierarchy(self):
vertical_specs_fh.basename())
vertical_specs = stat_var_hierarchy_generator.load_vertical_specs(
vertical_specs_fh.read_string())
svg_triples = stat_var_hierarchy_generator.generate(sv_triples,
vertical_specs)
# TODO: get dcid to name mappings and pass to generator.
svg_triples = stat_var_hierarchy_generator.generate(
triples=sv_triples, vertical_specs=vertical_specs, dcid2name={})
logging.info("Inserting %s SVG triples into DB.", len(svg_triples))
self.db.insert_triples(svg_triples)

Expand Down
90 changes: 53 additions & 37 deletions simple/stats/stat_var_hierarchy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from stats.data import VerticalSpec


def generate(triples: list[Triple],
vertical_specs: list[VerticalSpec]) -> list[Triple]:
def generate(triples: list[Triple], vertical_specs: list[VerticalSpec],
dcid2name: dict[str, str]) -> list[Triple]:
"""Given a list of input triples (including stat vars),
generates a SV hierarchy and returns a list of output triples
representing the hierarchy.
"""
return _generate_internal(triples, vertical_specs).svg_triples
return _generate_internal(triples, vertical_specs, dcid2name).svg_triples


def load_vertical_specs(data: str) -> list[VerticalSpec]:
Expand All @@ -46,9 +46,9 @@ def load_vertical_specs(data: str) -> list[VerticalSpec]:


# TODO: Pruning (e.g. ignore Thing).
def _generate_internal(
triples: list[Triple],
vertical_specs: list[VerticalSpec]) -> "StatVarHierarchy":
def _generate_internal(triples: list[Triple],
vertical_specs: list[VerticalSpec],
dcid2name: dict[str, str]) -> "StatVarHierarchy":
"""Given a list of input triples (including stat vars),
generates a SV hierarchy and returns a list of output triples
representing the hierarchy.
Expand All @@ -57,14 +57,14 @@ def _generate_internal(
# Extract SVs.
svs = _extract_svs(triples)
# Create SVGs.
svgs = _create_all_svgs(svs)
svgs = _create_all_svgs(svs, dcid2name)
# Sort by SVG ID so it's easier to follow the hierarchy.
svgs = dict(sorted(svgs.items()))

# Get pop type svgs (they don't have a parent set at this stage).
pop_type_svgs = _get_pop_type_svgs(svgs)
# Attach verticals to pop type svgs.
vertical_svgs = _attach_verticals(pop_type_svgs, vertical_specs)
vertical_svgs = _attach_verticals(pop_type_svgs, vertical_specs, dcid2name)
# Sort by SVG ID so it's easier to follow the verticals.
vertical_svgs = dict(sorted(vertical_svgs.items()))

Expand All @@ -90,10 +90,10 @@ def gen_pv_id(self) -> str:
return f"{_to_dcid_token(self.prop)}-{_to_dcid_token(self.val)}"
return _to_dcid_token(self.prop)

def gen_pv_name(self) -> str:
def gen_pv_name(self, dcid2name: dict[str, str]) -> str:
if self.val:
return f"{_capitalize_and_split(self.prop)} = {_capitalize_and_split(self.val)}"
return _capitalize_and_split(self.prop)
return f"{_gen_name(self.prop, dcid2name)} = {_gen_name(self.val, dcid2name)}"
return _gen_name(self.prop, dcid2name)


# TODO: DPV handling.
Expand All @@ -113,18 +113,19 @@ def gen_svg_id(self):
svg_id = f"{svg_id}_{pv.gen_pv_id()}"
return svg_id

def gen_svg_name(self):
svg_name = _capitalize_and_split(self.population_type)
def gen_svg_name(self, dcid2name: dict[str, str]):
svg_name = _gen_name(self.population_type, dcid2name)
if self.pvs:
pvs_str = ", ".join(map(lambda pv: pv.gen_pv_name(), self.pvs))
pvs_str = ", ".join(map(lambda pv: pv.gen_pv_name(dcid2name), self.pvs))
svg_name = f"{svg_name} With {pvs_str}"
return svg_name

def gen_specialized_name(self, parent_pvs: Self) -> str:
def gen_specialized_name(self, parent_pvs: Self, dcid2name: dict[str,
str]) -> str:
parent_parts = parent_pvs._get_pv_parts()
child_parts = self._get_pv_parts()
parts = [part for part in child_parts if part not in parent_parts]
return ", ".join(map(lambda part: _capitalize_and_split(part), parts))
return ", ".join(map(lambda part: _gen_name(part, dcid2name), parts))

# Creates and returns a new SVPropVals object with the same fields as this object
# except for PVs which are set to the specified list.
Expand Down Expand Up @@ -187,9 +188,11 @@ def triples(self) -> list[Triple]:

return triples

def gen_specialized_name(self, parent_svg: Self) -> str:
def gen_specialized_name(self, parent_svg: Self, dcid2name: dict[str,
str]) -> str:
if self.sample_sv and parent_svg.sample_sv:
return self.sample_sv.gen_specialized_name(parent_svg.sample_sv)
return self.sample_sv.gen_specialized_name(parent_svg.sample_sv,
dcid2name)
return ""

# For testing.
Expand All @@ -216,7 +219,8 @@ class StatVarHierarchy:

# Attaches matching pop type svgs to vertical svgs, creates those vertical svgs and returns them.
def _attach_verticals(poptype2svg: dict[str, SVG],
vertical_specs: list[VerticalSpec]) -> dict[str, SVG]:
vertical_specs: list[VerticalSpec],
dcid2name: dict[str, str]) -> dict[str, SVG]:
vertical_svgs: dict[str, SVG] = {}
for vertical_spec in vertical_specs:
pop_type_svg = poptype2svg.get(vertical_spec.population_type)
Expand All @@ -229,7 +233,8 @@ def _attach_verticals(poptype2svg: dict[str, SVG],
continue
# Put pop type svg under all verticals in the spec.
for vertical in vertical_spec.verticals:
vertical_svg = _get_or_create_vertical_svg(vertical, vertical_svgs)
vertical_svg = _get_or_create_vertical_svg(vertical, vertical_svgs,
dcid2name)
vertical_svgs[vertical_svg.svg_id] = vertical_svg
vertical_svg.child_svg_id_2_specialized_name[pop_type_svg.svg_id] = ""
pop_type_svg.parent_svg_ids[vertical_svg.svg_id] = True
Expand All @@ -244,12 +249,12 @@ def _attach_verticals(poptype2svg: dict[str, SVG],
return vertical_svgs


def _get_or_create_vertical_svg(vertical: str, vertical_svgs: dict[str,
SVG]) -> SVG:
def _get_or_create_vertical_svg(vertical: str, vertical_svgs: dict[str, SVG],
dcid2name: dict[str, str]) -> SVG:
vertical_svg_id = f"{sc.CUSTOM_SVG_PREFIX}{vertical}"
vertical_svg = vertical_svgs.get(vertical_svg_id)
if not vertical_svg:
vertical_svg = SVG(vertical_svg_id, _capitalize_and_split(vertical))
vertical_svg = SVG(vertical_svg_id, _gen_name(vertical, dcid2name))
vertical_svg.parent_svg_ids[sc.DEFAULT_CUSTOM_ROOT_SVG_ID] = True
return vertical_svg

Expand All @@ -264,11 +269,12 @@ def _get_pop_type_svgs(svgs: dict[str, SVG]) -> dict[str, SVG]:
return poptype2svg


def _get_or_create_svg(svgs: dict[str, SVG], sv: SVPropVals) -> SVG:
def _get_or_create_svg(svgs: dict[str, SVG], sv: SVPropVals,
dcid2name: dict[str, str]) -> SVG:
svg_id = sv.gen_svg_id()
svg = svgs.get(svg_id)
if not svg:
svg = SVG(svg_id=svg_id, svg_name=sv.gen_svg_name())
svg = SVG(svg_id=svg_id, svg_name=sv.gen_svg_name(dcid2name))
svg.sample_sv = sv
svgs[svg_id] = svg
# Add SV mprop to the SVG.
Expand All @@ -284,18 +290,20 @@ def _create_all_svg_triples(svgs: dict[str, SVG]):
return triples


def _create_all_svgs(svs: list[SVPropVals]) -> dict[str, SVG]:
svgs = _create_leaf_svgs(svs)
def _create_all_svgs(svs: list[SVPropVals],
dcid2name: dict[str, str]) -> dict[str, SVG]:
svgs = _create_leaf_svgs(svs, dcid2name)
for svg_id in list(svgs.keys()):
_create_parent_svgs(svg_id, svgs)
_create_parent_svgs(svg_id, svgs, dcid2name)
return svgs


# Create SVGs that the SVs are directly attached to.
def _create_leaf_svgs(svs: list[SVPropVals]) -> dict[str, SVG]:
def _create_leaf_svgs(svs: list[SVPropVals],
dcid2name: dict[str, str]) -> dict[str, SVG]:
svgs: dict[str, SVG] = {}
for sv in svs:
svg = _get_or_create_svg(svgs, sv)
svg = _get_or_create_svg(svgs, sv, dcid2name)
# Insert SV into SVG.
svg.sv_ids[sv.sv_id] = True
return svgs
Expand All @@ -314,24 +322,26 @@ def _add_measured_properties_to_parent_svgs(mprops: dict[str, bool],


def _create_parent_svg(parent_sv: SVPropVals, svg: SVG, svgs: dict[str, SVG],
svg_has_prop_without_val: bool):
parent_svg = _get_or_create_svg(svgs, parent_sv)
svg_has_prop_without_val: bool, dcid2name: dict[str,
str]):
parent_svg = _get_or_create_svg(svgs, parent_sv, dcid2name)

# Add parent child relationships.
svg.parent_svg_ids[parent_svg.svg_id] = True
parent_svg.child_svg_id_2_specialized_name[
svg.svg_id] = svg.gen_specialized_name(parent_svg)
svg.svg_id] = svg.gen_specialized_name(parent_svg, dcid2name)

# Add child mprops to all parents recursively.
_add_measured_properties_to_parent_svgs(svg.measured_properties,
svg.parent_svg_ids, svgs)

if not parent_svg.parent_svgs_processed:
parent_svg.has_prop_without_val = svg_has_prop_without_val
_create_parent_svgs(parent_svg.svg_id, svgs)
_create_parent_svgs(parent_svg.svg_id, svgs, dcid2name)


def _create_parent_svgs(svg_id: str, svgs: dict[str, SVG]):
def _create_parent_svgs(svg_id: str, svgs: dict[str, SVG],
dcid2name: dict[str, str]):
svg = svgs[svg_id]
sv = svg.sample_sv

Expand All @@ -355,7 +365,8 @@ def _create_parent_svgs(svg_id: str, svgs: dict[str, SVG]):
_create_parent_svg(parent_sv=sv.with_pvs(parent_pvs),
svg=svg,
svgs=svgs,
svg_has_prop_without_val=False)
svg_has_prop_without_val=False,
dcid2name=dcid2name)
# Process SVGs with vals.
else:
for pv1 in sv.pvs:
Expand All @@ -370,7 +381,8 @@ def _create_parent_svgs(svg_id: str, svgs: dict[str, SVG]):
_create_parent_svg(parent_sv=sv.with_pvs(parent_pvs),
svg=svg,
svgs=svgs,
svg_has_prop_without_val=True)
svg_has_prop_without_val=True,
dcid2name=dcid2name)

svg.parent_svgs_processed = True

Expand All @@ -397,6 +409,10 @@ def _capitalize_and_split(s: str) -> str:
return _split_camel_case(_capitalize(s))


def _gen_name(dcid: str, dcid2name: dict[str, str]) -> str:
return _capitalize_and_split(dcid2name.get(dcid) or dcid)


def _to_dcid_token(token: str) -> str:
# Remove all non-alphanumeric characters.
result = re.sub("[^0-9a-zA-Z]+", "", token)
Expand Down
54 changes: 48 additions & 6 deletions simple/tests/stats/stat_var_hierarchy_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def _mcf_to_triples(mcf_path: str) -> list[Triple]:
def _test_generate_internal(test: unittest.TestCase,
test_name: str,
is_mcf_input: bool = False,
has_vertical_specs: bool = False):
has_vertical_specs: bool = False,
has_schema_names: bool = False):
test.maxDiff = None

with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -91,6 +92,13 @@ def _test_generate_internal(test: unittest.TestCase,
with open(input_vertical_specs_path, "r") as file:
vertical_specs = load_vertical_specs(file.read())

dcid2name: dict[str, str] = {}
if has_schema_names:
input_schema_names_path = os.path.join(_INPUT_DIR,
f"{test_name}.schema_names.json")
with open(input_schema_names_path, "r") as file:
dcid2name = json.load(file)

output_svgs_json_path = os.path.join(temp_dir, f"{test_name}_svgs.json")
expected_svgs_json_path = os.path.join(_EXPECTED_DIR,
f"{test_name}_svgs.json")
Expand All @@ -99,7 +107,7 @@ def _test_generate_internal(test: unittest.TestCase,
expected_triples_csv_path = os.path.join(_EXPECTED_DIR,
f"{test_name}_triples.csv")

hierarchy = _generate_internal(input_triples, vertical_specs)
hierarchy = _generate_internal(input_triples, vertical_specs, dcid2name)
# Write SVGs json
svgs_json = [svg.json() for _, svg in hierarchy.svgs.items()]
with open(output_svgs_json_path, "w") as out:
Expand Down Expand Up @@ -142,6 +150,9 @@ def test_generate_internal_svs_with_mprops(self):
def test_generate_internal_verticals(self):
_test_generate_internal(self, "verticals", has_vertical_specs=True)

def test_generate_internal_schema_names(self):
_test_generate_internal(self, "schema_names", has_schema_names=True)

def test_extract_svs(self):
input_triples: list[Triple] = [
Triple("sv1", "typeOf", "StatisticalVariable", ""),
Expand Down Expand Up @@ -191,7 +202,7 @@ def test_extract_svs(self):
population_type="",
pvs=[PropVal("gender", "Female"),
PropVal("race", "Asian")],
measured_property=""), "Female"),
measured_property=""), {}, "Female"),
(SVPropVals(sv_id="",
population_type="",
pvs=[PropVal("gender", "Female")],
Expand All @@ -200,8 +211,39 @@ def test_extract_svs(self):
population_type="",
pvs=[PropVal("gender", "Female"),
PropVal("race", "")],
measured_property=""), "Race")
measured_property=""), {}, "Race"),
(SVPropVals(sv_id="",
population_type="",
pvs=[PropVal("gender", "Female")],
measured_property=""),
SVPropVals(
sv_id="",
population_type="",
pvs=[PropVal("gender", "Female"),
PropVal("povertyStatus", "")],
measured_property=""), {
"povertyStatus": "State of poverty"
}, "State of poverty"),
(SVPropVals(
sv_id="",
population_type="",
pvs=[PropVal("gender", "Female"),
PropVal("povertyStatus", "")],
measured_property=""),
SVPropVals(sv_id="",
population_type="",
pvs=[
PropVal("gender", "Female"),
PropVal("povertyStatus",
"BelowPovertyLevelInThePast12Months")
],
measured_property=""), {
"povertyStatus":
"State of poverty",
"BelowPovertyLevelInThePast12Months":
"BelowPovertyLevel in the last year"
}, "Below Poverty Level in the last year")
])
def test_gen_specialized_name(self, parent: SVPropVals, child: SVPropVals,
expected: str):
self.assertEqual(child.gen_specialized_name(parent), expected)
dcid2name: dict[str, str], expected: str):
self.assertEqual(child.gen_specialized_name(parent, dcid2name), expected)
Loading

0 comments on commit d98bfde

Please sign in to comment.