gestumblinde

Gestumblinde - reference implementation of SLH-DSA
git clone git://www.tkruger.se/gestumblinde.git
Log | Files | Refs | README

fors.py (3346B)


      1 """
      2 From Section 8.
      3 """
      4 
      5 from address import Address, AddressType
      6 from slh_dsa import SLHDSA
      7 from utils import base_2b, cdiv
      8 
      9 def fors_skgen(sk_seed: bytes, pk_seed: bytes, adrs: Address, idx: int, ctx: SLHDSA) -> bytes:
     10     """Generate a FORS private-key value"""
     11     assert idx >= 0
     12 
     13     skadrs = adrs.copy()
     14     skadrs.set_type_and_clear(AddressType.FORS_PRF)
     15     skadrs.set_key_pair_address(adrs.get_key_pair_address())
     16     skadrs.set_tree_index(idx)
     17 
     18     return ctx.prf(pk_seed, sk_seed, bytes(skadrs.data))
     19 
     20 def fors_node(sk_seed: bytes, i: int, z: int, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
     21     """Compute the root of a Merkle subtree of FORS public values"""
     22     assert i >= 0
     23     assert z >= 0
     24 
     25     if z > ctx.a or i >= ctx.k * 2**(ctx.a - z):
     26         return b"" # NULL
     27 
     28     if z == 0:
     29         sk = fors_skgen(sk_seed, pk_seed, adrs, i, ctx)
     30         adrs.set_tree_height(0)
     31         adrs.set_tree_index(i)
     32         node = ctx.f(pk_seed, bytes(adrs.data), sk)
     33     else:
     34         lnode = fors_node(sk_seed, 2*i  , z-1, pk_seed, adrs, ctx)
     35         rnode = fors_node(sk_seed, 2*i+1, z-1, pk_seed, adrs, ctx)
     36         adrs.set_tree_height(z)
     37         adrs.set_tree_index(i)
     38         node = ctx.hf(pk_seed, bytes(adrs.data), lnode + rnode)
     39 
     40     return node
     41 
     42 def fors_sign(md: bytes, sk_seed: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
     43     """Generate a FORS signature"""
     44     assert len(md) == cdiv(ctx.k * ctx.a, 8)
     45 
     46     sig_fors = bytes() # initialize sig_fors as a zero-length byte string
     47     indices = base_2b(md, ctx.a, ctx.k)
     48     for i in range(ctx.k):
     49         sig_fors += fors_skgen(sk_seed, pk_seed, adrs, i * 2**ctx.a + indices[i], ctx)
     50 
     51         auth = bytes()
     52         for j in range(ctx.a):
     53             s = indices[i]//(2**j) ^ 0b01
     54             auth += fors_node(sk_seed, i*2**(ctx.a-j) + s, j, pk_seed, adrs, ctx)
     55         sig_fors += auth
     56     return sig_fors
     57 
     58 def fors_pk_from_sig(sig_fors: bytes, md: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes:
     59     """Compute a FORS public key from a FORS signature"""
     60     assert len(sig_fors) == ctx.k * (ctx.a + 1) * ctx.n
     61     assert len(md) == cdiv(ctx.k * ctx.a, 8)
     62 
     63     indices = base_2b(md, ctx.a, ctx.k)
     64     node =[b"" ,b""]
     65     root = bytes()
     66     for i in range(ctx.k):
     67         sk = sig_fors[i*(ctx.a+1)*ctx.n : (i*(ctx.a+1)+1)*ctx.n] # sk <- sig_fors.getSK(i)
     68         adrs.set_tree_height(0)
     69         adrs.set_tree_index(i * 2**ctx.a + indices[i])
     70         node[0] = ctx.f(pk_seed, bytes(adrs.data), sk)
     71 
     72         auth = sig_fors[(i*(ctx.a+1)+1)*ctx.n : (i+1)*(ctx.a+ctx.n)*ctx.n] # auth <- sig_fors.getAUTH(i)
     73         for j in range(ctx.a):
     74             adrs.set_tree_height(j+1)
     75             if (indices[i]//(2**j)) % 2 == 0:
     76                 adrs.set_tree_index(adrs.get_tree_index()//2)
     77                 node[1] = ctx.hf(pk_seed, bytes(adrs.data), node[0] + auth[j*ctx.n:(j+1)*ctx.n])
     78             else:
     79                 adrs.set_tree_index((adrs.get_tree_index()-1)//2)
     80                 node[1] = ctx.hf(pk_seed, bytes(adrs.data), auth[j*ctx.n:(j+1)*ctx.n] + node[0])
     81             node[0] = node[1]
     82         root += node[0]
     83     forspkadrs = adrs.copy()
     84     forspkadrs.set_type_and_clear(AddressType.FORS_ROOTS)
     85     forspkadrs.set_key_pair_address(adrs.get_key_pair_address())
     86     pk = ctx.t(ctx.k, pk_seed, bytes(forspkadrs.data), root)
     87     return pk