From 7568ff8b4b320f4c7092ceee6f45ffb49125d91e Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 24 Aug 2022 18:17:51 -0700 Subject: [PATCH] Add serde_nbt --- Cargo.toml | 2 +- serde_nbt/Cargo.toml | 15 + serde_nbt/src/array.rs | 29 ++ serde_nbt/src/binary/de.rs | 363 ++++++++++++++++++ serde_nbt/src/binary/ser.rs | 733 ++++++++++++++++++++++++++++++++++++ serde_nbt/src/error.rs | 54 +++ serde_nbt/src/lib.rs | 90 +++++ serde_nbt/src/tests.rs | 172 +++++++++ serde_nbt/src/value.rs | 519 +++++++++++++++++++++++++ 9 files changed, 1976 insertions(+), 1 deletion(-) create mode 100644 serde_nbt/Cargo.toml create mode 100644 serde_nbt/src/array.rs create mode 100644 serde_nbt/src/binary/de.rs create mode 100644 serde_nbt/src/binary/ser.rs create mode 100644 serde_nbt/src/error.rs create mode 100644 serde_nbt/src/lib.rs create mode 100644 serde_nbt/src/tests.rs create mode 100644 serde_nbt/src/value.rs diff --git a/Cargo.toml b/Cargo.toml index a7c9824..2ed54e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,4 +67,4 @@ num = "0.4" protocol = [] [workspace] -members = ["packet-inspector"] +members = ["packet-inspector", "serde_nbt"] diff --git a/serde_nbt/Cargo.toml b/serde_nbt/Cargo.toml new file mode 100644 index 0000000..a2073df --- /dev/null +++ b/serde_nbt/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "serde_nbt" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1" +byteorder = "1.4.3" +cesu8 = "1.1.0" +serde = "1" +smallvec = { version = "1.9.0", features = ["union", "const_generics"] } + +[dev-dependencies] +ordered-float = { version = "3.0.0", features = ["serde"] } +hematite-nbt = "0.5" diff --git a/serde_nbt/src/array.rs b/serde_nbt/src/array.rs new file mode 100644 index 0000000..52993f0 --- /dev/null +++ b/serde_nbt/src/array.rs @@ -0,0 +1,29 @@ +use serde::{Serialize, Serializer}; +use serde::ser::SerializeTupleStruct; + +use super::{BYTE_ARRAY_MAGIC, INT_ARRAY_MAGIC, LONG_ARRAY_MAGIC}; + +macro_rules! def { + ($name:ident, $magic:ident) => { + pub fn $name(array: T, serializer: S) -> Result + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + T::Item: Serialize, + S: Serializer, + { + let it = array.into_iter(); + let mut sts = serializer.serialize_tuple_struct($magic, it.len())?; + + for item in it { + sts.serialize_field(&item)?; + } + + sts.end() + } + } +} + +def!(byte_array, BYTE_ARRAY_MAGIC); +def!(int_array, INT_ARRAY_MAGIC); +def!(long_array, LONG_ARRAY_MAGIC); diff --git a/serde_nbt/src/binary/de.rs b/serde_nbt/src/binary/de.rs new file mode 100644 index 0000000..af61490 --- /dev/null +++ b/serde_nbt/src/binary/de.rs @@ -0,0 +1,363 @@ +// TODO: recursion limit. +// TODO: serialize and deserialize recursion limit wrappers. (crate: +// serde_reclimit). + +use std::borrow::Cow; +use std::io::Read; + +use anyhow::anyhow; +use byteorder::{BigEndian, ReadBytesExt}; +use cesu8::from_java_cesu8; +use serde::de::{DeserializeOwned, DeserializeSeed, Visitor}; +use serde::{de, forward_to_deserialize_any}; +use smallvec::SmallVec; + +use crate::{Error, Result, Tag}; + +pub fn from_reader(reader: R) -> Result +where + R: Read, + T: DeserializeOwned, +{ + T::deserialize(&mut Deserializer::new(reader, false)) +} + +pub struct Deserializer { + reader: R, + root_name: Option, +} + +impl Deserializer { + pub fn new(reader: R, save_root_name: bool) -> Self { + Self { + reader, + root_name: if save_root_name { + Some(String::new()) + } else { + None + }, + } + } + + pub fn into_inner(self) -> (R, Option) { + (self.reader, self.root_name) + } + + fn read_header(&mut self) -> Result { + let tag = Tag::from_u8(self.reader.read_u8()?)?; + + if tag != Tag::Compound { + return Err(Error(anyhow!( + "unexpected tag `{tag}` (root value must be a compound)" + ))); + } + + if let Some(name) = &mut self.root_name { + let mut buf = SmallVec::<[u8; 128]>::new(); + for _ in 0..self.reader.read_u16::()? { + buf.push(self.reader.read_u8()?); + } + + *name = from_java_cesu8(&buf) + .map_err(|e| Error(anyhow!(e)))? + .into_owned(); + } else { + for _ in 0..self.reader.read_u16::()? { + self.reader.read_u8()?; + } + } + + Ok(tag) + } +} + +impl<'de: 'a, 'a, R: Read + 'de> de::Deserializer<'de> for &'a mut Deserializer { + type Error = Error; + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct seq tuple tuple_struct map + enum identifier ignored_any + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let tag = self.read_header()?; + + PayloadDeserializer { + reader: &mut self.reader, + tag, + } + .deserialize_any(visitor) + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + let tag = self.read_header()?; + + PayloadDeserializer { + reader: &mut self.reader, + tag, + } + .deserialize_struct(name, fields, visitor) + } + + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn is_human_readable(&self) -> bool { + false + } +} + +struct PayloadDeserializer<'a, R> { + reader: &'a mut R, + /// The type of payload to be deserialized. + tag: Tag, +} + +impl<'de: 'a, 'a, R: Read> de::Deserializer<'de> for PayloadDeserializer<'a, R> { + type Error = Error; + + forward_to_deserialize_any! { + i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option seq tuple tuple_struct map enum identifier + ignored_any + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self.tag { + Tag::End => unreachable!("invalid payload tag"), + Tag::Byte => visitor.visit_i8(self.reader.read_i8()?), + Tag::Short => visitor.visit_i16(self.reader.read_i16::()?), + Tag::Int => visitor.visit_i32(self.reader.read_i32::()?), + Tag::Long => visitor.visit_i64(self.reader.read_i64::()?), + Tag::Float => visitor.visit_f32(self.reader.read_f32::()?), + Tag::Double => visitor.visit_f64(self.reader.read_f64::()?), + Tag::ByteArray => { + let len = self.reader.read_i32::()?; + visitor.visit_seq(SeqAccess::new(self.reader, Tag::Byte, len)?) + } + Tag::String => { + let mut buf = SmallVec::<[u8; 128]>::new(); + for _ in 0..self.reader.read_u16::()? { + buf.push(self.reader.read_u8()?); + } + + match from_java_cesu8(&buf).map_err(|e| Error(anyhow!(e)))? { + Cow::Borrowed(s) => visitor.visit_str(s), + Cow::Owned(string) => visitor.visit_string(string), + } + } + Tag::List => { + let element_type = Tag::from_u8(self.reader.read_u8()?)?; + let len = self.reader.read_i32::()?; + visitor.visit_seq(SeqAccess::new(self.reader, element_type, len)?) + } + Tag::Compound => visitor.visit_map(MapAccess::new(self.reader, &[])), + Tag::IntArray => { + let len = self.reader.read_i32::()?; + visitor.visit_seq(SeqAccess::new(self.reader, Tag::Int, len)?) + } + Tag::LongArray => { + let len = self.reader.read_i32::()?; + visitor.visit_seq(SeqAccess::new(self.reader, Tag::Long, len)?) + } + } + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.tag == Tag::Byte { + match self.reader.read_i8()? { + 0 => visitor.visit_bool(false), + 1 => visitor.visit_bool(true), + n => visitor.visit_i8(n), + } + } else { + self.deserialize_any(visitor) + } + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_struct( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.tag == Tag::Compound { + visitor.visit_map(MapAccess::new(self.reader, fields)) + } else { + self.deserialize_any(visitor) + } + } + + fn is_human_readable(&self) -> bool { + false + } +} + +#[doc(hidden)] +pub struct SeqAccess<'a, R> { + reader: &'a mut R, + element_type: Tag, + remaining: u32, +} + +impl<'a, R: Read> SeqAccess<'a, R> { + fn new(reader: &'a mut R, element_type: Tag, len: i32) -> Result { + if len < 0 { + return Err(Error(anyhow!("list with negative length"))); + } + + if element_type == Tag::End && len != 0 { + return Err(Error(anyhow!( + "list with TAG_End element type must have length zero" + ))); + } + + Ok(Self { + reader, + element_type, + remaining: len as u32, + }) + } + + // TODO: function to check if this is for an array or list. +} + +impl<'de: 'a, 'a, R: Read> de::SeqAccess<'de> for SeqAccess<'a, R> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'de>, + { + if self.remaining > 0 { + self.remaining -= 1; + + seed.deserialize(PayloadDeserializer { + reader: self.reader, + tag: self.element_type, + }) + .map(Some) + } else { + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.remaining as usize) + } +} + +#[doc(hidden)] +pub struct MapAccess<'a, R> { + reader: &'a mut R, + value_tag: Tag, + /// Provides error context when deserializing structs. + fields: &'static [&'static str], +} + +impl<'a, R: Read> MapAccess<'a, R> { + fn new(reader: &'a mut R, fields: &'static [&'static str]) -> Self { + Self { + reader, + value_tag: Tag::End, + fields, + } + } +} + +impl<'de: 'a, 'a, R: Read> de::MapAccess<'de> for MapAccess<'a, R> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'de>, + { + self.value_tag = Tag::from_u8(self.reader.read_u8()?)?; + + if self.value_tag == Tag::End { + return Ok(None); + } + + seed.deserialize(PayloadDeserializer { + reader: self.reader, + tag: Tag::String, + }) + .map(Some) + .map_err(|e| match self.fields { + [f, ..] => e.context(anyhow!("compound key (field `{f}`)")), + [] => e, + }) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + if self.value_tag == Tag::End { + return Err(Error(anyhow!("end of compound?"))); + } + + let field = match self.fields { + [field, rest @ ..] => { + self.fields = rest; + Some(*field) + } + [] => None, + }; + + seed.deserialize(PayloadDeserializer { + reader: self.reader, + tag: self.value_tag, + }) + .map_err(|e| match field { + Some(f) => e.context(anyhow!("compound value (field `{f}`)")), + None => e, + }) + } +} diff --git a/serde_nbt/src/binary/ser.rs b/serde_nbt/src/binary/ser.rs new file mode 100644 index 0000000..dc899dc --- /dev/null +++ b/serde_nbt/src/binary/ser.rs @@ -0,0 +1,733 @@ +use std::io::Write; +use std::result::Result as StdResult; + +use anyhow::anyhow; +use byteorder::{BigEndian, WriteBytesExt}; +use cesu8::to_java_cesu8; +use serde::{ser, Serialize}; + +use crate::{Error, Result, Tag}; + +pub fn to_writer(mut writer: W, root_name: &str, value: &T) -> Result<()> +where + W: Write, + T: Serialize + ?Sized, +{ + value.serialize(&mut Serializer::new(&mut writer, root_name)) +} + +pub struct Serializer<'w, 'n, W: ?Sized> { + writer: &'w mut W, + allowed_tag: AllowedTag, + ser_state: SerState<'n>, +} + +#[derive(Copy, Clone)] +enum AllowedTag { + /// Any tag type is permitted to be serialized. + Any { + /// Set to the type that was serialized. Is unspecified if serialization + /// failed or has not taken place yet. + written_tag: Tag, + }, + /// Only one specific tag type is permitted to be serialized. + One { + /// The permitted tag. + tag: Tag, + /// The error message if a tag mismatch happens. + errmsg: &'static str, + }, +} + +enum SerState<'n> { + /// Serialize just the payload and nothing else. + PayloadOnly, + /// Prefix the payload with the tag and a length. + /// Used for the first element of lists. + FirstListElement { + /// Length of the list being serialized. + len: i32, + }, + /// Prefix the payload with the tag and a name. + /// Used for compound fields and the root compound. + Named { name: &'n str }, +} + +impl<'w, 'n, W: Write + ?Sized> Serializer<'w, 'n, W> { + pub fn new(writer: &'w mut W, root_name: &'n str) -> Self { + Self { + writer, + allowed_tag: AllowedTag::One { + tag: Tag::Compound, + errmsg: "root value must be a compound", + }, + ser_state: SerState::Named { name: root_name }, + } + } + + pub fn writer(&mut self) -> &mut W { + self.writer + } + + pub fn root_name(&self) -> &'n str { + match &self.ser_state { + SerState::Named { name } => *name, + _ => unreachable!(), + } + } + + pub fn set_root_name(&mut self, root_name: &'n str) { + self.ser_state = SerState::Named { name: root_name }; + } + + fn write_header(&mut self, tag: Tag) -> Result<()> { + match &mut self.allowed_tag { + AllowedTag::Any { written_tag } => *written_tag = tag, + AllowedTag::One { + tag: expected_tag, + errmsg, + } => { + if tag != *expected_tag { + let e = anyhow!(*errmsg).context(format!( + "attempt to serialize {tag} where {expected_tag} was expected" + )); + return Err(Error(e)); + } + } + } + + match &mut self.ser_state { + SerState::PayloadOnly => {} + SerState::FirstListElement { len } => { + self.writer.write_u8(tag as u8)?; + self.writer.write_i32::(*len)?; + } + SerState::Named { name } => { + self.writer.write_u8(tag as u8)?; + write_string_payload(*name, self.writer)?; + } + } + + Ok(()) + } +} + +type Impossible = ser::Impossible<(), Error>; + +#[inline] +fn unsupported(typ: &str) -> Result { + Err(Error(anyhow!("{typ} is not supported"))) +} + +impl<'a, W: Write + ?Sized> ser::Serializer for &'a mut Serializer<'_, '_, W> { + type Error = Error; + type Ok = (); + type SerializeMap = SerializeMap<'a, W>; + type SerializeSeq = SerializeSeq<'a, W>; + type SerializeStruct = SerializeStruct<'a, W>; + type SerializeStructVariant = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = SerializeArray<'a, W>; + type SerializeTupleVariant = Impossible; + + fn serialize_bool(self, v: bool) -> Result<()> { + self.write_header(Tag::Byte)?; + Ok(self.writer.write_i8(v as i8)?) + } + + fn serialize_i8(self, v: i8) -> Result<()> { + self.write_header(Tag::Byte)?; + Ok(self.writer.write_i8(v)?) + } + + fn serialize_i16(self, v: i16) -> Result<()> { + self.write_header(Tag::Short)?; + Ok(self.writer.write_i16::(v)?) + } + + fn serialize_i32(self, v: i32) -> Result<()> { + self.write_header(Tag::Int)?; + Ok(self.writer.write_i32::(v)?) + } + + fn serialize_i64(self, v: i64) -> Result<()> { + self.write_header(Tag::Long)?; + Ok(self.writer.write_i64::(v)?) + } + + fn serialize_u8(self, _v: u8) -> Result<()> { + unsupported("u8") + } + + fn serialize_u16(self, _v: u16) -> Result<()> { + unsupported("u16") + } + + fn serialize_u32(self, _v: u32) -> Result<()> { + unsupported("u32") + } + + fn serialize_u64(self, _v: u64) -> Result<()> { + unsupported("u64") + } + + fn serialize_f32(self, v: f32) -> Result<()> { + self.write_header(Tag::Float)?; + Ok(self.writer.write_f32::(v)?) + } + + fn serialize_f64(self, v: f64) -> Result<()> { + self.write_header(Tag::Double)?; + Ok(self.writer.write_f64::(v)?) + } + + fn serialize_char(self, _v: char) -> Result<()> { + unsupported("char") + } + + fn serialize_str(self, v: &str) -> Result<()> { + self.write_header(Tag::String)?; + write_string_payload(v, self.writer) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result<()> { + unsupported("&[u8]") + } + + fn serialize_none(self) -> Result<()> { + unsupported("Option") + } + + fn serialize_some(self, _value: &T) -> Result<()> + where + T: Serialize, + { + unsupported("Option") + } + + fn serialize_unit(self) -> Result<()> { + Ok(()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + Ok(()) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + ) -> Result<()> { + match variant_index.try_into() { + Ok(idx) => self.serialize_i32(idx), + Err(_) => Err(Error(anyhow!( + "variant index of {variant_index} is out of range" + ))), + } + } + + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> + where + T: Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result<()> + where + T: Serialize, + { + unsupported("newtype variant") + } + + fn serialize_seq(self, len: Option) -> Result { + self.write_header(Tag::List)?; + + let len = match len { + Some(len) => len, + None => return Err(Error(anyhow!("list length must be known up front"))), + }; + + match len.try_into() { + Ok(len) => Ok(SerializeSeq { + writer: self.writer, + element_type: Tag::End, + remaining: len, + }), + Err(_) => Err(Error(anyhow!("length of list exceeds i32::MAX"))), + } + } + + fn serialize_tuple(self, _len: usize) -> Result { + unsupported("tuple") + } + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + let element_type = match name { + crate::BYTE_ARRAY_MAGIC => { + self.write_header(Tag::ByteArray)?; + Tag::Byte + } + crate::INT_ARRAY_MAGIC => { + self.write_header(Tag::IntArray)?; + Tag::Int + } + crate::LONG_ARRAY_MAGIC => { + self.write_header(Tag::LongArray)?; + Tag::Long + } + _ => return unsupported("tuple struct"), + }; + + match len.try_into() { + Ok(len) => { + self.writer.write_i32::(len)?; + Ok(SerializeArray { + writer: self.writer, + element_type, + remaining: len, + }) + } + Err(_) => Err(Error(anyhow!("array length of {len} exceeds i32::MAX"))), + } + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unsupported("tuple variant") + } + + fn serialize_map(self, _len: Option) -> Result { + self.write_header(Tag::Compound)?; + + Ok(SerializeMap { + writer: self.writer, + }) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + self.write_header(Tag::Compound)?; + + Ok(SerializeStruct { + writer: self.writer, + }) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unsupported("struct variant") + } + + fn is_human_readable(&self) -> bool { + false + } +} + +#[doc(hidden)] +pub struct SerializeSeq<'w, W: ?Sized> { + writer: &'w mut W, + /// The element type of this list. TAG_End if unknown. + element_type: Tag, + /// Number of elements left to serialize. + remaining: i32, +} + +impl ser::SerializeSeq for SerializeSeq<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: Serialize, + { + if self.remaining <= 0 { + return Err(Error(anyhow!( + "attempt to serialize more list elements than specified" + ))); + } + + if self.element_type == Tag::End { + let mut ser = Serializer { + writer: self.writer, + allowed_tag: AllowedTag::Any { + written_tag: Tag::End, + }, + ser_state: SerState::FirstListElement { + len: self.remaining, + }, + }; + + value.serialize(&mut ser)?; + + self.element_type = match ser.allowed_tag { + AllowedTag::Any { written_tag } => written_tag, + AllowedTag::One { .. } => unreachable!(), + }; + } else { + value.serialize(&mut Serializer { + writer: self.writer, + allowed_tag: AllowedTag::One { + tag: self.element_type, + errmsg: "list elements must be homogeneous", + }, + ser_state: SerState::PayloadOnly, + })?; + } + + self.remaining -= 1; + + Ok(()) + } + + fn end(self) -> Result<()> { + if self.remaining > 0 { + return Err(Error(anyhow!( + "{} list element(s) left to serialize", + self.remaining + ))); + } + + // Were any elements written? + if self.element_type == Tag::End { + self.writer.write_u8(Tag::End as u8)?; + // List length. + self.writer.write_i32::(0)?; + } + + Ok(()) + } +} + +#[doc(hidden)] +pub struct SerializeArray<'w, W: ?Sized> { + writer: &'w mut W, + element_type: Tag, + remaining: i32, +} + +impl ser::SerializeTupleStruct for SerializeArray<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: Serialize, + { + if self.remaining <= 0 { + return Err(Error(anyhow!( + "attempt to serialize more array elements than specified" + ))); + } + + value.serialize(&mut Serializer { + writer: self.writer, + allowed_tag: AllowedTag::One { + tag: self.element_type, + errmsg: "mismatched array element type", + }, + ser_state: SerState::PayloadOnly, + })?; + + self.remaining -= 1; + + Ok(()) + } + + fn end(self) -> Result<()> { + if self.remaining > 0 { + return Err(Error(anyhow!( + "{} array element(s) left to serialize", + self.remaining + ))); + } + + Ok(()) + } +} + +#[doc(hidden)] +pub struct SerializeMap<'w, W: ?Sized> { + writer: &'w mut W, +} + +impl ser::SerializeMap for SerializeMap<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_key(&mut self, _key: &T) -> Result<()> + where + T: Serialize, + { + Err(Error(anyhow!("map keys cannot be serialized individually"))) + } + + fn serialize_value(&mut self, _value: &T) -> Result<()> + where + T: Serialize, + { + Err(Error(anyhow!( + "map values cannot be serialized individually" + ))) + } + + fn serialize_entry(&mut self, key: &K, value: &V) -> Result<()> + where + K: Serialize, + V: Serialize, + { + key.serialize(MapEntrySerializer { + writer: self.writer, + value, + }) + } + + fn end(self) -> Result<()> { + Ok(self.writer.write_u8(Tag::End as u8)?) + } +} + +struct MapEntrySerializer<'w, 'v, W: ?Sized, V: ?Sized> { + writer: &'w mut W, + value: &'v V, +} + +fn key_not_a_string() -> Result { + Err(Error(anyhow!("map keys must be strings"))) +} + +impl ser::Serializer for MapEntrySerializer<'_, '_, W, V> +where + W: Write + ?Sized, + V: Serialize + ?Sized, +{ + type Error = Error; + type Ok = (); + type SerializeMap = Impossible; + type SerializeSeq = Impossible; + type SerializeStruct = Impossible; + type SerializeStructVariant = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + + fn serialize_bool(self, _v: bool) -> Result<()> { + key_not_a_string() + } + + fn serialize_i8(self, _v: i8) -> Result<()> { + key_not_a_string() + } + + fn serialize_i16(self, _v: i16) -> Result<()> { + key_not_a_string() + } + + fn serialize_i32(self, _v: i32) -> Result<()> { + key_not_a_string() + } + + fn serialize_i64(self, _v: i64) -> Result<()> { + key_not_a_string() + } + + fn serialize_u8(self, _v: u8) -> Result<()> { + key_not_a_string() + } + + fn serialize_u16(self, _v: u16) -> Result<()> { + key_not_a_string() + } + + fn serialize_u32(self, _v: u32) -> Result<()> { + key_not_a_string() + } + + fn serialize_u64(self, _v: u64) -> Result<()> { + key_not_a_string() + } + + fn serialize_f32(self, _v: f32) -> Result<()> { + key_not_a_string() + } + + fn serialize_f64(self, _v: f64) -> Result<()> { + key_not_a_string() + } + + fn serialize_char(self, _v: char) -> Result<()> { + key_not_a_string() + } + + fn serialize_str(self, v: &str) -> Result<()> { + self.value + .serialize(&mut Serializer { + writer: self.writer, + allowed_tag: AllowedTag::Any { + written_tag: Tag::End, + }, + ser_state: SerState::Named { name: v }, + }) + .map_err(|e| e.context(format!("map key `{v}`"))) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result<()> { + key_not_a_string() + } + + fn serialize_none(self) -> Result<()> { + key_not_a_string() + } + + fn serialize_some(self, _value: &T) -> Result<()> + where + T: Serialize, + { + key_not_a_string() + } + + fn serialize_unit(self) -> Result<()> { + key_not_a_string() + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + key_not_a_string() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result<()> { + key_not_a_string() + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result<()> + where + T: Serialize, + { + key_not_a_string() + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result<()> + where + T: Serialize, + { + key_not_a_string() + } + + fn serialize_seq(self, _len: Option) -> StdResult { + key_not_a_string() + } + + fn serialize_tuple(self, _len: usize) -> StdResult { + key_not_a_string() + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> StdResult { + key_not_a_string() + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> StdResult { + key_not_a_string() + } + + fn serialize_map(self, _len: Option) -> StdResult { + key_not_a_string() + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> StdResult { + key_not_a_string() + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> StdResult { + key_not_a_string() + } +} + +#[doc(hidden)] +pub struct SerializeStruct<'w, W: ?Sized> { + writer: &'w mut W, +} + +impl ser::SerializeStruct for SerializeStruct<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: Serialize, + { + value + .serialize(&mut Serializer { + writer: self.writer, + allowed_tag: AllowedTag::Any { + written_tag: Tag::End, + }, + ser_state: SerState::Named { name: key }, + }) + .map_err(|e| e.context(format!("field `{key}`"))) + } + + fn end(self) -> Result<()> { + Ok(self.writer.write_u8(Tag::End as u8)?) + } +} + +fn write_string_payload(string: &str, writer: &mut (impl Write + ?Sized)) -> Result<()> { + let data = to_java_cesu8(string); + match data.len().try_into() { + Ok(len) => writer.write_u16::(len)?, + Err(_) => return Err(Error(anyhow!("string byte length exceeds u16::MAX"))), + }; + + writer.write_all(&data)?; + Ok(()) +} diff --git a/serde_nbt/src/error.rs b/serde_nbt/src/error.rs new file mode 100644 index 0000000..0bfff37 --- /dev/null +++ b/serde_nbt/src/error.rs @@ -0,0 +1,54 @@ +use std::error::Error as StdError; +use std::fmt::{Display, Formatter}; +use std::io; + +use anyhow::anyhow; +use serde::{de, ser}; + +#[derive(Debug)] +pub struct Error(pub(super) anyhow::Error); + +impl Error { + pub(super) fn context(self, ctx: C) -> Self + where + C: Display + Send + Sync + 'static, + { + Self(self.0.context(ctx)) + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.0.source() + } +} + +impl ser::Error for Error { + fn custom(msg: T) -> Self + where + T: Display, + { + Error(anyhow!("{msg}")) + } +} + +impl de::Error for Error { + fn custom(msg: T) -> Self + where + T: Display, + { + Error(anyhow!("{msg}")) + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Error(anyhow::Error::new(e)) + } +} diff --git a/serde_nbt/src/lib.rs b/serde_nbt/src/lib.rs new file mode 100644 index 0000000..482c482 --- /dev/null +++ b/serde_nbt/src/lib.rs @@ -0,0 +1,90 @@ +use std::fmt; +use std::fmt::{Display, Formatter}; +use anyhow::anyhow; + +pub use error::*; +pub use value::*; + +mod error; +mod value; +mod array; + +#[cfg(test)] +mod tests; + +pub use array::*; + +/// (De)serialization support for the binary representation of NBT. +pub mod binary { + pub use ser::*; + pub use de::*; + + mod ser; + mod de; +} + +pub type Result = std::result::Result; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +enum Tag { + End, + Byte, + Short, + Int, + Long, + Float, + Double, + ByteArray, + String, + List, + Compound, + IntArray, + LongArray, +} + +impl Tag { + pub fn from_u8(id: u8) -> Result { + match id { + 0 => Ok(Tag::End), + 1 => Ok(Tag::Byte), + 2 => Ok(Tag::Short), + 3 => Ok(Tag::Int), + 4 => Ok(Tag::Long), + 5 => Ok(Tag::Float), + 6 => Ok(Tag::Double), + 7 => Ok(Tag::ByteArray), + 8 => Ok(Tag::String), + 9 => Ok(Tag::List), + 10 => Ok(Tag::Compound), + 11 => Ok(Tag::IntArray), + 12 => Ok(Tag::LongArray), + _ => Err(Error(anyhow!("invalid tag byte `{id}`"))) + } + } +} + +impl Display for Tag { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let s = match self { + Tag::End => "end", + Tag::Byte => "byte", + Tag::Short => "short", + Tag::Int => "int", + Tag::Long => "long", + Tag::Float => "float", + Tag::Double => "double", + Tag::ByteArray => "byte array", + Tag::String => "string", + Tag::List => "list", + Tag::Compound => "compound", + Tag::IntArray => "int array", + Tag::LongArray => "long array", + }; + + write!(f, "{s}") + } +} + +const BYTE_ARRAY_MAGIC: &str = "__byte_array__"; +const INT_ARRAY_MAGIC: &str = "__int_array__"; +const LONG_ARRAY_MAGIC: &str = "__long_array__"; diff --git a/serde_nbt/src/tests.rs b/serde_nbt/src/tests.rs new file mode 100644 index 0000000..7e47704 --- /dev/null +++ b/serde_nbt/src/tests.rs @@ -0,0 +1,172 @@ +use binary::{from_reader, to_writer, Deserializer, Serializer}; +use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; + +use super::*; + +const ROOT_NAME: &str = "The root name‽"; + +#[derive(PartialEq, Debug, Serialize, Deserialize)] +struct Struct { + byte: i8, + list_of_int: Vec, + list_of_string: Vec, + string: String, + inner: Inner, + #[serde(serialize_with = "int_array")] + int_array: Vec, + #[serde(serialize_with = "byte_array")] + byte_array: Vec, + #[serde(serialize_with = "long_array")] + long_array: Vec, +} + +impl Struct { + pub fn new() -> Self { + Self { + byte: 123, + list_of_int: vec![3, -7, 5], + list_of_string: vec!["foo".to_owned(), "bar".to_owned(), "baz".to_owned()], + string: "aé日".to_owned(), + inner: Inner { + int: i32::MIN, + long: i64::MAX, + nan_float: OrderedFloat(f32::NAN), + neg_inf_double: f64::NEG_INFINITY, + }, + int_array: vec![5, -9, i32::MIN, 0, i32::MAX], + byte_array: vec![0, 1, 2], + long_array: vec![123, 456, 789], + } + } + + pub fn value() -> Value { + Value::Compound( + Compound::from_iter([ + ("byte".into(), 123.into()), + ("list_of_int".into(), List::Int(vec![3, -7, 5]).into()), + ( + "list_of_string".into(), + List::String(vec!["foo".into(), "bar".into(), "baz".into()]).into(), + ), + ("string".into(), "aé日".into()), + ( + "inner".into(), + Compound::from_iter([ + ("int".into(), i32::MIN.into()), + ("long".into(), i64::MAX.into()), + ("nan_float".into(), f32::NAN.into()), + ("neg_inf_double".into(), f64::NEG_INFINITY.into()), + ]) + .into(), + ), + ( + "int_array".into(), + vec![5, -9, i32::MIN, 0, i32::MAX].into(), + ), + ("byte_array".into(), vec![0_i8, 1, 2].into()), + ("long_array".into(), vec![123_i64, 456, 789].into()), + ]) + .into(), + ) + } +} + +#[derive(PartialEq, Debug, Serialize, Deserialize)] +struct Inner { + int: i32, + long: i64, + nan_float: OrderedFloat, + neg_inf_double: f64, +} + +#[test] +fn round_trip() { + let struct_ = Struct::new(); + + let mut buf = Vec::new(); + struct_ + .serialize(&mut Serializer::new(&mut buf, ROOT_NAME)) + .unwrap(); + + let reader = &mut buf.as_slice(); + + let mut de = Deserializer::new(reader, true); + + let example_de = Struct::deserialize(&mut de).unwrap(); + + assert_eq!(struct_, example_de); + + let (_, root) = de.into_inner(); + + assert_eq!(root.unwrap(), ROOT_NAME); +} + +#[test] +fn serialize() { + let struct_ = Struct::new(); + + let mut buf = Vec::new(); + + struct_ + .serialize(&mut Serializer::new(&mut buf, ROOT_NAME)) + .unwrap(); + + let example_de: Struct = nbt::from_reader(&mut buf.as_slice()).unwrap(); + + assert_eq!(struct_, example_de); +} + +#[test] +fn root_requires_compound() { + let mut buf = Vec::new(); + assert!(123 + .serialize(&mut Serializer::new(&mut buf, ROOT_NAME)) + .is_err()); +} + +#[test] +fn invalid_array_element() { + #[derive(Serialize)] + struct Struct { + #[serde(serialize_with = "byte_array")] + data: Vec, + } + + let struct_ = Struct { + data: vec![1, 2, 3], + }; + + let mut buf = Vec::new(); + assert!(struct_ + .serialize(&mut Serializer::new(&mut buf, ROOT_NAME)) + .is_err()); +} + +// #[test] +// fn struct_to_value() { +// let mut buf = Vec::new(); +// +// to_writer(&mut buf, ROOT_NAME, &Struct::new()).unwrap(); +// +// let reader = &mut buf.as_slice(); +// +// let val: Value = from_reader(reader).unwrap(); +// +// eprintln!("{:#?}", Struct::value()); +// +// assert_eq!(val, Struct::value()); +// } + +#[test] +fn value_to_struct() { + let mut buf = Vec::new(); + + to_writer(&mut buf, ROOT_NAME, &Struct::value()).unwrap(); + + let reader = &mut buf.as_slice(); + + let struct_: Struct = from_reader(reader).unwrap(); + + assert_eq!(struct_, Struct::new()); +} diff --git a/serde_nbt/src/value.rs b/serde_nbt/src/value.rs new file mode 100644 index 0000000..899c981 --- /dev/null +++ b/serde_nbt/src/value.rs @@ -0,0 +1,519 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::fmt; +use std::fmt::Formatter; + +use serde::de::{DeserializeSeed, Error, MapAccess, SeqAccess, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::{byte_array, int_array, long_array}; + +/// Represents an arbitrary NBT value. +#[derive(Clone, PartialEq, Debug)] +pub enum Value { + Byte(i8), + Short(i16), + Int(i32), + Long(i64), + Float(f32), + Double(f64), + ByteArray(Vec), + String(String), + List(List), + Compound(Compound), + IntArray(Vec), + LongArray(Vec), +} + +pub type Compound = HashMap; + +/// An NBT list value. +/// +/// NBT lists are homogeneous, meaning each list element must be of the same +/// type. This is opposed to a format like JSON where lists can be +/// heterogeneous: +/// +/// ```json +/// [42, "hello", {}] +/// ``` +/// +/// Every possible element type has its own variant in this enum. As a result, +/// heterogeneous lists are unrepresentable. +#[derive(Clone, PartialEq, Debug)] +pub enum List { + Byte(Vec), + Short(Vec), + Int(Vec), + Long(Vec), + Float(Vec), + Double(Vec), + ByteArray(Vec>), + String(Vec), + List(Vec), + Compound(Vec), + IntArray(Vec>), + LongArray(Vec>), +} + +impl List { + pub fn len(&self) -> usize { + match self { + List::Byte(l) => l.len(), + List::Short(l) => l.len(), + List::Int(l) => l.len(), + List::Long(l) => l.len(), + List::Float(l) => l.len(), + List::Double(l) => l.len(), + List::ByteArray(l) => l.len(), + List::String(l) => l.len(), + List::List(l) => l.len(), + List::Compound(l) => l.len(), + List::IntArray(l) => l.len(), + List::LongArray(l) => l.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl From for Value { + fn from(v: i8) -> Self { + Self::Byte(v) + } +} + +impl From for Value { + fn from(v: i16) -> Self { + Self::Short(v) + } +} + +impl From for Value { + fn from(v: i32) -> Self { + Self::Int(v) + } +} + +impl From for Value { + fn from(v: i64) -> Self { + Self::Long(v) + } +} + +impl From for Value { + fn from(v: f32) -> Self { + Self::Float(v) + } +} + +impl From for Value { + fn from(v: f64) -> Self { + Self::Double(v) + } +} + +impl From> for Value { + fn from(v: Vec) -> Self { + Self::ByteArray(v) + } +} + +impl From for Value { + fn from(v: String) -> Self { + Self::String(v) + } +} + +impl<'a> From<&'a str> for Value { + fn from(v: &'a str) -> Self { + Self::String(v.to_owned()) + } +} + +impl<'a> From> for Value { + fn from(v: Cow<'a, str>) -> Self { + Self::String(v.into_owned()) + } +} + +impl From for Value { + fn from(v: List) -> Self { + Self::List(v) + } +} + +impl From for Value { + fn from(v: Compound) -> Self { + Self::Compound(v) + } +} + +impl From> for Value { + fn from(v: Vec) -> Self { + Self::IntArray(v) + } +} + +impl From> for Value { + fn from(v: Vec) -> Self { + Self::LongArray(v) + } +} + +impl From> for List { + fn from(v: Vec) -> Self { + List::Byte(v) + } +} + +impl From> for List { + fn from(v: Vec) -> Self { + List::Short(v) + } +} + +impl From> for List { + fn from(v: Vec) -> Self { + List::Int(v) + } +} + +impl From> for List { + fn from(v: Vec) -> Self { + List::Long(v) + } +} + +impl From> for List { + fn from(v: Vec) -> Self { + List::Float(v) + } +} + +impl From> for List { + fn from(v: Vec) -> Self { + List::Double(v) + } +} + +impl From>> for List { + fn from(v: Vec>) -> Self { + List::ByteArray(v) + } +} + +impl From> for List { + fn from(v: Vec) -> Self { + List::String(v) + } +} + +impl From> for List { + fn from(v: Vec) -> Self { + List::List(v) + } +} + +impl From> for List { + fn from(v: Vec) -> Self { + List::Compound(v) + } +} + +impl From>> for List { + fn from(v: Vec>) -> Self { + List::IntArray(v) + } +} + +impl From>> for List { + fn from(v: Vec>) -> Self { + List::LongArray(v) + } +} + +impl Serialize for Value { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Value::Byte(v) => v.serialize(serializer), + Value::Short(v) => v.serialize(serializer), + Value::Int(v) => v.serialize(serializer), + Value::Long(v) => v.serialize(serializer), + Value::Float(v) => v.serialize(serializer), + Value::Double(v) => v.serialize(serializer), + Value::ByteArray(v) => byte_array(v, serializer), + Value::String(v) => v.serialize(serializer), + Value::List(v) => v.serialize(serializer), + Value::Compound(v) => v.serialize(serializer), + Value::IntArray(v) => int_array(v, serializer), + Value::LongArray(v) => long_array(v, serializer), + } + } +} + +impl Serialize for List { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + List::Byte(l) => l.serialize(serializer), + List::Short(l) => l.serialize(serializer), + List::Int(l) => l.serialize(serializer), + List::Long(l) => l.serialize(serializer), + List::Float(l) => l.serialize(serializer), + List::Double(l) => l.serialize(serializer), + List::ByteArray(l) => l.serialize(serializer), + List::String(l) => l.serialize(serializer), + List::List(l) => l.serialize(serializer), + List::Compound(l) => l.serialize(serializer), + List::IntArray(l) => l.serialize(serializer), + List::LongArray(l) => l.serialize(serializer), + } + } +} + +impl<'de> Deserialize<'de> for Value { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(ValueVisitor) + } +} + +struct ValueVisitor; + +impl<'de> Visitor<'de> for ValueVisitor { + type Value = Value; + + fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + write!(formatter, "a representable NBT value") + } + + fn visit_i8(self, v: i8) -> Result + where + E: Error, + { + Ok(Value::Byte(v)) + } + + fn visit_i16(self, v: i16) -> Result + where + E: Error, + { + Ok(Value::Short(v)) + } + + fn visit_i32(self, v: i32) -> Result + where + E: Error, + { + Ok(Value::Int(v)) + } + + fn visit_i64(self, v: i64) -> Result + where + E: Error, + { + Ok(Value::Long(v)) + } + + fn visit_f32(self, v: f32) -> Result + where + E: Error, + { + Ok(Value::Float(v)) + } + + fn visit_f64(self, v: f64) -> Result + where + E: Error, + { + Ok(Value::Double(v)) + } + + fn visit_string(self, v: String) -> Result + where + E: Error, + { + Ok(Value::String(v)) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + Ok(Value::String(v.to_owned())) + } + + fn visit_seq(self, seq: A) -> Result + where + A: SeqAccess<'de>, + { + ListVisitor.visit_seq(seq).map(Value::List) + } + + fn visit_map(self, map: A) -> Result + where + A: MapAccess<'de>, + { + visit_map(map).map(Value::Compound) + } +} + +impl<'de> Deserialize<'de> for List { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_seq(ListVisitor) + } +} + +struct ListVisitor; + +impl<'de> Visitor<'de> for ListVisitor { + type Value = List; + + fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + write!(formatter, "an NBT list") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut list = List::Byte(Vec::new()); + + while seq + .next_element_seed(DeserializeListElement(&mut list))? + .is_some() + {} + + Ok(list) + } +} + +struct DeserializeListElement<'a>(&'a mut List); + +impl<'de, 'a> DeserializeSeed<'de> for DeserializeListElement<'a> { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(self) + } +} + +macro_rules! visit { + ($self:expr, $variant:ident, $value:expr, $error:ty) => { + if $self.0.is_empty() { + *$self.0 = List::$variant(vec![$value]); + Ok(()) + } else if let List::$variant(elems) = $self.0 { + elems.push($value); + Ok(()) + } else { + Err(<$error>::custom("NBT lists must be homogenous")) + } + }; +} + +impl<'de, 'a> Visitor<'de> for DeserializeListElement<'a> { + type Value = (); + + fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + write!(formatter, "a valid NBT list element") + } + + fn visit_i8(self, v: i8) -> Result + where + E: Error, + { + visit!(self, Byte, v, E) + } + + fn visit_i16(self, v: i16) -> Result + where + E: Error, + { + visit!(self, Short, v, E) + } + + fn visit_i32(self, v: i32) -> Result + where + E: Error, + { + visit!(self, Int, v, E) + } + + fn visit_i64(self, v: i64) -> Result + where + E: Error, + { + visit!(self, Long, v, E) + } + + fn visit_f32(self, v: f32) -> Result + where + E: Error, + { + visit!(self, Float, v, E) + } + + fn visit_f64(self, v: f64) -> Result + where + E: Error, + { + visit!(self, Double, v, E) + } + + fn visit_string(self, v: String) -> Result + where + E: Error, + { + visit!(self, String, v, E) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + visit!(self, String, v.to_owned(), E) + } + + fn visit_seq(self, seq: A) -> Result + where + A: SeqAccess<'de>, + { + visit!(self, List, ListVisitor.visit_seq(seq)?, A::Error) + } + + fn visit_map(self, map: A) -> Result + where + A: MapAccess<'de>, + { + visit!(self, Compound, visit_map(map)?, A::Error) + } +} + +fn visit_map<'de, A>(mut map: A) -> Result +where + A: MapAccess<'de>, +{ + let mut compound = Compound::new(); + + while let Some((k, v)) = map.next_entry()? { + compound.insert(k, v); + } + + Ok(compound) +}