1use super::mlkem::{
2 ML_KEM_768, MlKemError, SHARED_SECRET_SIZE, crypto_kem_dec, crypto_kem_enc_derand, crypto_kem_keypair_derand,
3 indcpa_secret_key_bytes,
4};
5
6pub const PUBLIC_KEY_SIZE_768: usize = 1184;
7pub const SECRET_KEY_SIZE_768: usize = 2400;
8pub const CIPHERTEXT_SIZE_768: usize = 1088;
9
10#[derive(Clone, Debug, PartialEq, Eq)]
23#[cfg_attr(feature = "zeroize", derive(zeroize::Zeroize, zeroize::ZeroizeOnDrop))]
24pub struct SecretKey768 {
25 bytes: [u8; SECRET_KEY_SIZE_768],
26}
27
28#[derive(Clone, Debug, PartialEq, Eq)]
32pub struct PublicKey768 {
33 bytes: [u8; PUBLIC_KEY_SIZE_768],
34}
35
36#[inline]
42pub fn generate_keypair_768() -> (SecretKey768, PublicKey768) {
43 SecretKey768::generate()
44}
45
46#[inline]
47pub(crate) fn generate_keypair_768_derand(coins: &[u8; 64]) -> (SecretKey768, PublicKey768) {
48 SecretKey768::generate_derand(coins)
49}
50
51impl SecretKey768 {
52 pub fn from_bytes(bytes: &[u8; SECRET_KEY_SIZE_768]) -> Self {
53 Self {
54 bytes: *bytes,
55 }
56 }
57
58 pub fn to_bytes(&self) -> [u8; SECRET_KEY_SIZE_768] {
59 self.bytes
60 }
61
62 pub fn generate() -> (Self, PublicKey768) {
63 let coins: [u8; 64] = rand::random();
64 Self::generate_derand(&coins)
65 }
66
67 pub(crate) fn generate_derand(coins: &[u8; 64]) -> (Self, PublicKey768) {
68 let (sk_bytes, pk_bytes) =
69 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, coins);
70 (
71 Self {
72 bytes: sk_bytes,
73 },
74 PublicKey768 {
75 bytes: pk_bytes,
76 },
77 )
78 }
79
80 pub fn decapsulate(&self, ciphertext: &[u8; CIPHERTEXT_SIZE_768]) -> Result<[u8; SHARED_SECRET_SIZE], MlKemError> {
81 crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &self.bytes, ciphertext)
82 }
83
84 pub fn public_key(&self) -> PublicKey768 {
85 let offset = indcpa_secret_key_bytes::<3>();
86 let mut pk_bytes = [0u8; PUBLIC_KEY_SIZE_768];
87 pk_bytes.copy_from_slice(&self.bytes[offset..offset + PUBLIC_KEY_SIZE_768]);
88 PublicKey768 {
89 bytes: pk_bytes,
90 }
91 }
92}
93
94impl From<&[u8; SECRET_KEY_SIZE_768]> for SecretKey768 {
95 fn from(bytes: &[u8; SECRET_KEY_SIZE_768]) -> Self {
96 Self::from_bytes(bytes)
97 }
98}
99
100impl TryFrom<&[u8]> for SecretKey768 {
101 type Error = MlKemError;
102
103 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
104 Ok(Self::from_bytes(bytes.try_into().map_err(|_| MlKemError::InvalidKey)?))
105 }
106}
107
108impl PublicKey768 {
109 pub fn from_bytes(bytes: &[u8; PUBLIC_KEY_SIZE_768]) -> Self {
110 Self {
111 bytes: *bytes,
112 }
113 }
114
115 pub fn to_bytes(&self) -> [u8; PUBLIC_KEY_SIZE_768] {
116 self.bytes
117 }
118
119 pub fn encapsulate(&self) -> ([u8; CIPHERTEXT_SIZE_768], [u8; SHARED_SECRET_SIZE]) {
120 let coins: [u8; 32] = rand::random();
121 self.encapsulate_derand(&coins)
122 }
123
124 pub(crate) fn encapsulate_derand(&self, coins: &[u8; 32]) -> ([u8; CIPHERTEXT_SIZE_768], [u8; SHARED_SECRET_SIZE]) {
125 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &self.bytes, coins)
126 }
127}
128
129impl From<&[u8; PUBLIC_KEY_SIZE_768]> for PublicKey768 {
130 fn from(bytes: &[u8; PUBLIC_KEY_SIZE_768]) -> Self {
131 Self::from_bytes(bytes)
132 }
133}
134
135impl TryFrom<&[u8]> for PublicKey768 {
136 type Error = MlKemError;
137
138 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
139 Ok(Self::from_bytes(bytes.try_into().map_err(|_| MlKemError::InvalidKey)?))
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::{
146 super::mlkem::{
147 ML_KEM_768, crypto_kem_dec, crypto_kem_enc_derand, crypto_kem_keypair_derand, decode_hex_array,
148 sha3_256_hex,
149 },
150 *,
151 };
152
153 #[test]
154 fn ml_kem_768_round_trip() {
155 let (private_key, public_key) = generate_keypair_768();
156 let (ciphertext, encapsulated_secret) = public_key.encapsulate();
157 let decapsulated_secret = private_key.decapsulate(&ciphertext).unwrap();
158
159 assert_eq!(encapsulated_secret, decapsulated_secret);
160 }
161
162 #[test]
163 fn ml_kem_768_decapsulation_rejects_tampered_ciphertext() {
164 let (private_key, public_key) = generate_keypair_768();
165 let (mut ciphertext, encapsulated_secret) = public_key.encapsulate();
166
167 ciphertext[0] ^= 0x80;
168
169 let decapsulated_secret = private_key.decapsulate(&ciphertext).unwrap();
170
171 assert_ne!(encapsulated_secret, decapsulated_secret);
172 }
173
174 #[test]
175 fn ml_kem_768_deterministic_derand_vectors_are_stable() {
176 let key_coins = [7u8; 64];
177 let enc_coins = [9u8; 32];
178 let (secret_key, public_key) =
179 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &key_coins);
180 let (ciphertext, shared_secret) =
181 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &public_key, &enc_coins);
182 let decapsulated =
183 crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &secret_key, &ciphertext)
184 .unwrap();
185
186 assert_eq!(shared_secret, decapsulated);
187 assert_eq!(
188 hex::encode(&public_key[..32]),
189 "925a2700ad064ff778b4da4cf51457a48224a52751250a8ee10b251c818bafca"
190 );
191 assert_eq!(
192 hex::encode(&ciphertext[..32]),
193 "766c326c3483444c5b6d917cdddc3c07fbf935295c8f17c92a187a80dc4d15f2"
194 );
195 assert_eq!(
196 hex::encode(shared_secret),
197 "afcf18dfd6b710a09b5cf591d0eb8229d83aa10904934a3ca60a52da5ff36b96"
198 );
199 }
200
201 #[test]
202 fn ml_kem_768_cctv_accumulated_10k() {
203 use crate::{Xof, sha3::Shake128};
204
205 let mut rng = Shake128::new();
206 rng.absorb(&[]);
207
208 let mut acc = Shake128::new();
209
210 for _ in 0..10_000u32 {
211 let mut d = [0u8; 32];
212 let mut z = [0u8; 32];
213 let mut m = [0u8; 32];
214 let mut ct_random = [0u8; CIPHERTEXT_SIZE_768];
215
216 rng.squeeze(&mut d);
217 rng.squeeze(&mut z);
218 rng.squeeze(&mut m);
219 rng.squeeze(&mut ct_random);
220
221 let mut coins = [0u8; 64];
222 coins[..32].copy_from_slice(&d);
223 coins[32..].copy_from_slice(&z);
224
225 let (dk, ek) =
226 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &coins);
227 let (ct, k_encaps) =
228 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &ek, &m);
229
230 let k_decaps =
231 crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &dk, &ct).unwrap();
232 assert_eq!(k_encaps, k_decaps);
233
234 let k_decaps_random =
235 crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &dk, &ct_random).unwrap();
236
237 acc.absorb(&ek);
238 acc.absorb(&dk);
239 acc.absorb(&ct);
240 acc.absorb(&k_encaps);
241 acc.absorb(&k_decaps_random);
242 }
243
244 let mut hash = [0u8; 32];
245 acc.squeeze(&mut hash);
246 assert_eq!(
247 hex::encode(hash),
248 "f959d18d3d1180121433bf0e05f11e7908cf9d03edc150b2b07cb90bef5bc1c1",
249 "ML-KEM-768 CCTV accumulated hash mismatch"
250 );
251 }
252
253 #[test]
254 fn ml_kem_768_cctv_intermediate_vector() {
255 let d: [u8; 32] = decode_hex_array("f688563f7c66a5da2d8bdb5a5f3e07bd8dce6f7efcec7f41298d79863459f7cd");
256 let z: [u8; 32] = decode_hex_array("d1d49a515250dbceb9f6e3fcc1c7d5306918964b21ddb22207e03e57f0600da8");
257 let m: [u8; 32] = decode_hex_array("3dc27ca0a6594b0e56320457c45a0f76bb8a213ea4a76d442186a0aefadbcdb9");
258
259 let mut coins = [0u8; 64];
260 coins[..32].copy_from_slice(&d);
261 coins[32..].copy_from_slice(&z);
262
263 let (dk, ek) = crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &coins);
264 let (ct, k) = crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &ek, &m);
265
266 assert_eq!(
267 sha3_256_hex(&ek),
268 "42d930a50dfd1f0541ca45c4598daebb4f51cd10d711a001bd9bb87d5c87a4bf"
269 );
270 assert_eq!(
271 sha3_256_hex(&dk),
272 "db563aebd9fdc875e88563693edad1e5e359cc37b0f685d2d0a3723b37253192"
273 );
274 assert_eq!(
275 sha3_256_hex(&ct),
276 "9d6e358208c4d583050becb319050b7f916de47caad1d589a1d01fea43fe1750"
277 );
278 assert_eq!(
279 hex::encode(k),
280 "ae726da2df66601c6648a7565c02b203a089276ac30f6cc226d048f93fafd78c"
281 );
282 }
283
284 #[test]
285 fn ml_kem_768_decapsulation_with_wrong_key_rejects() {
286 let (_, alice_pk) = generate_keypair_768();
287 let (bob_sk, _bob_pk) = generate_keypair_768();
288 let (ct, _alice_ss) = alice_pk.encapsulate();
289
290 let wrong_ss = bob_sk.decapsulate(&ct).unwrap();
291 assert_ne!(_alice_ss, wrong_ss);
292 }
293
294 #[test]
295 fn ml_kem_768_round_trip_many() {
296 for _ in 0..100 {
297 let (sk, pk) = generate_keypair_768();
298 let (ct, ss_enc) = pk.encapsulate();
299 let ss_dec = sk.decapsulate(&ct).unwrap();
300 assert_eq!(ss_enc, ss_dec);
301 }
302 }
303
304 #[test]
305 fn ml_kem_768_all_zero_ciphertext_does_not_panic() {
306 let (sk, _pk) = generate_keypair_768();
307 let ct = [0u8; CIPHERTEXT_SIZE_768];
308 let _result = sk.decapsulate(&ct);
309 }
310
311 #[test]
312 fn ml_kem_768_all_ones_ciphertext_does_not_panic() {
313 let (sk, _pk) = generate_keypair_768();
314 let ct = [0xffu8; CIPHERTEXT_SIZE_768];
315 let _result = sk.decapsulate(&ct);
316 }
317
318 #[test]
319 fn ml_kem_768_derand_keygen_is_deterministic() {
320 let coins = [7u8; 64];
321 let (sk1, pk1) = crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &coins);
322 let (sk2, pk2) = crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &coins);
323 assert_eq!(sk1, sk2);
324 assert_eq!(pk1, pk2);
325 }
326
327 #[test]
328 fn ml_kem_768_key_sizes_are_correct() {
329 let (sk, pk) = generate_keypair_768();
330 let sk_bytes = sk.to_bytes();
331 let pk_bytes = pk.to_bytes();
332 assert_eq!(sk_bytes.len(), SECRET_KEY_SIZE_768);
333 assert_eq!(pk_bytes.len(), PUBLIC_KEY_SIZE_768);
334 let (ct, _) = pk.encapsulate();
335 assert_eq!(ct.len(), CIPHERTEXT_SIZE_768);
336 }
337
338 #[test]
339 fn ml_kem_768_encaps_is_deterministic_with_same_coins() {
340 let enc_coins = [9u8; 32];
341 let key_coins = [7u8; 64];
342 let (_sk, pk) =
343 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &key_coins);
344 let (ct1, ss1) =
345 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &pk, &enc_coins);
346 let (ct2, ss2) =
347 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &pk, &enc_coins);
348 assert_eq!(ct1, ct2);
349 assert_eq!(ss1, ss2);
350 }
351
352 #[test]
353 fn ml_kem_768_decapsulation_with_wrong_key_is_deterministic() {
354 let (_, pk_a) = generate_keypair_768();
355 let (sk_b, _pk_b) = generate_keypair_768();
356 let (ct, _) = pk_a.encapsulate();
357
358 let ss1 = sk_b.decapsulate(&ct).unwrap();
359 let ss2 = sk_b.decapsulate(&ct).unwrap();
360 assert_eq!(ss1, ss2, "implicit rejection must be deterministic");
361 }
362
363 #[test]
364 fn ml_kem_768_wycheproof_keygen() {
365 let data: serde_json::Value = serde_json::from_str(include_str!(
366 "../../testdata/wycheproof/testvectors_v1/mlkem_768_keygen_seed_test.json"
367 ))
368 .unwrap();
369 let mut tested = 0u64;
370 for group in data["testGroups"].as_array().unwrap() {
371 if group["parameterSet"].as_str() != Some("ML-KEM-768") {
372 continue;
373 }
374 for test in group["tests"].as_array().unwrap() {
375 let seed_hex = test["seed"].as_str().unwrap();
376 let expected_ek_hex = test["ek"].as_str().unwrap();
377 let expected_dk_hex = test["dk"].as_str().unwrap();
378 let result = test["result"].as_str().unwrap();
379
380 let seed = hex::decode_array::<64>(seed_hex.as_bytes()).unwrap();
381
382 let (dk, ek) =
383 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &seed);
384
385 let ek_hex = hex::encode(ek);
386 let dk_hex = hex::encode(dk);
387
388 if result == "valid" {
389 assert_eq!(
390 ek_hex, expected_ek_hex,
391 "wycheproof keygen KAT tcId={} ek mismatch",
392 test["tcId"]
393 );
394 assert_eq!(
395 dk_hex, expected_dk_hex,
396 "wycheproof keygen KAT tcId={} dk mismatch",
397 test["tcId"]
398 );
399 }
400 tested += 1;
401 }
402 }
403 assert!(tested > 0, "no ML-KEM-768 keygen tests were run");
404 }
405
406 fn wycheproof_kem_skip_invalid_lengths(seed_hex: &str, c_hex: &str, ct_size: usize) -> bool {
407 seed_hex.len() != 128 || c_hex.len() != ct_size * 2
408 }
409
410 #[test]
411 fn ml_kem_768_wycheproof_kem() {
412 let data: serde_json::Value =
413 serde_json::from_str(include_str!("../../testdata/wycheproof/testvectors_v1/mlkem_768_test.json")).unwrap();
414 let mut tested = 0u64;
415 for group in data["testGroups"].as_array().unwrap() {
416 if group["parameterSet"].as_str() != Some("ML-KEM-768") {
417 continue;
418 }
419 for test in group["tests"].as_array().unwrap() {
420 let seed_hex = test["seed"].as_str().unwrap();
421 let c_hex = test["c"].as_str().unwrap();
422 let expected_k_hex = test["K"].as_str().unwrap();
423 let result = test["result"].as_str().unwrap();
424
425 if wycheproof_kem_skip_invalid_lengths(seed_hex, c_hex, CIPHERTEXT_SIZE_768) {
426 tested += 1;
427 continue;
428 }
429
430 let seed = hex::decode_array::<64>(seed_hex.as_bytes()).unwrap();
431
432 let (dk, ek) =
433 crypto_kem_keypair_derand::<3, SECRET_KEY_SIZE_768, PUBLIC_KEY_SIZE_768>(&ML_KEM_768, &seed);
434
435 if let Some(expected_ek_hex) = test.get("ek").and_then(|v| v.as_str()) {
436 let ek_hex = hex::encode(ek);
437 assert_eq!(ek_hex, expected_ek_hex, "wycheproof KEM KAT tcId={} ek mismatch", test["tcId"]);
438 }
439
440 let c = decode_hex_array::<CIPHERTEXT_SIZE_768>(c_hex);
441 let shared_secret = crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &dk, &c);
442
443 if result == "valid" {
444 let k = shared_secret.unwrap();
445 let k_hex = hex::encode(k);
446 assert_eq!(k_hex, expected_k_hex, "wycheproof KEM KAT tcId={} K mismatch", test["tcId"]);
447 } else {
448 assert!(
449 shared_secret.is_ok(),
450 "wycheproof KEM KAT tcId={} unexpected error",
451 test["tcId"]
452 );
453 }
454 tested += 1;
455 }
456 }
457 assert!(tested > 0, "no ML-KEM-768 KEM tests were run");
458 }
459
460 #[test]
461 fn ml_kem_768_wycheproof_encaps() {
462 let data: serde_json::Value = serde_json::from_str(include_str!(
463 "../../testdata/wycheproof/testvectors_v1/mlkem_768_encaps_test.json"
464 ))
465 .unwrap();
466 let mut tested = 0u64;
467 for group in data["testGroups"].as_array().unwrap() {
468 if group["parameterSet"].as_str() != Some("ML-KEM-768") {
469 continue;
470 }
471 for test in group["tests"].as_array().unwrap() {
472 let ek_hex = test["ek"].as_str().unwrap();
473 let m_hex = test["m"].as_str().unwrap();
474 let expected_c_hex = test["c"].as_str().unwrap();
475 let expected_k_hex = test["K"].as_str().unwrap();
476 let result = test["result"].as_str().unwrap();
477
478 if ek_hex.len() != PUBLIC_KEY_SIZE_768 * 2 {
479 tested += 1;
480 continue;
481 }
482
483 let ek = decode_hex_array::<PUBLIC_KEY_SIZE_768>(ek_hex);
484
485 if result == "valid" {
486 let m = decode_hex_array::<32>(m_hex);
487 let (c, k) =
488 crypto_kem_enc_derand::<3, PUBLIC_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &ek, &m);
489 let c_hex_out = hex::encode(c);
490 let k_hex_out = hex::encode(k);
491 assert_eq!(
492 c_hex_out, expected_c_hex,
493 "wycheproof encaps KAT tcId={} c mismatch",
494 test["tcId"]
495 );
496 assert_eq!(
497 k_hex_out, expected_k_hex,
498 "wycheproof encaps KAT tcId={} K mismatch",
499 test["tcId"]
500 );
501 }
502 tested += 1;
503 }
504 }
505 assert!(tested > 0, "no ML-KEM-768 encaps tests were run");
506 }
507
508 #[test]
509 fn ml_kem_768_wycheproof_decaps_validation() {
510 let data: serde_json::Value = serde_json::from_str(include_str!(
511 "../../testdata/wycheproof/testvectors_v1/mlkem_768_semi_expanded_decaps_test.json"
512 ))
513 .unwrap();
514 let mut tested = 0u64;
515 for group in data["testGroups"].as_array().unwrap() {
516 if group["parameterSet"].as_str() != Some("ML-KEM-768") {
517 continue;
518 }
519 for test in group["tests"].as_array().unwrap() {
520 let flags: Vec<&str> = test["flags"]
521 .as_array()
522 .map(|a| a.iter().filter_map(|v| v.as_str()).collect())
523 .unwrap_or_default();
524 let dk_hex = test["dk"].as_str().unwrap();
525 let c_hex = test["c"].as_str().unwrap();
526
527 if flags.contains(&"IncorrectDecapsulationKeyLength") || flags.contains(&"IncorrectCiphertextLength") {
528 tested += 1;
529 continue;
530 }
531
532 let dk = decode_hex_array::<SECRET_KEY_SIZE_768>(dk_hex);
533 let c = decode_hex_array::<CIPHERTEXT_SIZE_768>(c_hex);
534
535 let result = crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &dk, &c);
536
537 assert!(result.is_ok(), "wycheproof decaps tcId={} panicked", test["tcId"]);
538 tested += 1;
539 }
540 }
541 assert!(tested > 0, "no ML-KEM-768 decaps validation tests were run");
542 }
543
544 #[test]
545 fn ml_kem_768_cross_implementation_pqcrypto() {
546 let data: serde_json::Value =
549 serde_json::from_str(include_str!("../../testdata/mlkem/pqcrypto_768_vectors.json")).unwrap();
550 let vectors = data.as_array().unwrap();
551 assert!(vectors.len() >= 5, "not enough cross-impl vectors");
552
553 for (i, vector) in vectors.iter().enumerate() {
554 let sk_hex = vector["sk"].as_str().unwrap();
555 let ct_hex = vector["ct"].as_str().unwrap();
556 let expected_ss_hex = vector["ss"].as_str().unwrap();
557
558 let sk = decode_hex_array::<SECRET_KEY_SIZE_768>(sk_hex);
559 let ct = decode_hex_array::<CIPHERTEXT_SIZE_768>(ct_hex);
560
561 let ss = crypto_kem_dec::<3, SECRET_KEY_SIZE_768, CIPHERTEXT_SIZE_768>(&ML_KEM_768, &sk, &ct).unwrap();
562 assert_eq!(
563 hex::encode(ss),
564 expected_ss_hex,
565 "cross-impl pqcrypto vector {i} decapsulation mismatch"
566 );
567 }
568 }
569}