slh_dsa.py (7591B)
1 """ 2 SLH_DSA context class 3 """ 4 import hashlib 5 import hmac 6 from utils import cdiv, toByte 7 8 ALLOWED_PSETS = [ 9 "SLH-DSA-SHA2-128s", 10 "SLH-DSA-SHAKE-128s", 11 "SLH-DSA-SHA2-128f", 12 "SLH-DSA-SHAKE-128f", 13 "SLH-DSA-SHA2-192s", 14 "SLH-DSA-SHAKE-192s", 15 "SLH-DSA-SHA2-192f", 16 "SLH-DSA-SHAKE-192f", 17 "SLH-DSA-SHA2-256s", 18 "SLH-DSA-SHAKE-256s", 19 "SLH-DSA-SHA2-256f", 20 "SLH-DSA-SHAKE-256f"] 21 22 SHA256_DIGEST_LEN = 32 23 SHA512_DIGEST_LEN = 64 24 25 def sha256(m: bytes) -> bytes: 26 """SHA256 wrapper function""" 27 h = hashlib.sha256() 28 h.update(m) 29 return h.digest() 30 31 def sha512(m: bytes) -> bytes: 32 """SHA256 wrapper function""" 33 h = hashlib.sha512() 34 h.update(m) 35 return h.digest() 36 37 def mgf1_sha256(seed: bytes, mlen: int) -> bytes: 38 """MGF1 mask generation function with SHA256. 39 See NIST SP 800-56B rev 2 section 7.2.2.2.""" 40 hash_len = SHA256_DIGEST_LEN 41 t = b"" 42 for c in range(0, cdiv(mlen, hash_len)): 43 t += sha256(seed + toByte(c, 4)) 44 return t[:mlen] 45 46 def mgf1_sha512(seed: bytes, mlen: int) -> bytes: 47 """MGF1 mask generation function with SHA256. 48 See NIST SP 800-56B rev 2 section 7.2.2.2.""" 49 hash_len = SHA512_DIGEST_LEN 50 t = b"" 51 for c in range(0, cdiv(mlen, hash_len)): 52 t += sha512(seed + toByte(c, 4)) 53 return t[:mlen] 54 55 def hmac_sha256(k: bytes, msg: bytes) -> bytes: 56 """HMAC-SHA-256 wrapper function""" 57 return hmac.digest(k, msg, hashlib.sha256) 58 59 def hmac_sha512(k: bytes, msg: bytes) -> bytes: 60 """HMAC-SHA-512 wrapper function""" 61 return hmac.digest(k, msg, hashlib.sha512) 62 63 class SLHDSA: 64 """Context class for SLH-DSA""" 65 def __init__(self, pset: str): 66 assert pset in ALLOWED_PSETS 67 68 if pset == "SLH-DSA-SHA2-128s" or pset == "SLH-DSA-SHAKE-128s": 69 self.n = 16 70 self.h = 63 71 self.d = 7 72 self.hp = 9 73 self.a = 12 74 self.k = 14 75 self.lg_w = 4 76 self.m = 30 77 self.sec_lvl = 1 78 self.pk_bytes = 32 79 self.sig_bytes = 7856 80 if pset == "SLH-DSA-SHA2-128f" or pset == "SLH-DSA-SHAKE-128f": 81 self.n = 16 82 self.h = 66 83 self.d = 22 84 self.hp = 3 85 self.a = 6 86 self.k = 33 87 self.lg_w = 4 88 self.m = 34 89 self.sec_lvl = 1 90 self.pk_bytes = 32 91 self.sig_bytes = 17088 92 if pset == "SLH-DSA-SHA2-192s" or pset == "SLH-DSA-SHAKE-192s": 93 self.n = 24 94 self.h = 63 95 self.d = 7 96 self.hp = 9 97 self.a = 14 98 self.k = 17 99 self.lg_w = 4 100 self.m = 39 101 self.sec_lvl = 3 102 self.pk_bytes = 48 103 self.sig_bytes = 16224 104 if pset == "SLH-DSA-SHA2-192f" or pset == "SLH-DSA-SHAKE-192f": 105 self.n = 24 106 self.h = 66 107 self.d = 22 108 self.hp = 3 109 self.a = 8 110 self.k = 33 111 self.lg_w = 4 112 self.m = 42 113 self.sec_lvl = 3 114 self.pk_bytes = 48 115 self.sig_bytes = 35664 116 if pset == "SLH-DSA-SHA2-256s" or pset == "SLH-DSA-SHAKE-256s": 117 self.n = 32 118 self.h = 64 119 self.d = 8 120 self.hp = 8 121 self.a = 14 122 self.k = 22 123 self.lg_w = 4 124 self.m = 47 125 self.sec_lvl = 5 126 self.pk_bytes = 64 127 self.sig_bytes = 29792 128 if pset == "SLH-DSA-SHA2-256f" or pset == "SLH-DSA-SHAKE-256f": 129 self.n = 32 130 self.h = 68 131 self.d = 17 132 self.hp = 4 133 self.a = 9 134 self.k = 35 135 self.lg_w = 4 136 self.m = 49 137 self.sec_lvl = 5 138 self.pk_bytes = 64 139 self.sig_bytes = 49856 140 141 if "SHAKE" in pset: 142 self.shake = True 143 else: 144 self.shake = False 145 if "-128" in pset: 146 self.seccat = 1 147 elif "-192" in pset: 148 self.seccat = 3 149 else: 150 self.seccat = 5 151 152 # Setting WOTS+ constants, assuming lg_w == 4 153 self.wotsp_w = 16 154 self.wotsp_len1 = 2*self.n 155 self.wotsp_len2 = 3 156 self.wotsp_len = 2*self.n + 3 157 158 def h_msg(self, R: bytes, pk_seed: bytes, pk_root: bytes, M: bytes) -> bytes: 159 """Compute the H_{msg} function""" 160 if self.shake: 161 h = hashlib.shake_256() 162 h.update(R + pk_seed + pk_root + M) 163 return h.digest(self.m) 164 else: 165 if self.seccat == 1: 166 return mgf1_sha256(R + pk_seed + sha256(R + pk_seed + pk_root + M), self.m) 167 else: 168 return mgf1_sha512(R + pk_seed + sha512(R + pk_seed + pk_root + M), self.m) 169 170 def prf(self, pk_seed: bytes, sk_seed: bytes, adrs: bytes) -> bytes: 171 """Compute the PRF function""" 172 if self.shake: 173 h = hashlib.shake_256() 174 h.update(pk_seed + adrs + sk_seed) 175 return h.digest(self.n) 176 else: 177 adrsc = adrs[3:4] + adrs[8:16] + adrs[19:20] + adrs[20:32] 178 if self.seccat == 1: 179 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + sk_seed)[:self.n] 180 else: 181 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + sk_seed)[:self.n] 182 183 def prf_msg(self, sk_prf: bytes, opt_rand: bytes, M: bytes) -> bytes: 184 """Compute the PRF_{msg} function""" 185 if self.shake: 186 h = hashlib.shake_256() 187 h.update(sk_prf + opt_rand + M) 188 return h.digest(self.n) 189 else: 190 if self.seccat == 1: 191 return hmac_sha256(sk_prf, opt_rand + M)[:self.n] 192 else: 193 return hmac_sha512(sk_prf, opt_rand + M)[:self.n] 194 195 def f(self, pk_seed: bytes, adrs: bytes, M1: bytes) -> bytes: 196 """Compute the F function""" 197 if self.shake: 198 h = hashlib.shake_256() 199 h.update(pk_seed + adrs + M1) 200 return h.digest(self.n) 201 else: 202 adrsc = adrs[3:4] + adrs[8:16] + adrs[19:20] + adrs[20:32] 203 if self.seccat == 1: 204 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + M1)[:self.n] 205 else: 206 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + M1)[:self.n] 207 208 def hf(self, pk_seed: bytes, adrs: bytes, M2: bytes) -> bytes: 209 """Compute the H function""" 210 if self.shake: 211 h = hashlib.shake_256() 212 h.update(pk_seed + adrs + M2) 213 return h.digest(self.n) 214 else: 215 adrsc = adrs[3:4] + adrs[8:16] + adrs[19:20] + adrs[20:32] 216 if self.seccat == 1: 217 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + M2)[:self.n] 218 else: 219 return sha512(pk_seed + toByte(0,128-self.n) + adrsc + M2)[:self.n] 220 221 def t(self, ell: int, pk_seed: bytes, adrs: bytes, M_ell: bytes) -> bytes: 222 """Compute the T_{\\ell} function""" 223 assert len(M_ell) == self.n*ell 224 if self.shake: 225 h = hashlib.shake_256() 226 h.update(pk_seed + adrs + M_ell) 227 return h.digest(self.n) 228 else: 229 adrsc = adrs[3:4] + adrs[8:16] + adrs[19:20] + adrs[20:32] 230 if self.seccat == 1: 231 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + M_ell)[:self.n] 232 else: 233 return sha512(pk_seed + toByte(0,128-self.n) + adrsc + M_ell)[:self.n] 234