Add serde_nbt

This commit is contained in:
Ryan 2022-08-24 18:17:51 -07:00
parent 793a795732
commit 7568ff8b4b
9 changed files with 1976 additions and 1 deletions

View file

@ -67,4 +67,4 @@ num = "0.4"
protocol = [] protocol = []
[workspace] [workspace]
members = ["packet-inspector"] members = ["packet-inspector", "serde_nbt"]

15
serde_nbt/Cargo.toml Normal file
View file

@ -0,0 +1,15 @@
[package]
name = "serde_nbt"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1"
byteorder = "1.4.3"
cesu8 = "1.1.0"
serde = "1"
smallvec = { version = "1.9.0", features = ["union", "const_generics"] }
[dev-dependencies]
ordered-float = { version = "3.0.0", features = ["serde"] }
hematite-nbt = "0.5"

29
serde_nbt/src/array.rs Normal file
View file

@ -0,0 +1,29 @@
use serde::{Serialize, Serializer};
use serde::ser::SerializeTupleStruct;
use super::{BYTE_ARRAY_MAGIC, INT_ARRAY_MAGIC, LONG_ARRAY_MAGIC};
macro_rules! def {
($name:ident, $magic:ident) => {
pub fn $name<T, S>(array: T, serializer: S) -> Result<S::Ok, S::Error>
where
T: IntoIterator,
T::IntoIter: ExactSizeIterator,
T::Item: Serialize,
S: Serializer,
{
let it = array.into_iter();
let mut sts = serializer.serialize_tuple_struct($magic, it.len())?;
for item in it {
sts.serialize_field(&item)?;
}
sts.end()
}
}
}
def!(byte_array, BYTE_ARRAY_MAGIC);
def!(int_array, INT_ARRAY_MAGIC);
def!(long_array, LONG_ARRAY_MAGIC);

363
serde_nbt/src/binary/de.rs Normal file
View file

