discovery and reconnection actually works!!!

This commit is contained in:
Alex Janka 2024-02-25 11:35:58 +11:00
parent 878521fcfb
commit 8b738777a8
5 changed files with 213 additions and 131 deletions

View file

@ -3,20 +3,18 @@ use std::{collections::HashMap, time::Duration};
use chacha20poly1305::{ use chacha20poly1305::{
aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce, aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce,
}; };
use futures_util::{pin_mut, StreamExt};
use http::{Method, Request}; use http::{Method, Request};
use mdns::RecordKind;
use thiserror::Error; use thiserror::Error;
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream, net::TcpStream,
time::timeout,
}; };
use crate::{ use crate::{
pairing_data::{Accessory, ServiceCharacteristic}, pairing_data::{Accessory, ServiceCharacteristic},
spawn_discover_thread,
tlv8::TlvEncode, tlv8::TlvEncode,
HomekitError, ConnectionError, HomekitError, MdnsDiscoveredList,
}; };
pub(super) struct AccessorySocket { pub(super) struct AccessorySocket {
@ -44,82 +42,41 @@ impl SocketEncryption {
} }
} }
const SERVICE_NAME: &str = "_hap._tcp.local"; async fn reconnect(
pub async fn discover(
duration_seconds: u64,
pairing_id: &str, pairing_id: &str,
) -> Result<(String, u16), DiscoveryError> { discovered: &MdnsDiscoveredList,
let stream = mdns::discover::all(SERVICE_NAME, Duration::from_secs(1))?.listen(); ) -> Result<TcpStream, ConnectionError> {
let (hostname, port) = discovered
.read()
.await
.get(pairing_id)
.ok_or(ConnectionError::Discovery(DiscoveryError::NotFound))?
.clone();
pin_mut!(stream); let socket = match TcpStream::connect(format!("{hostname}:{port}")).await {
while let Ok(Some(Ok(response))) = Ok(socket) => socket,
timeout(Duration::from_secs(duration_seconds), stream.next()).await Err(_) => {
{ spawn_discover_thread(discovered.clone())?;
if let Some(name) = response.additional.iter().find_map(|record| { tokio::time::sleep(Duration::from_secs(1)).await;
if let RecordKind::TXT(v) = &record.kind { TcpStream::connect(format!("{hostname}:{port}")).await?
if v.contains(&format!("id={pairing_id}")) {
return Some(record.name.clone());
} }
} };
None
}) {
log::info!("got name {name}");
if let Some((target, port)) = response.additional.iter().find_map(|record| {
if record.name == name {
if let RecordKind::SRV {
priority: _,
weight: _,
port,
target,
} = &record.kind
{
return Some((target.clone(), *port));
}
}
None
}) {
if let Some(ip) = response.additional.iter().find_map(|record| {
if record.name == target {
if let RecordKind::A(ip) = record.kind {
return Some(ip);
}
}
None
}) {
return Ok((ip.to_string(), port));
}
}
}
// if let Some(addr) = addr {
// println!("found cast device at {}", addr);
// } else {
// println!("cast device does not advertise address");
// }
}
todo!()
}
async fn reconnect(pairing_id: &str) -> Result<TcpStream, HomekitError> {
log::warn!("error connecting to device...");
log::warn!("trying to find {pairing_id}'s ip/port via bonjour/mdns...");
let (hostname, port) = discover(20, pairing_id).await?;
log::info!("successfully found device at {hostname}:{port}");
let socket = TcpStream::connect(format!("{hostname}:{port}")).await?;
log::info!(" ...and connected!");
Ok(socket) Ok(socket)
} }
impl AccessorySocket { impl AccessorySocket {
pub async fn new(pairing_id: &str, ip: &str, port: usize) -> Result<Self, HomekitError> { pub async fn new(
pairing_id: &str,
ip: &str,
port: usize,
discovered: &MdnsDiscoveredList,
) -> Result<Self, ConnectionError> {
let socket = tokio::select! { let socket = tokio::select! {
stream = TcpStream::connect(format!("{ip}:{port}")) => match stream { stream = TcpStream::connect(format!("{ip}:{port}")) => match stream {
Ok(v) => v, Ok(v) => v,
Err(_) => reconnect(pairing_id).await? Err(_) => reconnect(pairing_id,discovered).await?
}, },
_ = tokio::time::sleep(Duration::from_secs(1)) => reconnect(pairing_id).await? _ = tokio::time::sleep(Duration::from_secs(1)) => reconnect(pairing_id,discovered).await?
}; };
Ok(Self { Ok(Self {
@ -268,7 +225,7 @@ impl AccessorySocket {
let header_size = match result { let header_size = match result {
httparse::Status::Complete(header_size) => Ok(header_size), httparse::Status::Complete(header_size) => Ok(header_size),
httparse::Status::Partial => Err(HomekitError::Http), httparse::Status::Partial => Err(ConnectionError::Http),
}?; }?;
let mut packet = packet[header_size..].to_vec(); let mut packet = packet[header_size..].to_vec();
@ -335,7 +292,7 @@ impl AccessorySocket {
while read_num == 0 { while read_num == 0 {
if tries > 20 { if tries > 20 {
log::error!("unsuccessfully tried to reconnect"); log::error!("unsuccessfully tried to reconnect");
return Err(HomekitError::Http); return Err(ConnectionError::Http.into());
} }
tries += 1; tries += 1;
log::info!("read 0 bytes - about to reconnect"); log::info!("read 0 bytes - about to reconnect");
@ -368,10 +325,16 @@ impl AccessorySocket {
Ok(buf[..read_num].to_vec()) Ok(buf[..read_num].to_vec())
} }
} }
pub(super) async fn disconnect(&mut self) -> Result<(), std::io::Error> {
self.socket.shutdown().await
}
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum DiscoveryError { pub enum DiscoveryError {
#[error("mdns")] #[error("mdns")]
Mdns(#[from] mdns::Error), Mdns(#[from] mdns::Error),
#[error("not found")]
NotFound,
} }

View file

@ -2,15 +2,18 @@ use chacha20poly1305::{
aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce, aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce,
}; };
use ed25519_dalek::{Signer, Verifier}; use ed25519_dalek::{Signer, Verifier};
use futures_util::{pin_mut, StreamExt};
use hkdf::Hkdf; use hkdf::Hkdf;
use homekit_http::DiscoveryError; use homekit_http::DiscoveryError;
use mdns::RecordKind;
use sha2::Sha512; use sha2::Sha512;
use std::{collections::HashMap, path::PathBuf}; use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
use thiserror::Error; use thiserror::Error;
use tlv8::{HomekitState, TlvEncode, TlvError, TlvType}; use tlv8::{HomekitState, TlvEncode, TlvError, TlvType};
use tokio::sync::RwLock;
use x25519_dalek::{EphemeralSecret, PublicKey}; use x25519_dalek::{EphemeralSecret, PublicKey};
use pairing_data::{Accessory, DevicePairingData}; use pairing_data::{Accessory, PythonPairingData};
pub use crate::pairing_data::{CharacteristicType, Data, ServiceType}; pub use crate::pairing_data::{CharacteristicType, Data, ServiceType};
use crate::{ use crate::{
@ -22,21 +25,123 @@ mod homekit_http;
mod pairing_data; mod pairing_data;
mod tlv8; mod tlv8;
pub fn load(pairing_data: PathBuf) -> Result<HashMap<String, DevicePairingData>, HomekitError> { pub fn load(pairing_data: PathBuf) -> Result<HashMap<String, PythonPairingData>, HomekitError> {
Ok(serde_json::from_str(&std::fs::read_to_string( Ok(serde_json::from_str(&std::fs::read_to_string(
pairing_data, pairing_data,
)?)?) )?)?)
} }
impl DevicePairingData { impl PythonPairingData {
pub async fn connect(&self) -> Result<ConnectedDevice, HomekitError> { pub async fn connect(
&self,
discovered: &MdnsDiscoveredList,
) -> Result<DeviceConnection, HomekitError> {
let mut connected_device: DeviceConnection = self.to_connection(discovered.clone());
connected_device.connect().await?;
connected_device.characteristics_request(true).await?;
Ok(connected_device)
}
fn to_connection(&self, discovered: MdnsDiscoveredList) -> DeviceConnection {
DeviceConnection {
accessories: Default::default(),
discovered,
pairing_data: DevicePairingData {
accessory_pairing_id: self.accessory_pairing_id.clone(),
accessory_ip: self.accessory_ip.clone(),
accessory_port: self.accessory_port,
accessory_ltpk: self.accessory_ltpk,
ios_pairing_id: self.ios_pairing_id.clone(),
ios_device_ltsk: self.ios_device_ltsk,
},
socket: None,
}
}
}
pub type MdnsDiscoveredList = Arc<RwLock<HashMap<String, (String, u16)>>>;
pub fn spawn_discover_thread(discovered: MdnsDiscoveredList) -> Result<(), DiscoveryError> {
let stream = mdns::discover::all("_hap._tcp.local", Duration::from_secs(1))?.listen();
tokio::task::spawn(async move {
pin_mut!(stream);
while let Some(Ok(response)) = stream.next().await {
let all = response
.answers
.iter()
.chain(response.additional.iter())
.collect::<Vec<_>>();
if let Some((name, id)) = all.iter().find_map(|record| {
if let RecordKind::TXT(v) = &record.kind {
if let Some(id_string) = v.iter().find(|v| v.contains("id=")) {
let id = id_string[3..].to_string();
return Some((record.name.clone(), id));
}
}
None
}) {
if let Some((target, port)) = all.iter().find_map(|record| {
if record.name == name {
if let RecordKind::SRV {
priority: _,
weight: _,
port,
target,
} = &record.kind
{
return Some((target.clone(), *port));
}
}
None
}) {
if let Some(ip) = all.iter().find_map(|record| {
if record.name == target {
if let RecordKind::A(ip) = record.kind {
return Some(ip);
}
}
None
}) {
let mut connections = discovered.write().await;
connections.insert(id, (ip.to_string(), port));
}
}
}
}
});
Ok(())
}
struct DevicePairingData {
accessory_pairing_id: String,
accessory_ip: String,
accessory_port: usize,
accessory_ltpk: [u8; 32],
ios_pairing_id: String,
ios_device_ltsk: [u8; 32],
}
pub struct DeviceConnection {
pub accessories: HashMap<usize, Accessory>,
discovered: MdnsDiscoveredList,
pairing_data: DevicePairingData,
socket: Option<AccessorySocket>,
}
impl DeviceConnection {
async fn connect(&mut self) -> Result<(), HomekitError> {
if let Some(mut socket) = self.socket.take() {
socket.disconnect().await?;
}
let key = EphemeralSecret::random(); let key = EphemeralSecret::random();
let pubkey = PublicKey::from(&key); let pubkey = PublicKey::from(&key);
let mut socket = AccessorySocket::new( let mut socket = AccessorySocket::new(
&self.accessory_pairing_id, &self.pairing_data.accessory_pairing_id,
&self.accessory_ip, &self.pairing_data.accessory_ip,
self.accessory_port, self.pairing_data.accessory_port,
&self.discovered,
) )
.await?; .await?;
@ -113,12 +218,13 @@ impl DevicePairingData {
.get(&TlvType::Signature.into()) .get(&TlvType::Signature.into())
.ok_or(HomekitError::TlvNotFound)?; .ok_or(HomekitError::TlvNotFound)?;
if accessory_identifier != self.accessory_pairing_id.as_bytes() { if accessory_identifier != self.pairing_data.accessory_pairing_id.as_bytes() {
return Err(HomekitError::Auth); return Err(HomekitError::Auth);
} }
// 5. Get accessory LTPK // 5. Get accessory LTPK
let accessory_ltpk = ed25519_dalek::VerifyingKey::from_bytes(&self.accessory_ltpk)?; let accessory_ltpk =
ed25519_dalek::VerifyingKey::from_bytes(&self.pairing_data.accessory_ltpk)?;
let mut accessory_info = accessory_pubkey_bytes.to_vec(); let mut accessory_info = accessory_pubkey_bytes.to_vec();
accessory_info.extend_from_slice(accessory_identifier); accessory_info.extend_from_slice(accessory_identifier);
accessory_info.extend_from_slice(pubkey.as_bytes()); accessory_info.extend_from_slice(pubkey.as_bytes());
@ -135,19 +241,19 @@ impl DevicePairingData {
// 7. Construct iOSDeviceInfo // 7. Construct iOSDeviceInfo
let ios_device_info = { let ios_device_info = {
let mut buf = pubkey.as_bytes().to_vec(); let mut buf = pubkey.as_bytes().to_vec();
buf.extend_from_slice(self.ios_pairing_id.as_bytes()); buf.extend_from_slice(self.pairing_data.ios_pairing_id.as_bytes());
buf.extend_from_slice(response_pubkey); buf.extend_from_slice(response_pubkey);
buf buf
}; };
// 8. Use Ed25519 togenerate iOSDeviceSignature by signing iOSDeviceInfo with its long-term secret key, iOSDeviceLTSK // 8. Use Ed25519 togenerate iOSDeviceSignature by signing iOSDeviceInfo with its long-term secret key, iOSDeviceLTSK
let signing_key = ed25519_dalek::SigningKey::from_bytes(&self.ios_device_ltsk); let signing_key = ed25519_dalek::SigningKey::from_bytes(&self.pairing_data.ios_device_ltsk);
let signature = signing_key.sign(&ios_device_info); let signature = signing_key.sign(&ios_device_info);
// 9. Construct sub-TLV // 9. Construct sub-TLV
let mut encrypted_tlv = ([ let mut encrypted_tlv = ([
( (
TlvType::Identifier.into(), TlvType::Identifier.into(),
self.ios_pairing_id.encode_value(), self.pairing_data.ios_pairing_id.encode_value(),
), ),
( (
TlvType::Signature.into(), TlvType::Signature.into(),
@ -234,28 +340,21 @@ impl DevicePairingData {
)?; )?;
socket.set_encryption(controller_to_accessory_key, accessory_to_controller_key); socket.set_encryption(controller_to_accessory_key, accessory_to_controller_key);
self.accessories = socket.get_accessories().await?;
let mut connected_device = ConnectedDevice { self.socket = Some(socket);
accessories: socket.get_accessories().await?, Ok(())
socket,
};
connected_device.characteristics_request(true).await?;
Ok(connected_device)
}
} }
pub struct ConnectedDevice {
pub accessories: HashMap<usize, Accessory>,
socket: AccessorySocket,
}
impl ConnectedDevice {
pub async fn update_characteristics(&mut self) -> Result<(), HomekitError> { pub async fn update_characteristics(&mut self) -> Result<(), HomekitError> {
self.characteristics_request(true).await self.characteristics_request(true).await
} }
async fn characteristics_request(&mut self, additional_data: bool) -> Result<(), HomekitError> { async fn characteristics_request(&mut self, additional_data: bool) -> Result<(), HomekitError> {
'outer: loop {
if self.socket.is_none() {
self.connect().await?;
}
if let Some(socket) = &mut self.socket {
for (aid, data) in &mut self.accessories { for (aid, data) in &mut self.accessories {
for service in data.services.values_mut() { for service in data.services.values_mut() {
let characteristic_ids = service let characteristic_ids = service
@ -263,10 +362,15 @@ impl ConnectedDevice {
.keys() .keys()
.map(|k| format!("{aid}.{k}")) .map(|k| format!("{aid}.{k}"))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let characteristics = self let characteristics = match socket
.socket
.get_characteristics(&characteristic_ids, additional_data) .get_characteristics(&characteristic_ids, additional_data)
.await?; .await
{
Ok(val) => val,
Err(_e) => {
continue 'outer;
}
};
for (cid, c) in &characteristics { for (cid, c) in &characteristics {
if c.characteristic_type == CharacteristicType::Name { if c.characteristic_type == CharacteristicType::Name {
if let Some(Data::String(name)) = &c.value { if let Some(Data::String(name)) = &c.value {
@ -279,13 +383,16 @@ impl ConnectedDevice {
} }
} }
} }
}
Ok(()) return Ok(());
}
} }
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum HomekitError { pub enum HomekitError {
#[error("connection")]
Connection(#[from] ConnectionError),
#[error("io")] #[error("io")]
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
#[error("serde_json")] #[error("serde_json")]
@ -318,8 +425,6 @@ pub enum HomekitError {
InvalidUri(#[from] http::uri::InvalidUri), InvalidUri(#[from] http::uri::InvalidUri),
#[error("parsing response")] #[error("parsing response")]
ResponseParse(#[from] httparse::Error), ResponseParse(#[from] httparse::Error),
#[error("http")]
Http,
#[error("something else")] #[error("something else")]
SomethingElse(String), SomethingElse(String),
#[error("addr parse")] #[error("addr parse")]
@ -328,12 +433,22 @@ pub enum HomekitError {
TlvDeviceError(TlvError), TlvDeviceError(TlvError),
#[error("parsing utf-8")] #[error("parsing utf-8")]
Utf8(#[from] std::string::FromUtf8Error), Utf8(#[from] std::string::FromUtf8Error),
#[error("discovery")]
Discovery(#[from] DiscoveryError),
#[error("device not found")] #[error("device not found")]
DeviceNotFound, DeviceNotFound,
#[error("timeout")] #[error("timeout")]
Timeout(#[from] tokio::time::error::Elapsed), Timeout(#[from] tokio::time::error::Elapsed),
#[error("discovery")]
Discovery(#[from] DiscoveryError),
}
#[derive(Debug, Error)]
pub enum ConnectionError {
#[error("http")]
Http,
#[error("discovery")]
Discovery(#[from] DiscoveryError),
#[error("io")]
Io(#[from] std::io::Error),
} }
impl From<TlvError> for HomekitError { impl From<TlvError> for HomekitError {

View file

@ -8,7 +8,7 @@ pub use characteristics::{CharacteristicType, ServiceType};
mod characteristics; mod characteristics;
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]
pub struct DevicePairingData { pub struct PythonPairingData {
#[serde(rename = "AccessoryPairingID")] #[serde(rename = "AccessoryPairingID")]
pub accessory_pairing_id: String, pub accessory_pairing_id: String,
#[serde(with = "hex::serde", rename = "AccessoryLTPK")] #[serde(with = "hex::serde", rename = "AccessoryLTPK")]

View file

@ -1,11 +1,12 @@
#[macro_use] #[macro_use]
extern crate rocket; extern crate rocket;
use std::{collections::HashMap, path::PathBuf, time::Duration}; use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
use clap::Parser; use clap::Parser;
use homekit_controller::{ConnectedDevice, HomekitError, ServiceType}; use homekit_controller::{spawn_discover_thread, DeviceConnection, HomekitError, ServiceType};
use server::launch; use server::launch;
use tokio::sync::RwLock;
mod server; mod server;
@ -40,14 +41,17 @@ async fn rocket() -> rocket::Rocket<rocket::Build> {
} }
} }
async fn init(pairing_data: PathBuf) -> Result<HashMap<String, ConnectedDevice>, HomekitError> { async fn init(pairing_data: PathBuf) -> Result<HashMap<String, DeviceConnection>, HomekitError> {
let discovered = Arc::new(RwLock::new(HashMap::new()));
spawn_discover_thread(discovered.clone())?;
tokio::time::sleep(Duration::from_secs(1)).await;
if pairing_data.is_file() { if pairing_data.is_file() {
let devices = homekit_controller::load(pairing_data)?; let devices = homekit_controller::load(pairing_data)?;
let mut connected_devices = HashMap::new(); let mut connected_devices = HashMap::new();
for (k, v) in devices { for (k, v) in devices {
let mut num = 0; let mut num = 0;
let connected = loop { let connected = loop {
if let Ok(v) = v.connect().await { if let Ok(v) = v.connect(&discovered).await {
break Some(v); break Some(v);
} }
num += 1; num += 1;

View file

@ -1,18 +1,18 @@
use homekit_controller::{ConnectedDevice, Data}; use homekit_controller::{Data, DeviceConnection};
use rocket::State; use rocket::State;
use std::{collections::HashMap, ops::DerefMut}; use std::{collections::HashMap, ops::DerefMut};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::SENSORS; use crate::SENSORS;
pub fn launch(paired: HashMap<String, ConnectedDevice>) -> rocket::Rocket<rocket::Build> { pub fn launch(paired: HashMap<String, DeviceConnection>) -> rocket::Rocket<rocket::Build> {
rocket::build() rocket::build()
.manage(Mutex::new(paired)) .manage(Mutex::new(paired))
.mount("/", routes![metrics]) .mount("/", routes![metrics])
} }
#[get("/metrics")] #[get("/metrics")]
pub async fn metrics(state: &State<Mutex<HashMap<String, ConnectedDevice>>>) -> Option<String> { pub async fn metrics(state: &State<Mutex<HashMap<String, DeviceConnection>>>) -> Option<String> {
let mut s = String::new(); let mut s = String::new();
let mut state = state.lock().await; let mut state = state.lock().await;
let mut shown_types = Vec::new(); let mut shown_types = Vec::new();