valence/src/protocol/codec.rs

319 lines
9.6 KiB
Rust

use std::io::Read;
use std::time::Duration;
use aes::Aes128;
use anyhow::{bail, ensure, Context};
use cfb8::cipher::{AsyncStreamCipher, NewCipher};
use cfb8::Cfb8;
use flate2::bufread::{ZlibDecoder, ZlibEncoder};
use flate2::Compression;
use log::{log_enabled, Level};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::time::timeout;
use super::packets::{DecodePacket, EncodePacket};
use crate::protocol::{Decode, Encode, VarInt, MAX_PACKET_SIZE};
pub struct Encoder<W> {
write: W,
buf: Vec<u8>,
compress_buf: Vec<u8>,
compression_threshold: Option<u32>,
cipher: Option<Cipher>,
timeout: Duration,
}
impl<W: AsyncWrite + Unpin> Encoder<W> {
pub fn new(write: W, timeout: Duration) -> Self {
Self {
write,
buf: Vec::new(),
compress_buf: Vec::new(),
compression_threshold: None,
cipher: None,
timeout,
}
}
pub async fn write_packet(&mut self, packet: &impl EncodePacket) -> anyhow::Result<()> {
timeout(self.timeout, self.write_packet_impl(packet)).await?
}
async fn write_packet_impl(&mut self, packet: &impl EncodePacket) -> anyhow::Result<()> {
self.buf.clear();
packet.encode_packet(&mut self.buf)?;
let data_len = self.buf.len();
ensure!(data_len <= i32::MAX as usize, "bad packet data length");
if let Some(threshold) = self.compression_threshold {
if data_len >= threshold as usize {
let mut z = ZlibEncoder::new(self.buf.as_slice(), Compression::best());
self.compress_buf.clear();
z.read_to_end(&mut self.compress_buf)?;
let data_len_len = VarInt(data_len as i32).written_size();
let packet_len = data_len_len + self.compress_buf.len();
ensure!(packet_len <= MAX_PACKET_SIZE as usize, "bad packet length");
self.buf.clear();
VarInt(packet_len as i32).encode(&mut self.buf)?;
VarInt(data_len as i32).encode(&mut self.buf)?;
self.buf.extend_from_slice(&self.compress_buf);
} else {
let packet_len = VarInt(0).written_size() + data_len;
ensure!(packet_len <= MAX_PACKET_SIZE as usize, "bad packet length");
self.buf.clear();
VarInt(packet_len as i32).encode(&mut self.buf)?;
VarInt(0).encode(&mut self.buf)?;
packet.encode_packet(&mut self.buf)?;
}
} else {
let packet_len = data_len;
ensure!(packet_len <= MAX_PACKET_SIZE as usize, "bad packet length");
self.buf.clear();
VarInt(packet_len as i32).encode(&mut self.buf)?;
packet.encode_packet(&mut self.buf)?;
}
if let Some(cipher) = &mut self.cipher {
cipher.encrypt(&mut self.buf);
}
self.write.write_all(&self.buf).await?;
Ok(())
}
pub fn enable_encryption(&mut self, key: &[u8; 16]) {
self.cipher = Some(NewCipher::new(key.into(), key.into()));
}
pub fn enable_compression(&mut self, threshold: u32) {
self.compression_threshold = Some(threshold);
}
pub fn into_inner(self) -> W {
self.write
}
}
pub struct Decoder<R> {
read: R,
buf: Vec<u8>,
decompress_buf: Vec<u8>,
compression_threshold: Option<u32>,
cipher: Option<Cipher>,
timeout: Duration,
}
impl<R: AsyncRead + Unpin> Decoder<R> {
pub fn new(read: R, timeout: Duration) -> Self {
Self {
read,
buf: Vec::new(),
decompress_buf: Vec::new(),
compression_threshold: None,
cipher: None,
timeout,
}
}
pub async fn read_packet<P: DecodePacket>(&mut self) -> anyhow::Result<P> {
timeout(self.timeout, self.read_packet_impl()).await?
}
async fn read_packet_impl<P: DecodePacket>(&mut self) -> anyhow::Result<P> {
let packet_len = self
.read_var_int_async()
.await
.context("reading packet length")?;
ensure!(
(0..=MAX_PACKET_SIZE).contains(&packet_len),
"invalid packet length of {packet_len}."
);
self.buf.resize(packet_len as usize, 0);
self.read
.read_exact(&mut self.buf)
.await
.context("reading packet body")?;
if let Some(cipher) = &mut self.cipher {
cipher.decrypt(&mut self.buf);
}
let mut packet_contents = self.buf.as_slice();
// Compression enabled?
let packet = if self.compression_threshold.is_some() {
// The length of the packet data once uncompressed (zero indicates no
// compression).
let data_len = VarInt::decode(&mut packet_contents)
.context("reading data length (once uncompressed)")?
.0;
ensure!(
(0..=MAX_PACKET_SIZE).contains(&data_len),
"invalid packet data length of {data_len}."
);
if data_len != 0 {
let mut z = ZlibDecoder::new(&mut packet_contents);
self.decompress_buf.resize(data_len as usize, 0);
z.read_exact(&mut self.decompress_buf)
.context("uncompressing packet body")?;
let mut uncompressed = self.decompress_buf.as_slice();
let packet = P::decode_packet(&mut uncompressed)
.context("decoding packet after uncompressing")?;
ensure!(
uncompressed.is_empty(),
"packet contents were not read completely"
);
packet
} else {
P::decode_packet(&mut packet_contents).context("decoding packet")?
}
} else {
P::decode_packet(&mut packet_contents).context("decoding packet")?
};
if !packet_contents.is_empty() {
if log_enabled!(Level::Debug) {
log::debug!("complete packet after partial decode: {packet:?}");
}
bail!(
"packet contents were not decoded completely ({} bytes remaining)",
packet_contents.len()
);
}
Ok(packet)
}
async fn read_var_int_async(&mut self) -> anyhow::Result<i32> {
let mut val = 0;
for i in 0..VarInt::MAX_SIZE {
let array = &mut [self.read.read_u8().await?];
if let Some(cipher) = &mut self.cipher {
cipher.decrypt(array);
}
let [byte] = *array;
val |= (byte as i32 & 0b01111111) << (i * 7);
if byte & 0b10000000 == 0 {
return Ok(val);
}
}
bail!("var int is too large")
}
pub fn enable_encryption(&mut self, key: &[u8; 16]) {
self.cipher = Some(NewCipher::new(key.into(), key.into()));
}
pub fn enable_compression(&mut self, threshold: u32) {
self.compression_threshold = Some(threshold);
}
pub fn packet_buf(&self) -> &[u8] {
&self.buf
}
pub fn into_inner(self) -> R {
self.read
}
}
/// The AES block cipher with a 128 bit key, using the CFB-8 mode of
/// operation.
type Cipher = Cfb8<Aes128>;
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use super::*;
use crate::protocol::packets::test::TestPacket;
#[tokio::test]
async fn encode_decode() {
encode_decode_impl().await
}
const CRYPT_KEY: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
const TIMEOUT: Duration = Duration::from_secs(3);
async fn encode_decode_impl() {
let (tx, rx) = oneshot::channel();
let t = tokio::spawn(listen(tx));
let stream = TcpStream::connect(rx.await.unwrap()).await.unwrap();
let mut encoder = Encoder::new(stream, TIMEOUT);
send_test_packet(&mut encoder).await;
encoder.enable_compression(10);
send_test_packet(&mut encoder).await;
encoder.enable_encryption(&CRYPT_KEY);
send_test_packet(&mut encoder).await;
send_test_packet(&mut encoder).await;
send_test_packet(&mut encoder).await;
t.await.unwrap()
}
async fn listen(local_addr: oneshot::Sender<SocketAddr>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
local_addr.send(listener.local_addr().unwrap()).unwrap();
let stream = listener.accept().await.unwrap().0;
let mut decoder = Decoder::new(stream, TIMEOUT);
recv_test_packet(&mut decoder).await;
decoder.enable_compression(10);
recv_test_packet(&mut decoder).await;
decoder.enable_encryption(&CRYPT_KEY);
recv_test_packet(&mut decoder).await;
recv_test_packet(&mut decoder).await;
recv_test_packet(&mut decoder).await;
}
async fn send_test_packet(w: &mut Encoder<TcpStream>) {
w.write_packet(&TestPacket {
first: "abcdefghijklmnopqrstuvwxyz".to_string().into(),
second: vec![0x1234, 0xabcd],
third: 0x1122334455667788,
})
.await
.unwrap();
}
async fn recv_test_packet(r: &mut Decoder<TcpStream>) {
let TestPacket {
first,
second,
third,
} = r.read_packet().await.unwrap();
assert_eq!(&first, "abcdefghijklmnopqrstuvwxyz");
assert_eq!(&second, &[0x1234, 0xabcd]);
assert_eq!(third, 0x1122334455667788);
}
}