use std::io::Read; use std::time::Duration; use aes::Aes128; use anyhow::{bail, ensure, Context}; use cfb8::cipher::{AsyncStreamCipher, NewCipher}; use cfb8::Cfb8; use flate2::bufread::{ZlibDecoder, ZlibEncoder}; use flate2::Compression; use log::{log_enabled, Level}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::time::timeout; use super::packets::{DecodePacket, EncodePacket}; use crate::protocol::{Decode, Encode, VarInt, MAX_PACKET_SIZE}; pub struct Encoder { write: W, buf: Vec, compress_buf: Vec, compression_threshold: Option, cipher: Option, timeout: Duration, } impl Encoder { pub fn new(write: W, timeout: Duration) -> Self { Self { write, buf: Vec::new(), compress_buf: Vec::new(), compression_threshold: None, cipher: None, timeout, } } pub async fn write_packet(&mut self, packet: &impl EncodePacket) -> anyhow::Result<()> { timeout(self.timeout, self.write_packet_impl(packet)).await? } async fn write_packet_impl(&mut self, packet: &impl EncodePacket) -> anyhow::Result<()> { self.buf.clear(); packet.encode_packet(&mut self.buf)?; let data_len = self.buf.len(); ensure!(data_len <= i32::MAX as usize, "bad packet data length"); if let Some(threshold) = self.compression_threshold { if data_len >= threshold as usize { let mut z = ZlibEncoder::new(self.buf.as_slice(), Compression::best()); self.compress_buf.clear(); z.read_to_end(&mut self.compress_buf)?; let data_len_len = VarInt(data_len as i32).written_size(); let packet_len = data_len_len + self.compress_buf.len(); ensure!(packet_len <= MAX_PACKET_SIZE as usize, "bad packet length"); self.buf.clear(); VarInt(packet_len as i32).encode(&mut self.buf)?; VarInt(data_len as i32).encode(&mut self.buf)?; self.buf.extend_from_slice(&self.compress_buf); } else { let packet_len = VarInt(0).written_size() + data_len; ensure!(packet_len <= MAX_PACKET_SIZE as usize, "bad packet length"); self.buf.clear(); VarInt(packet_len as i32).encode(&mut self.buf)?; VarInt(0).encode(&mut self.buf)?; packet.encode_packet(&mut self.buf)?; } } else { let packet_len = data_len; ensure!(packet_len <= MAX_PACKET_SIZE as usize, "bad packet length"); self.buf.clear(); VarInt(packet_len as i32).encode(&mut self.buf)?; packet.encode_packet(&mut self.buf)?; } if let Some(cipher) = &mut self.cipher { cipher.encrypt(&mut self.buf); } self.write.write_all(&self.buf).await?; Ok(()) } pub fn enable_encryption(&mut self, key: &[u8; 16]) { self.cipher = Some(NewCipher::new(key.into(), key.into())); } pub fn enable_compression(&mut self, threshold: u32) { self.compression_threshold = Some(threshold); } pub fn into_inner(self) -> W { self.write } } pub struct Decoder { read: R, buf: Vec, decompress_buf: Vec, compression_threshold: Option, cipher: Option, timeout: Duration, } impl Decoder { pub fn new(read: R, timeout: Duration) -> Self { Self { read, buf: Vec::new(), decompress_buf: Vec::new(), compression_threshold: None, cipher: None, timeout, } } pub async fn read_packet(&mut self) -> anyhow::Result

{ timeout(self.timeout, self.read_packet_impl()).await? } async fn read_packet_impl(&mut self) -> anyhow::Result

{ let packet_len = self .read_var_int_async() .await .context("reading packet length")?; ensure!( (0..=MAX_PACKET_SIZE).contains(&packet_len), "invalid packet length of {packet_len}." ); self.buf.resize(packet_len as usize, 0); self.read .read_exact(&mut self.buf) .await .context("reading packet body")?; if let Some(cipher) = &mut self.cipher { cipher.decrypt(&mut self.buf); } let mut packet_contents = self.buf.as_slice(); // Compression enabled? let packet = if self.compression_threshold.is_some() { // The length of the packet data once uncompressed (zero indicates no // compression). let data_len = VarInt::decode(&mut packet_contents) .context("reading data length (once uncompressed)")? .0; ensure!( (0..=MAX_PACKET_SIZE).contains(&data_len), "invalid packet data length of {data_len}." ); if data_len != 0 { let mut z = ZlibDecoder::new(&mut packet_contents); self.decompress_buf.resize(data_len as usize, 0); z.read_exact(&mut self.decompress_buf) .context("uncompressing packet body")?; let mut uncompressed = self.decompress_buf.as_slice(); let packet = P::decode_packet(&mut uncompressed) .context("decoding packet after uncompressing")?; ensure!( uncompressed.is_empty(), "packet contents were not read completely" ); packet } else { P::decode_packet(&mut packet_contents).context("decoding packet")? } } else { P::decode_packet(&mut packet_contents).context("decoding packet")? }; if !packet_contents.is_empty() { if log_enabled!(Level::Debug) { log::debug!("complete packet after partial decode: {packet:?}"); } bail!( "packet contents were not decoded completely ({} bytes remaining)", packet_contents.len() ); } Ok(packet) } async fn read_var_int_async(&mut self) -> anyhow::Result { let mut val = 0; for i in 0..VarInt::MAX_SIZE { let array = &mut [self.read.read_u8().await?]; if let Some(cipher) = &mut self.cipher { cipher.decrypt(array); } let [byte] = *array; val |= (byte as i32 & 0b01111111) << (i * 7); if byte & 0b10000000 == 0 { return Ok(val); } } bail!("var int is too large") } pub fn enable_encryption(&mut self, key: &[u8; 16]) { self.cipher = Some(NewCipher::new(key.into(), key.into())); } pub fn enable_compression(&mut self, threshold: u32) { self.compression_threshold = Some(threshold); } pub fn into_inner(self) -> R { self.read } } /// The AES block cipher with a 128 bit key, using the CFB-8 mode of /// operation. type Cipher = Cfb8; #[cfg(test)] mod tests { use std::net::SocketAddr; use std::time::Duration; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; use super::*; use crate::protocol::packets::test::TestPacket; #[tokio::test] async fn encode_decode() { encode_decode_impl().await } const CRYPT_KEY: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; const TIMEOUT: Duration = Duration::from_secs(3); async fn encode_decode_impl() { let (tx, rx) = oneshot::channel(); let t = tokio::spawn(listen(tx)); let stream = TcpStream::connect(rx.await.unwrap()).await.unwrap(); let mut encoder = Encoder::new(stream, TIMEOUT); send_test_packet(&mut encoder).await; encoder.enable_compression(10); send_test_packet(&mut encoder).await; encoder.enable_encryption(&CRYPT_KEY); send_test_packet(&mut encoder).await; send_test_packet(&mut encoder).await; send_test_packet(&mut encoder).await; t.await.unwrap() } async fn listen(local_addr: oneshot::Sender) { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); local_addr.send(listener.local_addr().unwrap()).unwrap(); let stream = listener.accept().await.unwrap().0; let mut decoder = Decoder::new(stream, TIMEOUT); recv_test_packet(&mut decoder).await; decoder.enable_compression(10); recv_test_packet(&mut decoder).await; decoder.enable_encryption(&CRYPT_KEY); recv_test_packet(&mut decoder).await; recv_test_packet(&mut decoder).await; recv_test_packet(&mut decoder).await; } async fn send_test_packet(w: &mut Encoder) { w.write_packet(&TestPacket { first: "abcdefghijklmnopqrstuvwxyz".to_string().into(), second: vec![0x1234, 0xabcd], third: 0x1122334455667788, }) .await .unwrap(); } async fn recv_test_packet(r: &mut Decoder) { let TestPacket { first, second, third, } = r.read_packet().await.unwrap(); assert_eq!(&first, "abcdefghijklmnopqrstuvwxyz"); assert_eq!(&second, &[0x1234, 0xabcd]); assert_eq!(third, 0x1122334455667788); } }