@ -0,0 +1,363 @@
// TODO: recursion limit.
// TODO: serialize and deserialize recursion limit wrappers. (crate:
// serde_reclimit).
use std::borrow::Cow;
use std::io::Read;
use anyhow::anyhow;
use byteorder::{BigEndian, ReadBytesExt};
use cesu8::from_java_cesu8;
use serde::de::{DeserializeOwned, DeserializeSeed, Visitor};
use serde::{de, forward_to_deserialize_any};
use smallvec::SmallVec;
use crate::{Error, Result, Tag};
pub fn from_reader<R, T>(reader: R) -> Result<T>
where
R: Read,
T: DeserializeOwned,
{
T::deserialize(&mut Deserializer::new(reader, false))
}
pub struct Deserializer<R> {
reader: R,
root_name: Option<String>,
}
impl<R: Read> Deserializer<R> {
pub fn new(reader: R, save_root_name: bool) -> Self {
Self {
reader,
root_name: if save_root_name {
Some(String::new())
} else {
None
},
}
}
pub fn into_inner(self) -> (R, Option<String>) {
(self.reader, self.root_name)
}
fn read_header(&mut self) -> Result<Tag> {
let tag = Tag::from_u8(self.reader.read_u8()?)?;
if tag != Tag::Compound {
return Err(Error(anyhow!(
"unexpected tag `{tag}` (root value must be a compound)"
)));
}
if let Some(name) = &mut self.root_name {
let mut buf = SmallVec::<[u8; 128]>::new();
for _ in 0..self.reader.read_u16::<BigEndian>()? {
buf.push(self.reader.read_u8()?);
}
*name = from_java_cesu8(&buf)
.map_err(|e| Error(anyhow!(e)))?
.into_owned();
} else {
for _ in 0..self.reader.read_u16::<BigEndian>()? {
self.reader.read_u8()?;
}
}
Ok(tag)
}
}
impl<'de: 'a, 'a, R: Read + 'de> de::Deserializer<'de> for &'a mut Deserializer<R> {
type Error = Error;
forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct seq tuple tuple_struct map
enum identifier ignored_any
}
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let tag = self.read_header()?;
PayloadDeserializer {
reader: &mut self.reader,
tag,
}
.deserialize_any(visitor)
}
fn deserialize_struct<V>(
self,
name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
let tag = self.read_header()?;
PayloadDeserializer {
reader: &mut self.reader,
tag,
}
.deserialize_struct(name, fields, visitor)
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn is_human_readable(&self) -> bool {
false
}
}
struct PayloadDeserializer<'a, R> {
reader: &'a mut R,
/// The type of payload to be deserialized.
tag: Tag,
}
impl<'de: 'a, 'a, R: Read> de::Deserializer<'de> for PayloadDeserializer<'a, R> {
type Error = Error;
forward_to_deserialize_any! {
i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option seq tuple tuple_struct map enum identifier
ignored_any
}
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
match self.tag {
Tag::End => unreachable!("invalid payload tag"),
Tag::Byte => visitor.visit_i8(self.reader.read_i8()?),
Tag::Short => visitor.visit_i16(self.reader.read_i16::<BigEndian>()?),
Tag::Int => visitor.visit_i32(self.reader.read_i32::<BigEndian>()?),
Tag::Long => visitor.visit_i64(self.reader.read_i64::<BigEndian>()?),
Tag::Float => visitor.visit_f32(self.reader.read_f32::<BigEndian>()?),
Tag::Double => visitor.visit_f64(self.reader.read_f64::<BigEndian>()?),
Tag::ByteArray => {
let len = self.reader.read_i32::<BigEndian>()?;
visitor.visit_seq(SeqAccess::new(self.reader, Tag::Byte, len)?)
}
Tag::String => {
let mut buf = SmallVec::<[u8; 128]>::new();
for _ in 0..self.reader.read_u16::<BigEndian>()? {
buf.push(self.reader.read_u8()?);
}
match from_java_cesu8(&buf).map_err(|e| Error(anyhow!(e)))? {
Cow::Borrowed(s) => visitor.visit_str(s),
Cow::Owned(string) => visitor.visit_string(string),
}
}
Tag::List => {
let element_type = Tag::from_u8(self.reader.read_u8()?)?;
let len = self.reader.read_i32::<BigEndian>()?;
visitor.visit_seq(SeqAccess::new(self.reader, element_type, len)?)
}
Tag::Compound => visitor.visit_map(MapAccess::new(self.reader, &[])),
Tag::IntArray => {
let len = self.reader.read_i32::<BigEndian>()?;
visitor.visit_seq(SeqAccess::new(self.reader, Tag::Int, len)?)
}
Tag::LongArray => {
let len = self.reader.read_i32::<BigEndian>()?;
visitor.visit_seq(SeqAccess::new(self.reader, Tag::Long, len)?)
}
}
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
if self.tag == Tag::Byte {
match self.reader.read_i8()? {
0 => visitor.visit_bool(false),
1 => visitor.visit_bool(true),
n => visitor.visit_i8(n),
}
} else {
self.deserialize_any(visitor)
}
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_struct<V>(
self,
_name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
if self.tag == Tag::Compound {
visitor.visit_map(MapAccess::new(self.reader, fields))
} else {
self.deserialize_any(visitor)
}
}
fn is_human_readable(&self) -> bool {
false
}
}
#[doc(hidden)]
pub struct SeqAccess<'a, R> {
reader: &'a mut R,
element_type: Tag,
remaining: u32,
}
impl<'a, R: Read> SeqAccess<'a, R> {
fn new(reader: &'a mut R, element_type: Tag, len: i32) -> Result<Self> {
if len < 0 {
return Err(Error(anyhow!("list with negative length")));
}
if element_type == Tag::End && len != 0 {
return Err(Error(anyhow!(
"list with TAG_End element type must have length zero"
)));
}
Ok(Self {
reader,
element_type,
remaining: len as u32,
})
}
// TODO: function to check if this is for an array or list.
}
impl<'de: 'a, 'a, R: Read> de::SeqAccess<'de> for SeqAccess<'a, R> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
{
if self.remaining > 0 {
self.remaining -= 1;
seed.deserialize(PayloadDeserializer {
reader: self.reader,
tag: self.element_type,
})
.map(Some)
} else {
Ok(None)
}
}
fn size_hint(&self) -> Option<usize> {
Some(self.remaining as usize)
}
}
#[doc(hidden)]
pub struct MapAccess<'a, R> {
reader: &'a mut R,
value_tag: Tag,
/// Provides error context when deserializing structs.
fields: &'static [&'static str],
}
impl<'a, R: Read> MapAccess<'a, R> {
fn new(reader: &'a mut R, fields: &'static [&'static str]) -> Self {
Self {
reader,
value_tag: Tag::End,
fields,
}
}
}
impl<'de: 'a, 'a, R: Read> de::MapAccess<'de> for MapAccess<'a, R> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: DeserializeSeed<'de>,
{
self.value_tag = Tag::from_u8(self.reader.read_u8()?)?;
if self.value_tag == Tag::End {
return Ok(None);
}
seed.deserialize(PayloadDeserializer {
reader: self.reader,
tag: Tag::String,
})
.map(Some)
.map_err(|e| match self.fields {
[f, ..] => e.context(anyhow!("compound key (field `{f}`)")),
[] => e,
})
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: DeserializeSeed<'de>,
{
if self.value_tag == Tag::End {
return Err(Error(anyhow!("end of compound?")));
}
let field = match self.fields {
[field, rest @ ..] => {
self.fields = rest;
Some(*field)
}
[] => None,
};
seed.deserialize(PayloadDeserializer {
reader: self.reader,
tag: self.value_tag,
})
.map_err(|e| match field {
Some(f) => e.context(anyhow!("compound value (field `{f}`)")),
None => e,
})
}
}

733
serde_nbt/src/binary/ser.rs Normal file
View file

@ -0,0 +1,733 @@
use std::io::Write;
use std::result::Result as StdResult;
use anyhow::anyhow;
use byteorder::{BigEndian, WriteBytesExt};
use cesu8::to_java_cesu8;
use serde::{ser, Serialize};
use crate::{Error, Result, Tag};
pub fn to_writer<W, T>(mut writer: W, root_name: &str, value: &T) -> Result<()>
where
W: Write,
T: Serialize + ?Sized,
{
value.serialize(&mut Serializer::new(&mut writer, root_name))
}
pub struct Serializer<'w, 'n, W: ?Sized> {
writer: &'w mut W,
allowed_tag: AllowedTag,
ser_state: SerState<'n>,
}
#[derive(Copy, Clone)]
enum AllowedTag {
/// Any tag type is permitted to be serialized.
Any {
/// Set to the type that was serialized. Is unspecified if serialization
/// failed or has not taken place yet.
written_tag: Tag,
},
/// Only one specific tag type is permitted to be serialized.
One {
/// The permitted tag.
tag: Tag,
/// The error message if a tag mismatch happens.
errmsg: &'static str,
},
}
enum SerState<'n> {
/// Serialize just the payload and nothing else.
PayloadOnly,
/// Prefix the payload with the tag and a length.
/// Used for the first element of lists.
FirstListElement {
/// Length of the list being serialized.
len: i32,
},
/// Prefix the payload with the tag and a name.
/// Used for compound fields and the root compound.
Named { name: &'n str },
}
impl<'w, 'n, W: Write + ?Sized> Serializer<'w, 'n, W> {
pub fn new(writer: &'w mut W, root_name: &'n str) -> Self {
Self {
writer,
allowed_tag: AllowedTag::One {
tag: Tag::Compound,
errmsg: "root value must be a compound",
},
ser_state: SerState::Named { name: root_name },
}
}
pub fn writer(&mut self) -> &mut W {
self.writer
}
pub fn root_name(&self) -> &'n str {
match &self.ser_state {
SerState::Named { name } => *name,
_ => unreachable!(),
}
}
pub fn set_root_name(&mut self, root_name: &'n str) {
self.ser_state = SerState::Named { name: root_name };
}
fn write_header(&mut self, tag: Tag) -> Result<()> {
match &mut self.allowed_tag {
AllowedTag::Any { written_tag } => *written_tag = tag,
AllowedTag::One {
tag: expected_tag,
errmsg,
} => {
if tag != *expected_tag {
let e = anyhow!(*errmsg).context(format!(
"attempt to serialize {tag} where {expected_tag} was expected"
));
return Err(Error(e));
}
}
}
match &mut self.ser_state {
SerState::PayloadOnly => {}
SerState::FirstListElement { len } => {
self.writer.write_u8(tag as u8)?;
self.writer.write_i32::<BigEndian>(*len)?;
}
SerState::Named { name } => {
self.writer.write_u8(tag as u8)?;
write_string_payload(*name, self.writer)?;
}
}
Ok(())
}
}
type Impossible = ser::Impossible<(), Error>;
#[inline]
fn unsupported<T>(typ: &str) -> Result<T> {
Err(Error(anyhow!("{typ} is not supported")))
}
impl<'a, W: Write + ?Sized> ser::Serializer for &'a mut Serializer<'_, '_, W> {
type Error = Error;
type Ok = ();
type SerializeMap = SerializeMap<'a, W>;
type SerializeSeq = SerializeSeq<'a, W>;
type SerializeStruct = SerializeStruct<'a, W>;
type SerializeStructVariant = Impossible;
type SerializeTuple = Impossible;
type SerializeTupleStruct = SerializeArray<'a, W>;
type SerializeTupleVariant = Impossible;
fn serialize_bool(self, v: bool) -> Result<()> {
self.write_header(Tag::Byte)?;
Ok(self.writer.write_i8(v as i8)?)
}
fn serialize_i8(self, v: i8) -> Result<()> {
self.write_header(Tag::Byte)?;
Ok(self.writer.write_i8(v)?)
}
fn serialize_i16(self, v: i16) -> Result<()> {
self.write_header(Tag::Short)?;
Ok(self.writer.write_i16::<BigEndian>(v)?)
}
fn serialize_i32(self, v: i32) -> Result<()> {
self.write_header(Tag::Int)?;
Ok(self.writer.write_i32::<BigEndian>(v)?)
}
fn serialize_i64(self, v: i64) -> Result<()> {
self.write_header(Tag::Long)?;
Ok(self.writer.write_i64::<BigEndian>(v)?)
}
fn serialize_u8(self, _v: u8) -> Result<()> {
unsupported("u8")
}
fn serialize_u16(self, _v: u16) -> Result<()> {
unsupported("u16")
}
fn serialize_u32(self, _v: u32) -> Result<()> {
unsupported("u32")
}
fn serialize_u64(self, _v: u64) -> Result<()> {
unsupported("u64")
}
fn serialize_f32(self, v: f32) -> Result<()> {
self.write_header(Tag::Float)?;
Ok(self.writer.write_f32::<BigEndian>(v)?)
}
fn serialize_f64(self, v: f64) -> Result<()> {
self.write_header(Tag::Double)?;
Ok(self.writer.write_f64::<BigEndian>(v)?)
}
fn serialize_char(self, _v: char) -> Result<()> {
unsupported("char")
}
fn serialize_str(self, v: &str) -> Result<()> {
self.write_header(Tag::String)?;
write_string_payload(v, self.writer)
}
fn serialize_bytes(self, _v: &[u8]) -> Result<()> {
unsupported("&[u8]")
}
fn serialize_none(self) -> Result<()> {
unsupported("Option<T>")
}
fn serialize_some<T: ?Sized>(self, _value: &T) -> Result<()>
where
T: Serialize,
{
unsupported("Option<T>")
}
fn serialize_unit(self) -> Result<()> {
Ok(())
}
fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
Ok(())
}
fn serialize_unit_variant(
self,
_name: &'static str,
variant_index: u32,
_variant: &'static str,
) -> Result<()> {
match variant_index.try_into() {
Ok(idx) => self.serialize_i32(idx),
Err(_) => Err(Error(anyhow!(
"variant index of {variant_index} is out of range"
))),
}
}
fn serialize_newtype_struct<T: ?Sized>(self, _name: &'static str, value: &T) -> Result<()>
where
T: Serialize,
{
value.serialize(self)
}
fn serialize_newtype_variant<T: ?Sized>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result<()>
where
T: Serialize,
{
unsupported("newtype variant")
}
fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
self.write_header(Tag::List)?;
let len = match len {
Some(len) => len,
None => return Err(Error(anyhow!("list length must be known up front"))),
};
match len.try_into() {
Ok(len) => Ok(SerializeSeq {
writer: self.writer,
element_type: Tag::End,
remaining: len,
}),
Err(_) => Err(Error(anyhow!("length of list exceeds i32::MAX"))),
}
}
fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
unsupported("tuple")
}
fn serialize_tuple_struct(
self,
name: &'static str,
len: usize,
) -> Result<Self::SerializeTupleStruct> {
let element_type = match name {
crate::BYTE_ARRAY_MAGIC => {
self.write_header(Tag::ByteArray)?;
Tag::Byte
}
crate::INT_ARRAY_MAGIC => {
self.write_header(Tag::IntArray)?;
Tag::Int
}
crate::LONG_ARRAY_MAGIC => {
self.write_header(Tag::LongArray)?;
Tag::Long
}
_ => return unsupported("tuple struct"),
};
match len.try_into() {
Ok(len) => {
self.writer.write_i32::<BigEndian>(len)?;
Ok(SerializeArray {
writer: self.writer,
element_type,
remaining: len,
})
}
Err(_) => Err(Error(anyhow!("array length of {len} exceeds i32::MAX"))),
}
}
fn serialize_tuple_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleVariant> {
unsupported("tuple variant")
}
fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
self.write_header(Tag::Compound)?;
Ok(SerializeMap {
writer: self.writer,
})
}
fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
self.write_header(Tag::Compound)?;
Ok(SerializeStruct {
writer: self.writer,
})
}
fn serialize_struct_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
unsupported("struct variant")
}
fn is_human_readable(&self) -> bool {
false
}
}
#[doc(hidden)]
pub struct SerializeSeq<'w, W: ?Sized> {
writer: &'w mut W,
/// The element type of this list. TAG_End if unknown.
element_type: Tag,
/// Number of elements left to serialize.
remaining: i32,
}
impl<W: Write + ?Sized> ser::SerializeSeq for SerializeSeq<'_, W> {
type Error = Error;
type Ok = ();
fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<()>
where
T: Serialize,
{
if self.remaining <= 0 {
return Err(Error(anyhow!(
"attempt to serialize more list elements than specified"
)));
}
if self.element_type == Tag::End {
let mut ser = Serializer {
writer: self.writer,
allowed_tag: AllowedTag::Any {
written_tag: Tag::End,
},
ser_state: SerState::FirstListElement {
len: self.remaining,
},
};
value.serialize(&mut ser)?;
self.element_type = match ser.allowed_tag {
AllowedTag::Any { written_tag } => written_tag,
AllowedTag::One { .. } => unreachable!(),
};
} else {
value.serialize(&mut Serializer {
writer: self.writer,
allowed_tag: AllowedTag::One {
tag: self.element_type,
errmsg: "list elements must be homogeneous",
},
ser_state: SerState::PayloadOnly,
})?;
}
self.remaining -= 1;
Ok(())
}
fn end(self) -> Result<()> {
if self.remaining > 0 {
return Err(Error(anyhow!(
"{} list element(s) left to serialize",
self.remaining
)));
}
// Were any elements written?
if self.element_type == Tag::End {
self.writer.write_u8(Tag::End as u8)?;
// List length.
self.writer.write_i32::<BigEndian>(0)?;
}
Ok(())
}
}
#[doc(hidden)]
pub struct SerializeArray<'w, W: ?Sized> {
writer: &'w mut W,
element_type: Tag,
remaining: i32,
}
impl<W: Write + ?Sized> ser::SerializeTupleStruct for SerializeArray<'_, W> {
type Error = Error;
type Ok = ();
fn serialize_field<T: ?Sized>(&mut self, value: &T) -> Result<()>
where
T: Serialize,
{
if self.remaining <= 0 {
return Err(Error(anyhow!(
"attempt to serialize more array elements than specified"
)));
}
value.serialize(&mut Serializer {
writer: self.writer,
allowed_tag: AllowedTag::One {
tag: self.element_type,
errmsg: "mismatched array element type",
},
ser_state: SerState::PayloadOnly,
})?;
self.remaining -= 1;
Ok(())
}
fn end(self) -> Result<()> {
if self.remaining > 0 {
return Err(Error(anyhow!(
"{} array element(s) left to serialize",
self.remaining
)));
}
Ok(())
}
}
#[doc(hidden)]
pub struct SerializeMap<'w, W: ?Sized> {
writer: &'w mut W,
}
impl<W: Write + ?Sized> ser::SerializeMap for SerializeMap<'_, W> {
type Error = Error;
type Ok = ();
fn serialize_key<T: ?Sized>(&mut self, _key: &T) -> Result<()>
where
T: Serialize,
{
Err(Error(anyhow!("map keys cannot be serialized individually")))
}
fn serialize_value<T: ?Sized>(&mut self, _value: &T) -> Result<()>
where
T: Serialize,
{
Err(Error(anyhow!(
"map values cannot be serialized individually"
)))
}
fn serialize_entry<K: ?Sized, V: ?Sized>(&mut self, key: &K, value: &V) -> Result<()>
where
K: Serialize,
V: Serialize,
{
key.serialize(MapEntrySerializer {
writer: self.writer,
value,
})
}
fn end(self) -> Result<()> {
Ok(self.writer.write_u8(Tag::End as u8)?)
}
}
struct MapEntrySerializer<'w, 'v, W: ?Sized, V: ?Sized> {
writer: &'w mut W,
value: &'v V,
}
fn key_not_a_string<T>() -> Result<T> {
Err(Error(anyhow!("map keys must be strings")))
}
impl<W, V> ser::Serializer for MapEntrySerializer<'_, '_, W, V>
where
W: Write + ?Sized,
V: Serialize + ?Sized,
{
type Error = Error;
type Ok = ();
type SerializeMap = Impossible;
type SerializeSeq = Impossible;
type SerializeStruct = Impossible;
type SerializeStructVariant = Impossible;
type SerializeTuple = Impossible;
type SerializeTupleStruct = Impossible;
type SerializeTupleVariant = Impossible;
fn serialize_bool(self, _v: bool) -> Result<()> {
key_not_a_string()
}
fn serialize_i8(self, _v: i8) -> Result<()> {
key_not_a_string()
}
fn serialize_i16(self, _v: i16) -> Result<()> {
key_not_a_string()
}
fn serialize_i32(self, _v: i32) -> Result<()> {
key_not_a_string()
}
fn serialize_i64(self, _v: i64) -> Result<()> {
key_not_a_string()
}
fn serialize_u8(self, _v: u8) -> Result<()> {
key_not_a_string()
}
fn serialize_u16(self, _v: u16) -> Result<()> {
key_not_a_string()
}
fn serialize_u32(self, _v: u32) -> Result<()> {
key_not_a_string()
}
fn serialize_u64(self, _v: u64) -> Result<()> {
key_not_a_string()
}
fn serialize_f32(self, _v: f32) -> Result<()> {
key_not_a_string()
}
fn serialize_f64(self, _v: f64) -> Result<()> {
key_not_a_string()
}
fn serialize_char(self, _v: char) -> Result<()> {
key_not_a_string()
}
fn serialize_str(self, v: &str) -> Result<()> {
self.value
.serialize(&mut Serializer {
writer: self.writer,
allowed_tag: AllowedTag::Any {
written_tag: Tag::End,
},
ser_state: SerState::Named { name: v },
})
.map_err(|e| e.context(format!("map key `{v}`")))
}
fn serialize_bytes(self, _v: &[u8]) -> Result<()> {
key_not_a_string()
}
fn serialize_none(self) -> Result<()> {
key_not_a_string()
}
fn serialize_some<T: ?Sized>(self, _value: &T) -> Result<()>
where
T: Serialize,
{
key_not_a_string()
}
fn serialize_unit(self) -> Result<()> {
key_not_a_string()
}
fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
key_not_a_string()
}
fn serialize_unit_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
) -> Result<()> {
key_not_a_string()
}
fn serialize_newtype_struct<T: ?Sized>(self, _name: &'static str, _value: &T) -> Result<()>
where
T: Serialize,
{
key_not_a_string()
}
fn serialize_newtype_variant<T: ?Sized>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result<()>
where
T: Serialize,
{
key_not_a_string()
}
fn serialize_seq(self, _len: Option<usize>) -> StdResult<Self::SerializeSeq, Self::Error> {
key_not_a_string()
}
fn serialize_tuple(self, _len: usize) -> StdResult<Self::SerializeTuple, Self::Error> {
key_not_a_string()
}
fn serialize_tuple_struct(
self,
_name: &'static str,
_len: usize,
) -> StdResult<Self::SerializeTupleStruct, Self::Error> {
key_not_a_string()
}
fn serialize_tuple_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> StdResult<Self::SerializeTupleVariant, Self::Error> {
key_not_a_string()
}
fn serialize_map(self, _len: Option<usize>) -> StdResult<Self::SerializeMap, Self::Error> {
key_not_a_string()
}
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> StdResult<Self::SerializeStruct, Self::Error> {
key_not_a_string()
}
fn serialize_struct_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> StdResult<Self::SerializeStructVariant, Self::Error> {
key_not_a_string()
}
}
#[doc(hidden)]
pub struct SerializeStruct<'w, W: ?Sized> {
writer: &'w mut W,
}
impl<W: Write + ?Sized> ser::SerializeStruct for SerializeStruct<'_, W> {
type Error = Error;
type Ok = ();
fn serialize_field<T: ?Sized>(&mut self, key: &'static str, value: &T) -> Result<()>
where
T: Serialize,
{
value
.serialize(&mut Serializer {
writer: self.writer,
allowed_tag: AllowedTag::Any {
written_tag: Tag::End,
},
ser_state: SerState::Named { name: key },
})
.map_err(|e| e.context(format!("field `{key}`")))
}
fn end(self) -> Result<()> {
Ok(self.writer.write_u8(Tag::End as u8)?)
}
}
fn write_string_payload(string: &str, writer: &mut (impl Write + ?Sized)) -> Result<()> {
let data = to_java_cesu8(string);
match data.len().try_into() {
Ok(len) => writer.write_u16::<BigEndian>(len)?,
Err(_) => return Err(Error(anyhow!("string byte length exceeds u16::MAX"))),
};
writer.write_all(&data)?;
Ok(())
}

