blob: b6ccf5e4090dbb2de41fbdf6bfeda5d273e3f8e6 [file] [log] [blame]
Matt Jeanneretf1e9c5d2019-02-08 07:41:29 -05001#
2# Copyright 2017 the original author or authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16import binascii
17import json
18from scapy.fields import Field, StrFixedLenField, PadField, IntField, FieldListField, ByteField, StrField, \
19 StrFixedLenField, PacketField
20from scapy.packet import Raw
21
22class FixedLenField(PadField):
23 """
24 This Pad field limits parsing of its content to its size
25 """
26 def __init__(self, fld, align, padwith='\x00'):
27 super(FixedLenField, self).__init__(fld, align, padwith)
28
29 def getfield(self, pkt, s):
30 remain, val = self._fld.getfield(pkt, s[:self._align])
31 if isinstance(val.payload, Raw) and \
32 not val.payload.load.replace(self._padwith, ''):
33 # raw payload is just padding
34 val.remove_payload()
35 return remain + s[self._align:], val
36
37class StrCompoundField(Field):
38 __slots__ = ['flds']
39
40 def __init__(self, name, flds):
41 super(StrCompoundField, self).__init__(name=name, default=None, fmt='s')
42 self.flds = flds
43 for fld in self.flds:
44 assert not fld.holds_packets, 'compound field cannot have packet field members'
45
46 def addfield(self, pkt, s, val):
47 for fld in self.flds:
48 # run though fake add/get to consume the relevant portion of the input value for this field
49 x, extracted = fld.getfield(pkt, fld.addfield(pkt, '', val))
50 l = len(extracted)
51 s = fld.addfield(pkt, s, val[0:l])
52 val = val[l:]
53 return s;
54
55 def getfield(self, pkt, s):
56 data = ''
57 for fld in self.flds:
58 s, value = fld.getfield(pkt, s)
59 if not isinstance(value, str):
60 value = fld.i2repr(pkt, value)
61 data += value
62 return s, data
63
64class XStrFixedLenField(StrFixedLenField):
65 """
66 XStrFixedLenField which value is printed as hexadecimal.
67 """
68 def i2m(self, pkt, x):
69 l = self.length_from(pkt) * 2
70 return None if x is None else binascii.a2b_hex(x)[0:l+1]
71
72 def m2i(self, pkt, x):
73 return None if x is None else binascii.b2a_hex(x)
74
75class MultipleTypeField(object):
76 """MultipleTypeField are used for fields that can be implemented by
77 various Field subclasses, depending on conditions on the packet.
78
79 It is initialized with `flds` and `default`.
80
81 `default` is the default field type, to be used when none of the
82 conditions matched the current packet.
83
84 `flds` is a list of tuples (`fld`, `cond`), where `fld` if a field
85 type, and `cond` a "condition" to determine if `fld` is the field type
86 that should be used.
87
88 `cond` is either:
89
90 - a callable `cond_pkt` that accepts one argument (the packet) and
91 returns True if `fld` should be used, False otherwise.
92
93 - a tuple (`cond_pkt`, `cond_pkt_val`), where `cond_pkt` is the same
94 as in the previous case and `cond_pkt_val` is a callable that
95 accepts two arguments (the packet, and the value to be set) and
96 returns True if `fld` should be used, False otherwise.
97
98 See scapy.layers.l2.ARP (type "help(ARP)" in Scapy) for an example of
99 use.
100 """
101
102 __slots__ = ["flds", "default", "name"]
103
104 def __init__(self, flds, default):
105 self.flds = flds
106 self.default = default
107 self.name = self.default.name
108
109 def _find_fld_pkt(self, pkt):
110 """Given a Packet instance `pkt`, returns the Field subclass to be
111 used. If you know the value to be set (e.g., in .addfield()), use
112 ._find_fld_pkt_val() instead.
113 """
114 for fld, cond in self.flds:
115 if isinstance(cond, tuple):
116 cond = cond[0]
117 if cond(pkt):
118 return fld
119 return self.default
120
121 def _find_fld_pkt_val(self, pkt, val):
122 """Given a Packet instance `pkt` and the value `val` to be set,
123 returns the Field subclass to be used.
124 """
125 for fld, cond in self.flds:
126 if isinstance(cond, tuple):
127 if cond[1](pkt, val):
128 return fld
129 elif cond(pkt):
130 return fld
131 return self.default
132
133 def getfield(self, pkt, s):
134 return self._find_fld_pkt(pkt).getfield(pkt, s)
135
136 def addfield(self, pkt, s, val):
137 return self._find_fld_pkt_val(pkt, val).addfield(pkt, s, val)
138
139 def any2i(self, pkt, val):
140 return self._find_fld_pkt_val(pkt, val).any2i(pkt, val)
141
142 def h2i(self, pkt, val):
143 return self._find_fld_pkt_val(pkt, val).h2i(pkt, val)
144
145 def i2h(self, pkt, val):
146 return self._find_fld_pkt_val(pkt, val).i2h(pkt, val)
147
148 def i2m(self, pkt, val):
149 return self._find_fld_pkt_val(pkt, val).i2m(pkt, val)
150
151 def i2len(self, pkt, val):
152 return self._find_fld_pkt_val(pkt, val).i2len(pkt, val)
153
154 def i2repr(self, pkt, val):
155 return self._find_fld_pkt_val(pkt, val).i2repr(pkt, val)
156
157 def register_owner(self, cls):
158 for fld, _ in self.flds:
159 fld.owners.append(cls)
160 self.dflt.owners.append(cls)
161
162 def __getattr__(self, attr):
163 return getattr(self._find_fld(), attr)
164
165class OmciSerialNumberField(StrCompoundField):
166 def __init__(self, name, default=None):
167 assert default is None or (isinstance(default, str) and len(default) == 12), 'invalid default serial number'
168 vendor_default = default[0:4] if default is not None else None
169 vendor_serial_default = default[4:12] if default is not None else None
170 super(OmciSerialNumberField, self).__init__(name,
171 [StrFixedLenField('vendor_id', vendor_default, 4),
172 XStrFixedLenField('vendor_serial_number', vendor_serial_default, 4)])
173
174class OmciTableField(MultipleTypeField):
175 def __init__(self, tblfld):
176 assert isinstance(tblfld, PacketField)
177 assert hasattr(tblfld.cls, 'index'), 'No index() method defined for OmciTableField row object'
178 assert hasattr(tblfld.cls, 'is_delete'), 'No delete() method defined for OmciTableField row object'
179 super(OmciTableField, self).__init__(
180 [
181 (IntField('table_length', 0), (self.cond_pkt, self.cond_pkt_val)),
182 (PadField(StrField('me_type_table', None), OmciTableField.PDU_SIZE),
183 (self.cond_pkt2, self.cond_pkt_val2))
184 ], tblfld)
185
186 PDU_SIZE = 29 # Baseline message set raw get-next PDU size
187 OmciGetResponseMessageId = 0x29 # Ugh circular dependency
188 OmciGetNextResponseMessageId = 0x3a # Ugh circular dependency
189
190 def cond_pkt(self, pkt):
191 return pkt is not None and pkt.message_id == self.OmciGetResponseMessageId
192
193 def cond_pkt_val(self, pkt, val):
194 return pkt is not None and pkt.message_id == self.OmciGetResponseMessageId
195
196 def cond_pkt2(self, pkt):
197 return pkt is not None and pkt.message_id == self.OmciGetNextResponseMessageId
198
199 def cond_pkt_val2(self, pkt, val):
200 return pkt is not None and pkt.message_id == self.OmciGetNextResponseMessageId
201
202 def to_json(self, new_values, old_values_json):
203 if not isinstance(new_values, list): new_values = [new_values] # If setting a scalar, augment the old table
204 else: old_values_json = None # If setting a vector of new values, erase all old_values
205
206 key_value_pairs = dict()
207
208 old_table = self.load_json(old_values_json)
209 for old in old_table:
210 index = old.index()
211 key_value_pairs[index] = old
212 for new in new_values:
213 index = new.index()
214 if new.is_delete():
215 del key_value_pairs[index]
216 else:
217 key_value_pairs[index] = new
218
219 new_table = []
220 for k, v in sorted(key_value_pairs.iteritems()):
221 assert isinstance(v, self.default.cls), 'object type for Omci Table row object invalid'
222 new_table.append(v.fields)
223
224 str_values = json.dumps(new_table, separators=(',', ':'))
225
226 return str_values
227
228 def load_json(self, json_str):
229 if json_str is None: json_str = '[]'
230 json_values = json.loads(json_str)
231 key_value_pairs = dict()
232 for json_value in json_values:
233 v = self.default.cls(**json_value)
234 index = v.index()
235 key_value_pairs[index] = v
236 table = []
237 for k, v in sorted(key_value_pairs.iteritems()):
238 table.append(v)
239 return table