diff --git a/Cargo.toml b/Cargo.toml index 2ed54e6..b6ab29c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,9 @@ async-trait = "0.1" base64 = "0.13" bitfield-struct = "0.1" bitvec = "1" +bumpalo = "3.11.0" byteorder = "1" +cesu8 = "1.1.0" cfb8 = "0.7" flate2 = "1" flume = "0.10" @@ -67,4 +69,4 @@ num = "0.4" protocol = [] [workspace] -members = ["packet-inspector", "serde_nbt"] +members = ["packet-inspector", "valence_nbt"] diff --git a/serde_nbt/src/array.rs b/serde_nbt/src/array.rs deleted file mode 100644 index 52993f0..0000000 --- a/serde_nbt/src/array.rs +++ /dev/null @@ -1,29 +0,0 @@ -use serde::{Serialize, Serializer}; -use serde::ser::SerializeTupleStruct; - -use super::{BYTE_ARRAY_MAGIC, INT_ARRAY_MAGIC, LONG_ARRAY_MAGIC}; - -macro_rules! def { - ($name:ident, $magic:ident) => { - pub fn $name(array: T, serializer: S) -> Result - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - T::Item: Serialize, - S: Serializer, - { - let it = array.into_iter(); - let mut sts = serializer.serialize_tuple_struct($magic, it.len())?; - - for item in it { - sts.serialize_field(&item)?; - } - - sts.end() - } - } -} - -def!(byte_array, BYTE_ARRAY_MAGIC); -def!(int_array, INT_ARRAY_MAGIC); -def!(long_array, LONG_ARRAY_MAGIC); diff --git a/serde_nbt/src/binary/de.rs b/serde_nbt/src/binary/de.rs deleted file mode 100644 index af61490..0000000 --- a/serde_nbt/src/binary/de.rs +++ /dev/null @@ -1,363 +0,0 @@ -// TODO: recursion limit. -// TODO: serialize and deserialize recursion limit wrappers. (crate: -// serde_reclimit). - -use std::borrow::Cow; -use std::io::Read; - -use anyhow::anyhow; -use byteorder::{BigEndian, ReadBytesExt}; -use cesu8::from_java_cesu8; -use serde::de::{DeserializeOwned, DeserializeSeed, Visitor}; -use serde::{de, forward_to_deserialize_any}; -use smallvec::SmallVec; - -use crate::{Error, Result, Tag}; - -pub fn from_reader(reader: R) -> Result -where - R: Read, - T: DeserializeOwned, -{ - T::deserialize(&mut Deserializer::new(reader, false)) -} - -pub struct Deserializer { - reader: R, - root_name: Option, -} - -impl Deserializer { - pub fn new(reader: R, save_root_name: bool) -> Self { - Self { - reader, - root_name: if save_root_name { - Some(String::new()) - } else { - None - }, - } - } - - pub fn into_inner(self) -> (R, Option) { - (self.reader, self.root_name) - } - - fn read_header(&mut self) -> Result { - let tag = Tag::from_u8(self.reader.read_u8()?)?; - - if tag != Tag::Compound { - return Err(Error(anyhow!( - "unexpected tag `{tag}` (root value must be a compound)" - ))); - } - - if let Some(name) = &mut self.root_name { - let mut buf = SmallVec::<[u8; 128]>::new(); - for _ in 0..self.reader.read_u16::()? { - buf.push(self.reader.read_u8()?); - } - - *name = from_java_cesu8(&buf) - .map_err(|e| Error(anyhow!(e)))? - .into_owned(); - } else { - for _ in 0..self.reader.read_u16::()? { - self.reader.read_u8()?; - } - } - - Ok(tag) - } -} - -impl<'de: 'a, 'a, R: Read + 'de> de::Deserializer<'de> for &'a mut Deserializer { - type Error = Error; - - forward_to_deserialize_any! { - bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string - bytes byte_buf option unit unit_struct seq tuple tuple_struct map - enum identifier ignored_any - } - - fn deserialize_any(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - let tag = self.read_header()?; - - PayloadDeserializer { - reader: &mut self.reader, - tag, - } - .deserialize_any(visitor) - } - - fn deserialize_struct( - self, - name: &'static str, - fields: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - let tag = self.read_header()?; - - PayloadDeserializer { - reader: &mut self.reader, - tag, - } - .deserialize_struct(name, fields, visitor) - } - - fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - - fn is_human_readable(&self) -> bool { - false - } -} - -struct PayloadDeserializer<'a, R> { - reader: &'a mut R, - /// The type of payload to be deserialized. - tag: Tag, -} - -impl<'de: 'a, 'a, R: Read> de::Deserializer<'de> for PayloadDeserializer<'a, R> { - type Error = Error; - - forward_to_deserialize_any! { - i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string - bytes byte_buf option seq tuple tuple_struct map enum identifier - ignored_any - } - - fn deserialize_any(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - match self.tag { - Tag::End => unreachable!("invalid payload tag"), - Tag::Byte => visitor.visit_i8(self.reader.read_i8()?), - Tag::Short => visitor.visit_i16(self.reader.read_i16::()?), - Tag::Int => visitor.visit_i32(self.reader.read_i32::()?), - Tag::Long => visitor.visit_i64(self.reader.read_i64::()?), - Tag::Float => visitor.visit_f32(self.reader.read_f32::()?), - Tag::Double => visitor.visit_f64(self.reader.read_f64::()?), - Tag::ByteArray => { - let len = self.reader.read_i32::()?; - visitor.visit_seq(SeqAccess::new(self.reader, Tag::Byte, len)?) - } - Tag::String => { - let mut buf = SmallVec::<[u8; 128]>::new(); - for _ in 0..self.reader.read_u16::()? { - buf.push(self.reader.read_u8()?); - } - - match from_java_cesu8(&buf).map_err(|e| Error(anyhow!(e)))? { - Cow::Borrowed(s) => visitor.visit_str(s), - Cow::Owned(string) => visitor.visit_string(string), - } - } - Tag::List => { - let element_type = Tag::from_u8(self.reader.read_u8()?)?; - let len = self.reader.read_i32::()?; - visitor.visit_seq(SeqAccess::new(self.reader, element_type, len)?) - } - Tag::Compound => visitor.visit_map(MapAccess::new(self.reader, &[])), - Tag::IntArray => { - let len = self.reader.read_i32::()?; - visitor.visit_seq(SeqAccess::new(self.reader, Tag::Int, len)?) - } - Tag::LongArray => { - let len = self.reader.read_i32::()?; - visitor.visit_seq(SeqAccess::new(self.reader, Tag::Long, len)?) - } - } - } - - fn deserialize_bool(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - if self.tag == Tag::Byte { - match self.reader.read_i8()? { - 0 => visitor.visit_bool(false), - 1 => visitor.visit_bool(true), - n => visitor.visit_i8(n), - } - } else { - self.deserialize_any(visitor) - } - } - - fn deserialize_unit(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_unit() - } - - fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_unit() - } - - fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - - fn deserialize_struct( - self, - _name: &'static str, - fields: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - if self.tag == Tag::Compound { - visitor.visit_map(MapAccess::new(self.reader, fields)) - } else { - self.deserialize_any(visitor) - } - } - - fn is_human_readable(&self) -> bool { - false - } -} - -#[doc(hidden)] -pub struct SeqAccess<'a, R> { - reader: &'a mut R, - element_type: Tag, - remaining: u32, -} - -impl<'a, R: Read> SeqAccess<'a, R> { - fn new(reader: &'a mut R, element_type: Tag, len: i32) -> Result { - if len < 0 { - return Err(Error(anyhow!("list with negative length"))); - } - - if element_type == Tag::End && len != 0 { - return Err(Error(anyhow!( - "list with TAG_End element type must have length zero" - ))); - } - - Ok(Self { - reader, - element_type, - remaining: len as u32, - }) - } - - // TODO: function to check if this is for an array or list. -} - -impl<'de: 'a, 'a, R: Read> de::SeqAccess<'de> for SeqAccess<'a, R> { - type Error = Error; - - fn next_element_seed(&mut self, seed: T) -> Result> - where - T: DeserializeSeed<'de>, - { - if self.remaining > 0 { - self.remaining -= 1; - - seed.deserialize(PayloadDeserializer { - reader: self.reader, - tag: self.element_type, - }) - .map(Some) - } else { - Ok(None) - } - } - - fn size_hint(&self) -> Option { - Some(self.remaining as usize) - } -} - -#[doc(hidden)] -pub struct MapAccess<'a, R> { - reader: &'a mut R, - value_tag: Tag, - /// Provides error context when deserializing structs. - fields: &'static [&'static str], -} - -impl<'a, R: Read> MapAccess<'a, R> { - fn new(reader: &'a mut R, fields: &'static [&'static str]) -> Self { - Self { - reader, - value_tag: Tag::End, - fields, - } - } -} - -impl<'de: 'a, 'a, R: Read> de::MapAccess<'de> for MapAccess<'a, R> { - type Error = Error; - - fn next_key_seed(&mut self, seed: K) -> Result> - where - K: DeserializeSeed<'de>, - { - self.value_tag = Tag::from_u8(self.reader.read_u8()?)?; - - if self.value_tag == Tag::End { - return Ok(None); - } - - seed.deserialize(PayloadDeserializer { - reader: self.reader, - tag: Tag::String, - }) - .map(Some) - .map_err(|e| match self.fields { - [f, ..] => e.context(anyhow!("compound key (field `{f}`)")), - [] => e, - }) - } - - fn next_value_seed(&mut self, seed: V) -> Result - where - V: DeserializeSeed<'de>, - { - if self.value_tag == Tag::End { - return Err(Error(anyhow!("end of compound?"))); - } - - let field = match self.fields { - [field, rest @ ..] => { - self.fields = rest; - Some(*field) - } - [] => None, - }; - - seed.deserialize(PayloadDeserializer { - reader: self.reader, - tag: self.value_tag, - }) - .map_err(|e| match field { - Some(f) => e.context(anyhow!("compound value (field `{f}`)")), - None => e, - }) - } -} diff --git a/serde_nbt/src/binary/ser.rs b/serde_nbt/src/binary/ser.rs deleted file mode 100644 index dc899dc..0000000 --- a/serde_nbt/src/binary/ser.rs +++ /dev/null @@ -1,733 +0,0 @@ -use std::io::Write; -use std::result::Result as StdResult; - -use anyhow::anyhow; -use byteorder::{BigEndian, WriteBytesExt}; -use cesu8::to_java_cesu8; -use serde::{ser, Serialize}; - -use crate::{Error, Result, Tag}; - -pub fn to_writer(mut writer: W, root_name: &str, value: &T) -> Result<()> -where - W: Write, - T: Serialize + ?Sized, -{ - value.serialize(&mut Serializer::new(&mut writer, root_name)) -} - -pub struct Serializer<'w, 'n, W: ?Sized> { - writer: &'w mut W, - allowed_tag: AllowedTag, - ser_state: SerState<'n>, -} - -#[derive(Copy, Clone)] -enum AllowedTag { - /// Any tag type is permitted to be serialized. - Any { - /// Set to the type that was serialized. Is unspecified if serialization - /// failed or has not taken place yet. - written_tag: Tag, - }, - /// Only one specific tag type is permitted to be serialized. - One { - /// The permitted tag. - tag: Tag, - /// The error message if a tag mismatch happens. - errmsg: &'static str, - }, -} - -enum SerState<'n> { - /// Serialize just the payload and nothing else. - PayloadOnly, - /// Prefix the payload with the tag and a length. - /// Used for the first element of lists. - FirstListElement { - /// Length of the list being serialized. - len: i32, - }, - /// Prefix the payload with the tag and a name. - /// Used for compound fields and the root compound. - Named { name: &'n str }, -} - -impl<'w, 'n, W: Write + ?Sized> Serializer<'w, 'n, W> { - pub fn new(writer: &'w mut W, root_name: &'n str) -> Self { - Self { - writer, - allowed_tag: AllowedTag::One { - tag: Tag::Compound, - errmsg: "root value must be a compound", - }, - ser_state: SerState::Named { name: root_name }, - } - } - - pub fn writer(&mut self) -> &mut W { - self.writer - } - - pub fn root_name(&self) -> &'n str { - match &self.ser_state { - SerState::Named { name } => *name, - _ => unreachable!(), - } - } - - pub fn set_root_name(&mut self, root_name: &'n str) { - self.ser_state = SerState::Named { name: root_name }; - } - - fn write_header(&mut self, tag: Tag) -> Result<()> { - match &mut self.allowed_tag { - AllowedTag::Any { written_tag } => *written_tag = tag, - AllowedTag::One { - tag: expected_tag, - errmsg, - } => { - if tag != *expected_tag { - let e = anyhow!(*errmsg).context(format!( - "attempt to serialize {tag} where {expected_tag} was expected" - )); - return Err(Error(e)); - } - } - } - - match &mut self.ser_state { - SerState::PayloadOnly => {} - SerState::FirstListElement { len } => { - self.writer.write_u8(tag as u8)?; - self.writer.write_i32::(*len)?; - } - SerState::Named { name } => { - self.writer.write_u8(tag as u8)?; - write_string_payload(*name, self.writer)?; - } - } - - Ok(()) - } -} - -type Impossible = ser::Impossible<(), Error>; - -#[inline] -fn unsupported(typ: &str) -> Result { - Err(Error(anyhow!("{typ} is not supported"))) -} - -impl<'a, W: Write + ?Sized> ser::Serializer for &'a mut Serializer<'_, '_, W> { - type Error = Error; - type Ok = (); - type SerializeMap = SerializeMap<'a, W>; - type SerializeSeq = SerializeSeq<'a, W>; - type SerializeStruct = SerializeStruct<'a, W>; - type SerializeStructVariant = Impossible; - type SerializeTuple = Impossible; - type SerializeTupleStruct = SerializeArray<'a, W>; - type SerializeTupleVariant = Impossible; - - fn serialize_bool(self, v: bool) -> Result<()> { - self.write_header(Tag::Byte)?; - Ok(self.writer.write_i8(v as i8)?) - } - - fn serialize_i8(self, v: i8) -> Result<()> { - self.write_header(Tag::Byte)?; - Ok(self.writer.write_i8(v)?) - } - - fn serialize_i16(self, v: i16) -> Result<()> { - self.write_header(Tag::Short)?; - Ok(self.writer.write_i16::(v)?) - } - - fn serialize_i32(self, v: i32) -> Result<()> { - self.write_header(Tag::Int)?; - Ok(self.writer.write_i32::(v)?) - } - - fn serialize_i64(self, v: i64) -> Result<()> { - self.write_header(Tag::Long)?; - Ok(self.writer.write_i64::(v)?) - } - - fn serialize_u8(self, _v: u8) -> Result<()> { - unsupported("u8") - } - - fn serialize_u16(self, _v: u16) -> Result<()> { - unsupported("u16") - } - - fn serialize_u32(self, _v: u32) -> Result<()> { - unsupported("u32") - } - - fn serialize_u64(self, _v: u64) -> Result<()> { - unsupported("u64") - } - - fn serialize_f32(self, v: f32) -> Result<()> { - self.write_header(Tag::Float)?; - Ok(self.writer.write_f32::(v)?) - } - - fn serialize_f64(self, v: f64) -> Result<()> { - self.write_header(Tag::Double)?; - Ok(self.writer.write_f64::(v)?) - } - - fn serialize_char(self, _v: char) -> Result<()> { - unsupported("char") - } - - fn serialize_str(self, v: &str) -> Result<()> { - self.write_header(Tag::String)?; - write_string_payload(v, self.writer) - } - - fn serialize_bytes(self, _v: &[u8]) -> Result<()> { - unsupported("&[u8]") - } - - fn serialize_none(self) -> Result<()> { - unsupported("Option") - } - - fn serialize_some(self, _value: &T) -> Result<()> - where - T: Serialize, - { - unsupported("Option") - } - - fn serialize_unit(self) -> Result<()> { - Ok(()) - } - - fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { - Ok(()) - } - - fn serialize_unit_variant( - self, - _name: &'static str, - variant_index: u32, - _variant: &'static str, - ) -> Result<()> { - match variant_index.try_into() { - Ok(idx) => self.serialize_i32(idx), - Err(_) => Err(Error(anyhow!( - "variant index of {variant_index} is out of range" - ))), - } - } - - fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> - where - T: Serialize, - { - value.serialize(self) - } - - fn serialize_newtype_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _value: &T, - ) -> Result<()> - where - T: Serialize, - { - unsupported("newtype variant") - } - - fn serialize_seq(self, len: Option) -> Result { - self.write_header(Tag::List)?; - - let len = match len { - Some(len) => len, - None => return Err(Error(anyhow!("list length must be known up front"))), - }; - - match len.try_into() { - Ok(len) => Ok(SerializeSeq { - writer: self.writer, - element_type: Tag::End, - remaining: len, - }), - Err(_) => Err(Error(anyhow!("length of list exceeds i32::MAX"))), - } - } - - fn serialize_tuple(self, _len: usize) -> Result { - unsupported("tuple") - } - - fn serialize_tuple_struct( - self, - name: &'static str, - len: usize, - ) -> Result { - let element_type = match name { - crate::BYTE_ARRAY_MAGIC => { - self.write_header(Tag::ByteArray)?; - Tag::Byte - } - crate::INT_ARRAY_MAGIC => { - self.write_header(Tag::IntArray)?; - Tag::Int - } - crate::LONG_ARRAY_MAGIC => { - self.write_header(Tag::LongArray)?; - Tag::Long - } - _ => return unsupported("tuple struct"), - }; - - match len.try_into() { - Ok(len) => { - self.writer.write_i32::(len)?; - Ok(SerializeArray { - writer: self.writer, - element_type, - remaining: len, - }) - } - Err(_) => Err(Error(anyhow!("array length of {len} exceeds i32::MAX"))), - } - } - - fn serialize_tuple_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - unsupported("tuple variant") - } - - fn serialize_map(self, _len: Option) -> Result { - self.write_header(Tag::Compound)?; - - Ok(SerializeMap { - writer: self.writer, - }) - } - - fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - self.write_header(Tag::Compound)?; - - Ok(SerializeStruct { - writer: self.writer, - }) - } - - fn serialize_struct_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> Result { - unsupported("struct variant") - } - - fn is_human_readable(&self) -> bool { - false - } -} - -#[doc(hidden)] -pub struct SerializeSeq<'w, W: ?Sized> { - writer: &'w mut W, - /// The element type of this list. TAG_End if unknown. - element_type: Tag, - /// Number of elements left to serialize. - remaining: i32, -} - -impl ser::SerializeSeq for SerializeSeq<'_, W> { - type Error = Error; - type Ok = (); - - fn serialize_element(&mut self, value: &T) -> Result<()> - where - T: Serialize, - { - if self.remaining <= 0 { - return Err(Error(anyhow!( - "attempt to serialize more list elements than specified" - ))); - } - - if self.element_type == Tag::End { - let mut ser = Serializer { - writer: self.writer, - allowed_tag: AllowedTag::Any { - written_tag: Tag::End, - }, - ser_state: SerState::FirstListElement { - len: self.remaining, - }, - }; - - value.serialize(&mut ser)?; - - self.element_type = match ser.allowed_tag { - AllowedTag::Any { written_tag } => written_tag, - AllowedTag::One { .. } => unreachable!(), - }; - } else { - value.serialize(&mut Serializer { - writer: self.writer, - allowed_tag: AllowedTag::One { - tag: self.element_type, - errmsg: "list elements must be homogeneous", - }, - ser_state: SerState::PayloadOnly, - })?; - } - - self.remaining -= 1; - - Ok(()) - } - - fn end(self) -> Result<()> { - if self.remaining > 0 { - return Err(Error(anyhow!( - "{} list element(s) left to serialize", - self.remaining - ))); - } - - // Were any elements written? - if self.element_type == Tag::End { - self.writer.write_u8(Tag::End as u8)?; - // List length. - self.writer.write_i32::(0)?; - } - - Ok(()) - } -} - -#[doc(hidden)] -pub struct SerializeArray<'w, W: ?Sized> { - writer: &'w mut W, - element_type: Tag, - remaining: i32, -} - -impl ser::SerializeTupleStruct for SerializeArray<'_, W> { - type Error = Error; - type Ok = (); - - fn serialize_field(&mut self, value: &T) -> Result<()> - where - T: Serialize, - { - if self.remaining <= 0 { - return Err(Error(anyhow!( - "attempt to serialize more array elements than specified" - ))); - } - - value.serialize(&mut Serializer { - writer: self.writer, - allowed_tag: AllowedTag::One { - tag: self.element_type, - errmsg: "mismatched array element type", - }, - ser_state: SerState::PayloadOnly, - })?; - - self.remaining -= 1; - - Ok(()) - } - - fn end(self) -> Result<()> { - if self.remaining > 0 { - return Err(Error(anyhow!( - "{} array element(s) left to serialize", - self.remaining - ))); - } - - Ok(()) - } -} - -#[doc(hidden)] -pub struct SerializeMap<'w, W: ?Sized> { - writer: &'w mut W, -} - -impl ser::SerializeMap for SerializeMap<'_, W> { - type Error = Error; - type Ok = (); - - fn serialize_key(&mut self, _key: &T) -> Result<()> - where - T: Serialize, - { - Err(Error(anyhow!("map keys cannot be serialized individually"))) - } - - fn serialize_value(&mut self, _value: &T) -> Result<()> - where - T: Serialize, - { - Err(Error(anyhow!( - "map values cannot be serialized individually" - ))) - } - - fn serialize_entry(&mut self, key: &K, value: &V) -> Result<()> - where - K: Serialize, - V: Serialize, - { - key.serialize(MapEntrySerializer { - writer: self.writer, - value, - }) - } - - fn end(self) -> Result<()> { - Ok(self.writer.write_u8(Tag::End as u8)?) - } -} - -struct MapEntrySerializer<'w, 'v, W: ?Sized, V: ?Sized> { - writer: &'w mut W, - value: &'v V, -} - -fn key_not_a_string() -> Result { - Err(Error(anyhow!("map keys must be strings"))) -} - -impl ser::Serializer for MapEntrySerializer<'_, '_, W, V> -where - W: Write + ?Sized, - V: Serialize + ?Sized, -{ - type Error = Error; - type Ok = (); - type SerializeMap = Impossible; - type SerializeSeq = Impossible; - type SerializeStruct = Impossible; - type SerializeStructVariant = Impossible; - type SerializeTuple = Impossible; - type SerializeTupleStruct = Impossible; - type SerializeTupleVariant = Impossible; - - fn serialize_bool(self, _v: bool) -> Result<()> { - key_not_a_string() - } - - fn serialize_i8(self, _v: i8) -> Result<()> { - key_not_a_string() - } - - fn serialize_i16(self, _v: i16) -> Result<()> { - key_not_a_string() - } - - fn serialize_i32(self, _v: i32) -> Result<()> { - key_not_a_string() - } - - fn serialize_i64(self, _v: i64) -> Result<()> { - key_not_a_string() - } - - fn serialize_u8(self, _v: u8) -> Result<()> { - key_not_a_string() - } - - fn serialize_u16(self, _v: u16) -> Result<()> { - key_not_a_string() - } - - fn serialize_u32(self, _v: u32) -> Result<()> { - key_not_a_string() - } - - fn serialize_u64(self, _v: u64) -> Result<()> { - key_not_a_string() - } - - fn serialize_f32(self, _v: f32) -> Result<()> { - key_not_a_string() - } - - fn serialize_f64(self, _v: f64) -> Result<()> { - key_not_a_string() - } - - fn serialize_char(self, _v: char) -> Result<()> { - key_not_a_string() - } - - fn serialize_str(self, v: &str) -> Result<()> { - self.value - .serialize(&mut Serializer { - writer: self.writer, - allowed_tag: AllowedTag::Any { - written_tag: Tag::End, - }, - ser_state: SerState::Named { name: v }, - }) - .map_err(|e| e.context(format!("map key `{v}`"))) - } - - fn serialize_bytes(self, _v: &[u8]) -> Result<()> { - key_not_a_string() - } - - fn serialize_none(self) -> Result<()> { - key_not_a_string() - } - - fn serialize_some(self, _value: &T) -> Result<()> - where - T: Serialize, - { - key_not_a_string() - } - - fn serialize_unit(self) -> Result<()> { - key_not_a_string() - } - - fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { - key_not_a_string() - } - - fn serialize_unit_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - ) -> Result<()> { - key_not_a_string() - } - - fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result<()> - where - T: Serialize, - { - key_not_a_string() - } - - fn serialize_newtype_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _value: &T, - ) -> Result<()> - where - T: Serialize, - { - key_not_a_string() - } - - fn serialize_seq(self, _len: Option) -> StdResult { - key_not_a_string() - } - - fn serialize_tuple(self, _len: usize) -> StdResult { - key_not_a_string() - } - - fn serialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - ) -> StdResult { - key_not_a_string() - } - - fn serialize_tuple_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> StdResult { - key_not_a_string() - } - - fn serialize_map(self, _len: Option) -> StdResult { - key_not_a_string() - } - - fn serialize_struct( - self, - _name: &'static str, - _len: usize, - ) -> StdResult { - key_not_a_string() - } - - fn serialize_struct_variant( - self, - _name: &'static str, - _variant_index: u32, - _variant: &'static str, - _len: usize, - ) -> StdResult { - key_not_a_string() - } -} - -#[doc(hidden)] -pub struct SerializeStruct<'w, W: ?Sized> { - writer: &'w mut W, -} - -impl ser::SerializeStruct for SerializeStruct<'_, W> { - type Error = Error; - type Ok = (); - - fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> - where - T: Serialize, - { - value - .serialize(&mut Serializer { - writer: self.writer, - allowed_tag: AllowedTag::Any { - written_tag: Tag::End, - }, - ser_state: SerState::Named { name: key }, - }) - .map_err(|e| e.context(format!("field `{key}`"))) - } - - fn end(self) -> Result<()> { - Ok(self.writer.write_u8(Tag::End as u8)?) - } -} - -fn write_string_payload(string: &str, writer: &mut (impl Write + ?Sized)) -> Result<()> { - let data = to_java_cesu8(string); - match data.len().try_into() { - Ok(len) => writer.write_u16::(len)?, - Err(_) => return Err(Error(anyhow!("string byte length exceeds u16::MAX"))), - }; - - writer.write_all(&data)?; - Ok(()) -} diff --git a/serde_nbt/src/lib.rs b/serde_nbt/src/lib.rs deleted file mode 100644 index 482c482..0000000 --- a/serde_nbt/src/lib.rs +++ /dev/null @@ -1,90 +0,0 @@ -use std::fmt; -use std::fmt::{Display, Formatter}; -use anyhow::anyhow; - -pub use error::*; -pub use value::*; - -mod error; -mod value; -mod array; - -#[cfg(test)] -mod tests; - -pub use array::*; - -/// (De)serialization support for the binary representation of NBT. -pub mod binary { - pub use ser::*; - pub use de::*; - - mod ser; - mod de; -} - -pub type Result = std::result::Result; - -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] -enum Tag { - End, - Byte, - Short, - Int, - Long, - Float, - Double, - ByteArray, - String, - List, - Compound, - IntArray, - LongArray, -} - -impl Tag { - pub fn from_u8(id: u8) -> Result { - match id { - 0 => Ok(Tag::End), - 1 => Ok(Tag::Byte), - 2 => Ok(Tag::Short), - 3 => Ok(Tag::Int), - 4 => Ok(Tag::Long), - 5 => Ok(Tag::Float), - 6 => Ok(Tag::Double), - 7 => Ok(Tag::ByteArray), - 8 => Ok(Tag::String), - 9 => Ok(Tag::List), - 10 => Ok(Tag::Compound), - 11 => Ok(Tag::IntArray), - 12 => Ok(Tag::LongArray), - _ => Err(Error(anyhow!("invalid tag byte `{id}`"))) - } - } -} - -impl Display for Tag { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let s = match self { - Tag::End => "end", - Tag::Byte => "byte", - Tag::Short => "short", - Tag::Int => "int", - Tag::Long => "long", - Tag::Float => "float", - Tag::Double => "double", - Tag::ByteArray => "byte array", - Tag::String => "string", - Tag::List => "list", - Tag::Compound => "compound", - Tag::IntArray => "int array", - Tag::LongArray => "long array", - }; - - write!(f, "{s}") - } -} - -const BYTE_ARRAY_MAGIC: &str = "__byte_array__"; -const INT_ARRAY_MAGIC: &str = "__int_array__"; -const LONG_ARRAY_MAGIC: &str = "__long_array__"; diff --git a/serde_nbt/Cargo.toml b/valence_nbt/Cargo.toml similarity index 65% rename from serde_nbt/Cargo.toml rename to valence_nbt/Cargo.toml index a2073df..f59e8c3 100644 --- a/serde_nbt/Cargo.toml +++ b/valence_nbt/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "serde_nbt" +name = "valence_nbt" version = "0.1.0" edition = "2021" @@ -7,9 +7,12 @@ edition = "2021" anyhow = "1" byteorder = "1.4.3" cesu8 = "1.1.0" +indexmap = { version = "1.9.1", features = ["serde"] } serde = "1" smallvec = { version = "1.9.0", features = ["union", "const_generics"] } [dev-dependencies] -ordered-float = { version = "3.0.0", features = ["serde"] } hematite-nbt = "0.5" +serde_json = "1.0.85" +pretty_assertions = "1.2.1" + diff --git a/valence_nbt/src/array.rs b/valence_nbt/src/array.rs new file mode 100644 index 0000000..eb7e82d --- /dev/null +++ b/valence_nbt/src/array.rs @@ -0,0 +1,74 @@ +use std::fmt::Formatter; +use std::marker::PhantomData; + +use serde::de::value::SeqAccessDeserializer; +use serde::de::{EnumAccess, IgnoredAny, SeqAccess, VariantAccess, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::{ + ARRAY_ENUM_NAME, BYTE_ARRAY_VARIANT_NAME, INT_ARRAY_VARIANT_NAME, LONG_ARRAY_VARIANT_NAME, +}; + +macro_rules! def_mod { + ($index:literal, $mod_name:ident, $display_name:literal, $variant_name:ident) => { + pub mod $mod_name { + use super::*; + + pub fn serialize(array: &T, serializer: S) -> Result + where + T: Serialize, + S: Serializer, + { + serializer.serialize_newtype_variant(ARRAY_ENUM_NAME, $index, $variant_name, array) + } + + pub fn deserialize<'de, T, D>(deserializer: D) -> Result + where + T: Deserialize<'de>, + D: Deserializer<'de>, + { + struct ArrayVisitor(PhantomData); + + impl<'de, T: Deserialize<'de>> Visitor<'de> for ArrayVisitor { + type Value = T; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + write!( + formatter, + concat!("an NBT ", $display_name, " encoded as an enum or seq") + ) + } + + fn visit_seq(self, seq: A) -> Result + where + A: SeqAccess<'de>, + { + T::deserialize(SeqAccessDeserializer::new(seq)) + } + + fn visit_enum(self, data: A) -> Result + where + A: EnumAccess<'de>, + { + // Ignore the variant name. + let (_, variant) = data.variant::()?; + + variant.newtype_variant() + } + } + + let variants = &[ + BYTE_ARRAY_VARIANT_NAME, + INT_ARRAY_VARIANT_NAME, + LONG_ARRAY_VARIANT_NAME, + ]; + + deserializer.deserialize_enum(ARRAY_ENUM_NAME, variants, ArrayVisitor(PhantomData)) + } + } + }; +} + +def_mod!(0, byte_array, "byte array", BYTE_ARRAY_VARIANT_NAME); +def_mod!(1, int_array, "int array", INT_ARRAY_VARIANT_NAME); +def_mod!(2, long_array, "long array", LONG_ARRAY_VARIANT_NAME); diff --git a/valence_nbt/src/binary/de.rs b/valence_nbt/src/binary/de.rs new file mode 100644 index 0000000..1e34b4f --- /dev/null +++ b/valence_nbt/src/binary/de.rs @@ -0,0 +1,20 @@ +use std::io::Read; + +pub use root::RootDeserializer as Deserializer; +use serde::de::DeserializeOwned; + +use crate::Error; + +mod array; +mod compound; +mod list; +mod payload; +mod root; + +pub fn from_reader(reader: R) -> Result +where + R: Read, + T: DeserializeOwned, +{ + T::deserialize(&mut Deserializer::new(reader, false)) +} diff --git a/valence_nbt/src/binary/de/array.rs b/valence_nbt/src/binary/de/array.rs new file mode 100644 index 0000000..607b664 --- /dev/null +++ b/valence_nbt/src/binary/de/array.rs @@ -0,0 +1,158 @@ +use std::io::Read; + +use anyhow::anyhow; +use byteorder::{BigEndian, ReadBytesExt}; +use serde::de::value::StrDeserializer; +use serde::de::{DeserializeSeed, Error as _, SeqAccess, Unexpected, Visitor}; +use serde::{de, forward_to_deserialize_any, Deserializer}; + +use crate::binary::de::payload::PayloadDeserializer; +use crate::{ + ArrayType, Error, BYTE_ARRAY_VARIANT_NAME, INT_ARRAY_VARIANT_NAME, LONG_ARRAY_VARIANT_NAME, +}; + +pub struct EnumAccess<'r, R: ?Sized> { + pub(super) reader: &'r mut R, + pub(super) array_type: ArrayType, +} + +impl<'de: 'r, 'r, R: Read + ?Sized> de::EnumAccess<'de> for EnumAccess<'r, R> { + type Error = Error; + type Variant = VariantAccess<'r, R>; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: DeserializeSeed<'de>, + { + let variant_name = match self.array_type { + ArrayType::Byte => BYTE_ARRAY_VARIANT_NAME, + ArrayType::Int => INT_ARRAY_VARIANT_NAME, + ArrayType::Long => LONG_ARRAY_VARIANT_NAME, + }; + + Ok(( + seed.deserialize(StrDeserializer::::new(variant_name))?, + VariantAccess { + reader: self.reader, + array_type: self.array_type, + }, + )) + } +} + +pub struct VariantAccess<'r, R: ?Sized> { + reader: &'r mut R, + array_type: ArrayType, +} + +impl<'de: 'r, 'r, R: Read + ?Sized> de::VariantAccess<'de> for VariantAccess<'r, R> { + type Error = Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + Err(Error::invalid_type( + Unexpected::NewtypeVariant, + &"unit variant", + )) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + seed.deserialize(ArrayDeserializer { + reader: self.reader, + array_type: self.array_type, + }) + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::invalid_type( + Unexpected::NewtypeVariant, + &"tuple variant", + )) + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(Error::invalid_type( + Unexpected::NewtypeVariant, + &"struct variant", + )) + } +} + +struct ArrayDeserializer<'r, R: ?Sized> { + reader: &'r mut R, + array_type: ArrayType, +} + +impl<'de: 'r, 'r, R: Read + ?Sized> Deserializer<'de> for ArrayDeserializer<'r, R> { + type Error = Error; + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum identifier ignored_any + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let len = self.reader.read_i32::()?; + + if len < 0 { + return Err(Error(anyhow!("array with negative length"))); + } + + visitor.visit_seq(ArraySeqAccess { + reader: self.reader, + array_type: self.array_type, + remaining: len, + }) + } + + fn is_human_readable(&self) -> bool { + false + } +} + +struct ArraySeqAccess<'r, R: ?Sized> { + reader: &'r mut R, + array_type: ArrayType, + remaining: i32, +} + +impl<'de: 'r, 'r, R: Read + ?Sized> SeqAccess<'de> for ArraySeqAccess<'r, R> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + if self.remaining > 0 { + self.remaining -= 1; + + seed.deserialize(PayloadDeserializer { + reader: self.reader, + tag: self.array_type.element_tag(), + }) + .map(Some) + } else { + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.remaining as usize) + } +} diff --git a/valence_nbt/src/binary/de/compound.rs b/valence_nbt/src/binary/de/compound.rs new file mode 100644 index 0000000..8ff5167 --- /dev/null +++ b/valence_nbt/src/binary/de/compound.rs @@ -0,0 +1,77 @@ +use std::io::Read; + +use anyhow::anyhow; +use byteorder::ReadBytesExt; +use serde::de; +use serde::de::DeserializeSeed; + +use crate::binary::de::payload::PayloadDeserializer; +use crate::{Error, Tag}; + +pub struct MapAccess<'r, R: ?Sized> { + reader: &'r mut R, + value_tag: Tag, + /// Provides error context when deserializing structs. + fields: &'static [&'static str], +} + +impl<'r, R: Read + ?Sized> MapAccess<'r, R> { + pub fn new(reader: &'r mut R, fields: &'static [&'static str]) -> Self { + Self { + reader, + value_tag: Tag::End, + fields, + } + } +} + +impl<'de: 'r, 'r, R: Read + ?Sized> de::MapAccess<'de> for MapAccess<'r, R> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + self.value_tag = Tag::from_u8(self.reader.read_u8()?)?; + + if self.value_tag == Tag::End { + return Ok(None); + } + + seed.deserialize(PayloadDeserializer { + reader: self.reader, + tag: Tag::String, + }) + .map(Some) + .map_err(|e| match self.fields { + [f, ..] => e.context(anyhow!("compound key (field `{f}`)")), + [] => e, + }) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + if self.value_tag == Tag::End { + return Err(Error(anyhow!("end of compound?"))); + } + + let field = match self.fields { + [field, rest @ ..] => { + self.fields = rest; + Some(*field) + } + [] => None, + }; + + seed.deserialize(PayloadDeserializer { + reader: self.reader, + tag: self.value_tag, + }) + .map_err(|e| match field { + Some(f) => e.context(anyhow!("compound value (field `{f}`)")), + None => e, + }) + } +} diff --git a/valence_nbt/src/binary/de/list.rs b/valence_nbt/src/binary/de/list.rs new file mode 100644 index 0000000..4014015 --- /dev/null +++ b/valence_nbt/src/binary/de/list.rs @@ -0,0 +1,38 @@ +use std::io::Read; + +use serde::de; +use serde::de::DeserializeSeed; + +use crate::binary::de::payload::PayloadDeserializer; +use crate::{Error, Tag}; + +pub(super) struct SeqAccess<'r, R: ?Sized> { + pub reader: &'r mut R, + pub element_tag: Tag, + pub remaining: u32, +} + +impl<'de: 'r, 'r, R: Read + ?Sized> de::SeqAccess<'de> for SeqAccess<'r, R> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + if self.remaining > 0 { + self.remaining -= 1; + + seed.deserialize(PayloadDeserializer { + reader: self.reader, + tag: self.element_tag, + }) + .map(Some) + } else { + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.remaining as usize) + } +} diff --git a/valence_nbt/src/binary/de/payload.rs b/valence_nbt/src/binary/de/payload.rs new file mode 100644 index 0000000..2bf9c99 --- /dev/null +++ b/valence_nbt/src/binary/de/payload.rs @@ -0,0 +1,124 @@ +use std::borrow::Cow; +use std::io::Read; + +use anyhow::anyhow; +use byteorder::{BigEndian, ReadBytesExt}; +use cesu8::from_java_cesu8; +use serde::de::Visitor; +use serde::{de, forward_to_deserialize_any}; +use smallvec::SmallVec; + +use crate::binary::de::array::EnumAccess; +use crate::binary::de::compound::MapAccess; +use crate::binary::de::list::SeqAccess; +use crate::{ArrayType, Error, Tag}; + +pub(super) struct PayloadDeserializer<'w, R: ?Sized> { + pub reader: &'w mut R, + /// The type of payload to be deserialized. + pub tag: Tag, +} + +impl<'de: 'w, 'w, R: Read + ?Sized> de::Deserializer<'de> for PayloadDeserializer<'w, R> { + type Error = Error; + + forward_to_deserialize_any! { + i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map enum identifier ignored_any + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self.tag { + Tag::End => unreachable!("invalid payload tag"), + Tag::Byte => visitor.visit_i8(self.reader.read_i8()?), + Tag::Short => visitor.visit_i16(self.reader.read_i16::()?), + Tag::Int => visitor.visit_i32(self.reader.read_i32::()?), + Tag::Long => visitor.visit_i64(self.reader.read_i64::()?), + Tag::Float => visitor.visit_f32(self.reader.read_f32::()?), + Tag::Double => visitor.visit_f64(self.reader.read_f64::()?), + Tag::ByteArray => visitor.visit_enum(EnumAccess { + reader: self.reader, + array_type: ArrayType::Byte, + }), + Tag::String => { + let mut buf = SmallVec::<[u8; 128]>::new(); + for _ in 0..self.reader.read_u16::()? { + buf.push(self.reader.read_u8()?); + } + + match from_java_cesu8(&buf).map_err(|e| Error(anyhow!(e)))? { + Cow::Borrowed(s) => visitor.visit_str(s), + Cow::Owned(string) => visitor.visit_string(string), + } + } + Tag::List => { + let element_tag = Tag::from_u8(self.reader.read_u8()?)?; + let len = self.reader.read_i32::()?; + + if len < 0 { + return Err(Error(anyhow!("list with negative length"))); + } + + if element_tag == Tag::End && len != 0 { + return Err(Error(anyhow!( + "list with TAG_End element type must have length zero" + ))); + } + + visitor.visit_seq(SeqAccess { + reader: self.reader, + element_tag, + remaining: len as u32, + }) + } + Tag::Compound => visitor.visit_map(MapAccess::new(self.reader, &[])), + Tag::IntArray => visitor.visit_enum(EnumAccess { + reader: self.reader, + array_type: ArrayType::Int, + }), + Tag::LongArray => visitor.visit_enum(EnumAccess { + reader: self.reader, + array_type: ArrayType::Long, + }), + } + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.tag == Tag::Byte { + match self.reader.read_i8()? { + 0 => visitor.visit_bool(false), + 1 => visitor.visit_bool(true), + n => visitor.visit_i8(n), + } + } else { + self.deserialize_any(visitor) + } + } + + fn deserialize_struct( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.tag == Tag::Compound { + visitor.visit_map(MapAccess::new(self.reader, fields)) + } else { + self.deserialize_any(visitor) + } + } + + fn is_human_readable(&self) -> bool { + false + } +} diff --git a/valence_nbt/src/binary/de/root.rs b/valence_nbt/src/binary/de/root.rs new file mode 100644 index 0000000..3eac6b3 --- /dev/null +++ b/valence_nbt/src/binary/de/root.rs @@ -0,0 +1,102 @@ +use std::borrow::Cow; +use std::io::Read; + +use anyhow::anyhow; +use byteorder::{BigEndian, ReadBytesExt}; +use cesu8::from_java_cesu8; +use serde::de::Visitor; +use serde::{forward_to_deserialize_any, Deserializer}; +use smallvec::SmallVec; + +use crate::binary::de::payload::PayloadDeserializer; +use crate::{Error, Tag}; + +#[non_exhaustive] +pub struct RootDeserializer { + pub reader: R, + pub root_name: String, + pub save_root_name: bool, +} + +impl RootDeserializer { + pub fn new(reader: R, save_root_name: bool) -> Self { + Self { + reader, + root_name: String::new(), + save_root_name, + } + } + + fn read_name(&mut self) -> Result { + let tag = Tag::from_u8(self.reader.read_u8()?)?; + + if tag != Tag::Compound { + return Err(Error(anyhow!( + "unexpected tag `{tag}` (root value must be a compound)" + ))); + } + + if self.save_root_name { + let mut buf = SmallVec::<[u8; 128]>::new(); + for _ in 0..self.reader.read_u16::()? { + buf.push(self.reader.read_u8()?); + } + + match from_java_cesu8(&buf).map_err(|e| Error(anyhow!(e)))? { + Cow::Borrowed(s) => s.clone_into(&mut self.root_name), + Cow::Owned(s) => self.root_name = s, + } + } else { + for _ in 0..self.reader.read_u16::()? { + self.reader.read_u8()?; + } + } + + Ok(tag) + } +} + +impl<'de: 'a, 'a, R: Read> Deserializer<'de> for &'a mut RootDeserializer { + type Error = Error; + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map enum identifier ignored_any + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let tag = self.read_name()?; + + PayloadDeserializer { + reader: &mut self.reader, + tag, + } + .deserialize_any(visitor) + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + let tag = self.read_name()?; + + PayloadDeserializer { + reader: &mut self.reader, + tag, + } + .deserialize_struct(name, fields, visitor) + } + + fn is_human_readable(&self) -> bool { + false + } +} diff --git a/valence_nbt/src/binary/ser.rs b/valence_nbt/src/binary/ser.rs new file mode 100644 index 0000000..b54b2a2 --- /dev/null +++ b/valence_nbt/src/binary/ser.rs @@ -0,0 +1,36 @@ +use std::io::Write; + +use anyhow::anyhow; +use byteorder::{BigEndian, WriteBytesExt}; +use cesu8::to_java_cesu8; +pub use root::RootSerializer as Serializer; +use serde::{ser, Serialize}; + +use crate::{Error, Result}; + +mod map; +mod payload; +mod root; +mod seq; +mod structs; + +pub fn to_writer(writer: W, value: &T) -> Result<()> +where + W: Write, + T: Serialize + ?Sized, +{ + value.serialize(&mut Serializer::new(writer, "")) +} + +type Impossible = ser::Impossible<(), Error>; + +fn write_string(mut writer: impl Write, string: &str) -> Result<()> { + let data = to_java_cesu8(string); + match data.len().try_into() { + Ok(len) => writer.write_u16::(len)?, + Err(_) => return Err(Error(anyhow!("string byte length exceeds u16::MAX"))), + }; + + writer.write_all(&data)?; + Ok(()) +} diff --git a/valence_nbt/src/binary/ser/map.rs b/valence_nbt/src/binary/ser/map.rs new file mode 100644 index 0000000..18656d4 --- /dev/null +++ b/valence_nbt/src/binary/ser/map.rs @@ -0,0 +1,236 @@ +use std::io::Write; + +use anyhow::anyhow; +use byteorder::WriteBytesExt; +use serde::{ser, Serialize, Serializer}; + +use crate::binary::ser::payload::PayloadSerializer; +use crate::binary::ser::Impossible; +use crate::{Error, Tag}; + +pub struct SerializeMap<'w, W: ?Sized> { + pub(super) writer: &'w mut W, +} + +impl<'w, W: Write + ?Sized> ser::SerializeMap for SerializeMap<'w, W> { + type Error = Error; + type Ok = (); + + fn serialize_key(&mut self, _key: &T) -> Result<(), Error> + where + T: Serialize, + { + Err(Error(anyhow!("map keys cannot be serialized individually"))) + } + + fn serialize_value(&mut self, _value: &T) -> Result<(), Error> + where + T: Serialize, + { + Err(Error(anyhow!( + "map values cannot be serialized individually" + ))) + } + + fn serialize_entry( + &mut self, + key: &K, + value: &V, + ) -> Result<(), Self::Error> + where + K: Serialize, + V: Serialize, + { + key.serialize(MapEntrySerializer { + writer: self.writer, + value, + }) + } + + fn end(self) -> Result { + Ok(self.writer.write_u8(Tag::End as u8)?) + } +} + +struct MapEntrySerializer<'w, 'v, W: ?Sized, V: ?Sized> { + writer: &'w mut W, + value: &'v V, +} + +fn key_not_a_string(typ: &str) -> Result { + Err(Error(anyhow!("map keys must be strings (got {typ})"))) +} + +impl Serializer for MapEntrySerializer<'_, '_, W, V> { + type Error = Error; + type Ok = (); + type SerializeMap = Impossible; + type SerializeSeq = Impossible; + type SerializeStruct = Impossible; + type SerializeStructVariant = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + + fn serialize_bool(self, _v: bool) -> Result { + key_not_a_string("bool") + } + + fn serialize_i8(self, _v: i8) -> Result { + key_not_a_string("i8") + } + + fn serialize_i16(self, _v: i16) -> Result { + key_not_a_string("i16") + } + + fn serialize_i32(self, _v: i32) -> Result { + key_not_a_string("i32") + } + + fn serialize_i64(self, _v: i64) -> Result { + key_not_a_string("i64") + } + + fn serialize_u8(self, _v: u8) -> Result { + key_not_a_string("u8") + } + + fn serialize_u16(self, _v: u16) -> Result { + key_not_a_string("u16") + } + + fn serialize_u32(self, _v: u32) -> Result { + key_not_a_string("u32") + } + + fn serialize_u64(self, _v: u64) -> Result { + key_not_a_string("u64") + } + + fn serialize_f32(self, _v: f32) -> Result { + key_not_a_string("f32") + } + + fn serialize_f64(self, _v: f64) -> Result { + key_not_a_string("f64") + } + + fn serialize_char(self, _v: char) -> Result { + key_not_a_string("char") + } + + fn serialize_str(self, v: &str) -> Result { + self.value + .serialize(&mut PayloadSerializer::named(self.writer, v)) + .map_err(|e| e.context(format!("key `{v}`"))) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + key_not_a_string("&[u8]") + } + + fn serialize_none(self) -> Result { + key_not_a_string("None") + } + + fn serialize_some(self, _value: &T) -> Result + where + T: Serialize, + { + key_not_a_string("Some") + } + + fn serialize_unit(self) -> Result { + key_not_a_string("()") + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + key_not_a_string("unit struct") + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + key_not_a_string("unit variant") + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + key_not_a_string("newtype struct") + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + key_not_a_string("newtype variant") + } + + fn serialize_seq(self, _len: Option) -> Result { + key_not_a_string("seq") + } + + fn serialize_tuple(self, _len: usize) -> Result { + key_not_a_string("tuple") + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + key_not_a_string("tuple struct") + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + key_not_a_string("tuple variant") + } + + fn serialize_map(self, _len: Option) -> Result { + key_not_a_string("map") + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + key_not_a_string("struct") + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + key_not_a_string("struct variant") + } + + fn is_human_readable(&self) -> bool { + false + } +} diff --git a/valence_nbt/src/binary/ser/payload.rs b/valence_nbt/src/binary/ser/payload.rs new file mode 100644 index 0000000..84e08d6 --- /dev/null +++ b/valence_nbt/src/binary/ser/payload.rs @@ -0,0 +1,334 @@ +use std::io::Write; + +use anyhow::anyhow; +use byteorder::{BigEndian, WriteBytesExt}; +use serde::{Serialize, Serializer}; + +use crate::binary::ser::map::SerializeMap; +use crate::binary::ser::seq::SerializeSeq; +use crate::binary::ser::structs::SerializeStruct; +use crate::binary::ser::{write_string, Impossible}; +use crate::{ArrayType, Error, Tag}; + +pub struct PayloadSerializer<'w, 'n, W: ?Sized> { + writer: &'w mut W, + state: State<'n>, +} + +#[derive(Clone, Copy)] +enum State<'n> { + Named(&'n str), + FirstListElement { len: i32, written_tag: Tag }, + SeqElement { element_type: Tag }, + Array(ArrayType), +} + +impl<'w, 'n, W: Write + ?Sized> PayloadSerializer<'w, 'n, W> { + pub(super) fn named(writer: &'w mut W, name: &'n str) -> Self { + Self { + writer, + state: State::Named(name), + } + } + + pub(super) fn first_list_element(writer: &'w mut W, len: i32) -> Self { + Self { + writer, + state: State::FirstListElement { + len, + written_tag: Tag::End, + }, + } + } + + pub(super) fn seq_element(writer: &'w mut W, element_type: Tag) -> Self { + Self { + writer, + state: State::SeqElement { element_type }, + } + } + + pub(super) fn written_tag(&self) -> Option { + match self.state { + State::FirstListElement { written_tag, .. } if written_tag != Tag::End => { + Some(written_tag) + } + _ => None, + } + } + + fn check_state(&mut self, tag: Tag) -> Result<(), Error> { + match &mut self.state { + State::Named(name) => { + self.writer.write_u8(tag as u8)?; + write_string(&mut *self.writer, *name)?; + } + State::FirstListElement { len, written_tag } => { + self.writer.write_u8(tag as u8)?; + self.writer.write_i32::(*len)?; + *written_tag = tag; + } + State::SeqElement { element_type } => { + if tag != *element_type { + return Err(Error(anyhow!( + "list/array elements must be homogeneous (got {tag}, expected \ + {element_type})", + ))); + } + } + State::Array(array_type) => { + let msg = match array_type { + ArrayType::Byte => "a byte array", + ArrayType::Int => "an int array", + ArrayType::Long => "a long array", + }; + + return Err(Error(anyhow!( + "expected a seq for {msg}, got {tag} instead" + ))); + } + } + + Ok(()) + } +} + +#[inline] +fn unsupported(typ: &str) -> Result { + Err(Error(anyhow!("{typ} is not supported"))) +} + +impl<'a, W: Write + ?Sized> Serializer for &'a mut PayloadSerializer<'_, '_, W> { + type Error = Error; + type Ok = (); + type SerializeMap = SerializeMap<'a, W>; + type SerializeSeq = SerializeSeq<'a, W>; + type SerializeStruct = SerializeStruct<'a, W>; + type SerializeStructVariant = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + + fn serialize_bool(self, v: bool) -> Result { + self.check_state(Tag::Byte)?; + Ok(self.writer.write_i8(v as i8)?) + } + + fn serialize_i8(self, v: i8) -> Result { + self.check_state(Tag::Byte)?; + Ok(self.writer.write_i8(v)?) + } + + fn serialize_i16(self, v: i16) -> Result { + self.check_state(Tag::Short)?; + Ok(self.writer.write_i16::(v)?) + } + + fn serialize_i32(self, v: i32) -> Result { + self.check_state(Tag::Int)?; + Ok(self.writer.write_i32::(v)?) + } + + fn serialize_i64(self, v: i64) -> Result { + self.check_state(Tag::Long)?; + Ok(self.writer.write_i64::(v)?) + } + + fn serialize_u8(self, _v: u8) -> Result { + unsupported("u8") + } + + fn serialize_u16(self, _v: u16) -> Result { + unsupported("u16") + } + + fn serialize_u32(self, _v: u32) -> Result { + unsupported("u32") + } + + fn serialize_u64(self, _v: u64) -> Result { + unsupported("u64") + } + + fn serialize_f32(self, v: f32) -> Result { + self.check_state(Tag::Float)?; + Ok(self.writer.write_f32::(v)?) + } + + fn serialize_f64(self, v: f64) -> Result { + self.check_state(Tag::Double)?; + Ok(self.writer.write_f64::(v)?) + } + + fn serialize_char(self, _v: char) -> Result { + unsupported("char") + } + + fn serialize_str(self, v: &str) -> Result { + self.check_state(Tag::String)?; + write_string(&mut *self.writer, v) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + unsupported("&[u8]") + } + + fn serialize_none(self) -> Result { + unsupported("None") + } + + fn serialize_some(self, _value: &T) -> Result + where + T: Serialize, + { + unsupported("Some") + } + + fn serialize_unit(self) -> Result { + unsupported("()") + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + unsupported("unit struct") + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + unsupported("unit variant") + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + unsupported("newtype struct") + } + + fn serialize_newtype_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: Serialize, + { + let (array_tag, array_type) = match (name, variant) { + (crate::ARRAY_ENUM_NAME, crate::BYTE_ARRAY_VARIANT_NAME) => { + (Tag::ByteArray, ArrayType::Byte) + } + (crate::ARRAY_ENUM_NAME, crate::INT_ARRAY_VARIANT_NAME) => { + (Tag::IntArray, ArrayType::Int) + } + (crate::ARRAY_ENUM_NAME, crate::LONG_ARRAY_VARIANT_NAME) => { + (Tag::LongArray, ArrayType::Long) + } + _ => return unsupported("newtype variant"), + }; + + self.check_state(array_tag)?; + + value.serialize(&mut PayloadSerializer { + writer: self.writer, + state: State::Array(array_type), + }) + } + + fn serialize_seq(self, len: Option) -> Result { + if let State::Array(array_type) = self.state { + let len = match len { + Some(len) => len, + None => return Err(Error(anyhow!("array length must be known up front"))), + }; + + match len.try_into() { + Ok(len) => { + self.writer.write_i32::(len)?; + Ok(SerializeSeq::array( + self.writer, + array_type.element_tag(), + len, + )) + } + Err(_) => Err(Error(anyhow!("length of array exceeds i32::MAX"))), + } + } else { + self.check_state(Tag::List)?; + + let len = match len { + Some(len) => len, + None => return Err(Error(anyhow!("list length must be known up front"))), + }; + + match len.try_into() { + Ok(len) => Ok(SerializeSeq::list(self.writer, len)), + Err(_) => Err(Error(anyhow!("length of list exceeds i32::MAX"))), + } + } + } + + fn serialize_tuple(self, _len: usize) -> Result { + unsupported("tuple") + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + unsupported("tuple struct") + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unsupported("tuple variant") + } + + fn serialize_map(self, _len: Option) -> Result { + self.check_state(Tag::Compound)?; + + Ok(SerializeMap { + writer: self.writer, + }) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + self.check_state(Tag::Compound)?; + + Ok(SerializeStruct { + writer: self.writer, + }) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unsupported("struct variant") + } + + fn is_human_readable(&self) -> bool { + false + } +} diff --git a/valence_nbt/src/binary/ser/root.rs b/valence_nbt/src/binary/ser/root.rs new file mode 100644 index 0000000..fc2aeda --- /dev/null +++ b/valence_nbt/src/binary/ser/root.rs @@ -0,0 +1,213 @@ +use std::io::Write; + +use anyhow::anyhow; +use byteorder::WriteBytesExt; +use serde::{Serialize, Serializer}; + +use crate::binary::ser::map::SerializeMap; +use crate::binary::ser::structs::SerializeStruct; +use crate::binary::ser::{write_string, Impossible}; +use crate::{Error, Tag}; + +#[non_exhaustive] +pub struct RootSerializer<'n, W> { + pub writer: W, + pub root_name: &'n str, +} + +impl<'n, W: Write> RootSerializer<'n, W> { + pub fn new(writer: W, root_name: &'n str) -> Self { + Self { writer, root_name } + } + + fn write_header(&mut self) -> Result<(), Error> { + self.writer.write_u8(Tag::Compound as u8)?; + write_string(&mut self.writer, self.root_name) + } +} + +fn not_compound(typ: &str) -> Result { + Err(Error(anyhow!( + "root value must be a map or struct (got {typ})" + ))) +} + +impl<'a, W: Write> Serializer for &'a mut RootSerializer<'_, W> { + type Error = Error; + type Ok = (); + type SerializeMap = SerializeMap<'a, W>; + type SerializeSeq = Impossible; + type SerializeStruct = SerializeStruct<'a, W>; + type SerializeStructVariant = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + + fn serialize_bool(self, _v: bool) -> Result { + not_compound("bool") + } + + fn serialize_i8(self, _v: i8) -> Result { + not_compound("i8") + } + + fn serialize_i16(self, _v: i16) -> Result { + not_compound("i16") + } + + fn serialize_i32(self, _v: i32) -> Result { + not_compound("i32") + } + + fn serialize_i64(self, _v: i64) -> Result { + not_compound("i64") + } + + fn serialize_u8(self, _v: u8) -> Result { + not_compound("u8") + } + + fn serialize_u16(self, _v: u16) -> Result { + not_compound("u16") + } + + fn serialize_u32(self, _v: u32) -> Result { + not_compound("u32") + } + + fn serialize_u64(self, _v: u64) -> Result { + not_compound("u64") + } + + fn serialize_f32(self, _v: f32) -> Result { + not_compound("f32") + } + + fn serialize_f64(self, _v: f64) -> Result { + not_compound("f64") + } + + fn serialize_char(self, _v: char) -> Result { + not_compound("char") + } + + fn serialize_str(self, _v: &str) -> Result { + not_compound("str") + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + not_compound("&[u8]") + } + + fn serialize_none(self) -> Result { + not_compound("None") + } + + fn serialize_some(self, _value: &T) -> Result + where + T: Serialize, + { + not_compound("Some") + } + + fn serialize_unit(self) -> Result { + not_compound("()") + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + not_compound("unit struct") + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + not_compound("unit variant") + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + not_compound("newtype struct") + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + not_compound("newtype variant") + } + + fn serialize_seq(self, _len: Option) -> Result { + not_compound("seq") + } + + fn serialize_tuple(self, _len: usize) -> Result { + not_compound("tuple") + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + not_compound("tuple struct") + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + not_compound("tuple variant") + } + + fn serialize_map(self, _len: Option) -> Result { + self.write_header()?; + + Ok(SerializeMap { + writer: &mut self.writer, + }) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + self.write_header()?; + + Ok(SerializeStruct { + writer: &mut self.writer, + }) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + not_compound("struct variant") + } + + fn is_human_readable(&self) -> bool { + false + } +} diff --git a/valence_nbt/src/binary/ser/seq.rs b/valence_nbt/src/binary/ser/seq.rs new file mode 100644 index 0000000..b00cdf9 --- /dev/null +++ b/valence_nbt/src/binary/ser/seq.rs @@ -0,0 +1,121 @@ +use std::io::Write; + +use anyhow::anyhow; +use byteorder::{BigEndian, WriteBytesExt}; +use serde::{ser, Serialize}; + +use crate::binary::ser::payload::PayloadSerializer; +use crate::{Error, Tag}; + +pub struct SerializeSeq<'w, W: ?Sized> { + writer: &'w mut W, + element_tag: Tag, + remaining: i32, + list_or_array: ListOrArray, +} + +#[derive(Copy, Clone)] +enum ListOrArray { + List, + Array, +} + +impl ListOrArray { + pub const fn name(self) -> &'static str { + match self { + ListOrArray::List => "list", + ListOrArray::Array => "array", + } + } +} + +impl<'w, W: Write + ?Sized> SerializeSeq<'w, W> { + pub(super) fn list(writer: &'w mut W, length: i32) -> Self { + Self { + writer, + element_tag: Tag::End, + remaining: length, + list_or_array: ListOrArray::List, + } + } + + pub(super) fn array(writer: &'w mut W, element_tag: Tag, length: i32) -> Self { + Self { + writer, + element_tag, + remaining: length, + list_or_array: ListOrArray::Array, + } + } +} + +impl ser::SerializeSeq for SerializeSeq<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + if self.remaining <= 0 { + return Err(Error(anyhow!( + "attempt to serialize more {} elements than specified", + self.list_or_array.name() + ))); + } + + match self.list_or_array { + ListOrArray::List => { + if self.element_tag == Tag::End { + let mut ser = + PayloadSerializer::first_list_element(self.writer, self.remaining); + + value.serialize(&mut ser)?; + + self.element_tag = ser.written_tag().expect("tag must have been written"); + } else { + value.serialize(&mut PayloadSerializer::seq_element( + self.writer, + self.element_tag, + ))?; + } + } + ListOrArray::Array => { + value.serialize(&mut PayloadSerializer::seq_element( + self.writer, + self.element_tag, + ))?; + } + } + + self.remaining -= 1; + Ok(()) + } + + fn end(self) -> Result { + if self.remaining > 0 { + return Err(Error(anyhow!( + "{} {} element(s) left to serialize", + self.remaining, + self.list_or_array.name(), + ))); + } + + match self.list_or_array { + ListOrArray::List => { + // Were any elements written? + if self.element_tag == Tag::End { + // Element type + self.writer.write_u8(Tag::End as u8)?; + // List length. + self.writer.write_i32::(0)?; + } + } + ListOrArray::Array => { + // Array length should be written by the serializer already. + } + } + + Ok(()) + } +} diff --git a/valence_nbt/src/binary/ser/structs.rs b/valence_nbt/src/binary/ser/structs.rs new file mode 100644 index 0000000..c8587a5 --- /dev/null +++ b/valence_nbt/src/binary/ser/structs.rs @@ -0,0 +1,33 @@ +use std::io::Write; + +use byteorder::WriteBytesExt; +use serde::{ser, Serialize}; + +use crate::binary::ser::payload::PayloadSerializer; +use crate::{Error, Tag}; + +pub struct SerializeStruct<'w, W: ?Sized> { + pub(super) writer: &'w mut W, +} + +impl ser::SerializeStruct for SerializeStruct<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> + where + T: Serialize, + { + value + .serialize(&mut PayloadSerializer::named(self.writer, key)) + .map_err(|e| e.context(format!("field `{key}`"))) + } + + fn end(self) -> Result { + Ok(self.writer.write_u8(Tag::End as u8)?) + } +} diff --git a/serde_nbt/src/error.rs b/valence_nbt/src/error.rs similarity index 100% rename from serde_nbt/src/error.rs rename to valence_nbt/src/error.rs diff --git a/valence_nbt/src/lib.rs b/valence_nbt/src/lib.rs new file mode 100644 index 0000000..7e6d1d2 --- /dev/null +++ b/valence_nbt/src/lib.rs @@ -0,0 +1,153 @@ +use std::fmt; +use std::fmt::{Display, Formatter}; + +use anyhow::anyhow; +pub use array::*; +pub use error::*; +use serde::de::Visitor; +use serde::{Deserialize, Deserializer}; +pub use value::*; + +mod array; +mod error; +mod value; + +#[cfg(test)] +mod tests; + +/// (De)serialization support for the binary representation of NBT. +pub mod binary { + pub use de::*; + pub use ser::*; + + mod de; + mod ser; +} + +pub type Result = std::result::Result; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +enum Tag { + End, + Byte, + Short, + Int, + Long, + Float, + Double, + ByteArray, + String, + List, + Compound, + IntArray, + LongArray, +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +enum ArrayType { + Byte, + Int, + Long, +} + +impl ArrayType { + pub const fn element_tag(self) -> Tag { + match self { + ArrayType::Byte => Tag::Byte, + ArrayType::Int => Tag::Int, + ArrayType::Long => Tag::Long, + } + } +} + +impl<'de> Deserialize<'de> for ArrayType { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + struct ArrayTypeVisitor; + + impl<'de> Visitor<'de> for ArrayTypeVisitor { + type Value = ArrayType; + + fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + write!(formatter, "a u8 or string encoding an NBT array type") + } + + fn visit_u8(self, v: u8) -> std::result::Result + where + E: serde::de::Error, + { + match v { + 0 => Ok(ArrayType::Byte), + 1 => Ok(ArrayType::Int), + 2 => Ok(ArrayType::Long), + i => Err(E::custom(format!("invalid array type index `{i}`"))), + } + } + + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + match v { + BYTE_ARRAY_VARIANT_NAME => Ok(ArrayType::Byte), + INT_ARRAY_VARIANT_NAME => Ok(ArrayType::Int), + LONG_ARRAY_VARIANT_NAME => Ok(ArrayType::Long), + s => Err(E::custom(format!("invalid array type `{s}`"))), + } + } + } + + deserializer.deserialize_u8(ArrayTypeVisitor) + } +} + +impl Tag { + pub fn from_u8(id: u8) -> Result { + match id { + 0 => Ok(Tag::End), + 1 => Ok(Tag::Byte), + 2 => Ok(Tag::Short), + 3 => Ok(Tag::Int), + 4 => Ok(Tag::Long), + 5 => Ok(Tag::Float), + 6 => Ok(Tag::Double), + 7 => Ok(Tag::ByteArray), + 8 => Ok(Tag::String), + 9 => Ok(Tag::List), + 10 => Ok(Tag::Compound), + 11 => Ok(Tag::IntArray), + 12 => Ok(Tag::LongArray), + _ => Err(Error(anyhow!("invalid tag byte `{id}`"))), + } + } +} + +impl Display for Tag { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let name = match self { + Tag::End => "end", + Tag::Byte => "byte", + Tag::Short => "short", + Tag::Int => "int", + Tag::Long => "long", + Tag::Float => "float", + Tag::Double => "double", + Tag::ByteArray => "byte array", + Tag::String => "string", + Tag::List => "list", + Tag::Compound => "compound", + Tag::IntArray => "int array", + Tag::LongArray => "long array", + }; + + write!(f, "{name}") + } +} + +const ARRAY_ENUM_NAME: &str = "__array__"; + +const BYTE_ARRAY_VARIANT_NAME: &str = "__byte_array__"; +const INT_ARRAY_VARIANT_NAME: &str = "__int_array__"; +const LONG_ARRAY_VARIANT_NAME: &str = "__long_array__"; diff --git a/serde_nbt/src/tests.rs b/valence_nbt/src/tests.rs similarity index 59% rename from serde_nbt/src/tests.rs rename to valence_nbt/src/tests.rs index 7e47704..d093015 100644 --- a/serde_nbt/src/tests.rs +++ b/valence_nbt/src/tests.rs @@ -1,8 +1,8 @@ -use binary::{from_reader, to_writer, Deserializer, Serializer}; -use ordered_float::OrderedFloat; +use pretty_assertions::assert_eq; use serde::{Deserialize, Serialize}; -use super::*; +use crate::binary::{from_reader, to_writer, Deserializer, Serializer}; +use crate::{byte_array, int_array, long_array, Compound, List, Value}; const ROOT_NAME: &str = "The root nameā€½"; @@ -13,14 +13,22 @@ struct Struct { list_of_string: Vec, string: String, inner: Inner, - #[serde(serialize_with = "int_array")] + #[serde(with = "int_array")] int_array: Vec, - #[serde(serialize_with = "byte_array")] + #[serde(with = "byte_array")] byte_array: Vec, - #[serde(serialize_with = "long_array")] + #[serde(with = "long_array")] long_array: Vec, } +#[derive(PartialEq, Debug, Serialize, Deserialize)] +struct Inner { + int: i32, + long: i64, + float: f32, + double: f64, +} + impl Struct { pub fn new() -> Self { Self { @@ -31,8 +39,8 @@ impl Struct { inner: Inner { int: i32::MIN, long: i64::MAX, - nan_float: OrderedFloat(f32::NAN), - neg_inf_double: f64::NEG_INFINITY, + float: 1e10_f32, + double: f64::NEG_INFINITY, }, int_array: vec![5, -9, i32::MIN, 0, i32::MAX], byte_array: vec![0, 1, 2], @@ -43,7 +51,7 @@ impl Struct { pub fn value() -> Value { Value::Compound( Compound::from_iter([ - ("byte".into(), 123.into()), + ("byte".into(), 123_i8.into()), ("list_of_int".into(), List::Int(vec![3, -7, 5]).into()), ( "list_of_string".into(), @@ -55,8 +63,8 @@ impl Struct { Compound::from_iter([ ("int".into(), i32::MIN.into()), ("long".into(), i64::MAX.into()), - ("nan_float".into(), f32::NAN.into()), - ("neg_inf_double".into(), f64::NEG_INFINITY.into()), + ("float".into(), 1e10_f32.into()), + ("double".into(), f64::NEG_INFINITY.into()), ]) .into(), ), @@ -72,19 +80,12 @@ impl Struct { } } -#[derive(PartialEq, Debug, Serialize, Deserialize)] -struct Inner { - int: i32, - long: i64, - nan_float: OrderedFloat, - neg_inf_double: f64, -} - #[test] -fn round_trip() { +fn round_trip_binary_struct() { + let mut buf = Vec::new(); + let struct_ = Struct::new(); - let mut buf = Vec::new(); struct_ .serialize(&mut Serializer::new(&mut buf, ROOT_NAME)) .unwrap(); @@ -93,28 +94,45 @@ fn round_trip() { let mut de = Deserializer::new(reader, true); - let example_de = Struct::deserialize(&mut de).unwrap(); + let struct_de = Struct::deserialize(&mut de).unwrap(); - assert_eq!(struct_, example_de); - - let (_, root) = de.into_inner(); - - assert_eq!(root.unwrap(), ROOT_NAME); + assert_eq!(struct_, struct_de); + assert_eq!(de.root_name, ROOT_NAME); } #[test] -fn serialize() { - let struct_ = Struct::new(); - +fn round_trip_binary_value() { let mut buf = Vec::new(); + let value = Struct::value(); + + value + .serialize(&mut Serializer::new(&mut buf, ROOT_NAME)) + .unwrap(); + + let reader = &mut buf.as_slice(); + + let mut de = Deserializer::new(reader, true); + + let value_de = Value::deserialize(&mut de).unwrap(); + + assert_eq!(value, value_de); + assert_eq!(de.root_name, ROOT_NAME); +} + +#[test] +fn to_hematite() { + let mut buf = Vec::new(); + + let struct_ = Struct::new(); + struct_ .serialize(&mut Serializer::new(&mut buf, ROOT_NAME)) .unwrap(); - let example_de: Struct = nbt::from_reader(&mut buf.as_slice()).unwrap(); + let struct_de: Struct = nbt::from_reader(&mut buf.as_slice()).unwrap(); - assert_eq!(struct_, example_de); + assert_eq!(struct_, struct_de); } #[test] @@ -126,10 +144,10 @@ fn root_requires_compound() { } #[test] -fn invalid_array_element() { +fn mismatched_array_element() { #[derive(Serialize)] struct Struct { - #[serde(serialize_with = "byte_array")] + #[serde(with = "byte_array")] data: Vec, } @@ -143,30 +161,40 @@ fn invalid_array_element() { .is_err()); } -// #[test] -// fn struct_to_value() { -// let mut buf = Vec::new(); -// -// to_writer(&mut buf, ROOT_NAME, &Struct::new()).unwrap(); -// -// let reader = &mut buf.as_slice(); -// -// let val: Value = from_reader(reader).unwrap(); -// -// eprintln!("{:#?}", Struct::value()); -// -// assert_eq!(val, Struct::value()); -// } +#[test] +fn struct_to_value() { + let mut buf = Vec::new(); + + let struct_ = Struct::new(); + + to_writer(&mut buf, &struct_).unwrap(); + + let val: Value = from_reader(&mut buf.as_slice()).unwrap(); + + assert_eq!(val, Struct::value()); +} #[test] fn value_to_struct() { let mut buf = Vec::new(); - to_writer(&mut buf, ROOT_NAME, &Struct::value()).unwrap(); + to_writer(&mut buf, &Struct::value()).unwrap(); - let reader = &mut buf.as_slice(); - - let struct_: Struct = from_reader(reader).unwrap(); + let struct_: Struct = from_reader(&mut buf.as_slice()).unwrap(); assert_eq!(struct_, Struct::new()); } + +#[test] +fn value_from_json() { + let mut struct_ = Struct::new(); + + // JSON numbers only allow finite floats. + struct_.inner.double = 12345.0; + + let string = serde_json::to_string_pretty(&struct_).unwrap(); + + let struct_de: Struct = serde_json::from_str(&string).unwrap(); + + assert_eq!(struct_, struct_de); +} diff --git a/serde_nbt/src/value.rs b/valence_nbt/src/value.rs similarity index 90% rename from serde_nbt/src/value.rs rename to valence_nbt/src/value.rs index 899c981..ba1bd00 100644 --- a/serde_nbt/src/value.rs +++ b/valence_nbt/src/value.rs @@ -1,12 +1,11 @@ use std::borrow::Cow; -use std::collections::HashMap; use std::fmt; -use std::fmt::Formatter; -use serde::de::{DeserializeSeed, Error, MapAccess, SeqAccess, Visitor}; +use indexmap::IndexMap; +use serde::de::{DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::{byte_array, int_array, long_array}; +use crate::{byte_array, int_array, long_array, ArrayType}; /// Represents an arbitrary NBT value. #[derive(Clone, PartialEq, Debug)] @@ -25,7 +24,7 @@ pub enum Value { LongArray(Vec), } -pub type Compound = HashMap; +pub type Compound = IndexMap; /// An NBT list value. /// @@ -246,12 +245,12 @@ impl Serialize for Value { Value::Long(v) => v.serialize(serializer), Value::Float(v) => v.serialize(serializer), Value::Double(v) => v.serialize(serializer), - Value::ByteArray(v) => byte_array(v, serializer), + Value::ByteArray(v) => byte_array::serialize(v, serializer), Value::String(v) => v.serialize(serializer), Value::List(v) => v.serialize(serializer), Value::Compound(v) => v.serialize(serializer), - Value::IntArray(v) => int_array(v, serializer), - Value::LongArray(v) => long_array(v, serializer), + Value::IntArray(v) => int_array::serialize(v, serializer), + Value::LongArray(v) => long_array::serialize(v, serializer), } } } @@ -292,7 +291,7 @@ struct ValueVisitor; impl<'de> Visitor<'de> for ValueVisitor { type Value = Value; - fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "a representable NBT value") } @@ -338,13 +337,6 @@ impl<'de> Visitor<'de> for ValueVisitor { Ok(Value::Double(v)) } - fn visit_string(self, v: String) -> Result - where - E: Error, - { - Ok(Value::String(v)) - } - fn visit_str(self, v: &str) -> Result where E: Error, @@ -352,6 +344,13 @@ impl<'de> Visitor<'de> for ValueVisitor { Ok(Value::String(v.to_owned())) } + fn visit_string(self, v: String) -> Result + where + E: Error, + { + Ok(Value::String(v)) + } + fn visit_seq(self, seq: A) -> Result where A: SeqAccess<'de>, @@ -365,6 +364,19 @@ impl<'de> Visitor<'de> for ValueVisitor { { visit_map(map).map(Value::Compound) } + + fn visit_enum(self, data: A) -> Result + where + A: EnumAccess<'de>, + { + let (array_type, variant) = data.variant()?; + + Ok(match array_type { + ArrayType::Byte => Value::ByteArray(variant.newtype_variant()?), + ArrayType::Int => Value::IntArray(variant.newtype_variant()?), + ArrayType::Long => Value::LongArray(variant.newtype_variant()?), + }) + } } impl<'de> Deserialize<'de> for List { @@ -381,7 +393,7 @@ struct ListVisitor; impl<'de> Visitor<'de> for ListVisitor { type Value = List; - fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "an NBT list") } @@ -430,7 +442,7 @@ macro_rules! visit { impl<'de, 'a> Visitor<'de> for DeserializeListElement<'a> { type Value = (); - fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "a valid NBT list element") }