gestumblinde

Gestumblinde - reference implementation of SLH-DSA
git clone git://www.tkruger.se/gestumblinde.git
Log | Files | Refs | README

xmss.py (2660B)


      1 # From Section 6
      2 
      3 from address import Address, AddressType
      4 from slh_dsa import SLHDSA
      5 from utils import toByte
      6 from wotsp import wotsp_pkgen, wotsp_sign, wotsp_pk_from_sig
      7 
      8 def xmss_node(sk_seed: bytes, i: int, z: int, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
      9     """Compute the root of a Merkle subtree of WOTS+ public keys"""
     10     assert i >= 0
     11     assert z >= 0
     12 
     13     if z > ctx.hp or i >= 2**(ctx.hp - z):
     14         return b"" # NULL
     15 
     16     if z == 0:
     17         adrs.set_type_and_clear(AddressType.WOTS_HASH)
     18         adrs.set_key_pair_address(toByte(i, 4))
     19         node = wotsp_pkgen(sk_seed, pk_seed, adrs, ctx)
     20     else:
     21         lnode = xmss_node(sk_seed, 2*i    , z-1, pk_seed, adrs, ctx)
     22         rnode = xmss_node(sk_seed, 2*i + 1, z-1, pk_seed, adrs, ctx)
     23         adrs.set_type_and_clear(AddressType.TREE)
     24         adrs.set_tree_height(z)
     25         adrs.set_tree_index(i)
     26         node = ctx.hf(pk_seed, bytes(adrs.data), lnode + rnode)
     27     return node
     28 
     29 def xmss_sign(M: bytes, sk_seed: bytes, idx: int, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
     30     """Generate an XMSS signature"""
     31     assert len(M) == ctx.n
     32     assert idx >= 0
     33     assert idx < (2**ctx.hp)
     34 
     35     auth = bytes()
     36     for j in range(ctx.hp):
     37         k = (idx//(2**j)) ^ 0b01
     38         auth += xmss_node(sk_seed, k, j, pk_seed, adrs, ctx)
     39 
     40     adrs.set_type_and_clear(AddressType.WOTS_HASH)
     41     adrs.set_key_pair_address(toByte(idx, 4))
     42     sig = wotsp_sign(M, sk_seed, pk_seed, adrs, ctx)
     43     sig_xmss = sig + auth
     44     return sig_xmss
     45 
     46 def xmss_pk_from_sig(idx: int, sig_xmss: bytes, M: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
     47     """Compute an XMSS public key from an XMSS signature"""
     48     assert idx >= 0
     49     assert len(M) == ctx.n
     50     assert len(sig_xmss) == (ctx.wotsp_len + ctx.hp) * ctx.n
     51 
     52     adrs.set_type_and_clear(AddressType.WOTS_HASH)
     53     adrs.set_key_pair_address(toByte(idx, 4))
     54     sig = sig_xmss[:ctx.wotsp_len * ctx.n]  # sig <- sig_xmss.getWOTSSig()
     55     auth = sig_xmss[ctx.wotsp_len * ctx.n:] # auth <- sig_xmss.getXMSSAUTH()
     56     node = [wotsp_pk_from_sig(sig, M, pk_seed, adrs, ctx), None]
     57 
     58     adrs.set_type_and_clear(AddressType.TREE)
     59     adrs.set_tree_index(idx)
     60     for k in range(ctx.hp):
     61         adrs.set_tree_height(k + 1)
     62         if (idx//(2**k)) % 2 == 0:
     63             adrs.set_tree_index(adrs.get_tree_index()//2)
     64             node[1] = ctx.hf(pk_seed, bytes(adrs.data), node[0] + auth[k*ctx.n : (k+1)*ctx.n])
     65         else:
     66             adrs.set_tree_index((adrs.get_tree_index()-1)//2)
     67             node[1] = ctx.hf(pk_seed, bytes(adrs.data), auth[k*ctx.n : (k+1)*ctx.n] + node[0])
     68         node[0] = node[1]
     69     return node[0]