Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: smoltcp-rs/smoltcp
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 44aa8db751f1
Choose a base ref
...
head repository: smoltcp-rs/smoltcp
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: a81dd55cca5e
Choose a head ref
  • 2 commits
  • 1 file changed
  • 1 contributor

Commits on Dec 26, 2016

  1. Copy the full SHA
    0494c9f View commit details
  2. Copy the full SHA
    a81dd55 View commit details
Showing with 231 additions and 81 deletions.
  1. +231 −81 src/socket/tcp.rs
312 changes: 231 additions & 81 deletions src/socket/tcp.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ use core::fmt;

use Error;
use Managed;
use wire::{IpProtocol, IpEndpoint};
use wire::{IpProtocol, IpAddress, IpEndpoint};
use wire::{TcpPacket, TcpRepr, TcpControl};
use socket::{Socket, IpRepr, IpPayload};

@@ -162,6 +162,7 @@ impl Retransmit {
#[derive(Debug)]
pub struct TcpSocket<'a> {
state: State,
listen_address: IpAddress,
local_endpoint: IpEndpoint,
remote_endpoint: IpEndpoint,
local_seq_no: i32,
@@ -185,6 +186,7 @@ impl<'a> TcpSocket<'a> {

Socket::Tcp(TcpSocket {
state: State::Closed,
listen_address: IpAddress::default(),
local_endpoint: IpEndpoint::default(),
remote_endpoint: IpEndpoint::default(),
local_seq_no: 0,
@@ -235,6 +237,7 @@ impl<'a> TcpSocket<'a> {
pub fn listen(&mut self, endpoint: IpEndpoint) {
assert!(self.state == State::Closed);

self.listen_address = endpoint.addr;
self.local_endpoint = endpoint;
self.remote_endpoint = IpEndpoint::default();
self.set_state(State::Listen);
@@ -260,20 +263,37 @@ impl<'a> TcpSocket<'a> {

// Reject packets addressed to a closed socket.
if self.state == State::Closed {
net_trace!("tcp:{}:{}:{}: packet sent to a closed socket",
net_trace!("tcp:{}:{}:{}: packet received by a closed socket",
self.local_endpoint, ip_repr.src_addr(), repr.src_port);
return Err(Error::Malformed)
}

// Reject unacceptable acknowledgements.
match (self.state, repr) {
// The initial SYN cannot contain an acknowledgement.
// The initial SYN (or whatever) cannot contain an acknowledgement.
(State::Listen, TcpRepr { ack_number: Some(_), .. }) => {
net_trace!("tcp:{}:{}: ACK in initial SYN",
net_trace!("tcp:{}:{}: ACK received by a socket in LISTEN state",
self.local_endpoint, self.remote_endpoint);
return Err(Error::Malformed)
}
(State::Listen, TcpRepr { ack_number: None, .. }) => (),
// A reset received in response to initial SYN is acceptable if it acknowledges
// the initial SYN.
(State::SynSent, TcpRepr { control: TcpControl::Rst, ack_number: None, .. }) => {
net_trace!("tcp:{}:{}: unacceptable RST (expecting RST|ACK) \
in response to initial SYN",
self.local_endpoint, self.remote_endpoint);
return Err(Error::Malformed)
}
(State::SynSent, TcpRepr {
control: TcpControl::Rst, ack_number: Some(ack_number), ..
}) => {
if ack_number != self.local_seq_no {
net_trace!("tcp:{}:{}: unacceptable RST|ACK in response to initial SYN",
self.local_endpoint, self.remote_endpoint);
return Err(Error::Malformed)
}
}
// Every packet after the initial SYN must be an acknowledgement.
(_, TcpRepr { ack_number: None, .. }) => {
net_trace!("tcp:{}:{}: expecting an ACK",
@@ -295,27 +315,53 @@ impl<'a> TcpSocket<'a> {
}
}

// Reject segments not occupying a valid portion of the receive window.
// For now, do not try to reassemble out-of-order segments.
if self.state != State::Listen {
let next_remote_seq = self.remote_seq_no + self.rx_buffer.len() as i32 +
repr.control.len();
if repr.seq_number - next_remote_seq > 0 {
net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)",
self.local_endpoint, self.remote_endpoint,
repr.seq_number, next_remote_seq);
return Err(Error::Malformed)
} else if repr.seq_number - next_remote_seq != 0 {
net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})",
self.local_endpoint, self.remote_endpoint,
repr.seq_number, next_remote_seq);
return Ok(())
match (self.state, repr) {
// In LISTEN and SYN_SENT states, we have not yet synchronized with the remote end.
(State::Listen, _) => (),
(State::SynSent, _) => (),
// In all other states, segments must occupy a valid portion of the receive window.
// For now, do not try to reassemble out-of-order segments.
(_, TcpRepr { control, seq_number, .. }) => {
let next_remote_seq = self.remote_seq_no + self.rx_buffer.len() as i32 +
control.len();
if seq_number - next_remote_seq > 0 {
net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)",
self.local_endpoint, self.remote_endpoint,
seq_number, next_remote_seq);
return Err(Error::Malformed)
} else if seq_number - next_remote_seq != 0 {
net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})",
self.local_endpoint, self.remote_endpoint,
seq_number, next_remote_seq);
return Ok(())
}
}
}

// Validate and update the state.
let old_state = self.state;
match (self.state, repr) {
// RSTs are ignored in the LISTEN state.
(State::Listen, TcpRepr { control: TcpControl::Rst, .. }) =>
return Ok(()),

// RSTs in SYN_RECEIVED flip the socket back to the LISTEN state.
(State::SynReceived, TcpRepr { control: TcpControl::Rst, .. }) => {
self.local_endpoint.addr = self.listen_address;
self.remote_endpoint = IpEndpoint::default();
self.set_state(State::Listen);
return Ok(())
}

// RSTs in any other state close the socket.
(_, TcpRepr { control: TcpControl::Rst, .. }) => {
self.local_endpoint = IpEndpoint::default();
self.remote_endpoint = IpEndpoint::default();
self.set_state(State::Closed);
return Ok(())
}

// SYN packets in the LISTEN state change it to SYN_RECEIVED.
(State::Listen, TcpRepr {
src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, ..
}) => {
@@ -327,11 +373,13 @@ impl<'a> TcpSocket<'a> {
self.retransmit.reset()
}

// SYN|ACK packets in the SYN_RECEIVED state change it to ESTABLISHED.
(State::SynReceived, TcpRepr { control: TcpControl::None, .. }) => {
self.set_state(State::Established);
self.retransmit.reset()
}

// ACK packets in ESTABLISHED state do nothing.
(State::Established, TcpRepr { control: TcpControl::None, .. }) => (),

_ => {
@@ -569,6 +617,9 @@ mod test {
}
}

// =========================================================================================//
// Tests for the CLOSED state.
// =========================================================================================//
#[test]
fn test_closed() {
let mut s = socket();
@@ -580,137 +631,236 @@ mod test {
}, Err(Error::Rejected));
}

#[test]
fn test_listen() {
// =========================================================================================//
// Tests for the LISTEN state.
// =========================================================================================//
fn socket_listen() -> TcpSocket<'static> {
let mut s = socket();
s.listen(IpEndpoint::new(IpAddress::default(), LOCAL_PORT));
assert_eq!(s.state(), State::Listen);
s.state = State::Listen;
s.local_endpoint = IpEndpoint::new(IpAddress::default(), LOCAL_PORT);
s
}

#[test]
fn test_handshake() {
let mut s = socket();
s.state = State::Listen;
s.local_endpoint = IpEndpoint::new(IpAddress::default(), LOCAL_PORT);
fn test_listen_syn_no_ack() {
let mut s = socket_listen();
send!(s, TcpRepr {
control: TcpControl::Syn,
seq_number: REMOTE_SEQ,
ack_number: Some(LOCAL_SEQ),
..SEND_TEMPL
}, Err(Error::Malformed));
assert_eq!(s.state, State::Listen);
}

#[test]
fn test_listen_rst() {
let mut s = socket_listen();
send!(s, [TcpRepr {
control: TcpControl::Syn,
control: TcpControl::Rst,
seq_number: REMOTE_SEQ,
ack_number: None,
..SEND_TEMPL
}]);
assert_eq!(s.state(), State::SynReceived);
assert_eq!(s.local_endpoint(), LOCAL_END);
assert_eq!(s.remote_endpoint(), REMOTE_END);
recv!(s, [TcpRepr {
control: TcpControl::Syn,
seq_number: LOCAL_SEQ,
ack_number: Some(REMOTE_SEQ + 1),
..RECV_TEMPL
}]);
}

// =========================================================================================//
// Tests for the SYN_RECEIVED state.
// =========================================================================================//
fn socket_syn_received() -> TcpSocket<'static> {
let mut s = socket();
s.state = State::SynReceived;
s.local_endpoint = LOCAL_END;
s.remote_endpoint = REMOTE_END;
s.local_seq_no = LOCAL_SEQ;
s.remote_seq_no = REMOTE_SEQ;
s
}

#[test]
fn test_syn_received_rst() {
let mut s = socket_syn_received();
send!(s, [TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1),
control: TcpControl::Rst,
seq_number: REMOTE_SEQ,
ack_number: Some(LOCAL_SEQ),
..SEND_TEMPL
}]);
assert_eq!(s.state(), State::Established);
assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
assert_eq!(s.state, State::Listen);
assert_eq!(s.local_endpoint, IpEndpoint::new(IpAddress::Unspecified, LOCAL_END.port));
assert_eq!(s.remote_endpoint, IpEndpoint::default());
}

#[test]
fn test_no_ack() {
// =========================================================================================//
// Tests for the SYN_SENT state.
// =========================================================================================//
fn socket_syn_sent() -> TcpSocket<'static> {
let mut s = socket();
s.state = State::Established;
s.state = State::SynSent;
s.local_endpoint = LOCAL_END;
s.remote_endpoint = REMOTE_END;
s.local_seq_no = LOCAL_SEQ + 1;
s.remote_seq_no = REMOTE_SEQ + 1;
s.local_seq_no = LOCAL_SEQ;
s
}

#[test]
fn test_syn_sent_rst() {
let mut s = socket_syn_sent();
send!(s, [TcpRepr {
control: TcpControl::Rst,
seq_number: REMOTE_SEQ,
ack_number: Some(LOCAL_SEQ),
..SEND_TEMPL
}]);
assert_eq!(s.state, State::Closed);
}

#[test]
fn test_syn_sent_rst_no_ack() {
let mut s = socket_syn_sent();
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
control: TcpControl::Rst,
seq_number: REMOTE_SEQ,
ack_number: None,
..SEND_TEMPL
}, Err(Error::Malformed));
assert_eq!(s.state, State::SynSent);
}

