Skip to content

Commit

Permalink
wrapped context
Browse files Browse the repository at this point in the history
  • Loading branch information
k88hudson-cfa committed Oct 20, 2024
1 parent 86fa968 commit 3e9d288
Showing 1 changed file with 113 additions and 13 deletions.
126 changes: 113 additions & 13 deletions src/people.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{context::Context, define_data_plugin};
use std::{
any::{Any, TypeId},
cell::{RefCell, RefMut},
collections::HashMap,
collections::{HashMap, HashSet},
fmt,
};

Expand All @@ -12,14 +12,18 @@ use std::{
struct PeopleData {
current_population: usize,
properties_map: RefCell<HashMap<TypeId, Box<dyn Any>>>,
dependency_map: RefCell<HashMap<TypeId, HashSet<TypeId>>>,
initialized_derived_properties: RefCell<HashSet<TypeId>>,
}

define_data_plugin!(
PeoplePlugin,
PeopleData,
PeopleData {
current_population: 0,
properties_map: RefCell::new(HashMap::new())
properties_map: RefCell::new(HashMap::new()),
dependency_map: RefCell::new(HashMap::new()),
initialized_derived_properties: RefCell::new(HashSet::new()),
}
);

Expand All @@ -36,6 +40,41 @@ impl fmt::Debug for PersonId {
}
}

struct WrappedContext<'a> {
context: &'a Context,
type_id: TypeId,
}

impl<'a> WrappedContext<'a> {
// Takes a reference to Context instead of a value
fn new(context: &'a Context, type_id: TypeId) -> Self {
Self { context, type_id }
}

fn get_person_property<T: PersonProperty + 'static>(
&self,
person_id: PersonId,
property: T,
) -> T::Value {
let data_container = self.context.get_data_container(PeoplePlugin)
.expect("PeoplePlugin is not initialized; make sure you add a person before accessing properties");
if !data_container
.initialized_derived_properties
.borrow()
.contains(&self.type_id)
{
data_container
.dependency_map
.borrow_mut()
.entry(self.type_id)
.or_default()
.insert(TypeId::of::<T>());
}

self.context.get_person_property(person_id, property)
}
}

// Individual characteristics or states related to a person, such as age or
// disease status, are represented as "person properties". These properties
// * are represented by a struct type that implements the PersonProperty trait,
Expand All @@ -48,6 +87,9 @@ pub trait PersonProperty: Copy {
fn is_derived() -> bool {
false
}
fn dependencies() -> Vec<TypeId> {
Vec::new()
}
fn compute(context: &Context, person_id: PersonId) -> Self::Value;
}

Expand Down Expand Up @@ -103,10 +145,14 @@ macro_rules! define_derived_property {
true
}
fn compute(
_context: &$crate::context::Context,
context: &$crate::context::Context,
_person: $crate::people::PersonId,
) -> Self::Value {
$compute(_context, _person)
let wrapped_context = $crate::people::WrappedContext::new(
context,
std::any::TypeId::of::<$derived_property>(),
);
$compute(&wrapped_context, _person)
}
}
};
Expand Down Expand Up @@ -216,6 +262,8 @@ pub trait ContextPeopleExt {
_property: T,
value: T::Value,
);

fn register_derived_property<T: PersonProperty + 'static>(&self);
}

impl ContextPeopleExt for Context {
Expand Down Expand Up @@ -290,6 +338,15 @@ impl ContextPeopleExt for Context {
}
};

// When there are indexes, we will update them here + also for the derived properties
if data_container
.dependency_map
.borrow()
.contains_key(&TypeId::of::<T>())
{
println!("Derived properties were changed");
}

let change_event: PersonPropertyChangeEvent<T> = PersonPropertyChangeEvent {
person_id,
current: value,
Expand All @@ -298,28 +355,52 @@ impl ContextPeopleExt for Context {
data_container.set_person_property(person_id, property, value);
self.emit_event(change_event);
}

fn register_derived_property<T: PersonProperty + 'static>(&self) {
let type_id = TypeId::of::<T>();
let data_container = self.get_data_container(PeoplePlugin)
.expect("PeoplePlugin is not initialized; make sure you add a person before accessing properties");
if !data_container
.initialized_derived_properties
.borrow()
.contains(&type_id)
{
data_container
.dependency_map
.borrow_mut()
.entry(type_id)
.or_default()
.insert(TypeId::of::<T>());
}
}
}

#[cfg(test)]
mod test {
use super::{ContextPeopleExt, PersonCreatedEvent, PersonId, PersonPropertyChangeEvent};
use super::{
ContextPeopleExt, PersonCreatedEvent, PersonId, PersonPropertyChangeEvent, WrappedContext,
};
use crate::{context::Context, people::PeoplePlugin};
use std::{cell::RefCell, rc::Rc};
use std::{any::TypeId, cell::RefCell, collections::HashSet, rc::Rc};

define_person_property!(Age, u8);
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum AgeGroupType {
Child,
Adult,
}
define_derived_property!(AgeGroup, AgeGroupType, |context: &Context, person| {
let age = context.get_person_property(person, Age);
if age < 18 {
AgeGroupType::Child
} else {
AgeGroupType::Adult
define_derived_property!(
AgeGroup,
AgeGroupType,
|context: &WrappedContext, person| {
let age = context.get_person_property(person, Age);
if age < 18 {
AgeGroupType::Child
} else {
AgeGroupType::Adult
}
}
});
);

#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum RiskCategory {
Expand Down Expand Up @@ -606,4 +687,23 @@ mod test {
context.execute();
assert!(*flag.borrow());
}

#[test]
fn derived_property_registers_dependencies() {
let mut context = Context::new();
let person = context.add_person();
context.initialize_person_property(person, Age, 17);
context.get_person_property(person, AgeGroup);
context.execute();
let actual = context
.get_data_container(PeoplePlugin)
.unwrap()
.dependency_map
.borrow()
.get(&TypeId::of::<AgeGroup>())
.unwrap()
.clone();
let expected = HashSet::from_iter([TypeId::of::<Age>()]);
assert_eq!(actual, expected);
}
}

0 comments on commit 3e9d288

Please sign in to comment.