make everything async
This commit is contained in:
parent
71ea023e5b
commit
d030871337
1541
Cargo.lock
generated
1541
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -11,15 +11,11 @@ lto = "fat"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
serde = {version = "1.0", features = ["derive"]}
|
serde = {version = "1.0", features = ["derive"]}
|
||||||
bincode = "1.2"
|
bincode = "1.2"
|
||||||
actix-web = "0.7"
|
warp = {version = "0.2", default-features = false}
|
||||||
actix-net = "0.2"
|
tokio = {version = "0.2", features = ["macros", "net", "rt-threaded", "fs", "sync", "blocking", "signal"]}
|
||||||
binascii = "0.1"
|
binascii = "0.1"
|
||||||
toml = "0.5"
|
toml = "0.5"
|
||||||
clap = "2.33"
|
clap = "2.33"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
fern = "0.6"
|
fern = "0.6"
|
||||||
num_cpus = "1.13"
|
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
bzip2 = "0.3"
|
|
||||||
futures = "0.1"
|
|
||||||
lazy_static = "1.4"
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
|
pub use crate::tracker::TrackerMode;
|
||||||
|
use serde::Deserialize;
|
||||||
use std;
|
use std;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use toml;
|
use toml;
|
||||||
pub use crate::tracker::TrackerMode;
|
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct UDPConfig {
|
pub struct UDPConfig {
|
||||||
|
@ -70,10 +70,12 @@ impl Configuration {
|
||||||
pub fn load_file(path: &str) -> Result<Configuration, ConfigError> {
|
pub fn load_file(path: &str) -> Result<Configuration, ConfigError> {
|
||||||
match std::fs::read(path) {
|
match std::fs::read(path) {
|
||||||
Err(e) => Err(ConfigError::IOError(e)),
|
Err(e) => Err(ConfigError::IOError(e)),
|
||||||
Ok(data) => match Self::load(data.as_slice()) {
|
Ok(data) => {
|
||||||
|
match Self::load(data.as_slice()) {
|
||||||
Ok(cfg) => Ok(cfg),
|
Ok(cfg) => Ok(cfg),
|
||||||
Err(e) => Err(ConfigError::ParseError(e)),
|
Err(e) => Err(ConfigError::ParseError(e)),
|
||||||
},
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,16 +91,16 @@ impl Configuration {
|
||||||
&self.log_level
|
&self.log_level
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_http_config(&self) -> &Option<HTTPConfig> {
|
pub fn get_http_config(&self) -> Option<&HTTPConfig> {
|
||||||
&self.http
|
self.http.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_db_path(&self) -> &Option<String> {
|
pub fn get_db_path(&self) -> &Option<String> {
|
||||||
&self.db_path
|
&self.db_path
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_cleanup_interval(&self) -> &Option<u64> {
|
pub fn get_cleanup_interval(&self) -> Option<u64> {
|
||||||
&self.cleanup_interval
|
self.cleanup_interval
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
181
src/main.rs
181
src/main.rs
|
@ -1,10 +1,8 @@
|
||||||
#![forbid(unsafe_code)]
|
#![forbid(unsafe_code)]
|
||||||
|
|
||||||
use clap;
|
use clap;
|
||||||
use log::{trace, warn, info, debug, error};
|
|
||||||
use fern;
|
use fern;
|
||||||
use num_cpus;
|
use log::{error, info, trace, warn};
|
||||||
use lazy_static::lazy_static;
|
|
||||||
|
|
||||||
mod config;
|
mod config;
|
||||||
mod server;
|
mod server;
|
||||||
|
@ -15,14 +13,11 @@ mod webserver;
|
||||||
use config::Configuration;
|
use config::Configuration;
|
||||||
use std::process::exit;
|
use std::process::exit;
|
||||||
|
|
||||||
lazy_static!{
|
|
||||||
static ref term_mutex: std::sync::Arc<std::sync::atomic::AtomicBool> = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn setup_logging(cfg: &Configuration) {
|
fn setup_logging(cfg: &Configuration) {
|
||||||
let log_level = match cfg.get_log_level() {
|
let log_level = match cfg.get_log_level() {
|
||||||
None => log::LevelFilter::Info,
|
None => log::LevelFilter::Info,
|
||||||
Some(level) => match level.as_str() {
|
Some(level) => {
|
||||||
|
match level.as_str() {
|
||||||
"off" => log::LevelFilter::Off,
|
"off" => log::LevelFilter::Off,
|
||||||
"trace" => log::LevelFilter::Trace,
|
"trace" => log::LevelFilter::Trace,
|
||||||
"debug" => log::LevelFilter::Debug,
|
"debug" => log::LevelFilter::Debug,
|
||||||
|
@ -33,7 +28,8 @@ fn setup_logging(cfg: &Configuration) {
|
||||||
eprintln!("udpt: unknown log level encountered '{}'", level.as_str());
|
eprintln!("udpt: unknown log level encountered '{}'", level.as_str());
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(err) = fern::Dispatch::new()
|
if let Err(err) = fern::Dispatch::new()
|
||||||
|
@ -59,11 +55,8 @@ fn setup_logging(cfg: &Configuration) {
|
||||||
info!("logging initialized.");
|
info!("logging initialized.");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn signal_termination() {
|
#[tokio::main]
|
||||||
term_mutex.store(true, std::sync::atomic::Ordering::Relaxed);
|
async fn main() {
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let parser = clap::App::new(env!("CARGO_PKG_NAME"))
|
let parser = clap::App::new(env!("CARGO_PKG_NAME"))
|
||||||
.about(env!("CARGO_PKG_DESCRIPTION"))
|
.about(env!("CARGO_PKG_DESCRIPTION"))
|
||||||
.author(env!("CARGO_PKG_AUTHORS"))
|
.author(env!("CARGO_PKG_AUTHORS"))
|
||||||
|
@ -95,17 +88,19 @@ fn main() {
|
||||||
if !file_path.exists() {
|
if !file_path.exists() {
|
||||||
warn!("database file \"{}\" doesn't exist.", path);
|
warn!("database file \"{}\" doesn't exist.", path);
|
||||||
tracker::TorrentTracker::new(cfg.get_mode().clone())
|
tracker::TorrentTracker::new(cfg.get_mode().clone())
|
||||||
}
|
} else {
|
||||||
else {
|
let mut input_file = match tokio::fs::File::open(file_path).await {
|
||||||
let mut input_file = match std::fs::File::open(file_path) {
|
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("failed to open \"{}\". error: {}", path.as_str(), err);
|
error!("failed to open \"{}\". error: {}", path.as_str(), err);
|
||||||
panic!("error opening file. check logs.");
|
panic!("error opening file. check logs.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
match tracker::TorrentTracker::load_database(cfg.get_mode().clone(), &mut input_file) {
|
match tracker::TorrentTracker::load_database(cfg.get_mode().clone(), &mut input_file).await {
|
||||||
Ok(v) => v,
|
Ok(v) => {
|
||||||
|
info!("database loaded.");
|
||||||
|
v
|
||||||
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("failed to load database. error: {}", err);
|
error!("failed to load database. error: {}", err);
|
||||||
panic!("failed to load database. check logs.");
|
panic!("failed to load database. check logs.");
|
||||||
|
@ -116,111 +111,67 @@ fn main() {
|
||||||
None => tracker::TorrentTracker::new(cfg.get_mode().clone()),
|
None => tracker::TorrentTracker::new(cfg.get_mode().clone()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut threads = Vec::new();
|
|
||||||
|
|
||||||
let tracker = std::sync::Arc::new(tracker_obj);
|
let tracker = std::sync::Arc::new(tracker_obj);
|
||||||
|
|
||||||
let http_server = if cfg.get_http_config().is_some() {
|
if cfg.get_http_config().is_some() {
|
||||||
let http_tracker_ref = tracker.clone();
|
let https_tracker = tracker.clone();
|
||||||
let cfg_ref = cfg.clone();
|
let http_cfg = cfg.clone();
|
||||||
|
|
||||||
Some(webserver::WebServer::new(http_tracker_ref, cfg_ref))
|
info!("Starting http server");
|
||||||
} else {
|
tokio::spawn(async move {
|
||||||
None
|
let http_cfg = http_cfg.get_http_config().unwrap();
|
||||||
};
|
let bind_addr = http_cfg.get_address();
|
||||||
|
let tokens = http_cfg.get_access_tokens();
|
||||||
|
|
||||||
let udp_server = std::sync::Arc::new(server::UDPTracker::new(cfg.clone(), tracker.clone()).unwrap());
|
let server = webserver::build_server(https_tracker, tokens.clone());
|
||||||
|
server.bind(bind_addr.parse::<std::net::SocketAddr>().unwrap()).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut udp_server = server::UDPTracker::new(cfg.clone(), tracker.clone())
|
||||||
|
.await
|
||||||
|
.expect("failed to bind udp socket");
|
||||||
|
|
||||||
trace!("Waiting for UDP packets");
|
trace!("Waiting for UDP packets");
|
||||||
let logical_cpus = num_cpus::get();
|
let udp_server = tokio::spawn(async move {
|
||||||
for i in 0..logical_cpus {
|
|
||||||
debug!("starting thread {}/{}", i + 1, logical_cpus);
|
|
||||||
let server_handle = udp_server.clone();
|
|
||||||
let thread_term_ref = term_mutex.clone();
|
|
||||||
threads.push(std::thread::spawn(move || loop {
|
|
||||||
match server_handle.accept_packet() {
|
|
||||||
Err(e) => {
|
|
||||||
if thread_term_ref.load(std::sync::atomic::Ordering::Relaxed) == true {
|
|
||||||
debug!("Thread terminating...");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
match e.kind() {
|
|
||||||
std::io::ErrorKind::TimedOut => {},
|
|
||||||
std::io::ErrorKind::WouldBlock => {},
|
|
||||||
_ => {
|
|
||||||
error!("Failed to process packet. {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => {}
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
match cfg.get_db_path() {
|
|
||||||
Some(db_path) => {
|
|
||||||
let db_p = db_path.clone();
|
|
||||||
let tracker_clone = tracker.clone();
|
|
||||||
let cleanup_interval = match *cfg.get_cleanup_interval() {
|
|
||||||
Some(v) => v,
|
|
||||||
None => 10 * 60,
|
|
||||||
};
|
|
||||||
|
|
||||||
let thread_term_mutex = term_mutex.clone();
|
|
||||||
threads.push(std::thread::spawn(move || {
|
|
||||||
let timeout = std::time::Duration::new(cleanup_interval, 0);
|
|
||||||
|
|
||||||
let timeout_start = std::time::Instant::now();
|
|
||||||
let mut timeout_remaining = timeout;
|
|
||||||
loop {
|
loop {
|
||||||
std::thread::park_timeout(std::time::Duration::new(cleanup_interval, 0));
|
if let Err(err) = udp_server.accept_packet().await {
|
||||||
|
eprintln!("error: {}", err);
|
||||||
if thread_term_mutex.load(std::sync::atomic::Ordering::Relaxed) {
|
}
|
||||||
debug!("Maintenance thread terminating.");
|
}
|
||||||
break;
|
});
|
||||||
}
|
|
||||||
|
|
||||||
let elapsed = std::time::Instant::now() - timeout_start;
|
|
||||||
if elapsed < timeout_remaining {
|
|
||||||
timeout_remaining = timeout - elapsed;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
timeout_remaining = timeout;
|
|
||||||
}
|
|
||||||
|
|
||||||
debug!("periodically saving database.");
|
|
||||||
tracker_clone.periodic_task(db_p.as_str());
|
|
||||||
debug!("database saved.");
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
None => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
loop {
|
|
||||||
if term_mutex.load(std::sync::atomic::Ordering::Relaxed) {
|
|
||||||
// termination signaled. start cleanup.
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
|
||||||
}
|
|
||||||
|
|
||||||
match http_server {
|
|
||||||
Some(v) => v.shutdown(),
|
|
||||||
None => {},
|
|
||||||
};
|
|
||||||
|
|
||||||
while !threads.is_empty() {
|
|
||||||
if let Some(thread) = threads.pop() {
|
|
||||||
thread.thread().unpark();
|
|
||||||
let _ = thread.join();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
let weak_tracker = std::sync::Arc::downgrade(&tracker);
|
||||||
if let Some(db_path) = cfg.get_db_path() {
|
if let Some(db_path) = cfg.get_db_path() {
|
||||||
info!("running final cleanup & saving database...");
|
let db_path = db_path.clone();
|
||||||
tracker.periodic_task(db_path.as_str());
|
let interval = cfg.get_cleanup_interval().unwrap_or(600);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let interval = std::time::Duration::from_secs(interval);
|
||||||
|
let mut interval = tokio::time::interval(interval);
|
||||||
|
interval.tick().await; // first tick is immediate...
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
if let Some(tracker) = weak_tracker.upgrade() {
|
||||||
|
tracker.periodic_task(&db_path).await;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let ctrl_c = tokio::signal::ctrl_c();
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = udp_server => { warn!("udp server exited.") },
|
||||||
|
_ = ctrl_c => { info!("CTRL-C, exiting...") },
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(path) = cfg.get_db_path() {
|
||||||
|
info!("saving database...");
|
||||||
|
tracker.periodic_task(path).await;
|
||||||
|
}
|
||||||
|
|
||||||
info!("goodbye.");
|
info!("goodbye.");
|
||||||
}
|
}
|
||||||
|
|
153
src/server.rs
153
src/server.rs
|
@ -1,8 +1,9 @@
|
||||||
|
use log::{debug, error, trace};
|
||||||
use std;
|
use std;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::net::{SocketAddr, UdpSocket};
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use log::{error, trace, debug};
|
use tokio::net::UdpSocket;
|
||||||
|
|
||||||
use bincode;
|
use bincode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -107,31 +108,18 @@ struct UDPScrapeResponseEntry {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct UDPTracker {
|
pub struct UDPTracker {
|
||||||
server: std::net::UdpSocket,
|
server: UdpSocket,
|
||||||
tracker: std::sync::Arc<tracker::TorrentTracker>,
|
tracker: std::sync::Arc<tracker::TorrentTracker>,
|
||||||
config: Arc<Configuration>,
|
config: Arc<Configuration>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UDPTracker {
|
impl UDPTracker {
|
||||||
pub fn new(
|
pub async fn new(
|
||||||
config: Arc<Configuration>,
|
config: Arc<Configuration>, tracker: std::sync::Arc<tracker::TorrentTracker>,
|
||||||
tracker: std::sync::Arc<tracker::TorrentTracker>
|
|
||||||
) -> Result<UDPTracker, std::io::Error> {
|
) -> Result<UDPTracker, std::io::Error> {
|
||||||
let cfg = config.clone();
|
let cfg = config.clone();
|
||||||
|
|
||||||
let server = match UdpSocket::bind(cfg.get_udp_config().get_address()) {
|
let server = UdpSocket::bind(cfg.get_udp_config().get_address()).await?;
|
||||||
Ok(s) => s,
|
|
||||||
Err(e) => {
|
|
||||||
return Err(e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match server.set_read_timeout(Some(std::time::Duration::from_secs(1))) {
|
|
||||||
Ok(_) => {},
|
|
||||||
Err(err) => {
|
|
||||||
error!("Failed to set read timeout on socket; will try to continue anyway. err: {}", err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(UDPTracker {
|
Ok(UDPTracker {
|
||||||
server,
|
server,
|
||||||
|
@ -140,7 +128,8 @@ impl UDPTracker {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_packet(&self, remote_address: &SocketAddr, payload: &[u8]) {
|
// TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
|
||||||
|
async fn handle_packet(&mut self, remote_address: &SocketAddr, payload: &[u8]) {
|
||||||
let header: UDPRequestHeader = match unpack(payload) {
|
let header: UDPRequestHeader = match unpack(payload) {
|
||||||
Some(val) => val,
|
Some(val) => val,
|
||||||
None => {
|
None => {
|
||||||
|
@ -150,9 +139,9 @@ impl UDPTracker {
|
||||||
};
|
};
|
||||||
|
|
||||||
match header.action {
|
match header.action {
|
||||||
Actions::Connect => self.handle_connect(remote_address, &header, payload),
|
Actions::Connect => self.handle_connect(remote_address, &header, payload).await,
|
||||||
Actions::Announce => self.handle_announce(remote_address, &header, payload),
|
Actions::Announce => self.handle_announce(remote_address, &header, payload).await,
|
||||||
Actions::Scrape => self.handle_scrape(remote_address, &header, payload),
|
Actions::Scrape => self.handle_scrape(remote_address, &header, payload).await,
|
||||||
_ => {
|
_ => {
|
||||||
trace!("invalid action from {}", remote_address);
|
trace!("invalid action from {}", remote_address);
|
||||||
// someone is playing around... ignore request.
|
// someone is playing around... ignore request.
|
||||||
|
@ -161,7 +150,8 @@ impl UDPTracker {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_connect(&self, remote_addr: &SocketAddr, header: &UDPRequestHeader, _payload: &[u8]) {
|
// TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
|
||||||
|
async fn handle_connect(&mut self, remote_addr: &SocketAddr, header: &UDPRequestHeader, _payload: &[u8]) {
|
||||||
if header.connection_id != PROTOCOL_ID {
|
if header.connection_id != PROTOCOL_ID {
|
||||||
trace!("Bad protocol magic from {}", remote_addr);
|
trace!("Bad protocol magic from {}", remote_addr);
|
||||||
return;
|
return;
|
||||||
|
@ -178,15 +168,16 @@ impl UDPTracker {
|
||||||
connection_id: conn_id,
|
connection_id: conn_id,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut payload_buffer = [0u8; MAX_PACKET_SIZE];
|
let mut payload_buffer = vec![0u8; MAX_PACKET_SIZE];
|
||||||
let mut payload = StackVec::from(&mut payload_buffer);
|
let mut payload = StackVec::from(payload_buffer.as_mut_slice());
|
||||||
|
|
||||||
if let Ok(_) = pack_into(&mut payload, &response) {
|
if let Ok(_) = pack_into(&mut payload, &response) {
|
||||||
let _ = self.send_packet(remote_addr, payload.as_slice());
|
let _ = self.send_packet(remote_addr, payload.as_slice()).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_announce(&self, remote_addr: &SocketAddr, header: &UDPRequestHeader, payload: &[u8]) {
|
// TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
|
||||||
|
async fn handle_announce(&mut self, remote_addr: &SocketAddr, header: &UDPRequestHeader, payload: &[u8]) {
|
||||||
if header.connection_id != self.get_connection_id(remote_addr) {
|
if header.connection_id != self.get_connection_id(remote_addr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -205,28 +196,23 @@ impl UDPTracker {
|
||||||
let bep41_payload = &payload[plen..];
|
let bep41_payload = &payload[plen..];
|
||||||
|
|
||||||
// TODO: process BEP0041 payload.
|
// TODO: process BEP0041 payload.
|
||||||
trace!(
|
trace!("BEP0041 payload of {} bytes from {}", bep41_payload.len(), remote_addr);
|
||||||
"BEP0041 payload of {} bytes from {}",
|
|
||||||
bep41_payload.len(),
|
|
||||||
remote_addr
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if packet.ip_address != 0 {
|
if packet.ip_address != 0 {
|
||||||
// TODO: allow configurability of ip address
|
// TODO: allow configurability of ip address
|
||||||
// for now, ignore request.
|
// for now, ignore request.
|
||||||
trace!(
|
trace!("announce request for other IP ignored. (from {})", remote_addr);
|
||||||
"announce request for other IP ignored. (from {})",
|
|
||||||
remote_addr
|
|
||||||
);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let client_addr = SocketAddr::new(remote_addr.ip(), packet.port);
|
let client_addr = SocketAddr::new(remote_addr.ip(), packet.port);
|
||||||
let info_hash = packet.info_hash.into();
|
let info_hash = packet.info_hash.into();
|
||||||
|
|
||||||
match self.tracker.update_torrent_and_get_stats(
|
match self
|
||||||
|
.tracker
|
||||||
|
.update_torrent_and_get_stats(
|
||||||
&info_hash,
|
&info_hash,
|
||||||
&packet.peer_id,
|
&packet.peer_id,
|
||||||
&client_addr,
|
&client_addr,
|
||||||
|
@ -234,25 +220,25 @@ impl UDPTracker {
|
||||||
packet.downloaded,
|
packet.downloaded,
|
||||||
packet.left,
|
packet.left,
|
||||||
packet.event,
|
packet.event,
|
||||||
) {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracker::TorrentStats::Stats {
|
tracker::TorrentStats::Stats {
|
||||||
leechers,
|
leechers,
|
||||||
complete: _,
|
complete: _,
|
||||||
seeders,
|
seeders,
|
||||||
} => {
|
} => {
|
||||||
let peers = match self.tracker.get_torrent_peers(&info_hash, &client_addr) {
|
let peers = match self.tracker.get_torrent_peers(&info_hash, &client_addr).await {
|
||||||
Some(v) => v,
|
Some(v) => v,
|
||||||
None => {
|
None => {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut payload_buffer = [0u8; MAX_PACKET_SIZE];
|
let mut payload_buffer = vec![0u8; MAX_PACKET_SIZE];
|
||||||
let mut payload = StackVec::from(&mut payload_buffer);
|
let mut payload = StackVec::from(&mut payload_buffer);
|
||||||
|
|
||||||
match pack_into(
|
match pack_into(&mut payload, &UDPAnnounceResponse {
|
||||||
&mut payload,
|
|
||||||
&UDPAnnounceResponse {
|
|
||||||
header: UDPResponseHeader {
|
header: UDPResponseHeader {
|
||||||
action: Actions::Announce,
|
action: Actions::Announce,
|
||||||
transaction_id: packet.header.transaction_id,
|
transaction_id: packet.header.transaction_id,
|
||||||
|
@ -260,8 +246,7 @@ impl UDPTracker {
|
||||||
seeders,
|
seeders,
|
||||||
interval: self.config.get_udp_config().get_announce_interval(),
|
interval: self.config.get_udp_config().get_announce_interval(),
|
||||||
leechers,
|
leechers,
|
||||||
},
|
}) {
|
||||||
) {
|
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
return;
|
return;
|
||||||
|
@ -279,24 +264,24 @@ impl UDPTracker {
|
||||||
};
|
};
|
||||||
|
|
||||||
let port_hton = client_addr.port().to_be();
|
let port_hton = client_addr.port().to_be();
|
||||||
let _ =
|
let _ = payload.write(&[(port_hton & 0xff) as u8, ((port_hton >> 8) & 0xff) as u8]);
|
||||||
payload.write(&[(port_hton & 0xff) as u8, ((port_hton >> 8) & 0xff) as u8]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = self.send_packet(&client_addr, payload.as_slice());
|
let _ = self.send_packet(&client_addr, payload.as_slice()).await;
|
||||||
}
|
}
|
||||||
tracker::TorrentStats::TorrentFlagged => {
|
tracker::TorrentStats::TorrentFlagged => {
|
||||||
self.send_error(&client_addr, &packet.header, "torrent flagged.");
|
self.send_error(&client_addr, &packet.header, "torrent flagged.").await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
tracker::TorrentStats::TorrentNotRegistered => {
|
tracker::TorrentStats::TorrentNotRegistered => {
|
||||||
self.send_error(&client_addr, &packet.header, "torrent not registered.");
|
self.send_error(&client_addr, &packet.header, "torrent not registered.").await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_scrape(&self, remote_addr: &SocketAddr, header: &UDPRequestHeader, payload: &[u8]) {
|
// TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
|
||||||
|
async fn handle_scrape(&mut self, remote_addr: &SocketAddr, header: &UDPRequestHeader, payload: &[u8]) {
|
||||||
if header.connection_id != self.get_connection_id(remote_addr) {
|
if header.connection_id != self.get_connection_id(remote_addr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -306,16 +291,17 @@ impl UDPTracker {
|
||||||
let mut response_buffer = [0u8; 8 + MAX_SCRAPE * 12];
|
let mut response_buffer = [0u8; 8 + MAX_SCRAPE * 12];
|
||||||
let mut response = StackVec::from(&mut response_buffer);
|
let mut response = StackVec::from(&mut response_buffer);
|
||||||
|
|
||||||
if pack_into(&mut response, &UDPResponseHeader{
|
if pack_into(&mut response, &UDPResponseHeader {
|
||||||
action: Actions::Scrape,
|
action: Actions::Scrape,
|
||||||
transaction_id: header.transaction_id,
|
transaction_id: header.transaction_id,
|
||||||
}).is_err() {
|
})
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
// not much we can do...
|
// not much we can do...
|
||||||
error!("failed to encode udp scrape response header.");
|
error!("failed to encode udp scrape response header.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// skip first 16 bytes for header...
|
// skip first 16 bytes for header...
|
||||||
let info_hash_array = &payload[16..];
|
let info_hash_array = &payload[16..];
|
||||||
|
|
||||||
|
@ -323,7 +309,8 @@ impl UDPTracker {
|
||||||
trace!("received weird length for scrape info_hash array (!mod20).");
|
trace!("received weird length for scrape info_hash array (!mod20).");
|
||||||
}
|
}
|
||||||
|
|
||||||
let db = self.tracker.get_database();
|
{
|
||||||
|
let db = self.tracker.get_database().await;
|
||||||
|
|
||||||
for torrent_index in 0..MAX_SCRAPE {
|
for torrent_index in 0..MAX_SCRAPE {
|
||||||
let info_hash_start = torrent_index * 20;
|
let info_hash_start = torrent_index * 20;
|
||||||
|
@ -339,14 +326,14 @@ impl UDPTracker {
|
||||||
Some(torrent_info) => {
|
Some(torrent_info) => {
|
||||||
let (seeders, completed, leechers) = torrent_info.get_stats();
|
let (seeders, completed, leechers) = torrent_info.get_stats();
|
||||||
|
|
||||||
UDPScrapeResponseEntry{
|
UDPScrapeResponseEntry {
|
||||||
seeders,
|
seeders,
|
||||||
completed,
|
completed,
|
||||||
leechers,
|
leechers,
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
None => {
|
None => {
|
||||||
UDPScrapeResponseEntry{
|
UDPScrapeResponseEntry {
|
||||||
seeders: 0,
|
seeders: 0,
|
||||||
completed: 0,
|
completed: 0,
|
||||||
leechers: 0,
|
leechers: 0,
|
||||||
|
@ -359,9 +346,10 @@ impl UDPTracker {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// if sending fails, not much we can do...
|
// if sending fails, not much we can do...
|
||||||
let _ = self.send_packet(&remote_addr, &response.as_slice());
|
let _ = self.send_packet(&remote_addr, &response.as_slice()).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_connection_id(&self, remote_address: &SocketAddr) -> u64 {
|
fn get_connection_id(&self, remote_address: &SocketAddr) -> u64 {
|
||||||
|
@ -371,44 +359,38 @@ impl UDPTracker {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_packet(
|
// TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
|
||||||
&self,
|
async fn send_packet(&mut self, remote_addr: &SocketAddr, payload: &[u8]) -> Result<usize, std::io::Error> {
|
||||||
remote_addr: &SocketAddr,
|
self.server.send_to(payload, remote_addr).await
|
||||||
payload: &[u8],
|
|
||||||
) -> Result<usize, std::io::Error> {
|
|
||||||
self.server.send_to(payload, remote_addr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_error(&self, remote_addr: &SocketAddr, header: &UDPRequestHeader, error_msg: &str) {
|
// TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
|
||||||
let mut payload_buffer = [0u8; MAX_PACKET_SIZE];
|
async fn send_error(&mut self, remote_addr: &SocketAddr, header: &UDPRequestHeader, error_msg: &str) {
|
||||||
|
let mut payload_buffer = vec![0u8; MAX_PACKET_SIZE];
|
||||||
let mut payload = StackVec::from(&mut payload_buffer);
|
let mut payload = StackVec::from(&mut payload_buffer);
|
||||||
|
|
||||||
if let Ok(_) = pack_into(
|
if let Ok(_) = pack_into(&mut payload, &UDPResponseHeader {
|
||||||
&mut payload,
|
|
||||||
&UDPResponseHeader {
|
|
||||||
transaction_id: header.transaction_id,
|
transaction_id: header.transaction_id,
|
||||||
action: Actions::Error,
|
action: Actions::Error,
|
||||||
},
|
}) {
|
||||||
) {
|
|
||||||
let msg_bytes = Vec::from(error_msg.as_bytes());
|
let msg_bytes = Vec::from(error_msg.as_bytes());
|
||||||
payload.extend(msg_bytes);
|
payload.extend(msg_bytes);
|
||||||
|
|
||||||
let _ = self.send_packet(remote_addr, payload.as_slice());
|
let _ = self.send_packet(remote_addr, payload.as_slice()).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn accept_packet(&self) -> Result<(), std::io::Error> {
|
// TODO: remove `mut` for `accept_packet`, and spawn once https://github.com/tokio-rs/tokio/issues/1624 is resolved
|
||||||
let mut packet = [0u8; MAX_PACKET_SIZE];
|
pub async fn accept_packet(&mut self) -> Result<(), std::io::Error> {
|
||||||
match self.server.recv_from(&mut packet) {
|
let mut packet = vec![0u8; MAX_PACKET_SIZE];
|
||||||
Ok((size, remote_address)) => {
|
let (size, remote_address) = self.server.recv_from(packet.as_mut_slice()).await?;
|
||||||
|
|
||||||
|
// tokio::spawn(async {
|
||||||
debug!("Received {} bytes from {}", size, remote_address);
|
debug!("Received {} bytes from {}", size, remote_address);
|
||||||
self.handle_packet(&remote_address, &packet[..size]);
|
self.handle_packet(&remote_address, &packet[..size]).await;
|
||||||
|
// });
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
Err(e) => Err(e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -426,10 +408,7 @@ mod tests {
|
||||||
|
|
||||||
assert!(pack_into(&mut payload, &mystruct).is_ok());
|
assert!(pack_into(&mut payload, &mystruct).is_ok());
|
||||||
assert_eq!(payload.len(), 16);
|
assert_eq!(payload.len(), 16);
|
||||||
assert_eq!(
|
assert_eq!(payload.as_slice(), &[0, 0, 0, 0, 0, 0, 0, 200u8, 0, 0, 0, 0, 0, 1, 47, 203]);
|
||||||
payload.as_slice(),
|
|
||||||
&[0, 0, 0, 0, 0, 0, 0, 200u8, 0, 0, 0, 0, 0, 1, 47, 203]
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
@ -10,10 +10,6 @@ impl<'a, T> StackVec<'a, T> {
|
||||||
StackVec { data, length: 0 }
|
StackVec { data, length: 0 }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
self.length
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_slice(&self) -> &[T] {
|
pub fn as_slice(&self) -> &[T] {
|
||||||
&self.data[0..self.length]
|
&self.data[0..self.length]
|
||||||
}
|
}
|
||||||
|
|
234
src/tracker.rs
234
src/tracker.rs
|
@ -1,7 +1,11 @@
|
||||||
|
|
||||||
use crate::server::Events;
|
use crate::server::Events;
|
||||||
use serde::{Serialize, Deserialize};
|
|
||||||
use log::{error, trace};
|
use log::{error, trace};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::borrow::Cow;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
|
||||||
|
use tokio::stream::StreamExt;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
#[derive(Deserialize, Clone, PartialEq)]
|
#[derive(Deserialize, Clone, PartialEq)]
|
||||||
pub enum TrackerMode {
|
pub enum TrackerMode {
|
||||||
|
@ -18,6 +22,7 @@ pub enum TrackerMode {
|
||||||
PrivateMode,
|
PrivateMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
struct TorrentPeer {
|
struct TorrentPeer {
|
||||||
ip: std::net::SocketAddr,
|
ip: std::net::SocketAddr,
|
||||||
uploaded: u64,
|
uploaded: u64,
|
||||||
|
@ -32,6 +37,27 @@ pub struct InfoHash {
|
||||||
info_hash: [u8; 20],
|
info_hash: [u8; 20],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for InfoHash {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
let mut chars = [0u8; 40];
|
||||||
|
binascii::bin2hex(&self.info_hash, &mut chars).expect("failed to hexlify");
|
||||||
|
write!(f, "{}", std::str::from_utf8(&chars).unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for InfoHash {
|
||||||
|
type Err = binascii::ConvertError;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
let mut i = Self { info_hash: [0u8; 20] };
|
||||||
|
if s.len() != 40 {
|
||||||
|
return Err(binascii::ConvertError::InvalidInputLength);
|
||||||
|
}
|
||||||
|
binascii::hex2bin(s.as_bytes(), &mut i.info_hash)?;
|
||||||
|
Ok(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl std::cmp::PartialOrd<InfoHash> for InfoHash {
|
impl std::cmp::PartialOrd<InfoHash> for InfoHash {
|
||||||
fn partial_cmp(&self, other: &InfoHash) -> Option<std::cmp::Ordering> {
|
fn partial_cmp(&self, other: &InfoHash) -> Option<std::cmp::Ordering> {
|
||||||
self.info_hash.partial_cmp(&other.info_hash)
|
self.info_hash.partial_cmp(&other.info_hash)
|
||||||
|
@ -41,9 +67,7 @@ impl std::cmp::PartialOrd<InfoHash> for InfoHash {
|
||||||
impl std::convert::From<&[u8]> for InfoHash {
|
impl std::convert::From<&[u8]> for InfoHash {
|
||||||
fn from(data: &[u8]) -> InfoHash {
|
fn from(data: &[u8]) -> InfoHash {
|
||||||
assert_eq!(data.len(), 20);
|
assert_eq!(data.len(), 20);
|
||||||
let mut ret = InfoHash{
|
let mut ret = InfoHash { info_hash: [0u8; 20] };
|
||||||
info_hash: [0u8; 20],
|
|
||||||
};
|
|
||||||
ret.info_hash.clone_from_slice(data);
|
ret.info_hash.clone_from_slice(data);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -58,9 +82,7 @@ impl std::convert::Into<InfoHash> for [u8; 20] {
|
||||||
impl serde::ser::Serialize for InfoHash {
|
impl serde::ser::Serialize for InfoHash {
|
||||||
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||||
let mut buffer = [0u8; 40];
|
let mut buffer = [0u8; 40];
|
||||||
let bytes_out = binascii::bin2hex(&self.info_hash, &mut buffer)
|
let bytes_out = binascii::bin2hex(&self.info_hash, &mut buffer).ok().unwrap();
|
||||||
.ok()
|
|
||||||
.unwrap();
|
|
||||||
let str_out = std::str::from_utf8(bytes_out).unwrap();
|
let str_out = std::str::from_utf8(bytes_out).unwrap();
|
||||||
|
|
||||||
serializer.serialize_str(str_out)
|
serializer.serialize_str(str_out)
|
||||||
|
@ -84,9 +106,7 @@ impl<'v> serde::de::Visitor<'v> for InfoHashVisitor {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut res = InfoHash {
|
let mut res = InfoHash { info_hash: [0u8; 20] };
|
||||||
info_hash: [0u8; 20],
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(_) = binascii::hex2bin(v.as_bytes(), &mut res.info_hash) {
|
if let Err(_) = binascii::hex2bin(v.as_bytes(), &mut res.info_hash) {
|
||||||
return Err(serde::de::Error::invalid_value(
|
return Err(serde::de::Error::invalid_value(
|
||||||
|
@ -107,7 +127,7 @@ impl<'de> serde::de::Deserialize<'de> for InfoHash {
|
||||||
|
|
||||||
pub type PeerId = [u8; 20];
|
pub type PeerId = [u8; 20];
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
pub struct TorrentEntry {
|
pub struct TorrentEntry {
|
||||||
is_flagged: bool,
|
is_flagged: bool,
|
||||||
|
|
||||||
|
@ -135,28 +155,20 @@ impl TorrentEntry {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_peer(
|
pub fn update_peer(
|
||||||
&mut self,
|
&mut self, peer_id: &PeerId, remote_address: &std::net::SocketAddr, uploaded: u64, downloaded: u64, left: u64,
|
||||||
peer_id: &PeerId,
|
|
||||||
remote_address: &std::net::SocketAddr,
|
|
||||||
uploaded: u64,
|
|
||||||
downloaded: u64,
|
|
||||||
left: u64,
|
|
||||||
event: Events,
|
event: Events,
|
||||||
) {
|
) {
|
||||||
let is_seeder = left == 0 && uploaded > 0;
|
let is_seeder = left == 0 && uploaded > 0;
|
||||||
let mut was_seeder = false;
|
let mut was_seeder = false;
|
||||||
let mut is_completed = left == 0 && (event as u32) == (Events::Complete as u32);
|
let mut is_completed = left == 0 && (event as u32) == (Events::Complete as u32);
|
||||||
if let Some(prev) = self.peers.insert(
|
if let Some(prev) = self.peers.insert(*peer_id, TorrentPeer {
|
||||||
*peer_id,
|
|
||||||
TorrentPeer {
|
|
||||||
updated: std::time::SystemTime::now(),
|
updated: std::time::SystemTime::now(),
|
||||||
left,
|
left,
|
||||||
downloaded,
|
downloaded,
|
||||||
uploaded,
|
uploaded,
|
||||||
ip: *remote_address,
|
ip: *remote_address,
|
||||||
event,
|
event,
|
||||||
},
|
}) {
|
||||||
) {
|
|
||||||
was_seeder = prev.left == 0 && prev.uploaded > 0;
|
was_seeder = prev.left == 0 && prev.uploaded > 0;
|
||||||
|
|
||||||
if is_completed && (prev.event as u32) == (Events::Complete as u32) {
|
if is_completed && (prev.event as u32) == (Events::Complete as u32) {
|
||||||
|
@ -200,13 +212,13 @@ impl TorrentEntry {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TorrentDatabase {
|
struct TorrentDatabase {
|
||||||
torrent_peers: std::sync::RwLock<std::collections::BTreeMap<InfoHash, TorrentEntry>>,
|
torrent_peers: tokio::sync::RwLock<std::collections::BTreeMap<InfoHash, TorrentEntry>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for TorrentDatabase {
|
impl Default for TorrentDatabase {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
TorrentDatabase {
|
TorrentDatabase {
|
||||||
torrent_peers: std::sync::RwLock::new(std::collections::BTreeMap::new()),
|
torrent_peers: tokio::sync::RwLock::new(std::collections::BTreeMap::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -216,14 +228,16 @@ pub struct TorrentTracker {
|
||||||
database: TorrentDatabase,
|
database: TorrentDatabase,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct DatabaseRow<'a> {
|
||||||
|
info_hash: InfoHash,
|
||||||
|
entry: Cow<'a, TorrentEntry>,
|
||||||
|
}
|
||||||
|
|
||||||
pub enum TorrentStats {
|
pub enum TorrentStats {
|
||||||
TorrentFlagged,
|
TorrentFlagged,
|
||||||
TorrentNotRegistered,
|
TorrentNotRegistered,
|
||||||
Stats {
|
Stats { seeders: u32, leechers: u32, complete: u32 },
|
||||||
seeders: u32,
|
|
||||||
leechers: u32,
|
|
||||||
complete: u32,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TorrentTracker {
|
impl TorrentTracker {
|
||||||
|
@ -231,32 +245,52 @@ impl TorrentTracker {
|
||||||
TorrentTracker {
|
TorrentTracker {
|
||||||
mode,
|
mode,
|
||||||
database: TorrentDatabase {
|
database: TorrentDatabase {
|
||||||
torrent_peers: std::sync::RwLock::new(std::collections::BTreeMap::new()),
|
torrent_peers: RwLock::new(std::collections::BTreeMap::new()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_database<R: std::io::Read>(
|
pub async fn load_database<R: tokio::io::AsyncRead + Unpin>(
|
||||||
mode: TrackerMode,
|
mode: TrackerMode, reader: &mut R,
|
||||||
reader: &mut R,
|
) -> Result<TorrentTracker, std::io::Error> {
|
||||||
) -> serde_json::Result<TorrentTracker> {
|
let reader = tokio::io::BufReader::new(reader);
|
||||||
let decomp_reader = bzip2::read::BzDecoder::new(reader);
|
let mut tmp: Vec<u8> = Vec::with_capacity(4096);
|
||||||
let result: serde_json::Result<std::collections::BTreeMap<InfoHash, TorrentEntry>> =
|
tmp.resize(tmp.capacity(), 0);
|
||||||
serde_json::from_reader(decomp_reader);
|
|
||||||
match result {
|
let res = TorrentTracker::new(mode);
|
||||||
Ok(v) => Ok(TorrentTracker {
|
let mut db = res.database.torrent_peers.write().await;
|
||||||
mode,
|
|
||||||
database: TorrentDatabase {
|
let mut records = reader
|
||||||
torrent_peers: std::sync::RwLock::new(v),
|
.lines()
|
||||||
},
|
.filter_map(|res| {
|
||||||
}),
|
if let Err(ref err) = res {
|
||||||
Err(e) => Err(e),
|
error!("failed to read lines! {}", err);
|
||||||
}
|
}
|
||||||
|
res.ok()
|
||||||
|
})
|
||||||
|
.map(|line| serde_json::from_str::<DatabaseRow>(&line))
|
||||||
|
.filter_map(|jsr| {
|
||||||
|
if let Err(ref err) = jsr {
|
||||||
|
error!("failed to parse json: {}", err);
|
||||||
|
}
|
||||||
|
jsr.ok()
|
||||||
|
});
|
||||||
|
|
||||||
|
while let Some(entry) = records.next().await {
|
||||||
|
let x = || (entry.entry, entry.info_hash);
|
||||||
|
let (a, b) = x();
|
||||||
|
let a = a.into_owned();
|
||||||
|
db.insert(b, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(db);
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Adding torrents is not relevant to dynamic trackers.
|
/// Adding torrents is not relevant to dynamic trackers.
|
||||||
pub fn add_torrent(&self, info_hash: &InfoHash) -> Result<(), ()> {
|
pub async fn add_torrent(&self, info_hash: &InfoHash) -> Result<(), ()> {
|
||||||
let mut write_lock = self.database.torrent_peers.write().unwrap();
|
let mut write_lock = self.database.torrent_peers.write().await;
|
||||||
match write_lock.entry(info_hash.clone()) {
|
match write_lock.entry(info_hash.clone()) {
|
||||||
std::collections::btree_map::Entry::Vacant(ve) => {
|
std::collections::btree_map::Entry::Vacant(ve) => {
|
||||||
ve.insert(TorrentEntry::new());
|
ve.insert(TorrentEntry::new());
|
||||||
|
@ -269,9 +303,9 @@ impl TorrentTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If the torrent is flagged, it will not be removed unless force is set to true.
|
/// If the torrent is flagged, it will not be removed unless force is set to true.
|
||||||
pub fn remove_torrent(&self, info_hash: &InfoHash, force: bool) -> Result<(), ()> {
|
pub async fn remove_torrent(&self, info_hash: &InfoHash, force: bool) -> Result<(), ()> {
|
||||||
use std::collections::btree_map::Entry;
|
use std::collections::btree_map::Entry;
|
||||||
let mut entry_lock = self.database.torrent_peers.write().unwrap();
|
let mut entry_lock = self.database.torrent_peers.write().await;
|
||||||
let torrent_entry = entry_lock.entry(info_hash.clone());
|
let torrent_entry = entry_lock.entry(info_hash.clone());
|
||||||
match torrent_entry {
|
match torrent_entry {
|
||||||
Entry::Vacant(_) => {
|
Entry::Vacant(_) => {
|
||||||
|
@ -289,28 +323,23 @@ impl TorrentTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// flagged torrents will result in a tracking error. This is to allow enforcement against piracy.
|
/// flagged torrents will result in a tracking error. This is to allow enforcement against piracy.
|
||||||
pub fn set_torrent_flag(&self, info_hash: &InfoHash, is_flagged: bool) {
|
pub async fn set_torrent_flag(&self, info_hash: &InfoHash, is_flagged: bool) -> bool {
|
||||||
if let Some(entry) = self
|
if let Some(entry) = self.database.torrent_peers.write().await.get_mut(info_hash) {
|
||||||
.database
|
|
||||||
.torrent_peers
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.get_mut(info_hash)
|
|
||||||
{
|
|
||||||
if is_flagged && !entry.is_flagged {
|
if is_flagged && !entry.is_flagged {
|
||||||
// empty peer list.
|
// empty peer list.
|
||||||
entry.peers.clear();
|
entry.peers.clear();
|
||||||
}
|
}
|
||||||
entry.is_flagged = is_flagged;
|
entry.is_flagged = is_flagged;
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_torrent_peers(
|
pub async fn get_torrent_peers(
|
||||||
&self,
|
&self, info_hash: &InfoHash, remote_addr: &std::net::SocketAddr,
|
||||||
info_hash: &InfoHash,
|
|
||||||
remote_addr: &std::net::SocketAddr,
|
|
||||||
) -> Option<Vec<std::net::SocketAddr>> {
|
) -> Option<Vec<std::net::SocketAddr>> {
|
||||||
let read_lock = self.database.torrent_peers.read().unwrap();
|
let read_lock = self.database.torrent_peers.read().await;
|
||||||
match read_lock.get(info_hash) {
|
match read_lock.get(info_hash) {
|
||||||
None => {
|
None => {
|
||||||
return None;
|
return None;
|
||||||
|
@ -321,25 +350,21 @@ impl TorrentTracker {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_torrent_and_get_stats(
|
pub async fn update_torrent_and_get_stats(
|
||||||
&self,
|
&self, info_hash: &InfoHash, peer_id: &PeerId, remote_address: &std::net::SocketAddr, uploaded: u64,
|
||||||
info_hash: &InfoHash,
|
downloaded: u64, left: u64, event: Events,
|
||||||
peer_id: &PeerId,
|
|
||||||
remote_address: &std::net::SocketAddr,
|
|
||||||
uploaded: u64,
|
|
||||||
downloaded: u64,
|
|
||||||
left: u64,
|
|
||||||
event: Events,
|
|
||||||
) -> TorrentStats {
|
) -> TorrentStats {
|
||||||
use std::collections::btree_map::Entry;
|
use std::collections::btree_map::Entry;
|
||||||
let mut torrent_peers = self.database.torrent_peers.write().unwrap();
|
let mut torrent_peers = self.database.torrent_peers.write().await;
|
||||||
let torrent_entry = match torrent_peers.entry(info_hash.clone()) {
|
let torrent_entry = match torrent_peers.entry(info_hash.clone()) {
|
||||||
Entry::Vacant(vacant) => match self.mode {
|
Entry::Vacant(vacant) => {
|
||||||
|
match self.mode {
|
||||||
TrackerMode::DynamicMode => vacant.insert(TorrentEntry::new()),
|
TrackerMode::DynamicMode => vacant.insert(TorrentEntry::new()),
|
||||||
_ => {
|
_ => {
|
||||||
return TorrentStats::TorrentNotRegistered;
|
return TorrentStats::TorrentNotRegistered;
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
}
|
||||||
Entry::Occupied(entry) => {
|
Entry::Occupied(entry) => {
|
||||||
if entry.get().is_flagged() {
|
if entry.get().is_flagged() {
|
||||||
return TorrentStats::TorrentFlagged;
|
return TorrentStats::TorrentFlagged;
|
||||||
|
@ -359,32 +384,39 @@ impl TorrentTracker {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get_database(
|
pub(crate) async fn get_database<'a>(&'a self) -> tokio::sync::RwLockReadGuard<'a, BTreeMap<InfoHash, TorrentEntry>> {
|
||||||
&self,
|
self.database.torrent_peers.read().await
|
||||||
) -> std::sync::RwLockReadGuard<std::collections::BTreeMap<InfoHash, TorrentEntry>> {
|
|
||||||
self.database.torrent_peers.read().unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn save_database<W: std::io::Write>(&self, writer: &mut W) -> serde_json::Result<()> {
|
pub async fn save_database<W: tokio::io::AsyncWrite + Unpin>(&self, mut writer: W) -> Result<(), std::io::Error> {
|
||||||
let compressor = bzip2::write::BzEncoder::new(writer, bzip2::Compression::Best);
|
// TODO: find async friendly compressor
|
||||||
|
|
||||||
let db_lock = self.database.torrent_peers.read().unwrap();
|
let db_lock = self.database.torrent_peers.read().await;
|
||||||
|
|
||||||
let db = &*db_lock;
|
let db: &BTreeMap<InfoHash, TorrentEntry> = &*db_lock;
|
||||||
|
let mut tmp = Vec::with_capacity(4096);
|
||||||
|
|
||||||
serde_json::to_writer(compressor, &db)
|
for row in db {
|
||||||
|
let entry = DatabaseRow {
|
||||||
|
info_hash: row.0.clone(),
|
||||||
|
entry: Cow::Borrowed(row.1),
|
||||||
|
};
|
||||||
|
tmp.clear();
|
||||||
|
if let Err(err) = serde_json::to_writer(&mut tmp, &entry) {
|
||||||
|
error!("failed to serialize: {}", err);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
tmp.push(b'\n');
|
||||||
|
writer.write_all(&tmp).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cleanup(&self) {
|
async fn cleanup(&self) {
|
||||||
use std::ops::Add;
|
use std::ops::Add;
|
||||||
|
|
||||||
let now = std::time::SystemTime::now();
|
let now = std::time::SystemTime::now();
|
||||||
match self.database.torrent_peers.write() {
|
let mut lock = self.database.torrent_peers.write().await;
|
||||||
Err(err) => {
|
let db: &mut BTreeMap<InfoHash, TorrentEntry> = &mut *lock;
|
||||||
error!("failed to obtain write lock on database. err: {}", err);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Ok(mut db) => {
|
|
||||||
let mut torrents_to_remove = Vec::new();
|
let mut torrents_to_remove = Vec::new();
|
||||||
|
|
||||||
for (k, v) in db.iter_mut() {
|
for (k, v) in db.iter_mut() {
|
||||||
|
@ -417,12 +449,10 @@ impl TorrentTracker {
|
||||||
db.remove(&info_hash);
|
db.remove(&info_hash);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn periodic_task(&self, db_path: &str) {
|
pub async fn periodic_task(&self, db_path: &str) {
|
||||||
// cleanup db
|
// cleanup db
|
||||||
self.cleanup();
|
self.cleanup().await;
|
||||||
|
|
||||||
// save journal db.
|
// save journal db.
|
||||||
let mut journal_path = std::path::PathBuf::from(db_path);
|
let mut journal_path = std::path::PathBuf::from(db_path);
|
||||||
|
@ -435,7 +465,7 @@ impl TorrentTracker {
|
||||||
|
|
||||||
// scope to make sure backup file is dropped/closed.
|
// scope to make sure backup file is dropped/closed.
|
||||||
{
|
{
|
||||||
let mut file = match std::fs::File::create(jp_str) {
|
let mut file = match tokio::fs::File::create(jp_str).await {
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("failed to open file '{}': {}", db_path, err);
|
error!("failed to open file '{}': {}", db_path, err);
|
||||||
return;
|
return;
|
||||||
|
@ -443,7 +473,7 @@ impl TorrentTracker {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
};
|
};
|
||||||
trace!("writing database to {}", jp_str);
|
trace!("writing database to {}", jp_str);
|
||||||
if let Err(err) = self.save_database(&mut file) {
|
if let Err(err) = self.save_database(&mut file).await {
|
||||||
error!("failed saving database. {}", err);
|
error!("failed saving database. {}", err);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -451,7 +481,7 @@ impl TorrentTracker {
|
||||||
|
|
||||||
// overwrite previous db
|
// overwrite previous db
|
||||||
trace!("renaming '{}' to '{}'", jp_str, db_path);
|
trace!("renaming '{}' to '{}'", jp_str, db_path);
|
||||||
if let Err(err) = std::fs::rename(jp_str, db_path) {
|
if let Err(err) = tokio::fs::rename(jp_str, db_path).await {
|
||||||
error!("failed to move db backup. {}", err);
|
error!("failed to move db backup. {}", err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -474,14 +504,14 @@ mod tests {
|
||||||
is_sync::<TorrentTracker>();
|
is_sync::<TorrentTracker>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_save_db() {
|
async fn test_save_db() {
|
||||||
let tracker = TorrentTracker::new(TrackerMode::DynamicMode);
|
let tracker = TorrentTracker::new(TrackerMode::DynamicMode);
|
||||||
tracker.add_torrent(&[0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0].into());
|
tracker.add_torrent(&[0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0].into());
|
||||||
|
|
||||||
let mut out = Vec::new();
|
let mut out = Vec::new();
|
||||||
|
|
||||||
tracker.save_database(&mut out).unwrap();
|
tracker.save_database(&mut out).await.expect("db save failed");
|
||||||
assert!(out.len() > 0);
|
assert!(out.len() > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
604
src/webserver.rs
604
src/webserver.rs
|
@ -1,419 +1,193 @@
|
||||||
use std::collections::HashMap;
|
use crate::tracker::{InfoHash, TorrentTracker};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use warp::{filters, reply, reply::Reply, serve, Filter, Server};
|
||||||
|
|
||||||
use actix_web;
|
fn view_root() -> impl Reply {
|
||||||
use actix_net;
|
reply::html(concat!(
|
||||||
use binascii;
|
r#"<html>
|
||||||
|
<head>
|
||||||
use crate::config;
|
<title>udpt/"#,
|
||||||
use crate::tracker::TorrentTracker;
|
env!("CARGO_PKG_VERSION"),
|
||||||
use log::{error, warn, info, debug};
|
r#"</title>
|
||||||
|
</head>
|
||||||
const SERVER: &str = concat!("udpt/", env!("CARGO_PKG_VERSION"));
|
<body>
|
||||||
|
This is your <a href="https://github.com/naim94a/udpt">udpt</a> torrent tracker.
|
||||||
pub struct WebServer {
|
</body>
|
||||||
thread: std::thread::JoinHandle<()>,
|
</html>"#
|
||||||
addr: Option<actix_web::actix::Addr<actix_net::server::Server>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
mod http_responses {
|
|
||||||
use serde::Serialize;
|
|
||||||
use crate::tracker::InfoHash;
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct TorrentInfo {
|
|
||||||
pub is_flagged: bool,
|
|
||||||
pub leecher_count: u32,
|
|
||||||
pub seeder_count: u32,
|
|
||||||
pub completed: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct TorrentList {
|
|
||||||
pub offset: u32,
|
|
||||||
pub length: u32,
|
|
||||||
pub total: u32,
|
|
||||||
pub torrents: Vec<InfoHash>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum APIResponse {
|
|
||||||
Error(String),
|
|
||||||
TorrentList(TorrentList),
|
|
||||||
TorrentInfo(TorrentInfo),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct UdptState {
|
|
||||||
// k=token, v=username.
|
|
||||||
access_tokens: HashMap<String, String>,
|
|
||||||
tracker: Arc<TorrentTracker>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UdptState {
|
|
||||||
fn new(tracker: Arc<TorrentTracker>, tokens: HashMap<String, String>) -> UdptState {
|
|
||||||
UdptState {
|
|
||||||
tracker,
|
|
||||||
access_tokens: tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct UdptRequestState {
|
|
||||||
current_user: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for UdptRequestState {
|
|
||||||
fn default() -> Self {
|
|
||||||
UdptRequestState {
|
|
||||||
current_user: Option::None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UdptRequestState {
|
|
||||||
fn get_user<S>(req: &actix_web::HttpRequest<S>) -> Option<String> {
|
|
||||||
let exts = req.extensions();
|
|
||||||
let req_state: Option<&UdptRequestState> = exts.get();
|
|
||||||
match req_state {
|
|
||||||
None => None,
|
|
||||||
Option::Some(state) => match state.current_user {
|
|
||||||
Option::Some(ref v) => Option::Some(v.clone()),
|
|
||||||
None => {
|
|
||||||
error!(
|
|
||||||
"Invalid API token from {} @ {}",
|
|
||||||
req.peer_addr().unwrap(),
|
|
||||||
req.path()
|
|
||||||
);
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct UdptMiddleware;
|
|
||||||
|
|
||||||
impl actix_web::middleware::Middleware<UdptState> for UdptMiddleware {
|
|
||||||
fn start(
|
|
||||||
&self,
|
|
||||||
req: &actix_web::HttpRequest<UdptState>,
|
|
||||||
) -> actix_web::Result<actix_web::middleware::Started> {
|
|
||||||
let mut req_state = UdptRequestState::default();
|
|
||||||
if let Option::Some(token) = req.query().get("token") {
|
|
||||||
let app_state: &UdptState = req.state();
|
|
||||||
if let Option::Some(v) = app_state.access_tokens.get(token) {
|
|
||||||
req_state.current_user = Option::Some(v.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
req.extensions_mut().insert(req_state);
|
|
||||||
Ok(actix_web::middleware::Started::Done)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn response(
|
|
||||||
&self,
|
|
||||||
_req: &actix_web::HttpRequest<UdptState>,
|
|
||||||
mut resp: actix_web::HttpResponse,
|
|
||||||
) -> actix_web::Result<actix_web::middleware::Response> {
|
|
||||||
resp.headers_mut().insert(
|
|
||||||
actix_web::http::header::SERVER,
|
|
||||||
actix_web::http::header::HeaderValue::from_static(SERVER),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(actix_web::middleware::Response::Done(resp))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WebServer {
|
|
||||||
fn get_access_tokens(cfg: &config::HTTPConfig, tokens: &mut HashMap<String, String>) {
|
|
||||||
for (user, token) in cfg.get_access_tokens().iter() {
|
|
||||||
tokens.insert(token.clone(), user.clone());
|
|
||||||
}
|
|
||||||
if tokens.len() == 0 {
|
|
||||||
warn!("No access tokens provided. HTTP API will not be useful.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn shutdown(self) {
|
|
||||||
match self.addr {
|
|
||||||
Some(v) => {
|
|
||||||
use futures::future::Future;
|
|
||||||
|
|
||||||
v.send(actix_web::actix::signal::Signal(actix_web::actix::signal::SignalType::Term)).wait().unwrap();
|
|
||||||
},
|
|
||||||
None => {},
|
|
||||||
};
|
|
||||||
|
|
||||||
self.thread.thread().unpark();
|
|
||||||
let _ = self.thread.join();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new(
|
|
||||||
tracker: Arc<TorrentTracker>,
|
|
||||||
cfg: Arc<config::Configuration>,
|
|
||||||
) -> WebServer {
|
|
||||||
let cfg_cp = cfg.clone();
|
|
||||||
|
|
||||||
let (tx_addr, rx_addr) = std::sync::mpsc::channel();
|
|
||||||
|
|
||||||
let thread = std::thread::spawn(move || {
|
|
||||||
let server = actix_web::server::HttpServer::new(move || {
|
|
||||||
let mut access_tokens = HashMap::new();
|
|
||||||
|
|
||||||
if let Some(http_cfg) = cfg_cp.get_http_config() {
|
|
||||||
Self::get_access_tokens(http_cfg, &mut access_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
let state = UdptState::new(tracker.clone(), access_tokens);
|
|
||||||
|
|
||||||
actix_web::App::<UdptState>::with_state(state)
|
|
||||||
.middleware(UdptMiddleware)
|
|
||||||
.resource("/t", |r| r.f(Self::view_torrent_list))
|
|
||||||
.scope(r"/t/{info_hash:[\dA-Fa-f]{40,40}}", |scope| {
|
|
||||||
scope.resource("", |r| {
|
|
||||||
r.method(actix_web::http::Method::GET)
|
|
||||||
.f(Self::view_torrent_stats);
|
|
||||||
r.method(actix_web::http::Method::POST)
|
|
||||||
.f(Self::torrent_action);
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.resource("/", |r| {
|
|
||||||
r.method(actix_web::http::Method::GET).f(Self::view_root)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(http_cfg) = cfg.get_http_config() {
|
|
||||||
let bind_addr = http_cfg.get_address();
|
|
||||||
match server.bind(bind_addr) {
|
|
||||||
Ok(v) => {
|
|
||||||
let sys = actix_web::actix::System::new("http-server");
|
|
||||||
let addr = v.start();
|
|
||||||
let _ = tx_addr.send(addr);
|
|
||||||
sys.run();
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
error!("Failed to bind http server. {}", err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
unreachable!();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let addr = match rx_addr.recv() {
|
|
||||||
Ok(v) => Some(v),
|
|
||||||
Err(_) => None
|
|
||||||
};
|
|
||||||
|
|
||||||
WebServer {
|
|
||||||
thread,
|
|
||||||
addr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn view_root(_req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse {
|
|
||||||
actix_web::HttpResponse::build(actix_web::http::StatusCode::OK)
|
|
||||||
.content_type("text/html")
|
|
||||||
.body(r#"Powered by <a href="https://github.com/naim94a/udpt">https://github.com/naim94a/udpt</a>"#)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn view_torrent_list(req: &actix_web::HttpRequest<UdptState>) -> impl actix_web::Responder {
|
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
if UdptRequestState::get_user(req).is_none() {
|
|
||||||
return actix_web::Json(http_responses::APIResponse::Error(String::from(
|
|
||||||
"access_denied",
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let req_offset = match req.query().get("offset") {
|
|
||||||
None => 0,
|
|
||||||
Some(v) => match u32::from_str(v.as_str()) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(_) => 0,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut req_limit = match req.query().get("limit") {
|
|
||||||
None => 0,
|
|
||||||
Some(v) => match u32::from_str(v.as_str()) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(_) => 0,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
if req_limit > 4096 {
|
|
||||||
req_limit = 4096;
|
|
||||||
} else if req_limit == 0 {
|
|
||||||
req_limit = 1000;
|
|
||||||
}
|
|
||||||
|
|
||||||
let app_state: &UdptState = req.state();
|
|
||||||
let app_db = app_state.tracker.get_database();
|
|
||||||
|
|
||||||
let total = app_db.len() as u32;
|
|
||||||
|
|
||||||
let mut torrents = Vec::with_capacity(req_limit as usize);
|
|
||||||
|
|
||||||
for (info_hash, _) in app_db
|
|
||||||
.iter()
|
|
||||||
.skip(req_offset as usize)
|
|
||||||
.take(req_limit as usize)
|
|
||||||
{
|
|
||||||
torrents.push(info_hash.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
actix_web::Json(http_responses::APIResponse::TorrentList(
|
|
||||||
http_responses::TorrentList {
|
|
||||||
total,
|
|
||||||
length: torrents.len() as u32,
|
|
||||||
offset: req_offset,
|
|
||||||
torrents,
|
|
||||||
},
|
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn view_torrent_stats(req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse {
|
#[derive(Deserialize, Debug)]
|
||||||
use actix_web::FromRequest;
|
struct TorrentInfoQuery {
|
||||||
|
offset: Option<u32>,
|
||||||
if UdptRequestState::get_user(req).is_none() {
|
limit: Option<u32>,
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::UNAUTHORIZED).json(
|
}
|
||||||
http_responses::APIResponse::Error(String::from("access_denied")),
|
|
||||||
);
|
#[derive(Serialize)]
|
||||||
}
|
struct TorrentEntry<'a> {
|
||||||
|
info_hash: &'a InfoHash,
|
||||||
let path: actix_web::Path<String> = match actix_web::Path::extract(req) {
|
#[serde(flatten)]
|
||||||
Ok(v) => v,
|
data: &'a crate::tracker::TorrentEntry,
|
||||||
Err(_) => {
|
}
|
||||||
return actix_web::HttpResponse::build(
|
|
||||||
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
|
#[derive(Serialize, Deserialize)]
|
||||||
)
|
struct TorrentFlag {
|
||||||
.json(http_responses::APIResponse::Error(String::from(
|
is_flagged: bool,
|
||||||
"internal_error",
|
}
|
||||||
)));
|
|
||||||
}
|
#[derive(Serialize, Debug)]
|
||||||
};
|
#[serde(tag = "status", rename_all = "snake_case")]
|
||||||
|
enum ActionStatus<'a> {
|
||||||
let mut info_hash = [0u8; 20];
|
Ok,
|
||||||
if let Err(_) = binascii::hex2bin((*path).as_bytes(), &mut info_hash) {
|
Err { reason: std::borrow::Cow<'a, str> },
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST).json(
|
}
|
||||||
http_responses::APIResponse::Error(String::from("invalid_info_hash")),
|
|
||||||
);
|
impl warp::reject::Reject for ActionStatus<'static> {}
|
||||||
}
|
|
||||||
|
fn authenticate(tokens: HashMap<String, String>) -> impl Filter<Extract = (), Error = warp::reject::Rejection> + Clone {
|
||||||
let app_state: &UdptState = req.state();
|
#[derive(Deserialize)]
|
||||||
|
struct AuthToken {
|
||||||
let db = app_state.tracker.get_database();
|
token: Option<String>,
|
||||||
let entry = match db.get(&info_hash.into()) {
|
}
|
||||||
Some(v) => v,
|
|
||||||
None => {
|
let tokens: HashSet<String> = tokens.into_iter().map(|(_, v)| v).collect();
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::NOT_FOUND).json(
|
|
||||||
http_responses::APIResponse::Error(String::from("not_found")),
|
let tokens = Arc::new(tokens);
|
||||||
);
|
warp::filters::any::any()
|
||||||
}
|
.map(move || tokens.clone())
|
||||||
};
|
.and(filters::query::query::<AuthToken>())
|
||||||
|
.and_then(|tokens: Arc<HashSet<String>>, token: AuthToken| {
|
||||||
let is_flagged = entry.is_flagged();
|
async move {
|
||||||
let (seeders, completed, leechers) = entry.get_stats();
|
if let Some(token) = token.token {
|
||||||
|
if tokens.contains(&token) {
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK).json(
|
return Ok(());
|
||||||
http_responses::APIResponse::TorrentInfo(http_responses::TorrentInfo {
|
}
|
||||||
is_flagged,
|
}
|
||||||
seeder_count: seeders,
|
Err(warp::reject::custom(ActionStatus::Err {
|
||||||
leecher_count: leechers,
|
reason: "Access Denied".into(),
|
||||||
completed,
|
}))
|
||||||
}),
|
}
|
||||||
);
|
})
|
||||||
}
|
.untuple_one()
|
||||||
|
}
|
||||||
fn torrent_action(req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse {
|
|
||||||
use actix_web::FromRequest;
|
pub fn build_server(
|
||||||
|
tracker: Arc<TorrentTracker>, tokens: HashMap<String, String>,
|
||||||
if UdptRequestState::get_user(req).is_none() {
|
) -> Server<impl Filter<Extract = impl Reply> + Clone + Send + Sync + 'static> {
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::UNAUTHORIZED).json(
|
let root = filters::path::end().map(|| view_root());
|
||||||
http_responses::APIResponse::Error(String::from("access_denied")),
|
|
||||||
);
|
let t1 = tracker.clone();
|
||||||
}
|
// view_torrent_list -> GET /t/?offset=:u32&limit=:u32 HTTP/1.1
|
||||||
|
let view_torrent_list = filters::path::end()
|
||||||
let query = req.query();
|
.and(filters::method::get())
|
||||||
let action_opt = query.get("action");
|
.and(filters::query::query())
|
||||||
let action = match action_opt {
|
.map(move |limits| {
|
||||||
Some(v) => v,
|
let tracker = t1.clone();
|
||||||
None => {
|
(limits, tracker)
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST)
|
})
|
||||||
.json(http_responses::APIResponse::Error(String::from(
|
.and_then(|(limits, tracker): (TorrentInfoQuery, Arc<TorrentTracker>)| {
|
||||||
"action_required",
|
async move {
|
||||||
)));
|
let offset = limits.offset.unwrap_or(0);
|
||||||
}
|
let limit = min(limits.limit.unwrap_or(1000), 4000);
|
||||||
};
|
|
||||||
|
let db = tracker.get_database().await;
|
||||||
let app_state: &UdptState = req.state();
|
let results: Vec<_> = db
|
||||||
|
.iter()
|
||||||
let path: actix_web::Path<String> = match actix_web::Path::extract(req) {
|
.map(|(k, v)| TorrentEntry { info_hash: k, data: v })
|
||||||
Ok(v) => v,
|
.skip(offset as usize)
|
||||||
Err(_err) => {
|
.take(limit as usize)
|
||||||
return actix_web::HttpResponse::build(
|
.collect();
|
||||||
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
Result::<_, warp::reject::Rejection>::Ok(reply::json(&results))
|
||||||
.json(http_responses::APIResponse::Error(String::from(
|
}
|
||||||
"internal_error",
|
});
|
||||||
)));
|
|
||||||
}
|
let t2 = tracker.clone();
|
||||||
};
|
// view_torrent_info -> GET /t/:infohash HTTP/*
|
||||||
|
let view_torrent_info = filters::method::get()
|
||||||
let info_hash_str = &(*path);
|
.and(filters::path::param())
|
||||||
let mut info_hash = [0u8; 20];
|
.map(move |info_hash: InfoHash| {
|
||||||
if let Err(_) = binascii::hex2bin(info_hash_str.as_bytes(), &mut info_hash) {
|
let tracker = t2.clone();
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST).json(
|
(info_hash, tracker)
|
||||||
http_responses::APIResponse::Error(String::from("invalid_info_hash")),
|
})
|
||||||
);
|
.and_then(|(info_hash, tracker): (InfoHash, Arc<TorrentTracker>)| {
|
||||||
}
|
async move {
|
||||||
|
let db = tracker.get_database().await;
|
||||||
match action.as_str() {
|
let info = match db.get(&info_hash) {
|
||||||
"flag" => {
|
Some(v) => v,
|
||||||
app_state.tracker.set_torrent_flag(&info_hash.into(), true);
|
None => return Err(warp::reject::reject()),
|
||||||
info!("Flagged {}", info_hash_str.as_str());
|
};
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK).body("");
|
|
||||||
}
|
Ok(reply::json(&TorrentEntry {
|
||||||
"unflag" => {
|
info_hash: &info_hash,
|
||||||
app_state.tracker.set_torrent_flag(&info_hash.into(), false);
|
data: info,
|
||||||
info!("Unflagged {}", info_hash_str.as_str());
|
}))
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK).body("");
|
}
|
||||||
}
|
});
|
||||||
"add" => {
|
|
||||||
let success = app_state.tracker.add_torrent(&info_hash.into()).is_ok();
|
// DELETE /t/:info_hash
|
||||||
info!("Added {}, success={}", info_hash_str.as_str(), success);
|
let t3 = tracker.clone();
|
||||||
let code = if success {
|
let delete_torrent = filters::method::post()
|
||||||
actix_web::http::StatusCode::OK
|
.and(filters::path::param())
|
||||||
} else {
|
.map(move |info_hash: InfoHash| {
|
||||||
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR
|
let tracker = t3.clone();
|
||||||
};
|
(info_hash, tracker)
|
||||||
|
})
|
||||||
return actix_web::HttpResponse::build(code).body("");
|
.and_then(|(info_hash, tracker): (InfoHash, Arc<TorrentTracker>)| {
|
||||||
}
|
async move {
|
||||||
"remove" => {
|
let resp = match tracker.remove_torrent(&info_hash, true).await.is_ok() {
|
||||||
let success = app_state
|
true => ActionStatus::Ok,
|
||||||
.tracker
|
false => {
|
||||||
.remove_torrent(&info_hash.into(), true)
|
ActionStatus::Err {
|
||||||
.is_ok();
|
reason: "failed to delete torrent".into(),
|
||||||
info!("Removed {}, success={}", info_hash_str.as_str(), success);
|
}
|
||||||
let code = if success {
|
}
|
||||||
actix_web::http::StatusCode::OK
|
};
|
||||||
} else {
|
|
||||||
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR
|
Result::<_, warp::Rejection>::Ok(reply::json(&resp))
|
||||||
};
|
}
|
||||||
|
});
|
||||||
return actix_web::HttpResponse::build(code).body("");
|
|
||||||
}
|
let t4 = tracker.clone();
|
||||||
_ => {
|
// add_torrent/alter: POST /t/:info_hash
|
||||||
debug!("Invalid action {}", action.as_str());
|
// (optional) BODY: json: {"is_flagged": boolean}
|
||||||
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST)
|
let change_torrent = filters::method::post()
|
||||||
.json(http_responses::APIResponse::Error(String::from(
|
.and(filters::path::param())
|
||||||
"invalid_action",
|
.and(filters::body::content_length_limit(4096))
|
||||||
)));
|
.and(filters::body::json())
|
||||||
}
|
.map(move |info_hash: InfoHash, body: Option<TorrentFlag>| {
|
||||||
}
|
let tracker = t4.clone();
|
||||||
}
|
(info_hash, tracker, body)
|
||||||
|
})
|
||||||
|
.and_then(
|
||||||
|
|(info_hash, tracker, body): (InfoHash, Arc<TorrentTracker>, Option<TorrentFlag>)| {
|
||||||
|
async move {
|
||||||
|
let is_flagged = body.map(|e| e.is_flagged).unwrap_or(false);
|
||||||
|
if !tracker.set_torrent_flag(&info_hash, is_flagged).await {
|
||||||
|
// torrent doesn't exist, add it...
|
||||||
|
|
||||||
|
if is_flagged {
|
||||||
|
if tracker.add_torrent(&info_hash).await.is_ok() {
|
||||||
|
tracker.set_torrent_flag(&info_hash, is_flagged).await;
|
||||||
|
} else {
|
||||||
|
return Err(warp::reject::custom(ActionStatus::Err {
|
||||||
|
reason: "failed to flag torrent".into(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Result::<_, warp::Rejection>::Ok(reply::json(&ActionStatus::Ok))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let torrent_mgmt =
|
||||||
|
filters::path::path("t").and(view_torrent_list.or(delete_torrent).or(view_torrent_info).or(change_torrent));
|
||||||
|
|
||||||
|
let server = root.or(authenticate(tokens).and(torrent_mgmt));
|
||||||
|
|
||||||
|
serve(server)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue