@@ -25,14 +25,19 @@ impl<'a> SocketBuffer<'a> {
25
25
}
26
26
}
27
27
28
- /// Return the amount of octets enqueued in the buffer.
28
+ /// Return the maximum amount of octets that can be enqueued in the buffer.
29
+ pub fn capacity ( & self ) -> usize {
30
+ self . storage . len ( )
31
+ }
32
+
33
+ /// Return the amount of octets already enqueued in the buffer.
29
34
pub fn len ( & self ) -> usize {
30
35
self . length
31
36
}
32
37
33
- /// Return the maximum amount of octets that can be enqueued in the buffer.
34
- pub fn capacity ( & self ) -> usize {
35
- self . storage . len ( )
38
+ /// Return the amount of octets that remain to be enqueued in the buffer.
39
+ pub fn window ( & self ) -> usize {
40
+ self . capacity ( ) - self . len ( )
36
41
}
37
42
38
43
/// Enqueue a slice of octets up to the given size into the buffer, and return a pointer
@@ -135,14 +140,14 @@ impl Retransmit {
135
140
/// A Transmission Control Protocol data stream.
136
141
#[ derive( Debug ) ]
137
142
pub struct TcpSocket < ' a > {
138
- state : State ,
139
- local_end : IpEndpoint ,
140
- remote_end : IpEndpoint ,
141
- local_seq_no : i32 ,
142
- remote_seq_no : i32 ,
143
- retransmit : Retransmit ,
144
- rx_buffer : SocketBuffer < ' a > ,
145
- tx_buffer : SocketBuffer < ' a >
143
+ state : State ,
144
+ local_endpoint : IpEndpoint ,
145
+ remote_endpoint : IpEndpoint ,
146
+ local_seq_no : i32 ,
147
+ remote_seq_no : i32 ,
148
+ retransmit : Retransmit ,
149
+ rx_buffer : SocketBuffer < ' a > ,
150
+ tx_buffer : SocketBuffer < ' a >
146
151
}
147
152
148
153
impl < ' a > TcpSocket < ' a > {
@@ -156,14 +161,14 @@ impl<'a> TcpSocket<'a> {
156
161
}
157
162
158
163
Socket :: Tcp ( TcpSocket {
159
- state : State :: Closed ,
160
- local_end : IpEndpoint :: default ( ) ,
161
- remote_end : IpEndpoint :: default ( ) ,
162
- local_seq_no : 0 ,
163
- remote_seq_no : 0 ,
164
- retransmit : Retransmit :: new ( ) ,
165
- tx_buffer : tx_buffer. into ( ) ,
166
- rx_buffer : rx_buffer. into ( )
164
+ state : State :: Closed ,
165
+ local_endpoint : IpEndpoint :: default ( ) ,
166
+ remote_endpoint : IpEndpoint :: default ( ) ,
167
+ local_seq_no : 0 ,
168
+ remote_seq_no : 0 ,
169
+ retransmit : Retransmit :: new ( ) ,
170
+ tx_buffer : tx_buffer. into ( ) ,
171
+ rx_buffer : rx_buffer. into ( )
167
172
} )
168
173
}
169
174
@@ -176,23 +181,23 @@ impl<'a> TcpSocket<'a> {
176
181
/// Return the local endpoint.
177
182
#[ inline( always) ]
178
183
pub fn local_endpoint ( & self ) -> IpEndpoint {
179
- self . local_end
184
+ self . local_endpoint
180
185
}
181
186
182
187
/// Return the remote endpoint.
183
188
#[ inline( always) ]
184
189
pub fn remote_endpoint ( & self ) -> IpEndpoint {
185
- self . remote_end
190
+ self . remote_endpoint
186
191
}
187
192
188
193
fn set_state ( & mut self , state : State ) {
189
194
if self . state != state {
190
- if self . remote_end . addr . is_unspecified ( ) {
195
+ if self . remote_endpoint . addr . is_unspecified ( ) {
191
196
net_trace ! ( "tcp:{}: state={}→{}" ,
192
- self . local_end , self . state, state) ;
197
+ self . local_endpoint , self . state, state) ;
193
198
} else {
194
199
net_trace ! ( "tcp:{}:{}: state={}→{}" ,
195
- self . local_end , self . remote_end , self . state, state) ;
200
+ self . local_endpoint , self . remote_endpoint , self . state, state) ;
196
201
}
197
202
}
198
203
self . state = state
@@ -205,8 +210,8 @@ impl<'a> TcpSocket<'a> {
205
210
pub fn listen ( & mut self , endpoint : IpEndpoint ) {
206
211
assert ! ( self . state == State :: Closed ) ;
207
212
208
- self . local_end = endpoint;
209
- self . remote_end = IpEndpoint :: default ( ) ;
213
+ self . local_endpoint = endpoint;
214
+ self . remote_endpoint = IpEndpoint :: default ( ) ;
210
215
self . set_state ( State :: Listen ) ;
211
216
}
212
217
@@ -219,49 +224,83 @@ impl<'a> TcpSocket<'a> {
219
224
let packet = try!( TcpPacket :: new ( payload) ) ;
220
225
let repr = try!( TcpRepr :: parse ( & packet, src_addr, dst_addr) ) ;
221
226
222
- if self . local_end . port != repr. dst_port { return Err ( Error :: Rejected ) }
223
- if !self . local_end . addr . is_unspecified ( ) &&
224
- self . local_end . addr != * dst_addr { return Err ( Error :: Rejected ) }
227
+ // Reject packets with a wrong destination.
228
+ if self . local_endpoint . port != repr. dst_port { return Err ( Error :: Rejected ) }
229
+ if !self . local_endpoint . addr . is_unspecified ( ) &&
230
+ self . local_endpoint . addr != * dst_addr { return Err ( Error :: Rejected ) }
225
231
226
- if self . remote_end . port != 0 &&
227
- self . remote_end . port != repr. src_port { return Err ( Error :: Rejected ) }
228
- if !self . remote_end . addr . is_unspecified ( ) &&
229
- self . remote_end . addr != * src_addr { return Err ( Error :: Rejected ) }
232
+ // Reject packets from a source to which we aren't connected.
233
+ if self . remote_endpoint . port != 0 &&
234
+ self . remote_endpoint . port != repr. src_port { return Err ( Error :: Rejected ) }
235
+ if !self . remote_endpoint . addr . is_unspecified ( ) &&
236
+ self . remote_endpoint . addr != * src_addr { return Err ( Error :: Rejected ) }
230
237
231
238
match ( self . state , repr) {
232
- ( State :: Closed , _) => Err ( Error :: Rejected ) ,
239
+ // Reject packets addressed to a closed socket.
240
+ ( State :: Closed , TcpRepr { src_port, .. } ) => {
241
+ net_trace ! ( "tcp:{}:{}:{}: packet sent to a closed socket" ,
242
+ self . local_endpoint, src_addr, src_port) ;
243
+ return Err ( Error :: Malformed )
244
+ }
245
+ // Don't care about ACKs when performing the handshake.
246
+ ( State :: Listen , _) => ( ) ,
247
+ ( State :: SynSent , _) => ( ) ,
248
+ // Every packet after the initial SYN must be an acknowledgement.
249
+ ( _, TcpRepr { ack_number : None , .. } ) => {
250
+ net_trace ! ( "tcp:{}:{}: expecting an ACK packet" ,
251
+ self . local_endpoint, self . remote_endpoint) ;
252
+ return Err ( Error :: Malformed )
253
+ }
254
+ // Reject unacceptable acknowledgements.
255
+ ( state, TcpRepr { ack_number : Some ( ack_number) , .. } ) => {
256
+ let unacknowledged =
257
+ if state != State :: SynReceived { self . rx_buffer . len ( ) as i32 } else { 1 } ;
258
+ if !( ack_number - self . local_seq_no > 0 &&
259
+ ack_number - ( self . local_seq_no + unacknowledged) <= 0 ) {
260
+ net_trace ! ( "tcp:{}:{}: unacceptable ACK ({} not in {}..{})" ,
261
+ self . local_endpoint, self . remote_endpoint,
262
+ ack_number, self . local_seq_no, self . local_seq_no + unacknowledged) ;
263
+ return Err ( Error :: Malformed )
264
+ }
265
+ }
266
+ }
233
267
268
+ // Handle the incoming packet.
269
+ match ( self . state , repr) {
234
270
( State :: Listen , TcpRepr {
235
- src_port, dst_port, control : TcpControl :: Syn , seq_number, ack_number : None , ..
271
+ src_port, dst_port, control : TcpControl :: Syn , seq_number, ack_number : None ,
272
+ payload, ..
236
273
} ) => {
237
- self . local_end = IpEndpoint :: new ( * dst_addr, dst_port) ;
238
- self . remote_end = IpEndpoint :: new ( * src_addr, src_port) ;
239
- self . remote_seq_no = seq_number;
240
- // FIXME: use something more secure
241
- self . local_seq_no = !seq_number;
274
+ // FIXME: don't do this, just enqueue the payload
275
+ if payload. len ( ) > 0 {
276
+ net_trace ! ( "tcp:{}:{}: SYN with payload rejected" ,
277
+ IpEndpoint :: new( * dst_addr, dst_port) ,
278
+ IpEndpoint :: new( * src_addr, src_port) ) ;
279
+ return Err ( Error :: Malformed )
280
+ }
281
+
282
+ self . local_endpoint = IpEndpoint :: new ( * dst_addr, dst_port) ;
283
+ self . remote_endpoint = IpEndpoint :: new ( * src_addr, src_port) ;
284
+ self . remote_seq_no = seq_number + 1 ;
285
+ self . local_seq_no = -seq_number; // FIXME: use something more secure
242
286
self . set_state ( State :: SynReceived ) ;
243
287
244
- // FIXME: queue data from SYN
245
288
self . retransmit . reset ( ) ;
246
289
Ok ( ( ) )
247
290
}
248
291
249
292
( State :: SynReceived , TcpRepr {
250
293
control : TcpControl :: None , ack_number : Some ( ack_number) , ..
251
294
} ) => {
252
- if ack_number != self . local_seq_no + 1 { return Err ( Error :: Rejected ) }
295
+ self . local_seq_no = ack_number ;
253
296
self . set_state ( State :: Established ) ;
254
297
255
298
// FIXME: queue data from ACK
256
- // FIXME: update sequence numbers
257
299
self . retransmit . reset ( ) ;
258
300
Ok ( ( ) )
259
301
}
260
302
261
- _ => {
262
- // This will cause the interface to reply with an RST.
263
- Err ( Error :: Rejected )
264
- }
303
+ _ => Err ( Error :: Malformed )
265
304
}
266
305
}
267
306
@@ -270,12 +309,12 @@ impl<'a> TcpSocket<'a> {
270
309
IpProtocol , & PacketRepr ) -> Result < ( ) , Error > )
271
310
-> Result < ( ) , Error > {
272
311
let mut repr = TcpRepr {
273
- src_port : self . local_end . port ,
274
- dst_port : self . remote_end . port ,
312
+ src_port : self . local_endpoint . port ,
313
+ dst_port : self . remote_endpoint . port ,
275
314
control : TcpControl :: None ,
276
315
seq_number : 0 ,
277
316
ack_number : None ,
278
- window_len : ( self . rx_buffer . capacity ( ) - self . rx_buffer . len ( ) ) as u16 ,
317
+ window_len : self . rx_buffer . window ( ) as u16 ,
279
318
payload : & [ ]
280
319
} ;
281
320
@@ -291,9 +330,9 @@ impl<'a> TcpSocket<'a> {
291
330
if !self . retransmit . check ( ) { return Err ( Error :: Exhausted ) }
292
331
repr. control = TcpControl :: Syn ;
293
332
repr. seq_number = self . local_seq_no ;
294
- repr. ack_number = Some ( self . remote_seq_no + 1 ) ;
333
+ repr. ack_number = Some ( self . remote_seq_no ) ;
295
334
net_trace ! ( "tcp:{}:{}: SYN sent" ,
296
- self . local_end , self . remote_end ) ;
335
+ self . local_endpoint , self . remote_endpoint ) ;
297
336
}
298
337
299
338
State :: Established => {
@@ -304,7 +343,7 @@ impl<'a> TcpSocket<'a> {
304
343
_ => unreachable ! ( )
305
344
}
306
345
307
- f ( & self . local_end . addr , & self . remote_end . addr , IpProtocol :: Tcp , & repr)
346
+ f ( & self . local_endpoint . addr , & self . remote_endpoint . addr , IpProtocol :: Tcp , & repr)
308
347
}
309
348
}
310
349
@@ -342,7 +381,7 @@ mod test {
342
381
const LOCAL_END : IpEndpoint = IpEndpoint :: new ( LOCAL_IP , LOCAL_PORT ) ;
343
382
const REMOTE_END : IpEndpoint = IpEndpoint :: new ( REMOTE_IP , REMOTE_PORT ) ;
344
383
const LOCAL_SEQ : i32 = 100 ;
345
- const REMOTE_SEQ : i32 = ! 100 ;
384
+ const REMOTE_SEQ : i32 = - 100 ;
346
385
347
386
const SEND_TEMPL : TcpRepr < ' static > = TcpRepr {
348
387
src_port : REMOTE_PORT , dst_port : LOCAL_PORT ,
@@ -434,22 +473,24 @@ mod test {
434
473
435
474
send ! ( s, TcpRepr {
436
475
control: TcpControl :: Syn ,
437
- seq_number: LOCAL_SEQ , ack_number: None ,
476
+ seq_number: REMOTE_SEQ , ack_number: None ,
438
477
..SEND_TEMPL
439
478
} ) ;
440
479
assert_eq ! ( s. state( ) , State :: SynReceived ) ;
441
480
assert_eq ! ( s. local_endpoint( ) , LOCAL_END ) ;
442
481
assert_eq ! ( s. remote_endpoint( ) , REMOTE_END ) ;
443
482
recv ! ( s, TcpRepr {
444
483
control: TcpControl :: Syn ,
445
- seq_number: REMOTE_SEQ , ack_number: Some ( LOCAL_SEQ + 1 ) ,
484
+ seq_number: LOCAL_SEQ , ack_number: Some ( REMOTE_SEQ + 1 ) ,
446
485
..RECV_TEMPL
447
486
} ) ;
448
487
send ! ( s, TcpRepr {
449
488
control: TcpControl :: None ,
450
- seq_number: LOCAL_SEQ + 1 , ack_number: Some ( REMOTE_SEQ + 1 ) ,
489
+ seq_number: REMOTE_SEQ + 1 , ack_number: Some ( LOCAL_SEQ + 1 ) ,
451
490
..SEND_TEMPL
452
491
} ) ;
453
492
assert_eq ! ( s. state( ) , State :: Established ) ;
493
+ assert_eq ! ( s. local_seq_no, LOCAL_SEQ + 1 ) ;
494
+ assert_eq ! ( s. remote_seq_no, REMOTE_SEQ + 1 ) ;
454
495
}
455
496
}
0 commit comments