Skip to content

Commit

Permalink
Merge pull request #2672 from o1-labs/feature/safe-lagrange-basis
Browse files Browse the repository at this point in the history
Make lagrange basis handling safe
  • Loading branch information
45930 authored Oct 9, 2024
2 parents 045a578 + a70e020 commit 3689889
Show file tree
Hide file tree
Showing 16 changed files with 106 additions and 60 deletions.
2 changes: 1 addition & 1 deletion circuit-construction/src/tests/example_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ fn test_example_circuit() {
// create SRS
let srs = {
let mut srs = SRS::<Vesta>::create(1 << 7); // 2^7 = 128
srs.add_lagrange_basis(Radix2EvaluationDomain::new(srs.g.len()).unwrap());
srs.get_lagrange_basis_from_domain_size(Radix2EvaluationDomain::new(srs.g.len()));
Arc::new(srs)
};

Expand Down
4 changes: 2 additions & 2 deletions kimchi/src/circuits/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3007,8 +3007,8 @@ pub mod test {
];
let index = {
let constraint_system = ConstraintSystem::fp_for_testing(gates);
let mut srs = SRS::<Vesta>::create(constraint_system.domain.d1.size());
srs.add_lagrange_basis(constraint_system.domain.d1);
let srs = SRS::<Vesta>::create(constraint_system.domain.d1.size());
srs.get_lagrange_basis(constraint_system.domain.d1);
let srs = Arc::new(srs);

let (endo_q, _endo_r) = endos::<Pallas>();
Expand Down
2 changes: 1 addition & 1 deletion kimchi/src/lagrange_basis_evaluations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl<F: FftField> LagrangeBasisEvaluations<F> {
chunked_evals[i * max_poly_size + j] = x_pow;
x_pow *= x;
}
// This uses the same trick as `poly_commitment::srs::SRS::add_lagrange_basis`, but
// This uses the same trick as `poly_commitment::srs::SRS::lagrange_basis`, but
// applied to field elements instead of group elements.
domain.ifft_in_place(&mut chunked_evals);
evals.push(chunked_evals);
Expand Down
4 changes: 2 additions & 2 deletions kimchi/src/prover_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,15 @@ pub mod testing {
override_srs_size,
|d1: D<G::ScalarField>, size: usize| {
let log2_size = size.ilog2();
let mut srs = if log2_size <= precomputed_srs::SERIALIZED_SRS_SIZE {
let srs = if log2_size <= precomputed_srs::SERIALIZED_SRS_SIZE {
// TODO: we should trim it if it's smaller
precomputed_srs::get_srs()
} else {
// TODO: we should resume the SRS generation starting from the serialized one
SRS::<G>::create(size)
};

srs.add_lagrange_basis(d1);
srs.get_lagrange_basis(d1);
srs
},
)
Expand Down
8 changes: 4 additions & 4 deletions kimchi/src/tests/foreign_field_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ fn create_test_constraint_system_ffadd(
};

let cs = ConstraintSystem::create(gates).public(1).build().unwrap();
let mut srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.add_lagrange_basis(cs.domain.d1);
let srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.get_lagrange_basis(cs.domain.d1);
let srs = Arc::new(srs);

let (endo_q, _endo_r) = endos::<Pallas>();
Expand Down Expand Up @@ -1491,8 +1491,8 @@ fn test_ffadd_finalization() {
.public(num_public_inputs)
.build()
.unwrap();
let mut srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.add_lagrange_basis(cs.domain.d1);
let srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.get_lagrange_basis(cs.domain.d1);
let srs = Arc::new(srs);

let (endo_q, _endo_r) = endos::<Pallas>();
Expand Down
4 changes: 2 additions & 2 deletions kimchi/src/tests/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ fn test_generic_gate_pairing() {
.witness(witness)
.public_inputs(public)
.setup_with_custom_srs(|d1, usize| {
let mut srs = poly_commitment::pairing_proof::PairingSRS::create(x, usize);
srs.full_srs.add_lagrange_basis(d1);
let srs = poly_commitment::pairing_proof::PairingSRS::create(x, usize);
srs.full_srs.get_lagrange_basis(d1);
srs
})
.prove_and_verify::<BaseSponge, ScalarSponge>()
Expand Down
4 changes: 2 additions & 2 deletions kimchi/src/tests/range_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1076,8 +1076,8 @@ fn verify_64_bit_range_check() {
.unwrap();

let index = {
let mut srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.add_lagrange_basis(cs.domain.d1);
let srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.get_lagrange_basis(cs.domain.d1);
let srs = Arc::new(srs);

let (endo_q, _endo_r) = endos::<Pallas>();
Expand Down
4 changes: 2 additions & 2 deletions kimchi/src/tests/rot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ fn test_rot_finalization() {
.public(num_public_inputs)
.build()
.unwrap();
let mut srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.add_lagrange_basis(cs.domain.d1);
let srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.get_lagrange_basis(cs.domain.d1);
let srs = Arc::new(srs);

let (endo_q, _endo_r) = endos::<Pallas>();
Expand Down
4 changes: 2 additions & 2 deletions kimchi/src/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ mod tests {
serde_json::from_str(&verifier_index_serialize).unwrap();

// add srs with lagrange bases
let mut srs = SRS::<Affine<VestaParameters>>::create(verifier_index.max_poly_size);
srs.add_lagrange_basis(verifier_index.domain);
let srs = SRS::<Affine<VestaParameters>>::create(verifier_index.max_poly_size);
srs.get_lagrange_basis(verifier_index.domain);
verifier_index_deserialize.powers_of_alpha = index.powers_of_alpha;
verifier_index_deserialize.linearization = index.linearization;
verifier_index_deserialize.srs = std::sync::Arc::new(srs);
Expand Down
4 changes: 2 additions & 2 deletions kimchi/src/tests/xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ fn test_xor_finalization() {
.public(num_inputs)
.build()
.unwrap();
let mut srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.add_lagrange_basis(cs.domain.d1);
let srs = SRS::<Vesta>::create(cs.domain.d1.size());
srs.get_lagrange_basis(cs.domain.d1);
let srs = Arc::new(srs);

let (endo_q, _endo_r) = endos::<Pallas>();
Expand Down
4 changes: 2 additions & 2 deletions kimchi/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use rand::thread_rng;
/// The result of a proof verification.
pub type Result<T> = std::result::Result<T, VerifyError>;

#[derive(Debug)]
pub struct Context<'a, G: KimchiCurve, OpeningProof: OpenProof<G>> {
/// The [VerifierIndex] associated to the proof
pub verifier_index: &'a VerifierIndex<G, OpeningProof>,
Expand Down Expand Up @@ -797,8 +798,7 @@ where
}
let lgr_comm = verifier_index
.srs()
.get_lagrange_basis(verifier_index.domain.size())
.expect("pre-computed committed lagrange bases not found");
.get_lagrange_basis(verifier_index.domain.size());
let com: Vec<_> = lgr_comm.iter().take(verifier_index.public).collect();
if public_input.is_empty() {
PolyComm::new(vec![verifier_index.srs().blinding_commitment(); chunk_size])
Expand Down
27 changes: 12 additions & 15 deletions poly-commitment/src/commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ impl<G: CommitmentCurve> SRSTrait<G> for SRS<G> {
self.g.len()
}

fn get_lagrange_basis(&self, domain_size: usize) -> Option<&Vec<PolyComm<G>>> {
self.lagrange_bases.get(&domain_size)
fn get_lagrange_basis(&self, domain_size: usize) -> &Vec<PolyComm<G>> {
self.get_lagrange_basis_from_domain_size(domain_size)
}

fn blinding_commitment(&self) -> G {
Expand Down Expand Up @@ -591,10 +591,7 @@ impl<G: CommitmentCurve> SRSTrait<G> for SRS<G> {
domain: D<G::ScalarField>,
plnm: &Evaluations<G::ScalarField, D<G::ScalarField>>,
) -> PolyComm<G> {
let basis = self
.lagrange_bases
.get(&domain.size())
.unwrap_or_else(|| panic!("lagrange bases for size {} not found", domain.size()));
let basis = self.get_lagrange_basis(domain);
let commit_evaluations = |evals: &Vec<G::ScalarField>, basis: &Vec<PolyComm<G>>| {
PolyComm::<G>::multi_scalar_mul(&basis.iter().collect::<Vec<_>>()[..], &evals[..])
};
Expand Down Expand Up @@ -824,8 +821,8 @@ mod tests {
let n = 64;
let domain = D::<Fp>::new(n).unwrap();

let mut srs = SRS::<VestaG>::create(n);
srs.add_lagrange_basis(domain);
let srs = SRS::<VestaG>::create(n);
srs.get_lagrange_basis(domain);

let num_chunks = domain.size() / srs.g.len();

Expand All @@ -838,7 +835,7 @@ mod tests {
})
.collect();

let computed_lagrange_commitments = srs.lagrange_bases.get(&domain.size()).unwrap();
let computed_lagrange_commitments = srs.get_lagrange_basis_from_domain_size(domain.size());
for i in 0..n {
assert_eq!(
computed_lagrange_commitments[i],
Expand All @@ -854,8 +851,8 @@ mod tests {
let divisor = 4;
let domain = D::<Fp>::new(n).unwrap();

let mut srs = SRS::<VestaG>::create(n / divisor);
srs.add_lagrange_basis(domain);
let srs = SRS::<VestaG>::create(n / divisor);
srs.get_lagrange_basis(domain);

let num_chunks = domain.size() / srs.g.len();
assert!(num_chunks == divisor);
Expand All @@ -869,7 +866,7 @@ mod tests {
})
.collect();

let computed_lagrange_commitments = srs.lagrange_bases.get(&domain.size()).unwrap();
let computed_lagrange_commitments = srs.get_lagrange_basis_from_domain_size(domain.size());
for i in 0..n {
assert_eq!(
computed_lagrange_commitments[i],
Expand All @@ -887,8 +884,8 @@ mod tests {
let n = 64;
let domain = D::<Fp>::new(n).unwrap();

let mut srs = SRS::<VestaG>::create(n / 2 + 1);
srs.add_lagrange_basis(domain);
let srs = SRS::<VestaG>::create(n / 2 + 1);
srs.get_lagrange_basis(domain);

// Is this even taken into account?...
let num_chunks = (domain.size() + srs.g.len() - 1) / srs.g.len();
Expand All @@ -903,7 +900,7 @@ mod tests {
})
.collect();

let computed_lagrange_commitments = srs.lagrange_bases.get(&domain.size()).unwrap();
let computed_lagrange_commitments = srs.get_lagrange_basis_from_domain_size(domain.size());
for i in 0..n {
assert_eq!(
computed_lagrange_commitments[i],
Expand Down
41 changes: 41 additions & 0 deletions poly-commitment/src/hash_map_cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use std::{
collections::{hash_map::Entry, HashMap},
hash::Hash,
sync::{Arc, Mutex},
};

#[derive(Debug, Clone, Default)]
pub struct HashMapCache<Key: Hash, Value> {
contents: Arc<Mutex<HashMap<Key, Value>>>,
}

impl<Key: Hash + std::cmp::Eq, Value> HashMapCache<Key, Value> {
pub fn new() -> Self {
HashMapCache {
contents: Arc::new(Mutex::new(HashMap::new())),
}
}

pub fn get_or_generate<F: FnOnce() -> Value>(&self, key: Key, generator: F) -> &Value {
let mut hashmap = self.contents.lock().unwrap();
let entry = (*hashmap).entry(key);
let inner_ptr = match entry {
Entry::Occupied(o) => {
let o_ref = o.into_mut();
&*o_ref as *const Value
}
Entry::Vacant(v) => {
let v_ref = v.insert(generator());
&*v_ref as *const Value
}
};

// This is safe because we never delete entries from the cache, and the value reference
// must live at least at most as long as the cache value.
unsafe { &*inner_ptr }
}

pub fn contains_key(&self, key: &Key) -> bool {
self.contents.lock().unwrap().contains_key(key)
}
}
5 changes: 3 additions & 2 deletions poly-commitment/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod combine;
pub mod commitment;
pub mod error;
pub mod evaluation_proof;
pub mod hash_map_cache;
pub mod pairing_proof;
pub mod srs;

Expand All @@ -29,7 +30,7 @@ pub trait SRS<G: CommitmentCurve> {
fn max_poly_size(&self) -> usize;

/// Retrieve the precomputed Lagrange basis for the given domain size
fn get_lagrange_basis(&self, domain_size: usize) -> Option<&Vec<PolyComm<G>>>;
fn get_lagrange_basis(&self, domain_size: usize) -> &Vec<PolyComm<G>>;

/// Get the group element used for blinding commitments
fn blinding_commitment(&self) -> G;
Expand Down Expand Up @@ -92,7 +93,7 @@ type PolynomialsToCombine<'a, G: CommitmentCurve, D: EvaluationDomain<G::ScalarF
)];

pub trait OpenProof<G: CommitmentCurve>: Sized {
type SRS: SRS<G>;
type SRS: SRS<G> + std::fmt::Debug;

#[allow(clippy::too_many_arguments)]
fn open<EFqSponge, RNG, D: EvaluationDomain<<G as AffineRepr>::ScalarField>>(
Expand Down
9 changes: 5 additions & 4 deletions poly-commitment/src/pairing_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ impl<
self.full_srs.max_poly_size()
}

fn get_lagrange_basis(&self, domain_size: usize) -> Option<&Vec<PolyComm<G>>> {
self.full_srs.get_lagrange_basis(domain_size)
fn get_lagrange_basis(&self, domain_size: usize) -> &Vec<PolyComm<G>> {
self.full_srs
.get_lagrange_basis_from_domain_size(domain_size)
}

fn blinding_commitment(&self) -> G {
Expand Down Expand Up @@ -349,9 +350,9 @@ mod tests {

let x = ScalarField::rand(rng);

let mut srs = SRS::<G1>::create_trusted_setup(x, n);
let srs = SRS::<G1>::create_trusted_setup(x, n);
let verifier_srs = SRS::<G2>::create_trusted_setup(x, 3);
srs.add_lagrange_basis(domain);
srs.get_lagrange_basis(domain);

let srs = PairingSRS {
full_srs: srs,
Expand Down
Loading

0 comments on commit 3689889

Please sign in to comment.