From 8b738777a8c3b6294e20f017685bb58f56ea74ee Mon Sep 17 00:00:00 2001 From: Alex Janka Date: Sun, 25 Feb 2024 11:35:58 +1100 Subject: [PATCH] discovery and reconnection actually works!!! --- homekit-controller/src/homekit_http.rs | 105 ++++------ homekit-controller/src/lib.rs | 219 ++++++++++++++++----- homekit-controller/src/pairing_data/mod.rs | 2 +- homekit-exporter/src/main.rs | 12 +- homekit-exporter/src/server.rs | 6 +- 5 files changed, 213 insertions(+), 131 deletions(-) diff --git a/homekit-controller/src/homekit_http.rs b/homekit-controller/src/homekit_http.rs index addc465..a4a110e 100644 --- a/homekit-controller/src/homekit_http.rs +++ b/homekit-controller/src/homekit_http.rs @@ -3,20 +3,18 @@ use std::{collections::HashMap, time::Duration}; use chacha20poly1305::{ aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce, }; -use futures_util::{pin_mut, StreamExt}; use http::{Method, Request}; -use mdns::RecordKind; use thiserror::Error; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, - time::timeout, }; use crate::{ pairing_data::{Accessory, ServiceCharacteristic}, + spawn_discover_thread, tlv8::TlvEncode, - HomekitError, + ConnectionError, HomekitError, MdnsDiscoveredList, }; pub(super) struct AccessorySocket { @@ -44,82 +42,41 @@ impl SocketEncryption { } } -const SERVICE_NAME: &str = "_hap._tcp.local"; - -pub async fn discover( - duration_seconds: u64, +async fn reconnect( pairing_id: &str, -) -> Result<(String, u16), DiscoveryError> { - let stream = mdns::discover::all(SERVICE_NAME, Duration::from_secs(1))?.listen(); + discovered: &MdnsDiscoveredList, +) -> Result { + let (hostname, port) = discovered + .read() + .await + .get(pairing_id) + .ok_or(ConnectionError::Discovery(DiscoveryError::NotFound))? + .clone(); - pin_mut!(stream); - while let Ok(Some(Ok(response))) = - timeout(Duration::from_secs(duration_seconds), stream.next()).await - { - if let Some(name) = response.additional.iter().find_map(|record| { - if let RecordKind::TXT(v) = &record.kind { - if v.contains(&format!("id={pairing_id}")) { - return Some(record.name.clone()); - } - } - None - }) { - log::info!("got name {name}"); - if let Some((target, port)) = response.additional.iter().find_map(|record| { - if record.name == name { - if let RecordKind::SRV { - priority: _, - weight: _, - port, - target, - } = &record.kind - { - return Some((target.clone(), *port)); - } - } - None - }) { - if let Some(ip) = response.additional.iter().find_map(|record| { - if record.name == target { - if let RecordKind::A(ip) = record.kind { - return Some(ip); - } - } - None - }) { - return Ok((ip.to_string(), port)); - } - } + let socket = match TcpStream::connect(format!("{hostname}:{port}")).await { + Ok(socket) => socket, + Err(_) => { + spawn_discover_thread(discovered.clone())?; + tokio::time::sleep(Duration::from_secs(1)).await; + TcpStream::connect(format!("{hostname}:{port}")).await? } - - // if let Some(addr) = addr { - // println!("found cast device at {}", addr); - // } else { - // println!("cast device does not advertise address"); - // } - } - - todo!() -} - -async fn reconnect(pairing_id: &str) -> Result { - log::warn!("error connecting to device..."); - log::warn!("trying to find {pairing_id}'s ip/port via bonjour/mdns..."); - let (hostname, port) = discover(20, pairing_id).await?; - log::info!("successfully found device at {hostname}:{port}"); - let socket = TcpStream::connect(format!("{hostname}:{port}")).await?; - log::info!(" ...and connected!"); + }; Ok(socket) } impl AccessorySocket { - pub async fn new(pairing_id: &str, ip: &str, port: usize) -> Result { + pub async fn new( + pairing_id: &str, + ip: &str, + port: usize, + discovered: &MdnsDiscoveredList, + ) -> Result { let socket = tokio::select! { stream = TcpStream::connect(format!("{ip}:{port}")) => match stream { Ok(v) => v, - Err(_) => reconnect(pairing_id).await? + Err(_) => reconnect(pairing_id,discovered).await? }, - _ = tokio::time::sleep(Duration::from_secs(1)) => reconnect(pairing_id).await? + _ = tokio::time::sleep(Duration::from_secs(1)) => reconnect(pairing_id,discovered).await? }; Ok(Self { @@ -268,7 +225,7 @@ impl AccessorySocket { let header_size = match result { httparse::Status::Complete(header_size) => Ok(header_size), - httparse::Status::Partial => Err(HomekitError::Http), + httparse::Status::Partial => Err(ConnectionError::Http), }?; let mut packet = packet[header_size..].to_vec(); @@ -335,7 +292,7 @@ impl AccessorySocket { while read_num == 0 { if tries > 20 { log::error!("unsuccessfully tried to reconnect"); - return Err(HomekitError::Http); + return Err(ConnectionError::Http.into()); } tries += 1; log::info!("read 0 bytes - about to reconnect"); @@ -368,10 +325,16 @@ impl AccessorySocket { Ok(buf[..read_num].to_vec()) } } + + pub(super) async fn disconnect(&mut self) -> Result<(), std::io::Error> { + self.socket.shutdown().await + } } #[derive(Debug, Error)] pub enum DiscoveryError { #[error("mdns")] Mdns(#[from] mdns::Error), + #[error("not found")] + NotFound, } diff --git a/homekit-controller/src/lib.rs b/homekit-controller/src/lib.rs index 30c2785..e7eefea 100644 --- a/homekit-controller/src/lib.rs +++ b/homekit-controller/src/lib.rs @@ -2,15 +2,18 @@ use chacha20poly1305::{ aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce, }; use ed25519_dalek::{Signer, Verifier}; +use futures_util::{pin_mut, StreamExt}; use hkdf::Hkdf; use homekit_http::DiscoveryError; +use mdns::RecordKind; use sha2::Sha512; -use std::{collections::HashMap, path::PathBuf}; +use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; use thiserror::Error; use tlv8::{HomekitState, TlvEncode, TlvError, TlvType}; +use tokio::sync::RwLock; use x25519_dalek::{EphemeralSecret, PublicKey}; -use pairing_data::{Accessory, DevicePairingData}; +use pairing_data::{Accessory, PythonPairingData}; pub use crate::pairing_data::{CharacteristicType, Data, ServiceType}; use crate::{ @@ -22,21 +25,123 @@ mod homekit_http; mod pairing_data; mod tlv8; -pub fn load(pairing_data: PathBuf) -> Result, HomekitError> { +pub fn load(pairing_data: PathBuf) -> Result, HomekitError> { Ok(serde_json::from_str(&std::fs::read_to_string( pairing_data, )?)?) } -impl DevicePairingData { - pub async fn connect(&self) -> Result { +impl PythonPairingData { + pub async fn connect( + &self, + discovered: &MdnsDiscoveredList, + ) -> Result { + let mut connected_device: DeviceConnection = self.to_connection(discovered.clone()); + connected_device.connect().await?; + connected_device.characteristics_request(true).await?; + + Ok(connected_device) + } + + fn to_connection(&self, discovered: MdnsDiscoveredList) -> DeviceConnection { + DeviceConnection { + accessories: Default::default(), + discovered, + pairing_data: DevicePairingData { + accessory_pairing_id: self.accessory_pairing_id.clone(), + accessory_ip: self.accessory_ip.clone(), + accessory_port: self.accessory_port, + accessory_ltpk: self.accessory_ltpk, + ios_pairing_id: self.ios_pairing_id.clone(), + ios_device_ltsk: self.ios_device_ltsk, + }, + socket: None, + } + } +} + +pub type MdnsDiscoveredList = Arc>>; + +pub fn spawn_discover_thread(discovered: MdnsDiscoveredList) -> Result<(), DiscoveryError> { + let stream = mdns::discover::all("_hap._tcp.local", Duration::from_secs(1))?.listen(); + tokio::task::spawn(async move { + pin_mut!(stream); + while let Some(Ok(response)) = stream.next().await { + let all = response + .answers + .iter() + .chain(response.additional.iter()) + .collect::>(); + if let Some((name, id)) = all.iter().find_map(|record| { + if let RecordKind::TXT(v) = &record.kind { + if let Some(id_string) = v.iter().find(|v| v.contains("id=")) { + let id = id_string[3..].to_string(); + return Some((record.name.clone(), id)); + } + } + None + }) { + if let Some((target, port)) = all.iter().find_map(|record| { + if record.name == name { + if let RecordKind::SRV { + priority: _, + weight: _, + port, + target, + } = &record.kind + { + return Some((target.clone(), *port)); + } + } + None + }) { + if let Some(ip) = all.iter().find_map(|record| { + if record.name == target { + if let RecordKind::A(ip) = record.kind { + return Some(ip); + } + } + None + }) { + let mut connections = discovered.write().await; + connections.insert(id, (ip.to_string(), port)); + } + } + } + } + }); + Ok(()) +} + +struct DevicePairingData { + accessory_pairing_id: String, + accessory_ip: String, + accessory_port: usize, + accessory_ltpk: [u8; 32], + ios_pairing_id: String, + ios_device_ltsk: [u8; 32], +} + +pub struct DeviceConnection { + pub accessories: HashMap, + discovered: MdnsDiscoveredList, + pairing_data: DevicePairingData, + socket: Option, +} + +impl DeviceConnection { + async fn connect(&mut self) -> Result<(), HomekitError> { + if let Some(mut socket) = self.socket.take() { + socket.disconnect().await?; + } let key = EphemeralSecret::random(); let pubkey = PublicKey::from(&key); let mut socket = AccessorySocket::new( - &self.accessory_pairing_id, - &self.accessory_ip, - self.accessory_port, + &self.pairing_data.accessory_pairing_id, + &self.pairing_data.accessory_ip, + self.pairing_data.accessory_port, + &self.discovered, ) .await?; @@ -113,12 +218,13 @@ impl DevicePairingData { .get(&TlvType::Signature.into()) .ok_or(HomekitError::TlvNotFound)?; - if accessory_identifier != self.accessory_pairing_id.as_bytes() { + if accessory_identifier != self.pairing_data.accessory_pairing_id.as_bytes() { return Err(HomekitError::Auth); } // 5. Get accessory LTPK - let accessory_ltpk = ed25519_dalek::VerifyingKey::from_bytes(&self.accessory_ltpk)?; + let accessory_ltpk = + ed25519_dalek::VerifyingKey::from_bytes(&self.pairing_data.accessory_ltpk)?; let mut accessory_info = accessory_pubkey_bytes.to_vec(); accessory_info.extend_from_slice(accessory_identifier); accessory_info.extend_from_slice(pubkey.as_bytes()); @@ -135,19 +241,19 @@ impl DevicePairingData { // 7. Construct iOSDeviceInfo let ios_device_info = { let mut buf = pubkey.as_bytes().to_vec(); - buf.extend_from_slice(self.ios_pairing_id.as_bytes()); + buf.extend_from_slice(self.pairing_data.ios_pairing_id.as_bytes()); buf.extend_from_slice(response_pubkey); buf }; // 8. Use Ed25519 togenerate iOSDeviceSignature by signing iOSDeviceInfo with its long-term secret key, iOSDeviceLTSK - let signing_key = ed25519_dalek::SigningKey::from_bytes(&self.ios_device_ltsk); + let signing_key = ed25519_dalek::SigningKey::from_bytes(&self.pairing_data.ios_device_ltsk); let signature = signing_key.sign(&ios_device_info); // 9. Construct sub-TLV let mut encrypted_tlv = ([ ( TlvType::Identifier.into(), - self.ios_pairing_id.encode_value(), + self.pairing_data.ios_pairing_id.encode_value(), ), ( TlvType::Signature.into(), @@ -234,58 +340,59 @@ impl DevicePairingData { )?; socket.set_encryption(controller_to_accessory_key, accessory_to_controller_key); - - let mut connected_device = ConnectedDevice { - accessories: socket.get_accessories().await?, - socket, - }; - connected_device.characteristics_request(true).await?; - - Ok(connected_device) + self.accessories = socket.get_accessories().await?; + self.socket = Some(socket); + Ok(()) } -} -pub struct ConnectedDevice { - pub accessories: HashMap, - socket: AccessorySocket, -} - -impl ConnectedDevice { pub async fn update_characteristics(&mut self) -> Result<(), HomekitError> { self.characteristics_request(true).await } async fn characteristics_request(&mut self, additional_data: bool) -> Result<(), HomekitError> { - for (aid, data) in &mut self.accessories { - for service in data.services.values_mut() { - let characteristic_ids = service - .characteristics - .keys() - .map(|k| format!("{aid}.{k}")) - .collect::>(); - let characteristics = self - .socket - .get_characteristics(&characteristic_ids, additional_data) - .await?; - for (cid, c) in &characteristics { - if c.characteristic_type == CharacteristicType::Name { - if let Some(Data::String(name)) = &c.value { - service.name = Some(name.clone()); + 'outer: loop { + if self.socket.is_none() { + self.connect().await?; + } + if let Some(socket) = &mut self.socket { + for (aid, data) in &mut self.accessories { + for service in data.services.values_mut() { + let characteristic_ids = service + .characteristics + .keys() + .map(|k| format!("{aid}.{k}")) + .collect::>(); + let characteristics = match socket + .get_characteristics(&characteristic_ids, additional_data) + .await + { + Ok(val) => val, + Err(_e) => { + continue 'outer; + } + }; + for (cid, c) in &characteristics { + if c.characteristic_type == CharacteristicType::Name { + if let Some(Data::String(name)) = &c.value { + service.name = Some(name.clone()); + } + } + if let Some(prev) = service.characteristics.get_mut(cid) { + prev.update_from(c); + } } } - if let Some(prev) = service.characteristics.get_mut(cid) { - prev.update_from(c); - } } } + return Ok(()); } - - Ok(()) } } #[derive(Debug, Error)] pub enum HomekitError { + #[error("connection")] + Connection(#[from] ConnectionError), #[error("io")] Io(#[from] std::io::Error), #[error("serde_json")] @@ -318,8 +425,6 @@ pub enum HomekitError { InvalidUri(#[from] http::uri::InvalidUri), #[error("parsing response")] ResponseParse(#[from] httparse::Error), - #[error("http")] - Http, #[error("something else")] SomethingElse(String), #[error("addr parse")] @@ -328,12 +433,22 @@ pub enum HomekitError { TlvDeviceError(TlvError), #[error("parsing utf-8")] Utf8(#[from] std::string::FromUtf8Error), - #[error("discovery")] - Discovery(#[from] DiscoveryError), #[error("device not found")] DeviceNotFound, #[error("timeout")] Timeout(#[from] tokio::time::error::Elapsed), + #[error("discovery")] + Discovery(#[from] DiscoveryError), +} + +#[derive(Debug, Error)] +pub enum ConnectionError { + #[error("http")] + Http, + #[error("discovery")] + Discovery(#[from] DiscoveryError), + #[error("io")] + Io(#[from] std::io::Error), } impl From for HomekitError { diff --git a/homekit-controller/src/pairing_data/mod.rs b/homekit-controller/src/pairing_data/mod.rs index 8767f5b..9f62121 100644 --- a/homekit-controller/src/pairing_data/mod.rs +++ b/homekit-controller/src/pairing_data/mod.rs @@ -8,7 +8,7 @@ pub use characteristics::{CharacteristicType, ServiceType}; mod characteristics; #[derive(Deserialize, Debug, Clone)] -pub struct DevicePairingData { +pub struct PythonPairingData { #[serde(rename = "AccessoryPairingID")] pub accessory_pairing_id: String, #[serde(with = "hex::serde", rename = "AccessoryLTPK")] diff --git a/homekit-exporter/src/main.rs b/homekit-exporter/src/main.rs index 4e58f28..e3de484 100644 --- a/homekit-exporter/src/main.rs +++ b/homekit-exporter/src/main.rs @@ -1,11 +1,12 @@ #[macro_use] extern crate rocket; -use std::{collections::HashMap, path::PathBuf, time::Duration}; +use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; use clap::Parser; -use homekit_controller::{ConnectedDevice, HomekitError, ServiceType}; +use homekit_controller::{spawn_discover_thread, DeviceConnection, HomekitError, ServiceType}; use server::launch; +use tokio::sync::RwLock; mod server; @@ -40,14 +41,17 @@ async fn rocket() -> rocket::Rocket { } } -async fn init(pairing_data: PathBuf) -> Result, HomekitError> { +async fn init(pairing_data: PathBuf) -> Result, HomekitError> { + let discovered = Arc::new(RwLock::new(HashMap::new())); + spawn_discover_thread(discovered.clone())?; + tokio::time::sleep(Duration::from_secs(1)).await; if pairing_data.is_file() { let devices = homekit_controller::load(pairing_data)?; let mut connected_devices = HashMap::new(); for (k, v) in devices { let mut num = 0; let connected = loop { - if let Ok(v) = v.connect().await { + if let Ok(v) = v.connect(&discovered).await { break Some(v); } num += 1; diff --git a/homekit-exporter/src/server.rs b/homekit-exporter/src/server.rs index cb247b4..094ef98 100644 --- a/homekit-exporter/src/server.rs +++ b/homekit-exporter/src/server.rs @@ -1,18 +1,18 @@ -use homekit_controller::{ConnectedDevice, Data}; +use homekit_controller::{Data, DeviceConnection}; use rocket::State; use std::{collections::HashMap, ops::DerefMut}; use tokio::sync::Mutex; use crate::SENSORS; -pub fn launch(paired: HashMap) -> rocket::Rocket { +pub fn launch(paired: HashMap) -> rocket::Rocket { rocket::build() .manage(Mutex::new(paired)) .mount("/", routes![metrics]) } #[get("/metrics")] -pub async fn metrics(state: &State>>) -> Option { +pub async fn metrics(state: &State>>) -> Option { let mut s = String::new(); let mut state = state.lock().await; let mut shown_types = Vec::new();