gestumblinde

Gestumblinde - reference implementation of SLH-DSA
git clone git://www.tkruger.se/gestumblinde.git
Log | Files | Refs | README

test-sha2-192f.py (14185B)


      1 """
      2 Unit tests
      3 """
      4 
      5 import unittest
      6 import json
      7 from slh_dsa import SLHDSA
      8 from address import Address, AddressType
      9 from wotsp import wotsp_pkgen, wotsp_sign, wotsp_pk_from_sig
     10 from xmss import xmss_node, xmss_sign, xmss_pk_from_sig
     11 from fors import fors_skgen, fors_node, fors_sign, fors_pk_from_sig
     12 from slh import slh_keygen, slh_sign, slh_verify
     13 from ht import ht_sign, ht_verify
     14 from utils import toByte, cdiv
     15 
     16 def read_json_test_vectors(filename):
     17     """Read test vector from JSON file"""
     18     with open(filename, "r") as infile:
     19         json_object = infile.read()
     20     r = json.loads(json_object)
     21     # convert to bytes/tuples of bytes
     22     converted = dict()
     23     for k in r.keys():
     24         if type(r[k]) == list:
     25             if len(r[k]) in [2,4] and (("SLH SK" in k) or ("SLH PK" in k)):
     26                 converted[k] = tuple([bytes(x) for x in r[k]])
     27             else:
     28                 converted[k] = bytes(r[k])
     29         else:
     30             converted[k] = r[k]
     31     return converted
     32 
     33 class TestSHA2_192F_Address(unittest.TestCase):
     34     """Unit tests for Address class"""
     35 
     36     def test_init(self):
     37         """Test Address initialisation"""
     38         a = Address()
     39         self.assertEqual(len(a.data), 32)
     40         for i in range(len(a.data)):
     41             self.assertEqual(a.data[i], 0)
     42         
     43 class TestSHA2_192F_WotsPlus(unittest.TestCase):
     44     """Unit tests for WOTS+ stuff"""
     45 
     46     def test_pkgen(self):
     47         """Testing WOTS+ public key generation"""
     48         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
     49         ctx = SLHDSA("SLH-DSA-SHA2-192f")
     50         pfx = "SLH-DSA-SHA2-192f WOTS+ "
     51         adrs_bytes   = tv[pfx+"ADDRESS"]
     52         sk_seed      = tv[pfx+"SK_SEED"]
     53         pk_seed      = tv[pfx+"PK_SEED"]
     54         pk_cmp       = tv[pfx+"PUBLIC_KEY"]
     55         adrs = Address()
     56         adrs.data = bytearray(adrs_bytes)
     57         pk = wotsp_pkgen(sk_seed, pk_seed, adrs, ctx)
     58         self.assertEqual(len(pk), ctx.n)
     59         self.assertEqual(pk, pk_cmp)
     60 
     61     def test_sign(self):
     62         """Testing signature."""
     63         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
     64         ctx = SLHDSA("SLH-DSA-SHA2-192f")
     65         pfx = "SLH-DSA-SHA2-192f WOTS+ "
     66         adrs_bytes   = tv[pfx+"ADDRESS"]
     67         sk_seed      = tv[pfx+"SK_SEED"]
     68         pk_seed      = tv[pfx+"PK_SEED"]
     69         msg          = tv[pfx+"MSG"]
     70         sig_ref      = tv[pfx+"SIGNATURE"]
     71         adrs = Address()
     72         adrs.data = bytearray(adrs_bytes)
     73         sig = wotsp_sign(msg, sk_seed, pk_seed, adrs, ctx)
     74         self.assertEqual(sig, sig_ref)
     75 
     76     def test_verify(self):
     77         """Testing verification of own signatures"""
     78         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
     79         ctx = SLHDSA("SLH-DSA-SHA2-192f")
     80         pfx = "SLH-DSA-SHA2-192f WOTS+ "
     81         adrs_bytes   = tv[pfx+"ADDRESS"]
     82         sk_seed      = tv[pfx+"SK_SEED"]
     83         pk_seed      = tv[pfx+"PK_SEED"]
     84         pk           = tv[pfx+"PUBLIC_KEY"]
     85         msg          = tv[pfx+"MSG"]
     86         sig          = tv[pfx+"SIGNATURE"]
     87         adrs = Address()
     88         adrs.data = bytearray(adrs_bytes)
     89         # verify the signature; good (msg, sig)
     90         pk_from_sig = wotsp_pk_from_sig(sig, msg, pk_seed, adrs, ctx)
     91         self.assertEqual(pk, pk_from_sig)
     92         # fail verify with bitflipped sig
     93         sigp = bytearray(sig) # copy signature
     94         sigp[3] ^= 0x10 # flip a bit
     95         pk_from_sigp = wotsp_pk_from_sig(bytes(sigp), msg, pk_seed, adrs, ctx)
     96         self.assertNotEqual(pk, pk_from_sigp)
     97         # fail verify with bitflipped msg
     98         msgp = bytearray(msg) # copy message
     99         msgp[6] ^= 0x04 # flip a bit
    100         pk_from_sigm = wotsp_pk_from_sig(sig, bytes(msgp), pk_seed, adrs, ctx)
    101         self.assertNotEqual(pk, pk_from_sigm)
    102 
    103 class TestSHA2_192F_XMSS(unittest.TestCase):
    104     """Unit tests for XMSS stuff"""
    105 
    106     def test_node(self):
    107         """Testing XMSS node function"""
    108         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    109         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    110         pfx = "SLH-DSA-SHA2-192f XMSS "
    111         adrs_bytes   = tv[pfx+"ADDRESS"]
    112         sk_seed      = tv[pfx+"SK_SEED"]
    113         pk_seed      = tv[pfx+"PK_SEED"]
    114         node_ref     = tv[pfx+"NODE"]
    115         i            = tv[pfx+"NODEI"]
    116         z            = tv[pfx+"NODEZ"]
    117         adrs = Address()
    118         adrs.data = bytearray(adrs_bytes)
    119         node = xmss_node(sk_seed, i, z, pk_seed, adrs, ctx)
    120         self.assertEqual(node, node_ref)
    121 
    122     def test_sign(self):
    123         """Testing XMSS signing"""
    124         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    125         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    126         pfx = "SLH-DSA-SHA2-192f XMSS "
    127         adrs_bytes   = tv[pfx+"ADDRESS"]
    128         sk_seed      = tv[pfx+"SK_SEED"]
    129         pk_seed      = tv[pfx+"PK_SEED"]
    130         msg          = tv[pfx+"MSG"]
    131         idx          = tv[pfx+"SIGNIDX"]
    132         sig_ref      = tv[pfx+"SIGNATURE"]
    133         adrs = Address()
    134         adrs.data = bytearray(adrs_bytes)
    135         sig = xmss_sign(msg, sk_seed, idx, pk_seed, adrs, ctx)
    136         self.assertEqual(sig, sig_ref)
    137 
    138     def test_verify(self):
    139         """Self-testing XMSS signature verification"""
    140         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    141         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    142         pfx = "SLH-DSA-SHA2-192f XMSS "
    143         adrs_bytes   = tv[pfx+"ADDRESS"]
    144         sk_seed      = tv[pfx+"SK_SEED"]
    145         pk_seed      = tv[pfx+"PK_SEED"]
    146         msg          = tv[pfx+"MSG"]
    147         idx          = tv[pfx+"SIGNIDX"]
    148         sig          = tv[pfx+"SIGNATURE"]
    149         adrs = Address()
    150         adrs.data = bytearray(adrs_bytes)
    151         # public key is node at level h'
    152         pk = xmss_node(sk_seed, 0, ctx.hp, pk_seed, adrs, ctx)
    153         # verify the signature; good (idx, msg, sig)
    154         pk_from_sig = xmss_pk_from_sig(idx, sig, msg, pk_seed, adrs, ctx)
    155         self.assertEqual(pk, pk_from_sig)
    156         # fail verify with bitflipped sig
    157         sigp = bytearray(sig) # copy signature
    158         sigp[3] ^= 0x10 # flip a bit
    159         pk_from_sigp = xmss_pk_from_sig(idx, bytes(sigp), msg, pk_seed, adrs, ctx)
    160         self.assertNotEqual(pk, pk_from_sigp)
    161         # fail verify with bitflipped msg
    162         msgp = bytearray(msg) # copy message
    163         msgp[6] ^= 0x04 # flip a bit
    164         pk_from_sigm = xmss_pk_from_sig(idx, sig, bytes(msgp), pk_seed, adrs, ctx)
    165         self.assertNotEqual(pk, pk_from_sigm)
    166         # fail verify with wrong idx
    167         idxp = 2
    168         pk_from_sigi = xmss_pk_from_sig(idxp, sig, msg, pk_seed, adrs, ctx)
    169         self.assertNotEqual(pk, pk_from_sigi)
    170 
    171         # testing with address of pk_root (as in ht_verify)
    172         adrs = Address(toByte(0, 32)) # adrs <- toByte(0,32)
    173         adrs.set_layer_address(ctx.d - 1)
    174         pk_root = xmss_node(sk_seed, 0, ctx.hp, pk_seed, adrs, ctx)
    175         sig2 = xmss_sign(msg, sk_seed, idx, pk_seed, adrs, ctx)
    176         pk_from_sig = xmss_pk_from_sig(idx, sig2, msg, pk_seed, adrs, ctx)
    177         self.assertEqual(pk_from_sig, pk_root)
    178 
    179 class TestSHA2_192F_HT(unittest.TestCase):
    180     """Unit test for hypertree stuff"""
    181 
    182     def test_sign(self):
    183         """Hypertree signature test"""
    184         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    185         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    186         pfx = "SLH-DSA-SHA2-192f HT "
    187         sk_seed      = tv[pfx+"SK_SEED"]
    188         pk_seed      = tv[pfx+"PK_SEED"]
    189         msg          = tv[pfx+"MSG"]
    190         sig_ref      = tv[pfx+"SIGNATURE"]
    191         idx_tree     = tv[pfx+"IDX_TREE"]
    192         idx_leaf     = tv[pfx+"IDX_LEAF"]
    193 
    194         sig = ht_sign(msg, sk_seed, pk_seed, idx_tree, idx_leaf, ctx)
    195 
    196         self.assertEqual(sig, sig_ref)
    197 
    198     def test_verify(self):
    199         """Self-testing hypertree signature verification"""
    200         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    201         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    202         pfx = "SLH-DSA-SHA2-192f HT "
    203         sk_seed      = tv[pfx+"SK_SEED"]
    204         pk_seed      = tv[pfx+"PK_SEED"]
    205         msg          = tv[pfx+"MSG"]
    206         sig          = tv[pfx+"SIGNATURE"]
    207         idx_tree     = tv[pfx+"IDX_TREE"]
    208         idx_leaf     = tv[pfx+"IDX_LEAF"]
    209 
    210         # Calculate PK.root as in slh_keygen()
    211         adrs = Address(toByte(0, 32)) # adrs <- toByte(0,32)
    212         adrs.set_layer_address(ctx.d - 1)
    213         pk_root = xmss_node(sk_seed, 0, ctx.hp, pk_seed, adrs, ctx)
    214 
    215         self.assertTrue(ht_verify(msg, sig, pk_seed, idx_tree, idx_leaf, pk_root, ctx))
    216 
    217 class TestSHA2_192F_FORS(unittest.TestCase):
    218     """Unit test for FORS stuff"""
    219 
    220     def test_skgen(self):
    221         """Testing FORS secret key generation function"""
    222         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    223         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    224         pfx = "SLH-DSA-SHA2-192f FORS "
    225         sk_seed      = tv[pfx+"SK_SEED"]
    226         pk_seed      = tv[pfx+"PK_SEED"]
    227         fors_sk      = tv[pfx+"SK"]
    228         idx          = tv[pfx+"IDX"]
    229         adrs_bytes   = tv[pfx+"ADDRESS"]
    230         adrs = Address()
    231         adrs.data = bytearray(adrs_bytes)
    232 
    233         sk = fors_skgen(sk_seed, pk_seed, adrs, idx, ctx)
    234 
    235         self.assertTrue(sk, fors_sk)
    236 
    237     def test_node(self):
    238         """Testing FORS node function"""
    239         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    240         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    241         pfx = "SLH-DSA-SHA2-192f FORS "
    242         sk_seed      = tv[pfx+"SK_SEED"]
    243         pk_seed      = tv[pfx+"PK_SEED"]
    244         adrs_bytes   = tv[pfx+"ADDRESS"]
    245         i            = tv[pfx+"NODEI"]
    246         z            = tv[pfx+"NODEZ"]
    247         node_ref     = tv[pfx+"NODE"]
    248         adrs = Address()
    249         adrs.data = bytearray(adrs_bytes)
    250 
    251         node = fors_node(sk_seed, i, z, pk_seed, adrs, ctx)
    252 
    253         self.assertEqual(node, node_ref)
    254 
    255     def test_sign(self):
    256         """Testing FORS sign function"""
    257         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    258         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    259         pfx = "SLH-DSA-SHA2-192f FORS "
    260         sk_seed      = tv[pfx+"SK_SEED"]
    261         pk_seed      = tv[pfx+"PK_SEED"]
    262         adrs_bytes   = tv[pfx+"ADDRESS"]
    263         sig_ref      = tv[pfx+"SIGNATURE"]
    264         md           = tv[pfx+"MD"]
    265         adrs = Address()
    266         adrs.data = bytearray(adrs_bytes)
    267 
    268         sig = fors_sign(md, sk_seed, pk_seed, adrs, ctx)
    269 
    270         self.assertEqual(sig, sig_ref)
    271 
    272     def test_pk_from_sig(self):
    273         """Self-testing FORS verification (pk from sig)"""
    274         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    275         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    276         pfx = "SLH-DSA-SHA2-192f FORS "
    277         sk_seed      = tv[pfx+"SK_SEED"]
    278         pk_seed      = tv[pfx+"PK_SEED"]
    279         adrs_bytes   = tv[pfx+"ADDRESS"]
    280         sig          = tv[pfx+"SIGNATURE"]
    281         node         = tv[pfx+"NODE"]
    282         md           = tv[pfx+"MD"]
    283         adrs = Address()
    284         adrs.data = bytearray(adrs_bytes)
    285 
    286         pk = fors_pk_from_sig(sig, md, pk_seed, adrs, ctx)
    287 
    288         # generate the pk from roots from sk_seed
    289         roots = b""
    290         for j in range(ctx.k):
    291             node = fors_node(sk_seed, j, ctx.a, pk_seed, adrs, ctx)
    292             roots += node
    293         forspkadrs = adrs.copy()
    294         forspkadrs.set_type_and_clear(AddressType.FORS_ROOTS)
    295         forspkadrs.set_key_pair_address(adrs.get_key_pair_address())
    296         pk_ref = ctx.t(ctx.k, pk_seed, bytes(forspkadrs.data), roots)
    297 
    298         self.assertEqual(pk, pk_ref)
    299 
    300 class TestSHA2_192F_SLH(unittest.TestCase):
    301     """Unit test for top-level SLH-DSA function"""
    302 
    303     def test_keygen(self):
    304         """Sanity checks of SLH-DSA keygen"""
    305         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    306         key = slh_keygen(ctx)
    307         self.assertEqual(type(key), tuple)
    308         self.assertEqual(len(key),    2)
    309         self.assertEqual(type(key[0]), tuple)
    310         self.assertEqual(len(key[0]), 4)
    311         self.assertEqual(type(key[1]), tuple)
    312         self.assertEqual(len(key[1]), 2)
    313         self.assertEqual(len(key[0][0]), ctx.n) 
    314         self.assertEqual(len(key[0][1]), ctx.n) 
    315         self.assertEqual(len(key[0][2]), ctx.n) 
    316         self.assertEqual(len(key[0][3]), ctx.n) 
    317         self.assertEqual(len(key[1][0]), ctx.n) 
    318         self.assertEqual(len(key[1][1]), ctx.n) 
    319 
    320     def test_sign(self):
    321         """Test signature generation"""
    322         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    323         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    324         pfx = "SLH-DSA-SHA2-192f SLH "
    325         sk           = tv[pfx+"SK"]
    326         sig_ref      = tv[pfx+"SIGNATURE"]
    327         msg          = tv[pfx+"MSG"]
    328 
    329         sig = slh_sign(msg, sk, ctx)
    330         
    331         # sanity
    332         self.assertEqual(len(sig), ctx.sig_bytes)
    333         self.assertEqual(len(sig), len(sig_ref))
    334 
    335         self.assertEqual(sig, sig_ref)
    336 
    337     def test_verify(self):
    338         """Self-testing signature verification"""
    339         tv = read_json_test_vectors("../slh-dsa-sha2-192f-test-vectors.json")
    340         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    341         pfx = "SLH-DSA-SHA2-192f SLH "
    342         pk           = tv[pfx+"PK"]
    343         sig          = tv[pfx+"SIGNATURE"]
    344         msg          = tv[pfx+"MSG"]
    345 
    346         # verify original (msg,sig) pair
    347         self.assertTrue(slh_verify(msg, sig, pk, ctx))
    348 
    349         # flip a bit and the verification should fail
    350         msg2 = bytes([msg[0] ^ 0x10]) + msg[1:]
    351         self.assertFalse(slh_verify(msg2, sig, pk, ctx))
    352 
    353 class TestSHA2_192F_SLHREF(unittest.TestCase):
    354     """Unit test for SLH-DSA wrp hacked SPHINCS+ ref impl"""
    355 
    356     def test_sign(self):
    357         """Test signature generation"""
    358         tv = read_json_test_vectors("../slh-dsa-sha2-192f-ref-vectors.json")
    359         ctx = SLHDSA("SLH-DSA-SHA2-192f")
    360         pfx = "SLH-DSA-SHA2-192f SLH "
    361         sk           = tv[pfx+"SK"]
    362         sig_ref      = tv[pfx+"SIGNATURE"]
    363         msg          = tv[pfx+"MSG"]
    364 
    365         sig = slh_sign(msg, sk, ctx)
    366         
    367         # sanity
    368         self.assertEqual(len(sig), ctx.sig_bytes)
    369         self.assertEqual(len(sig), len(sig_ref))
    370 
    371         self.assertEqual(sig, sig_ref)
    372 
    373 
    374 if __name__ == "__main__":
    375     unittest.main()