Skip to content

Commit ee0b8b3

Browse files
committedJul 30, 2017
Rework and test raw sockets.
1 parent 265e6f6 commit ee0b8b3

File tree

2 files changed

+248
-43
lines changed

2 files changed

+248
-43
lines changed
 

Diff for: ‎src/socket/raw.rs

+224-43
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::cmp::min;
12
use managed::Managed;
23

34
use {Error, Result};
@@ -30,6 +31,15 @@ impl<'a> PacketBuffer<'a> {
3031
fn as_mut<'b>(&'b mut self) -> &'b mut [u8] {
3132
&mut self.payload[..self.size]
3233
}
34+
35+
fn resize<'b>(&'b mut self, size: usize) -> Result<&'b mut Self> {
36+
if self.payload.len() >= size {
37+
self.size = size;
38+
Ok(self)
39+
} else {
40+
Err(Error::Truncated)
41+
}
42+
}
3343
}
3444

3545
impl<'a> Resettable for PacketBuffer<'a> {
@@ -111,87 +121,107 @@ impl<'a, 'b> RawSocket<'a, 'b> {
111121
///
112122
/// This function returns `Err(Error::Exhausted)` if the size is greater than
113123
/// the transmit packet buffer size.
124+
///
125+
/// If the buffer is filled in a way that does not match the socket's
126+
/// IP version or protocol, the packet will be silently dropped.
127+
///
128+
/// **Note:** The IP header is parsed and reserialized, and may not match
129+
/// the header actually transmitted bit for bit.
114130
pub fn send(&mut self, size: usize) -> Result<&mut [u8]> {
115-
let packet_buf = self.tx_buffer.enqueue()?;
116-
packet_buf.size = size;
131+
let packet_buf = self.tx_buffer.try_enqueue(|buf| buf.resize(size))?;
117132
net_trace!("[{}]:{}:{}: buffer to send {} octets",
118133
self.debug_id, self.ip_version, self.ip_protocol,
119134
packet_buf.size);
120-
Ok(&mut packet_buf.as_mut()[..size])
135+
Ok(packet_buf.as_mut())
121136
}
122137

123138
/// Enqueue a packet to send, and fill it from a slice.
124139
///
125140
/// See also [send](#method.send).
126-
pub fn send_slice(&mut self, data: &[u8]) -> Result<usize> {
127-
let buffer = self.send(data.len())?;
128-
let data = &data[..buffer.len()];
129-
buffer.copy_from_slice(data);
130-
Ok(data.len())
141+
pub fn send_slice(&mut self, data: &[u8]) -> Result<()> {
142+
self.send(data.len())?.copy_from_slice(data);
143+
Ok(())
131144
}
132145

133146
/// Dequeue a packet, and return a pointer to the payload.
134147
///
135148
/// This function returns `Err(Error::Exhausted)` if the receive buffer is empty.
149+
///
150+
/// **Note:** The IP header is parsed and reserialized, and may not match
151+
/// the header actually received bit for bit.
136152
pub fn recv(&mut self) -> Result<&[u8]> {
137153
let packet_buf = self.rx_buffer.dequeue()?;
138154
net_trace!("[{}]:{}:{}: receive {} buffered octets",
139155
self.debug_id, self.ip_version, self.ip_protocol,
140156
packet_buf.size);
141-
Ok(&packet_buf.as_ref()[..packet_buf.size])
157+
Ok(&packet_buf.as_ref())
142158
}
143159

144160
/// Dequeue a packet, and copy the payload into the given slice.
145161
///
146162
/// See also [recv](#method.recv).
147163
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize> {
148164
let buffer = self.recv()?;
149-
data[..buffer.len()].copy_from_slice(buffer);
150-
Ok(buffer.len())
165+
let length = min(data.len(), buffer.len());
166+
data[..length].copy_from_slice(&buffer[..length]);
167+
Ok(length)
151168
}
152169

153170
pub(crate) fn process(&mut self, _timestamp: u64, ip_repr: &IpRepr,
154171
payload: &[u8]) -> Result<()> {
155-
match self.ip_version {
156-
IpVersion::Ipv4 => {
157-
if ip_repr.protocol() != self.ip_protocol {
158-
return Err(Error::Rejected);
159-
}
160-
let header_len = ip_repr.buffer_len();
161-
let packet_buf = self.rx_buffer.enqueue()?;
162-
packet_buf.size = header_len + payload.len();
163-
ip_repr.emit(&mut packet_buf.as_mut()[..header_len]);
164-
packet_buf.as_mut()[header_len..header_len + payload.len()]
165-
.copy_from_slice(payload);
166-
net_trace!("[{}]:{}:{}: receiving {} octets",
167-
self.debug_id, self.ip_version, self.ip_protocol,
168-
packet_buf.size);
169-
Ok(())
170-
}
171-
IpVersion::__Nonexhaustive => unreachable!()
172-
}
172+
if ip_repr.version() != self.ip_version { return Err(Error::Rejected) }
173+
if ip_repr.protocol() != self.ip_protocol { return Err(Error::Rejected) }
174+
175+
let header_len = ip_repr.buffer_len();
176+
let total_len = header_len + payload.len();
177+
let packet_buf = self.rx_buffer.try_enqueue(|buf| buf.resize(total_len))?;
178+
ip_repr.emit(&mut packet_buf.as_mut()[..header_len]);
179+
packet_buf.as_mut()[header_len..].copy_from_slice(payload);
180+
net_trace!("[{}]:{}:{}: receiving {} octets",
181+
self.debug_id, self.ip_version, self.ip_protocol,
182+
packet_buf.size);
183+
Ok(())
173184
}
174185

175-
/// See [Socket::dispatch](enum.Socket.html#method.dispatch).
176186
pub(crate) fn dispatch<F, R>(&mut self, _timestamp: u64, _limits: &DeviceLimits,
177187
emit: &mut F) -> Result<R>
178188
where F: FnMut(&IpRepr, &IpPayload) -> Result<R> {
179-
let mut packet_buf = self.tx_buffer.dequeue()?;
180-
net_trace!("[{}]:{}:{}: sending {} octets",
181-
self.debug_id, self.ip_version, self.ip_protocol,
182-
packet_buf.size);
189+
fn prepare(version: IpVersion, protocol: IpProtocol,
190+
buffer: &mut [u8]) -> Result<(IpRepr, RawRepr)> {
191+
match IpVersion::of_packet(buffer.as_ref())? {
192+
IpVersion::Ipv4 => {
193+
let mut packet = Ipv4Packet::new_checked(buffer.as_mut())?;
194+
if packet.protocol() != protocol { return Err(Error::Unaddressable) }
195+
packet.fill_checksum();
183196

184-
match self.ip_version {
185-
IpVersion::Ipv4 => {
186-
let mut ipv4_packet = Ipv4Packet::new_checked(packet_buf.as_mut())?;
187-
ipv4_packet.fill_checksum();
197+
let packet = Ipv4Packet::new(&*packet.into_inner());
198+
let ipv4_repr = Ipv4Repr::parse(&packet)?;
199+
let raw_repr = RawRepr(packet.payload());
200+
Ok((IpRepr::Ipv4(ipv4_repr), raw_repr))
201+
}
202+
IpVersion::Unspecified => unreachable!(),
203+
IpVersion::__Nonexhaustive => unreachable!()
204+
}
205+
}
188206

189-
let ipv4_packet = Ipv4Packet::new(&*ipv4_packet.into_inner());
190-
let raw_repr = RawRepr(ipv4_packet.payload());
191-
let ipv4_repr = Ipv4Repr::parse(&ipv4_packet)?;
192-
emit(&IpRepr::Ipv4(ipv4_repr), &raw_repr)
207+
let mut packet_buf = self.tx_buffer.dequeue()?;
208+
match prepare(self.ip_version, self.ip_protocol, packet_buf.as_mut()) {
209+
Ok((ip_repr, raw_repr)) => {
210+
net_trace!("[{}]:{}:{}: sending {} octets",
211+
self.debug_id, self.ip_version, self.ip_protocol,
212+
ip_repr.buffer_len() + raw_repr.buffer_len());
213+
emit(&ip_repr, &raw_repr)
214+
}
215+
Err(error) => {
216+
net_trace!("[{}]:{}:{}: dropping outgoing packet ({})",
217+
self.debug_id, self.ip_version, self.ip_protocol,
218+
error);
219+
// This case is a bit special because in every other socket, no matter what data
220+
// is put into the socket, it can be sent, but it's possible to put data into
221+
// a raw socket that may not be, and we're generic over the result type, so
222+
// we can't possibly return Ok(()) here.
223+
Err(Error::Rejected)
193224
}
194-
IpVersion::__Nonexhaustive => unreachable!()
195225
}
196226
}
197227
}
@@ -207,3 +237,154 @@ impl<'a> IpPayload for RawRepr<'a> {
207237
payload.copy_from_slice(self.0);
208238
}
209239
}
240+
241+
#[cfg(test)]
242+
mod test {
243+
use wire::{IpAddress, Ipv4Address, IpRepr, Ipv4Repr};
244+
use super::*;
245+
246+
fn buffer(packets: usize) -> SocketBuffer<'static, 'static> {
247+
let mut storage = vec![];
248+
for _ in 0..packets {
249+
storage.push(PacketBuffer::new(vec![0; 24]))
250+
}
251+
SocketBuffer::new(storage)
252+
}
253+
254+
fn socket(rx_buffer: SocketBuffer<'static, 'static>,
255+
tx_buffer: SocketBuffer<'static, 'static>)
256+
-> RawSocket<'static, 'static> {
257+
match RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(63),
258+
rx_buffer, tx_buffer) {
259+
Socket::Raw(socket) => socket,
260+
_ => unreachable!()
261+
}
262+
}
263+
264+
const HEADER_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
265+
src_addr: Ipv4Address([10, 0, 0, 1]),
266+
dst_addr: Ipv4Address([10, 0, 0, 2]),
267+
protocol: IpProtocol::Unknown(63),
268+
payload_len: 4
269+
});
270+
const PACKET_BYTES: [u8; 24] = [
271+
0x45, 0x00, 0x00, 0x18,
272+
0x00, 0x00, 0x40, 0x00,
273+
0x40, 0x3f, 0x00, 0x00,
274+
0x0a, 0x00, 0x00, 0x01,
275+
0x0a, 0x00, 0x00, 0x02,
276+
0xaa, 0x00, 0x00, 0xff
277+
];
278+
const PACKET_PAYLOAD: [u8; 4] = [
279+
0xaa, 0x00, 0x00, 0xff
280+
];
281+
282+
#[test]
283+
fn test_send_truncated() {
284+
let mut socket = socket(buffer(0), buffer(1));
285+
assert_eq!(socket.send_slice(&[0; 32][..]), Err(Error::Truncated));
286+
}
287+
288+
#[test]
289+
fn test_send_dispatch() {
290+
let limits = DeviceLimits::default();
291+
292+
let mut socket = socket(buffer(0), buffer(1));
293+
294+
assert!(socket.can_send());
295+
assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
296+
unreachable!()
297+
}), Err(Error::Exhausted) as Result<()>);
298+
299+
assert_eq!(socket.send_slice(&PACKET_BYTES[..]), Ok(()));
300+
assert_eq!(socket.send_slice(b""), Err(Error::Exhausted));
301+
assert!(!socket.can_send());
302+
303+
macro_rules! assert_payload_eq {
304+
($ip_repr:expr, $ip_payload:expr, $expected:expr) => {{
305+
let mut buffer = vec![0; $ip_payload.buffer_len()];
306+
$ip_payload.emit(&$ip_repr, &mut buffer);
307+
assert_eq!(&buffer[..], &$expected[$ip_repr.buffer_len()..]);
308+
}}
309+
}
310+
311+
assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
312+
assert_eq!(ip_repr, &HEADER_REPR);
313+
assert_payload_eq!(ip_repr, ip_payload, PACKET_BYTES);
314+
Err(Error::Unaddressable)
315+
}), Err(Error::Unaddressable) as Result<()>);
316+
/*assert!(!socket.can_send());*/
317+
318+
assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
319+
assert_eq!(ip_repr, &HEADER_REPR);
320+
assert_payload_eq!(ip_repr, ip_payload, PACKET_BYTES);
321+
Ok(())
322+
}), /*Ok(())*/ Err(Error::Exhausted));
323+
assert!(socket.can_send());
324+
}
325+
326+
#[test]
327+
fn test_send_illegal() {
328+
let limits = DeviceLimits::default();
329+
330+
let mut socket = socket(buffer(0), buffer(1));
331+
332+
let mut wrong_version = PACKET_BYTES.clone();
333+
Ipv4Packet::new(&mut wrong_version).set_version(5);
334+
335+
assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
336+
assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
337+
unreachable!()
338+
}), Err(Error::Rejected) as Result<()>);
339+
340+
let mut wrong_protocol = PACKET_BYTES.clone();
341+
Ipv4Packet::new(&mut wrong_protocol).set_protocol(IpProtocol::Tcp);
342+
343+
assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
344+
assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
345+
unreachable!()
346+
}), Err(Error::Rejected) as Result<()>);
347+
}
348+
349+
#[test]
350+
fn test_recv_process() {
351+
let mut socket = socket(buffer(1), buffer(0));
352+
assert!(!socket.can_recv());
353+
354+
let mut cksumd_packet = PACKET_BYTES.clone();
355+
Ipv4Packet::new(&mut cksumd_packet).fill_checksum();
356+
357+
assert_eq!(socket.recv(), Err(Error::Exhausted));
358+
assert_eq!(socket.process(0, &HEADER_REPR, &PACKET_PAYLOAD),
359+
Ok(()));
360+
assert!(socket.can_recv());
361+
362+
assert_eq!(socket.process(0, &HEADER_REPR, &PACKET_PAYLOAD),
363+
Err(Error::Exhausted));
364+
assert_eq!(socket.recv(), Ok(&cksumd_packet[..]));
365+
assert!(!socket.can_recv());
366+
}
367+
368+
#[test]
369+
fn test_recv_truncated_slice() {
370+
let mut socket = socket(buffer(1), buffer(0));
371+
372+
assert_eq!(socket.process(0, &HEADER_REPR, &PACKET_PAYLOAD),
373+
Ok(()));
374+
375+
let mut slice = [0; 4];
376+
assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4));
377+
assert_eq!(&slice, &PACKET_BYTES[..slice.len()]);
378+
}
379+
380+
#[test]
381+
fn test_recv_truncated_packet() {
382+
let mut socket = socket(buffer(1), buffer(0));
383+
384+
let mut buffer = vec![0; 128];
385+
buffer[..PACKET_BYTES.len()].copy_from_slice(&PACKET_BYTES[..]);
386+
387+
assert_eq!(socket.process(0, &HEADER_REPR, &buffer),
388+
Err(Error::Truncated));
389+
}
390+
}

