discovery and reconnection actually works!!!

This commit is contained in:
Alex Janka 2024-02-25 11:35:58 +11:00
parent 878521fcfb
commit 8b738777a8
5 changed files with 213 additions and 131 deletions

View file

@ -3,20 +3,18 @@ use std::{collections::HashMap, time::Duration};
use chacha20poly1305::{
aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce,
};
use futures_util::{pin_mut, StreamExt};
use http::{Method, Request};
use mdns::RecordKind;
use thiserror::Error;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
time::timeout,
};
use crate::{
pairing_data::{Accessory, ServiceCharacteristic},
spawn_discover_thread,
tlv8::TlvEncode,
HomekitError,
ConnectionError, HomekitError, MdnsDiscoveredList,
};
pub(super) struct AccessorySocket {
@ -44,82 +42,41 @@ impl SocketEncryption {
}
}
const SERVICE_NAME: &str = "_hap._tcp.local";
pub async fn discover(
duration_seconds: u64,
async fn reconnect(
pairing_id: &str,
) -> Result<(String, u16), DiscoveryError> {
let stream = mdns::discover::all(SERVICE_NAME, Duration::from_secs(1))?.listen();
discovered: &MdnsDiscoveredList,
) -> Result<TcpStream, ConnectionError> {
let (hostname, port) = discovered
.read()
.await
.get(pairing_id)
.ok_or(ConnectionError::Discovery(DiscoveryError::NotFound))?
.clone();
pin_mut!(stream);
while let Ok(Some(Ok(response))) =
timeout(Duration::from_secs(duration_seconds), stream.next()).await
{
if let Some(name) = response.additional.iter().find_map(|record| {
if let RecordKind::TXT(v) = &record.kind {
if v.contains(&format!("id={pairing_id}")) {
return Some(record.name.clone());
}
}
None
}) {
log::info!("got name {name}");
if let Some((target, port)) = response.additional.iter().find_map(|record| {
if record.name == name {
if let RecordKind::SRV {
priority: _,
weight: _,
port,
target,
} = &record.kind
{
return Some((target.clone(), *port));
}
}
None
}) {
if let Some(ip) = response.additional.iter().find_map(|record| {
if record.name == target {
if let RecordKind::A(ip) = record.kind {
return Some(ip);
}
}
None
}) {
return Ok((ip.to_string(), port));
}
}
let socket = match TcpStream::connect(format!("{hostname}:{port}")).await {
Ok(socket) => socket,
Err(_) => {
spawn_discover_thread(discovered.clone())?;
tokio::time::sleep(Duration::from_secs(1)).await;
TcpStream::connect(format!("{hostname}:{port}")).await?
}
// if let Some(addr) = addr {
// println!("found cast device at {}", addr);
// } else {
// println!("cast device does not advertise address");
// }
}
todo!()
}
async fn reconnect(pairing_id: &str) -> Result<TcpStream, HomekitError> {
log::warn!("error connecting to device...");
log::warn!("trying to find {pairing_id}'s ip/port via bonjour/mdns...");
let (hostname, port) = discover(20, pairing_id).await?;
log::info!("successfully found device at {hostname}:{port}");
let socket = TcpStream::connect(format!("{hostname}:{port}")).await?;
log::info!(" ...and connected!");
};
Ok(socket)
}
impl AccessorySocket {
pub async fn new(pairing_id: &str, ip: &str, port: usize) -> Result<Self, HomekitError> {
pub async fn new(
pairing_id: &str,
ip: &str,
port: usize,
discovered: &MdnsDiscoveredList,
) -> Result<Self, ConnectionError> {
let socket = tokio::select! {
stream = TcpStream::connect(format!("{ip}:{port}")) => match stream {
Ok(v) => v,
Err(_) => reconnect(pairing_id).await?
Err(_) => reconnect(pairing_id,discovered).await?
},
_ = tokio::time::sleep(Duration::from_secs(1)) => reconnect(pairing_id).await?
_ = tokio::time::sleep(Duration::from_secs(1)) => reconnect(pairing_id,discovered).await?
};
Ok(Self {
@ -268,7 +225,7 @@ impl AccessorySocket {
let header_size = match result {
httparse::Status::Complete(header_size) => Ok(header_size),
httparse::Status::Partial => Err(HomekitError::Http),
httparse::Status::Partial => Err(ConnectionError::Http),
}?;
let mut packet = packet[header_size..].to_vec();
@ -335,7 +292,7 @@ impl AccessorySocket {
while read_num == 0 {
if tries > 20 {
log::error!("unsuccessfully tried to reconnect");
return Err(HomekitError::Http);
return Err(ConnectionError::Http.into());
}
tries += 1;
log::info!("read 0 bytes - about to reconnect");
@ -368,10 +325,16 @@ impl AccessorySocket {
Ok(buf[..read_num].to_vec())
}
}
pub(super) async fn disconnect(&mut self) -> Result<(), std::io::Error> {
self.socket.shutdown().await
}
}
#[derive(Debug, Error)]
pub enum DiscoveryError {
#[error("mdns")]
Mdns(#[from] mdns::Error),
#[error("not found")]
NotFound,
}

View file

@ -2,15 +2,18 @@ use chacha20poly1305::{
aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce,
};
use ed25519_dalek::{Signer, Verifier};
use futures_util::{pin_mut, StreamExt};
use hkdf::Hkdf;
use homekit_http::DiscoveryError;
use mdns::RecordKind;
use sha2::Sha512;
use std::{collections::HashMap, path::PathBuf};
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
use thiserror::Error;
use tlv8::{HomekitState, TlvEncode, TlvError, TlvType};
use tokio::sync::RwLock;
use x25519_dalek::{EphemeralSecret, PublicKey};
use pairing_data::{Accessory, DevicePairingData};
use pairing_data::{Accessory, PythonPairingData};
pub use crate::pairing_data::{CharacteristicType, Data, ServiceType};
use crate::{
@ -22,21 +25,123 @@ mod homekit_http;
mod pairing_data;
mod tlv8;
pub fn load(pairing_data: PathBuf) -> Result<HashMap<String, DevicePairingData>, HomekitError> {
pub fn load(pairing_data: PathBuf) -> Result<HashMap<String, PythonPairingData>, HomekitError> {
Ok(serde_json::from_str(&std::fs::read_to_string(
pairing_data,
)?)?)
}
impl DevicePairingData {
pub async fn connect(&self) -> Result<ConnectedDevice, HomekitError> {
impl PythonPairingData {
pub async fn connect(
&self,
discovered: &MdnsDiscoveredList,
) -> Result<DeviceConnection, HomekitError> {
let mut connected_device: DeviceConnection = self.to_connection(discovered.clone());
connected_device.connect().await?;
connected_device.characteristics_request(true).await?;
Ok(connected_device)
}
fn to_connection(&self, discovered: MdnsDiscoveredList) -> DeviceConnection {
DeviceConnection {
accessories: Default::default(),
discovered,
pairing_data: DevicePairingData {
accessory_pairing_id: self.accessory_pairing_id.clone(),
accessory_ip: self.accessory_ip.clone(),
accessory_port: self.accessory_port,
accessory_ltpk: self.accessory_ltpk,
ios_pairing_id: self.ios_pairing_id.clone(),
ios_device_ltsk: self.ios_device_ltsk,
},
socket: None,
}
}
}
pub type MdnsDiscoveredList = Arc<RwLock<HashMap<String, (String, u16)>>>;
pub fn spawn_discover_thread(discovered: MdnsDiscoveredList) -> Result<(), DiscoveryError> {
let stream = mdns::discover::all("_hap._tcp.local", Duration::from_secs(1))?.listen();
tokio::task::spawn(async move {
pin_mut!(stream);
while let Some(Ok(response)) = stream.next().await {
let all = response
.answers
.iter()
.chain(response.additional.iter())
.collect::<Vec<_>>();
if let Some((name, id)) = all.iter().find_map(|record| {
if let RecordKind::TXT(v) = &record.kind {
if let Some(id_string) = v.iter().find(|v| v.contains("id=")) {
let id = id_string[3..].to_string();
return Some((record.name.clone(), id));
}
}
None
}) {
if let Some((target, port)) = all.iter().find_map(|record| {
if record.name == name {
if let RecordKind::SRV {
priority: _,
weight: _,
port,
target,
} = &record.kind
{
return Some((target.clone(), *port));
}
}
None
}) {
if let Some(ip) = all.iter().find_map(|record| {
if record.name == target {
if let RecordKind::A(ip) = record.kind {
return Some(ip);
}
}
None
}) {
let mut connections = discovered.write().await;
connections.insert(id, (ip.to_string(), port));
}
}
}
}
});
Ok(())
}
struct DevicePairingData {
accessory_pairing_id: String,
accessory_ip: String,
accessory_port: usize,
accessory_ltpk: [u8; 32],
ios_pairing_id: String,
ios_device_ltsk: [u8; 32],
}
pub struct DeviceConnection {
pub accessories: HashMap<usize, Accessory>,
discovered: MdnsDiscoveredList,
pairing_data: DevicePairingData,
socket: Option<AccessorySocket>,
}
impl DeviceConnection {
async fn connect(&mut self) -> Result<(), HomekitError> {
if let Some(mut socket) = self.socket.take() {
socket.disconnect().await?;
}
let key = EphemeralSecret::random();
let pubkey = PublicKey::from(&key);
let mut socket = AccessorySocket::new(
&self.accessory_pairing_id,
&self.accessory_ip,
self.accessory_port,
&self.pairing_data.accessory_pairing_id,
&self.pairing_data.accessory_ip,
self.pairing_data.accessory_port,
&self.discovered,
)
.await?;
@ -113,12 +218,13 @@ impl DevicePairingData {
.get(&TlvType::Signature.into())
.ok_or(HomekitError::TlvNotFound)?;
if accessory_identifier != self.accessory_pairing_id.as_bytes() {
if accessory_identifier != self.pairing_data.accessory_pairing_id.as_bytes() {
return Err(HomekitError::Auth);
}
// 5. Get accessory LTPK
let accessory_ltpk = ed25519_dalek::VerifyingKey::from_bytes(&self.accessory_ltpk)?;
let accessory_ltpk =
ed25519_dalek::VerifyingKey::from_bytes(&self.pairing_data.accessory_ltpk)?;
let mut accessory_info = accessory_pubkey_bytes.to_vec();
accessory_info.extend_from_slice(accessory_identifier);
accessory_info.extend_from_slice(pubkey.as_bytes());
@ -135,19 +241,19 @@ impl DevicePairingData {
// 7. Construct iOSDeviceInfo
let ios_device_info = {
let mut buf = pubkey.as_bytes().to_vec();
buf.extend_from_slice(self.ios_pairing_id.as_bytes());
buf.extend_from_slice(self.pairing_data.ios_pairing_id.as_bytes());
buf.extend_from_slice(response_pubkey);
buf
};
// 8. Use Ed25519 togenerate iOSDeviceSignature by signing iOSDeviceInfo with its long-term secret key, iOSDeviceLTSK
let signing_key = ed25519_dalek::SigningKey::from_bytes(&self.ios_device_ltsk);
let signing_key = ed25519_dalek::SigningKey::from_bytes(&self.pairing_data.ios_device_ltsk);
let signature = signing_key.sign(&ios_device_info);
// 9. Construct sub-TLV
let mut encrypted_tlv = ([
(
TlvType::Identifier.into(),
self.ios_pairing_id.encode_value(),
self.pairing_data.ios_pairing_id.encode_value(),
),
(
TlvType::Signature.into(),
@ -234,58 +340,59 @@ impl DevicePairingData {
)?;
socket.set_encryption(controller_to_accessory_key, accessory_to_controller_key);
let mut connected_device = ConnectedDevice {
accessories: socket.get_accessories().await?,
socket,
};
connected_device.characteristics_request(true).await?;
Ok(connected_device)
self.accessories = socket.get_accessories().await?;
self.socket = Some(socket);
Ok(())
}
}
pub struct ConnectedDevice {
pub accessories: HashMap<usize, Accessory>,
socket: AccessorySocket,
}
impl ConnectedDevice {
pub async fn update_characteristics(&mut self) -> Result<(), HomekitError> {
self.characteristics_request(true).await
}
async fn characteristics_request(&mut self, additional_data: bool) -> Result<(), HomekitError> {
for (aid, data) in &mut self.accessories {
for service in data.services.values_mut() {
let characteristic_ids = service
.characteristics
.keys()
.map(|k| format!("{aid}.{k}"))
.collect::<Vec<_>>();
let characteristics = self
.socket
.get_characteristics(&characteristic_ids, additional_data)
.await?;
for (cid, c) in &characteristics {
if c.characteristic_type == CharacteristicType::Name {
if let Some(Data::String(name)) = &c.value {
service.name = Some(name.clone());
'outer: loop {
if self.socket.is_none() {
self.connect().await?;
}
if let Some(socket) = &mut self.socket {
for (aid, data) in &mut self.accessories {
for service in data.services.values_mut() {
let characteristic_ids = service
.characteristics
.keys()
.map(|k| format!("{aid}.{k}"))
.collect::<Vec<_>>();
let characteristics = match socket
.get_characteristics(&characteristic_ids, additional_data)
.await
{
Ok(val) => val,
Err(_e) => {
continue 'outer;
}
};
for (cid, c) in &characteristics {
if c.characteristic_type == CharacteristicType::Name {
if let Some(Data::String(name)) = &c.value {
service.name = Some(name.clone());
}
}
if let Some(prev) = service.characteristics.get_mut(cid) {
prev.update_from(c);
}
}
}
if let Some(prev) = service.characteristics.get_mut(cid) {
prev.update_from(c);
}
}
}
return Ok(());
}
Ok(())
}
}
#[derive(Debug, Error)]
pub enum HomekitError {
#[error("connection")]
Connection(#[from] ConnectionError),
#[error("io")]
Io(#[from] std::io::Error),
#[error("serde_json")]
@ -318,8 +425,6 @@ pub enum HomekitError {
InvalidUri(#[from] http::uri::InvalidUri),
#[error("parsing response")]
ResponseParse(#[from] httparse::Error),
#[error("http")]
Http,
#[error("something else")]
SomethingElse(String),
#[error("addr parse")]
@ -328,12 +433,22 @@ pub enum HomekitError {
TlvDeviceError(TlvError),
#[error("parsing utf-8")]
Utf8(#[from] std::string::FromUtf8Error),
#[error("discovery")]
Discovery(#[from] DiscoveryError),
#[error("device not found")]
DeviceNotFound,
#[error("timeout")]
Timeout(#[from] tokio::time::error::Elapsed),
#[error("discovery")]
Discovery(#[from] DiscoveryError),
}
#[derive(Debug, Error)]
pub enum ConnectionError {
#[error("http")]
Http,
#[error("discovery")]
Discovery(#[from] DiscoveryError),
#[error("io")]
Io(#[from] std::io::Error),
}
impl From<TlvError> for HomekitError {

View file

@ -8,7 +8,7 @@ pub use characteristics::{CharacteristicType, ServiceType};
mod characteristics;
#[derive(Deserialize, Debug, Clone)]
pub struct DevicePairingData {
pub struct PythonPairingData {
#[serde(rename = "AccessoryPairingID")]
pub accessory_pairing_id: String,
#[serde(with = "hex::serde", rename = "AccessoryLTPK")]

View file

@ -1,11 +1,12 @@
#[macro_use]
extern crate rocket;
use std::{collections::HashMap, path::PathBuf, time::Duration};
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
use clap::Parser;
use homekit_controller::{ConnectedDevice, HomekitError, ServiceType};
use homekit_controller::{spawn_discover_thread, DeviceConnection, HomekitError, ServiceType};
use server::launch;
use tokio::sync::RwLock;
mod server;
@ -40,14 +41,17 @@ async fn rocket() -> rocket::Rocket<rocket::Build> {
}
}
async fn init(pairing_data: PathBuf) -> Result<HashMap<String, ConnectedDevice>, HomekitError> {
async fn init(pairing_data: PathBuf) -> Result<HashMap<String, DeviceConnection>, HomekitError> {
let discovered = Arc::new(RwLock::new(HashMap::new()));
spawn_discover_thread(discovered.clone())?;
tokio::time::sleep(Duration::from_secs(1)).await;
if pairing_data.is_file() {
let devices = homekit_controller::load(pairing_data)?;
let mut connected_devices = HashMap::new();
for (k, v) in devices {
let mut num = 0;
let connected = loop {
if let Ok(v) = v.connect().await {
if let Ok(v) = v.connect(&discovered).await {
break Some(v);
}
num += 1;

View file

@ -1,18 +1,18 @@
use homekit_controller::{ConnectedDevice, Data};
use homekit_controller::{Data, DeviceConnection};
use rocket::State;
use std::{collections::HashMap, ops::DerefMut};
use tokio::sync::Mutex;
use crate::SENSORS;
pub fn launch(paired: HashMap<String, ConnectedDevice>) -> rocket::Rocket<rocket::Build> {
pub fn launch(paired: HashMap<String, DeviceConnection>) -> rocket::Rocket<rocket::Build> {
rocket::build()
.manage(Mutex::new(paired))
.mount("/", routes![metrics])
}
#[get("/metrics")]
pub async fn metrics(state: &State<Mutex<HashMap<String, ConnectedDevice>>>) -> Option<String> {
pub async fn metrics(state: &State<Mutex<HashMap<String, DeviceConnection>>>) -> Option<String> {
let mut s = String::new();
let mut state = state.lock().await;
let mut shown_types = Vec::new();