54
serde_nbt/src/error.rs Normal file
View file

@ -0,0 +1,54 @@
use std::error::Error as StdError;
use std::fmt::{Display, Formatter};
use std::io;
use anyhow::anyhow;
use serde::{de, ser};
#[derive(Debug)]
pub struct Error(pub(super) anyhow::Error);
impl Error {
pub(super) fn context<C>(self, ctx: C) -> Self
where
C: Display + Send + Sync + 'static,
{
Self(self.0.context(ctx))
}
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
self.0.source()
}
}
impl ser::Error for Error {
fn custom<T>(msg: T) -> Self
where
T: Display,
{
Error(anyhow!("{msg}"))
}
}
impl de::Error for Error {
fn custom<T>(msg: T) -> Self
where
T: Display,
{
Error(anyhow!("{msg}"))
}
}
impl From<io::Error> for Error {
fn from(e: io::Error) -> Self {
Error(anyhow::Error::new(e))
}
}

90
serde_nbt/src/lib.rs Normal file
View file

@ -0,0 +1,90 @@
use std::fmt;
use std::fmt::{Display, Formatter};
use anyhow::anyhow;
pub use error::*;
pub use value::*;
mod error;
mod value;
mod array;
#[cfg(test)]
mod tests;
pub use array::*;
/// (De)serialization support for the binary representation of NBT.
pub mod binary {
pub use ser::*;
pub use de::*;
mod ser;
mod de;
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
enum Tag {
End,
Byte,
Short,
Int,
Long,
Float,
Double,
ByteArray,
String,
List,
Compound,
IntArray,
LongArray,
}
impl Tag {
pub fn from_u8(id: u8) -> Result<Self> {
match id {
0 => Ok(Tag::End),
1 => Ok(Tag::Byte),
2 => Ok(Tag::Short),
3 => Ok(Tag::Int),
4 => Ok(Tag::Long),
5 => Ok(Tag::Float),
6 => Ok(Tag::Double),
7 => Ok(Tag::ByteArray),
8 => Ok(Tag::String),
9 => Ok(Tag::List),
10 => Ok(Tag::Compound),
11 => Ok(Tag::IntArray),
12 => Ok(Tag::LongArray),
_ => Err(Error(anyhow!("invalid tag byte `{id}`")))
}
}
}
impl Display for Tag {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let s = match self {
Tag::End => "end",
Tag::Byte => "byte",
Tag::Short => "short",
Tag::Int => "int",
Tag::Long => "long",
Tag::Float => "float",
Tag::Double => "double",
Tag::ByteArray => "byte array",
Tag::String => "string",
Tag::List => "list",
Tag::Compound => "compound",
Tag::IntArray => "int array",
Tag::LongArray => "long array",
};
write!(f, "{s}")
}
}
const BYTE_ARRAY_MAGIC: &str = "__byte_array__";
const INT_ARRAY_MAGIC: &str = "__int_array__";
const LONG_ARRAY_MAGIC: &str = "__long_array__";

