fors.py (3346B)
1 """ 2 From Section 8. 3 """ 4 5 from address import Address, AddressType 6 from slh_dsa import SLHDSA 7 from utils import base_2b, cdiv 8 9 def fors_skgen(sk_seed: bytes, pk_seed: bytes, adrs: Address, idx: int, ctx: SLHDSA) -> bytes: 10 """Generate a FORS private-key value""" 11 assert idx >= 0 12 13 skadrs = adrs.copy() 14 skadrs.set_type_and_clear(AddressType.FORS_PRF) 15 skadrs.set_key_pair_address(adrs.get_key_pair_address()) 16 skadrs.set_tree_index(idx) 17 18 return ctx.prf(pk_seed, sk_seed, bytes(skadrs.data)) 19 20 def fors_node(sk_seed: bytes, i: int, z: int, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 21 """Compute the root of a Merkle subtree of FORS public values""" 22 assert i >= 0 23 assert z >= 0 24 25 if z > ctx.a or i >= ctx.k * 2**(ctx.a - z): 26 return b"" # NULL 27 28 if z == 0: 29 sk = fors_skgen(sk_seed, pk_seed, adrs, i, ctx) 30 adrs.set_tree_height(0) 31 adrs.set_tree_index(i) 32 node = ctx.f(pk_seed, bytes(adrs.data), sk) 33 else: 34 lnode = fors_node(sk_seed, 2*i , z-1, pk_seed, adrs, ctx) 35 rnode = fors_node(sk_seed, 2*i+1, z-1, pk_seed, adrs, ctx) 36 adrs.set_tree_height(z) 37 adrs.set_tree_index(i) 38 node = ctx.hf(pk_seed, bytes(adrs.data), lnode + rnode) 39 40 return node 41 42 def fors_sign(md: bytes, sk_seed: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 43 """Generate a FORS signature""" 44 assert len(md) == cdiv(ctx.k * ctx.a, 8) 45 46 sig_fors = bytes() # initialize sig_fors as a zero-length byte string 47 indices = base_2b(md, ctx.a, ctx.k) 48 for i in range(ctx.k): 49 sig_fors += fors_skgen(sk_seed, pk_seed, adrs, i * 2**ctx.a + indices[i], ctx) 50 51 auth = bytes() 52 for j in range(ctx.a): 53 s = indices[i]//(2**j) ^ 0b01 54 auth += fors_node(sk_seed, i*2**(ctx.a-j) + s, j, pk_seed, adrs, ctx) 55 sig_fors += auth 56 return sig_fors 57 58 def fors_pk_from_sig(sig_fors: bytes, md: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 59 """Compute a FORS public key from a FORS signature""" 60 assert len(sig_fors) == ctx.k * (ctx.a + 1) * ctx.n 61 assert len(md) == cdiv(ctx.k * ctx.a, 8) 62 63 indices = base_2b(md, ctx.a, ctx.k) 64 node =[b"" ,b""] 65 root = bytes() 66 for i in range(ctx.k): 67 sk = sig_fors[i*(ctx.a+1)*ctx.n : (i*(ctx.a+1)+1)*ctx.n] # sk <- sig_fors.getSK(i) 68 adrs.set_tree_height(0) 69 adrs.set_tree_index(i * 2**ctx.a + indices[i]) 70 node[0] = ctx.f(pk_seed, bytes(adrs.data), sk) 71 72 auth = sig_fors[(i*(ctx.a+1)+1)*ctx.n : (i+1)*(ctx.a+ctx.n)*ctx.n] # auth <- sig_fors.getAUTH(i) 73 for j in range(ctx.a): 74 adrs.set_tree_height(j+1) 75 if (indices[i]//(2**j)) % 2 == 0: 76 adrs.set_tree_index(adrs.get_tree_index()//2) 77 node[1] = ctx.hf(pk_seed, bytes(adrs.data), node[0] + auth[j*ctx.n:(j+1)*ctx.n]) 78 else: 79 adrs.set_tree_index((adrs.get_tree_index()-1)//2) 80 node[1] = ctx.hf(pk_seed, bytes(adrs.data), auth[j*ctx.n:(j+1)*ctx.n] + node[0]) 81 node[0] = node[1] 82 root += node[0] 83 forspkadrs = adrs.copy() 84 forspkadrs.set_type_and_clear(AddressType.FORS_ROOTS) 85 forspkadrs.set_key_pair_address(adrs.get_key_pair_address()) 86 pk = ctx.t(ctx.k, pk_seed, bytes(forspkadrs.data), root) 87 return pk