gestumblinde

Gestumblinde - reference implementation of SLH-DSA
git clone git://www.tkruger.se/gestumblinde.git
Log | Files | Refs | README

slh_dsa.py (7591B)


      1 """
      2 SLH_DSA context class
      3 """
      4 import hashlib
      5 import hmac
      6 from utils import cdiv, toByte
      7 
      8 ALLOWED_PSETS = [
      9     "SLH-DSA-SHA2-128s",
     10     "SLH-DSA-SHAKE-128s",
     11     "SLH-DSA-SHA2-128f",
     12     "SLH-DSA-SHAKE-128f",
     13     "SLH-DSA-SHA2-192s",
     14     "SLH-DSA-SHAKE-192s",
     15     "SLH-DSA-SHA2-192f",
     16     "SLH-DSA-SHAKE-192f",
     17     "SLH-DSA-SHA2-256s",
     18     "SLH-DSA-SHAKE-256s",
     19     "SLH-DSA-SHA2-256f",
     20     "SLH-DSA-SHAKE-256f"]
     21 
     22 SHA256_DIGEST_LEN = 32
     23 SHA512_DIGEST_LEN = 64
     24 
     25 def sha256(m: bytes) -> bytes:
     26     """SHA256 wrapper function"""
     27     h = hashlib.sha256()
     28     h.update(m)
     29     return h.digest()
     30 
     31 def sha512(m: bytes) -> bytes:
     32     """SHA256 wrapper function"""
     33     h = hashlib.sha512()
     34     h.update(m)
     35     return h.digest()
     36 
     37 def mgf1_sha256(seed: bytes, mlen: int) -> bytes:
     38     """MGF1 mask generation function with SHA256.
     39        See NIST SP 800-56B rev 2 section 7.2.2.2."""
     40     hash_len = SHA256_DIGEST_LEN
     41     t = b""
     42     for c in range(0, cdiv(mlen, hash_len)):
     43         t += sha256(seed + toByte(c, 4))
     44     return t[:mlen]
     45 
     46 def mgf1_sha512(seed: bytes, mlen: int) -> bytes:
     47     """MGF1 mask generation function with SHA256.
     48        See NIST SP 800-56B rev 2 section 7.2.2.2."""
     49     hash_len = SHA512_DIGEST_LEN
     50     t = b""
     51     for c in range(0, cdiv(mlen, hash_len)):
     52         t += sha512(seed + toByte(c, 4))
     53     return t[:mlen]
     54 
     55 def hmac_sha256(k: bytes, msg: bytes) -> bytes:
     56     """HMAC-SHA-256 wrapper function"""
     57     return hmac.digest(k, msg, hashlib.sha256)
     58 
     59 def hmac_sha512(k: bytes, msg: bytes) -> bytes:
     60     """HMAC-SHA-512 wrapper function"""
     61     return hmac.digest(k, msg, hashlib.sha512)
     62 
     63 class SLHDSA:
     64     """Context class for SLH-DSA"""
     65     def __init__(self, pset: str):
     66         assert pset in ALLOWED_PSETS
     67 
     68         if pset == "SLH-DSA-SHA2-128s" or pset == "SLH-DSA-SHAKE-128s":
     69             self.n = 16
     70             self.h = 63
     71             self.d = 7
     72             self.hp = 9
     73             self.a = 12
     74             self.k = 14
     75             self.lg_w = 4
     76             self.m = 30
     77             self.sec_lvl = 1
     78             self.pk_bytes = 32
     79             self.sig_bytes = 7856
     80         if pset == "SLH-DSA-SHA2-128f" or pset == "SLH-DSA-SHAKE-128f":
     81             self.n = 16
     82             self.h = 66
     83             self.d = 22
     84             self.hp = 3
     85             self.a = 6
     86             self.k = 33
     87             self.lg_w = 4
     88             self.m = 34
     89             self.sec_lvl = 1
     90             self.pk_bytes = 32
     91             self.sig_bytes = 17088
     92         if pset == "SLH-DSA-SHA2-192s" or pset == "SLH-DSA-SHAKE-192s":
     93             self.n = 24
     94             self.h = 63
     95             self.d = 7
     96             self.hp = 9
     97             self.a = 14
     98             self.k = 17
     99             self.lg_w = 4
    100             self.m = 39
    101             self.sec_lvl = 3
    102             self.pk_bytes = 48
    103             self.sig_bytes = 16224
    104         if pset == "SLH-DSA-SHA2-192f" or pset == "SLH-DSA-SHAKE-192f":
    105             self.n = 24
    106             self.h = 66
    107             self.d = 22
    108             self.hp = 3
    109             self.a = 8
    110             self.k = 33
    111             self.lg_w = 4
    112             self.m = 42
    113             self.sec_lvl = 3
    114             self.pk_bytes = 48
    115             self.sig_bytes = 35664
    116         if pset == "SLH-DSA-SHA2-256s" or pset == "SLH-DSA-SHAKE-256s":
    117             self.n = 32
    118             self.h = 64
    119             self.d = 8
    120             self.hp = 8
    121             self.a = 14
    122             self.k = 22
    123             self.lg_w = 4
    124             self.m = 47
    125             self.sec_lvl = 5
    126             self.pk_bytes = 64
    127             self.sig_bytes = 29792
    128         if pset == "SLH-DSA-SHA2-256f" or pset == "SLH-DSA-SHAKE-256f":
    129             self.n = 32
    130             self.h = 68
    131             self.d = 17
    132             self.hp = 4
    133             self.a = 9
    134             self.k = 35
    135             self.lg_w = 4
    136             self.m = 49
    137             self.sec_lvl = 5
    138             self.pk_bytes = 64
    139             self.sig_bytes = 49856
    140 
    141         if "SHAKE" in pset:
    142             self.shake = True
    143         else:
    144             self.shake = False
    145             if "-128" in pset:
    146                 self.seccat = 1
    147             elif "-192" in pset:
    148                 self.seccat = 3
    149             else:
    150                 self.seccat = 5
    151 
    152         # Setting WOTS+ constants, assuming lg_w == 4
    153         self.wotsp_w = 16
    154         self.wotsp_len1 = 2*self.n
    155         self.wotsp_len2 = 3
    156         self.wotsp_len = 2*self.n + 3
    157 
    158     def h_msg(self, R: bytes, pk_seed: bytes, pk_root: bytes, M: bytes) -> bytes:
    159         """Compute the H_{msg} function"""
    160         if self.shake:
    161             h = hashlib.shake_256()
    162             h.update(R + pk_seed + pk_root + M)
    163             return h.digest(self.m)
    164         else:
    165             if self.seccat == 1:
    166                 return mgf1_sha256(R + pk_seed + sha256(R + pk_seed + pk_root + M), self.m)
    167             else:
    168                 return mgf1_sha512(R + pk_seed + sha512(R + pk_seed + pk_root + M), self.m)
    169 
    170     def prf(self, pk_seed: bytes, sk_seed: bytes, adrs: bytes) -> bytes:
    171         """Compute the PRF function"""
    172         if self.shake:
    173             h = hashlib.shake_256()
    174             h.update(pk_seed + adrs + sk_seed)
    175             return h.digest(self.n)
    176         else:
    177             adrsc = adrs[3:4] + adrs[8:16] + adrs[19:20] + adrs[20:32]
    178             if self.seccat == 1:
    179                 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + sk_seed)[:self.n]
    180             else:
    181                 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + sk_seed)[:self.n]
    182 
    183     def prf_msg(self, sk_prf: bytes, opt_rand: bytes, M: bytes) -> bytes:
    184         """Compute the PRF_{msg} function"""
    185         if self.shake:
    186             h = hashlib.shake_256()
    187             h.update(sk_prf + opt_rand + M)
    188             return h.digest(self.n)
    189         else:
    190             if self.seccat == 1:
    191                 return hmac_sha256(sk_prf, opt_rand + M)[:self.n]
    192             else:
    193                 return hmac_sha512(sk_prf, opt_rand + M)[:self.n]
    194 
    195     def f(self, pk_seed: bytes, adrs: bytes, M1: bytes) -> bytes:
    196         """Compute the F function"""
    197         if self.shake:
    198             h = hashlib.shake_256()
    199             h.update(pk_seed + adrs + M1)
    200             return h.digest(self.n)
    201         else:
    202             adrsc = adrs[3:4] + adrs[8:16] + adrs[19:20] + adrs[20:32]
    203             if self.seccat == 1:
    204                 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + M1)[:self.n]
    205             else:
    206                 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + M1)[:self.n]
    207 
    208     def hf(self, pk_seed: bytes, adrs: bytes, M2: bytes) -> bytes:
    209         """Compute the H function"""
    210         if self.shake:
    211             h = hashlib.shake_256()
    212             h.update(pk_seed + adrs + M2)
    213             return h.digest(self.n)
    214         else:
    215             adrsc = adrs[3:4] + adrs[8:16] + adrs[19:20] + adrs[20:32]
    216             if self.seccat == 1:
    217                 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + M2)[:self.n]
    218             else:
    219                 return sha512(pk_seed + toByte(0,128-self.n) + adrsc + M2)[:self.n]
    220 
    221     def t(self, ell: int, pk_seed: bytes, adrs: bytes, M_ell: bytes) -> bytes:
    222         """Compute the T_{\\ell} function"""
    223         assert len(M_ell) == self.n*ell
    224         if self.shake:
    225             h = hashlib.shake_256()
    226             h.update(pk_seed + adrs + M_ell)
    227             return h.digest(self.n)
    228         else:
    229             adrsc = adrs[3:4] + adrs[8:16] + adrs[19:20] + adrs[20:32]
    230             if self.seccat == 1:
    231                 return sha256(pk_seed + toByte(0,64-self.n) + adrsc + M_ell)[:self.n]
    232             else:
    233                 return sha512(pk_seed + toByte(0,128-self.n) + adrsc + M_ell)[:self.n]
    234