172
serde_nbt/src/tests.rs Normal file
View file

@ -0,0 +1,172 @@
use binary::{from_reader, to_writer, Deserializer, Serializer};
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use super::*;
const ROOT_NAME: &str = "The root name‽";
#[derive(PartialEq, Debug, Serialize, Deserialize)]
struct Struct {
byte: i8,
list_of_int: Vec<i32>,
list_of_string: Vec<String>,
string: String,
inner: Inner,
#[serde(serialize_with = "int_array")]
int_array: Vec<i32>,
#[serde(serialize_with = "byte_array")]
byte_array: Vec<i8>,
#[serde(serialize_with = "long_array")]
long_array: Vec<i64>,
}
impl Struct {
pub fn new() -> Self {
Self {
byte: 123,
list_of_int: vec![3, -7, 5],
list_of_string: vec!["foo".to_owned(), "bar".to_owned(), "baz".to_owned()],
string: "aé日".to_owned(),
inner: Inner {
int: i32::MIN,
long: i64::MAX,
nan_float: OrderedFloat(f32::NAN),
neg_inf_double: f64::NEG_INFINITY,
},
int_array: vec![5, -9, i32::MIN, 0, i32::MAX],
byte_array: vec![0, 1, 2],
long_array: vec![123, 456, 789],
}
}
pub fn value() -> Value {
Value::Compound(
Compound::from_iter([
("byte".into(), 123.into()),
("list_of_int".into(), List::Int(vec![3, -7, 5]).into()),
(
"list_of_string".into(),
List::String(vec!["foo".into(), "bar".into(), "baz".into()]).into(),
),
("string".into(), "aé日".into()),
(
"inner".into(),
Compound::from_iter([
("int".into(), i32::MIN.into()),
("long".into(), i64::MAX.into()),
("nan_float".into(), f32::NAN.into()),
("neg_inf_double".into(), f64::NEG_INFINITY.into()),
])
.into(),
),
(
"int_array".into(),
vec![5, -9, i32::MIN, 0, i32::MAX].into(),
),
("byte_array".into(), vec![0_i8, 1, 2].into()),
("long_array".into(), vec![123_i64, 456, 789].into()),
])
.into(),
)
}
}
#[derive(PartialEq, Debug, Serialize, Deserialize)]
struct Inner {
int: i32,
long: i64,
nan_float: OrderedFloat<f32>,
neg_inf_double: f64,
}
#[test]
fn round_trip() {
let struct_ = Struct::new();
let mut buf = Vec::new();
struct_
.serialize(&mut Serializer::new(&mut buf, ROOT_NAME))
.unwrap();
let reader = &mut buf.as_slice();
let mut de = Deserializer::new(reader, true);
let example_de = Struct::deserialize(&mut de).unwrap();
assert_eq!(struct_, example_de);
let (_, root) = de.into_inner();
assert_eq!(root.unwrap(), ROOT_NAME);
}
#[test]
fn serialize() {
let struct_ = Struct::new();
let mut buf = Vec::new();
struct_
.serialize(&mut Serializer::new(&mut buf, ROOT_NAME))
.unwrap();
let example_de: Struct = nbt::from_reader(&mut buf.as_slice()).unwrap();
assert_eq!(struct_, example_de);
}
#[test]
fn root_requires_compound() {
let mut buf = Vec::new();
assert!(123
.serialize(&mut Serializer::new(&mut buf, ROOT_NAME))
.is_err());
}
#[test]
fn invalid_array_element() {
#[derive(Serialize)]
struct Struct {
#[serde(serialize_with = "byte_array")]
data: Vec<i32>,
}
let struct_ = Struct {
data: vec![1, 2, 3],
};
let mut buf = Vec::new();
assert!(struct_
.serialize(&mut Serializer::new(&mut buf, ROOT_NAME))
.is_err());
}
// #[test]
// fn struct_to_value() {
// let mut buf = Vec::new();
//
// to_writer(&mut buf, ROOT_NAME, &Struct::new()).unwrap();
//
// let reader = &mut buf.as_slice();
//
// let val: Value = from_reader(reader).unwrap();
//
// eprintln!("{:#?}", Struct::value());
//
// assert_eq!(val, Struct::value());
// }
#[test]
fn value_to_struct() {
let mut buf = Vec::new();
to_writer(&mut buf, ROOT_NAME, &Struct::value()).unwrap();
let reader = &mut buf.as_slice();
let struct_: Struct = from_reader(reader).unwrap();
assert_eq!(struct_, Struct::new());
}