#[test]
fn test_bad_ack_listen() {
let mut s = socket();
s.state = State::Listen;
s.local_endpoint = IpEndpoint::new(IpAddress::default(), LOCAL_PORT);

fn test_syn_sent_rst_bad_ack() {
let mut s = socket_syn_sent();
send!(s, TcpRepr {
control: TcpControl::Syn,
control: TcpControl::Rst,
seq_number: REMOTE_SEQ,
ack_number: Some(LOCAL_SEQ),
ack_number: Some(1234),
..SEND_TEMPL
}, Err(Error::Malformed));
assert_eq!(s.state, State::SynSent);
}

#[test]
fn test_bad_ack_established() {
// =========================================================================================//
// Tests for the ESTABLISHED state.
// =========================================================================================//
fn socket_established() -> TcpSocket<'static> {
let mut s = socket();
s.state = State::Established;
s.state = State::Established;
s.local_endpoint = LOCAL_END;
s.remote_endpoint = REMOTE_END;
s.local_seq_no = LOCAL_SEQ + 1;
s.remote_seq_no = REMOTE_SEQ + 1;
s.tx_buffer.enqueue_slice(b"abcdef");
s
}

#[test]
fn test_established_data() {
let mut s = socket_established();
send!(s, [TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1),
payload: &b"abcdef"[..],
..SEND_TEMPL
}]);
recv!(s, [TcpRepr {
seq_number: LOCAL_SEQ + 1,
ack_number: Some(REMOTE_SEQ + 1 + 6),
window_len: 122,
..RECV_TEMPL
}]);
assert_eq!(s.rx_buffer.dequeue(6), &b"abcdef"[..]);
}

