diff --git a/Cargo.toml b/Cargo.toml index a15c814..f0f3f63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ num-integer = "0.1.45" owo-colors = "3.5.0" parking_lot = "0.12.1" paste = "1.0.11" +pretty_assertions = "1.3.0" proc-macro2 = "1.0.56" quote = "1.0.26" rand = "0.8.5" @@ -84,7 +85,7 @@ valence_dimension.path = "crates/valence_dimension" valence_entity.path = "crates/valence_entity" valence_instance.path = "crates/valence_instance" valence_inventory.path = "crates/valence_inventory" -valence_nbt = { path = "crates/valence_nbt", features = ["uuid"] } +valence_nbt = { path = "crates/valence_nbt" } valence_network.path = "crates/valence_network" valence_player_list.path = "crates/valence_player_list" valence_registry.path = "crates/valence_registry" diff --git a/crates/valence_anvil/src/lib.rs b/crates/valence_anvil/src/lib.rs index 2b635ee..5b5d31e 100644 --- a/crates/valence_anvil/src/lib.rs +++ b/crates/valence_anvil/src/lib.rs @@ -56,7 +56,7 @@ pub enum ReadChunkError { #[error(transparent)] Io(#[from] io::Error), #[error(transparent)] - Nbt(#[from] valence_nbt::Error), + Nbt(#[from] valence_nbt::binary::Error), #[error("invalid chunk sector offset")] BadSectorOffset, #[error("invalid chunk size")] @@ -180,7 +180,7 @@ impl AnvilWorld { b => return Err(ReadChunkError::UnknownCompressionScheme(b)), }; - let (data, _) = valence_nbt::from_binary_slice(&mut nbt_slice)?; + let (data, _) = Compound::from_binary(&mut nbt_slice)?; if !nbt_slice.is_empty() { return Err(ReadChunkError::IncompleteNbtRead); diff --git a/crates/valence_core/Cargo.toml b/crates/valence_core/Cargo.toml index 74a8baa..4430321 100644 --- a/crates/valence_core/Cargo.toml +++ b/crates/valence_core/Cargo.toml @@ -24,7 +24,7 @@ serde_json.workspace = true thiserror.workspace = true tracing.workspace = true uuid = { workspace = true, features = ["serde"] } -valence_nbt.workspace = true +valence_nbt = { workspace = true, features = ["binary"] } valence_core_macros.workspace = true url.workspace = true base64.workspace = true diff --git a/crates/valence_core/src/protocol/impls.rs b/crates/valence_core/src/protocol/impls.rs index 81aa9b4..e14bb21 100644 --- a/crates/valence_core/src/protocol/impls.rs +++ b/crates/valence_core/src/protocol/impls.rs @@ -1,5 +1,6 @@ //! [`Encode`] and [`Decode`] impls on foreign types. +use core::slice; use std::borrow::Cow; use std::collections::{BTreeSet, HashSet}; use std::hash::{BuildHasher, Hash}; @@ -27,8 +28,8 @@ impl Encode for bool { fn encode_slice(slice: &[bool], mut w: impl Write) -> Result<()> { // SAFETY: Bools have the same layout as u8. + let bytes = unsafe { slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len()) }; // Bools are guaranteed to have the correct bit pattern. - let bytes: &[u8] = unsafe { mem::transmute(slice) }; Ok(w.write_all(bytes)?) } } @@ -64,7 +65,7 @@ impl Encode for i8 { fn encode_slice(slice: &[i8], mut w: impl Write) -> Result<()> { // SAFETY: i8 has the same layout as u8. - let bytes: &[u8] = unsafe { mem::transmute(slice) }; + let bytes = unsafe { slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len()) }; Ok(w.write_all(bytes)?) } } @@ -471,7 +472,6 @@ impl Encode for [T; N] { impl<'a, const N: usize, T: Decode<'a>> Decode<'a> for [T; N] { fn decode(r: &mut &'a [u8]) -> Result { // TODO: rewrite using std::array::try_from_fn when stabilized? - // TODO: specialization for [f64; 3] improved performance. let mut data: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; @@ -539,9 +539,12 @@ impl<'a> Decode<'a> for &'a [u8] { impl<'a> Decode<'a> for &'a [i8] { fn decode(r: &mut &'a [u8]) -> Result { - let unsigned_bytes = <&[u8]>::decode(r)?; - let signed_bytes: &[i8] = unsafe { mem::transmute(unsigned_bytes) }; - Ok(signed_bytes) + let bytes = <&[u8]>::decode(r)?; + + // SAFETY: i8 and u8 have the same layout. + let bytes = unsafe { slice::from_raw_parts(bytes.as_ptr() as *const i8, bytes.len()) }; + + Ok(bytes) } } @@ -765,12 +768,12 @@ impl<'a> Decode<'a> for Uuid { impl Encode for Compound { fn encode(&self, w: impl Write) -> Result<()> { - Ok(valence_nbt::to_binary_writer(w, self, "")?) + Ok(self.to_binary(w, "")?) } } impl Decode<'_> for Compound { fn decode(r: &mut &[u8]) -> Result { - Ok(valence_nbt::from_binary_slice(r)?.0) + Ok(Self::from_binary(r)?.0) } } diff --git a/crates/valence_nbt/Cargo.toml b/crates/valence_nbt/Cargo.toml index 8b95af2..1ec15ee 100644 --- a/crates/valence_nbt/Cargo.toml +++ b/crates/valence_nbt/Cargo.toml @@ -10,11 +10,20 @@ version = "0.5.0" edition.workspace = true [features] +binary = ["dep:byteorder", "dep:cesu8"] +snbt = [] # When enabled, the order of fields in compounds are preserved. preserve_order = ["dep:indexmap"] +serde = ["dep:serde", "dep:thiserror", "indexmap?/serde"] [dependencies] -byteorder.workspace = true -cesu8.workspace = true +byteorder = { workspace = true, optional = true } +cesu8 = { workspace = true, optional = true } indexmap = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"], optional = true } +thiserror = { workspace = true, optional = true } uuid = { workspace = true, optional = true } + +[dev-dependencies] +pretty_assertions.workspace = true +serde_json.workspace = true diff --git a/crates/valence_nbt/README.md b/crates/valence_nbt/README.md index 4aeba9e..1d07ce1 100644 --- a/crates/valence_nbt/README.md +++ b/crates/valence_nbt/README.md @@ -5,48 +5,10 @@ format. [Named Binary Tag]: https://minecraft.fandom.com/wiki/NBT_format -# Examples - -Encode NBT data to its binary form. We are using the [`compound!`] macro to -conveniently construct [`Compound`] values. - -```rust -use valence_nbt::{compound, to_binary_writer, List}; - -let c = compound! { - "byte" => 5_i8, - "string" => "hello", - "list_of_float" => List::Float(vec![ - 3.1415, - 2.7182, - 1.4142 - ]), -}; - -let mut buf = vec![]; - -to_binary_writer(&mut buf, &c, "").unwrap(); -``` - -Decode NBT data from its binary form. - -```rust -use valence_nbt::{compound, from_binary_slice}; - -let some_bytes = [10, 0, 0, 3, 0, 3, 105, 110, 116, 0, 0, 222, 173, 0]; - -let expected_value = compound! { - "int" => 0xdead -}; - -let (nbt, root_name) = from_binary_slice(&mut some_bytes.as_slice()).unwrap(); - -assert_eq!(nbt, expected_value); -assert_eq!(root_name, ""); -``` - # Features - +- `binary`: Adds support for serializing and deserializing in Java edition's binary format. +- `snbt`: Adds support for serializing and deserializing in "stringified" format. - `preserve_order`: Causes the order of fields in [`Compound`]s to be preserved during insertion and deletion at a slight cost to performance. The iterators on `Compound` can then implement [`DoubleEndedIterator`]. +- `serde` Adds support for [`serde`](https://docs.rs/serde/latest/serde/) diff --git a/crates/valence_nbt/src/binary.rs b/crates/valence_nbt/src/binary.rs new file mode 100644 index 0000000..7e07465 --- /dev/null +++ b/crates/valence_nbt/src/binary.rs @@ -0,0 +1,48 @@ +//! Support for serializing and deserializing compounds in Java edition's binary +//! format. +//! +//! # Examples +//! +//! ``` +//! use valence_nbt::{compound, Compound, List}; +//! +//! let c = compound! { +//! "byte" => 5_i8, +//! "string" => "hello", +//! "list_of_float" => List::Float(vec![ +//! 3.1415, +//! 2.7182, +//! 1.4142 +//! ]), +//! }; +//! +//! let mut buf = vec![]; +//! +//! c.to_binary(&mut buf, "").unwrap(); +//! ``` +//! +//! Decode NBT data from its binary form. +//! +//! ``` +//! use valence_nbt::{compound, Compound}; +//! +//! let some_bytes = [10, 0, 0, 3, 0, 3, 105, 110, 116, 0, 0, 222, 173, 0]; +//! +//! let expected_value = compound! { +//! "int" => 0xdead +//! }; +//! +//! let (nbt, root_name) = Compound::from_binary(&mut some_bytes.as_slice()).unwrap(); +//! +//! assert_eq!(nbt, expected_value); +//! assert_eq!(root_name, ""); +//! ``` + +mod decode; +mod encode; +mod error; +mod modified_utf8; +#[cfg(test)] +mod tests; + +pub use error::*; diff --git a/crates/valence_nbt/src/to_binary_writer.rs b/crates/valence_nbt/src/binary/decode.rs similarity index 67% rename from crates/valence_nbt/src/to_binary_writer.rs rename to crates/valence_nbt/src/binary/decode.rs index 9003ad0..b740e56 100644 --- a/crates/valence_nbt/src/to_binary_writer.rs +++ b/crates/valence_nbt/src/binary/decode.rs @@ -1,77 +1,88 @@ use std::io::Write; +use std::slice; use byteorder::{BigEndian, WriteBytesExt}; +use super::{modified_utf8, Error, Result}; use crate::tag::Tag; -use crate::{modified_utf8, Compound, Error, List, Result, Value}; +use crate::{Compound, List, Value}; -/// Encodes uncompressed NBT binary data to the provided writer. -/// -/// Only compounds are permitted at the top level. This is why the function -/// accepts a [`Compound`] reference rather than a [`Value`]. -/// -/// Additionally, the root compound can be given a name. Typically the empty -/// string `""` is used. -pub fn to_binary_writer(writer: W, compound: &Compound, root_name: &str) -> Result<()> { - let mut state = EncodeState { writer }; +impl Compound { + /// Encodes uncompressed NBT binary data to the provided writer. + /// + /// Only compounds are permitted at the top level. This is why the function + /// accepts a [`Compound`] reference rather than a [`Value`]. + /// + /// Additionally, the root compound can be given a name. Typically the empty + /// string `""` is used. + pub fn to_binary(&self, writer: W, root_name: &str) -> Result<()> { + let mut state = EncodeState { writer }; - state.write_tag(Tag::Compound)?; - state.write_string(root_name)?; - state.write_compound(compound)?; + state.write_tag(Tag::Compound)?; + state.write_string(root_name)?; + state.write_compound(self)?; - Ok(()) -} + Ok(()) + } -pub(crate) fn written_size(compound: &Compound, root_name: &str) -> usize { - fn value_size(val: &Value) -> usize { - match val { - Value::Byte(_) => 1, - Value::Short(_) => 2, - Value::Int(_) => 4, - Value::Long(_) => 8, - Value::Float(_) => 4, - Value::Double(_) => 8, - Value::ByteArray(ba) => 4 + ba.len(), - Value::String(s) => string_size(s), - Value::List(l) => list_size(l), - Value::Compound(c) => compound_size(c), - Value::IntArray(ia) => 4 + ia.len() * 4, - Value::LongArray(la) => 4 + la.len() * 8, + /// Returns the number of bytes that will be written when + /// [`Compound::to_binary`] is called with this compound and root name. + /// + /// If `to_binary` results in `Ok`, the exact number of bytes + /// reported by this function will have been written. If the result is + /// `Err`, then the reported count will be greater than or equal to the + /// number of bytes that have actually been written. + pub fn written_size(&self, root_name: &str) -> usize { + fn value_size(val: &Value) -> usize { + match val { + Value::Byte(_) => 1, + Value::Short(_) => 2, + Value::Int(_) => 4, + Value::Long(_) => 8, + Value::Float(_) => 4, + Value::Double(_) => 8, + Value::ByteArray(ba) => 4 + ba.len(), + Value::String(s) => string_size(s), + Value::List(l) => list_size(l), + Value::Compound(c) => compound_size(c), + Value::IntArray(ia) => 4 + ia.len() * 4, + Value::LongArray(la) => 4 + la.len() * 8, + } } + + fn list_size(l: &List) -> usize { + let elems_size = match l { + List::End => 0, + List::Byte(b) => b.len(), + List::Short(s) => s.len() * 2, + List::Int(i) => i.len() * 4, + List::Long(l) => l.len() * 8, + List::Float(f) => f.len() * 4, + List::Double(d) => d.len() * 8, + List::ByteArray(ba) => ba.iter().map(|b| 4 + b.len()).sum(), + List::String(s) => s.iter().map(|s| string_size(s)).sum(), + List::List(l) => l.iter().map(list_size).sum(), + List::Compound(c) => c.iter().map(compound_size).sum(), + List::IntArray(i) => i.iter().map(|i| 4 + i.len() * 4).sum(), + List::LongArray(l) => l.iter().map(|l| 4 + l.len() * 8).sum(), + }; + + 1 + 4 + elems_size + } + + fn string_size(s: &str) -> usize { + 2 + modified_utf8::encoded_len(s) + } + + fn compound_size(c: &Compound) -> usize { + c.iter() + .map(|(k, v)| 1 + string_size(k) + value_size(v)) + .sum::() + + 1 + } + + 1 + string_size(root_name) + compound_size(self) } - - fn list_size(l: &List) -> usize { - let elems_size = match l { - List::End => 0, - List::Byte(b) => b.len(), - List::Short(s) => s.len() * 2, - List::Int(i) => i.len() * 4, - List::Long(l) => l.len() * 8, - List::Float(f) => f.len() * 4, - List::Double(d) => d.len() * 8, - List::ByteArray(ba) => ba.iter().map(|b| 4 + b.len()).sum(), - List::String(s) => s.iter().map(|s| string_size(s)).sum(), - List::List(l) => l.iter().map(list_size).sum(), - List::Compound(c) => c.iter().map(compound_size).sum(), - List::IntArray(i) => i.iter().map(|i| 4 + i.len() * 4).sum(), - List::LongArray(l) => l.iter().map(|l| 4 + l.len() * 8).sum(), - }; - - 1 + 4 + elems_size - } - - fn string_size(s: &str) -> usize { - 2 + modified_utf8::encoded_len(s) - } - - fn compound_size(c: &Compound) -> usize { - c.iter() - .map(|(k, v)| 1 + string_size(k) + value_size(v)) - .sum::() - + 1 - } - - 1 + string_size(root_name) + compound_size(compound) } struct EncodeState { @@ -136,7 +147,7 @@ impl EncodeState { } // SAFETY: i8 has the same layout as u8. - let bytes: &[u8] = unsafe { std::mem::transmute(bytes) }; + let bytes = unsafe { slice::from_raw_parts(bytes.as_ptr() as *const u8, bytes.len()) }; Ok(self.writer.write_all(bytes)?) } @@ -187,7 +198,7 @@ impl EncodeState { } // SAFETY: i8 has the same layout as u8. - let bytes: &[u8] = unsafe { std::mem::transmute(bl.as_slice()) }; + let bytes = unsafe { slice::from_raw_parts(bl.as_ptr() as *const u8, bl.len()) }; Ok(self.writer.write_all(bytes)?) } diff --git a/crates/valence_nbt/src/from_binary_slice.rs b/crates/valence_nbt/src/binary/encode.rs similarity index 89% rename from crates/valence_nbt/src/from_binary_slice.rs rename to crates/valence_nbt/src/binary/encode.rs index 2ac1f3d..5478097 100644 --- a/crates/valence_nbt/src/from_binary_slice.rs +++ b/crates/valence_nbt/src/binary/encode.rs @@ -3,36 +3,39 @@ use std::mem; use byteorder::{BigEndian, ReadBytesExt}; use cesu8::Cesu8DecodingError; +use super::{Error, Result}; use crate::tag::Tag; -use crate::{Compound, Error, List, Result, Value}; +use crate::{Compound, List, Value}; -/// Decodes uncompressed NBT binary data from the provided slice. -/// -/// The string returned is the name of the root compound. -pub fn from_binary_slice(slice: &mut &[u8]) -> Result<(Compound, String)> { - let mut state = DecodeState { slice, depth: 0 }; +impl Compound { + /// Decodes uncompressed NBT binary data from the provided slice. + /// + /// The string returned in the tuple is the name of the root compound + /// (typically the empty string). + pub fn from_binary(slice: &mut &[u8]) -> Result<(Self, String)> { + let mut state = DecodeState { slice, depth: 0 }; - let root_tag = state.read_tag()?; + let root_tag = state.read_tag()?; - // For cases such as Block Entity Data in the - // ChunkUpdateAndUpdateLight Packet - // https://wiki.vg/Protocol#Chunk_Data_and_Update_Light - if root_tag == Tag::End { - return Ok((Compound::new(), String::new())); + // For cases such as Block Entity Data in the chunk packet. + // https://wiki.vg/Protocol#Chunk_Data_and_Update_Light + if root_tag == Tag::End { + return Ok((Compound::new(), String::new())); + } + + if root_tag != Tag::Compound { + return Err(Error::new_owned(format!( + "expected root tag for compound (got {root_tag})", + ))); + } + + let root_name = state.read_string()?; + let root = state.read_compound()?; + + debug_assert_eq!(state.depth, 0); + + Ok((root, root_name)) } - - if root_tag != Tag::Compound { - return Err(Error::new_owned(format!( - "expected root tag for compound (got {root_tag})", - ))); - } - - let root_name = state.read_string()?; - let root = state.read_compound()?; - - debug_assert_eq!(state.depth, 0); - - Ok((root, root_name)) } /// Maximum recursion depth to prevent overflowing the call stack. diff --git a/crates/valence_nbt/src/error.rs b/crates/valence_nbt/src/binary/error.rs similarity index 91% rename from crates/valence_nbt/src/error.rs rename to crates/valence_nbt/src/binary/error.rs index a3f3421..8c5c4ef 100644 --- a/crates/valence_nbt/src/error.rs +++ b/crates/valence_nbt/src/binary/error.rs @@ -2,7 +2,9 @@ use std::error::Error as StdError; use std::fmt::{Display, Formatter}; use std::io; -/// Errors that can occur when encoding or decoding. +pub type Result = std::result::Result; + +/// Errors that can occur when encoding or decoding binary NBT. #[derive(Debug)] pub struct Error { /// Box this to keep the size of `Result` small. diff --git a/crates/valence_nbt/src/modified_utf8.rs b/crates/valence_nbt/src/binary/modified_utf8.rs similarity index 100% rename from crates/valence_nbt/src/modified_utf8.rs rename to crates/valence_nbt/src/binary/modified_utf8.rs diff --git a/crates/valence_nbt/src/tests.rs b/crates/valence_nbt/src/binary/tests.rs similarity index 79% rename from crates/valence_nbt/src/tests.rs rename to crates/valence_nbt/src/binary/tests.rs index ee93bc7..897cabb 100644 --- a/crates/valence_nbt/src/tests.rs +++ b/crates/valence_nbt/src/binary/tests.rs @@ -1,5 +1,5 @@ use crate::tag::Tag; -use crate::{compound, from_binary_slice, to_binary_writer, Compound, List, Value}; +use crate::{compound, Compound, List, Value}; const ROOT_NAME: &str = "The root name‽"; @@ -9,11 +9,11 @@ fn round_trip() { let compound = example_compound(); - to_binary_writer(&mut buf, &compound, ROOT_NAME).unwrap(); + compound.to_binary(&mut buf, ROOT_NAME).unwrap(); println!("{buf:?}"); - let (decoded, root_name) = from_binary_slice(&mut buf.as_slice()).unwrap(); + let (decoded, root_name) = Compound::from_binary(&mut buf.as_slice()).unwrap(); assert_eq!(root_name, ROOT_NAME); assert_eq!(compound, decoded); @@ -28,7 +28,7 @@ fn check_min_sizes() { let dbg = format!("{min_val:?}"); let mut buf = vec![]; - to_binary_writer(&mut buf, &compound!("" => min_val), "").unwrap(); + compound!("" => min_val).to_binary(&mut buf, "").unwrap(); assert_eq!( expected_size, @@ -65,7 +65,7 @@ fn deeply_nested_compound_decode() { buf.push(Tag::End as u8); // End root compound // Should not overflow the stack - let _ = from_binary_slice(&mut buf.as_slice()); + let _ = Compound::from_binary(&mut buf.as_slice()); } #[test] @@ -84,7 +84,7 @@ fn deeply_nested_list_decode() { buf.push(Tag::End as u8); // End root compound // Should not overflow the stack - let _ = from_binary_slice(&mut buf.as_slice()); + let _ = Compound::from_binary(&mut buf.as_slice()); } #[test] @@ -92,26 +92,11 @@ fn correct_length() { let c = example_compound(); let mut buf = vec![]; - to_binary_writer(&mut buf, &c, "abc").unwrap(); + c.to_binary(&mut buf, "abc").unwrap(); assert_eq!(c.written_size("abc"), buf.len()); } -#[cfg(feature = "preserve_order")] -#[test] -fn preserves_order() { - let letters = ["g", "b", "d", "e", "h", "z", "m", "a", "q"]; - - let mut c = Compound::new(); - for l in letters { - c.insert(l, 0_i8); - } - - for (k, l) in c.keys().zip(letters) { - assert_eq!(k, l); - } -} - fn example_compound() -> Compound { fn inner() -> Compound { compound! { diff --git a/crates/valence_nbt/src/compound.rs b/crates/valence_nbt/src/compound.rs index f6d65c8..2d2649d 100644 --- a/crates/valence_nbt/src/compound.rs +++ b/crates/valence_nbt/src/compound.rs @@ -4,11 +4,15 @@ use std::hash::Hash; use std::iter::FusedIterator; use std::ops::{Index, IndexMut}; -use crate::to_binary_writer::written_size; use crate::Value; /// A map type with [`String`] keys and [`Value`] values. #[derive(Clone, PartialEq, Default)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize), + serde(transparent) +)] pub struct Compound { map: Map, } @@ -19,74 +23,6 @@ type Map = std::collections::BTreeMap; #[cfg(feature = "preserve_order")] type Map = indexmap::IndexMap; -impl Compound { - /// Returns the number of bytes that will be written when - /// [`to_binary_writer`] is called with this compound and root name. - /// - /// If [`to_binary_writer`] results in `Ok`, the exact number of bytes - /// reported by this function will have been written. If the result is - /// `Err`, then the reported count will be greater than or equal to the - /// number of bytes that have actually been written. - /// - /// [`to_binary_writer`]: crate::to_binary_writer() - pub fn written_size(&self, root_name: &str) -> usize { - written_size(self, root_name) - } - - /// Inserts all items from `other` into `self` recursively. - /// - /// # Example - /// - /// ``` - /// use valence_nbt::compound; - /// - /// let mut this = compound! { - /// "foo" => 10, - /// "bar" => compound! { - /// "baz" => 20, - /// } - /// }; - /// - /// let other = compound! { - /// "foo" => 15, - /// "bar" => compound! { - /// "quux" => "hello", - /// } - /// }; - /// - /// this.merge(other); - /// - /// assert_eq!( - /// this, - /// compound! { - /// "foo" => 15, - /// "bar" => compound! { - /// "baz" => 20, - /// "quux" => "hello", - /// } - /// } - /// ); - /// ``` - pub fn merge(&mut self, other: Compound) { - for (k, v) in other { - match (self.entry(k), v) { - (Entry::Occupied(mut oe), Value::Compound(other)) => { - if let Value::Compound(this) = oe.get_mut() { - // Insert compound recursively. - this.merge(other); - } - } - (Entry::Occupied(mut oe), value) => { - oe.insert(value); - } - (Entry::Vacant(ve), value) => { - ve.insert(value); - } - } - } - } -} - impl fmt::Debug for Compound { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.map.fmt(f) @@ -241,6 +177,59 @@ impl Compound { { self.map.retain(f) } + + /// Inserts all items from `other` into `self` recursively. + /// + /// # Example + /// + /// ``` + /// use valence_nbt::compound; + /// + /// let mut this = compound! { + /// "foo" => 10, + /// "bar" => compound! { + /// "baz" => 20, + /// } + /// }; + /// + /// let other = compound! { + /// "foo" => 15, + /// "bar" => compound! { + /// "quux" => "hello", + /// } + /// }; + /// + /// this.merge(other); + /// + /// assert_eq!( + /// this, + /// compound! { + /// "foo" => 15, + /// "bar" => compound! { + /// "baz" => 20, + /// "quux" => "hello", + /// } + /// } + /// ); + /// ``` + pub fn merge(&mut self, other: Compound) { + for (k, v) in other { + match (self.entry(k), v) { + (Entry::Occupied(mut oe), Value::Compound(other)) => { + if let Value::Compound(this) = oe.get_mut() { + // Insert compound recursively. + this.merge(other); + } + } + (Entry::Occupied(mut oe), value) => { + oe.insert(value); + } + (Entry::Vacant(ve), value) => { + ve.insert(value); + } + } + } + } } impl Extend<(String, Value)> for Compound { @@ -502,3 +491,23 @@ pub struct ValuesMut<'a> { } impl_iterator_traits!((ValuesMut<'a>) => &'a mut Value); + +#[cfg(test)] +mod tests { + #[cfg(feature = "preserve_order")] + #[test] + fn compound_preserves_order() { + use super::*; + + let letters = ["g", "b", "d", "e", "h", "z", "m", "a", "q"]; + + let mut c = Compound::new(); + for l in letters { + c.insert(l, 0_i8); + } + + for (k, l) in c.keys().zip(letters) { + assert_eq!(k, l); + } + } +} diff --git a/crates/valence_nbt/src/lib.rs b/crates/valence_nbt/src/lib.rs index 5038f92..addcee0 100644 --- a/crates/valence_nbt/src/lib.rs +++ b/crates/valence_nbt/src/lib.rs @@ -18,25 +18,18 @@ )] pub use compound::Compound; -pub use error::Error; -pub use from_binary_slice::from_binary_slice; pub use tag::Tag; -pub use to_binary_writer::to_binary_writer; pub use value::{List, Value}; +#[cfg(feature = "binary")] +pub mod binary; pub mod compound; -mod error; -mod from_binary_slice; -mod modified_utf8; +#[cfg(feature = "serde")] +pub mod serde; +#[cfg(feature = "snbt")] pub mod snbt; -mod to_binary_writer; -pub mod value; - mod tag; -#[cfg(test)] -mod tests; - -type Result = std::result::Result; +pub mod value; /// A convenience macro for constructing [`Compound`]s. /// diff --git a/crates/valence_nbt/src/serde.rs b/crates/valence_nbt/src/serde.rs new file mode 100644 index 0000000..7dd780f --- /dev/null +++ b/crates/valence_nbt/src/serde.rs @@ -0,0 +1,160 @@ +use std::fmt; +use std::mem::ManuallyDrop; + +pub use ser::*; +use thiserror::Error; + +mod de; +mod ser; + +/// Errors that can occur while serializing or deserializing. +#[derive(Clone, Error, Debug)] +#[error("{0}")] + +pub struct Error(Box); + +impl Error { + fn new(s: impl Into>) -> Self { + Self(s.into()) + } +} + +impl serde::de::Error for Error { + fn custom(msg: T) -> Self + where + T: fmt::Display, + { + Self::new(format!("{msg}")) + } +} + +impl serde::ser::Error for Error { + fn custom(msg: T) -> Self + where + T: fmt::Display, + { + Self::new(format!("{msg}")) + } +} + +#[inline] +fn u8_vec_to_i8_vec(vec: Vec) -> Vec { + // SAFETY: Layouts of u8 and i8 are the same and we're being careful not to drop + // the original vec after calling Vec::from_raw_parts. + unsafe { + let mut vec = ManuallyDrop::new(vec); + Vec::from_raw_parts(vec.as_mut_ptr() as *mut i8, vec.len(), vec.capacity()) + } +} + +#[inline] +fn i8_vec_to_u8_vec(vec: Vec) -> Vec { + // SAFETY: Layouts of u8 and i8 are the same and we're being careful not to drop + // the original vec after calling Vec::from_raw_parts. + unsafe { + let mut vec = ManuallyDrop::new(vec); + Vec::from_raw_parts(vec.as_mut_ptr() as *mut u8, vec.len(), vec.capacity()) + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde::{Deserialize, Serialize}; + use serde_json::json; + + use super::*; + use crate::{compound, Compound, List}; + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Struct { + foo: i32, + bar: StructInner, + baz: String, + quux: Vec, + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct StructInner { + a: bool, + b: i64, + c: Vec>, + d: Vec, + } + + fn make_struct() -> Struct { + Struct { + foo: i32::MIN, + bar: StructInner { + a: true, + b: 123456789, + c: vec![vec![1, 2, 3], vec![4, 5, 6]], + d: vec![], + }, + baz: "🤨".into(), + quux: vec![std::f32::consts::PI, f32::MAX, f32::MIN], + } + } + + fn make_compound() -> Compound { + compound! { + "foo" => i32::MIN, + "bar" => compound! { + "a" => true, + "b" => 123456789_i64, + "c" => List::IntArray(vec![vec![1, 2, 3], vec![4, 5, 6]]), + "d" => List::End, + }, + "baz" => "🤨", + "quux" => List::Float(vec![ + std::f32::consts::PI, + f32::MAX, + f32::MIN, + ]), + } + } + + fn make_json() -> serde_json::Value { + json!({ + "foo": i32::MIN, + "bar": { + "a": true, + "b": 123456789_i64, + "c": [[1, 2, 3], [4, 5, 6]], + "d": [] + }, + "baz": "🤨", + "quux": [ + std::f32::consts::PI, + f32::MAX, + f32::MIN, + ] + }) + } + + #[test] + fn struct_to_compound() { + let c = make_struct().serialize(CompoundSerializer).unwrap(); + + assert_eq!(c, make_compound()); + } + + #[test] + fn compound_to_struct() { + let s = Struct::deserialize(make_compound()).unwrap(); + + assert_eq!(s, make_struct()); + } + + #[test] + fn compound_to_json() { + let mut j = serde_json::to_value(make_compound()).unwrap(); + + // Bools map to bytes in NBT, but the result should be the same otherwise. + let p = j.pointer_mut("/bar/a").unwrap(); + assert_eq!(*p, serde_json::Value::from(1)); + *p = true.into(); + + assert_eq!(j, make_json()); + } +} diff --git a/crates/valence_nbt/src/serde/de.rs b/crates/valence_nbt/src/serde/de.rs new file mode 100644 index 0000000..199618a --- /dev/null +++ b/crates/valence_nbt/src/serde/de.rs @@ -0,0 +1,356 @@ +use std::{fmt, slice}; + +use serde::de::value::{MapAccessDeserializer, MapDeserializer, SeqAccessDeserializer}; +use serde::de::{self, IntoDeserializer, SeqAccess, Visitor}; +use serde::{forward_to_deserialize_any, Deserialize, Deserializer}; + +use super::Error; +use crate::serde::{i8_vec_to_u8_vec, u8_vec_to_i8_vec}; +use crate::{Compound, List, Value}; + +impl<'de> Deserialize<'de> for Value { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ValueVisitor; + + impl<'de> Visitor<'de> for ValueVisitor { + type Value = Value; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a valid NBT type") + } + + fn visit_bool(self, v: bool) -> Result + where + E: de::Error, + { + Ok(Value::Byte(v as _)) + } + + fn visit_i8(self, v: i8) -> Result + where + E: de::Error, + { + Ok(Value::Byte(v)) + } + + fn visit_i16(self, v: i16) -> Result + where + E: de::Error, + { + Ok(Value::Short(v)) + } + + fn visit_i32(self, v: i32) -> Result + where + E: de::Error, + { + Ok(Value::Int(v)) + } + + fn visit_i64(self, v: i64) -> Result + where + E: de::Error, + { + Ok(Value::Long(v)) + } + + fn visit_u8(self, v: u8) -> Result + where + E: de::Error, + { + Ok(Value::Byte(v as _)) + } + + fn visit_u16(self, v: u16) -> Result + where + E: de::Error, + { + Ok(Value::Short(v as _)) + } + + fn visit_u32(self, v: u32) -> Result + where + E: de::Error, + { + Ok(Value::Int(v as _)) + } + + fn visit_u64(self, v: u64) -> Result + where + E: de::Error, + { + Ok(Value::Long(v as _)) + } + + fn visit_f32(self, v: f32) -> Result + where + E: de::Error, + { + Ok(Value::Float(v)) + } + + fn visit_f64(self, v: f64) -> Result + where + E: de::Error, + { + Ok(Value::Double(v)) + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + Ok(Value::String(v.into())) + } + + fn visit_string(self, v: String) -> Result + where + E: de::Error, + { + Ok(Value::String(v)) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + let slice: &[i8] = + unsafe { slice::from_raw_parts(v.as_ptr() as *const i8, v.len()) }; + + Ok(Value::ByteArray(slice.into())) + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: de::Error, + { + Ok(Value::ByteArray(u8_vec_to_i8_vec(v))) + } + + fn visit_seq(self, seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + Ok(List::deserialize(SeqAccessDeserializer::new(seq))?.into()) + } + + fn visit_map(self, map: A) -> Result + where + A: de::MapAccess<'de>, + { + Ok(Compound::deserialize(MapAccessDeserializer::new(map))?.into()) + } + } + + deserializer.deserialize_any(ValueVisitor) + } +} + +impl<'de> Deserialize<'de> for List { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ListVisitor; + + impl<'de> Visitor<'de> for ListVisitor { + type Value = List; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a sequence or bytes") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + match seq.next_element::()? { + Some(v) => match v { + Value::Byte(v) => deserialize_seq_remainder(v, seq), + Value::Short(v) => deserialize_seq_remainder(v, seq), + Value::Int(v) => deserialize_seq_remainder(v, seq), + Value::Long(v) => deserialize_seq_remainder(v, seq), + Value::Float(v) => deserialize_seq_remainder(v, seq), + Value::Double(v) => deserialize_seq_remainder(v, seq), + Value::ByteArray(v) => deserialize_seq_remainder(v, seq), + Value::String(v) => deserialize_seq_remainder(v, seq), + Value::List(v) => deserialize_seq_remainder(v, seq), + Value::Compound(v) => deserialize_seq_remainder(v, seq), + Value::IntArray(v) => deserialize_seq_remainder(v, seq), + Value::LongArray(v) => deserialize_seq_remainder(v, seq), + }, + None => Ok(List::End), + } + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: de::Error, + { + Ok(List::Byte(u8_vec_to_i8_vec(v))) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + let bytes: &[i8] = + unsafe { slice::from_raw_parts(v.as_ptr() as *const i8, v.len()) }; + + Ok(List::Byte(bytes.into())) + } + } + + deserializer.deserialize_seq(ListVisitor) + } +} + +/// Deserializes the remainder of a sequence after having +/// determined the type of the first element. +fn deserialize_seq_remainder<'de, T, A, R>(first: T, mut seq: A) -> Result +where + T: Deserialize<'de>, + Vec: Into, + A: de::SeqAccess<'de>, +{ + let mut vec = match seq.size_hint() { + Some(n) => Vec::with_capacity(n + 1), + None => Vec::new(), + }; + + vec.push(first); + + while let Some(v) = seq.next_element()? { + vec.push(v); + } + + Ok(vec.into()) +} + +impl<'de> Deserializer<'de> for Compound { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(MapDeserializer::new(self.into_iter())) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum identifier ignored_any + } +} + +impl<'de> IntoDeserializer<'de, Error> for Compound { + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'de> Deserializer<'de> for Value { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self { + Value::Byte(v) => visitor.visit_i8(v), + Value::Short(v) => visitor.visit_i16(v), + Value::Int(v) => visitor.visit_i32(v), + Value::Long(v) => visitor.visit_i64(v), + Value::Float(v) => visitor.visit_f32(v), + Value::Double(v) => visitor.visit_f64(v), + Value::ByteArray(v) => visitor.visit_byte_buf(i8_vec_to_u8_vec(v)), + Value::String(v) => visitor.visit_string(v), + Value::List(v) => v.deserialize_any(visitor), + Value::Compound(v) => v.into_deserializer().deserialize_any(visitor), + Value::IntArray(v) => v.into_deserializer().deserialize_any(visitor), + Value::LongArray(v) => v.into_deserializer().deserialize_any(visitor), + } + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self { + Value::Byte(b) => visitor.visit_bool(b != 0), + _ => self.deserialize_any(visitor), + } + } + + forward_to_deserialize_any! { + i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum identifier ignored_any + } +} + +impl<'de> IntoDeserializer<'de, Error> for Value { + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'de> Deserializer<'de> for List { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + struct EndSeqAccess; + + impl<'de> SeqAccess<'de> for EndSeqAccess { + type Error = Error; + + fn next_element_seed(&mut self, _seed: T) -> Result, Self::Error> + where + T: de::DeserializeSeed<'de>, + { + Ok(None) + } + } + + match self { + List::End => visitor.visit_seq(EndSeqAccess), + List::Byte(v) => visitor.visit_byte_buf(i8_vec_to_u8_vec(v)), + List::Short(v) => v.into_deserializer().deserialize_any(visitor), + List::Int(v) => v.into_deserializer().deserialize_any(visitor), + List::Long(v) => v.into_deserializer().deserialize_any(visitor), + List::Float(v) => v.into_deserializer().deserialize_any(visitor), + List::Double(v) => v.into_deserializer().deserialize_any(visitor), + List::ByteArray(v) => v.into_deserializer().deserialize_any(visitor), + List::String(v) => v.into_deserializer().deserialize_any(visitor), + List::List(v) => v.into_deserializer().deserialize_any(visitor), + List::Compound(v) => v.into_deserializer().deserialize_any(visitor), + List::IntArray(v) => v.into_deserializer().deserialize_any(visitor), + List::LongArray(v) => v.into_deserializer().deserialize_any(visitor), + } + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum identifier ignored_any + } +} + +impl<'de> IntoDeserializer<'de, Error> for List { + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self + } +} diff --git a/crates/valence_nbt/src/serde/ser.rs b/crates/valence_nbt/src/serde/ser.rs new file mode 100644 index 0000000..941beff --- /dev/null +++ b/crates/valence_nbt/src/serde/ser.rs @@ -0,0 +1,625 @@ +use core::slice; +use std::marker::PhantomData; + +use serde::ser::{Impossible, SerializeMap, SerializeSeq, SerializeStruct}; +use serde::{Serialize, Serializer}; + +use super::{u8_vec_to_i8_vec, Error}; +use crate::{Compound, List, Value}; + +impl Serialize for Value { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Value::Byte(v) => serializer.serialize_i8(*v), + Value::Short(v) => serializer.serialize_i16(*v), + Value::Int(v) => serializer.serialize_i32(*v), + Value::Long(v) => serializer.serialize_i64(*v), + Value::Float(v) => serializer.serialize_f32(*v), + Value::Double(v) => serializer.serialize_f64(*v), + Value::ByteArray(v) => { + // SAFETY: i8 has the same layout as u8. + let bytes = unsafe { slice::from_raw_parts(v.as_ptr() as *const u8, v.len()) }; + + serializer.serialize_bytes(bytes) + } + Value::String(v) => serializer.serialize_str(v), + Value::List(v) => v.serialize(serializer), + Value::Compound(v) => v.serialize(serializer), + Value::IntArray(v) => v.serialize(serializer), + Value::LongArray(v) => v.serialize(serializer), + } + } +} + +impl Serialize for List { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + List::End => serializer.serialize_seq(Some(0))?.end(), + List::Byte(v) => v.serialize(serializer), + List::Short(v) => v.serialize(serializer), + List::Int(v) => v.serialize(serializer), + List::Long(v) => v.serialize(serializer), + List::Float(v) => v.serialize(serializer), + List::Double(v) => v.serialize(serializer), + List::ByteArray(v) => v.serialize(serializer), + List::String(v) => v.serialize(serializer), + List::List(v) => v.serialize(serializer), + List::Compound(v) => v.serialize(serializer), + List::IntArray(v) => v.serialize(serializer), + List::LongArray(v) => v.serialize(serializer), + } + } +} + +macro_rules! unsupported { + ($lit:literal) => { + Err(Error::new(concat!("unsupported type: ", $lit))) + }; +} + +/// [`Serializer`] whose output is [`Compound`]. +pub struct CompoundSerializer; + +impl Serializer for CompoundSerializer { + type Ok = Compound; + + type Error = Error; + + type SerializeSeq = Impossible; + + type SerializeTuple = Impossible; + + type SerializeTupleStruct = Impossible; + + type SerializeTupleVariant = Impossible; + + type SerializeMap = GenericSerializeMap; + + type SerializeStruct = GenericSerializeStruct; + + type SerializeStructVariant = Impossible; + + fn serialize_bool(self, _v: bool) -> Result { + unsupported!("bool") + } + + fn serialize_i8(self, _v: i8) -> Result { + unsupported!("i8") + } + + fn serialize_i16(self, _v: i16) -> Result { + unsupported!("i16") + } + + fn serialize_i32(self, _v: i32) -> Result { + unsupported!("i32") + } + + fn serialize_i64(self, _v: i64) -> Result { + unsupported!("i64") + } + + fn serialize_u8(self, _v: u8) -> Result { + unsupported!("u8") + } + + fn serialize_u16(self, _v: u16) -> Result { + unsupported!("u16") + } + + fn serialize_u32(self, _v: u32) -> Result { + unsupported!("u32") + } + + fn serialize_u64(self, _v: u64) -> Result { + unsupported!("u64") + } + + fn serialize_f32(self, _v: f32) -> Result { + unsupported!("f32") + } + + fn serialize_f64(self, _v: f64) -> Result { + unsupported!("f64") + } + + fn serialize_char(self, _v: char) -> Result { + unsupported!("char") + } + + fn serialize_str(self, _v: &str) -> Result { + unsupported!("str") + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + unsupported!("bytes") + } + + fn serialize_none(self) -> Result { + unsupported!("none") + } + + fn serialize_some(self, _value: &T) -> Result + where + T: Serialize, + { + unsupported!("some") + } + + fn serialize_unit(self) -> Result { + unsupported!("unit") + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + unsupported!("unit struct") + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + unsupported!("unit variant") + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + unsupported!("newtype struct") + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + unsupported!("newtype variant") + } + + fn serialize_seq(self, _len: Option) -> Result { + unsupported!("seq") + } + + fn serialize_tuple(self, _len: usize) -> Result { + unsupported!("tuple") + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + unsupported!("tuple struct") + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unsupported!("tuple variant") + } + + fn serialize_map(self, len: Option) -> Result { + Ok(GenericSerializeMap::new(len)) + } + + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + Ok(GenericSerializeStruct::new(len)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unsupported!("struct variant") + } +} + +/// [`Serializer`] whose output is [`Value`]. +struct ValueSerializer; + +impl Serializer for ValueSerializer { + type Ok = Value; + + type Error = Error; + + type SerializeSeq = ValueSerializeSeq; + + type SerializeTuple = Impossible; + + type SerializeTupleStruct = Impossible; + + type SerializeTupleVariant = Impossible; + + type SerializeMap = GenericSerializeMap; + + type SerializeStruct = GenericSerializeStruct; + + type SerializeStructVariant = Impossible; + + fn serialize_bool(self, v: bool) -> Result { + Ok(Value::Byte(v as _)) + } + + fn serialize_i8(self, v: i8) -> Result { + Ok(Value::Byte(v)) + } + + fn serialize_i16(self, v: i16) -> Result { + Ok(Value::Short(v)) + } + + fn serialize_i32(self, v: i32) -> Result { + Ok(Value::Int(v)) + } + + fn serialize_i64(self, v: i64) -> Result { + Ok(Value::Long(v)) + } + + fn serialize_u8(self, v: u8) -> Result { + Ok(Value::Byte(v as _)) + } + + fn serialize_u16(self, v: u16) -> Result { + Ok(Value::Short(v as _)) + } + + fn serialize_u32(self, v: u32) -> Result { + Ok(Value::Int(v as _)) + } + + fn serialize_u64(self, v: u64) -> Result { + Ok(Value::Long(v as _)) + } + + fn serialize_f32(self, v: f32) -> Result { + Ok(Value::Float(v)) + } + + fn serialize_f64(self, v: f64) -> Result { + Ok(Value::Double(v)) + } + + fn serialize_char(self, v: char) -> Result { + Ok(Value::String(v.into())) + } + + fn serialize_str(self, v: &str) -> Result { + Ok(Value::String(v.into())) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + Ok(Value::ByteArray(u8_vec_to_i8_vec(v.into()))) + } + + fn serialize_none(self) -> Result { + unsupported!("none") + } + + fn serialize_some(self, _value: &T) -> Result + where + T: Serialize, + { + unsupported!("some") + } + + fn serialize_unit(self) -> Result { + unsupported!("unit") + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + unsupported!("unit variant") + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result + where + T: Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + unsupported!("newtype variant") + } + + fn serialize_seq(self, len: Option) -> Result { + Ok(ValueSerializeSeq::End { + len: len.unwrap_or(0), + }) + } + + fn serialize_tuple(self, _len: usize) -> Result { + unsupported!("tuple") + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + unsupported!("tuple struct") + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unsupported!("tuple variant") + } + + fn serialize_map(self, len: Option) -> Result { + Ok(GenericSerializeMap::new(len)) + } + + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + Ok(GenericSerializeStruct::new(len)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unsupported!("struct variant") + } +} + +enum ValueSerializeSeq { + End { len: usize }, + Byte(Vec), + Short(Vec), + Int(Vec), + Long(Vec), + Float(Vec), + Double(Vec), + ByteArray(Vec>), + String(Vec), + List(Vec), + Compound(Vec), + IntArray(Vec>), + LongArray(Vec>), +} + +impl SerializeSeq for ValueSerializeSeq { + type Ok = Value; + + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + macro_rules! serialize_variant { + ($variant:ident, $vec:ident, $elem:ident) => {{ + match $elem.serialize(ValueSerializer)? { + Value::$variant(val) => { + $vec.push(val); + Ok(()) + } + _ => Err(Error::new(concat!( + "heterogeneous NBT list (expected `", + stringify!($variant), + "` element)" + ))), + } + }}; + } + + match self { + Self::End { len } => { + fn vec(elem: T, len: usize) -> Vec { + let mut vec = Vec::with_capacity(len); + vec.push(elem); + vec + } + + // Set the first element of the list. + *self = match value.serialize(ValueSerializer)? { + Value::Byte(v) => Self::Byte(vec(v, *len)), + Value::Short(v) => Self::Short(vec(v, *len)), + Value::Int(v) => Self::Int(vec(v, *len)), + Value::Long(v) => Self::Long(vec(v, *len)), + Value::Float(v) => Self::Float(vec(v, *len)), + Value::Double(v) => Self::Double(vec(v, *len)), + Value::ByteArray(v) => Self::ByteArray(vec(v, *len)), + Value::String(v) => Self::String(vec(v, *len)), + Value::List(v) => Self::List(vec(v, *len)), + Value::Compound(v) => Self::Compound(vec(v, *len)), + Value::IntArray(v) => Self::IntArray(vec(v, *len)), + Value::LongArray(v) => Self::LongArray(vec(v, *len)), + }; + Ok(()) + } + Self::Byte(v) => serialize_variant!(Byte, v, value), + Self::Short(v) => serialize_variant!(Short, v, value), + Self::Int(v) => serialize_variant!(Int, v, value), + Self::Long(v) => serialize_variant!(Long, v, value), + Self::Float(v) => serialize_variant!(Float, v, value), + Self::Double(v) => serialize_variant!(Double, v, value), + Self::ByteArray(v) => serialize_variant!(ByteArray, v, value), + Self::String(v) => serialize_variant!(String, v, value), + Self::List(v) => serialize_variant!(List, v, value), + Self::Compound(v) => serialize_variant!(Compound, v, value), + Self::IntArray(v) => serialize_variant!(IntArray, v, value), + Self::LongArray(v) => serialize_variant!(LongArray, v, value), + } + } + + fn end(self) -> Result { + Ok(match self { + Self::End { .. } => List::End.into(), + Self::Byte(v) => v.into(), + Self::Short(v) => List::Short(v).into(), + Self::Int(v) => v.into(), + Self::Long(v) => List::Long(v).into(), + Self::Float(v) => List::Float(v).into(), + Self::Double(v) => List::Double(v).into(), + Self::ByteArray(v) => List::ByteArray(v).into(), + Self::String(v) => List::String(v).into(), + Self::List(v) => List::List(v).into(), + Self::Compound(v) => List::Compound(v).into(), + Self::IntArray(v) => List::IntArray(v).into(), + Self::LongArray(v) => List::LongArray(v).into(), + }) + } +} + +#[doc(hidden)] +pub struct GenericSerializeMap { + /// Temp storage for `serialize_key`. + key: Option, + res: Compound, + _marker: PhantomData, +} + +impl GenericSerializeMap { + pub fn new(len: Option) -> Self { + Self { + key: None, + res: Compound::with_capacity(len.unwrap_or(0)), + _marker: PhantomData, + } + } +} + +impl SerializeMap for GenericSerializeMap +where + Compound: Into, +{ + type Ok = Ok; + + type Error = Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + debug_assert!( + self.key.is_none(), + "call to `serialize_key` must be followed by `serialize_value`" + ); + + match key.serialize(ValueSerializer)? { + Value::String(s) => { + self.key = Some(s); + Ok(()) + } + _ => Err(Error::new("invalid map key type (expected string)")), + } + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + let key = self + .key + .take() + .expect("missing previous call to `serialize_key`"); + self.res.insert(key, value.serialize(ValueSerializer)?); + Ok(()) + } + + fn end(self) -> Result { + Ok(self.res.into()) + } +} + +#[doc(hidden)] +pub struct GenericSerializeStruct { + c: Compound, + _marker: PhantomData, +} + +impl GenericSerializeStruct { + fn new(len: usize) -> Self { + Self { + c: Compound::with_capacity(len), + _marker: PhantomData, + } + } +} + +impl SerializeStruct for GenericSerializeStruct +where + Compound: Into, +{ + type Ok = Ok; + + type Error = Error; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> + where + T: Serialize, + { + self.c.insert(key, value.serialize(ValueSerializer)?); + Ok(()) + } + + fn end(self) -> Result { + Ok(self.c.into()) + } +} diff --git a/crates/valence_registry/src/lib.rs b/crates/valence_registry/src/lib.rs index 3fa4205..0d9bdaa 100644 --- a/crates/valence_registry/src/lib.rs +++ b/crates/valence_registry/src/lib.rs @@ -113,7 +113,7 @@ impl RegistryCodec { impl Default for RegistryCodec { fn default() -> Self { let codec = include_bytes!("../../../extracted/registry_codec_1.19.4.dat"); - let compound = valence_nbt::from_binary_slice(&mut codec.as_slice()) + let compound = Compound::from_binary(&mut codec.as_slice()) .expect("failed to decode vanilla registry codec") .0;