1#[cfg(target_arch = "x86_64")]
2use core::arch::x86_64::__m128i;
3
4#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
5use super::aes::RoundKeys;
6use super::{
7 aes::{encrypt_block, key_expand},
8 aes_ctr::Aes256Ctr,
9 ghash::{compute_tag, precompute_ghash_powers, precompute_ghash_table},
10};
11use crate::{Aead, AeadError, StreamCipher, Tag};
12
13pub(crate) const MAX_GCM_LEN: usize = (u32::MAX as usize - 1) * 16;
26
27pub struct Aes256Gcm {
28 pub(crate) key: [u8; 32],
29 #[cfg(target_arch = "x86_64")]
31 pub(crate) round_keys_ni: [__m128i; 15],
32 #[cfg(target_arch = "x86_64")]
35 pub(crate) h_powers: [__m128i; 8],
36 #[cfg(target_arch = "aarch64")]
38 pub(crate) round_keys_arm: [core::arch::aarch64::uint8x16_t; 15],
39 #[cfg(target_arch = "aarch64")]
42 pub(crate) h_powers: [core::arch::aarch64::uint8x16_t; 8],
43 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
45 pub(crate) round_keys: RoundKeys,
46}
47
48impl Aes256Gcm {
49 pub const KEY_SIZE: usize = 32;
50 pub fn new(key: &[u8; 32]) -> Self {
56 #[cfg(target_arch = "x86_64")]
57 {
58 use core::arch::x86_64::*;
59 let rk_soft = key_expand(key);
60 let mut rk = unsafe { [_mm_setzero_si128(); 15] };
61 for i in 0..15 {
62 rk[i] = unsafe { _mm_loadu_si128(rk_soft[i].as_ptr().cast()) };
63 }
64 let (h_powers_bytes, _h) = precompute_ghash_powers(key);
65 let mut h_powers = unsafe { [_mm_setzero_si128(); 8] };
66 for i in 0..8 {
67 h_powers[i] = unsafe { _mm_loadu_si128(h_powers_bytes[i].as_ptr().cast()) };
68 }
69 Aes256Gcm {
70 key: *key,
71 round_keys_ni: rk,
72 h_powers,
73 }
74 }
75
76 #[cfg(target_arch = "aarch64")]
77 {
78 use core::arch::aarch64::*;
79 let rk_soft = key_expand(key);
80 let mut rk = [unsafe { vdupq_n_u8(0) }; 15];
81 for i in 0..15 {
82 rk[i] = unsafe { vld1q_u8(rk_soft[i].as_ptr()) };
83 }
84 let (h_powers_bytes, _h) = precompute_ghash_powers(key);
85 let mut h_powers = [unsafe { vdupq_n_u8(0) }; 8];
86 for i in 0..8 {
87 h_powers[i] = unsafe { vld1q_u8(h_powers_bytes[i].as_ptr()) };
88 }
89 Aes256Gcm {
90 key: *key,
91 round_keys_arm: rk,
92 h_powers,
93 }
94 }
95
96 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
97 {
98 Aes256Gcm {
99 key: *key,
100 round_keys: key_expand(key),
101 }
102 }
103 }
104
105 pub(crate) fn encrypt_in_place_soft(&self, in_out: &mut [u8], nonce: &[u8; 12], aad: &[u8]) -> Tag {
111 assert!(
112 in_out.len() <= MAX_GCM_LEN,
113 "GCM plaintext exceeds maximum allowed length (2^32 - 2 blocks)"
114 );
115 let rk = key_expand(&self.key);
116 let h = encrypt_block(&rk, &[0u8; 16]);
117 let ghash_table = precompute_ghash_table(&h);
118
119 let mut j0 = [0u8; 16];
121 j0[..12].copy_from_slice(nonce);
122 j0[15] = 1;
123
124 let ej0 = encrypt_block(&rk, &j0);
125
126 j0[15] = 2;
128
129 let mut aes_ctr = Aes256Ctr::new(&self.key);
130 aes_ctr.set_counter(&j0);
131 aes_ctr.xor_keystream(in_out);
132 compute_tag(&ghash_table, aad, in_out, &ej0)
133 }
134
135 pub(crate) fn decrypt_in_place_soft(
137 &self,
138 in_out: &mut [u8],
139 tag: &[u8; 16],
140 nonce: &[u8; 12],
141 aad: &[u8],
142 ) -> Result<(), AeadError> {
143 if in_out.len() > MAX_GCM_LEN {
144 return Err(AeadError::InvalidCiphertext);
145 }
146 let rk = key_expand(&self.key);
147 let h = encrypt_block(&rk, &[0u8; 16]);
148 let ghash_table = precompute_ghash_table(&h);
149
150 let mut j0 = [0u8; 16];
151 j0[..12].copy_from_slice(nonce);
152 j0[15] = 1;
153
154 let ej0 = encrypt_block(&rk, &j0);
155
156 let expected_tag = compute_tag(&ghash_table, aad, in_out, &ej0);
158
159 let mut diff = 0u8;
161 for i in 0..16 {
162 diff |= expected_tag.as_ref()[i] ^ tag[i];
163 }
164 if diff != 0 {
165 return Err(AeadError::InvalidCiphertext);
166 }
167
168 j0[15] = 2;
170 let mut aes_ctr = Aes256Ctr::new(&self.key);
171 aes_ctr.set_counter(&j0);
172 aes_ctr.xor_keystream(in_out);
173
174 Ok(())
175 }
176}
177
178impl Aead for Aes256Gcm {
179 const TAG_SIZE: usize = 16;
180 const NONCE_SIZE: usize = 12;
181
182 #[inline]
183 #[allow(unreachable_code)]
184 fn encrypt_in_place(&self, in_out: &mut [u8], nonce: &[u8], aad: &[u8]) -> Tag {
185 assert_eq!(nonce.len(), 12, "AES-256-GCM nonce must be 12 bytes");
186 let nonce_arr: &[u8; 12] = nonce.try_into().unwrap();
187
188 #[cfg(target_arch = "aarch64")]
189 {
190 use crate::aes::aes_gcm_arm64::gcm_encrypt_armv8;
191 return unsafe { gcm_encrypt_armv8(&self.round_keys_arm, &self.h_powers, in_out, nonce_arr, aad) };
192 }
193
194 #[cfg(feature = "std")]
195 {
196 #[cfg(target_arch = "x86_64")]
197 {
198 use crate::aes::aes_gcm_amd64::gcm_encrypt_aesni;
199 if std::arch::is_x86_feature_detected!("aes")
200 && std::arch::is_x86_feature_detected!("pclmulqdq")
201 && std::arch::is_x86_feature_detected!("ssse3")
202 && std::arch::is_x86_feature_detected!("sse4.1")
203 {
204 return unsafe { gcm_encrypt_aesni(&self.round_keys_ni, &self.h_powers, in_out, nonce_arr, aad) };
205 }
206 }
207 }
208
209 #[cfg(not(feature = "std"))]
210 {
211 #[cfg(all(
212 target_feature = "aes",
213 target_feature = "pclmulqdq",
214 target_feature = "ssse3",
215 target_feature = "sse4.1"
216 ))]
217 {
218 use crate::aes::aes_gcm_amd64::gcm_encrypt_aesni;
219 return unsafe { gcm_encrypt_aesni(&self.round_keys_ni, &self.h_powers, in_out, nonce_arr, aad) };
220 }
221 }
222
223 self.encrypt_in_place_soft(in_out, nonce_arr, aad)
224 }
225
226 #[inline]
227 #[allow(unreachable_code)]
228 fn decrypt_in_place(&self, in_out: &mut [u8], nonce: &[u8], aad: &[u8], tag: &[u8]) -> Result<(), AeadError> {
229 assert_eq!(nonce.len(), 12, "AES-256-GCM nonce must be 12 bytes");
230 let nonce_arr: &[u8; 12] = nonce.try_into().unwrap();
231 let tag_arr: &[u8; 16] = tag.try_into().expect("AES-256-GCM tag must be 16 bytes");
232
233 #[cfg(target_arch = "aarch64")]
234 {
235 use crate::aes::aes_gcm_arm64::gcm_decrypt_armv8;
236 unsafe { return gcm_decrypt_armv8(&self.round_keys_arm, &self.h_powers, in_out, tag_arr, nonce_arr, aad) }
237 }
238
239 #[cfg(feature = "std")]
240 {
241 #[cfg(target_arch = "x86_64")]
242 {
243 use crate::aes::aes_gcm_amd64::gcm_decrypt_aesni;
244 if std::arch::is_x86_feature_detected!("aes")
245 && std::arch::is_x86_feature_detected!("pclmulqdq")
246 && std::arch::is_x86_feature_detected!("ssse3")
247 && std::arch::is_x86_feature_detected!("sse4.1")
248 {
249 unsafe {
250 return gcm_decrypt_aesni(&self.round_keys_ni, &self.h_powers, in_out, tag_arr, nonce_arr, aad);
251 }
252 }
253 }
254 }
255
256 #[cfg(not(feature = "std"))]
257 {
258 #[cfg(all(
259 target_feature = "aes",
260 target_feature = "pclmulqdq",
261 target_feature = "ssse3",
262 target_feature = "sse4.1"
263 ))]
264 {
265 use crate::aes::aes_gcm_amd64::gcm_decrypt_aesni;
266 unsafe {
267 return gcm_decrypt_aesni(&self.round_keys_ni, &self.h_powers, in_out, tag_arr, nonce_arr, aad);
268 }
269 }
270 }
271
272 self.decrypt_in_place_soft(in_out, tag_arr, nonce_arr, aad)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 include!("aes_gcm_vectors.rs");
283
284 fn run_gcm_vector_soft(v: &GcmVector) {
285 let key: [u8; 32] = hex::decode_array::<32>(v.key.as_bytes()).unwrap();
286 let nonce: [u8; 12] = hex::decode_array::<12>(v.nonce.as_bytes()).unwrap();
287 let pt = hex::decode(v.pt).unwrap();
288 let aad = hex::decode(v.aad).unwrap();
289 let expected_ct = hex::decode(v.ct).unwrap();
290 let expected_tag: [u8; 16] = hex::decode_array::<16>(v.tag.as_bytes()).unwrap();
291
292 let cipher = Aes256Gcm::new(&key);
293
294 let mut buf = pt.clone();
296 let tag = cipher.encrypt_in_place_soft(&mut buf, &nonce, &aad);
297 assert_eq!(buf, expected_ct, "ciphertext mismatch for key={}", v.key);
298 assert_eq!(tag.as_ref(), &expected_tag[..], "tag mismatch for key={}", v.key);
299
300 let mut buf2 = expected_ct.clone();
302 cipher
303 .decrypt_in_place_soft(&mut buf2, &expected_tag, &nonce, &aad)
304 .expect("decrypt failed");
305 assert_eq!(buf2, pt, "plaintext mismatch after decrypt for key={}", v.key);
306 }
307
308 #[test]
309 fn nist_gcm_test_vectors_soft() {
310 for v in NIST_GCM_VECTORS.iter().chain(EXTRA_GCM_VECTORS.iter()) {
311 run_gcm_vector_soft(v);
312 }
313 }
314
315 #[test]
316 fn gcm_tag_mismatch_returns_error_soft() {
317 let key = [0u8; 32];
318 let nonce = [0u8; 12];
319 let cipher = Aes256Gcm::new(&key);
320 let mut buf = b"hello world".to_vec();
321 let tag = cipher.encrypt_in_place_soft(&mut buf, &nonce, &[]);
322 let mut bad_tag: [u8; 16] = tag.as_ref().try_into().unwrap();
324 bad_tag[0] ^= 0xff;
325 let mut buf2 = buf.clone();
326 assert!(cipher.decrypt_in_place_soft(&mut buf2, &bad_tag, &nonce, &[]).is_err());
327 }
328
329 #[test]
330 fn gcm_encrypt_decrypt_large_soft() {
331 let key = [0xabu8; 32];
332 let nonce = [0x01u8; 12];
333 let aad = b"additional data";
334 let plaintext: Vec<u8> = (0u8..=255u8).cycle().take(1024).collect();
335
336 let cipher = Aes256Gcm::new(&key);
337 let mut buf = plaintext.clone();
338 let tag = cipher.encrypt_in_place_soft(&mut buf, &nonce, aad);
339 let tag_bytes: [u8; 16] = tag.as_ref().try_into().unwrap();
340 cipher
341 .decrypt_in_place_soft(&mut buf, &tag_bytes, &nonce, aad)
342 .expect("decrypt failed");
343 assert_eq!(buf, plaintext);
344 }
345
346 #[test]
347 fn gcm_empty_plaintext_nonempty_aad_soft() {
348 let key: [u8; 32] =
349 hex::decode_array::<32>(b"feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308").unwrap();
350 let nonce: [u8; 12] = hex::decode_array::<12>(b"cafebabefacedbaddecaf888").unwrap();
351 let aad = hex::decode("feedfacedeadbeeffeedfacedeadbeef").unwrap();
352 let cipher = Aes256Gcm::new(&key);
353 let mut buf: Vec<u8> = vec![];
354 let tag = cipher.encrypt_in_place_soft(&mut buf, &nonce, &aad);
355 let tag_bytes: [u8; 16] = tag.as_ref().try_into().unwrap();
356 cipher
357 .decrypt_in_place_soft(&mut buf, &tag_bytes, &nonce, &aad)
358 .expect("decrypt failed");
359 }
360
361 #[test]
364 fn nist_gcm_test_vectors_dispatch() {
365 for v in NIST_GCM_VECTORS.iter().chain(EXTRA_GCM_VECTORS.iter()) {
366 let key: [u8; 32] = hex::decode_array::<32>(v.key.as_bytes()).unwrap();
367 let nonce: [u8; 12] = hex::decode_array::<12>(v.nonce.as_bytes()).unwrap();
368 let pt = hex::decode(v.pt).unwrap();
369 let aad = hex::decode(v.aad).unwrap();
370 let expected_ct = hex::decode(v.ct).unwrap();
371 let expected_tag: [u8; 16] = hex::decode_array::<16>(v.tag.as_bytes()).unwrap();
372
373 let cipher = Aes256Gcm::new(&key);
374
375 let mut buf = pt.clone();
376 let tag = cipher.encrypt_in_place(&mut buf, &nonce[..], &aad);
377 assert_eq!(&buf[..], &expected_ct[..], "dispatch ciphertext mismatch key={}", v.key);
378 assert_eq!(tag.as_ref(), &expected_tag[..], "dispatch tag mismatch key={}", v.key);
379
380 let mut buf2 = expected_ct.clone();
381 cipher
382 .decrypt_in_place(&mut buf2, &nonce[..], &aad, &expected_tag)
383 .expect("dispatch decrypt failed");
384 assert_eq!(buf2, pt);
385 }
386 }
387
388 #[test]
391 fn wycheproof_gcm_vectors() {
392 let data: serde_json::Value =
393 serde_json::from_str(include_str!("../../testdata/wycheproof/testvectors_v1/aes_gcm_test.json")).unwrap();
394 let mut valid_tested = 0u64;
395 let mut invalid_tested = 0u64;
396 for group in data["testGroups"].as_array().unwrap() {
397 if group["keySize"].as_u64() != Some(256) {
398 continue;
399 }
400 if group["ivSize"].as_u64() != Some(96) {
401 continue;
402 }
403 if group["tagSize"].as_u64() != Some(128) {
404 continue;
405 }
406 for test in group["tests"].as_array().unwrap() {
407 let key_hex = test["key"].as_str().unwrap();
408 let iv_hex = test["iv"].as_str().unwrap();
409 let msg_hex = test["msg"].as_str().unwrap();
410 let aad_hex = test["aad"].as_str().unwrap();
411 let ct_hex = test["ct"].as_str().unwrap();
412 let tag_hex = test["tag"].as_str().unwrap();
413 let result = test["result"].as_str().unwrap();
414
415 let key = hex::decode_array::<32>(key_hex.as_bytes()).unwrap();
416 let nonce = hex::decode_array::<12>(iv_hex.as_bytes()).unwrap();
417 let expected_ct = hex::decode(ct_hex).unwrap();
418 let expected_tag = hex::decode_array::<16>(tag_hex.as_bytes()).unwrap();
419 let pt = hex::decode(msg_hex).unwrap();
420 let aad = hex::decode(aad_hex).unwrap();
421
422 let cipher = Aes256Gcm::new(&key);
423
424 if result == "valid" {
425 let mut buf = pt.clone();
426 let tag = cipher.encrypt_in_place(&mut buf, &nonce[..], &aad);
427 assert_eq!(buf, expected_ct, "wycheproof GCM tcId={} ct mismatch", test["tcId"]);
428 assert_eq!(
429 tag.as_ref(),
430 &expected_tag[..],
431 "wycheproof GCM tcId={} tag mismatch",
432 test["tcId"]
433 );
434
435 let mut buf2 = expected_ct.clone();
436 cipher
437 .decrypt_in_place(&mut buf2, &nonce[..], &aad, &expected_tag[..])
438 .expect("wycheproof GCM decrypt failed");
439 assert_eq!(buf2, pt, "wycheproof GCM tcId={} pt mismatch", test["tcId"]);
440 valid_tested += 1;
441 } else {
442 let mut buf = expected_ct.clone();
443 let result = cipher.decrypt_in_place(&mut buf, &nonce[..], &aad, &expected_tag[..]);
444 assert!(
445 result.is_err(),
446 "wycheproof GCM tcId={} expected invalid but passed",
447 test["tcId"]
448 );
449 invalid_tested += 1;
450 }
451 }
452 }
453 assert!(valid_tested > 0, "no valid AES-GCM wycheproof tests were run");
454 assert!(invalid_tested > 0, "no invalid AES-GCM wycheproof tests were run");
455 }
456}