Skip to content

Commit

Permalink
Rework TcpSocket::{send,recv} to remove need for precomputing size.
Browse files Browse the repository at this point in the history
Now, these functions give you the largest contiguous slice they can
grab, and you return however much you took from it.
whitequark committed Oct 31, 2017
1 parent 1e18c03 commit 0091191
Showing 5 changed files with 121 additions and 78 deletions.
8 changes: 4 additions & 4 deletions examples/client.rs
Original file line number Diff line number Diff line change
@@ -66,17 +66,17 @@ fn main() {
tcp_active = socket.is_active();

if socket.may_recv() {
let data = {
let mut data = socket.recv(128).unwrap().to_owned();
let data = socket.recv(|data| {
let mut data = data.to_owned();
if data.len() > 0 {
debug!("recv data: {:?}",
str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
data = data.split(|&b| b == b'\n').collect::<Vec<_>>().concat();
data.reverse();
data.extend(b"\n");
}
data
};
(data.len(), data)
}).unwrap();
if socket.can_send() && data.len() > 0 {
debug!("send data: {:?}",
str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
4 changes: 3 additions & 1 deletion examples/loopback.rs
Original file line number Diff line number Diff line change
@@ -133,7 +133,9 @@ fn main() {
}

if socket.can_recv() {
debug!("got {:?}", str::from_utf8(socket.recv(32).unwrap()).unwrap());
debug!("got {:?}", socket.recv(|buffer| {
(buffer.len(), str::from_utf8(buffer).unwrap())
}));
socket.close();
done = true;
}
22 changes: 12 additions & 10 deletions examples/server.rs
Original file line number Diff line number Diff line change
@@ -121,17 +121,17 @@ fn main() {
tcp_6970_active = socket.is_active();

if socket.may_recv() {
let data = {
let mut data = socket.recv(128).unwrap().to_owned();
let data = socket.recv(|buffer| {
let mut data = buffer.to_owned();
if data.len() > 0 {
debug!("tcp:6970 recv data: {:?}",
str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
data = data.split(|&b| b == b'\n').collect::<Vec<_>>().concat();
data.reverse();
data.extend(b"\n");
}
data
};
(data.len(), data)
}).unwrap();
if socket.can_send() && data.len() > 0 {
debug!("tcp:6970 send data: {:?}",
str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
@@ -153,11 +153,12 @@ fn main() {
}

if socket.may_recv() {
if let Ok(data) = socket.recv(65535) {
if data.len() > 0 {
debug!("tcp:6971 recv {:?} octets", data.len());
socket.recv(|buffer| {
if buffer.len() > 0 {
debug!("tcp:6971 recv {:?} octets", buffer.len());
}
}
(buffer.len(), ())
}).unwrap();
} else if socket.may_send() {
socket.close();
}
@@ -171,14 +172,15 @@ fn main() {
}

if socket.may_send() {
if let Ok(data) = socket.send(65535) {
socket.send(|data| {
if data.len() > 0 {
debug!("tcp:6972 send {:?} octets", data.len());
for (i, b) in data.iter_mut().enumerate() {
*b = (i % 256) as u8;
}
}
}
(data.len(), ())
}).unwrap();
}
}

161 changes: 100 additions & 61 deletions src/socket/tcp.rs
Original file line number Diff line number Diff line change
@@ -593,15 +593,8 @@ impl<'a> TcpSocket<'a> {
!self.rx_buffer.is_empty()
}

/// Enqueue a sequence of octets to be sent, and return a pointer to it.
///
/// This function may return a slice smaller than the requested size in case
/// there is not enough contiguous free space in the transmit buffer, down to
/// an empty slice.
///
/// This function returns `Err(Error::Illegal) if the transmit half of
/// the connection is not open; see [may_send](#method.may_send).
pub fn send(&mut self, size: usize) -> Result<&mut [u8]> {
fn send_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) {
if !self.may_send() { return Err(Error::Illegal) }

// The connection might have been idle for a long time, and so remote_last_ts
@@ -610,14 +603,26 @@ impl<'a> TcpSocket<'a> {
if self.tx_buffer.is_empty() { self.remote_last_ts = None }

let _old_length = self.tx_buffer.len();
let buffer = self.tx_buffer.enqueue_many(size);
if buffer.len() > 0 {
let (size, result) = f(&mut self.tx_buffer);
if size > 0 {
#[cfg(any(test, feature = "verbose"))]
net_trace!("{}:{}:{}: tx buffer: enqueueing {} octets (now {})",
self.handle, self.local_endpoint, self.remote_endpoint,
buffer.len(), _old_length + buffer.len());
size, _old_length + size);
}
Ok(buffer)
Ok(result)
}

/// Call `f` with the largest contiguous slice of octets in the transmit buffer,
/// and enqueue the amount of elements returned by `f`.
///
/// This function returns `Err(Error::Illegal) if the transmit half of
/// the connection is not open; see [may_send](#method.may_send).
pub fn send<'b, F, R>(&'b mut self, f: F) -> Result<R>
where F: FnOnce(&'b mut [u8]) -> (usize, R) {
self.send_impl(|tx_buffer| {
tx_buffer.enqueue_many_with(f)
})
}

/// Enqueue a sequence of octets to be sent, and fill it from a slice.
@@ -627,46 +632,42 @@ impl<'a> TcpSocket<'a> {
///
/// See also [send](#method.send).
pub fn send_slice(&mut self, data: &[u8]) -> Result<usize> {
if !self.may_send() { return Err(Error::Illegal) }

// See above.
if self.tx_buffer.is_empty() { self.remote_last_ts = None }

let _old_length = self.tx_buffer.len();
let enqueued = self.tx_buffer.enqueue_slice(data);
if enqueued != 0 {
#[cfg(any(test, feature = "verbose"))]
net_trace!("{}:{}:{}: tx buffer: enqueueing {} octets (now {})",
self.handle, self.local_endpoint, self.remote_endpoint,
enqueued, _old_length + enqueued);
}
Ok(enqueued)
self.send_impl(|tx_buffer| {
let size = tx_buffer.enqueue_slice(data);
(size, size)
})
}

/// Dequeue a sequence of received octets, and return a pointer to it.
///
/// This function may return a slice smaller than the requested size in case
/// there are not enough octets queued in the receive buffer, down to
/// an empty slice.
///
/// This function returns `Err(Error::Illegal) if the receive half of
/// the connection is not open; see [may_recv](#method.may_recv).
pub fn recv(&mut self, size: usize) -> Result<&[u8]> {
pub fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) {
// We may have received some data inside the initial SYN, but until the connection
// is fully open we must not dequeue any data, as it may be overwritten by e.g.
// another (stale) SYN.
// another (stale) SYN. (We do not support TCP Fast Open.)
if !self.may_recv() { return Err(Error::Illegal) }

let _old_length = self.rx_buffer.len();
let buffer = self.rx_buffer.dequeue_many(size);
self.remote_seq_no += buffer.len();
if buffer.len() > 0 {
let (size, result) = f(&mut self.rx_buffer);
self.remote_seq_no += size;
if size > 0 {
#[cfg(any(test, feature = "verbose"))]
net_trace!("{}:{}:{}: rx buffer: dequeueing {} octets (now {})",
self.handle, self.local_endpoint, self.remote_endpoint,
buffer.len(), _old_length - buffer.len());
size, _old_length - size);
}
Ok(buffer)
Ok(result)
}


/// Call `f` with the largest contiguous slice of octets in the receive buffer,
/// and dequeue the amount of elements returned by `f`.
///
/// This function returns `Err(Error::Illegal) if the receive half of
/// the connection is not open; see [may_recv](#method.may_recv).
pub fn recv<'b, F, R>(&'b mut self, f: F) -> Result<R>
where F: FnOnce(&'b mut [u8]) -> (usize, R) {
self.recv_impl(|rx_buffer| {
rx_buffer.dequeue_many_with(f)
})
}

/// Dequeue a sequence of received octets, and fill a slice from it.
@@ -676,19 +677,10 @@ impl<'a> TcpSocket<'a> {
///
/// See also [recv](#method.recv).
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize> {
// See recv() above.
if !self.may_recv() { return Err(Error::Illegal) }

let _old_length = self.rx_buffer.len();
let dequeued = self.rx_buffer.dequeue_slice(data);
self.remote_seq_no += dequeued;
if dequeued > 0 {
#[cfg(any(test, feature = "verbose"))]
net_trace!("{}:{}:{}: rx buffer: dequeueing {} octets (now {})",
self.handle, self.local_endpoint, self.remote_endpoint,
dequeued, _old_length - dequeued);
}
Ok(dequeued)
self.recv_impl(|rx_buffer| {
let size = rx_buffer.dequeue_slice(data);
(size, size)
})
}

/// Peek at a sequence of received octets without removing them from
@@ -3145,15 +3137,21 @@ mod test {
..RECV_TEMPL
}]);
recv!(s, time 0, Err(Error::Exhausted));
assert_eq!(s.recv(3), Ok(&b"abc"[..]));
s.recv(|buffer| {
assert_eq!(&buffer[..3], b"abc");
(3, ())
}).unwrap();
recv!(s, time 0, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1,
ack_number: Some(REMOTE_SEQ + 1 + 6),
window_len: 3,
..RECV_TEMPL
}));
recv!(s, time 0, Err(Error::Exhausted));
assert_eq!(s.recv(3), Ok(&b"def"[..]));
s.recv(|buffer| {
assert_eq!(buffer, b"def");
(buffer.len(), ())
}).unwrap();
recv!(s, time 0, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1,
ack_number: Some(REMOTE_SEQ + 1 + 6),
@@ -3457,7 +3455,10 @@ mod test {
ack_number: Some(REMOTE_SEQ + 1),
..RECV_TEMPL
})));
assert_eq!(s.recv(10), Ok(&b""[..]));
s.recv(|buffer| {
assert_eq!(buffer, b"");
(buffer.len(), ())
}).unwrap();
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1),
@@ -3469,11 +3470,14 @@ mod test {
window_len: 58,
..RECV_TEMPL
})));
assert_eq!(s.recv(10), Ok(&b"abcdef"[..]));
s.recv(|buffer| {
assert_eq!(buffer, b"abcdef");
(buffer.len(), ())
}).unwrap();
}

