1
1
import sys
2
2
import types
3
3
import unittest
4
+ from collections import OrderedDict
4
5
from bitarray import bitarray
5
- from ctypes import c_ubyte , c_uint64 , LittleEndianStructure , Union
6
-
7
6
8
7
__all__ = ["Bitfield" ]
9
8
9
+ class _Bitfield :
10
+ @classmethod
11
+ def _build_fields (cls , size_bits , fields ):
12
+ total_width = sum (width for name , width in fields )
13
+ if total_width > size_bits :
14
+ raise ValueError ("field widths exceed declared bit size (%d > %d)" % (total_width , size_bits ))
15
+ elif total_width < size_bits :
16
+ fields = fields + [(None , size_bits - total_width )]
17
+
18
+ cls ._size_bits = size_bits
19
+ cls ._size_bytes = (size_bits + 7 ) // 8
20
+ cls ._named_fields = []
21
+ cls ._widths = OrderedDict ()
22
+
23
+ bit = 0
24
+ for name , width in fields :
25
+ if name is None :
26
+ name = "_padding_%d" % bit
27
+ else :
28
+ cls ._named_fields .append (name )
29
+
30
+ cls ._create_field (name , bit , width )
31
+ bit += width
10
32
11
- class _PackedUnion (Union ):
12
33
@classmethod
13
- def from_int (cls , data ):
14
- pack = cls ()
15
- pack ._int_ = data
16
- return pack
34
+ def _create_field (cls , name , start , width ):
35
+ cls ._widths [name ] = width
36
+ end = start + width
37
+ num_bytes = (width + 7 ) // 8
38
+ max_int = (1 << width ) - 1
39
+
40
+ @property
41
+ def getter (self ):
42
+ return int .from_bytes (self ._bits [start :end ].tobytes (), "little" )
43
+
44
+ @getter .setter
45
+ def setter (self , value ):
46
+ if isinstance (value , bitarray ):
47
+ assert value .length () == width
48
+ self ._bits [start :end ] = b
49
+ else :
50
+ if value > max_int :
51
+ raise OverflowError ("int too big to fit in %d bits" % width )
52
+ b = bitarray (endian = "little" )
53
+ b .frombytes (value .to_bytes (num_bytes , "little" ))
54
+ self ._bits [start :end ] = b [:width ]
55
+
56
+ setattr (cls , name , setter )
17
57
18
58
@classmethod
19
- def from_bytes (cls , data ):
20
- data = bytes (data )
21
- pack = cls ()
22
- pack ._bytes_ [:] = data
23
- return pack
59
+ def from_int (cls , data ):
60
+ if data >= (1 << cls ._size_bits ):
61
+ raise OverflowError ("int too big to fit in %d bits" % cls ._size_bits )
62
+ return cls .from_bytes (data .to_bytes (cls ._size_bytes , "little" ))
24
63
25
64
@classmethod
26
65
def from_bytearray (cls , data ):
27
- data = bytearray (data )
28
- pack = cls ()
29
- pack ._bytes_ [:] = data
30
- return pack
66
+ return cls .from_bytes (bytes (data ))
67
+
68
+ @classmethod
69
+ def from_bytes (cls , data ):
70
+ if len (data ) != cls ._size_bytes :
71
+ raise ValueError ("need %d bytes to fill BitArray" % cls ._size_bytes )
72
+ b = bitarray (endian = "little" )
73
+ b .frombytes (data )
74
+ return cls .from_bitarray (b [:cls ._size_bits ])
31
75
32
76
@classmethod
33
77
def from_bitarray (cls , data ):
34
- data = bitarray (data , endian = "little" )
35
- return cls .from_bytes (data .tobytes ())
78
+ assert data .length () == cls ._size_bits
79
+ assert data .endian () == "little"
80
+ pack = cls ()
81
+ pack ._bits = bitarray (data , endian = "little" )
82
+ return pack
36
83
37
84
def __init__ (self , * args , ** kwargs ):
38
- _ , bits_cls = self ._fields_ [0 ]
85
+ self ._bits = bitarray (self ._size_bits , endian = "little" )
86
+ self ._bits [:] = 0
39
87
40
- arg_index = 0
41
- fields = {}
42
- for f_name , f_type , f_width in bits_cls ._fields_ :
43
- if arg_index == len (args ):
44
- break
88
+ if len (args ) > len (self ._named_fields ):
89
+ raise ValueError ("too many arguments for field count (%d > %d)" % (len (args ), len (self ._named_fields )))
45
90
46
- if not f_name .startswith ("_reserved_" ):
47
- assert f_name not in fields
48
- fields [f_name ] = args [arg_index ]
49
- arg_index += 1
91
+ for i , v in enumerate (args ):
92
+ setattr (self , self ._named_fields [i ], v )
50
93
51
- fields .update (kwargs )
94
+ for k ,v in kwargs .values ():
95
+ if k not in self ._widths :
96
+ raise ValueError ("unknown field name %s" % k )
97
+ setattr (self , k , v )
52
98
53
- super ().__init__ (bits_cls (** fields ))
54
-
55
- def copy (self ):
56
- pack = self .__class__ ()
57
- pack ._bytes_ [:] = self ._bytes_ [:]
58
- return pack
59
-
60
- def to_int (self ):
61
- return self ._int_
99
+ def to_bitarray (self ):
100
+ return bitarray (self ._bits , endian = "little" )
62
101
63
102
def to_bytes (self ):
64
- return bytes ( self ._bytes_ )
103
+ return self ._bits . tobytes ( )
65
104
66
105
def to_bytearray (self ):
67
- return bytearray (self ._bytes_ )
106
+ return bytearray (self .to_bytes () )
68
107
69
- def to_bitarray (self ):
70
- data = bitarray (endian = "little" )
71
- data .frombytes (self .to_bytes ())
72
- return data
108
+ def to_int (self ):
109
+ return int .from_bytes (self .to_bytes (), "little" )
73
110
74
- def bits_repr (self , omit_zero = False ):
75
- fields = []
76
- for f_name , f_type , f_width in self ._bits_ ._fields_ :
77
- if f_name .startswith ("_reserved_" ):
78
- continue
111
+ def copy (self ):
112
+ return self .__class__ .from_bitarray (self ._bits )
79
113
80
- f_value = getattr (self ._bits_ , f_name )
114
+ def bits_repr (self , omit_zero = False , omit_padding = True ):
115
+ fields = []
116
+ if omit_padding :
117
+ names = self ._named_fields
118
+ else :
119
+ names = self ._widths .keys ()
120
+
121
+ for name in names :
122
+ width = self ._widths [name ]
123
+ value = getattr (self , name )
81
124
if omit_zero and not f_value :
82
125
continue
83
126
84
- fields .append ("{}={:0{}b}" .format (f_name , f_value , f_width ))
127
+ fields .append ("{}={:0{}b}" .format (name , value , width ))
85
128
86
129
return " " .join (fields )
87
130
88
131
def __repr__ (self ):
89
132
return "<{}.{} {}>" .format (self .__module__ , self .__class__ .__name__ , self .bits_repr ())
90
133
91
134
def __eq__ (self , other ):
92
- return self ._bytes_ [:] == other ._bytes_ [:]
135
+ return self ._bits [:] == other ._bits [:]
93
136
94
137
def __ne__ (self , other ):
95
- return self ._bytes_ [:] != other ._bytes_ [:]
96
-
138
+ return self ._bits [:] != other ._bits [:]
97
139
98
140
def Bitfield (name , size_bytes , fields ):
99
141
mod = sys ._getframe (1 ).f_globals ["__name__" ] # see namedtuple()
142
+ size_bits = size_bytes * 8
100
143
101
- reserved = 0
102
- def make_reserved ():
103
- nonlocal reserved
104
- reserved += 1
105
- return "_reserved_{}" .format (reserved )
106
-
107
- bits_cls = types .new_class (name + "_bits_" , (LittleEndianStructure ,))
108
- bits_cls .__module__ = mod
109
- bits_cls ._packed_ = True
110
- bits_cls ._fields_ = [(make_reserved () if f_name is None else f_name , c_uint64 , f_width )
111
- for f_name , f_width in fields ]
112
-
113
- pack_cls = types .new_class (name , (_PackedUnion ,))
114
- pack_cls .__module__ = mod
115
- pack_cls ._packed_ = True
116
- pack_cls ._anonymous_ = ("_bits_" ,)
117
- pack_cls ._fields_ = [("_bits_" , bits_cls ),
118
- ("_bytes_" , c_ubyte * size_bytes ),
119
- ("_int_" , c_uint64 )]
144
+ cls = types .new_class (name , (_Bitfield ,))
145
+ cls .__module__ = mod
146
+ cls ._build_fields (size_bits , fields )
120
147
121
- return pack_cls
148
+ return cls
122
149
123
150
# -------------------------------------------------------------------------------------------------
124
151
@@ -132,7 +159,16 @@ def test_definition(self):
132
159
self .assertEqual (x .b , 2 )
133
160
134
161
def test_large (self ):
135
- bf = Bitfield ("bf" , 8 , [("a" , 64 )])
162
+ bf = Bitfield ("bf" , 9 , [(None , 8 ), ("a" , 64 )])
163
+ val = (3 << 62 ) + 1
164
+ x = bf (val )
165
+ self .assertEqual (x .to_int (), val << 8 )
166
+
167
+ def test_huge (self ):
168
+ bf = Bitfield ("bf" , 260 , [("e" , 32 ), ("m" , 2048 )])
169
+ x = bf (65537 , (30 << 2048 ) // 31 )
170
+ self .assertEqual (x .e , 65537 )
171
+ self .assertEqual (x .m , (30 << 2048 ) // 31 )
136
172
137
173
def test_reserved (self ):
138
174
bf = Bitfield ("bf" , 8 , [(None , 1 ), ("a" , 1 )])
@@ -153,6 +189,13 @@ def test_bytearray(self):
153
189
self .assertEqual (x .to_bytearray (), bytearray (b"\x11 \x00 " ))
154
190
self .assertEqual (bf .from_bytearray (x .to_bytearray ()), x )
155
191
192
+ def test_int (self ):
193
+ bf = Bitfield ("bf" , 2 , [("a" , 3 ), ("b" , 5 )])
194
+ x = bf (1 , 2 )
195
+ self .assertIsInstance (x .to_int (), int )
196
+ self .assertEqual (x .to_int (), 17 )
197
+ self .assertEqual (bf .from_int (x .to_int ()), x )
198
+
156
199
def test_bitaray (self ):
157
200
bf = Bitfield ("bf" , 2 , [("a" , 3 ), ("b" , 5 )])
158
201
x = bf (1 , 2 )
0 commit comments