Skip to content

Commit

Permalink
Rework responses to TCP packets and factor in RST replies to TcpSocket.
Browse files Browse the repository at this point in the history
whitequark committed Aug 22, 2017
1 parent 7e6e379 commit bc2a894
Showing 4 changed files with 142 additions and 99 deletions.
150 changes: 65 additions & 85 deletions src/iface/ethernet.rs
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ enum Response<'a> {
Nop,
Arp(ArpRepr),
Icmpv4(Ipv4Repr, Icmpv4Repr<'a>),
Tcpv4(Ipv4Repr, TcpRepr<'a>)
Tcp(IpRepr, TcpRepr<'a>)
}

impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
@@ -220,10 +220,10 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
match ipv4_repr.protocol {
IpProtocol::Icmp =>
Self::process_icmpv4(ipv4_repr, ipv4_packet.payload()),
IpProtocol::Tcp =>
Self::process_tcpv4(sockets, timestamp, ipv4_repr, ipv4_packet.payload()),
IpProtocol::Udp =>
Self::process_udpv4(sockets, timestamp, ipv4_repr, ipv4_packet.payload()),
IpProtocol::Tcp =>
Self::process_tcp(sockets, timestamp, ipv4_repr.into(), ipv4_packet.payload()),
_ if handled_by_raw_socket =>
Ok(Response::Nop),
_ => {
@@ -307,11 +307,9 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
Ok(Response::Icmpv4(ipv4_reply_repr, icmp_reply_repr))
}

fn process_tcpv4<'frame>(sockets: &mut SocketSet, timestamp: u64,
ipv4_repr: Ipv4Repr, ip_payload: &'frame [u8]) ->
Result<Response<'frame>> {
let ip_repr = IpRepr::Ipv4(ipv4_repr);

fn process_tcp<'frame>(sockets: &mut SocketSet, timestamp: u64,
ip_repr: IpRepr, ip_payload: &'frame [u8]) ->
Result<Response<'frame>> {
for tcp_socket in sockets.iter_mut().filter_map(
<Socket as AsSocket<TcpSocket>>::try_as_socket) {
match tcp_socket.process(timestamp, &ip_repr, ip_payload) {
@@ -327,99 +325,81 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {

// The packet wasn't handled by a socket, send a TCP RST packet.
let tcp_packet = TcpPacket::new_checked(ip_payload)?;
if tcp_packet.rst() {
// Don't reply to a TCP RST packet with another TCP RST packet.
return Ok(Response::Nop)
let tcp_repr = TcpRepr::parse(&tcp_packet, &ip_repr.src_addr(), &ip_repr.dst_addr())?;
if tcp_repr.control == TcpControl::Rst {
// Never reply to a TCP RST packet with another TCP RST packet.
Ok(Response::Nop)
} else {
let (ip_reply_repr, tcp_reply_repr) = TcpSocket::rst_reply(&ip_repr, &tcp_repr);
Ok(Response::Tcp(ip_reply_repr, tcp_reply_repr))
}
let tcp_reply_repr = TcpRepr {
src_port: tcp_packet.dst_port(),
dst_port: tcp_packet.src_port(),
control: TcpControl::Rst,
push: false,
seq_number: tcp_packet.ack_number(),
ack_number: Some(tcp_packet.seq_number() +
tcp_packet.segment_len()),
window_len: 0,
max_seg_size: None,
payload: &[]
};
let ipv4_reply_repr = Ipv4Repr {
src_addr: ipv4_repr.dst_addr,
dst_addr: ipv4_repr.src_addr,
protocol: IpProtocol::Tcp,
payload_len: tcp_reply_repr.buffer_len()
};
Ok(Response::Tcpv4(ipv4_reply_repr, tcp_reply_repr))
}

fn send_response(&mut self, timestamp: u64, response: Response) -> Result<()> {
macro_rules! ip_response {
($tx_buffer:ident, $frame:ident, $ip_repr:ident) => ({
let dst_hardware_addr =
match self.arp_cache.lookup(&$ip_repr.dst_addr.into()) {
None => return Err(Error::Unaddressable),
Some(hardware_addr) => hardware_addr
};

let tx_len = EthernetFrame::<&[u8]>::buffer_len($ip_repr.buffer_len() +
$ip_repr.payload_len);
$tx_buffer = self.device.transmit(timestamp, tx_len)?;
debug_assert!($tx_buffer.as_ref().len() == tx_len);
macro_rules! emit_packet {
(Ethernet, $buffer_len:expr, |$frame:ident| $code:stmt) => ({
let tx_len = EthernetFrame::<&[u8]>::buffer_len($buffer_len);
let mut tx_buffer = self.device.transmit(timestamp, tx_len)?;
debug_assert!(tx_buffer.as_ref().len() == tx_len);

$frame = EthernetFrame::new(&mut $tx_buffer);
let mut $frame = EthernetFrame::new(&mut tx_buffer);
$frame.set_src_addr(self.hardware_addr);
$frame.set_dst_addr(dst_hardware_addr);
$frame.set_ethertype(EthernetProtocol::Ipv4);

let mut ip_packet = Ipv4Packet::new($frame.payload_mut());
$ip_repr.emit(&mut ip_packet);
ip_packet
$code

Ok(())
});

(Ip, $ip_repr:expr, |$payload:ident| $code:stmt) => ({
let ip_repr = $ip_repr.lower(&self.protocol_addrs)?;
match self.arp_cache.lookup(&ip_repr.dst_addr()) {
None => Err(Error::Unaddressable),
Some(dst_hardware_addr) => {
emit_packet!(Ethernet, ip_repr.total_len(), |frame| {
frame.set_dst_addr(dst_hardware_addr);
match ip_repr {
IpRepr::Ipv4(_) => frame.set_ethertype(EthernetProtocol::Ipv4),
_ => unreachable!()
}

ip_repr.emit(frame.payload_mut());

let $payload = &mut frame.payload_mut()[ip_repr.buffer_len()..];
$code
})
}
}
})
}

match response {
Response::Arp(repr) => {
let tx_len = EthernetFrame::<&[u8]>::buffer_len(repr.buffer_len());
let mut tx_buffer = self.device.transmit(timestamp, tx_len)?;
debug_assert!(tx_buffer.as_ref().len() == tx_len);

let mut frame = EthernetFrame::new(&mut tx_buffer);
frame.set_src_addr(self.hardware_addr);
frame.set_dst_addr(match repr {
ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr,
_ => unreachable!()
});
frame.set_ethertype(EthernetProtocol::Arp);
Response::Arp(arp_repr) => {
let dst_hardware_addr =
match arp_repr {
ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr,
_ => unreachable!()
};

let mut packet = ArpPacket::new(frame.payload_mut());
repr.emit(&mut packet);
emit_packet!(Ethernet, arp_repr.buffer_len(), |frame| {
frame.set_dst_addr(dst_hardware_addr);
frame.set_ethertype(EthernetProtocol::Arp);

Ok(())
let mut packet = ArpPacket::new(frame.payload_mut());
arp_repr.emit(&mut packet);
})
},

Response::Icmpv4(ip_repr, icmp_repr) => {
let mut tx_buffer;
let mut frame;
let mut ip_packet = ip_response!(tx_buffer, frame, ip_repr);
let mut icmp_packet = Icmpv4Packet::new(ip_packet.payload_mut());
icmp_repr.emit(&mut icmp_packet);
Ok(())
Response::Icmpv4(ipv4_repr, icmpv4_repr) => {
emit_packet!(Ip, IpRepr::Ipv4(ipv4_repr), |payload| {
icmpv4_repr.emit(&mut Icmpv4Packet::new(payload));
})
}

Response::Tcpv4(ip_repr, tcp_repr) => {
let mut tx_buffer;
let mut frame;
let mut ip_packet = ip_response!(tx_buffer, frame, ip_repr);
let mut tcp_packet = TcpPacket::new(ip_packet.payload_mut());
tcp_repr.emit(&mut tcp_packet,
&IpAddress::Ipv4(ip_repr.src_addr),
&IpAddress::Ipv4(ip_repr.dst_addr));
Ok(())
}

Response::Nop => {
Ok(())
Response::Tcp(ip_repr, tcp_repr) => {
emit_packet!(Ip, ip_repr, |payload| {
tcp_repr.emit(&mut TcpPacket::new(payload),
&ip_repr.src_addr(), &ip_repr.dst_addr());
})
}
Response::Nop => Ok(())
}
}

54 changes: 46 additions & 8 deletions src/socket/tcp.rs
Original file line number Diff line number Diff line change
@@ -285,10 +285,10 @@ impl<'a> TcpSocket<'a> {
listen_address: IpAddress::default(),
local_endpoint: IpEndpoint::default(),
remote_endpoint: IpEndpoint::default(),
local_seq_no: TcpSeqNumber(0),
remote_seq_no: TcpSeqNumber(0),
remote_last_seq: TcpSeqNumber(0),
remote_last_ack: TcpSeqNumber(0),
local_seq_no: TcpSeqNumber::default(),
remote_seq_no: TcpSeqNumber::default(),
remote_last_seq: TcpSeqNumber::default(),
remote_last_ack: TcpSeqNumber::default(),
remote_win_len: 0,
remote_mss: DEFAULT_MSS,
retransmit: Retransmit::new(),
@@ -335,10 +335,10 @@ impl<'a> TcpSocket<'a> {
self.listen_address = IpAddress::default();
self.local_endpoint = IpEndpoint::default();
self.remote_endpoint = IpEndpoint::default();
self.local_seq_no = TcpSeqNumber(0);
self.remote_seq_no = TcpSeqNumber(0);
self.remote_last_seq = TcpSeqNumber(0);
self.remote_last_ack = TcpSeqNumber(0);
self.local_seq_no = TcpSeqNumber::default();
self.remote_seq_no = TcpSeqNumber::default();
self.remote_last_seq = TcpSeqNumber::default();
self.remote_last_ack = TcpSeqNumber::default();
self.remote_win_len = 0;
self.remote_mss = DEFAULT_MSS;
self.retransmit.reset();
@@ -681,6 +681,44 @@ impl<'a> TcpSocket<'a> {
self.state = state
}

pub(crate) fn reply(ip_repr: &IpRepr, tcp_repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) {
let tcp_reply_repr = TcpRepr {
src_port: tcp_repr.dst_port,
dst_port: tcp_repr.src_port,
control: TcpControl::None,
push: false,
seq_number: TcpSeqNumber(0),
ack_number: None,
window_len: 0,
max_seg_size: None,
payload: &[]
};
let ip_reply_repr = IpRepr::Unspecified {
src_addr: ip_repr.dst_addr(),
dst_addr: ip_repr.src_addr(),
protocol: IpProtocol::Tcp,
payload_len: tcp_reply_repr.buffer_len()
};
(ip_reply_repr, tcp_reply_repr)
}

pub(crate) fn rst_reply(ip_repr: &IpRepr, tcp_repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) {
debug_assert!(tcp_repr.control != TcpControl::Rst);

let (ip_reply_repr, mut tcp_reply_repr) = Self::reply(ip_repr, tcp_repr);

// See https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for explanation
// of why we sometimes send an RST and sometimes an RST|ACK
tcp_reply_repr.control = TcpControl::Rst;
tcp_reply_repr.seq_number = tcp_repr.ack_number.unwrap_or_default();
if tcp_repr.control == TcpControl::Syn {
tcp_reply_repr.ack_number = Some(tcp_repr.seq_number +
tcp_repr.segment_len());
}

(ip_reply_repr, tcp_reply_repr)
}

pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr,
payload: &[u8]) -> Result<()> {
debug_assert!(ip_repr.protocol() == IpProtocol::Tcp);
17 changes: 17 additions & 0 deletions src/wire/ip.rs
Original file line number Diff line number Diff line change
@@ -177,6 +177,12 @@ pub enum IpRepr {
__Nonexhaustive
}

impl From<Ipv4Repr> for IpRepr {
fn from(repr: Ipv4Repr) -> IpRepr {
IpRepr::Ipv4(repr)
}
}

impl IpRepr {
/// Return the protocol version.
pub fn version(&self) -> Version {
@@ -323,6 +329,17 @@ impl IpRepr {
unreachable!()
}
}

/// Return the total length of a packet that will be emitted from this
/// high-level representation.
///
/// This is the same as `repr.buffer_len() + repr.payload_len()`.
///
/// # Panics
/// This function panics if invoked on an unspecified representation.
pub fn total_len(&self) -> usize {
self.buffer_len() + self.payload_len()
}
}

pub mod checksum {
20 changes: 14 additions & 6 deletions src/wire/tcp.rs
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ use super::ip::checksum;
///
/// A sequence number is a monotonically advancing integer modulo 2<sup>32</sup>.
/// Sequence numbers do not have a discontiguity when compared pairwise across a signed overflow.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
pub struct SeqNumber(pub i32);

impl fmt::Display for SeqNumber {
@@ -275,7 +275,6 @@ impl<T: AsRef<[u8]>> Packet<T> {
}

/// Return the length of the segment, in terms of sequence space.
#[inline]
pub fn segment_len(&self) -> usize {
let data = self.buffer.as_ref();
let mut length = data.len() - self.header_len() as usize;
@@ -695,10 +694,9 @@ impl<'a> Repr<'a> {
}

/// Emit a high-level representation into a Transmission Control Protocol packet.
pub fn emit<T: ?Sized>(&self, packet: &mut Packet<&mut T>,
src_addr: &IpAddress,
dst_addr: &IpAddress)
where T: AsRef<[u8]> + AsMut<[u8]> {
pub fn emit<T>(&self, packet: &mut Packet<&mut T>,
src_addr: &IpAddress, dst_addr: &IpAddress)
where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized {
packet.set_src_port(self.src_port);
packet.set_dst_port(self.dst_port);
packet.set_seq_number(self.seq_number);
@@ -727,6 +725,16 @@ impl<'a> Repr<'a> {
packet.payload_mut().copy_from_slice(self.payload);
packet.fill_checksum(src_addr, dst_addr)
}

/// Return the length of the segment, in terms of sequence space.
pub fn segment_len(&self) -> usize {
let mut length = self.payload.len();
match self.control {
Control::Syn | Control::Fin => length += 1,
_ => ()
}
length
}
}

impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {

0 comments on commit bc2a894

Please sign in to comment.