#[test]
fn test_buffer_wraparound() {
fn test_buffer_wraparound_rx() {
let mut s = socket_established();
s.rx_buffer = SocketBuffer::new(vec![0; 6]);
s.assembler = Assembler::new(s.rx_buffer.capacity());
@@ -3483,7 +3487,10 @@ mod test {
payload: &b"abc"[..],
..SEND_TEMPL
});
assert_eq!(s.recv(3), Ok(&b"abc"[..]));
s.recv(|buffer| {
assert_eq!(buffer, b"abc");
(buffer.len(), ())
}).unwrap();
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1 + 3,
ack_number: Some(LOCAL_SEQ + 1),
@@ -3495,6 +3502,38 @@ mod test {
assert_eq!(data, &b"defghi"[..]);
}

#[test]
fn test_buffer_wraparound_tx() {
let mut s = socket_established();
s.tx_buffer = SocketBuffer::new(vec![0; 6]);
assert_eq!(s.send_slice(b"abc"), Ok(3));
recv!(s, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1,
ack_number: Some(REMOTE_SEQ + 1),
payload: &b"abc"[..],
..RECV_TEMPL
}));
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1 + 3),
..SEND_TEMPL
});
assert_eq!(s.send_slice(b"defghi"), Ok(6));
recv!(s, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1 + 3,
ack_number: Some(REMOTE_SEQ + 1),
payload: &b"def"[..],
..RECV_TEMPL
}));
// "defghi" not contiguous in tx buffer
recv!(s, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1 + 3 + 3,
ack_number: Some(REMOTE_SEQ + 1),
payload: &b"ghi"[..],
..RECV_TEMPL
}));
}

// =========================================================================================//
// Tests for packet filtering.
// =========================================================================================//
4 changes: 2 additions & 2 deletions src/socket/udp.rs
Original file line number Diff line number Diff line change
@@ -195,8 +195,8 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
Ok((&packet_buf.as_ref(), packet_buf.endpoint))
}

/// Dequeue a packet received from a remote endpoint, and return the endpoint as well
/// as copy the payload into the given slice.
/// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
/// and return the amount of octets copied as well as the endpoint.
///
/// See also [recv](#method.recv).
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpEndpoint)> {

0 comments on commit 0091191

Please sign in to comment.