valence/crates/valence_protocol/src/codec.rs
Ryan Johnson 7af119da72
Replace EncodePacket and DecodePacket with Packet (#261)
## Description

Combines the `EncodePacket` and `DecodePacket` trait into a single
`Packet` trait. This makes `valence_protocol` simpler and easier to use.
This can be done because all packets were made to be bidirectional in
#253.

Additionally, a `packet_id` method has been added. This should help with
#238.

## Test Plan

Steps:
1. Run examples, packet_inspector, etc. Behavior should be the same.
2023-02-25 11:21:25 -08:00

626 lines
17 KiB
Rust

#[cfg(feature = "encryption")]
use aes::cipher::{AsyncStreamCipher, NewCipher};
use anyhow::{bail, ensure};
use bytes::{Buf, BufMut, BytesMut};
use tracing::debug;
use crate::var_int::{VarInt, VarIntDecodeError};
use crate::{Encode, Packet, Result, MAX_PACKET_SIZE};
/// The AES block cipher with a 128 bit key, using the CFB-8 mode of
/// operation.
#[cfg(feature = "encryption")]
type Cipher = cfb8::Cfb8<aes::Aes128>;
#[derive(Default)]
pub struct PacketEncoder {
buf: BytesMut,
#[cfg(feature = "compression")]
compress_buf: Vec<u8>,
#[cfg(feature = "compression")]
compression_threshold: Option<u32>,
#[cfg(feature = "encryption")]
cipher: Option<Cipher>,
}
impl PacketEncoder {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn append_bytes(&mut self, bytes: &[u8]) {
self.buf.extend_from_slice(bytes)
}
pub fn prepend_packet<'a, P>(&mut self, pkt: &P) -> Result<()>
where
P: Packet<'a>,
{
let start_len = self.buf.len();
self.append_packet(pkt)?;
let end_len = self.buf.len();
let total_packet_len = end_len - start_len;
// 1) Move everything back by the length of the packet.
// 2) Move the packet to the new space at the front.
// 3) Truncate the old packet away.
self.buf.put_bytes(0, total_packet_len);
self.buf.copy_within(..end_len, total_packet_len);
self.buf.copy_within(total_packet_len + start_len.., 0);
self.buf.truncate(end_len);
Ok(())
}
pub fn append_packet<'a, P>(&mut self, pkt: &P) -> Result<()>
where
P: Packet<'a>,
{
let start_len = self.buf.len();
pkt.encode_packet((&mut self.buf).writer())?;
let data_len = self.buf.len() - start_len;
#[cfg(feature = "compression")]
if let Some(threshold) = self.compression_threshold {
use std::io::Read;
use flate2::bufread::ZlibEncoder;
use flate2::Compression;
if data_len > threshold as usize {
let mut z = ZlibEncoder::new(&self.buf[start_len..], Compression::new(4));
self.compress_buf.clear();
let data_len_size = VarInt(data_len as i32).written_size();
let packet_len = data_len_size + z.read_to_end(&mut self.compress_buf)?;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
drop(z);
self.buf.truncate(start_len);
let mut writer = (&mut self.buf).writer();
VarInt(packet_len as i32).encode(&mut writer)?;
VarInt(data_len as i32).encode(&mut writer)?;
self.buf.extend_from_slice(&self.compress_buf);
} else {
let data_len_size = 1;
let packet_len = data_len_size + data_len;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
let packet_len_size = VarInt(packet_len as i32).written_size();
let data_prefix_len = packet_len_size + data_len_size;
self.buf.put_bytes(0, data_prefix_len);
self.buf
.copy_within(start_len..start_len + data_len, start_len + data_prefix_len);
let mut front = &mut self.buf[start_len..];
VarInt(packet_len as i32).encode(&mut front)?;
// Zero for no compression on this packet.
VarInt(0).encode(front)?;
}
return Ok(());
}
let packet_len = data_len;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
let packet_len_size = VarInt(packet_len as i32).written_size();
self.buf.put_bytes(0, packet_len_size);
self.buf
.copy_within(start_len..start_len + data_len, start_len + packet_len_size);
let front = &mut self.buf[start_len..];
VarInt(packet_len as i32).encode(front)?;
Ok(())
}
/// Takes all the packets written so far and encrypts them if encryption is
/// enabled.
pub fn take(&mut self) -> BytesMut {
#[cfg(feature = "encryption")]
if let Some(cipher) = &mut self.cipher {
cipher.encrypt(&mut self.buf);
}
self.buf.split()
}
pub fn clear(&mut self) {
self.buf.clear();
}
#[cfg(feature = "compression")]
pub fn set_compression(&mut self, threshold: Option<u32>) {
self.compression_threshold = threshold;
}
/// Encrypts all future packets **and any packets that have
/// not been [taken] yet.**
///
/// [taken]: Self::take
#[cfg(feature = "encryption")]
pub fn enable_encryption(&mut self, key: &[u8; 16]) {
assert!(self.cipher.is_none(), "encryption is already enabled");
self.cipher = Some(NewCipher::new(key.into(), key.into()));
}
}
pub fn encode_packet<'a, P>(buf: &mut Vec<u8>, pkt: &P) -> Result<()>
where
P: Packet<'a>,
{
let start_len = buf.len();
pkt.encode_packet(&mut *buf)?;
let packet_len = buf.len() - start_len;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
let packet_len_size = VarInt(packet_len as i32).written_size();
buf.put_bytes(0, packet_len_size);
buf.copy_within(
start_len..start_len + packet_len,
start_len + packet_len_size,
);
let front = &mut buf[start_len..];
VarInt(packet_len as i32).encode(front)?;
Ok(())
}
#[cfg(feature = "compression")]
pub fn encode_packet_compressed<'a, P>(
buf: &mut Vec<u8>,
pkt: &P,
threshold: u32,
scratch: &mut Vec<u8>,
) -> Result<()>
where
P: Packet<'a>,
{
use std::io::Read;
use flate2::bufread::ZlibEncoder;
use flate2::Compression;
let start_len = buf.len();
pkt.encode_packet(&mut *buf)?;
let data_len = buf.len() - start_len;
if data_len > threshold as usize {
let mut z = ZlibEncoder::new(&buf[start_len..], Compression::new(4));
scratch.clear();
let data_len_size = VarInt(data_len as i32).written_size();
let packet_len = data_len_size + z.read_to_end(scratch)?;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
drop(z);
buf.truncate(start_len);
VarInt(packet_len as i32).encode(&mut *buf)?;
VarInt(data_len as i32).encode(&mut *buf)?;
buf.extend_from_slice(scratch);
} else {
let data_len_size = 1;
let packet_len = data_len_size + data_len;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
let packet_len_size = VarInt(packet_len as i32).written_size();
let data_prefix_len = packet_len_size + data_len_size;
buf.put_bytes(0, data_prefix_len);
buf.copy_within(start_len..start_len + data_len, start_len + data_prefix_len);
let mut front = &mut buf[start_len..];
VarInt(packet_len as i32).encode(&mut front)?;
// Zero for no compression on this packet.
VarInt(0).encode(front)?;
}
Ok(())
}
#[derive(Default)]
pub struct PacketDecoder {
buf: BytesMut,
cursor: usize,
#[cfg(feature = "compression")]
decompress_buf: Vec<u8>,
#[cfg(feature = "compression")]
compression_enabled: bool,
#[cfg(feature = "encryption")]
cipher: Option<Cipher>,
}
impl PacketDecoder {
pub fn new() -> Self {
Self::default()
}
pub fn try_next_packet<'a, P>(&'a mut self) -> Result<Option<P>>
where
P: Packet<'a>,
{
self.buf.advance(self.cursor);
self.cursor = 0;
let mut r = &self.buf[..];
let packet_len = match VarInt::decode_partial(&mut r) {
Ok(len) => len,
Err(VarIntDecodeError::Incomplete) => return Ok(None),
Err(VarIntDecodeError::TooLarge) => bail!("malformed packet length VarInt"),
};
ensure!(
(0..=MAX_PACKET_SIZE).contains(&packet_len),
"packet length of {packet_len} is out of bounds"
);
if r.len() < packet_len as usize {
return Ok(None);
}
r = &r[..packet_len as usize];
#[cfg(feature = "compression")]
let packet = if self.compression_enabled {
use std::io::Read;
use anyhow::Context;
use flate2::bufread::ZlibDecoder;
use crate::Decode;
let data_len = VarInt::decode(&mut r)?.0;
ensure!(
(0..MAX_PACKET_SIZE).contains(&data_len),
"decompressed packet length of {data_len} is out of bounds"
);
if data_len != 0 {
self.decompress_buf.clear();
self.decompress_buf.reserve_exact(data_len as usize);
let mut z = ZlibDecoder::new(r).take(data_len as u64);
z.read_to_end(&mut self.decompress_buf)
.context("decompressing packet")?;
r = &self.decompress_buf;
P::decode_packet(&mut r)?
} else {
P::decode_packet(&mut r)?
}
} else {
P::decode_packet(&mut r)?
};
#[cfg(not(feature = "compression"))]
let packet = P::decode_packet(&mut r)?;
if !r.is_empty() {
let remaining = r.len();
debug!("packet after partial decode ({remaining} bytes remain): {packet:?}");
bail!("packet contents were not read completely ({remaining} bytes remain)");
}
let total_packet_len = VarInt(packet_len).written_size() + packet_len as usize;
self.cursor = total_packet_len;
Ok(Some(packet))
}
/// Repeatedly decodes a packet type until all packets in the decoder are
/// consumed or an error occurs. The decoded packets are returned in a vec.
///
/// Intended for testing purposes with encryption and compression disabled.
#[track_caller]
pub fn collect_into_vec<'a, P>(&'a mut self) -> Result<Vec<P>>
where
P: Packet<'a>,
{
#[cfg(feature = "encryption")]
assert!(
self.cipher.is_none(),
"encryption must be disabled to use this method"
);
#[cfg(feature = "compression")]
assert!(
!self.compression_enabled,
"compression must be disabled to use this method"
);
self.buf.advance(self.cursor);
self.cursor = 0;
let mut res = vec![];
loop {
let mut r = &self.buf[self.cursor..];
let packet_len = match VarInt::decode_partial(&mut r) {
Ok(len) => len,
Err(VarIntDecodeError::Incomplete) => return Ok(res),
Err(VarIntDecodeError::TooLarge) => bail!("malformed packet length VarInt"),
};
ensure!(
(0..=MAX_PACKET_SIZE).contains(&packet_len),
"packet length of {packet_len} is out of bounds"
);
if r.len() < packet_len as usize {
return Ok(res);
}
r = &r[..packet_len as usize];
let packet = P::decode_packet(&mut r)?;
if !r.is_empty() {
let remaining = r.len();
debug!("packet after partial decode ({remaining} bytes remain): {packet:?}");
bail!("packet contents were not read completely ({remaining} bytes remain)");
}
let total_packet_len = VarInt(packet_len).written_size() + packet_len as usize;
self.cursor += total_packet_len;
res.push(packet);
}
}
pub fn has_next_packet(&self) -> Result<bool> {
let mut r = &self.buf[self.cursor..];
match VarInt::decode_partial(&mut r) {
Ok(packet_len) => {
ensure!(
(0..=MAX_PACKET_SIZE).contains(&packet_len),
"packet length of {packet_len} is out of bounds"
);
Ok(r.len() >= packet_len as usize)
}
Err(VarIntDecodeError::Incomplete) => Ok(false),
Err(VarIntDecodeError::TooLarge) => bail!("malformed packet length VarInt"),
}
}
#[cfg(feature = "compression")]
pub fn compression(&self) -> bool {
self.compression_enabled
}
#[cfg(feature = "compression")]
pub fn set_compression(&mut self, enabled: bool) {
self.compression_enabled = enabled;
}
#[cfg(feature = "encryption")]
pub fn enable_encryption(&mut self, key: &[u8; 16]) {
assert!(self.cipher.is_none(), "encryption is already enabled");
let mut cipher = Cipher::new(key.into(), key.into());
// Don't forget to decrypt the data we already have.
cipher.decrypt(&mut self.buf[self.cursor..]);
self.cipher = Some(cipher);
}
pub fn queue_bytes(&mut self, mut bytes: BytesMut) {
#![allow(unused_mut)]
#[cfg(feature = "encryption")]
if let Some(cipher) = &mut self.cipher {
cipher.decrypt(&mut bytes);
}
self.buf.unsplit(bytes);
}
pub fn queue_slice(&mut self, bytes: &[u8]) {
#[cfg(feature = "encryption")]
let len = self.buf.len();
self.buf.extend_from_slice(bytes);
#[cfg(feature = "encryption")]
if let Some(cipher) = &mut self.cipher {
cipher.decrypt(&mut self.buf[len..]);
}
}
pub fn queued_bytes(&self) -> &[u8] {
self.buf.as_ref()
}
pub fn take_capacity(&mut self) -> BytesMut {
self.buf.split_off(self.buf.len())
}
pub fn reserve(&mut self, additional: usize) {
self.buf.reserve(additional);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::block_pos::BlockPos;
use crate::ident::Ident;
use crate::item::{ItemKind, ItemStack};
use crate::text::{Text, TextFormat};
use crate::tracked_data::PaintingKind;
use crate::username::Username;
use crate::var_long::VarLong;
use crate::Decode;
#[cfg(feature = "encryption")]
const CRYPT_KEY: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
#[derive(PartialEq, Debug, Encode, Decode, Packet)]
#[packet_id = 42]
struct TestPacket<'a> {
a: bool,
b: u8,
c: i32,
d: f32,
e: f64,
f: BlockPos,
g: PaintingKind,
h: Ident<&'a str>,
i: Option<ItemStack>,
j: Text,
k: Username<&'a str>,
l: VarInt,
m: VarLong,
n: &'a str,
o: &'a [u8; 10],
p: [u128; 3],
}
impl<'a> TestPacket<'a> {
fn new(n: &'a str) -> Self {
Self {
a: true,
b: 12,
c: -999,
d: 5.001,
e: 1e10,
f: BlockPos::new(1, 2, 3),
g: PaintingKind::DonkeyKong,
h: Ident::new("minecraft:whatever").unwrap(),
i: Some(ItemStack::new(ItemKind::WoodenSword, 12, None)),
j: "my ".into_text() + "fancy".italic() + " text",
k: Username::new("00a").unwrap(),
l: VarInt(123),
m: VarLong(456),
n,
o: &[7; 10],
p: [123456789; 3],
}
}
fn check(&self, n: &'a str) {
assert_eq!(self, &Self::new(n));
}
}
#[test]
fn packets_round_trip() {
let mut buf = BytesMut::new();
let mut enc = PacketEncoder::new();
enc.append_packet(&TestPacket::new("first")).unwrap();
#[cfg(feature = "compression")]
enc.set_compression(Some(0));
enc.append_packet(&TestPacket::new("second")).unwrap();
buf.unsplit(enc.take());
#[cfg(feature = "encryption")]
enc.enable_encryption(&CRYPT_KEY);
enc.append_packet(&TestPacket::new("third")).unwrap();
enc.prepend_packet(&TestPacket::new("fourth")).unwrap();
buf.unsplit(enc.take());
let mut dec = PacketDecoder::new();
dec.queue_bytes(buf);
dec.try_next_packet::<TestPacket>()
.unwrap()
.unwrap()
.check("first");
#[cfg(feature = "compression")]
dec.set_compression(true);
dec.try_next_packet::<TestPacket>()
.unwrap()
.unwrap()
.check("second");
#[cfg(feature = "encryption")]
dec.enable_encryption(&CRYPT_KEY);
dec.try_next_packet::<TestPacket>()
.unwrap()
.unwrap()
.check("fourth");
dec.try_next_packet::<TestPacket>()
.unwrap()
.unwrap()
.check("third");
}
#[test]
fn collect_packets_into_vec() {
let packets = vec![
TestPacket::new("foo"),
TestPacket::new("bar"),
TestPacket::new("baz"),
];
let mut enc = PacketEncoder::new();
let mut dec = PacketDecoder::new();
for pkt in &packets {
enc.append_packet(pkt).unwrap();
}
dec.queue_bytes(enc.take());
let res = dec.collect_into_vec::<TestPacket>().unwrap();
assert_eq!(packets, res);
}
}