Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
multi_scalar_mul.test.cpp
Go to the documentation of this file.
2#include "acir_format.hpp"
6
7#include <gtest/gtest.h>
8#include <vector>
9
10using namespace ::acir_format;
11
12enum class InputConstancy : uint8_t { None, Points, Scalars, Both };
13
23template <typename Builder_, InputConstancy Constancy> class MultiScalarMulTestingFunctions {
24 public:
25 using Builder = Builder_;
26 using AcirConstraint = MultiScalarMul;
29 using FF = bb::fr;
30
32 public:
33 enum class Target : uint8_t {
34 None,
35 Points, // Invalidate point inputs
36 Scalars, // Invalidate scalar inputs
37 Result // Invalidate result output
38 };
39
44
45 static std::vector<std::string> get_labels() { return { "None", "Points", "Scalars", "Result" }; }
46 };
47
48 static ProgramMetadata generate_metadata() { return ProgramMetadata{}; }
49
50 static void generate_constraints(AcirConstraint& msm_constraint, WitnessVector& witness_values)
51 {
52 // Generate a single point and scalar for simplicity
54 bb::fq scalar_native = bb::fq::random_element();
55 GrumpkinPoint result = point * scalar_native;
56 BB_ASSERT(result != GrumpkinPoint::one()); // Ensure that tampering works correctly
57
58 // Split scalar into low and high limbs (128 bits each) as FF for witness values
59 uint256_t scalar_u256 = uint256_t(scalar_native);
60 FF scalar_lo = scalar_u256.slice(0, 128);
61 FF scalar_hi = scalar_u256.slice(128, 256);
62
63 // Determine which inputs are constants based on the Constancy template parameter
64 constexpr bool points_are_constant = (Constancy == InputConstancy::Points || Constancy == InputConstancy::Both);
65 constexpr bool scalars_are_constant =
66 (Constancy == InputConstancy::Scalars || Constancy == InputConstancy::Both);
67
68 // Helper to add points: either as witnesses or constants based on Constancy
69 auto construct_points = [&]() -> std::vector<WitnessOrConstant<FF>> {
70 if constexpr (points_are_constant) {
71 // Points are constants
72 return { WitnessOrConstant<FF>::from_constant(point.x),
73 WitnessOrConstant<FF>::from_constant(point.y),
74 WitnessOrConstant<FF>::from_constant(point.is_point_at_infinity() ? FF(1) : FF(0)) };
75 }
76 // Points are witnesses
77 std::vector<uint32_t> point_indices = add_to_witness_and_track_indices(witness_values, point);
78 return { WitnessOrConstant<FF>::from_index(point_indices[0]),
79 WitnessOrConstant<FF>::from_index(point_indices[1]),
80 WitnessOrConstant<FF>::from_index(point_indices[2]) };
81 };
82
83 // Helper to add scalars: either as witnesses or constants based on Constancy
84 auto construct_scalars = [&]() -> std::vector<WitnessOrConstant<FF>> {
85 if constexpr (scalars_are_constant) {
86 // Scalars are constants
87 return { WitnessOrConstant<FF>::from_constant(scalar_lo),
88 WitnessOrConstant<FF>::from_constant(scalar_hi) };
89 }
90 // Scalars are witnesses
91 uint32_t scalar_lo_index = static_cast<uint32_t>(witness_values.size());
92 witness_values.emplace_back(scalar_lo);
93 uint32_t scalar_hi_index = static_cast<uint32_t>(witness_values.size());
94 witness_values.emplace_back(scalar_hi);
95 return { WitnessOrConstant<FF>::from_index(scalar_lo_index),
96 WitnessOrConstant<FF>::from_index(scalar_hi_index) };
97 };
98
99 // Add points and scalars according to constancy template parameter
100 auto point_fields = construct_points();
101 auto scalar_fields = construct_scalars();
102
103 // Construct result and predicate as witnesses
104 std::vector<uint32_t> result_indices = add_to_witness_and_track_indices(witness_values, result);
105 uint32_t predicate_index = static_cast<uint32_t>(witness_values.size());
106 witness_values.emplace_back(FF::one()); // predicate
107
108 // Build the constraint
109 msm_constraint = MultiScalarMul{
110 .points = point_fields,
111 .scalars = scalar_fields,
112 .predicate = WitnessOrConstant<FF>::from_index(predicate_index),
113 .out_point_x = result_indices[0],
114 .out_point_y = result_indices[1],
115 .out_point_is_infinite = result_indices[2],
116 };
117 }
118
120 AcirConstraint constraint, WitnessVector witness_values, const InvalidWitness::Target& invalid_witness_target)
121 {
122 switch (invalid_witness_target) {
124 // Invalidate the point by adding 1 to x coordinate
125 if constexpr (Constancy == InputConstancy::None || Constancy == InputConstancy::Scalars) {
126 witness_values[constraint.points[0].index] += bb::fr(1);
127 } else {
128 constraint.points[0] = WitnessOrConstant<FF>::from_constant(constraint.points[0].value + bb::fr(1));
129 }
130 break;
131 }
133 // Invalidate the scalar by adding 1 to the low limb
134 if constexpr (Constancy == InputConstancy::None || Constancy == InputConstancy::Points) {
135 witness_values[constraint.scalars[0].index] += bb::fr(1);
136 } else {
137 constraint.scalars[0] = WitnessOrConstant<FF>::from_constant(constraint.scalars[0].value + bb::fr(1));
138 }
139 break;
140 }
142 // Tamper with the result by setting it to the generator point
143 witness_values[constraint.out_point_x] = GrumpkinPoint::one().x;
144 witness_values[constraint.out_point_y] = GrumpkinPoint::one().y;
145 witness_values[constraint.out_point_is_infinite] = FF::zero();
146 break;
147 }
149 default:
150 break;
151 }
152
153 return { constraint, witness_values };
154 };
155};
156
157template <typename Builder>
159 : public ::testing::Test,
160 public TestClassWithPredicate<MultiScalarMulTestingFunctions<Builder, InputConstancy::None>> {
161 protected:
163};
164
165template <typename Builder>
167 : public ::testing::Test,
168 public TestClassWithPredicate<MultiScalarMulTestingFunctions<Builder, InputConstancy::Points>> {
169 protected:
171};
172
173template <typename Builder>
175 : public ::testing::Test,
176 public TestClassWithPredicate<MultiScalarMulTestingFunctions<Builder, InputConstancy::Scalars>> {
177 protected:
179};
180
181template <typename Builder>
183 : public ::testing::Test,
184 public TestClassWithPredicate<MultiScalarMulTestingFunctions<Builder, InputConstancy::Both>> {
185 protected:
187};
188
189using BuilderTypes = testing::Types<UltraCircuitBuilder, MegaCircuitBuilder>;
190
195
197{
199 TestFixture::template test_vk_independence<Flavor>();
200}
201
203{
205 TestFixture::test_constant_true(TestFixture::InvalidWitnessTarget::Result);
206}
207
209{
211 TestFixture::test_witness_true(TestFixture::InvalidWitnessTarget::Result);
212}
213
215{
217 TestFixture::test_witness_false_slow();
218}
219
221{
223 [[maybe_unused]] std::vector<std::string> _ = TestFixture::test_invalid_witnesses();
224}
225
227{
229 TestFixture::template test_vk_independence<Flavor>();
230}
231
233{
235 TestFixture::test_constant_true(TestFixture::InvalidWitnessTarget::Result);
236}
237
239{
241 TestFixture::test_witness_true(TestFixture::InvalidWitnessTarget::Result);
242}
243
245{
247 TestFixture::test_witness_false_slow();
248}
249
251{
253 [[maybe_unused]] std::vector<std::string> _ = TestFixture::test_invalid_witnesses();
254}
255
257{
259 TestFixture::template test_vk_independence<Flavor>();
260}
261
263{
265 TestFixture::test_constant_true(TestFixture::InvalidWitnessTarget::Result);
266}
267
269{
271 TestFixture::test_witness_true(TestFixture::InvalidWitnessTarget::Result);
272}
273
275{
277 TestFixture::test_witness_false_slow();
278}
279
281{
283 [[maybe_unused]] std::vector<std::string> _ = TestFixture::test_invalid_witnesses();
284}
285
287{
289 TestFixture::template test_vk_independence<Flavor>();
290}
291
293{
295 TestFixture::test_constant_true(TestFixture::InvalidWitnessTarget::Result);
296}
297
299{
301 TestFixture::test_witness_true(TestFixture::InvalidWitnessTarget::Result);
302}
303
305{
307 TestFixture::test_witness_false_slow();
308}
309
311{
313 [[maybe_unused]] std::vector<std::string> _ = TestFixture::test_invalid_witnesses();
314}
315
316// ============================================================
317// Infinity flag tests
318// ============================================================
319
320// ACIR convention for encoding a curve point: (x, y, is_infinite) as field values.
322using MsmFF = bb::fr;
323
327 {
328 return { p.x, p.y, p.is_point_at_infinity() ? MsmFF(1) : MsmFF(0) };
329 }
330 static MsmAcirPoint infinity() { return { MsmFF(0), MsmFF(0), MsmFF(1) }; }
331};
332
333// Grumpkin scalar split into low 128-bit and high 128-bit field limbs.
334struct MsmScalar {
336 static MsmScalar zero() { return { MsmFF(0), MsmFF(0) }; }
337 static MsmScalar from_native(const bb::fq& s)
338 {
339 uint256_t u = uint256_t(s);
340 return { u.slice(0, 128), u.slice(128, 256) };
341 }
342};
343
344template <typename Builder> class MultiScalarMulInfinityTests : public ::testing::Test {
345 protected:
347
348 // Push an MsmAcirPoint to witness; return [x, y, inf] indices.
349 static std::array<uint32_t, 3> push_point(WitnessVector& witness, const MsmAcirPoint& pt)
350 {
351 uint32_t xi = static_cast<uint32_t>(witness.size());
352 witness.emplace_back(pt.x);
353 uint32_t yi = static_cast<uint32_t>(witness.size());
354 witness.emplace_back(pt.y);
355 uint32_t ii = static_cast<uint32_t>(witness.size());
356 witness.emplace_back(pt.inf);
357 return { xi, yi, ii };
358 }
359
360 // Push a scalar (lo, hi) to witness; return [lo_idx, hi_idx].
361 static std::array<uint32_t, 2> push_scalar(WitnessVector& witness, const MsmScalar& s)
362 {
363 uint32_t lo_idx = static_cast<uint32_t>(witness.size());
364 witness.emplace_back(s.lo);
365 uint32_t hi_idx = static_cast<uint32_t>(witness.size());
366 witness.emplace_back(s.hi);
367 return { lo_idx, hi_idx };
368 }
369
370 // Build a single-term MSM constraint (predicate=1) from a point, scalar, and expected result.
371 // Returns the constraint and the populated witness vector.
373 {
374 WitnessVector witness;
375 auto p = push_point(witness, point);
376 auto s = push_scalar(witness, scalar);
377 auto r = push_point(witness, result);
378 uint32_t pred_idx = static_cast<uint32_t>(witness.size());
379 witness.emplace_back(MsmFF(1));
380
381 MultiScalarMul c{
382 .points = { WitnessOrConstant<MsmFF>::from_index(p[0]),
383 WitnessOrConstant<MsmFF>::from_index(p[1]),
384 WitnessOrConstant<MsmFF>::from_index(p[2]) },
385 .scalars = { WitnessOrConstant<MsmFF>::from_index(s[0]), WitnessOrConstant<MsmFF>::from_index(s[1]) },
386 .predicate = WitnessOrConstant<MsmFF>::from_index(pred_idx),
387 .out_point_x = r[0],
388 .out_point_y = r[1],
389 .out_point_is_infinite = r[2],
390 };
391 return { c, witness };
392 }
393
394 // Run the circuit and return (satisfied, error_string).
395 static std::pair<bool, std::string> run_circuit(MultiScalarMul constraint, WitnessVector witness)
396 {
397 AcirFormat cs = constraint_to_acir_format(constraint);
398 AcirProgram program{ cs, witness };
399 auto builder = create_circuit<Builder>(program, ProgramMetadata{});
400 bool ok = CircuitChecker::check(builder) && !builder.failed();
401 return { ok, builder.err() };
402 }
403};
404
406
407// scalar=0 → result = ∞: valid proof with out_point_is_infinite=1.
409{
412 auto [constraint, witness] =
413 TestFixture::make_msm(MsmAcirPoint::from_native(point), MsmScalar::zero(), MsmAcirPoint::infinity());
414
415 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
416 EXPECT_TRUE(ok) << "0 * P = infinity should produce a valid circuit";
417}
418
419// A finite result with out_point_is_infinite=1 (forged) must fail.
420TYPED_TEST(MultiScalarMulInfinityTests, ForgedInfinityFlagOnFiniteResultFails)
421{
424 bb::fq scalar_native = bb::fq::random_element();
425 while (scalar_native.is_zero()) {
426 scalar_native = bb::fq::random_element();
427 }
428 MsmGrumpkinPoint result = point * scalar_native;
429 ASSERT_FALSE(result.is_point_at_infinity());
430 auto [constraint, witness] = TestFixture::make_msm(
432 witness[constraint.out_point_is_infinite] = MsmFF(1); // forge: finite result claimed as infinite
433
434 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
435 EXPECT_TRUE(!ok || err.find("assert_eq") != std::string::npos)
436 << "Forged infinity flag on finite result should fail";
437}
438
439// An infinity result with out_point_is_infinite=0 (forged) must fail.
440TYPED_TEST(MultiScalarMulInfinityTests, ForgedFiniteFlagOnInfinityResultFails)
441{
444 // Forge result: (0,0) coordinates but is_infinite=0 (should be 1)
445 auto [constraint, witness] = TestFixture::make_msm(
447
448 auto [ok, err] = TestFixture::run_circuit(constraint, witness);
449 EXPECT_TRUE(!ok || err.find("assert_eq") != std::string::npos)
450 << "Forged finite flag on infinity result should fail";
451}
#define BB_ASSERT(expression,...)
Definition assert.hpp:70
#define BB_DISABLE_ASSERTS()
Definition assert.hpp:33
static std::array< uint32_t, 3 > push_point(WitnessVector &witness, const MsmAcirPoint &pt)
static std::pair< bool, std::string > run_circuit(MultiScalarMul constraint, WitnessVector witness)
static std::array< uint32_t, 2 > push_scalar(WitnessVector &witness, const MsmScalar &s)
static std::pair< MultiScalarMul, WitnessVector > make_msm(MsmAcirPoint point, MsmScalar scalar, MsmAcirPoint result)
Testing functions to generate the MultiScalarMul test suite. Constancy specifies which inputs to the ...
static ProgramMetadata generate_metadata()
static void generate_constraints(AcirConstraint &msm_constraint, WitnessVector &witness_values)
static std::pair< AcirConstraint, WitnessVector > invalidate_witness(AcirConstraint constraint, WitnessVector witness_values, const InvalidWitness::Target &invalid_witness_target)
static bool check(const Builder &circuit)
Check the witness satisifies the circuit.
constexpr bool is_point_at_infinity() const noexcept
static affine_element random_element(numeric::RNG *engine=nullptr) noexcept
Samples a random point on the curve.
static constexpr affine_element one() noexcept
group class. Represents an elliptic curve group element. Group is parametrised by Fq and Fr
Definition group.hpp:38
group_elements::affine_element< Fq, Fr, Params > affine_element
Definition group.hpp:44
constexpr uint256_t slice(uint64_t start, uint64_t end) const
AluTraceBuilder builder
Definition alu.test.cpp:124
TYPED_TEST(MultiScalarMulTestsNoneConstant, GenerateVKFromConstraints)
bb::fr MsmFF
TYPED_TEST_SUITE(MultiScalarMulTestsNoneConstant, BuilderTypes)
bb::group< bb::fr, bb::fq, G1Params > g1
Definition grumpkin.hpp:46
std::filesystem::path bb_crs_path()
void init_file_crs_factory(const std::filesystem::path &path)
field< Bn254FrParams > fr
Definition fr.hpp:155
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
::testing::Types< UltraCircuitBuilder, MegaCircuitBuilder > BuilderTypes
static MsmAcirPoint infinity()
static MsmAcirPoint from_native(const MsmGrumpkinPoint &p)
static MsmScalar zero()
static MsmScalar from_native(const bb::fq &s)
static constexpr field one()
static field random_element(numeric::RNG *engine=nullptr) noexcept
BB_INLINE constexpr bool is_zero() const noexcept
static constexpr field zero()