519
serde_nbt/src/value.rs Normal file
View file

@ -0,0 +1,519 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::fmt::Formatter;
use serde::de::{DeserializeSeed, Error, MapAccess, SeqAccess, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::{byte_array, int_array, long_array};
/// Represents an arbitrary NBT value.
#[derive(Clone, PartialEq, Debug)]
pub enum Value {
Byte(i8),
Short(i16),
Int(i32),
Long(i64),
Float(f32),
Double(f64),
ByteArray(Vec<i8>),
String(String),
List(List),
Compound(Compound),
IntArray(Vec<i32>),
LongArray(Vec<i64>),
}
pub type Compound = HashMap<String, Value>;
/// An NBT list value.
///
/// NBT lists are homogeneous, meaning each list element must be of the same
/// type. This is opposed to a format like JSON where lists can be
/// heterogeneous:
///
/// ```json
/// [42, "hello", {}]
/// ```
///
/// Every possible element type has its own variant in this enum. As a result,
/// heterogeneous lists are unrepresentable.
#[derive(Clone, PartialEq, Debug)]
pub enum List {
Byte(Vec<i8>),
Short(Vec<i16>),
Int(Vec<i32>),
Long(Vec<i64>),
Float(Vec<f32>),
Double(Vec<f64>),
ByteArray(Vec<Vec<i8>>),
String(Vec<String>),
List(Vec<List>),
Compound(Vec<Compound>),
IntArray(Vec<Vec<i32>>),
LongArray(Vec<Vec<i64>>),
}
impl List {
pub fn len(&self) -> usize {
match self {
List::Byte(l) => l.len(),
List::Short(l) => l.len(),
List::Int(l) => l.len(),
List::Long(l) => l.len(),
List::Float(l) => l.len(),
List::Double(l) => l.len(),
List::ByteArray(l) => l.len(),
List::String(l) => l.len(),
List::List(l) => l.len(),
List::Compound(l) => l.len(),
List::IntArray(l) => l.len(),
List::LongArray(l) => l.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl From<i8> for Value {
fn from(v: i8) -> Self {
Self::Byte(v)
}
}
impl From<i16> for Value {
fn from(v: i16) -> Self {
Self::Short(v)
}
}
impl From<i32> for Value {
fn from(v: i32) -> Self {
Self::Int(v)
}
}
impl From<i64> for Value {
fn from(v: i64) -> Self {
Self::Long(v)
}
}
impl From<f32> for Value {
fn from(v: f32) -> Self {
Self::Float(v)
}
}
impl From<f64> for Value {
fn from(v: f64) -> Self {
Self::Double(v)
}
}
impl From<Vec<i8>> for Value {
fn from(v: Vec<i8>) -> Self {
Self::ByteArray(v)
}
}
impl From<String> for Value {
fn from(v: String) -> Self {
Self::String(v)
}
}
impl<'a> From<&'a str> for Value {
fn from(v: &'a str) -> Self {
Self::String(v.to_owned())
}
}
impl<'a> From<Cow<'a, str>> for Value {
fn from(v: Cow<'a, str>) -> Self {
Self::String(v.into_owned())
}
}
impl From<List> for Value {
fn from(v: List) -> Self {
Self::List(v)
}
}
impl From<Compound> for Value {
fn from(v: Compound) -> Self {
Self::Compound(v)
}
}
impl From<Vec<i32>> for Value {
fn from(v: Vec<i32>) -> Self {
Self::IntArray(v)
}
}
impl From<Vec<i64>> for Value {
fn from(v: Vec<i64>) -> Self {
Self::LongArray(v)
}
}
impl From<Vec<i8>> for List {
fn from(v: Vec<i8>) -> Self {
List::Byte(v)
}
}
impl From<Vec<i16>> for List {
fn from(v: Vec<i16>) -> Self {
List::Short(v)
}
}
impl From<Vec<i32>> for List {
fn from(v: Vec<i32>) -> Self {
List::Int(v)
}
}
impl From<Vec<i64>> for List {
fn from(v: Vec<i64>) -> Self {
List::Long(v)
}
}
impl From<Vec<f32>> for List {
fn from(v: Vec<f32>) -> Self {
List::Float(v)
}
}
impl From<Vec<f64>> for List {
fn from(v: Vec<f64>) -> Self {
List::Double(v)
}
}
impl From<Vec<Vec<i8>>> for List {
fn from(v: Vec<Vec<i8>>) -> Self {
List::ByteArray(v)
}
}
impl From<Vec<String>> for List {
fn from(v: Vec<String>) -> Self {
List::String(v)
}
}
impl From<Vec<List>> for List {
fn from(v: Vec<List>) -> Self {
List::List(v)
}
}
impl From<Vec<Compound>> for List {
fn from(v: Vec<Compound>) -> Self {
List::Compound(v)
}
}
impl From<Vec<Vec<i32>>> for List {
fn from(v: Vec<Vec<i32>>) -> Self {
List::IntArray(v)
}
}
impl From<Vec<Vec<i64>>> for List {
fn from(v: Vec<Vec<i64>>) -> Self {
List::LongArray(v)
}
}
impl Serialize for Value {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Value::Byte(v) => v.serialize(serializer),
Value::Short(v) => v.serialize(serializer),
Value::Int(v) => v.serialize(serializer),
Value::Long(v) => v.serialize(serializer),
Value::Float(v) => v.serialize(serializer),
Value::Double(v) => v.serialize(serializer),
Value::ByteArray(v) => byte_array(v, serializer),
Value::String(v) => v.serialize(serializer),
Value::List(v) => v.serialize(serializer),
Value::Compound(v) => v.serialize(serializer),
Value::IntArray(v) => int_array(v, serializer),
Value::LongArray(v) => long_array(v, serializer),
}
}
}
impl Serialize for List {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
List::Byte(l) => l.serialize(serializer),
List::Short(l) => l.serialize(serializer),
List::Int(l) => l.serialize(serializer),
List::Long(l) => l.serialize(serializer),
List::Float(l) => l.serialize(serializer),
List::Double(l) => l.serialize(serializer),
List::ByteArray(l) => l.serialize(serializer),
List::String(l) => l.serialize(serializer),
List::List(l) => l.serialize(serializer),
List::Compound(l) => l.serialize(serializer),
List::IntArray(l) => l.serialize(serializer),
List::LongArray(l) => l.serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for Value {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(ValueVisitor)
}
}
struct ValueVisitor;
impl<'de> Visitor<'de> for ValueVisitor {
type Value = Value;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "a representable NBT value")
}
fn visit_i8<E>(self, v: i8) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Value::Byte(v))
}
fn visit_i16<E>(self, v: i16) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Value::Short(v))
}
fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Value::Int(v))
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Value::Long(v))
}
fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Value::Float(v))
}
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Value::Double(v))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Value::String(v))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Value::String(v.to_owned()))
}
fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
ListVisitor.visit_seq(seq).map(Value::List)
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
visit_map(map).map(Value::Compound)
}
}
impl<'de> Deserialize<'de> for List {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_seq(ListVisitor)
}
}
struct ListVisitor;
impl<'de> Visitor<'de> for ListVisitor {
type Value = List;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "an NBT list")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut list = List::Byte(Vec::new());
while seq
.next_element_seed(DeserializeListElement(&mut list))?
.is_some()
{}
Ok(list)
}
}
struct DeserializeListElement<'a>(&'a mut List);
impl<'de, 'a> DeserializeSeed<'de> for DeserializeListElement<'a> {
type Value = ();
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(self)
}
}
macro_rules! visit {
($self:expr, $variant:ident, $value:expr, $error:ty) => {
if $self.0.is_empty() {
*$self.0 = List::$variant(vec![$value]);
Ok(())
} else if let List::$variant(elems) = $self.0 {
elems.push($value);
Ok(())
} else {
Err(<$error>::custom("NBT lists must be homogenous"))
}
};
}
impl<'de, 'a> Visitor<'de> for DeserializeListElement<'a> {
type Value = ();
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "a valid NBT list element")
}
fn visit_i8<E>(self, v: i8) -> Result<Self::Value, E>
where
E: Error,
{
visit!(self, Byte, v, E)
}
fn visit_i16<E>(self, v: i16) -> Result<Self::Value, E>
where
E: Error,
{
visit!(self, Short, v, E)
}
fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
where
E: Error,
{
visit!(self, Int, v, E)
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: Error,
{
visit!(self, Long, v, E)
}
fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
where
E: Error,
{
visit!(self, Float, v, E)
}
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
where
E: Error,
{
visit!(self, Double, v, E)
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: Error,
{
visit!(self, String, v, E)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
visit!(self, String, v.to_owned(), E)
}
fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
visit!(self, List, ListVisitor.visit_seq(seq)?, A::Error)
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
visit!(self, Compound, visit_map(map)?, A::Error)
}
}
fn visit_map<'de, A>(mut map: A) -> Result<Compound, A::Error>
where
A: MapAccess<'de>,
{
let mut compound = Compound::new();
while let Some((k, v)) = map.next_entry()? {
compound.insert(k, v);
}
Ok(compound)
}