Skip to main content

crypto/aes/
aes_ctr.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2#[cfg(target_arch = "aarch64")]
3use core::arch::aarch64::*;
4#[cfg(target_arch = "x86_64")]
5use core::arch::x86_64::*;
6
7/// AES-256-CTR stream cipher with hardware acceleration.
8///
9/// Wraps the AES-256 block cipher in CTR mode as a [`StreamCipher`].
10/// On x86-64 with AES-NI + SSSE3, and on aarch64 with ARMv8 Crypto
11/// extensions, the keystream generation is hardware-accelerated.
12use super::aes::{RoundKeys, encrypt_block, key_expand};
13#[cfg(target_arch = "x86_64")]
14use super::aes_amd64::aes_encrypt_block;
15#[cfg(target_arch = "aarch64")]
16use super::aes_arm64::aes_encrypt_block;
17#[cfg(target_arch = "x86_64")]
18use super::aes_ctr_amd64::ctr_inc as ctr_inc_ni;
19#[cfg(target_arch = "aarch64")]
20use super::aes_ctr_arm64::ctr_inc as ctr_inc_arm;
21use crate::StreamCipher;
22
23/// AES-256 in CTR mode.
24///
25/// Create a new cipher with [`new`](Aes256Ctr::new).
26/// [`xor_keystream`](StreamCipher::xor_keystream) to encrypt or decrypt
27/// (CTR mode is symmetric).
28/// You can move in the keystream with [`set_counter`](Aes256Ctr::set_counter).
29pub struct Aes256Ctr {
30    round_keys: RoundKeys,
31    #[cfg(target_arch = "x86_64")]
32    round_keys_aesni: [__m128i; 15],
33    #[cfg(target_arch = "aarch64")]
34    round_keys_armv8: [uint8x16_t; 15],
35    counter: [u8; 16],
36}
37
38impl Aes256Ctr {
39    /// Create a new cipher from a 32-byte key.
40    ///
41    /// The initial counter is zeroed.
42    pub fn new(key: &[u8; 32]) -> Self {
43        let round_keys = key_expand(key);
44        #[cfg(target_arch = "x86_64")]
45        {
46            let mut round_keys_aesni = unsafe { [_mm_setzero_si128(); 15] };
47            for i in 0..15 {
48                round_keys_aesni[i] = unsafe { _mm_loadu_si128(round_keys[i].as_ptr().cast()) };
49            }
50            return Self {
51                round_keys,
52                round_keys_aesni,
53                counter: [0u8; 16],
54            };
55        }
56        #[cfg(target_arch = "aarch64")]
57        {
58            let mut round_keys_armv8 = [unsafe { vdupq_n_u8(0) }; 15];
59            for i in 0..15 {
60                round_keys_armv8[i] = unsafe { vld1q_u8(round_keys[i].as_ptr()) };
61            }
62            return Self {
63                round_keys,
64                round_keys_armv8,
65                counter: [0u8; 16],
66            };
67        }
68        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
69        {
70            Self {
71                round_keys,
72                counter: [0u8; 16],
73            }
74        }
75    }
76
77    fn xor_keystream_soft(&mut self, in_out: &mut [u8]) {
78        let n = in_out.len();
79        let mut i = 0;
80        while i + 16 <= n {
81            let ks = encrypt_block(&self.round_keys, &self.counter);
82            for k in 0..16 {
83                in_out[i + k] ^= ks[k];
84            }
85            self.increment_counter();
86            i += 16;
87        }
88        if i < n {
89            let ks = encrypt_block(&self.round_keys, &self.counter);
90            for k in 0..n - i {
91                in_out[i + k] ^= ks[k];
92            }
93        }
94    }
95
96    #[cfg(target_arch = "x86_64")]
97    #[target_feature(enable = "aes,ssse3,sse2")]
98    unsafe fn xor_keystream_aesni(&mut self, in_out: &mut [u8]) {
99        let n = in_out.len();
100        let mut i = 0;
101        let mut ctr = _mm_loadu_si128(self.counter.as_ptr().cast());
102
103        while i + 16 <= n {
104            let ks = aes_encrypt_block(&self.round_keys_aesni, ctr);
105            let p = _mm_loadu_si128(in_out.as_ptr().add(i).cast());
106            _mm_storeu_si128(in_out.as_mut_ptr().add(i).cast(), _mm_xor_si128(p, ks));
107            ctr = ctr_inc_ni(ctr);
108            i += 16;
109        }
110        if i < n {
111            let ks = aes_encrypt_block(&self.round_keys_aesni, ctr);
112            let mut ks_bytes = [0u8; 16];
113            _mm_storeu_si128(ks_bytes.as_mut_ptr().cast(), ks);
114            for k in 0..n - i {
115                in_out[i + k] ^= ks_bytes[k];
116            }
117        }
118
119        _mm_storeu_si128(self.counter.as_mut_ptr().cast(), ctr);
120    }
121
122    #[cfg(target_arch = "aarch64")]
123    unsafe fn xor_keystream_armv8(&mut self, in_out: &mut [u8]) {
124        let n = in_out.len();
125        let mut i = 0;
126        let mut ctr = vld1q_u8(self.counter.as_ptr());
127
128        while i + 16 <= n {
129            let ks = aes_encrypt_block(&self.round_keys_armv8, ctr);
130            let p = vld1q_u8(in_out.as_ptr().add(i));
131            vst1q_u8(in_out.as_mut_ptr().add(i), veorq_u8(p, ks));
132            ctr = ctr_inc_arm(ctr);
133            i += 16;
134        }
135        if i < n {
136            let ks = aes_encrypt_block(&self.round_keys_armv8, ctr);
137            let mut ks_bytes = [0u8; 16];
138            vst1q_u8(ks_bytes.as_mut_ptr(), ks);
139            for k in 0..n - i {
140                in_out[i + k] ^= ks_bytes[k];
141            }
142        }
143
144        vst1q_u8(self.counter.as_mut_ptr(), ctr);
145    }
146
147    /// Set the 16-byte counter block.
148    ///
149    /// For GCM this is `nonce || 0x00000002` (J₀ + 1).
150    #[inline]
151    pub fn set_counter(&mut self, counter: &[u8; 16]) {
152        self.counter = *counter;
153    }
154
155    #[inline]
156    fn increment_counter(&mut self) {
157        let counter_value = u32::from_be_bytes(self.counter[12..16].try_into().unwrap());
158        self.counter[12..16].copy_from_slice(&counter_value.wrapping_add(1).to_be_bytes());
159    }
160}
161
162impl StreamCipher for Aes256Ctr {
163    #[allow(unreachable_code)]
164    fn xor_keystream(&mut self, in_out: &mut [u8]) {
165        #[cfg(target_arch = "aarch64")]
166        {
167            unsafe {
168                self.xor_keystream_armv8(in_out);
169            }
170            return;
171        }
172
173        #[cfg(feature = "std")]
174        {
175            #[cfg(target_arch = "x86_64")]
176            {
177                if std::arch::is_x86_feature_detected!("aes") && std::arch::is_x86_feature_detected!("ssse3") {
178                    unsafe {
179                        self.xor_keystream_aesni(in_out);
180                    }
181                    return;
182                }
183            }
184        }
185
186        #[cfg(not(feature = "std"))]
187        {
188            #[cfg(all(target_arch = "x86_64", target_feature = "aes", target_feature = "ssse3"))]
189            {
190                unsafe {
191                    self.xor_keystream_aesni(in_out);
192                }
193                return;
194            }
195        }
196
197        self.xor_keystream_soft(in_out);
198    }
199}