gestumblinde

Gestumblinde - reference implementation of SLH-DSA
git clone git://www.tkruger.se/gestumblinde.git
Log | Files | Refs | README

address.py (7999B)


      1 """
      2 Address class.
      3 """
      4 
      5 from enum import Enum
      6 import copy
      7 from utils import toByte, toInt
      8 
      9 # unmagic numbers, note that end indices are the
     10 # last entry index + 1
     11 LAYER_ADDRESS_START=0
     12 LAYER_ADDRESS_END=4
     13 TREE_ADDRESS_START=4
     14 TREE_ADDRESS_END=16
     15 TYPE_START=16
     16 TYPE_END=20
     17 HASH_ADDRESS_START=28
     18 HASH_ADDRESS_END=32
     19 ADDRESS_DATA_LEN=32
     20 KEY_PAIR_ADDRESS_START=20
     21 KEY_PAIR_ADDRESS_END=24
     22 CHAIN_ADDRESS_START=24
     23 CHAIN_ADDRESS_END=28
     24 TREE_HEIGHT_START=24
     25 TREE_HEIGHT_END=28
     26 TREE_INDEX_START=28
     27 TREE_INDEX_END=32
     28 
     29 class AddressType(Enum):
     30     """Enum for types of address"""
     31     WOTS_HASH = 0
     32     WOTS_PK = 1
     33     TREE = 2
     34     FORS_TREE = 3
     35     FORS_ROOTS = 4
     36     WOTS_PRF = 5
     37     FORS_PRF = 6
     38 
     39     def __int__(self):
     40         return self.value
     41     
     42 class Address:
     43     """Stores an SLH-DSA address (same class for all types)"""
     44     def __init__(self, b=None):
     45         """Initialize an address"""
     46         if b == None:
     47             self.data = bytearray([0]*32)
     48         else:
     49             self.data = bytearray(b)
     50 
     51     def copy(self):
     52         """Make a (deep) copy of an address"""
     53         cself = copy.deepcopy(self)
     54         return cself
     55 
     56     def get_layer_address_raw(self) -> bytearray:
     57         """Get the layer address as raw bytes"""
     58         return self.data[:4]
     59 
     60     def set_layer_address_raw(self, layer_address: bytes):
     61         """Set the layer address from raw bytes"""
     62         assert len(layer_address) == 4
     63         self.data[LAYER_ADDRESS_START:LAYER_ADDRESS_END] = layer_address
     64 
     65     def get_layer_address(self) -> int:
     66         """Get the layer address as an integer"""
     67         return toInt(self.get_layer_address_raw(), 4)
     68 
     69     def set_layer_address(self, layer_address: int):
     70         """Set the layer address from an integer"""
     71         assert layer_address < 2**32
     72         self.set_layer_address_raw(toByte(layer_address, 4))
     73 
     74     def get_tree_address_raw(self) -> bytearray:
     75         """Get the tree address as raw bytes"""
     76         return self.data[TREE_ADDRESS_START:TREE_ADDRESS_END]
     77 
     78     def set_tree_address_raw(self, tree_address: bytes):
     79         """Set the tree address from raw bytes"""
     80         assert len(tree_address) == 12
     81         self.data[TREE_ADDRESS_START:TREE_ADDRESS_END] = tree_address
     82 
     83     def get_tree_address(self) -> int:
     84         """Get the tree address as integer"""
     85         return toInt(self.get_tree_address_raw(), 12)
     86 
     87     def set_tree_address(self, tree_address: int):
     88         """Set the tree address from raw bytes"""
     89         assert tree_address < 2**(8*12)
     90         self.set_tree_address_raw(toByte(tree_address, 12))
     91 
     92     def get_type(self) -> bytearray:
     93         """Get type of address"""
     94         return self.data[TYPE_START:TYPE_END]
     95 
     96     def set_type(self, type_: bytes):
     97         """Set the type of the address"""
     98         assert len(type_) == 4
     99         self.data[TYPE_START:TYPE_END] = type_
    100 
    101     def get_typeaddress(self) -> AddressType:
    102         """Get the address as a AddressType enum"""
    103         return AddressType(toInt(self.get_type(), 4))
    104 
    105     def __set_typeaddress(self, new_type: AddressType):
    106         self.set_type(toByte(int(new_type), 4))
    107 
    108     def get_hash_address(self) -> bytearray:
    109         """Get the hash address of WOTS_HASH"""
    110         assert self.get_typeaddress() == AddressType.WOTS_HASH
    111         return self.data[HASH_ADDRESS_START:HASH_ADDRESS_END]
    112 
    113     def set_hash_address_raw(self, value: bytes):
    114         """Set the hash address from raw bytes"""
    115         assert self.get_typeaddress() == AddressType.WOTS_HASH
    116         assert len(value) == 4
    117         self.data[HASH_ADDRESS_START:HASH_ADDRESS_END] = value
    118 
    119     def set_hash_address(self, value: int):
    120         """Set the hash address from an int"""
    121         assert self.get_typeaddress() == AddressType.WOTS_HASH
    122         assert value >= 0 # valid uint32
    123         assert value < 2**32 # valid uint32
    124         self.set_hash_address_raw(toByte(value, 4))
    125 
    126     def set_type_and_clear(self, new_type: AddressType):
    127         """Set the stype and clear everything following it in the address"""
    128         # set new address type:
    129         self.__set_typeaddress(new_type)
    130         # clear everything after type:
    131         self.data[TYPE_END:ADDRESS_DATA_LEN] = b"\x00"*(ADDRESS_DATA_LEN - TYPE_END)
    132 
    133     def has_key_pair_address(self) -> bool:
    134         """Check if type is one of the types that has a key pair address"""
    135         return self.get_typeaddress() in \
    136             [AddressType.WOTS_HASH,
    137              AddressType.WOTS_PK,
    138              AddressType.FORS_TREE,
    139              AddressType.FORS_ROOTS,
    140              AddressType.WOTS_PRF,
    141              AddressType.FORS_PRF]
    142 
    143     def get_key_pair_address(self) -> bytearray:
    144         """Get the key pair address"""
    145         assert self.has_key_pair_address()
    146         return self.data[KEY_PAIR_ADDRESS_START:KEY_PAIR_ADDRESS_END]
    147 
    148     def set_key_pair_address(self, new_kpa: bytes):
    149         """Set the keypair adress to `new_kpa`"""
    150         assert self.has_key_pair_address()
    151         assert len(new_kpa) == 4
    152         self.data[KEY_PAIR_ADDRESS_START:KEY_PAIR_ADDRESS_END] = new_kpa
    153 
    154     def has_chain_address(self) -> bool:
    155         """Check if type is one of the types that has a chain address"""
    156         return self.get_typeaddress() in \
    157             [AddressType.WOTS_HASH,
    158              AddressType.WOTS_PRF]
    159 
    160     def get_chain_address_raw(self) -> bytearray:
    161         """Get the chain address as raw bytes"""
    162         assert self.has_chain_address()
    163         return self.data[CHAIN_ADDRESS_START:CHAIN_ADDRESS_END]
    164 
    165     def set_chain_address_raw(self, new_ca: bytes):
    166         """Set the chain address from raw bytes"""
    167         assert self.has_chain_address()
    168         assert len(new_ca) == 4
    169         self.data[CHAIN_ADDRESS_START:CHAIN_ADDRESS_END] = new_ca
    170 
    171     def get_chain_address(self) -> int:
    172         """Get the chain address as an int"""
    173         return toInt(self.get_chain_address_raw(), 4)
    174 
    175     def set_chain_address(self, new_ca: int):
    176         """Set the chain address from an int"""
    177         assert new_ca < 2**32
    178         self.set_chain_address_raw(toByte(new_ca, 4))
    179 
    180     def has_tree_height(self) -> bool:
    181         """Check if type is one of the types that has a tree height"""
    182         return self.get_typeaddress() in \
    183             [AddressType.TREE,
    184              AddressType.FORS_TREE,
    185              AddressType.FORS_PRF]
    186 
    187     def get_tree_height_raw(self) -> bytearray:
    188         """Get the tree height as raw bytes"""
    189         assert self.has_tree_height()
    190         return self.data[TREE_HEIGHT_START:TREE_HEIGHT_END]
    191 
    192     def get_tree_height(self) -> int:
    193         """Get the tree height as an integer"""
    194         return toInt(self.get_tree_height_raw(), 4)
    195 
    196     def set_tree_height_raw(self, new_height: bytes):
    197         """Set the tree height from raw bytes"""
    198         assert self.has_tree_height()
    199         assert len(new_height) == 4
    200         self.data[TREE_HEIGHT_START:TREE_HEIGHT_END] = new_height
    201 
    202     def set_tree_height(self, new_height: int):
    203         """Set the tree height from an integer"""
    204         assert new_height < 2**32
    205         self.set_tree_height_raw(toByte(new_height, 4))
    206 
    207     def has_tree_index(self):
    208         """Check if type is one of the types that has a tree index"""
    209         return self.get_typeaddress() in \
    210             [AddressType.TREE,
    211              AddressType.FORS_TREE,
    212              AddressType.FORS_PRF]
    213 
    214     def get_tree_index_raw(self) -> bytes:
    215         """Get the tree index as raw bytes"""
    216         assert self.has_tree_index()
    217         return self.data[TREE_INDEX_START:TREE_INDEX_END]
    218 
    219     def get_tree_index(self) -> int:
    220         """Get the tree index as an integer"""
    221         return toInt(self.get_tree_index_raw(), 4)
    222 
    223     def set_tree_index_raw(self, new_ti: bytes):
    224         """Set the tree index from raw bytes"""
    225         assert self.has_tree_index()
    226         assert len(new_ti) == 4
    227         self.data[TREE_INDEX_START:TREE_INDEX_END] = new_ti
    228 
    229     def set_tree_index(self, new_ti: int):
    230         """Set the tree index from an integer"""
    231         assert new_ti < 2**32
    232         self.set_tree_index_raw(toByte(new_ti, 4))