Skip to main content

crypto/
hkdf.rs

1use crate::{Hash, Hasher, HkdfError, hmac::Hmac};
2
3const DEFAULT_SALT: [u8; 64] = [0u8; 64];
4
5/// HKDF extract step: `PRK = HMAC-Hash(salt, IKM)`.
6///
7/// If `salt` is `None`, a string of `H::OUTPUT_SIZE` zero bytes is used.
8///
9/// # Example
10///
11/// ```ignore
12/// use crypto::hkdf;
13/// use crypto::sha2::Sha256;
14///
15/// let prk = hkdf::extract::<Sha256>(Some(b"salt"), b"input key material");
16/// ```
17pub fn extract<H: Hasher>(salt: Option<&[u8]>, ikm: &[u8]) -> Hash {
18    let salt = salt.unwrap_or(&DEFAULT_SALT[..H::OUTPUT_SIZE]);
19    let mut mac = Hmac::<H>::new(salt);
20    mac.update(ikm);
21    return mac.finalize();
22}
23
24/// HKDF expand step: `OKM = T(1) || T(2) || ...`, where
25/// `T(i) = HMAC-Hash(PRK, T(i-1) || info || i)`.
26///
27/// # Example
28///
29/// ```ignore
30/// use crypto::hkdf;
31/// use crypto::sha2::Sha256;
32///
33/// let prk = hkdf::extract::<Sha256>(Some(b"salt"), b"input key material");
34/// let okm: [u8; 32] = hkdf::expand::<Sha256, 32>(&prk, b"context info").unwrap();
35/// ```
36///
37/// # Error
38///
39/// Returns an error if `N > 255 * H::OUTPUT_SIZE` or if `prk.len() < H::OUTPUT_SIZE`.
40pub fn expand<H: Hasher, const N: usize>(prk: &[u8], info: &[u8]) -> Result<[u8; N], HkdfError> {
41    if prk.len() < H::OUTPUT_SIZE {
42        return Err(HkdfError::PrkIsTooShort(H::OUTPUT_SIZE));
43    }
44
45    if N > 255 * H::OUTPUT_SIZE {
46        return Err(HkdfError::OutputIsTooLong);
47    }
48
49    let mut okm = [0u8; N];
50    if N == 0 {
51        return Ok(okm);
52    }
53
54    let mut t = [0u8; 64];
55    let mut t_len = 0usize;
56    let mut offset = 0usize;
57    let mut counter = 1u8;
58
59    while offset < N {
60        let mut mac = Hmac::<H>::new(&prk[..H::OUTPUT_SIZE]);
61        mac.update(&t[..t_len]);
62        mac.update(info);
63        mac.update(&[counter]);
64        let block = mac.finalize();
65        let block_bytes = block.as_ref();
66        let chunk_len = (N - offset).min(H::OUTPUT_SIZE);
67        okm[offset..offset + chunk_len].copy_from_slice(&block_bytes[..chunk_len]);
68        t[..H::OUTPUT_SIZE].copy_from_slice(block_bytes);
69        t_len = H::OUTPUT_SIZE;
70        offset += chunk_len;
71        counter = counter.wrapping_add(1);
72    }
73
74    return Ok(okm);
75}
76
77/// One-shot HKDF: extract-then-expand in a single call.
78///
79/// # Example
80///
81/// ```ignore
82/// use crypto::hkdf;
83/// use crypto::sha2::Sha256;
84///
85/// let okm: [u8; 32] = hkdf::derive_key::<Sha256, 32>(
86///     b"input key material",
87///     b"context info",
88///     Some(b"salt"),
89/// ).unwrap();
90/// ```
91///
92/// # Error
93///
94/// Returns an error if `N > 255 * H::OUTPUT_SIZE`.
95pub fn derive_key<H: Hasher, const N: usize>(
96    ikm: &[u8],
97    info: &[u8],
98    salt: Option<&[u8]>,
99) -> Result<[u8; N], HkdfError> {
100    let prk = extract::<H>(salt, ikm);
101    return expand::<H, N>(prk.as_ref(), info);
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::sha2::{Sha256, Sha384, Sha512};
108
109    struct TestVector {
110        ikm: &'static str,
111        salt: Option<&'static str>,
112        info: &'static str,
113        expected_prk: &'static str,
114        expected_okm: &'static str,
115    }
116
117    fn decode_hex(input: &str) -> Vec<u8> {
118        let input = input.replace(|c: char| c.is_whitespace(), "");
119        (0..input.len())
120            .step_by(2)
121            .map(|i| u8::from_str_radix(&input[i..i + 2], 16).unwrap())
122            .collect()
123    }
124
125    const SHA256_VECTORS: [TestVector; 4] = [
126        TestVector {
127            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
128            salt: Some("000102030405060708090a0b0c"),
129            info: "f0f1f2f3f4f5f6f7f8f9",
130            expected_prk: "077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5",
131            expected_okm: "3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf34007208d5b887185865",
132        },
133        TestVector {
134            ikm: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f\
135                  202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f\
136                  404142434445464748494a4b4c4d4e4f",
137            salt: Some(
138                "606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f\
139                 808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f\
140                 a0a1a2a3a4a5a6a7a8a9aaabacadaeaf",
141            ),
142            info: "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecf\
143                  d0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeef\
144                  f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
145            expected_prk: "06a6b88c5853361a06104c9ceb35b45cef760014904671014a193f40c15fc244",
146            expected_okm: "b11e398dc80327a1c8e7f78c596a49344f012eda2d4efad8a050cc4c19afa97c59045a99cac7827271cb41c65e590e09da3275600c2f09b8367793a9aca3db71cc30c58179ec3e87c14c01d5c1f3434f1d87",
147        },
148        TestVector {
149            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
150            salt: Some(""),
151            info: "",
152            expected_prk: "19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04",
153            expected_okm: "8da4e775a563c18f715f802a063c5a31b8a11f5c5ee1879ec3454e5f3c738d2d9d201395faa4b61a96c8",
154        },
155        TestVector {
156            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
157            salt: None,
158            info: "",
159            expected_prk: "19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04",
160            expected_okm: "8da4e775a563c18f715f802a063c5a31b8a11f5c5ee1879ec3454e5f3c738d2d9d201395faa4b61a96c8",
161        },
162    ];
163
164    const SHA512_VECTORS: [TestVector; 4] = [
165        TestVector {
166            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
167            salt: Some("000102030405060708090a0b0c"),
168            info: "f0f1f2f3f4f5f6f7f8f9",
169            expected_prk: "665799823737ded04a88e47e54a5890bb2c3d247c7a4254a8e61350723590a26c36238127d8661b88cf80ef802d57e2f7cebcf1e00e083848be19929c61b4237",
170            expected_okm: "832390086cda71fb47625bb5ceb168e4c8e26a1a16ed34d9fc7fe92c1481579338da362cb8d9f925d7cb",
171        },
172        TestVector {
173            ikm: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f\
174                  202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f\
175                  404142434445464748494a4b4c4d4e4f",
176            salt: Some(
177                "606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f\
178                 808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f\
179                 a0a1a2a3a4a5a6a7a8a9aaabacadaeaf",
180            ),
181            info: "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecf\
182                  d0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeef\
183                  f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
184            expected_prk: "35672542907d4e142c00e84499e74e1de08be86535f924e022804ad775dde27ec86cd1e5b7d178c74489bdbeb30712beb82d4f97416c5a94ea81ebdf3e629e4a",
185            expected_okm: "ce6c97192805b346e6161e821ed165673b84f400a2b514b2fe23d84cd189ddf1b695b48cbd1c8388441137b3ce28f16aa64ba33ba466b24df6cfcb021ecff235f6a2056ce3af1de44d572097a8505d9e7a93",
186        },
187        TestVector {
188            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
189            salt: Some(""),
190            info: "",
191            expected_prk: "fd200c4987ac491313bd4a2a13287121247239e11c9ef82802044b66ef357e5b194498d0682611382348572a7b1611de54764094286320578a863f36562b0df6",
192            expected_okm: "f5fa02b18298a72a8c23898a8703472c6eb179dc204c03425c970e3b164bf90fff22d04836d0e2343bac",
193        },
194        TestVector {
195            ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
196            salt: None,
197            info: "",
198            expected_prk: "fd200c4987ac491313bd4a2a13287121247239e11c9ef82802044b66ef357e5b194498d0682611382348572a7b1611de54764094286320578a863f36562b0df6",
199            expected_okm: "f5fa02b18298a72a8c23898a8703472c6eb179dc204c03425c970e3b164bf90fff22d04836d0e2343bac",
200        },
201    ];
202
203    #[test]
204    fn hkdf_sha256_vectors() {
205        for (i, vector) in SHA256_VECTORS.iter().enumerate() {
206            let ikm = decode_hex(vector.ikm);
207            let salt = vector.salt.map(decode_hex);
208            let info = decode_hex(vector.info);
209            let expected_prk = decode_hex(vector.expected_prk);
210            let expected_okm = decode_hex(vector.expected_okm);
211
212            let prk = extract::<Sha256>(salt.as_deref(), &ikm);
213            assert_eq!(prk.as_ref(), expected_prk.as_slice(), "vector {} PRK", i);
214
215            let okm = match expected_okm.len() {
216                42 => expand::<Sha256, 42>(prk.as_ref(), &info).unwrap().to_vec(),
217                82 => expand::<Sha256, 82>(prk.as_ref(), &info).unwrap().to_vec(),
218                _ => unreachable!(),
219            };
220            assert_eq!(okm, expected_okm, "vector {} OKM", i);
221
222            let derived = match expected_okm.len() {
223                42 => derive_key::<Sha256, 42>(&ikm, &info, salt.as_deref()).unwrap().to_vec(),
224                82 => derive_key::<Sha256, 82>(&ikm, &info, salt.as_deref()).unwrap().to_vec(),
225                _ => unreachable!(),
226            };
227            assert_eq!(derived, expected_okm, "vector {} derive_key OKM", i);
228        }
229    }
230
231    #[test]
232    fn hkdf_sha512_vectors() {
233        for (i, vector) in SHA512_VECTORS.iter().enumerate() {
234            let ikm = decode_hex(vector.ikm);
235            let salt = vector.salt.map(decode_hex);
236            let info = decode_hex(vector.info);
237            let expected_prk = decode_hex(vector.expected_prk);
238            let expected_okm = decode_hex(vector.expected_okm);
239
240            let prk = extract::<Sha512>(salt.as_deref(), &ikm);
241            assert_eq!(prk.as_ref(), expected_prk.as_slice(), "vector {} PRK", i);
242
243            let okm = match expected_okm.len() {
244                42 => expand::<Sha512, 42>(prk.as_ref(), &info).unwrap().to_vec(),
245                82 => expand::<Sha512, 82>(prk.as_ref(), &info).unwrap().to_vec(),
246                _ => unreachable!(),
247            };
248            assert_eq!(okm, expected_okm, "vector {} OKM", i);
249
250            let derived = match expected_okm.len() {
251                42 => derive_key::<Sha512, 42>(&ikm, &info, salt.as_deref()).unwrap().to_vec(),
252                82 => derive_key::<Sha512, 82>(&ikm, &info, salt.as_deref()).unwrap().to_vec(),
253                _ => unreachable!(),
254            };
255            assert_eq!(derived, expected_okm, "vector {} derive_key OKM", i);
256        }
257    }
258
259    #[test]
260    fn hkdf_zero_length_output() {
261        let prk = [0u8; 32];
262        assert_eq!(expand::<Sha256, 0>(&prk, b"").unwrap(), [] as [u8; 0]);
263        assert_eq!(derive_key::<Sha256, 0>(b"ikm", b"info", None).unwrap(), [] as [u8; 0]);
264    }
265
266    #[test]
267    fn hkdf_expand_panics_when_output_is_too_large() {
268        let prk = [0u8; 32];
269        const N: usize = Sha256::BLOCK_SIZE * 300;
270        assert_eq!(expand::<Sha256, N>(&prk, b""), Err(HkdfError::OutputIsTooLong));
271    }
272
273    #[test]
274    fn hkdf_expand_panics_when_prk_is_too_short() {
275        assert_eq!(
276            expand::<Sha256, 32>(&[0u8; 31], b""),
277            Err(HkdfError::PrkIsTooShort(Sha256::OUTPUT_SIZE))
278        );
279    }
280
281    // --- Wycheproof test vectors ---
282
283    #[test]
284    fn hkdf_sha256_wycheproof() {
285        // Maximum valid HKDF-SHA-256 output: 255 * 32 = 8160 bytes.
286        const MAX_OKM: usize = 8160;
287        const SIZE_TOO_LARGE: usize = 8161;
288
289        let data: serde_json::Value =
290            serde_json::from_str(include_str!("../testdata/wycheproof/testvectors_v1/hkdf_sha256_test.json")).unwrap();
291        let mut valid_tested = 0u64;
292        let mut invalid_tested = 0u64;
293        for group in data["testGroups"].as_array().unwrap() {
294            for test in group["tests"].as_array().unwrap() {
295                let ikm_hex = test["ikm"].as_str().unwrap();
296                let salt_hex = test["salt"].as_str().unwrap();
297                let info_hex = test["info"].as_str().unwrap();
298                let size = test["size"].as_u64().unwrap() as usize;
299                let expected_okm_hex = test["okm"].as_str().unwrap();
300                let result = test["result"].as_str().unwrap();
301
302                let ikm = hex::decode(ikm_hex).unwrap();
303                let info = hex::decode(info_hex).unwrap();
304                let salt: Option<Vec<u8>> = if salt_hex.is_empty() {
305                    None
306                } else {
307                    Some(hex::decode(salt_hex).unwrap())
308                };
309
310                if result == "valid" {
311                    let okm = derive_key::<Sha256, MAX_OKM>(&ikm, &info, salt.as_deref()).unwrap();
312                    let okm_hex = hex::encode(&okm[..size]);
313                    assert_eq!(
314                        okm_hex, expected_okm_hex,
315                        "wycheproof HKDF-SHA-256 tcId={} size={}",
316                        test["tcId"], size
317                    );
318                    valid_tested += 1;
319                } else {
320                    assert_eq!(
321                        derive_key::<Sha256, SIZE_TOO_LARGE>(&ikm, &info, salt.as_deref()),
322                        Err(HkdfError::OutputIsTooLong),
323                        "wycheproof HKDF-SHA-256 tcId={} size={} should reject",
324                        test["tcId"],
325                        size
326                    );
327                    invalid_tested += 1;
328                }
329            }
330        }
331        assert!(valid_tested > 0, "no valid HKDF-SHA-256 wycheproof tests were run");
332        assert!(invalid_tested > 0, "no invalid HKDF-SHA-256 wycheproof tests were run");
333    }
334
335    #[test]
336    fn hkdf_sha512_wycheproof() {
337        // Maximum valid HKDF-SHA-512 output: 255 * 64 = 16320 bytes.
338        const MAX_OKM: usize = 16320;
339        const SIZE_TOO_LARGE: usize = 16321;
340
341        let data: serde_json::Value =
342            serde_json::from_str(include_str!("../testdata/wycheproof/testvectors_v1/hkdf_sha512_test.json")).unwrap();
343        let mut valid_tested = 0u64;
344        let mut invalid_tested = 0u64;
345        for group in data["testGroups"].as_array().unwrap() {
346            for test in group["tests"].as_array().unwrap() {
347                let ikm_hex = test["ikm"].as_str().unwrap();
348                let salt_hex = test["salt"].as_str().unwrap();
349                let info_hex = test["info"].as_str().unwrap();
350                let size = test["size"].as_u64().unwrap() as usize;
351                let expected_okm_hex = test["okm"].as_str().unwrap();
352                let result = test["result"].as_str().unwrap();
353
354                let ikm = hex::decode(ikm_hex).unwrap();
355                let info = hex::decode(info_hex).unwrap();
356                let salt: Option<Vec<u8>> = if salt_hex.is_empty() {
357                    None
358                } else {
359                    Some(hex::decode(salt_hex).unwrap())
360                };
361
362                if result == "valid" {
363                    let okm = derive_key::<Sha512, MAX_OKM>(&ikm, &info, salt.as_deref()).unwrap();
364                    let okm_hex = hex::encode(&okm[..size]);
365                    assert_eq!(
366                        okm_hex, expected_okm_hex,
367                        "wycheproof HKDF-SHA-512 tcId={} size={}",
368                        test["tcId"], size
369                    );
370                    valid_tested += 1;
371                } else {
372                    assert_eq!(
373                        derive_key::<Sha512, SIZE_TOO_LARGE>(&ikm, &info, salt.as_deref()),
374                        Err(HkdfError::OutputIsTooLong),
375                        "wycheproof HKDF-SHA-512 tcId={} size={} should reject",
376                        test["tcId"],
377                        size
378                    );
379                    invalid_tested += 1;
380                }
381            }
382        }
383        assert!(valid_tested > 0, "no valid HKDF-SHA-512 wycheproof tests were run");
384        assert!(invalid_tested > 0, "no invalid HKDF-SHA-512 wycheproof tests were run");
385    }
386
387    #[test]
388    fn hkdf_sha384_wycheproof() {
389        // Maximum valid HKDF-SHA-384 output: 255 * 48 = 12240 bytes.
390        const MAX_OKM: usize = 12240;
391        const SIZE_TOO_LARGE: usize = 12241;
392
393        let data: serde_json::Value =
394            serde_json::from_str(include_str!("../testdata/wycheproof/testvectors_v1/hkdf_sha384_test.json")).unwrap();
395        let mut valid_tested = 0u64;
396        let mut invalid_tested = 0u64;
397        for group in data["testGroups"].as_array().unwrap() {
398            for test in group["tests"].as_array().unwrap() {
399                let ikm_hex = test["ikm"].as_str().unwrap();
400                let salt_hex = test["salt"].as_str().unwrap();
401                let info_hex = test["info"].as_str().unwrap();
402                let size = test["size"].as_u64().unwrap() as usize;
403                let expected_okm_hex = test["okm"].as_str().unwrap();
404                let result = test["result"].as_str().unwrap();
405
406                let ikm = hex::decode(ikm_hex).unwrap();
407                let info = hex::decode(info_hex).unwrap();
408                let salt: Option<Vec<u8>> = if salt_hex.is_empty() {
409                    None
410                } else {
411                    Some(hex::decode(salt_hex).unwrap())
412                };
413
414                if result == "valid" {
415                    let okm = derive_key::<Sha384, MAX_OKM>(&ikm, &info, salt.as_deref()).unwrap();
416                    let okm_hex = hex::encode(&okm[..size]);
417                    assert_eq!(
418                        okm_hex, expected_okm_hex,
419                        "wycheproof HKDF-SHA-384 tcId={} size={}",
420                        test["tcId"], size
421                    );
422                    valid_tested += 1;
423                } else {
424                    assert_eq!(
425                        derive_key::<Sha384, SIZE_TOO_LARGE>(&ikm, &info, salt.as_deref()),
426                        Err(HkdfError::OutputIsTooLong),
427                        "wycheproof HKDF-SHA-384 tcId={} size={} should reject",
428                        test["tcId"],
429                        size
430                    );
431                    invalid_tested += 1;
432                }
433            }
434        }
435        assert!(valid_tested > 0, "no valid HKDF-SHA-384 wycheproof tests were run");
436        assert!(invalid_tested > 0, "no invalid HKDF-SHA-384 wycheproof tests were run");
437    }
438}