diff --git a/Cargo.toml b/Cargo.toml index c0535a5..29f2ca8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ lto = "fat" serde = {version = "1.0", features = ["derive"]} bincode = "1.3" warp = {version = "0.2", default-features = false} -tokio = {version = "0.3", features = ["macros", "net", "time", "stream", "rt-threaded", "fs", "sync", "blocking", "signal"]} +tokio = {version = "0.3", features = ["macros", "io-util", "net", "time", "stream", "rt-multi-thread", "fs", "sync", "signal"]} tokio-util = {version = "0.4", features = ["compat"]} binascii = "0.1" toml = "0.5" @@ -23,7 +23,3 @@ serde_json = "1.0" futures = "0.3" async-compression = {version = "0.3", features = ["bzip2", "futures-bufread", "futures-write"]} chrono = "0.4" - -[patch.crates-io] -tokio = {git = "https://github.com/tokio-rs/tokio", rev = "3fd043931e6d37f211e682980edc6e12e9d4fc54", features = ["macros", "net", "time", "stream", "rt-threaded", "fs", "sync", "blocking", "signal"]} -tokio-util = {git = "https://github.com/tokio-rs/tokio", rev = "3fd043931e6d37f211e682980edc6e12e9d4fc54", features = ["compat"]} diff --git a/src/server.rs b/src/server.rs index 8583c8e..3444c6d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -107,8 +107,7 @@ struct UDPScrapeResponseEntry { } pub struct UDPTracker { - srv_send: tokio::net::udp::SendHalf, - srv_recv: Option, + srv: tokio::net::UdpSocket, tracker: std::sync::Arc, config: Arc, } @@ -119,12 +118,10 @@ impl UDPTracker { ) -> Result { let cfg = config.clone(); - let server = UdpSocket::bind(cfg.get_udp_config().get_address()).await?; - let (srv_recv, srv_send) = server.split(); + let srv = UdpSocket::bind(cfg.get_udp_config().get_address()).await?; Ok(UDPTracker { - srv_send, - srv_recv: Some(srv_recv), + srv, tracker, config: cfg, }) @@ -360,12 +357,13 @@ impl UDPTracker { } async fn send_packet(&self, remote_addr: &SocketAddr, payload: &[u8]) -> Result { - tokio::future::poll_fn(|cx| self.srv_send.as_ref().poll_send_to(cx, payload, remote_addr)) - .await - .map_err(|e| { - debug!("failed to send a packet: {}", e); - e - }) + match self.srv.send_to(payload, remote_addr).await { + Err(err) => { + debug!("failed to send a packet: {}", err); + Err(err) + }, + Ok(sz) => Ok(sz), + } } async fn send_error(&self, remote_addr: &SocketAddr, header: &UDPRequestHeader, error_msg: &str) { @@ -383,13 +381,12 @@ impl UDPTracker { } } - pub async fn accept_packets(mut self) -> Result<(), std::io::Error> { - let mut recv = self.srv_recv.take().unwrap(); + pub async fn accept_packets(self) -> Result<(), std::io::Error> { let tracker = Arc::new(self); loop { let mut packet = vec![0u8; MAX_PACKET_SIZE]; - let (size, remote_address) = recv.recv_from(packet.as_mut_slice()).await?; + let (size, remote_address) = tracker.srv.recv_from(packet.as_mut_slice()).await?; let tracker = tracker.clone(); tokio::spawn(async move {