1use crate::{
2 curve25519::x25519,
3 mlkem::{self, MlKemError},
4 sha3::{Sha3_256, Shake256},
5};
6
7pub const SECRET_KEY_SIZE: usize = 32;
9pub const PUBLIC_KEY_SIZE: usize = mlkem::PUBLIC_KEY_SIZE_768 + x25519::KEY_SIZE; pub const CIPHERTEXT_SIZE: usize = mlkem::CIPHERTEXT_SIZE_768 + x25519::SHARED_SECRET_SIZE; pub const SHARED_SECRET_SIZE: usize = 32;
15
16const XWING_LABEL: &[u8; 6] = b"\\.//^\\";
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum XWingError {
21 MlKem(MlKemError),
22}
23
24impl From<MlKemError> for XWingError {
25 fn from(err: MlKemError) -> Self {
26 XWingError::MlKem(err)
27 }
28}
29
30impl core::fmt::Display for XWingError {
31 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32 match self {
33 XWingError::MlKem(err) => write!(f, "ML-KEM error: {err}"),
34 }
35 }
36}
37
38#[derive(Clone, Debug, PartialEq, Eq)]
55pub struct SecretKey {
56 bytes: [u8; SECRET_KEY_SIZE],
57 x25519_secret_key: x25519::SecretKey,
58 x25519_public_key_bytes: [u8; x25519::KEY_SIZE],
59 mlkem_secret_key: mlkem::SecretKey768,
60}
61
62impl SecretKey {
63 pub fn to_bytes(&self) -> [u8; SECRET_KEY_SIZE] {
64 self.bytes
65 }
66
67 pub fn decapsulate(&self, ct: &[u8; CIPHERTEXT_SIZE]) -> Result<[u8; SHARED_SECRET_SIZE], XWingError> {
68 let ct_m = &ct[..mlkem::CIPHERTEXT_SIZE_768].try_into().unwrap();
69 let ct_x = x25519::PublicKey::from_bytes(&ct[mlkem::CIPHERTEXT_SIZE_768..].try_into().unwrap());
70
71 let ss_m = self.mlkem_secret_key.decapsulate(&ct_m)?;
72 let ss_x = self.x25519_secret_key.ecdh(&ct_x);
73
74 Ok(combiner(&ss_m, &ss_x, &ct_x.to_bytes(), &self.x25519_public_key_bytes))
75 }
76}
77
78#[derive(Clone, Debug, PartialEq, Eq)]
82pub struct PublicKey {
83 mlkem_public_key: mlkem::PublicKey768,
84 x25519_public_key: x25519::PublicKey,
85}
86
87impl PublicKey {
88 pub fn to_bytes(&self) -> [u8; PUBLIC_KEY_SIZE] {
89 let mut bytes = [0u8; PUBLIC_KEY_SIZE];
90 bytes[..mlkem::PUBLIC_KEY_SIZE_768].copy_from_slice(&self.mlkem_public_key.to_bytes());
91 bytes[mlkem::PUBLIC_KEY_SIZE_768..].copy_from_slice(&self.x25519_public_key.to_bytes());
92 bytes
93 }
94
95 pub fn encapsulate(&self) -> ([u8; SHARED_SECRET_SIZE], [u8; CIPHERTEXT_SIZE]) {
96 let eseed: [u8; 64] = rand::random();
97 self.encapsulate_derand(&eseed)
98 }
99
100 fn encapsulate_derand(&self, eseed: &[u8; 64]) -> ([u8; SHARED_SECRET_SIZE], [u8; CIPHERTEXT_SIZE]) {
101 let ek_x = x25519::SecretKey::from_bytes(&eseed[32..64].try_into().unwrap());
102 let ct_x = ek_x.public_key();
103 let ss_x = ek_x.ecdh(&self.x25519_public_key);
104
105 let m = &eseed[..32].try_into().unwrap();
106 let (ct_m, ss_m) = self.mlkem_public_key.encapsulate_derand(&m);
107
108 let ss = combiner(&ss_m, &ss_x, &ct_x.to_bytes(), &self.x25519_public_key.to_bytes());
109
110 let mut ct = [0u8; CIPHERTEXT_SIZE];
111 ct[..mlkem::CIPHERTEXT_SIZE_768].copy_from_slice(&ct_m);
112 ct[mlkem::CIPHERTEXT_SIZE_768..].copy_from_slice(&ct_x.to_bytes());
113
114 (ss, ct)
115 }
116}
117
118pub fn generate_keypair() -> (SecretKey, PublicKey) {
124 let seed: [u8; SECRET_KEY_SIZE] = rand::random();
125 generate_keypair_derand(&seed)
126}
127
128fn generate_keypair_derand(secret_key: &[u8; SECRET_KEY_SIZE]) -> (SecretKey, PublicKey) {
130 let (mlkem_sk, x25519_sk, mlkem_pk, x25519_pk) = expand_decapsulation_key(secret_key);
131
132 let secret_key = SecretKey {
133 bytes: *secret_key,
134 x25519_secret_key: x25519_sk,
135 x25519_public_key_bytes: x25519_pk.to_bytes(),
136 mlkem_secret_key: mlkem_sk,
137 };
138
139 let public_key = PublicKey {
140 mlkem_public_key: mlkem_pk,
141 x25519_public_key: x25519_pk,
142 };
143
144 (secret_key, public_key)
145}
146
147fn expand_decapsulation_key(
148 secret_key: &[u8; 32],
149) -> (mlkem::SecretKey768, x25519::SecretKey, mlkem::PublicKey768, x25519::PublicKey) {
150 let mut expanded_secret_key = [0u8; 96];
151 Shake256::hash(secret_key, &mut expanded_secret_key);
152
153 let (sk_m, pk_m) = derive_mlkeem_keys(&expanded_secret_key);
154
155 let sk_x = x25519::SecretKey::from_bytes(&expanded_secret_key[64..96].try_into().unwrap());
156 let pk_x = sk_x.public_key();
157
158 (sk_m, sk_x, pk_m, pk_x)
159}
160
161fn derive_mlkeem_keys(expnded_secret_key: &[u8; 96]) -> (mlkem::SecretKey768, mlkem::PublicKey768) {
162 mlkem::generate_keypair_768_derand(&expnded_secret_key[..64].try_into().unwrap())
163}
164
165fn combiner(
166 ss_m: &[u8; mlkem::SHARED_SECRET_SIZE],
167 ss_x: &[u8; x25519::KEY_SIZE],
168 ct_x: &[u8; x25519::KEY_SIZE],
169 pk_x: &[u8; x25519::KEY_SIZE],
170) -> [u8; SHARED_SECRET_SIZE] {
171 use crate::Hasher;
172 let mut hasher = Sha3_256::new();
173 hasher.update(ss_m);
174 hasher.update(ss_x);
175 hasher.update(ct_x);
176 hasher.update(pk_x);
177 hasher.update(XWING_LABEL);
178 hasher.sum().as_ref().try_into().unwrap()
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 fn hex_to_array<const N: usize>(hex_str: &str) -> [u8; N] {
186 let bytes = hex::decode(hex_str).unwrap();
187 return bytes.try_into().unwrap();
188 }
189
190 #[test]
191 fn constants() {
192 assert!(PUBLIC_KEY_SIZE == 1216);
193 assert!(CIPHERTEXT_SIZE == 1120);
194 }
195
196 struct TestVector {
197 seed: &'static str,
198 eseed: &'static str,
199 ss: &'static str,
200 }
201
202 const TEST_VECTORS: [TestVector; 3] = [
203 TestVector {
204 seed: "7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26",
205 eseed: "3cb1eea988004b93103cfb0aeefd2a686e01fa4a58e8a3639ca8a1e3f9ae57e235b8cc873c23dc62b8d260169afa2f75ab916a58d974918835d25e6a435085b2",
206 ss: "d2df0522128f09dd8e2c92b1e905c793d8f57a54c3da25861f10bf4ca613e384",
207 },
208 TestVector {
209 seed: "badfd6dfaac359a5efbb7bcc4b59d538df9a04302e10c8bc1cbf1a0b3a5120ea",
210 eseed: "17cda7cfad765f5623474d368ccca8af0007cd9f5e4c849f167a580b14aabdefaee7eef47cb0fca9767be1fda69419dfb927e9df07348b196691abaeb580b32d",
211 ss: "f2e86241c64d60f6649fbc6c5b7d17180b780a3f34355e64a85749949c45f150",
212 },
213 TestVector {
214 seed: "ef58538b8d23f87732ea63b02b4fa0f4873360e2841928cd60dd4cee8cc0d4c9",
215 eseed: "22a96188d032675c8ac850933c7aff1533b94c834adbb69c6115bad4692d8619f90b0cdf8a7b9c264029ac185b70b83f2801f2f4b3f70c593ea3aeeb613a7f1b",
216 ss: "953f7f4e8c5b5049bdc771d1dffada0dd961477d1a2ae0988baa7ea6898d893f",
217 },
218 ];
219
220 #[test]
221 fn test_vectors_from_draft() {
222 for (i, tv) in TEST_VECTORS.iter().enumerate() {
223 let seed: [u8; 32] = hex_to_array(tv.seed);
224 let eseed: [u8; 64] = hex_to_array(tv.eseed);
225 let expected_ss: [u8; 32] = hex_to_array(tv.ss);
226
227 let (secret_key, pk) = generate_keypair_derand(&seed);
228 assert_eq!(secret_key.to_bytes(), seed, "vector {i}: sk mismatch");
229
230 let (ss, ct) = pk.encapsulate_derand(&eseed);
231 assert_eq!(ss, expected_ss, "vector {i}: encaps ss mismatch");
232
233 let decapsulated_ss = secret_key.decapsulate(&ct).unwrap();
234 assert_eq!(decapsulated_ss, expected_ss, "vector {i}: decaps ss mismatch");
235 }
236 }
237
238 #[test]
239 fn round_trip() {
240 let (secret_key, public_key) = generate_keypair();
241 let (ss, ct) = public_key.encapsulate();
242 let decapsulated = secret_key.decapsulate(&ct).unwrap();
243 assert_eq!(ss, decapsulated);
244 }
245
246 #[test]
247 fn round_trip_many() {
248 for _ in 0..10 {
249 let (secret_key, public_key) = generate_keypair();
250 let (ss, ct) = public_key.encapsulate();
251 let decapsulated = secret_key.decapsulate(&ct).unwrap();
252 assert_eq!(ss, decapsulated);
253 }
254 }
255
256 #[test]
257 fn decapsulation_with_wrong_key_produces_different_secret() {
258 let (_, pk_a) = generate_keypair();
259 let (sk_b, _) = generate_keypair();
260
261 let (ss_a, ct) = pk_a.encapsulate();
262 let ss_b = sk_b.decapsulate(&ct).unwrap();
263 assert_ne!(ss_a, ss_b);
264 }
265
266 #[test]
267 fn tampered_ciphertext_produces_different_secret() {
268 let (secret_key, public_key) = generate_keypair();
269 let (ss, mut ct) = public_key.encapsulate();
270
271 ct[0] ^= 0x80;
272
273 let tampered_ss = secret_key.decapsulate(&ct).unwrap();
274 assert_ne!(ss, tampered_ss);
275 }
276
277 #[test]
278 fn derandomized_keygen_is_deterministic() {
279 let seed: [u8; 32] = hex_to_array("7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
280 let (sk1, pk1) = generate_keypair_derand(&seed);
281 let (sk2, pk2) = generate_keypair_derand(&seed);
282 assert_eq!(sk1.to_bytes(), sk2.to_bytes());
283 assert_eq!(pk1.to_bytes(), pk2.to_bytes());
284 }
285
286 #[test]
287 fn derandomized_encaps_is_deterministic() {
288 let seed: [u8; 32] = hex_to_array("7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
289 let eseed: [u8; 64] = hex_to_array(
290 "3cb1eea988004b93103cfb0aeefd2a686e01fa4a58e8a3639ca8a1e3f9ae57e235b8cc873c23dc62b8d260169afa2f75ab916a58d974918835d25e6a435085b2",
291 );
292 let (_, pk) = generate_keypair_derand(&seed);
293
294 let (ss1, ct1) = pk.encapsulate_derand(&eseed);
295 let (ss2, ct2) = pk.encapsulate_derand(&eseed);
296 assert_eq!(ct1, ct2);
297 assert_eq!(ss1, ss2);
298 }
299
300 #[test]
301 fn xwing_label_is_correct() {
302 assert_eq!(XWING_LABEL.len(), 6);
303 assert_eq!(hex::encode(XWING_LABEL), "5c2e2f2f5e5c");
304 }
305
306 #[test]
307 fn expand_decapsulation_key_is_deterministic() {
308 let seed: [u8; 32] = hex_to_array("7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
309
310 let (sk_m1, sk_x1, pk_m1, pk_x1) = expand_decapsulation_key(&seed);
311 let (sk_m2, sk_x2, pk_m2, pk_x2) = expand_decapsulation_key(&seed);
312 assert_eq!(sk_m1, sk_m2);
313 assert_eq!(sk_x1, sk_x2);
314 assert_eq!(pk_m1, pk_m2);
315 assert_eq!(pk_x1, pk_x2);
316 }
317
318 #[test]
319 fn combiner_is_deterministic() {
320 let ss_m = [0x01u8; 32];
321 let ss_x = [0x02u8; 32];
322 let ct_x = [0x03u8; 32];
323 let pk_x = [0x04u8; 32];
324
325 let result1 = combiner(&ss_m, &ss_x, &ct_x, &pk_x);
326 let result2 = combiner(&ss_m, &ss_x, &ct_x, &pk_x);
327 assert_eq!(result1, result2);
328 }
329}