Diff for: ‎src/wire/ip.rs

+24
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,29 @@ use super::{Ipv4Address, Ipv4Packet, Ipv4Repr};
66
/// Internet protocol version.
77
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
88
pub enum Version {
9+
Unspecified,
910
Ipv4,
1011
#[doc(hidden)]
1112
__Nonexhaustive,
1213
}
1314

15+
impl Version {
16+
/// Return the version of an IP packet stored in the provided buffer.
17+
///
18+
/// This function never returns `Ok(IpVersion::Unspecified)`; instead,
19+
/// unknown versions result in `Err(Error::Unrecognized)`.
20+
pub fn of_packet(data: &[u8]) -> Result<Version> {
21+
match data[0] >> 4 {
22+
4 => Ok(Version::Ipv4),
23+
_ => Err(Error::Unrecognized)
24+
}
25+
}
26+
}
27+
1428
impl fmt::Display for Version {
1529
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1630
match self {
31+
&Version::Unspecified => write!(f, "IPv?"),
1732
&Version::Ipv4 => write!(f, "IPv4"),
1833
&Version::__Nonexhaustive => unreachable!()
1934
}
@@ -171,6 +186,15 @@ pub enum IpRepr {
171186
}
172187

173188
impl IpRepr {
189+
/// Return the protocol version.
190+
pub fn version(&self) -> Version {
191+
match self {
192+
&IpRepr::Unspecified { .. } => Version::Unspecified,
193+
&IpRepr::Ipv4(_) => Version::Ipv4,
194+
&IpRepr::__Nonexhaustive => unreachable!()
195+
}
196+
}
197+
174198
/// Return the source address.
175199
pub fn src_addr(&self) -> Address {
176200
match self {

0 commit comments

Comments
 (0)
Please sign in to comment.