#[test]
fn test_established_no_ack() {
let mut s = socket_established();
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: None,
..SEND_TEMPL
}, Err(Error::Malformed));
}

#[test]
fn test_established_bad_ack() {
let mut s = socket_established();
// Already acknowledged data.
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ - 1),
..SEND_TEMPL
}, Err(Error::Malformed));

assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
// Data not yet transmitted.
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 10),
..SEND_TEMPL
}, Err(Error::Malformed));
assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
}

#[test]
fn test_unacceptable_seq() {
let mut s = socket();
s.state = State::Established;
s.local_endpoint = LOCAL_END;
s.remote_endpoint = REMOTE_END;
s.local_seq_no = LOCAL_SEQ + 1;
s.remote_seq_no = REMOTE_SEQ + 1;

fn test_established_bad_seq() {
let mut s = socket_established();
// Data outside of receive window.
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1 + 256,
ack_number: Some(LOCAL_SEQ + 1),
..SEND_TEMPL
}, Err(Error::Malformed));
assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
}

#[test]
fn test_recv_data() {
let mut s = socket();
s.state = State::Established;
s.local_endpoint = LOCAL_END;
s.remote_endpoint = REMOTE_END;
s.local_seq_no = LOCAL_SEQ + 1;
s.remote_seq_no = REMOTE_SEQ + 1;

fn test_established_rst() {
let mut s = socket_established();
send!(s, [TcpRepr {
control: TcpControl::Rst,
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1),
payload: &b"abcdef"[..],
..SEND_TEMPL
}]);
assert_eq!(s.state, State::Closed);
}

