Skip to main content

crypto/mlkem/
mlkem.rs

1use constant_time_eq::constant_time_eq;
2#[cfg(feature = "zeroize")]
3use zeroize::{Zeroize, ZeroizeOnDrop};
4
5use crate::{
6    Xof,
7    sha3::{Sha3_256, Sha3_512, Shake128, Shake256},
8};
9
10/// Size of the shared secret produced by ML-KEM encapsulation/decapsulation (32 bytes).
11pub const SHARED_SECRET_SIZE: usize = 32;
12
13pub(crate) const N: usize = 256;
14pub(crate) const Q: i16 = 3329;
15const SYMBYTES: usize = 32;
16const POLY_BYTES: usize = 384;
17const SHAKE128_RATE: usize = 168;
18const QINV: i16 = -3327;
19const MONT_SQUARED_DIV_N: i16 = 1441;
20const ZETAS: [i16; 128] = [
21    -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, 182, 962, -1202, -1474, 1468, 573, -1325, 264,
22    383, -829, 1458, -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, 652, -552, 1015, -1293, 1491,
23    -282, -1544, 516, -8, -320, -666, -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, -951, -398, 961,
24    -1508, -725, 448, -1065, 677, -1275, -1103, 430, 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, -460,
25    1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097,
26    603, 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, -1187, -1659, -1185, -1530, -1278, 794,
27    -1510, -854, -870, 478, -108, -308, 996, 991, 958, -1460, 1522, 1628,
28];
29
30pub(crate) const ML_KEM_768: MlKemParams<3> = MlKemParams {
31    eta1: 2,
32    polycompressedbytes: 128,
33    polyveccompressedbytes: 960,
34};
35pub(crate) const ML_KEM_1024: MlKemParams<4> = MlKemParams {
36    eta1: 2,
37    polycompressedbytes: 160,
38    polyveccompressedbytes: 1408,
39};
40
41/// ML-KEM error type.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum MlKemError {
44    InvalidKey,
45}
46
47impl core::fmt::Display for MlKemError {
48    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
49        match self {
50            MlKemError::InvalidKey => write!(f, "key is not valid"),
51        }
52    }
53}
54
55#[derive(Clone, Copy)]
56pub(crate) struct MlKemParams<const K: usize> {
57    pub(crate) eta1: usize,
58    pub(crate) polycompressedbytes: usize,
59    pub(crate) polyveccompressedbytes: usize,
60}
61
62#[derive(Clone, Debug, PartialEq, Eq)]
63#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
64pub(crate) struct Poly {
65    pub(crate) coeffs: [i16; N],
66}
67
68impl Default for Poly {
69    #[inline]
70    fn default() -> Self {
71        Self {
72            coeffs: [0; N],
73        }
74    }
75}
76
77#[derive(Clone, Debug, PartialEq, Eq)]
78#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
79pub(crate) struct PolyVec<const K: usize> {
80    pub(crate) vec: [Poly; K],
81}
82
83impl<const K: usize> Default for PolyVec<K> {
84    #[inline]
85    fn default() -> Self {
86        Self {
87            vec: core::array::from_fn(|_| Poly::default()),
88        }
89    }
90}
91
92#[inline]
93pub(crate) fn crypto_kem_keypair_derand<const K: usize, const SECRET_KEY_SIZE: usize, const PUBLIC_KEY_SIZE: usize>(
94    params: &MlKemParams<K>,
95    coins: &[u8; 64],
96) -> ([u8; SECRET_KEY_SIZE], [u8; PUBLIC_KEY_SIZE]) {
97    let mut public_key = [0u8; PUBLIC_KEY_SIZE];
98    let mut secret_key = [0u8; SECRET_KEY_SIZE];
99
100    indcpa_keypair_derand::<K>(
101        params,
102        &mut public_key,
103        &mut secret_key[..indcpa_secret_key_bytes::<K>()],
104        &coins[..32],
105    );
106    secret_key[indcpa_secret_key_bytes::<K>()..indcpa_secret_key_bytes::<K>() + PUBLIC_KEY_SIZE]
107        .copy_from_slice(&public_key);
108
109    let public_key_hash = hash_h(&public_key);
110    secret_key[SECRET_KEY_SIZE - 64..SECRET_KEY_SIZE - 32].copy_from_slice(&public_key_hash);
111    secret_key[SECRET_KEY_SIZE - 32..].copy_from_slice(&coins[32..]);
112
113    (secret_key, public_key)
114}
115
116#[inline]
117pub(crate) fn crypto_kem_enc_derand<const K: usize, const PUBLIC_KEY_SIZE: usize, const CIPHERTEXT_SIZE: usize>(
118    params: &MlKemParams<K>,
119    public_key: &[u8; PUBLIC_KEY_SIZE],
120    coins: &[u8; 32],
121) -> ([u8; CIPHERTEXT_SIZE], [u8; SHARED_SECRET_SIZE]) {
122    let mut ciphertext = [0u8; CIPHERTEXT_SIZE];
123    let mut buf = [0u8; 64];
124    let mut kr = [0u8; 64];
125
126    buf[..32].copy_from_slice(coins);
127    buf[32..].copy_from_slice(&hash_h(public_key));
128    kr.copy_from_slice(&hash_g(&buf));
129
130    indcpa_enc::<K>(params, &mut ciphertext, &buf[..32], public_key, array_ref_32(&kr[32..64]));
131
132    let mut shared_secret = [0u8; SHARED_SECRET_SIZE];
133    shared_secret.copy_from_slice(&kr[..32]);
134    (ciphertext, shared_secret)
135}
136
137#[inline]
138pub(crate) fn crypto_kem_dec<const K: usize, const SECRET_KEY_SIZE: usize, const CIPHERTEXT_SIZE: usize>(
139    params: &MlKemParams<K>,
140    secret_key: &[u8; SECRET_KEY_SIZE],
141    ciphertext: &[u8; CIPHERTEXT_SIZE],
142) -> Result<[u8; SHARED_SECRET_SIZE], MlKemError> {
143    let public_key_offset = indcpa_secret_key_bytes::<K>();
144    let public_key_size = public_key_bytes::<K>();
145    if SECRET_KEY_SIZE != secret_key_size::<K>() {
146        return Err(MlKemError::InvalidKey);
147    }
148
149    let public_key = &secret_key[public_key_offset..public_key_offset + public_key_size];
150    let mut message_and_hash = [0u8; 64];
151    let mut kr = [0u8; 64];
152    let mut cmp = [0u8; CIPHERTEXT_SIZE];
153
154    indcpa_dec::<K>(
155        params,
156        &mut message_and_hash[..32],
157        ciphertext,
158        &secret_key[..public_key_offset],
159    );
160    message_and_hash[32..].copy_from_slice(&secret_key[SECRET_KEY_SIZE - 64..SECRET_KEY_SIZE - 32]);
161    kr.copy_from_slice(&hash_g(&message_and_hash));
162
163    indcpa_enc::<K>(params, &mut cmp, &message_and_hash[..32], public_key, array_ref_32(&kr[32..64]));
164
165    let mut shared_secret = rkprf(array_ref_32(&secret_key[SECRET_KEY_SIZE - 32..]), ciphertext);
166    cmov(&mut shared_secret, array_ref_32(&kr[..32]), constant_time_eq(ciphertext, &cmp));
167    Ok(shared_secret)
168}
169
170#[inline]
171pub(crate) fn indcpa_keypair_derand<const K: usize>(
172    params: &MlKemParams<K>,
173    public_key: &mut [u8],
174    secret_key: &mut [u8],
175    coins: &[u8],
176) {
177    debug_assert_eq!(public_key.len(), public_key_bytes::<K>());
178    debug_assert_eq!(secret_key.len(), indcpa_secret_key_bytes::<K>());
179    debug_assert_eq!(coins.len(), 32);
180
181    let mut g_input = [0u8; 33];
182    g_input[..32].copy_from_slice(coins);
183    g_input[32] = K as u8;
184    let seed_output = hash_g(&g_input);
185    let public_seed = array_ref_32(&seed_output[..32]);
186    let noise_seed = array_ref_32(&seed_output[32..64]);
187    let matrix = gen_matrix::<K>(public_seed, false);
188
189    let mut skpv = PolyVec::<K>::default();
190    let mut e = PolyVec::<K>::default();
191    for (index, poly) in skpv.vec.iter_mut().enumerate() {
192        *poly = poly_getnoise(noise_seed, index as u8, params.eta1);
193    }
194    for (index, poly) in e.vec.iter_mut().enumerate() {
195        *poly = poly_getnoise(noise_seed, (K + index) as u8, params.eta1);
196    }
197
198    polyvec_ntt(&mut skpv);
199    polyvec_ntt(&mut e);
200
201    let mut pkpv = PolyVec::<K>::default();
202    for i in 0..K {
203        pkpv.vec[i] = polyvec_basemul_acc_montgomery(&matrix[i], &skpv);
204        poly_tomont(&mut pkpv.vec[i]);
205    }
206
207    polyvec_add(&mut pkpv, &e);
208    polyvec_reduce(&mut pkpv);
209
210    pack_sk(secret_key, &skpv);
211    pack_pk(public_key, &pkpv, public_seed);
212}
213
214#[inline]
215pub(crate) fn indcpa_enc<const K: usize>(
216    params: &MlKemParams<K>,
217    ciphertext: &mut [u8],
218    message: &[u8],
219    public_key: &[u8],
220    coins: &[u8; 32],
221) {
222    debug_assert_eq!(ciphertext.len(), ciphertext_bytes(params));
223    debug_assert_eq!(message.len(), 32);
224    debug_assert_eq!(public_key.len(), public_key_bytes::<K>());
225
226    let (pkpv, seed) = unpack_pk::<K>(public_key);
227    let at = gen_matrix::<K>(&seed, true);
228    let k = poly_frommsg(message);
229
230    let mut sp = PolyVec::<K>::default();
231    let mut ep = PolyVec::<K>::default();
232    for (index, poly) in sp.vec.iter_mut().enumerate() {
233        *poly = poly_getnoise(coins, index as u8, params.eta1);
234    }
235    let ep_nonce_offset = sp.vec.len();
236    for (index, poly) in ep.vec.iter_mut().enumerate() {
237        *poly = poly_getnoise(coins, (ep_nonce_offset + index) as u8, 2);
238    }
239    let epp = poly_getnoise(coins, (sp.vec.len() + ep.vec.len()) as u8, 2);
240
241    polyvec_ntt(&mut sp);
242
243    let mut b = PolyVec::<K>::default();
244    for i in 0..K {
245        b.vec[i] = polyvec_basemul_acc_montgomery(&at[i], &sp);
246    }
247    let mut v = polyvec_basemul_acc_montgomery(&pkpv, &sp);
248
249    polyvec_invntt_tomont(&mut b);
250    poly_invntt_tomont(&mut v);
251
252    polyvec_add(&mut b, &ep);
253    poly_add(&mut v, &epp);
254    poly_add(&mut v, &k);
255    polyvec_reduce(&mut b);
256    poly_reduce(&mut v);
257
258    pack_ciphertext(params, ciphertext, &b, &v);
259}
260
261#[inline]
262pub(crate) fn indcpa_dec<const K: usize>(
263    params: &MlKemParams<K>,
264    message: &mut [u8],
265    ciphertext: &[u8],
266    secret_key: &[u8],
267) {
268    debug_assert_eq!(message.len(), 32);
269    debug_assert_eq!(ciphertext.len(), ciphertext_bytes(params));
270    debug_assert_eq!(secret_key.len(), indcpa_secret_key_bytes::<K>());
271
272    let (mut b, v) = unpack_ciphertext::<K>(params, ciphertext);
273    let skpv = unpack_sk::<K>(secret_key);
274
275    polyvec_ntt(&mut b);
276    let mut mp = polyvec_basemul_acc_montgomery(&skpv, &b);
277    poly_invntt_tomont(&mut mp);
278    let product = mp.clone();
279    poly_sub(&mut mp, &v, &product);
280    poly_reduce(&mut mp);
281
282    message.copy_from_slice(&poly_tomsg(&mp));
283}
284
285#[inline]
286fn pack_pk<const K: usize>(out: &mut [u8], pk: &PolyVec<K>, seed: &[u8; 32]) {
287    let polyvec_bytes = polyvec_bytes::<K>();
288    polyvec_tobytes(&mut out[..polyvec_bytes], pk);
289    out[polyvec_bytes..polyvec_bytes + 32].copy_from_slice(seed);
290}
291
292#[inline]
293fn unpack_pk<const K: usize>(packed: &[u8]) -> (PolyVec<K>, [u8; 32]) {
294    let polyvec_bytes = polyvec_bytes::<K>();
295    let pk = polyvec_frombytes::<K>(&packed[..polyvec_bytes]);
296    let mut seed = [0u8; 32];
297    seed.copy_from_slice(&packed[polyvec_bytes..polyvec_bytes + 32]);
298    (pk, seed)
299}
300
301#[inline]
302fn pack_sk<const K: usize>(out: &mut [u8], sk: &PolyVec<K>) {
303    polyvec_tobytes(out, sk);
304}
305
306#[inline]
307fn unpack_sk<const K: usize>(packed: &[u8]) -> PolyVec<K> {
308    polyvec_frombytes(packed)
309}
310
311#[inline]
312fn pack_ciphertext<const K: usize>(params: &MlKemParams<K>, out: &mut [u8], b: &PolyVec<K>, v: &Poly) {
313    let split = params.polyveccompressedbytes;
314    polyvec_compress(params, &mut out[..split], b);
315    poly_compress(params, &mut out[split..split + params.polycompressedbytes], v);
316}
317
318#[inline]
319fn unpack_ciphertext<const K: usize>(params: &MlKemParams<K>, packed: &[u8]) -> (PolyVec<K>, Poly) {
320    let split = params.polyveccompressedbytes;
321    (
322        polyvec_decompress(params, &packed[..split]),
323        poly_decompress(params, &packed[split..split + params.polycompressedbytes]),
324    )
325}
326
327#[inline]
328pub(crate) fn gen_matrix<const K: usize>(seed: &[u8; 32], transpose: bool) -> [PolyVec<K>; K] {
329    let mut matrix = core::array::from_fn(|_| PolyVec::<K>::default());
330    for i in 0..K {
331        for j in 0..K {
332            let (x, y) = if transpose {
333                (i as u8, j as u8)
334            } else {
335                (j as u8, i as u8)
336            };
337            matrix[i].vec[j] = uniform_poly(seed, x, y);
338        }
339    }
340    matrix
341}
342
343#[inline]
344fn uniform_poly(seed: &[u8; 32], x: u8, y: u8) -> Poly {
345    let mut shake = Shake128::new();
346    shake.absorb(seed);
347    shake.absorb(&[x, y]);
348
349    let mut poly = Poly::default();
350    let mut ctr = 0usize;
351    let mut block = [0u8; SHAKE128_RATE];
352    while ctr < N {
353        shake.squeeze(&mut block);
354        ctr += rej_uniform(&mut poly.coeffs[ctr..], &block);
355    }
356    poly
357}
358
359#[inline]
360fn rej_uniform(out: &mut [i16], buf: &[u8]) -> usize {
361    let mut ctr = 0usize;
362    let mut pos = 0usize;
363    while ctr < out.len() && pos + 3 <= buf.len() {
364        let val0 = (((buf[pos] as u16) | ((buf[pos + 1] as u16) << 8)) & 0x0fff) as i16;
365        let val1 = ((((buf[pos + 1] as u16) >> 4) | ((buf[pos + 2] as u16) << 4)) & 0x0fff) as i16;
366        pos += 3;
367
368        if val0 < Q {
369            out[ctr] = val0;
370            ctr += 1;
371        }
372        if ctr < out.len() && val1 < Q {
373            out[ctr] = val1;
374            ctr += 1;
375        }
376    }
377    ctr
378}
379
380#[inline]
381pub(crate) fn poly_getnoise(seed: &[u8; 32], nonce: u8, eta: usize) -> Poly {
382    debug_assert_eq!(eta, 2);
383    let mut input = [0u8; 33];
384    input[..32].copy_from_slice(seed);
385    input[32] = nonce;
386    let mut buf = [0u8; 128];
387    Shake256::hash(&input, &mut buf);
388    cbd2(&buf)
389}
390
391#[inline]
392fn cbd2(buf: &[u8; 128]) -> Poly {
393    let mut poly = Poly::default();
394    for i in 0..(N / 8) {
395        let t = load32(&buf[4 * i..4 * i + 4]);
396        let mut d = t & 0x5555_5555;
397        d += (t >> 1) & 0x5555_5555;
398        for j in 0..8 {
399            let a = ((d >> (4 * j)) & 0x3) as i16;
400            let b = ((d >> (4 * j + 2)) & 0x3) as i16;
401            poly.coeffs[8 * i + j] = a - b;
402        }
403    }
404    poly
405}
406
407#[inline]
408pub(crate) fn polyvec_compress<const K: usize>(params: &MlKemParams<K>, out: &mut [u8], a: &PolyVec<K>) {
409    match params.polyveccompressedbytes {
410        960 => {
411            let mut offset = 0usize;
412            for poly in &a.vec {
413                for chunk in poly.coeffs.chunks_exact(4) {
414                    let mut t = [0u16; 4];
415                    for (dst, coeff) in t.iter_mut().zip(chunk.iter()) {
416                        let mut u = *coeff as i32;
417                        u += (u >> 15) & Q as i32;
418                        let mut d0 = u as u64;
419                        d0 <<= 10;
420                        d0 += 1665;
421                        d0 *= 1_290_167;
422                        d0 >>= 32;
423                        *dst = (d0 as u16) & 0x03ff;
424                    }
425                    out[offset] = t[0] as u8;
426                    out[offset + 1] = ((t[0] >> 8) as u8) | ((t[1] << 2) as u8);
427                    out[offset + 2] = ((t[1] >> 6) as u8) | ((t[2] << 4) as u8);
428                    out[offset + 3] = ((t[2] >> 4) as u8) | ((t[3] << 6) as u8);
429                    out[offset + 4] = (t[3] >> 2) as u8;
430                    offset += 5;
431                }
432            }
433        }
434        1408 => {
435            let mut offset = 0usize;
436            for poly in &a.vec {
437                for chunk in poly.coeffs.chunks_exact(8) {
438                    let mut t = [0u16; 8];
439                    for (dst, coeff) in t.iter_mut().zip(chunk.iter()) {
440                        let mut u = *coeff as i32;
441                        u += (u >> 15) & Q as i32;
442                        let mut d0 = u as u64;
443                        d0 <<= 11;
444                        d0 += 1664;
445                        d0 *= 645_084;
446                        d0 >>= 31;
447                        *dst = (d0 as u16) & 0x07ff;
448                    }
449                    out[offset] = t[0] as u8;
450                    out[offset + 1] = ((t[0] >> 8) as u8) | ((t[1] << 3) as u8);
451                    out[offset + 2] = ((t[1] >> 5) as u8) | ((t[2] << 6) as u8);
452                    out[offset + 3] = (t[2] >> 2) as u8;
453                    out[offset + 4] = ((t[2] >> 10) as u8) | ((t[3] << 1) as u8);
454                    out[offset + 5] = ((t[3] >> 7) as u8) | ((t[4] << 4) as u8);
455                    out[offset + 6] = ((t[4] >> 4) as u8) | ((t[5] << 7) as u8);
456                    out[offset + 7] = (t[5] >> 1) as u8;
457                    out[offset + 8] = ((t[5] >> 9) as u8) | ((t[6] << 2) as u8);
458                    out[offset + 9] = ((t[6] >> 6) as u8) | ((t[7] << 5) as u8);
459                    out[offset + 10] = (t[7] >> 3) as u8;
460                    offset += 11;
461                }
462            }
463        }
464        _ => unreachable!(),
465    }
466}
467
468#[inline]
469pub(crate) fn polyvec_decompress<const K: usize>(params: &MlKemParams<K>, input: &[u8]) -> PolyVec<K> {
470    let mut out = PolyVec::<K>::default();
471    match params.polyveccompressedbytes {
472        960 => {
473            let mut offset = 0usize;
474            for poly in &mut out.vec {
475                for j in 0..(N / 4) {
476                    let t0 = (input[offset] as u16) | ((input[offset + 1] as u16) << 8);
477                    let t1 = ((input[offset + 1] as u16) >> 2) | ((input[offset + 2] as u16) << 6);
478                    let t2 = ((input[offset + 2] as u16) >> 4) | ((input[offset + 3] as u16) << 4);
479                    let t3 = ((input[offset + 3] as u16) >> 6) | ((input[offset + 4] as u16) << 2);
480                    offset += 5;
481                    poly.coeffs[4 * j] = ((((t0 & 0x03ff) as u32) * Q as u32 + 512) >> 10) as i16;
482                    poly.coeffs[4 * j + 1] = ((((t1 & 0x03ff) as u32) * Q as u32 + 512) >> 10) as i16;
483                    poly.coeffs[4 * j + 2] = ((((t2 & 0x03ff) as u32) * Q as u32 + 512) >> 10) as i16;
484                    poly.coeffs[4 * j + 3] = ((((t3 & 0x03ff) as u32) * Q as u32 + 512) >> 10) as i16;
485                }
486            }
487        }
488        1408 => {
489            let mut offset = 0usize;
490            for poly in &mut out.vec {
491                for j in 0..(N / 8) {
492                    let t0 = (input[offset] as u16) | ((input[offset + 1] as u16) << 8);
493                    let t1 = ((input[offset + 1] as u16) >> 3) | ((input[offset + 2] as u16) << 5);
494                    let t2 = ((input[offset + 2] as u16) >> 6)
495                        | ((input[offset + 3] as u16) << 2)
496                        | ((input[offset + 4] as u16) << 10);
497                    let t3 = ((input[offset + 4] as u16) >> 1) | ((input[offset + 5] as u16) << 7);
498                    let t4 = ((input[offset + 5] as u16) >> 4) | ((input[offset + 6] as u16) << 4);
499                    let t5 = ((input[offset + 6] as u16) >> 7)
500                        | ((input[offset + 7] as u16) << 1)
501                        | ((input[offset + 8] as u16) << 9);
502                    let t6 = ((input[offset + 8] as u16) >> 2) | ((input[offset + 9] as u16) << 6);
503                    let t7 = ((input[offset + 9] as u16) >> 5) | ((input[offset + 10] as u16) << 3);
504                    offset += 11;
505                    let values = [t0, t1, t2, t3, t4, t5, t6, t7];
506                    for (k, value) in values.into_iter().enumerate() {
507                        poly.coeffs[8 * j + k] = ((((value & 0x07ff) as u32) * Q as u32 + 1024) >> 11) as i16;
508                    }
509                }
510            }
511        }
512        _ => unreachable!(),
513    }
514    out
515}
516
517#[inline]
518fn polyvec_tobytes<const K: usize>(out: &mut [u8], polyvec: &PolyVec<K>) {
519    for (i, poly) in polyvec.vec.iter().enumerate() {
520        poly_tobytes(&mut out[i * POLY_BYTES..(i + 1) * POLY_BYTES], poly);
521    }
522}
523
524#[inline]
525fn polyvec_frombytes<const K: usize>(input: &[u8]) -> PolyVec<K> {
526    let mut out = PolyVec::<K>::default();
527    for (i, poly) in out.vec.iter_mut().enumerate() {
528        *poly = poly_frombytes(&input[i * POLY_BYTES..(i + 1) * POLY_BYTES]);
529    }
530    out
531}
532
533#[inline]
534fn polyvec_ntt<const K: usize>(polyvec: &mut PolyVec<K>) {
535    for poly in &mut polyvec.vec {
536        poly_ntt(poly);
537    }
538}
539
540#[inline]
541fn polyvec_invntt_tomont<const K: usize>(polyvec: &mut PolyVec<K>) {
542    for poly in &mut polyvec.vec {
543        poly_invntt_tomont(poly);
544    }
545}
546
547#[inline]
548fn polyvec_basemul_acc_montgomery<const K: usize>(a: &PolyVec<K>, b: &PolyVec<K>) -> Poly {
549    let mut out = poly_basemul_montgomery(&a.vec[0], &b.vec[0]);
550    for i in 1..K {
551        let t = poly_basemul_montgomery(&a.vec[i], &b.vec[i]);
552        poly_add(&mut out, &t);
553    }
554    poly_reduce(&mut out);
555    out
556}
557
558#[inline]
559fn polyvec_reduce<const K: usize>(polyvec: &mut PolyVec<K>) {
560    for poly in &mut polyvec.vec {
561        poly_reduce(poly);
562    }
563}
564
565#[inline]
566fn polyvec_add<const K: usize>(left: &mut PolyVec<K>, right: &PolyVec<K>) {
567    for i in 0..K {
568        poly_add(&mut left.vec[i], &right.vec[i]);
569    }
570}
571
572#[inline]
573fn poly_compress<const K: usize>(params: &MlKemParams<K>, out: &mut [u8], poly: &Poly) {
574    match params.polycompressedbytes {
575        128 => {
576            let mut offset = 0usize;
577            for chunk in poly.coeffs.chunks_exact(8) {
578                let mut t = [0u8; 8];
579                for (dst, coeff) in t.iter_mut().zip(chunk.iter()) {
580                    let mut u = *coeff as i32;
581                    u += (u >> 15) & Q as i32;
582                    let mut d0 = ((u as u32) << 4) as u64;
583                    d0 += 1665;
584                    d0 *= 80_635;
585                    d0 >>= 28;
586                    *dst = (d0 as u8) & 0x0f;
587                }
588                out[offset] = t[0] | (t[1] << 4);
589                out[offset + 1] = t[2] | (t[3] << 4);
590                out[offset + 2] = t[4] | (t[5] << 4);
591                out[offset + 3] = t[6] | (t[7] << 4);
592                offset += 4;
593            }
594        }
595        160 => {
596            let mut offset = 0usize;
597            for chunk in poly.coeffs.chunks_exact(8) {
598                let mut t = [0u8; 8];
599                for (dst, coeff) in t.iter_mut().zip(chunk.iter()) {
600                    let mut u = *coeff as i32;
601                    u += (u >> 15) & Q as i32;
602                    let mut d0 = ((u as u32) << 5) as u64;
603                    d0 += 1664;
604                    d0 *= 40_318;
605                    d0 >>= 27;
606                    *dst = (d0 as u8) & 0x1f;
607                }
608                out[offset] = t[0] | (t[1] << 5);
609                out[offset + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
610                out[offset + 2] = (t[3] >> 1) | (t[4] << 4);
611                out[offset + 3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6);
612                out[offset + 4] = (t[6] >> 2) | (t[7] << 3);
613                offset += 5;
614            }
615        }
616        _ => unreachable!(),
617    }
618}
619
620#[inline]
621fn poly_decompress<const K: usize>(params: &MlKemParams<K>, input: &[u8]) -> Poly {
622    let mut out = Poly::default();
623    match params.polycompressedbytes {
624        128 => {
625            for i in 0..(N / 2) {
626                out.coeffs[2 * i] = ((((input[i] & 0x0f) as u16) * Q as u16 + 8) >> 4) as i16;
627                out.coeffs[2 * i + 1] = ((((input[i] >> 4) as u16) * Q as u16 + 8) >> 4) as i16;
628            }
629        }
630        160 => {
631            let mut offset = 0usize;
632            for i in 0..(N / 8) {
633                let t0 = input[offset] >> 0;
634                let t1 = (input[offset] >> 5) | (input[offset + 1] << 3);
635                let t2 = input[offset + 1] >> 2;
636                let t3 = (input[offset + 1] >> 7) | (input[offset + 2] << 1);
637                let t4 = (input[offset + 2] >> 4) | (input[offset + 3] << 4);
638                let t5 = input[offset + 3] >> 1;
639                let t6 = (input[offset + 3] >> 6) | (input[offset + 4] << 2);
640                let t7 = input[offset + 4] >> 3;
641                offset += 5;
642                let values = [t0, t1, t2, t3, t4, t5, t6, t7];
643                for (j, value) in values.into_iter().enumerate() {
644                    out.coeffs[8 * i + j] = (((value as u32 & 31) * Q as u32 + 16) >> 5) as i16;
645                }
646            }
647        }
648        _ => unreachable!(),
649    }
650    out
651}
652
653#[inline]
654fn poly_tobytes(out: &mut [u8], poly: &Poly) {
655    for i in 0..(N / 2) {
656        let mut t0 = poly.coeffs[2 * i] as i32;
657        t0 += (t0 >> 15) & Q as i32;
658        let mut t1 = poly.coeffs[2 * i + 1] as i32;
659        t1 += (t1 >> 15) & Q as i32;
660        out[3 * i] = t0 as u8;
661        out[3 * i + 1] = ((t0 >> 8) as u8) | ((t1 << 4) as u8);
662        out[3 * i + 2] = (t1 >> 4) as u8;
663    }
664}
665
666#[inline]
667fn poly_frombytes(input: &[u8]) -> Poly {
668    let mut out = Poly::default();
669    for i in 0..(N / 2) {
670        out.coeffs[2 * i] = (((input[3 * i] as u16) | ((input[3 * i + 1] as u16) << 8)) & 0x0fff) as i16;
671        out.coeffs[2 * i + 1] = ((((input[3 * i + 1] as u16) >> 4) | ((input[3 * i + 2] as u16) << 4)) & 0x0fff) as i16;
672    }
673    out
674}
675
676#[inline]
677pub(crate) fn poly_frommsg(msg: &[u8]) -> Poly {
678    let mut out = Poly::default();
679    let half_q: i16 = ((Q + 1) / 2) as i16;
680    for i in 0..(N / 8) {
681        for j in 0..8 {
682            let bit = ((msg[i] >> j) & 1) as i16;
683            out.coeffs[8 * i + j] = (-bit) & half_q;
684        }
685    }
686    out
687}
688
689#[inline]
690pub(crate) fn poly_tomsg(poly: &Poly) -> [u8; 32] {
691    let mut msg = [0u8; 32];
692    for i in 0..(N / 8) {
693        for j in 0..8 {
694            let mut t = poly.coeffs[8 * i + j] as i32;
695            t <<= 1;
696            t += 1665;
697            t *= 80_635;
698            t >>= 28;
699            msg[i] |= ((t & 1) as u8) << j;
700        }
701    }
702    msg
703}
704
705#[inline]
706fn poly_ntt(poly: &mut Poly) {
707    ntt(&mut poly.coeffs);
708    poly_reduce(poly);
709}
710
711#[inline]
712fn poly_invntt_tomont(poly: &mut Poly) {
713    invntt(&mut poly.coeffs);
714}
715
716#[inline]
717fn poly_basemul_montgomery(a: &Poly, b: &Poly) -> Poly {
718    let mut out = Poly::default();
719    for i in 0..(N / 4) {
720        let r0 = basemul(
721            [a.coeffs[4 * i], a.coeffs[4 * i + 1]],
722            [b.coeffs[4 * i], b.coeffs[4 * i + 1]],
723            ZETAS[64 + i],
724        );
725        out.coeffs[4 * i] = r0[0];
726        out.coeffs[4 * i + 1] = r0[1];
727
728        let r1 = basemul(
729            [a.coeffs[4 * i + 2], a.coeffs[4 * i + 3]],
730            [b.coeffs[4 * i + 2], b.coeffs[4 * i + 3]],
731            -ZETAS[64 + i],
732        );
733        out.coeffs[4 * i + 2] = r1[0];
734        out.coeffs[4 * i + 3] = r1[1];
735    }
736    out
737}
738
739#[inline]
740fn poly_tomont(poly: &mut Poly) {
741    for coeff in &mut poly.coeffs {
742        *coeff = montgomery_reduce(*coeff as i32 * 1353);
743    }
744}
745
746#[inline]
747fn poly_reduce(poly: &mut Poly) {
748    for coeff in &mut poly.coeffs {
749        *coeff = barrett_reduce(*coeff);
750    }
751}
752
753#[inline]
754fn poly_add(left: &mut Poly, right: &Poly) {
755    for i in 0..N {
756        left.coeffs[i] = (left.coeffs[i] as i32 + right.coeffs[i] as i32) as i16;
757    }
758}
759
760#[inline]
761fn poly_sub(out: &mut Poly, left: &Poly, right: &Poly) {
762    for i in 0..N {
763        out.coeffs[i] = (left.coeffs[i] as i32 - right.coeffs[i] as i32) as i16;
764    }
765}
766
767#[inline]
768fn ntt(r: &mut [i16; N]) {
769    let mut k = 1usize;
770    let mut len = 128usize;
771    while len >= 2 {
772        let mut start = 0usize;
773        while start < N {
774            let zeta = ZETAS[k];
775            k += 1;
776            for j in start..start + len {
777                let t = fqmul(zeta, r[j + len]);
778                let rj = r[j] as i32;
779                r[j + len] = (rj - t as i32) as i16;
780                r[j] = (rj + t as i32) as i16;
781            }
782            start += 2 * len;
783        }
784        len >>= 1;
785    }
786}
787
788#[inline]
789fn invntt(r: &mut [i16; N]) {
790    let mut k = 127usize;
791    let mut len = 2usize;
792    while len <= 128 {
793        let mut start = 0usize;
794        while start < N {
795            let zeta = ZETAS[k];
796            k -= 1;
797            for j in start..start + len {
798                let t = r[j];
799                r[j] = barrett_reduce((t as i32 + r[j + len] as i32) as i16);
800                r[j + len] = fqmul(zeta, (r[j + len] as i32 - t as i32) as i16);
801            }
802            start += 2 * len;
803        }
804        len <<= 1;
805    }
806
807    for coeff in r.iter_mut() {
808        *coeff = fqmul(*coeff, MONT_SQUARED_DIV_N);
809    }
810}
811
812#[inline]
813fn basemul(a: [i16; 2], b: [i16; 2], zeta: i16) -> [i16; 2] {
814    let mut out = [0i16; 2];
815    out[0] = fqmul(a[1], b[1]);
816    out[0] = fqmul(out[0], zeta);
817    out[0] = (out[0] as i32 + fqmul(a[0], b[0]) as i32) as i16;
818    out[1] = (fqmul(a[0], b[1]) as i32 + fqmul(a[1], b[0]) as i32) as i16;
819    out
820}
821
822#[inline]
823fn fqmul(a: i16, b: i16) -> i16 {
824    montgomery_reduce(a as i32 * b as i32)
825}
826
827#[inline]
828fn montgomery_reduce(a: i32) -> i16 {
829    let t = (a as i16).wrapping_mul(QINV) as i32;
830    ((a - t * Q as i32) >> 16) as i16
831}
832
833#[inline]
834fn barrett_reduce(a: i16) -> i16 {
835    const V: i32 = ((1 << 26) + (Q as i32 / 2)) / Q as i32;
836    let t = ((V * a as i32 + (1 << 25)) >> 26) * Q as i32;
837    (a as i32 - t) as i16
838}
839
840#[inline]
841fn hash_h(data: &[u8]) -> [u8; 32] {
842    use crate::Hasher;
843    let mut hasher = Sha3_256::new();
844    hasher.update(data);
845    hasher.sum().as_ref().try_into().unwrap()
846}
847
848#[inline]
849fn hash_g(data: &[u8]) -> [u8; 64] {
850    use crate::Hasher;
851    let mut hasher = Sha3_512::new();
852    hasher.update(data);
853    hasher.sum().as_ref().try_into().unwrap()
854}
855
856#[inline]
857fn rkprf(cipher_key: &[u8; 32], ciphertext: &[u8]) -> [u8; 32] {
858    let mut shake = Shake256::new();
859    shake.absorb(cipher_key);
860    shake.absorb(ciphertext);
861    let mut out = [0u8; 32];
862    shake.squeeze(&mut out);
863    out
864}
865
866/// Constant-time conditional move: if `cond` is true, copies `value` into `out`.
867/// Uses a compiler barrier on the mask to prevent the optimizer from turning this
868/// into a branch (which would leak timing information in the FO transform).
869#[inline]
870fn cmov(out: &mut [u8; 32], value: &[u8; 32], cond: bool) {
871    let mask = ct_mask_u8(cond);
872    for i in 0..32 {
873        out[i] ^= mask & (out[i] ^ value[i]);
874    }
875}
876
877/// Converts a boolean condition to a constant-time mask (0x00 or 0xFF) with a compiler
878/// barrier to prevent optimization into a branch.
879#[inline]
880fn ct_mask_u8(cond: bool) -> u8 {
881    let mask = 0u8.wrapping_sub(cond as u8);
882    // Prevent the compiler from reasoning about the mask value and potentially
883    // converting downstream code into a conditional branch.
884    ct_barrier_u8(mask)
885}
886
887#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
888#[inline]
889fn ct_barrier_u8(mut value: u8) -> u8 {
890    // SAFETY: the inline asm is a no-op that forces the compiler to treat `value`
891    // as an opaque value, preventing branch-based optimizations.
892    unsafe {
893        core::arch::asm!("/* {0} */", inout(reg_byte) value, options(pure, nomem, nostack, preserves_flags));
894    }
895    value
896}
897
898#[cfg(any(
899    target_arch = "aarch64",
900    target_arch = "arm",
901    target_arch = "riscv32",
902    target_arch = "riscv64"
903))]
904#[inline]
905#[allow(asm_sub_register)]
906fn ct_barrier_u8(mut value: u8) -> u8 {
907    unsafe {
908        core::arch::asm!("/* {0} */", inout(reg) value, options(pure, nomem, nostack, preserves_flags));
909    }
910    value
911}
912
913#[cfg(not(any(
914    target_arch = "x86",
915    target_arch = "x86_64",
916    target_arch = "aarch64",
917    target_arch = "arm",
918    target_arch = "riscv32",
919    target_arch = "riscv64"
920)))]
921#[inline(never)]
922fn ct_barrier_u8(value: u8) -> u8 {
923    core::hint::black_box(value)
924}
925
926#[inline]
927fn load32(input: &[u8]) -> u32 {
928    (input[0] as u32) | ((input[1] as u32) << 8) | ((input[2] as u32) << 16) | ((input[3] as u32) << 24)
929}
930
931#[inline]
932pub(crate) fn public_key_bytes<const K: usize>() -> usize {
933    polyvec_bytes::<K>() + SYMBYTES
934}
935
936#[inline]
937pub(crate) fn indcpa_secret_key_bytes<const K: usize>() -> usize {
938    polyvec_bytes::<K>()
939}
940
941#[inline]
942fn polyvec_bytes<const K: usize>() -> usize {
943    K * POLY_BYTES
944}
945
946#[inline]
947pub(crate) fn secret_key_size<const K: usize>() -> usize {
948    indcpa_secret_key_bytes::<K>() + public_key_bytes::<K>() + 2 * SYMBYTES
949}
950
951#[inline]
952fn ciphertext_bytes<const K: usize>(params: &MlKemParams<K>) -> usize {
953    params.polyveccompressedbytes + params.polycompressedbytes
954}
955
956#[inline]
957fn array_ref_32(input: &[u8]) -> &[u8; 32] {
958    input.try_into().expect("slice length should be 32")
959}
960
961#[cfg(test)]
962pub(crate) fn decode_hex_array<const N: usize>(s: &str) -> [u8; N] {
963    let bytes = hex::decode(s).expect("valid hex");
964    assert_eq!(bytes.len(), N);
965    let mut out = [0u8; N];
966    out.copy_from_slice(&bytes);
967    out
968}
969
970#[cfg(test)]
971pub(crate) fn sha3_256_hex(data: &[u8]) -> String {
972    use crate::Hasher;
973    let mut hasher = Sha3_256::new();
974    hasher.update(data);
975    hex::encode(hasher.sum().as_ref())
976}
977
978#[cfg(test)]
979mod tests {
980    use super::*;
981
982    #[test]
983    fn poly_frommsg_tomsg_roundtrip() {
984        for pattern in 0..=255u16 {
985            let mut msg = [0u8; 32];
986            msg[0] = pattern as u8;
987            msg[1] = (pattern >> 8) as u8;
988            let poly = poly_frommsg(&msg);
989            let recovered = poly_tomsg(&poly);
990            assert_eq!(msg, recovered, "roundtrip failed for pattern {pattern:#06x}");
991        }
992    }
993
994    #[test]
995    fn poly_frommsg_constant_time_produces_expected_values() {
996        let half_q = ((Q + 1) / 2) as i16;
997        let mut msg = [0u8; 32];
998        msg[0] = 0b1010_1010;
999        msg[1] = 0b0101_0101;
1000        let poly = poly_frommsg(&msg);
1001        assert_eq!(poly.coeffs[0], 0);
1002        assert_eq!(poly.coeffs[1], half_q);
1003        assert_eq!(poly.coeffs[2], 0);
1004        assert_eq!(poly.coeffs[3], half_q);
1005        assert_eq!(poly.coeffs[8], half_q);
1006        assert_eq!(poly.coeffs[9], 0);
1007        assert_eq!(poly.coeffs[10], half_q);
1008        assert_eq!(poly.coeffs[11], 0);
1009    }
1010
1011    #[test]
1012    fn cmov_selects_correctly() {
1013        let mut out = [0xAAu8; 32];
1014        let value = [0xBBu8; 32];
1015        cmov(&mut out, &value, false);
1016        assert_eq!(out, [0xAAu8; 32], "cmov with false should not modify output");
1017
1018        cmov(&mut out, &value, true);
1019        assert_eq!(out, [0xBBu8; 32], "cmov with true should copy value");
1020    }
1021
1022    #[test]
1023    fn cmov_is_idempotent() {
1024        let mut out = [0x42u8; 32];
1025        let value = [0x42u8; 32];
1026        cmov(&mut out, &value, true);
1027        assert_eq!(out, [0x42u8; 32]);
1028        cmov(&mut out, &value, false);
1029        assert_eq!(out, [0x42u8; 32]);
1030    }
1031
1032    #[test]
1033    fn barrett_reduce_produces_values_in_range() {
1034        // Barrett reduce should map any i16 to the range [-(Q-1)/2, (Q-1)/2] approximately
1035        for val in [0i16, 1, -1, Q - 1, -(Q - 1), Q, -Q, 3000, -3000, i16::MAX, i16::MIN] {
1036            let reduced = barrett_reduce(val);
1037            // The reduced value should be congruent to val mod Q
1038            let diff = (val as i32 - reduced as i32).rem_euclid(Q as i32);
1039            assert!(diff == 0, "barrett_reduce({val}) = {reduced} not congruent mod Q");
1040        }
1041    }
1042
1043    #[test]
1044    fn montgomery_reduce_correctness() {
1045        // Montgomery reduce: given a, return a * R^(-1) mod Q where R = 2^16
1046        // Verify: montgomery_reduce(a * R) == a mod Q for small a
1047        let r_mod_q: i32 = (1i32 << 16) % Q as i32; // R mod Q = 65536 mod 3329 = 2285
1048        for val in [0i16, 1, -1, 100, -100, Q - 1, -(Q - 1)] {
1049            let product = val as i32 * r_mod_q;
1050            let result = montgomery_reduce(product);
1051            // result should be congruent to val mod Q
1052            let diff = (val as i32 - result as i32).rem_euclid(Q as i32);
1053            assert!(
1054                diff == 0,
1055                "montgomery_reduce({val} * R) = {result}, expected congruent to {val} mod Q"
1056            );
1057        }
1058    }
1059
1060    #[test]
1061    fn ntt_invntt_preserves_polynomial_structure() {
1062        // NTT->InvNTT roundtrip preserves polynomial relationships.
1063        // The full KEM roundtrip tests already validate NTT correctness,
1064        // but this verifies that two distinct inputs remain distinct after transform.
1065        let mut poly_a = Poly::default();
1066        let mut poly_b = Poly::default();
1067        for i in 0..N {
1068            poly_a.coeffs[i] = (i as i16 * 7 + 3) % Q;
1069            poly_b.coeffs[i] = (i as i16 * 11 + 5) % Q;
1070        }
1071        poly_ntt(&mut poly_a);
1072        poly_ntt(&mut poly_b);
1073        // NTT outputs should be different for different inputs
1074        assert_ne!(poly_a.coeffs, poly_b.coeffs);
1075
1076        poly_invntt_tomont(&mut poly_a);
1077        poly_invntt_tomont(&mut poly_b);
1078        // After roundtrip, they should still be different
1079        assert_ne!(poly_a.coeffs, poly_b.coeffs);
1080    }
1081
1082    #[test]
1083    fn poly_compress_decompress_roundtrip_4bit() {
1084        // For 4-bit compression (ML-KEM-768)
1085        let params = &ML_KEM_768;
1086        let mut poly = Poly::default();
1087        for i in 0..N {
1088            poly.coeffs[i] = ((i * 13) % Q as usize) as i16;
1089        }
1090        let mut compressed = [0u8; 128];
1091        poly_compress::<3>(params, &mut compressed, &poly);
1092        let decompressed = poly_decompress::<3>(params, &compressed);
1093        // Compression is lossy but within rounding error
1094        for i in 0..N {
1095            let orig = poly.coeffs[i] as i32;
1096            let dec = decompressed.coeffs[i] as i32;
1097            // Maximum rounding error for d-bit compression: Q / (2^(d+1))
1098            // For 4 bits: Q/32 ≈ 104
1099            let error = ((orig - dec).rem_euclid(Q as i32)).min(((dec - orig).rem_euclid(Q as i32)));
1100            assert!(
1101                error <= Q as i32 / 32 + 1,
1102                "4-bit compress/decompress error too large at index {i}: orig={orig}, dec={dec}, error={error}"
1103            );
1104        }
1105    }
1106
1107    #[test]
1108    fn poly_compress_decompress_roundtrip_5bit() {
1109        // For 5-bit compression (ML-KEM-1024)
1110        let params = &ML_KEM_1024;
1111        let mut poly = Poly::default();
1112        for i in 0..N {
1113            poly.coeffs[i] = ((i * 13) % Q as usize) as i16;
1114        }
1115        let mut compressed = [0u8; 160];
1116        poly_compress::<4>(params, &mut compressed, &poly);
1117        let decompressed = poly_decompress::<4>(params, &compressed);
1118        for i in 0..N {
1119            let orig = poly.coeffs[i] as i32;
1120            let dec = decompressed.coeffs[i] as i32;
1121            let error = ((orig - dec).rem_euclid(Q as i32)).min(((dec - orig).rem_euclid(Q as i32)));
1122            assert!(
1123                error <= Q as i32 / 64 + 1,
1124                "5-bit compress/decompress error too large at index {i}: orig={orig}, dec={dec}, error={error}"
1125            );
1126        }
1127    }
1128
1129    #[test]
1130    fn polyvec_compress_decompress_roundtrip_10bit() {
1131        let params = &ML_KEM_768;
1132        let mut pv = PolyVec::<3>::default();
1133        for k in 0..3 {
1134            for i in 0..N {
1135                pv.vec[k].coeffs[i] = ((k * 97 + i * 13) % Q as usize) as i16;
1136            }
1137        }
1138        let mut compressed = [0u8; 960];
1139        polyvec_compress(params, &mut compressed, &pv);
1140        let decompressed = polyvec_decompress::<3>(params, &compressed);
1141        for k in 0..3 {
1142            for i in 0..N {
1143                let orig = pv.vec[k].coeffs[i] as i32;
1144                let dec = decompressed.vec[k].coeffs[i] as i32;
1145                let error = ((orig - dec).rem_euclid(Q as i32)).min(((dec - orig).rem_euclid(Q as i32)));
1146                assert!(
1147                    error <= Q as i32 / 2048 + 1,
1148                    "10-bit compress/decompress error at [{k}][{i}]: orig={orig}, dec={dec}, error={error}"
1149                );
1150            }
1151        }
1152    }
1153
1154    #[test]
1155    fn polyvec_compress_decompress_roundtrip_11bit() {
1156        let params = &ML_KEM_1024;
1157        let mut pv = PolyVec::<4>::default();
1158        for k in 0..4 {
1159            for i in 0..N {
1160                pv.vec[k].coeffs[i] = ((k * 97 + i * 13) % Q as usize) as i16;
1161            }
1162        }
1163        let mut compressed = [0u8; 1408];
1164        polyvec_compress(params, &mut compressed, &pv);
1165        let decompressed = polyvec_decompress::<4>(params, &compressed);
1166        for k in 0..4 {
1167            for i in 0..N {
1168                let orig = pv.vec[k].coeffs[i] as i32;
1169                let dec = decompressed.vec[k].coeffs[i] as i32;
1170                let error = ((orig - dec).rem_euclid(Q as i32)).min(((dec - orig).rem_euclid(Q as i32)));
1171                assert!(
1172                    error <= Q as i32 / 4096 + 1,
1173                    "11-bit compress/decompress error at [{k}][{i}]: orig={orig}, dec={dec}, error={error}"
1174                );
1175            }
1176        }
1177    }
1178
1179    #[test]
1180    fn poly_tobytes_frombytes_roundtrip() {
1181        let mut poly = Poly::default();
1182        for i in 0..N {
1183            poly.coeffs[i] = (i as i16 * 13) % Q;
1184        }
1185        let mut buf = [0u8; POLY_BYTES];
1186        poly_tobytes(&mut buf, &poly);
1187        let recovered = poly_frombytes(&buf);
1188        assert_eq!(poly.coeffs, recovered.coeffs);
1189    }
1190
1191    #[test]
1192    fn gen_matrix_transpose_relationship() {
1193        let seed = [42u8; 32];
1194        let matrix = gen_matrix::<3>(&seed, false);
1195        let transposed = gen_matrix::<3>(&seed, true);
1196        for i in 0..3 {
1197            for j in 0..3 {
1198                assert_eq!(
1199                    matrix[i].vec[j].coeffs, transposed[j].vec[i].coeffs,
1200                    "A[{i}][{j}] != A^T[{j}][{i}]"
1201                );
1202            }
1203        }
1204    }
1205
1206    #[test]
1207    fn cbd2_produces_values_in_correct_range() {
1208        // CBD with eta=2 should produce coefficients in [-2, 2]
1209        let mut buf = [0u8; 128];
1210        for i in 0..128 {
1211            buf[i] = (i as u8).wrapping_mul(0x37);
1212        }
1213        let poly = cbd2(&buf);
1214        for (i, &coeff) in poly.coeffs.iter().enumerate() {
1215            assert!((-2..=2).contains(&coeff), "CBD2 coeff[{i}] = {coeff} out of range [-2, 2]");
1216        }
1217    }
1218
1219    #[test]
1220    fn rej_uniform_only_accepts_values_less_than_q() {
1221        // Craft input where val0 = Q (3329 = 0xD01) should be rejected
1222        // rej_uniform parses 3 bytes into 2 12-bit values:
1223        // val0 = (buf[0] | buf[1]<<8) & 0x0fff
1224        // val1 = ((buf[1]>>4) | buf[2]<<4) & 0x0fff
1225        let buf = [
1226            0x01, 0x0D,
1227            0x00, // val0 = 0xD01 = 3329 = Q (rejected), val1 = (0x0D>>4 | 0x00<<4) & 0xfff = 0 (accepted)
1228            0x00, 0x0D,
1229            0xD0, // val0 = 0xD00 = 3328 (accepted), val1 = (0x0D>>4 | 0xD0<<4) & 0xfff = 0xD00 = 3328 (accepted)
1230        ];
1231        let mut out = [0i16; 256];
1232        let count = rej_uniform(&mut out, &buf);
1233        // val0=Q rejected, val1=0 accepted, val0=3328 accepted, val1=3328 accepted
1234        assert_eq!(count, 3);
1235        assert_eq!(out[0], 0); // first accepted: val1 from first triple
1236        assert_eq!(out[1], 3328); // second accepted: val0 from second triple
1237        assert_eq!(out[2], 3328); // third accepted: val1 from second triple
1238    }
1239
1240    #[test]
1241    fn nist_acvp_ml_kem_768_full_vector() {
1242        // Verify against NIST FIPS 203 intermediate test vector (ML-KEM-768.txt)
1243        // These values come from the NIST test file and are validated by the CCTV tests
1244        let d: [u8; 32] = decode_hex_array("f688563f7c66a5da2d8bdb5a5f3e07bd8dce6f7efcec7f41298d79863459f7cd");
1245        let z: [u8; 32] = decode_hex_array("d1d49a515250dbceb9f6e3fcc1c7d5306918964b21ddb22207e03e57f0600da8");
1246        let m: [u8; 32] = decode_hex_array("3dc27ca0a6594b0e56320457c45a0f76bb8a213ea4a76d442186a0aefadbcdb9");
1247
1248        let mut coins = [0u8; 64];
1249        coins[..32].copy_from_slice(&d);
1250        coins[32..].copy_from_slice(&z);
1251
1252        let (dk, ek) = crypto_kem_keypair_derand::<3, 2400, 1184>(&ML_KEM_768, &coins);
1253        let (ct, k) = crypto_kem_enc_derand::<3, 1184, 1088>(&ML_KEM_768, &ek, &m);
1254
1255        // Verify public key hash matches NIST vector
1256        assert_eq!(
1257            sha3_256_hex(&ek),
1258            "42d930a50dfd1f0541ca45c4598daebb4f51cd10d711a001bd9bb87d5c87a4bf"
1259        );
1260        // Verify secret key hash
1261        assert_eq!(
1262            sha3_256_hex(&dk),
1263            "db563aebd9fdc875e88563693edad1e5e359cc37b0f685d2d0a3723b37253192"
1264        );
1265        // Verify ciphertext hash
1266        assert_eq!(
1267            sha3_256_hex(&ct),
1268            "9d6e358208c4d583050becb319050b7f916de47caad1d589a1d01fea43fe1750"
1269        );
1270        // Verify shared secret
1271        assert_eq!(
1272            hex::encode(k),
1273            "ae726da2df66601c6648a7565c02b203a089276ac30f6cc226d048f93fafd78c"
1274        );
1275
1276        // Verify decapsulation produces the same shared secret
1277        let k_dec = crypto_kem_dec::<3, 2400, 1088>(&ML_KEM_768, &dk, &ct).unwrap();
1278        assert_eq!(k, k_dec, "decapsulation mismatch against NIST vector");
1279    }
1280
1281    #[test]
1282    fn nist_acvp_ml_kem_1024_full_vector() {
1283        // Verify against NIST FIPS 203 intermediate test vector (ML-KEM-1024.txt)
1284        let d: [u8; 32] = decode_hex_array("2a62c39ef4fc499f2d132716f480bb7521a49558ae84ee80d9352e66daf1e3a8");
1285        let z: [u8; 32] = decode_hex_array("5f574ef7f013d4336801fed022178c3ed91d0b6d51325315fc1dcabf4770a2ea");
1286        let m: [u8; 32] = decode_hex_array("e07d685ed308e609c9c7842026e35732f6ffc6e2fee10f0afd348f2b42a8acb4");
1287
1288        let mut coins = [0u8; 64];
1289        coins[..32].copy_from_slice(&d);
1290        coins[32..].copy_from_slice(&z);
1291
1292        let (dk, ek) = crypto_kem_keypair_derand::<4, 3168, 1568>(&ML_KEM_1024, &coins);
1293        let (ct, k) = crypto_kem_enc_derand::<4, 1568, 1568>(&ML_KEM_1024, &ek, &m);
1294
1295        assert_eq!(
1296            sha3_256_hex(&ek),
1297            "3b308d1344ed70366b84d790acb705b86cd3dfd471fff171969aaa338f26dca5"
1298        );
1299        assert_eq!(
1300            sha3_256_hex(&dk),
1301            "aa63a9e0c035ada6635e7938b71856b24917ff9b3ebca1a4d205a83b502a415a"
1302        );
1303        assert_eq!(
1304            sha3_256_hex(&ct),
1305            "8caba02733421f12a7ba9a2bcbe4de7c9853156a0637df5a7a0f9127c81da943"
1306        );
1307        assert_eq!(
1308            hex::encode(k),
1309            "d53825c3ff666bb2881215dbec04a8bdce9099b2a3680938c2f199b54d505953"
1310        );
1311
1312        let k_dec = crypto_kem_dec::<4, 3168, 1568>(&ML_KEM_1024, &dk, &ct).unwrap();
1313        assert_eq!(k, k_dec, "decapsulation mismatch against NIST vector");
1314    }
1315
1316    #[test]
1317    fn compression_constant_time_no_division() {
1318        // Verify that the compression constants avoid division at runtime.
1319        // This test exercises boundary values where a naive division would
1320        // produce different rounding behavior than the multiplication trick.
1321        let params_768 = &ML_KEM_768;
1322        let params_1024 = &ML_KEM_1024;
1323
1324        // Test boundary values for poly_compress (4-bit)
1325        let mut poly = Poly::default();
1326        poly.coeffs[0] = 0;
1327        poly.coeffs[1] = (Q - 1) as i16;
1328        poly.coeffs[2] = (Q / 2) as i16;
1329        poly.coeffs[3] = (Q / 2 + 1) as i16;
1330        let mut buf4 = [0u8; 128];
1331        poly_compress::<3>(params_768, &mut buf4, &poly);
1332        let dec = poly_decompress::<3>(params_768, &buf4);
1333        // Verify round-trip for boundary values
1334        assert_eq!(dec.coeffs[0], 0); // 0 should compress/decompress to 0
1335
1336        // Test boundary values for poly_compress (5-bit)
1337        let mut buf5 = [0u8; 160];
1338        poly_compress::<4>(params_1024, &mut buf5, &poly);
1339        let dec5 = poly_decompress::<4>(params_1024, &buf5);
1340        assert_eq!(dec5.coeffs[0], 0);
1341    }
1342
1343    #[test]
1344    fn poly_tomsg_boundary_values() {
1345        // Test poly_tomsg at the decision boundary: Q/4 and 3Q/4
1346        let mut poly = Poly::default();
1347        // Value 0 should produce bit 0
1348        poly.coeffs[0] = 0;
1349        // Value Q/2 (1665) should produce bit 1
1350        poly.coeffs[1] = (Q / 2) as i16;
1351        // Value Q/4 (832) is at the boundary
1352        poly.coeffs[2] = (Q / 4) as i16;
1353        // Value 3Q/4 (2497) is at the other boundary
1354        poly.coeffs[3] = (3 * Q as i32 / 4) as i16;
1355
1356        let msg = poly_tomsg(&poly);
1357        // bit 0: value 0 -> 0
1358        assert_eq!(msg[0] & 1, 0);
1359        // bit 1: value Q/2 -> 1
1360        assert_eq!((msg[0] >> 1) & 1, 1);
1361    }
1362}