1#![allow(unsafe_op_in_unsafe_fn)]
2#[cfg(target_arch = "aarch64")]
3use core::arch::aarch64::*;
4#[cfg(target_arch = "x86_64")]
5use core::arch::x86_64::*;
6
7use super::aes::{RoundKeys, encrypt_block, key_expand};
13#[cfg(target_arch = "x86_64")]
14use super::aes_amd64::aes_encrypt_block;
15#[cfg(target_arch = "aarch64")]
16use super::aes_arm64::aes_encrypt_block;
17#[cfg(target_arch = "x86_64")]
18use super::aes_ctr_amd64::ctr_inc as ctr_inc_ni;
19#[cfg(target_arch = "aarch64")]
20use super::aes_ctr_arm64::ctr_inc as ctr_inc_arm;
21use crate::StreamCipher;
22
23pub struct Aes256Ctr {
30 round_keys: RoundKeys,
31 #[cfg(target_arch = "x86_64")]
32 round_keys_aesni: [__m128i; 15],
33 #[cfg(target_arch = "aarch64")]
34 round_keys_armv8: [uint8x16_t; 15],
35 counter: [u8; 16],
36}
37
38impl Aes256Ctr {
39 pub fn new(key: &[u8; 32]) -> Self {
43 let round_keys = key_expand(key);
44 #[cfg(target_arch = "x86_64")]
45 {
46 let mut round_keys_aesni = unsafe { [_mm_setzero_si128(); 15] };
47 for i in 0..15 {
48 round_keys_aesni[i] = unsafe { _mm_loadu_si128(round_keys[i].as_ptr().cast()) };
49 }
50 return Self {
51 round_keys,
52 round_keys_aesni,
53 counter: [0u8; 16],
54 };
55 }
56 #[cfg(target_arch = "aarch64")]
57 {
58 let mut round_keys_armv8 = [unsafe { vdupq_n_u8(0) }; 15];
59 for i in 0..15 {
60 round_keys_armv8[i] = unsafe { vld1q_u8(round_keys[i].as_ptr()) };
61 }
62 return Self {
63 round_keys,
64 round_keys_armv8,
65 counter: [0u8; 16],
66 };
67 }
68 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
69 {
70 Self {
71 round_keys,
72 counter: [0u8; 16],
73 }
74 }
75 }
76
77 fn xor_keystream_soft(&mut self, in_out: &mut [u8]) {
78 let n = in_out.len();
79 let mut i = 0;
80 while i + 16 <= n {
81 let ks = encrypt_block(&self.round_keys, &self.counter);
82 for k in 0..16 {
83 in_out[i + k] ^= ks[k];
84 }
85 self.increment_counter();
86 i += 16;
87 }
88 if i < n {
89 let ks = encrypt_block(&self.round_keys, &self.counter);
90 for k in 0..n - i {
91 in_out[i + k] ^= ks[k];
92 }
93 }
94 }
95
96 #[cfg(target_arch = "x86_64")]
97 #[target_feature(enable = "aes,ssse3,sse2")]
98 unsafe fn xor_keystream_aesni(&mut self, in_out: &mut [u8]) {
99 let n = in_out.len();
100 let mut i = 0;
101 let mut ctr = _mm_loadu_si128(self.counter.as_ptr().cast());
102
103 while i + 16 <= n {
104 let ks = aes_encrypt_block(&self.round_keys_aesni, ctr);
105 let p = _mm_loadu_si128(in_out.as_ptr().add(i).cast());
106 _mm_storeu_si128(in_out.as_mut_ptr().add(i).cast(), _mm_xor_si128(p, ks));
107 ctr = ctr_inc_ni(ctr);
108 i += 16;
109 }
110 if i < n {
111 let ks = aes_encrypt_block(&self.round_keys_aesni, ctr);
112 let mut ks_bytes = [0u8; 16];
113 _mm_storeu_si128(ks_bytes.as_mut_ptr().cast(), ks);
114 for k in 0..n - i {
115 in_out[i + k] ^= ks_bytes[k];
116 }
117 }
118
119 _mm_storeu_si128(self.counter.as_mut_ptr().cast(), ctr);
120 }
121
122 #[cfg(target_arch = "aarch64")]
123 unsafe fn xor_keystream_armv8(&mut self, in_out: &mut [u8]) {
124 let n = in_out.len();
125 let mut i = 0;
126 let mut ctr = vld1q_u8(self.counter.as_ptr());
127
128 while i + 16 <= n {
129 let ks = aes_encrypt_block(&self.round_keys_armv8, ctr);
130 let p = vld1q_u8(in_out.as_ptr().add(i));
131 vst1q_u8(in_out.as_mut_ptr().add(i), veorq_u8(p, ks));
132 ctr = ctr_inc_arm(ctr);
133 i += 16;
134 }
135 if i < n {
136 let ks = aes_encrypt_block(&self.round_keys_armv8, ctr);
137 let mut ks_bytes = [0u8; 16];
138 vst1q_u8(ks_bytes.as_mut_ptr(), ks);
139 for k in 0..n - i {
140 in_out[i + k] ^= ks_bytes[k];
141 }
142 }
143
144 vst1q_u8(self.counter.as_mut_ptr(), ctr);
145 }
146
147 #[inline]
151 pub fn set_counter(&mut self, counter: &[u8; 16]) {
152 self.counter = *counter;
153 }
154
155 #[inline]
156 fn increment_counter(&mut self) {
157 let counter_value = u32::from_be_bytes(self.counter[12..16].try_into().unwrap());
158 self.counter[12..16].copy_from_slice(&counter_value.wrapping_add(1).to_be_bytes());
159 }
160}
161
162impl StreamCipher for Aes256Ctr {
163 #[allow(unreachable_code)]
164 fn xor_keystream(&mut self, in_out: &mut [u8]) {
165 #[cfg(target_arch = "aarch64")]
166 {
167 unsafe {
168 self.xor_keystream_armv8(in_out);
169 }
170 return;
171 }
172
173 #[cfg(feature = "std")]
174 {
175 #[cfg(target_arch = "x86_64")]
176 {
177 if std::arch::is_x86_feature_detected!("aes") && std::arch::is_x86_feature_detected!("ssse3") {
178 unsafe {
179 self.xor_keystream_aesni(in_out);
180 }
181 return;
182 }
183 }
184 }
185
186 #[cfg(not(feature = "std"))]
187 {
188 #[cfg(all(target_arch = "x86_64", target_feature = "aes", target_feature = "ssse3"))]
189 {
190 unsafe {
191 self.xor_keystream_aesni(in_out);
192 }
193 return;
194 }
195 }
196
197 self.xor_keystream_soft(in_out);
198 }
199}