// =========================================================================================//
// Tests for transitioning through multiple states.
// =========================================================================================//
#[test]
fn test_listen() {
let mut s = socket();
s.listen(IpEndpoint::new(IpAddress::default(), LOCAL_PORT));
assert_eq!(s.state, State::Listen);
}

#[test]
fn test_three_way_handshake() {
let mut s = socket();
s.state = State::Listen;
s.local_endpoint = IpEndpoint::new(IpAddress::default(), LOCAL_PORT);

send!(s, [TcpRepr {
control: TcpControl::Syn,
seq_number: REMOTE_SEQ,
ack_number: None,
..SEND_TEMPL
}]);
assert_eq!(s.state(), State::SynReceived);
assert_eq!(s.local_endpoint(), LOCAL_END);
assert_eq!(s.remote_endpoint(), REMOTE_END);
recv!(s, [TcpRepr {
seq_number: LOCAL_SEQ + 1,
ack_number: Some(REMOTE_SEQ + 1 + 6),
window_len: 122,
control: TcpControl::Syn,
seq_number: LOCAL_SEQ,
ack_number: Some(REMOTE_SEQ + 1),
..RECV_TEMPL
}]);
assert_eq!(s.rx_buffer.dequeue(6), &b"abcdef"[..]);
send!(s, [TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1),
..SEND_TEMPL
}]);
assert_eq!(s.state(), State::Established);
assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
}
}