xmss.py (2660B)
1 # From Section 6 2 3 from address import Address, AddressType 4 from slh_dsa import SLHDSA 5 from utils import toByte 6 from wotsp import wotsp_pkgen, wotsp_sign, wotsp_pk_from_sig 7 8 def xmss_node(sk_seed: bytes, i: int, z: int, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 9 """Compute the root of a Merkle subtree of WOTS+ public keys""" 10 assert i >= 0 11 assert z >= 0 12 13 if z > ctx.hp or i >= 2**(ctx.hp - z): 14 return b"" # NULL 15 16 if z == 0: 17 adrs.set_type_and_clear(AddressType.WOTS_HASH) 18 adrs.set_key_pair_address(toByte(i, 4)) 19 node = wotsp_pkgen(sk_seed, pk_seed, adrs, ctx) 20 else: 21 lnode = xmss_node(sk_seed, 2*i , z-1, pk_seed, adrs, ctx) 22 rnode = xmss_node(sk_seed, 2*i + 1, z-1, pk_seed, adrs, ctx) 23 adrs.set_type_and_clear(AddressType.TREE) 24 adrs.set_tree_height(z) 25 adrs.set_tree_index(i) 26 node = ctx.hf(pk_seed, bytes(adrs.data), lnode + rnode) 27 return node 28 29 def xmss_sign(M: bytes, sk_seed: bytes, idx: int, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 30 """Generate an XMSS signature""" 31 assert len(M) == ctx.n 32 assert idx >= 0 33 assert idx < (2**ctx.hp) 34 35 auth = bytes() 36 for j in range(ctx.hp): 37 k = (idx//(2**j)) ^ 0b01 38 auth += xmss_node(sk_seed, k, j, pk_seed, adrs, ctx) 39 40 adrs.set_type_and_clear(AddressType.WOTS_HASH) 41 adrs.set_key_pair_address(toByte(idx, 4)) 42 sig = wotsp_sign(M, sk_seed, pk_seed, adrs, ctx) 43 sig_xmss = sig + auth 44 return sig_xmss 45 46 def xmss_pk_from_sig(idx: int, sig_xmss: bytes, M: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 47 """Compute an XMSS public key from an XMSS signature""" 48 assert idx >= 0 49 assert len(M) == ctx.n 50 assert len(sig_xmss) == (ctx.wotsp_len + ctx.hp) * ctx.n 51 52 adrs.set_type_and_clear(AddressType.WOTS_HASH) 53 adrs.set_key_pair_address(toByte(idx, 4)) 54 sig = sig_xmss[:ctx.wotsp_len * ctx.n] # sig <- sig_xmss.getWOTSSig() 55 auth = sig_xmss[ctx.wotsp_len * ctx.n:] # auth <- sig_xmss.getXMSSAUTH() 56 node = [wotsp_pk_from_sig(sig, M, pk_seed, adrs, ctx), None] 57 58 adrs.set_type_and_clear(AddressType.TREE) 59 adrs.set_tree_index(idx) 60 for k in range(ctx.hp): 61 adrs.set_tree_height(k + 1) 62 if (idx//(2**k)) % 2 == 0: 63 adrs.set_tree_index(adrs.get_tree_index()//2) 64 node[1] = ctx.hf(pk_seed, bytes(adrs.data), node[0] + auth[k*ctx.n : (k+1)*ctx.n]) 65 else: 66 adrs.set_tree_index((adrs.get_tree_index()-1)//2) 67 node[1] = ctx.hf(pk_seed, bytes(adrs.data), auth[k*ctx.n : (k+1)*ctx.n] + node[0]) 68 node[0] = node[1] 69 return node[0]