address.py (7999B)
1 """ 2 Address class. 3 """ 4 5 from enum import Enum 6 import copy 7 from utils import toByte, toInt 8 9 # unmagic numbers, note that end indices are the 10 # last entry index + 1 11 LAYER_ADDRESS_START=0 12 LAYER_ADDRESS_END=4 13 TREE_ADDRESS_START=4 14 TREE_ADDRESS_END=16 15 TYPE_START=16 16 TYPE_END=20 17 HASH_ADDRESS_START=28 18 HASH_ADDRESS_END=32 19 ADDRESS_DATA_LEN=32 20 KEY_PAIR_ADDRESS_START=20 21 KEY_PAIR_ADDRESS_END=24 22 CHAIN_ADDRESS_START=24 23 CHAIN_ADDRESS_END=28 24 TREE_HEIGHT_START=24 25 TREE_HEIGHT_END=28 26 TREE_INDEX_START=28 27 TREE_INDEX_END=32 28 29 class AddressType(Enum): 30 """Enum for types of address""" 31 WOTS_HASH = 0 32 WOTS_PK = 1 33 TREE = 2 34 FORS_TREE = 3 35 FORS_ROOTS = 4 36 WOTS_PRF = 5 37 FORS_PRF = 6 38 39 def __int__(self): 40 return self.value 41 42 class Address: 43 """Stores an SLH-DSA address (same class for all types)""" 44 def __init__(self, b=None): 45 """Initialize an address""" 46 if b == None: 47 self.data = bytearray([0]*32) 48 else: 49 self.data = bytearray(b) 50 51 def copy(self): 52 """Make a (deep) copy of an address""" 53 cself = copy.deepcopy(self) 54 return cself 55 56 def get_layer_address_raw(self) -> bytearray: 57 """Get the layer address as raw bytes""" 58 return self.data[:4] 59 60 def set_layer_address_raw(self, layer_address: bytes): 61 """Set the layer address from raw bytes""" 62 assert len(layer_address) == 4 63 self.data[LAYER_ADDRESS_START:LAYER_ADDRESS_END] = layer_address 64 65 def get_layer_address(self) -> int: 66 """Get the layer address as an integer""" 67 return toInt(self.get_layer_address_raw(), 4) 68 69 def set_layer_address(self, layer_address: int): 70 """Set the layer address from an integer""" 71 assert layer_address < 2**32 72 self.set_layer_address_raw(toByte(layer_address, 4)) 73 74 def get_tree_address_raw(self) -> bytearray: 75 """Get the tree address as raw bytes""" 76 return self.data[TREE_ADDRESS_START:TREE_ADDRESS_END] 77 78 def set_tree_address_raw(self, tree_address: bytes): 79 """Set the tree address from raw bytes""" 80 assert len(tree_address) == 12 81 self.data[TREE_ADDRESS_START:TREE_ADDRESS_END] = tree_address 82 83 def get_tree_address(self) -> int: 84 """Get the tree address as integer""" 85 return toInt(self.get_tree_address_raw(), 12) 86 87 def set_tree_address(self, tree_address: int): 88 """Set the tree address from raw bytes""" 89 assert tree_address < 2**(8*12) 90 self.set_tree_address_raw(toByte(tree_address, 12)) 91 92 def get_type(self) -> bytearray: 93 """Get type of address""" 94 return self.data[TYPE_START:TYPE_END] 95 96 def set_type(self, type_: bytes): 97 """Set the type of the address""" 98 assert len(type_) == 4 99 self.data[TYPE_START:TYPE_END] = type_ 100 101 def get_typeaddress(self) -> AddressType: 102 """Get the address as a AddressType enum""" 103 return AddressType(toInt(self.get_type(), 4)) 104 105 def __set_typeaddress(self, new_type: AddressType): 106 self.set_type(toByte(int(new_type), 4)) 107 108 def get_hash_address(self) -> bytearray: 109 """Get the hash address of WOTS_HASH""" 110 assert self.get_typeaddress() == AddressType.WOTS_HASH 111 return self.data[HASH_ADDRESS_START:HASH_ADDRESS_END] 112 113 def set_hash_address_raw(self, value: bytes): 114 """Set the hash address from raw bytes""" 115 assert self.get_typeaddress() == AddressType.WOTS_HASH 116 assert len(value) == 4 117 self.data[HASH_ADDRESS_START:HASH_ADDRESS_END] = value 118 119 def set_hash_address(self, value: int): 120 """Set the hash address from an int""" 121 assert self.get_typeaddress() == AddressType.WOTS_HASH 122 assert value >= 0 # valid uint32 123 assert value < 2**32 # valid uint32 124 self.set_hash_address_raw(toByte(value, 4)) 125 126 def set_type_and_clear(self, new_type: AddressType): 127 """Set the stype and clear everything following it in the address""" 128 # set new address type: 129 self.__set_typeaddress(new_type) 130 # clear everything after type: 131 self.data[TYPE_END:ADDRESS_DATA_LEN] = b"\x00"*(ADDRESS_DATA_LEN - TYPE_END) 132 133 def has_key_pair_address(self) -> bool: 134 """Check if type is one of the types that has a key pair address""" 135 return self.get_typeaddress() in \ 136 [AddressType.WOTS_HASH, 137 AddressType.WOTS_PK, 138 AddressType.FORS_TREE, 139 AddressType.FORS_ROOTS, 140 AddressType.WOTS_PRF, 141 AddressType.FORS_PRF] 142 143 def get_key_pair_address(self) -> bytearray: 144 """Get the key pair address""" 145 assert self.has_key_pair_address() 146 return self.data[KEY_PAIR_ADDRESS_START:KEY_PAIR_ADDRESS_END] 147 148 def set_key_pair_address(self, new_kpa: bytes): 149 """Set the keypair adress to `new_kpa`""" 150 assert self.has_key_pair_address() 151 assert len(new_kpa) == 4 152 self.data[KEY_PAIR_ADDRESS_START:KEY_PAIR_ADDRESS_END] = new_kpa 153 154 def has_chain_address(self) -> bool: 155 """Check if type is one of the types that has a chain address""" 156 return self.get_typeaddress() in \ 157 [AddressType.WOTS_HASH, 158 AddressType.WOTS_PRF] 159 160 def get_chain_address_raw(self) -> bytearray: 161 """Get the chain address as raw bytes""" 162 assert self.has_chain_address() 163 return self.data[CHAIN_ADDRESS_START:CHAIN_ADDRESS_END] 164 165 def set_chain_address_raw(self, new_ca: bytes): 166 """Set the chain address from raw bytes""" 167 assert self.has_chain_address() 168 assert len(new_ca) == 4 169 self.data[CHAIN_ADDRESS_START:CHAIN_ADDRESS_END] = new_ca 170 171 def get_chain_address(self) -> int: 172 """Get the chain address as an int""" 173 return toInt(self.get_chain_address_raw(), 4) 174 175 def set_chain_address(self, new_ca: int): 176 """Set the chain address from an int""" 177 assert new_ca < 2**32 178 self.set_chain_address_raw(toByte(new_ca, 4)) 179 180 def has_tree_height(self) -> bool: 181 """Check if type is one of the types that has a tree height""" 182 return self.get_typeaddress() in \ 183 [AddressType.TREE, 184 AddressType.FORS_TREE, 185 AddressType.FORS_PRF] 186 187 def get_tree_height_raw(self) -> bytearray: 188 """Get the tree height as raw bytes""" 189 assert self.has_tree_height() 190 return self.data[TREE_HEIGHT_START:TREE_HEIGHT_END] 191 192 def get_tree_height(self) -> int: 193 """Get the tree height as an integer""" 194 return toInt(self.get_tree_height_raw(), 4) 195 196 def set_tree_height_raw(self, new_height: bytes): 197 """Set the tree height from raw bytes""" 198 assert self.has_tree_height() 199 assert len(new_height) == 4 200 self.data[TREE_HEIGHT_START:TREE_HEIGHT_END] = new_height 201 202 def set_tree_height(self, new_height: int): 203 """Set the tree height from an integer""" 204 assert new_height < 2**32 205 self.set_tree_height_raw(toByte(new_height, 4)) 206 207 def has_tree_index(self): 208 """Check if type is one of the types that has a tree index""" 209 return self.get_typeaddress() in \ 210 [AddressType.TREE, 211 AddressType.FORS_TREE, 212 AddressType.FORS_PRF] 213 214 def get_tree_index_raw(self) -> bytes: 215 """Get the tree index as raw bytes""" 216 assert self.has_tree_index() 217 return self.data[TREE_INDEX_START:TREE_INDEX_END] 218 219 def get_tree_index(self) -> int: 220 """Get the tree index as an integer""" 221 return toInt(self.get_tree_index_raw(), 4) 222 223 def set_tree_index_raw(self, new_ti: bytes): 224 """Set the tree index from raw bytes""" 225 assert self.has_tree_index() 226 assert len(new_ti) == 4 227 self.data[TREE_INDEX_START:TREE_INDEX_END] = new_ti 228 229 def set_tree_index(self, new_ti: int): 230 """Set the tree index from an integer""" 231 assert new_ti < 2**32 232 self.set_tree_index_raw(toByte(new_ti, 4))