Skip to main content

crypto/aes/
aes_gcm.rs

1#[cfg(target_arch = "x86_64")]
2use core::arch::x86_64::__m128i;
3
4#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
5use super::aes::RoundKeys;
6use super::{
7    aes::{encrypt_block, key_expand},
8    aes_ctr::Aes256Ctr,
9    ghash::{compute_tag, precompute_ghash_powers, precompute_ghash_table},
10};
11use crate::{Aead, AeadError, StreamCipher, Tag};
12
13/// AES-256-GCM authenticated cipher.
14///
15/// On x86-64 machines with AES-NI + PCLMULQDQ the methods automatically
16/// dispatch to the hardware-accelerated path (see `aes_gcm_amd64`).
17///
18/// The struct stores **only** the round keys native to the target architecture.
19/// - x86_64: stores `round_keys_ni` (`[__m128i; 15]`) + precomputed GHASH powers
20/// - aarch64: stores `round_keys_arm` (`[uint8x16_t; 15]`) + precomputed GHASH powers
21/// - other: stores `round_keys` (`[[u8; 16]; 15]`)
22///
23/// The raw 32-byte key is retained so the software fallback can recompute
24/// the expanded key on the rare occasion the hardware path is unavailable.
25pub(crate) const MAX_GCM_LEN: usize = (u32::MAX as usize - 1) * 16;
26
27pub struct Aes256Gcm {
28    pub(crate) key: [u8; 32],
29    /// x86_64 AES-NI round keys (precomputed in `new()`).
30    #[cfg(target_arch = "x86_64")]
31    pub(crate) round_keys_ni: [__m128i; 15],
32    /// Precomputed GHASH powers [H, H², H³, H⁴] in bit-reversed-per-byte form.
33    /// Used by 4-block aggregated GHASH to avoid recomputing H on every call.
34    #[cfg(target_arch = "x86_64")]
35    pub(crate) h_powers: [__m128i; 8],
36    /// aarch64 ARMv8 round keys (precomputed in `new()`).
37    #[cfg(target_arch = "aarch64")]
38    pub(crate) round_keys_arm: [core::arch::aarch64::uint8x16_t; 15],
39    /// Precomputed GHASH powers [H¹..H⁸] in bit-reversed-per-byte form.
40    /// Used by 8-block aggregated GHASH to avoid recomputing H on every call.
41    #[cfg(target_arch = "aarch64")]
42    pub(crate) h_powers: [core::arch::aarch64::uint8x16_t; 8],
43    /// Software round keys (targets without hardware acceleration).
44    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
45    pub(crate) round_keys: RoundKeys,
46}
47
48impl Aes256Gcm {
49    pub const KEY_SIZE: usize = 32;
50    /// Create a new `Aes256Gcm` instance from a 32-byte key.
51    ///
52    /// Precomputes the target-specific round keys and GHASH powers (H, H², H³, H⁴)
53    /// using software GF(2¹²⁸) multiplication, so `new()` is safe on any CPU
54    /// and does not require hardware feature detection.
55    pub fn new(key: &[u8; 32]) -> Self {
56        #[cfg(target_arch = "x86_64")]
57        {
58            use core::arch::x86_64::*;
59            let rk_soft = key_expand(key);
60            let mut rk = unsafe { [_mm_setzero_si128(); 15] };
61            for i in 0..15 {
62                rk[i] = unsafe { _mm_loadu_si128(rk_soft[i].as_ptr().cast()) };
63            }
64            let (h_powers_bytes, _h) = precompute_ghash_powers(key);
65            let mut h_powers = unsafe { [_mm_setzero_si128(); 8] };
66            for i in 0..8 {
67                h_powers[i] = unsafe { _mm_loadu_si128(h_powers_bytes[i].as_ptr().cast()) };
68            }
69            Aes256Gcm {
70                key: *key,
71                round_keys_ni: rk,
72                h_powers,
73            }
74        }
75
76        #[cfg(target_arch = "aarch64")]
77        {
78            use core::arch::aarch64::*;
79            let rk_soft = key_expand(key);
80            let mut rk = [unsafe { vdupq_n_u8(0) }; 15];
81            for i in 0..15 {
82                rk[i] = unsafe { vld1q_u8(rk_soft[i].as_ptr()) };
83            }
84            let (h_powers_bytes, _h) = precompute_ghash_powers(key);
85            let mut h_powers = [unsafe { vdupq_n_u8(0) }; 8];
86            for i in 0..8 {
87                h_powers[i] = unsafe { vld1q_u8(h_powers_bytes[i].as_ptr()) };
88            }
89            Aes256Gcm {
90                key: *key,
91                round_keys_arm: rk,
92                h_powers,
93            }
94        }
95
96        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
97        {
98            Aes256Gcm {
99                key: *key,
100                round_keys: key_expand(key),
101            }
102        }
103    }
104
105    /// Pure-Rust encrypt implementation.
106    ///
107    /// The expanded key is recomputed here because on x86_64/aarch64 we store
108    /// only the hardware-specific keys. This path is only reached when
109    /// the hardware accelerator is unavailable, so the overhead is negligible.
110    pub(crate) fn encrypt_in_place_soft(&self, in_out: &mut [u8], nonce: &[u8; 12], aad: &[u8]) -> Tag {
111        assert!(
112            in_out.len() <= MAX_GCM_LEN,
113            "GCM plaintext exceeds maximum allowed length (2^32 - 2 blocks)"
114        );
115        let rk = key_expand(&self.key);
116        let h = encrypt_block(&rk, &[0u8; 16]);
117        let ghash_table = precompute_ghash_table(&h);
118
119        // J0 = nonce || 0x00000001
120        let mut j0 = [0u8; 16];
121        j0[..12].copy_from_slice(nonce);
122        j0[15] = 1;
123
124        let ej0 = encrypt_block(&rk, &j0);
125
126        // CTR starts at J0 + 1 (= nonce || 0x00000002)
127        j0[15] = 2;
128
129        let mut aes_ctr = Aes256Ctr::new(&self.key);
130        aes_ctr.set_counter(&j0);
131        aes_ctr.xor_keystream(in_out);
132        compute_tag(&ghash_table, aad, in_out, &ej0)
133    }
134
135    /// Pure-Rust decrypt implementation.
136    pub(crate) fn decrypt_in_place_soft(
137        &self,
138        in_out: &mut [u8],
139        tag: &[u8; 16],
140        nonce: &[u8; 12],
141        aad: &[u8],
142    ) -> Result<(), AeadError> {
143        if in_out.len() > MAX_GCM_LEN {
144            return Err(AeadError::InvalidCiphertext);
145        }
146        let rk = key_expand(&self.key);
147        let h = encrypt_block(&rk, &[0u8; 16]);
148        let ghash_table = precompute_ghash_table(&h);
149
150        let mut j0 = [0u8; 16];
151        j0[..12].copy_from_slice(nonce);
152        j0[15] = 1;
153
154        let ej0 = encrypt_block(&rk, &j0);
155
156        // Verify tag before decrypting (authenticate-then-decrypt ordering)
157        let expected_tag = compute_tag(&ghash_table, aad, in_out, &ej0);
158
159        // Constant-time comparison to avoid timing oracle
160        let mut diff = 0u8;
161        for i in 0..16 {
162            diff |= expected_tag.as_ref()[i] ^ tag[i];
163        }
164        if diff != 0 {
165            return Err(AeadError::InvalidCiphertext);
166        }
167
168        // CTR starts at J0 + 1 (= nonce || 0x00000002)
169        j0[15] = 2;
170        let mut aes_ctr = Aes256Ctr::new(&self.key);
171        aes_ctr.set_counter(&j0);
172        aes_ctr.xor_keystream(in_out);
173
174        Ok(())
175    }
176}
177
178impl Aead for Aes256Gcm {
179    const TAG_SIZE: usize = 16;
180    const NONCE_SIZE: usize = 12;
181
182    #[inline]
183    #[allow(unreachable_code)]
184    fn encrypt_in_place(&self, in_out: &mut [u8], nonce: &[u8], aad: &[u8]) -> Tag {
185        assert_eq!(nonce.len(), 12, "AES-256-GCM nonce must be 12 bytes");
186        let nonce_arr: &[u8; 12] = nonce.try_into().unwrap();
187
188        #[cfg(target_arch = "aarch64")]
189        {
190            use crate::aes::aes_gcm_arm64::gcm_encrypt_armv8;
191            return unsafe { gcm_encrypt_armv8(&self.round_keys_arm, &self.h_powers, in_out, nonce_arr, aad) };
192        }
193
194        #[cfg(feature = "std")]
195        {
196            #[cfg(target_arch = "x86_64")]
197            {
198                use crate::aes::aes_gcm_amd64::gcm_encrypt_aesni;
199                if std::arch::is_x86_feature_detected!("aes")
200                    && std::arch::is_x86_feature_detected!("pclmulqdq")
201                    && std::arch::is_x86_feature_detected!("ssse3")
202                    && std::arch::is_x86_feature_detected!("sse4.1")
203                {
204                    return unsafe { gcm_encrypt_aesni(&self.round_keys_ni, &self.h_powers, in_out, nonce_arr, aad) };
205                }
206            }
207        }
208
209        #[cfg(not(feature = "std"))]
210        {
211            #[cfg(all(
212                target_feature = "aes",
213                target_feature = "pclmulqdq",
214                target_feature = "ssse3",
215                target_feature = "sse4.1"
216            ))]
217            {
218                use crate::aes::aes_gcm_amd64::gcm_encrypt_aesni;
219                return unsafe { gcm_encrypt_aesni(&self.round_keys_ni, &self.h_powers, in_out, nonce_arr, aad) };
220            }
221        }
222
223        self.encrypt_in_place_soft(in_out, nonce_arr, aad)
224    }
225
226    #[inline]
227    #[allow(unreachable_code)]
228    fn decrypt_in_place(&self, in_out: &mut [u8], nonce: &[u8], aad: &[u8], tag: &[u8]) -> Result<(), AeadError> {
229        assert_eq!(nonce.len(), 12, "AES-256-GCM nonce must be 12 bytes");
230        let nonce_arr: &[u8; 12] = nonce.try_into().unwrap();
231        let tag_arr: &[u8; 16] = tag.try_into().expect("AES-256-GCM tag must be 16 bytes");
232
233        #[cfg(target_arch = "aarch64")]
234        {
235            use crate::aes::aes_gcm_arm64::gcm_decrypt_armv8;
236            unsafe { return gcm_decrypt_armv8(&self.round_keys_arm, &self.h_powers, in_out, tag_arr, nonce_arr, aad) }
237        }
238
239        #[cfg(feature = "std")]
240        {
241            #[cfg(target_arch = "x86_64")]
242            {
243                use crate::aes::aes_gcm_amd64::gcm_decrypt_aesni;
244                if std::arch::is_x86_feature_detected!("aes")
245                    && std::arch::is_x86_feature_detected!("pclmulqdq")
246                    && std::arch::is_x86_feature_detected!("ssse3")
247                    && std::arch::is_x86_feature_detected!("sse4.1")
248                {
249                    unsafe {
250                        return gcm_decrypt_aesni(&self.round_keys_ni, &self.h_powers, in_out, tag_arr, nonce_arr, aad);
251                    }
252                }
253            }
254        }
255
256        #[cfg(not(feature = "std"))]
257        {
258            #[cfg(all(
259                target_feature = "aes",
260                target_feature = "pclmulqdq",
261                target_feature = "ssse3",
262                target_feature = "sse4.1"
263            ))]
264            {
265                use crate::aes::aes_gcm_amd64::gcm_decrypt_aesni;
266                unsafe {
267                    return gcm_decrypt_aesni(&self.round_keys_ni, &self.h_powers, in_out, tag_arr, nonce_arr, aad);
268                }
269            }
270        }
271
272        self.decrypt_in_place_soft(in_out, tag_arr, nonce_arr, aad)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    // ── AES-256-GCM (NIST SP 800-38D Appendix B and additional vectors) ────────
281
282    include!("aes_gcm_vectors.rs");
283
284    fn run_gcm_vector_soft(v: &GcmVector) {
285        let key: [u8; 32] = hex::decode_array::<32>(v.key.as_bytes()).unwrap();
286        let nonce: [u8; 12] = hex::decode_array::<12>(v.nonce.as_bytes()).unwrap();
287        let pt = hex::decode(v.pt).unwrap();
288        let aad = hex::decode(v.aad).unwrap();
289        let expected_ct = hex::decode(v.ct).unwrap();
290        let expected_tag: [u8; 16] = hex::decode_array::<16>(v.tag.as_bytes()).unwrap();
291
292        let cipher = Aes256Gcm::new(&key);
293
294        // Encrypt
295        let mut buf = pt.clone();
296        let tag = cipher.encrypt_in_place_soft(&mut buf, &nonce, &aad);
297        assert_eq!(buf, expected_ct, "ciphertext mismatch for key={}", v.key);
298        assert_eq!(tag.as_ref(), &expected_tag[..], "tag mismatch for key={}", v.key);
299
300        // Decrypt
301        let mut buf2 = expected_ct.clone();
302        cipher
303            .decrypt_in_place_soft(&mut buf2, &expected_tag, &nonce, &aad)
304            .expect("decrypt failed");
305        assert_eq!(buf2, pt, "plaintext mismatch after decrypt for key={}", v.key);
306    }
307
308    #[test]
309    fn nist_gcm_test_vectors_soft() {
310        for v in NIST_GCM_VECTORS.iter().chain(EXTRA_GCM_VECTORS.iter()) {
311            run_gcm_vector_soft(v);
312        }
313    }
314
315    #[test]
316    fn gcm_tag_mismatch_returns_error_soft() {
317        let key = [0u8; 32];
318        let nonce = [0u8; 12];
319        let cipher = Aes256Gcm::new(&key);
320        let mut buf = b"hello world".to_vec();
321        let tag = cipher.encrypt_in_place_soft(&mut buf, &nonce, &[]);
322        // Flip one tag byte
323        let mut bad_tag: [u8; 16] = tag.as_ref().try_into().unwrap();
324        bad_tag[0] ^= 0xff;
325        let mut buf2 = buf.clone();
326        assert!(cipher.decrypt_in_place_soft(&mut buf2, &bad_tag, &nonce, &[]).is_err());
327    }
328
329    #[test]
330    fn gcm_encrypt_decrypt_large_soft() {
331        let key = [0xabu8; 32];
332        let nonce = [0x01u8; 12];
333        let aad = b"additional data";
334        let plaintext: Vec<u8> = (0u8..=255u8).cycle().take(1024).collect();
335
336        let cipher = Aes256Gcm::new(&key);
337        let mut buf = plaintext.clone();
338        let tag = cipher.encrypt_in_place_soft(&mut buf, &nonce, aad);
339        let tag_bytes: [u8; 16] = tag.as_ref().try_into().unwrap();
340        cipher
341            .decrypt_in_place_soft(&mut buf, &tag_bytes, &nonce, aad)
342            .expect("decrypt failed");
343        assert_eq!(buf, plaintext);
344    }
345
346    #[test]
347    fn gcm_empty_plaintext_nonempty_aad_soft() {
348        let key: [u8; 32] =
349            hex::decode_array::<32>(b"feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308").unwrap();
350        let nonce: [u8; 12] = hex::decode_array::<12>(b"cafebabefacedbaddecaf888").unwrap();
351        let aad = hex::decode("feedfacedeadbeeffeedfacedeadbeef").unwrap();
352        let cipher = Aes256Gcm::new(&key);
353        let mut buf: Vec<u8> = vec![];
354        let tag = cipher.encrypt_in_place_soft(&mut buf, &nonce, &aad);
355        let tag_bytes: [u8; 16] = tag.as_ref().try_into().unwrap();
356        cipher
357            .decrypt_in_place_soft(&mut buf, &tag_bytes, &nonce, &aad)
358            .expect("decrypt failed");
359    }
360
361    // ── Dispatching wrappers (use hardware path when available) ───────────────
362
363    #[test]
364    fn nist_gcm_test_vectors_dispatch() {
365        for v in NIST_GCM_VECTORS.iter().chain(EXTRA_GCM_VECTORS.iter()) {
366            let key: [u8; 32] = hex::decode_array::<32>(v.key.as_bytes()).unwrap();
367            let nonce: [u8; 12] = hex::decode_array::<12>(v.nonce.as_bytes()).unwrap();
368            let pt = hex::decode(v.pt).unwrap();
369            let aad = hex::decode(v.aad).unwrap();
370            let expected_ct = hex::decode(v.ct).unwrap();
371            let expected_tag: [u8; 16] = hex::decode_array::<16>(v.tag.as_bytes()).unwrap();
372
373            let cipher = Aes256Gcm::new(&key);
374
375            let mut buf = pt.clone();
376            let tag = cipher.encrypt_in_place(&mut buf, &nonce[..], &aad);
377            assert_eq!(&buf[..], &expected_ct[..], "dispatch ciphertext mismatch key={}", v.key);
378            assert_eq!(tag.as_ref(), &expected_tag[..], "dispatch tag mismatch key={}", v.key);
379
380            let mut buf2 = expected_ct.clone();
381            cipher
382                .decrypt_in_place(&mut buf2, &nonce[..], &aad, &expected_tag)
383                .expect("dispatch decrypt failed");
384            assert_eq!(buf2, pt);
385        }
386    }
387
388    // --- Wycheproof test vectors ---
389
390    #[test]
391    fn wycheproof_gcm_vectors() {
392        let data: serde_json::Value =
393            serde_json::from_str(include_str!("../../testdata/wycheproof/testvectors_v1/aes_gcm_test.json")).unwrap();
394        let mut valid_tested = 0u64;
395        let mut invalid_tested = 0u64;
396        for group in data["testGroups"].as_array().unwrap() {
397            if group["keySize"].as_u64() != Some(256) {
398                continue;
399            }
400            if group["ivSize"].as_u64() != Some(96) {
401                continue;
402            }
403            if group["tagSize"].as_u64() != Some(128) {
404                continue;
405            }
406            for test in group["tests"].as_array().unwrap() {
407                let key_hex = test["key"].as_str().unwrap();
408                let iv_hex = test["iv"].as_str().unwrap();
409                let msg_hex = test["msg"].as_str().unwrap();
410                let aad_hex = test["aad"].as_str().unwrap();
411                let ct_hex = test["ct"].as_str().unwrap();
412                let tag_hex = test["tag"].as_str().unwrap();
413                let result = test["result"].as_str().unwrap();
414
415                let key = hex::decode_array::<32>(key_hex.as_bytes()).unwrap();
416                let nonce = hex::decode_array::<12>(iv_hex.as_bytes()).unwrap();
417                let expected_ct = hex::decode(ct_hex).unwrap();
418                let expected_tag = hex::decode_array::<16>(tag_hex.as_bytes()).unwrap();
419                let pt = hex::decode(msg_hex).unwrap();
420                let aad = hex::decode(aad_hex).unwrap();
421
422                let cipher = Aes256Gcm::new(&key);
423
424                if result == "valid" {
425                    let mut buf = pt.clone();
426                    let tag = cipher.encrypt_in_place(&mut buf, &nonce[..], &aad);
427                    assert_eq!(buf, expected_ct, "wycheproof GCM tcId={} ct mismatch", test["tcId"]);
428                    assert_eq!(
429                        tag.as_ref(),
430                        &expected_tag[..],
431                        "wycheproof GCM tcId={} tag mismatch",
432                        test["tcId"]
433                    );
434
435                    let mut buf2 = expected_ct.clone();
436                    cipher
437                        .decrypt_in_place(&mut buf2, &nonce[..], &aad, &expected_tag[..])
438                        .expect("wycheproof GCM decrypt failed");
439                    assert_eq!(buf2, pt, "wycheproof GCM tcId={} pt mismatch", test["tcId"]);
440                    valid_tested += 1;
441                } else {
442                    let mut buf = expected_ct.clone();
443                    let result = cipher.decrypt_in_place(&mut buf, &nonce[..], &aad, &expected_tag[..]);
444                    assert!(
445                        result.is_err(),
446                        "wycheproof GCM tcId={} expected invalid but passed",
447                        test["tcId"]
448                    );
449                    invalid_tested += 1;
450                }
451            }
452        }
453        assert!(valid_tested > 0, "no valid AES-GCM wycheproof tests were run");
454        assert!(invalid_tested > 0, "no invalid AES-GCM wycheproof tests were run");
455    }
456}