gestumblinde

Gestumblinde - reference implementation of SLH-DSA
git clone git://www.tkruger.se/gestumblinde.git
Log | Files | Refs | README

wotsp.py (3102B)


      1 """
      2 From Section 5.
      3 """
      4 
      5 from address import Address, AddressType
      6 from slh_dsa import SLHDSA
      7 from utils import cdiv, base_2b, toByte
      8 
      9 def wotsp_chain(X: bytes, i: int, s: int, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
     10     """Chaining function used in WOTS+"""
     11     assert len(X) == ctx.n
     12     if (i + s) >= ctx.wotsp_w:
     13         return b"" # NULL
     14 
     15     tmp = X
     16 
     17     for j in range(i, i + s):
     18         adrs.set_hash_address(j)
     19         tmp = ctx.f(pk_seed, bytes(adrs.data), tmp)
     20     return tmp
     21 
     22 def wotsp_pkgen(sk_seed: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
     23     """Generate a WOTS+ public key"""
     24     sk_adrs = adrs.copy()
     25     sk_adrs.set_type_and_clear(AddressType.WOTS_PRF)
     26     sk_adrs.set_key_pair_address(adrs.get_key_pair_address())
     27     tmp = bytes()
     28     for i in range(ctx.wotsp_len):
     29         sk_adrs.set_chain_address(i)
     30         sk = ctx.prf(pk_seed, sk_seed, sk_adrs.data)
     31         adrs.set_chain_address(i)
     32         tmp += wotsp_chain(sk, 0, ctx.wotsp_w - 1, pk_seed, adrs, ctx)
     33     wotspk_adrs = adrs.copy()
     34     wotspk_adrs.set_type_and_clear(AddressType.WOTS_PK)
     35     wotspk_adrs.set_key_pair_address(adrs.get_key_pair_address())
     36     pk = ctx.t(ctx.wotsp_len, pk_seed, bytes(wotspk_adrs.data), tmp)
     37     return pk
     38 
     39 def wotsp_sign(M: bytes, sk_seed: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
     40     """Generate a WOTS+ signature on an n-byte message"""
     41     csum = 0
     42     msg = base_2b(M, ctx.lg_w, ctx.wotsp_len1)
     43 
     44     for i in range(ctx.wotsp_len1):
     45         csum += ctx.wotsp_w - 1 - msg[i]
     46 
     47     csum <<= ((8 - ((ctx.wotsp_len2 * ctx.lg_w) % 8)) % 8)
     48     msg += base_2b(toByte(csum, cdiv(ctx.wotsp_len2 * ctx.lg_w, 8)),
     49                    ctx.lg_w, ctx.wotsp_len2)
     50 
     51     sk_adrs = adrs.copy()
     52     sk_adrs.set_type_and_clear(AddressType.WOTS_PRF)
     53     sk_adrs.set_key_pair_address(adrs.get_key_pair_address())
     54     sig = b""
     55     for i in range(ctx.wotsp_len):
     56         sk_adrs.set_chain_address(i)
     57         sk = ctx.prf(pk_seed, sk_seed, bytes(sk_adrs.data))
     58         adrs.set_chain_address(i)
     59         sig += wotsp_chain(sk, 0, msg[i], pk_seed, adrs, ctx)
     60     return sig
     61 
     62 def wotsp_pk_from_sig(sig: bytes, M: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
     63     """Computes a WOTS+ public key from a message and its signature"""
     64     csum = 0
     65     msg = base_2b(M, ctx.lg_w, ctx.wotsp_len1)
     66 
     67     for i in range(ctx.wotsp_len1):
     68         csum += ctx.wotsp_w - 1 - msg[i]
     69 
     70     csum <<= ((8 - ((ctx.wotsp_len2 * ctx.lg_w) % 8)) % 8)
     71     msg += base_2b(toByte(csum, cdiv(ctx.wotsp_len2 * ctx.lg_w, 8)),
     72                    ctx.lg_w, ctx.wotsp_len2)
     73     tmp = bytes()
     74     for i in range(ctx.wotsp_len):
     75         adrs.set_chain_address(i)
     76         next_chain = wotsp_chain(sig[ctx.n * i: ctx.n * (i + 1)], msg[i], ctx.wotsp_w-1-msg[i], pk_seed, adrs, ctx)
     77         assert next_chain != None
     78         tmp += next_chain
     79     wotspk_adrs = adrs.copy()
     80     wotspk_adrs.set_type_and_clear(AddressType.WOTS_PK)
     81     wotspk_adrs.set_key_pair_address(adrs.get_key_pair_address())
     82     pk_sig = ctx.t(ctx.wotsp_len, pk_seed, bytes(wotspk_adrs.data), tmp)
     83     return pk_sig