wotsp.py (3102B)
1 """ 2 From Section 5. 3 """ 4 5 from address import Address, AddressType 6 from slh_dsa import SLHDSA 7 from utils import cdiv, base_2b, toByte 8 9 def wotsp_chain(X: bytes, i: int, s: int, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 10 """Chaining function used in WOTS+""" 11 assert len(X) == ctx.n 12 if (i + s) >= ctx.wotsp_w: 13 return b"" # NULL 14 15 tmp = X 16 17 for j in range(i, i + s): 18 adrs.set_hash_address(j) 19 tmp = ctx.f(pk_seed, bytes(adrs.data), tmp) 20 return tmp 21 22 def wotsp_pkgen(sk_seed: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 23 """Generate a WOTS+ public key""" 24 sk_adrs = adrs.copy() 25 sk_adrs.set_type_and_clear(AddressType.WOTS_PRF) 26 sk_adrs.set_key_pair_address(adrs.get_key_pair_address()) 27 tmp = bytes() 28 for i in range(ctx.wotsp_len): 29 sk_adrs.set_chain_address(i) 30 sk = ctx.prf(pk_seed, sk_seed, sk_adrs.data) 31 adrs.set_chain_address(i) 32 tmp += wotsp_chain(sk, 0, ctx.wotsp_w - 1, pk_seed, adrs, ctx) 33 wotspk_adrs = adrs.copy() 34 wotspk_adrs.set_type_and_clear(AddressType.WOTS_PK) 35 wotspk_adrs.set_key_pair_address(adrs.get_key_pair_address()) 36 pk = ctx.t(ctx.wotsp_len, pk_seed, bytes(wotspk_adrs.data), tmp) 37 return pk 38 39 def wotsp_sign(M: bytes, sk_seed: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 40 """Generate a WOTS+ signature on an n-byte message""" 41 csum = 0 42 msg = base_2b(M, ctx.lg_w, ctx.wotsp_len1) 43 44 for i in range(ctx.wotsp_len1): 45 csum += ctx.wotsp_w - 1 - msg[i] 46 47 csum <<= ((8 - ((ctx.wotsp_len2 * ctx.lg_w) % 8)) % 8) 48 msg += base_2b(toByte(csum, cdiv(ctx.wotsp_len2 * ctx.lg_w, 8)), 49 ctx.lg_w, ctx.wotsp_len2) 50 51 sk_adrs = adrs.copy() 52 sk_adrs.set_type_and_clear(AddressType.WOTS_PRF) 53 sk_adrs.set_key_pair_address(adrs.get_key_pair_address()) 54 sig = b"" 55 for i in range(ctx.wotsp_len): 56 sk_adrs.set_chain_address(i) 57 sk = ctx.prf(pk_seed, sk_seed, bytes(sk_adrs.data)) 58 adrs.set_chain_address(i) 59 sig += wotsp_chain(sk, 0, msg[i], pk_seed, adrs, ctx) 60 return sig 61 62 def wotsp_pk_from_sig(sig: bytes, M: bytes, pk_seed: bytes, adrs: Address, ctx: SLHDSA) -> bytes: 63 """Computes a WOTS+ public key from a message and its signature""" 64 csum = 0 65 msg = base_2b(M, ctx.lg_w, ctx.wotsp_len1) 66 67 for i in range(ctx.wotsp_len1): 68 csum += ctx.wotsp_w - 1 - msg[i] 69 70 csum <<= ((8 - ((ctx.wotsp_len2 * ctx.lg_w) % 8)) % 8) 71 msg += base_2b(toByte(csum, cdiv(ctx.wotsp_len2 * ctx.lg_w, 8)), 72 ctx.lg_w, ctx.wotsp_len2) 73 tmp = bytes() 74 for i in range(ctx.wotsp_len): 75 adrs.set_chain_address(i) 76 next_chain = wotsp_chain(sig[ctx.n * i: ctx.n * (i + 1)], msg[i], ctx.wotsp_w-1-msg[i], pk_seed, adrs, ctx) 77 assert next_chain != None 78 tmp += next_chain 79 wotspk_adrs = adrs.copy() 80 wotspk_adrs.set_type_and_clear(AddressType.WOTS_PK) 81 wotspk_adrs.set_key_pair_address(adrs.get_key_pair_address()) 82 pk_sig = ctx.t(ctx.wotsp_len, pk_seed, bytes(wotspk_adrs.data), tmp) 83 return pk_sig