make everything async

This commit is contained in:
Naim A 2020-05-06 03:06:14 +03:00
parent 71ea023e5b
commit d030871337
8 changed files with 948 additions and 2189 deletions

1541
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -11,15 +11,11 @@ lto = "fat"
[dependencies] [dependencies]
serde = {version = "1.0", features = ["derive"]} serde = {version = "1.0", features = ["derive"]}
bincode = "1.2" bincode = "1.2"
actix-web = "0.7" warp = {version = "0.2", default-features = false}
actix-net = "0.2" tokio = {version = "0.2", features = ["macros", "net", "rt-threaded", "fs", "sync", "blocking", "signal"]}
binascii = "0.1" binascii = "0.1"
toml = "0.5" toml = "0.5"
clap = "2.33" clap = "2.33"
log = "0.4" log = "0.4"
fern = "0.6" fern = "0.6"
num_cpus = "1.13"
serde_json = "1.0" serde_json = "1.0"
bzip2 = "0.3"
futures = "0.1"
lazy_static = "1.4"

View file

@ -1,119 +1,121 @@
use std; pub use crate::tracker::TrackerMode;
use std::collections::HashMap; use serde::Deserialize;
use toml; use std;
pub use crate::tracker::TrackerMode; use std::collections::HashMap;
use serde::Deserialize; use toml;
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct UDPConfig { pub struct UDPConfig {
bind_address: String, bind_address: String,
announce_interval: u32, announce_interval: u32,
} }
impl UDPConfig { impl UDPConfig {
pub fn get_address(&self) -> &str { pub fn get_address(&self) -> &str {
self.bind_address.as_str() self.bind_address.as_str()
} }
pub fn get_announce_interval(&self) -> u32 { pub fn get_announce_interval(&self) -> u32 {
self.announce_interval self.announce_interval
} }
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct HTTPConfig { pub struct HTTPConfig {
bind_address: String, bind_address: String,
access_tokens: HashMap<String, String>, access_tokens: HashMap<String, String>,
} }
impl HTTPConfig { impl HTTPConfig {
pub fn get_address(&self) -> &str { pub fn get_address(&self) -> &str {
self.bind_address.as_str() self.bind_address.as_str()
} }
pub fn get_access_tokens(&self) -> &HashMap<String, String> { pub fn get_access_tokens(&self) -> &HashMap<String, String> {
&self.access_tokens &self.access_tokens
} }
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Configuration { pub struct Configuration {
mode: TrackerMode, mode: TrackerMode,
udp: UDPConfig, udp: UDPConfig,
http: Option<HTTPConfig>, http: Option<HTTPConfig>,
log_level: Option<String>, log_level: Option<String>,
db_path: Option<String>, db_path: Option<String>,
cleanup_interval: Option<u64>, cleanup_interval: Option<u64>,
} }
#[derive(Debug)] #[derive(Debug)]
pub enum ConfigError { pub enum ConfigError {
IOError(std::io::Error), IOError(std::io::Error),
ParseError(toml::de::Error), ParseError(toml::de::Error),
} }
impl std::fmt::Display for ConfigError { impl std::fmt::Display for ConfigError {
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
match self { match self {
ConfigError::IOError(e) => e.fmt(formatter), ConfigError::IOError(e) => e.fmt(formatter),
ConfigError::ParseError(e) => e.fmt(formatter), ConfigError::ParseError(e) => e.fmt(formatter),
} }
} }
} }
impl std::error::Error for ConfigError {} impl std::error::Error for ConfigError {}
impl Configuration { impl Configuration {
pub fn load(data: &[u8]) -> Result<Configuration, toml::de::Error> { pub fn load(data: &[u8]) -> Result<Configuration, toml::de::Error> {
toml::from_slice(data) toml::from_slice(data)
} }
pub fn load_file(path: &str) -> Result<Configuration, ConfigError> { pub fn load_file(path: &str) -> Result<Configuration, ConfigError> {
match std::fs::read(path) { match std::fs::read(path) {
Err(e) => Err(ConfigError::IOError(e)), Err(e) => Err(ConfigError::IOError(e)),
Ok(data) => match Self::load(data.as_slice()) { Ok(data) => {
Ok(cfg) => Ok(cfg), match Self::load(data.as_slice()) {
Err(e) => Err(ConfigError::ParseError(e)), Ok(cfg) => Ok(cfg),
}, Err(e) => Err(ConfigError::ParseError(e)),
} }
} }
}
pub fn get_mode(&self) -> &TrackerMode { }
&self.mode
} pub fn get_mode(&self) -> &TrackerMode {
&self.mode
pub fn get_udp_config(&self) -> &UDPConfig { }
&self.udp
} pub fn get_udp_config(&self) -> &UDPConfig {
&self.udp
pub fn get_log_level(&self) -> &Option<String> { }
&self.log_level
} pub fn get_log_level(&self) -> &Option<String> {
&self.log_level
pub fn get_http_config(&self) -> &Option<HTTPConfig> { }
&self.http
} pub fn get_http_config(&self) -> Option<&HTTPConfig> {
self.http.as_ref()
pub fn get_db_path(&self) -> &Option<String> { }
&self.db_path
} pub fn get_db_path(&self) -> &Option<String> {
&self.db_path
pub fn get_cleanup_interval(&self) -> &Option<u64> { }
&self.cleanup_interval
} pub fn get_cleanup_interval(&self) -> Option<u64> {
} self.cleanup_interval
}
impl Default for Configuration { }
fn default() -> Configuration {
Configuration { impl Default for Configuration {
log_level: None, fn default() -> Configuration {
mode: TrackerMode::DynamicMode, Configuration {
udp: UDPConfig { log_level: None,
announce_interval: 120, mode: TrackerMode::DynamicMode,
bind_address: String::from("0.0.0.0:6969"), udp: UDPConfig {
}, announce_interval: 120,
http: None, bind_address: String::from("0.0.0.0:6969"),
db_path: None, },
cleanup_interval: None, http: None,
} db_path: None,
} cleanup_interval: None,
} }
}
}

View file

@ -1,10 +1,8 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
use clap; use clap;
use log::{trace, warn, info, debug, error};
use fern; use fern;
use num_cpus; use log::{error, info, trace, warn};
use lazy_static::lazy_static;
mod config; mod config;
mod server; mod server;
@ -15,25 +13,23 @@ mod webserver;
use config::Configuration; use config::Configuration;
use std::process::exit; use std::process::exit;
lazy_static!{
static ref term_mutex: std::sync::Arc<std::sync::atomic::AtomicBool> = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
}
fn setup_logging(cfg: &Configuration) { fn setup_logging(cfg: &Configuration) {
let log_level = match cfg.get_log_level() { let log_level = match cfg.get_log_level() {
None => log::LevelFilter::Info, None => log::LevelFilter::Info,
Some(level) => match level.as_str() { Some(level) => {
"off" => log::LevelFilter::Off, match level.as_str() {
"trace" => log::LevelFilter::Trace, "off" => log::LevelFilter::Off,
"debug" => log::LevelFilter::Debug, "trace" => log::LevelFilter::Trace,
"info" => log::LevelFilter::Info, "debug" => log::LevelFilter::Debug,
"warn" => log::LevelFilter::Warn, "info" => log::LevelFilter::Info,
"error" => log::LevelFilter::Error, "warn" => log::LevelFilter::Warn,
_ => { "error" => log::LevelFilter::Error,
eprintln!("udpt: unknown log level encountered '{}'", level.as_str()); _ => {
exit(-1); eprintln!("udpt: unknown log level encountered '{}'", level.as_str());
exit(-1);
}
} }
}, }
}; };
if let Err(err) = fern::Dispatch::new() if let Err(err) = fern::Dispatch::new()
@ -59,11 +55,8 @@ fn setup_logging(cfg: &Configuration) {
info!("logging initialized."); info!("logging initialized.");
} }
fn signal_termination() { #[tokio::main]
term_mutex.store(true, std::sync::atomic::Ordering::Relaxed); async fn main() {
}
fn main() {
let parser = clap::App::new(env!("CARGO_PKG_NAME")) let parser = clap::App::new(env!("CARGO_PKG_NAME"))
.about(env!("CARGO_PKG_DESCRIPTION")) .about(env!("CARGO_PKG_DESCRIPTION"))
.author(env!("CARGO_PKG_AUTHORS")) .author(env!("CARGO_PKG_AUTHORS"))
@ -95,17 +88,19 @@ fn main() {
if !file_path.exists() { if !file_path.exists() {
warn!("database file \"{}\" doesn't exist.", path); warn!("database file \"{}\" doesn't exist.", path);
tracker::TorrentTracker::new(cfg.get_mode().clone()) tracker::TorrentTracker::new(cfg.get_mode().clone())
} } else {
else { let mut input_file = match tokio::fs::File::open(file_path).await {
let mut input_file = match std::fs::File::open(file_path) {
Ok(v) => v, Ok(v) => v,
Err(err) => { Err(err) => {
error!("failed to open \"{}\". error: {}", path.as_str(), err); error!("failed to open \"{}\". error: {}", path.as_str(), err);
panic!("error opening file. check logs."); panic!("error opening file. check logs.");
} }
}; };
match tracker::TorrentTracker::load_database(cfg.get_mode().clone(), &mut input_file) { match tracker::TorrentTracker::load_database(cfg.get_mode().clone(), &mut input_file).await {
Ok(v) => v, Ok(v) => {
info!("database loaded.");
v
}
Err(err) => { Err(err) => {
error!("failed to load database. error: {}", err); error!("failed to load database. error: {}", err);
panic!("failed to load database. check logs."); panic!("failed to load database. check logs.");
@ -116,111 +111,67 @@ fn main() {
None => tracker::TorrentTracker::new(cfg.get_mode().clone()), None => tracker::TorrentTracker::new(cfg.get_mode().clone()),
}; };
let mut threads = Vec::new();
let tracker = std::sync::Arc::new(tracker_obj); let tracker = std::sync::Arc::new(tracker_obj);
let http_server = if cfg.get_http_config().is_some() { if cfg.get_http_config().is_some() {
let http_tracker_ref = tracker.clone(); let https_tracker = tracker.clone();
let cfg_ref = cfg.clone(); let http_cfg = cfg.clone();
Some(webserver::WebServer::new(http_tracker_ref, cfg_ref)) info!("Starting http server");
} else { tokio::spawn(async move {
None let http_cfg = http_cfg.get_http_config().unwrap();
}; let bind_addr = http_cfg.get_address();
let tokens = http_cfg.get_access_tokens();
let udp_server = std::sync::Arc::new(server::UDPTracker::new(cfg.clone(), tracker.clone()).unwrap()); let server = webserver::build_server(https_tracker, tokens.clone());
server.bind(bind_addr.parse::<std::net::SocketAddr>().unwrap()).await;
});
}
let mut udp_server = server::UDPTracker::new(cfg.clone(), tracker.clone())
.await
.expect("failed to bind udp socket");
trace!("Waiting for UDP packets"); trace!("Waiting for UDP packets");
let logical_cpus = num_cpus::get(); let udp_server = tokio::spawn(async move {
for i in 0..logical_cpus { loop {
debug!("starting thread {}/{}", i + 1, logical_cpus); if let Err(err) = udp_server.accept_packet().await {
let server_handle = udp_server.clone(); eprintln!("error: {}", err);
let thread_term_ref = term_mutex.clone();
threads.push(std::thread::spawn(move || loop {
match server_handle.accept_packet() {
Err(e) => {
if thread_term_ref.load(std::sync::atomic::Ordering::Relaxed) == true {
debug!("Thread terminating...");
break;
}
match e.kind() {
std::io::ErrorKind::TimedOut => {},
std::io::ErrorKind::WouldBlock => {},
_ => {
error!("Failed to process packet. {}", e);
}
}
}
Ok(_) => {}
} }
}));
}
match cfg.get_db_path() {
Some(db_path) => {
let db_p = db_path.clone();
let tracker_clone = tracker.clone();
let cleanup_interval = match *cfg.get_cleanup_interval() {
Some(v) => v,
None => 10 * 60,
};
let thread_term_mutex = term_mutex.clone();
threads.push(std::thread::spawn(move || {
let timeout = std::time::Duration::new(cleanup_interval, 0);
let timeout_start = std::time::Instant::now();
let mut timeout_remaining = timeout;
loop {
std::thread::park_timeout(std::time::Duration::new(cleanup_interval, 0));
if thread_term_mutex.load(std::sync::atomic::Ordering::Relaxed) {
debug!("Maintenance thread terminating.");
break;
}
let elapsed = std::time::Instant::now() - timeout_start;
if elapsed < timeout_remaining {
timeout_remaining = timeout - elapsed;
continue;
}
else {
timeout_remaining = timeout;
}
debug!("periodically saving database.");
tracker_clone.periodic_task(db_p.as_str());
debug!("database saved.");
}
}));
},
None => {}
}
loop {
if term_mutex.load(std::sync::atomic::Ordering::Relaxed) {
// termination signaled. start cleanup.
break;
} }
std::thread::sleep(std::time::Duration::from_secs(1)); });
}
match http_server {
Some(v) => v.shutdown(),
None => {},
};
while !threads.is_empty() {
if let Some(thread) = threads.pop() {
thread.thread().unpark();
let _ = thread.join();
}
}
let weak_tracker = std::sync::Arc::downgrade(&tracker);
if let Some(db_path) = cfg.get_db_path() { if let Some(db_path) = cfg.get_db_path() {
info!("running final cleanup & saving database..."); let db_path = db_path.clone();
tracker.periodic_task(db_path.as_str()); let interval = cfg.get_cleanup_interval().unwrap_or(600);
tokio::spawn(async move {
let interval = std::time::Duration::from_secs(interval);
let mut interval = tokio::time::interval(interval);
interval.tick().await; // first tick is immediate...
loop {
interval.tick().await;
if let Some(tracker) = weak_tracker.upgrade() {
tracker.periodic_task(&db_path).await;
} else {
break;
}
}
});
} }
let ctrl_c = tokio::signal::ctrl_c();
tokio::select! {
_ = udp_server => { warn!("udp server exited.") },
_ = ctrl_c => { info!("CTRL-C, exiting...") },
}
if let Some(path) = cfg.get_db_path() {
info!("saving database...");
tracker.periodic_task(path).await;
}
info!("goodbye."); info!("goodbye.");
} }

View file

@ -1,8 +1,9 @@
use log::{debug, error, trace};
use std; use std;
use std::io::Write; use std::io::Write;
use std::net::{SocketAddr, UdpSocket}; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use log::{error, trace, debug}; use tokio::net::UdpSocket;
use bincode; use bincode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -107,31 +108,18 @@ struct UDPScrapeResponseEntry {
} }
pub struct UDPTracker { pub struct UDPTracker {
server: std::net::UdpSocket, server: UdpSocket,
tracker: std::sync::Arc<tracker::TorrentTracker>, tracker: std::sync::Arc<tracker::TorrentTracker>,
config: Arc<Configuration>, config: Arc<Configuration>,
} }
impl UDPTracker { impl UDPTracker {
pub fn new( pub async fn new(
config: Arc<Configuration>, config: Arc<Configuration>, tracker: std::sync::Arc<tracker::TorrentTracker>,
tracker: std::sync::Arc<tracker::TorrentTracker>
) -> Result<UDPTracker, std::io::Error> { ) -> Result<UDPTracker, std::io::Error> {
let cfg = config.clone(); let cfg = config.clone();
let server = match UdpSocket::bind(cfg.get_udp_config().get_address()) { let server = UdpSocket::bind(cfg.get_udp_config().get_address()).await?;
Ok(s) => s,
Err(e) => {
return Err(e);
}
};
match server.set_read_timeout(Some(std::time::Duration::from_secs(1))) {
Ok(_) => {},
Err(err) => {
error!("Failed to set read timeout on socket; will try to continue anyway. err: {}", err);
}
}
Ok(UDPTracker { Ok(UDPTracker {
server, server,
@ -140,7 +128,8 @@ impl UDPTracker {
}) })
} }
fn handle_packet(&self, remote_address: &SocketAddr, payload: &[u8]) { // TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
async fn handle_packet(&mut self, remote_address: &SocketAddr, payload: &[u8]) {
let header: UDPRequestHeader = match unpack(payload) { let header: UDPRequestHeader = match unpack(payload) {
Some(val) => val, Some(val) => val,
None => { None => {
@ -150,9 +139,9 @@ impl UDPTracker {
}; };
match header.action { match header.action {
Actions::Connect => self.handle_connect(remote_address, &header, payload), Actions::Connect => self.handle_connect(remote_address, &header, payload).await,
Actions::Announce => self.handle_announce(remote_address, &header, payload), Actions::Announce => self.handle_announce(remote_address, &header, payload).await,
Actions::Scrape => self.handle_scrape(remote_address, &header, payload), Actions::Scrape => self.handle_scrape(remote_address, &header, payload).await,
_ => { _ => {
trace!("invalid action from {}", remote_address); trace!("invalid action from {}", remote_address);
// someone is playing around... ignore request. // someone is playing around... ignore request.
@ -161,7 +150,8 @@ impl UDPTracker {
} }
} }
fn handle_connect(&self, remote_addr: &SocketAddr, header: &UDPRequestHeader, _payload: &[u8]) { // TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
async fn handle_connect(&mut self, remote_addr: &SocketAddr, header: &UDPRequestHeader, _payload: &[u8]) {
if header.connection_id != PROTOCOL_ID { if header.connection_id != PROTOCOL_ID {
trace!("Bad protocol magic from {}", remote_addr); trace!("Bad protocol magic from {}", remote_addr);
return; return;
@ -178,15 +168,16 @@ impl UDPTracker {
connection_id: conn_id, connection_id: conn_id,
}; };
let mut payload_buffer = [0u8; MAX_PACKET_SIZE]; let mut payload_buffer = vec![0u8; MAX_PACKET_SIZE];
let mut payload = StackVec::from(&mut payload_buffer); let mut payload = StackVec::from(payload_buffer.as_mut_slice());
if let Ok(_) = pack_into(&mut payload, &response) { if let Ok(_) = pack_into(&mut payload, &response) {
let _ = self.send_packet(remote_addr, payload.as_slice()); let _ = self.send_packet(remote_addr, payload.as_slice()).await;
} }
} }
fn handle_announce(&self, remote_addr: &SocketAddr, header: &UDPRequestHeader, payload: &[u8]) { // TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
async fn handle_announce(&mut self, remote_addr: &SocketAddr, header: &UDPRequestHeader, payload: &[u8]) {
if header.connection_id != self.get_connection_id(remote_addr) { if header.connection_id != self.get_connection_id(remote_addr) {
return; return;
} }
@ -205,63 +196,57 @@ impl UDPTracker {
let bep41_payload = &payload[plen..]; let bep41_payload = &payload[plen..];
// TODO: process BEP0041 payload. // TODO: process BEP0041 payload.
trace!( trace!("BEP0041 payload of {} bytes from {}", bep41_payload.len(), remote_addr);
"BEP0041 payload of {} bytes from {}",
bep41_payload.len(),
remote_addr
);
} }
} }
if packet.ip_address != 0 { if packet.ip_address != 0 {
// TODO: allow configurability of ip address // TODO: allow configurability of ip address
// for now, ignore request. // for now, ignore request.
trace!( trace!("announce request for other IP ignored. (from {})", remote_addr);
"announce request for other IP ignored. (from {})",
remote_addr
);
return; return;
} }
let client_addr = SocketAddr::new(remote_addr.ip(), packet.port); let client_addr = SocketAddr::new(remote_addr.ip(), packet.port);
let info_hash = packet.info_hash.into(); let info_hash = packet.info_hash.into();
match self.tracker.update_torrent_and_get_stats( match self
&info_hash, .tracker
&packet.peer_id, .update_torrent_and_get_stats(
&client_addr, &info_hash,
packet.uploaded, &packet.peer_id,
packet.downloaded, &client_addr,
packet.left, packet.uploaded,
packet.event, packet.downloaded,
) { packet.left,
packet.event,
)
.await
{
tracker::TorrentStats::Stats { tracker::TorrentStats::Stats {
leechers, leechers,
complete: _, complete: _,
seeders, seeders,
} => { } => {
let peers = match self.tracker.get_torrent_peers(&info_hash, &client_addr) { let peers = match self.tracker.get_torrent_peers(&info_hash, &client_addr).await {
Some(v) => v, Some(v) => v,
None => { None => {
return; return;
} }
}; };
let mut payload_buffer = [0u8; MAX_PACKET_SIZE]; let mut payload_buffer = vec![0u8; MAX_PACKET_SIZE];
let mut payload = StackVec::from(&mut payload_buffer); let mut payload = StackVec::from(&mut payload_buffer);
match pack_into( match pack_into(&mut payload, &UDPAnnounceResponse {
&mut payload, header: UDPResponseHeader {
&UDPAnnounceResponse { action: Actions::Announce,
header: UDPResponseHeader { transaction_id: packet.header.transaction_id,
action: Actions::Announce,
transaction_id: packet.header.transaction_id,
},
seeders,
interval: self.config.get_udp_config().get_announce_interval(),
leechers,
}, },
) { seeders,
interval: self.config.get_udp_config().get_announce_interval(),
leechers,
}) {
Ok(_) => {} Ok(_) => {}
Err(_) => { Err(_) => {
return; return;
@ -279,24 +264,24 @@ impl UDPTracker {
}; };
let port_hton = client_addr.port().to_be(); let port_hton = client_addr.port().to_be();
let _ = let _ = payload.write(&[(port_hton & 0xff) as u8, ((port_hton >> 8) & 0xff) as u8]);
payload.write(&[(port_hton & 0xff) as u8, ((port_hton >> 8) & 0xff) as u8]);
} }
let _ = self.send_packet(&client_addr, payload.as_slice()); let _ = self.send_packet(&client_addr, payload.as_slice()).await;
} }
tracker::TorrentStats::TorrentFlagged => { tracker::TorrentStats::TorrentFlagged => {
self.send_error(&client_addr, &packet.header, "torrent flagged."); self.send_error(&client_addr, &packet.header, "torrent flagged.").await;
return; return;
} }
tracker::TorrentStats::TorrentNotRegistered => { tracker::TorrentStats::TorrentNotRegistered => {
self.send_error(&client_addr, &packet.header, "torrent not registered."); self.send_error(&client_addr, &packet.header, "torrent not registered.").await;
return; return;
} }
} }
} }
fn handle_scrape(&self, remote_addr: &SocketAddr, header: &UDPRequestHeader, payload: &[u8]) { // TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
async fn handle_scrape(&mut self, remote_addr: &SocketAddr, header: &UDPRequestHeader, payload: &[u8]) {
if header.connection_id != self.get_connection_id(remote_addr) { if header.connection_id != self.get_connection_id(remote_addr) {
return; return;
} }
@ -306,16 +291,17 @@ impl UDPTracker {
let mut response_buffer = [0u8; 8 + MAX_SCRAPE * 12]; let mut response_buffer = [0u8; 8 + MAX_SCRAPE * 12];
let mut response = StackVec::from(&mut response_buffer); let mut response = StackVec::from(&mut response_buffer);
if pack_into(&mut response, &UDPResponseHeader{ if pack_into(&mut response, &UDPResponseHeader {
action: Actions::Scrape, action: Actions::Scrape,
transaction_id: header.transaction_id, transaction_id: header.transaction_id,
}).is_err() { })
.is_err()
{
// not much we can do... // not much we can do...
error!("failed to encode udp scrape response header."); error!("failed to encode udp scrape response header.");
return; return;
} }
// skip first 16 bytes for header... // skip first 16 bytes for header...
let info_hash_array = &payload[16..]; let info_hash_array = &payload[16..];
@ -323,45 +309,47 @@ impl UDPTracker {
trace!("received weird length for scrape info_hash array (!mod20)."); trace!("received weird length for scrape info_hash array (!mod20).");
} }
let db = self.tracker.get_database(); {
let db = self.tracker.get_database().await;
for torrent_index in 0..MAX_SCRAPE { for torrent_index in 0..MAX_SCRAPE {
let info_hash_start = torrent_index * 20; let info_hash_start = torrent_index * 20;
let info_hash_end = (torrent_index + 1) * 20; let info_hash_end = (torrent_index + 1) * 20;
if info_hash_end > info_hash_array.len() { if info_hash_end > info_hash_array.len() {
break; break;
}
let info_hash = &info_hash_array[info_hash_start..info_hash_end];
let ih = tracker::InfoHash::from(info_hash);
let result = match db.get(&ih) {
Some(torrent_info) => {
let (seeders, completed, leechers) = torrent_info.get_stats();
UDPScrapeResponseEntry{
seeders,
completed,
leechers,
}
},
None => {
UDPScrapeResponseEntry{
seeders: 0,
completed: 0,
leechers: 0,
}
} }
};
if pack_into(&mut response, &result).is_err() { let info_hash = &info_hash_array[info_hash_start..info_hash_end];
debug!("failed to encode scrape entry."); let ih = tracker::InfoHash::from(info_hash);
return; let result = match db.get(&ih) {
Some(torrent_info) => {
let (seeders, completed, leechers) = torrent_info.get_stats();
UDPScrapeResponseEntry {
seeders,
completed,
leechers,
}
}
None => {
UDPScrapeResponseEntry {
seeders: 0,
completed: 0,
leechers: 0,
}
}
};
if pack_into(&mut response, &result).is_err() {
debug!("failed to encode scrape entry.");
return;
}
} }
} }
// if sending fails, not much we can do... // if sending fails, not much we can do...
let _ = self.send_packet(&remote_addr, &response.as_slice()); let _ = self.send_packet(&remote_addr, &response.as_slice()).await;
} }
fn get_connection_id(&self, remote_address: &SocketAddr) -> u64 { fn get_connection_id(&self, remote_address: &SocketAddr) -> u64 {
@ -371,43 +359,37 @@ impl UDPTracker {
} }
} }
fn send_packet( // TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
&self, async fn send_packet(&mut self, remote_addr: &SocketAddr, payload: &[u8]) -> Result<usize, std::io::Error> {
remote_addr: &SocketAddr, self.server.send_to(payload, remote_addr).await
payload: &[u8],
) -> Result<usize, std::io::Error> {
self.server.send_to(payload, remote_addr)
} }
fn send_error(&self, remote_addr: &SocketAddr, header: &UDPRequestHeader, error_msg: &str) { // TODO: remove `mut` once https://github.com/tokio-rs/tokio/issues/1624 is resolved
let mut payload_buffer = [0u8; MAX_PACKET_SIZE]; async fn send_error(&mut self, remote_addr: &SocketAddr, header: &UDPRequestHeader, error_msg: &str) {
let mut payload_buffer = vec![0u8; MAX_PACKET_SIZE];
let mut payload = StackVec::from(&mut payload_buffer); let mut payload = StackVec::from(&mut payload_buffer);
if let Ok(_) = pack_into( if let Ok(_) = pack_into(&mut payload, &UDPResponseHeader {
&mut payload, transaction_id: header.transaction_id,
&UDPResponseHeader { action: Actions::Error,
transaction_id: header.transaction_id, }) {
action: Actions::Error,
},
) {
let msg_bytes = Vec::from(error_msg.as_bytes()); let msg_bytes = Vec::from(error_msg.as_bytes());
payload.extend(msg_bytes); payload.extend(msg_bytes);
let _ = self.send_packet(remote_addr, payload.as_slice()); let _ = self.send_packet(remote_addr, payload.as_slice()).await;
} }
} }
pub fn accept_packet(&self) -> Result<(), std::io::Error> { // TODO: remove `mut` for `accept_packet`, and spawn once https://github.com/tokio-rs/tokio/issues/1624 is resolved
let mut packet = [0u8; MAX_PACKET_SIZE]; pub async fn accept_packet(&mut self) -> Result<(), std::io::Error> {
match self.server.recv_from(&mut packet) { let mut packet = vec![0u8; MAX_PACKET_SIZE];
Ok((size, remote_address)) => { let (size, remote_address) = self.server.recv_from(packet.as_mut_slice()).await?;
debug!("Received {} bytes from {}", size, remote_address);
self.handle_packet(&remote_address, &packet[..size]);
Ok(()) // tokio::spawn(async {
} debug!("Received {} bytes from {}", size, remote_address);
Err(e) => Err(e), self.handle_packet(&remote_address, &packet[..size]).await;
} // });
Ok(())
} }
} }
@ -426,10 +408,7 @@ mod tests {
assert!(pack_into(&mut payload, &mystruct).is_ok()); assert!(pack_into(&mut payload, &mystruct).is_ok());
assert_eq!(payload.len(), 16); assert_eq!(payload.len(), 16);
assert_eq!( assert_eq!(payload.as_slice(), &[0, 0, 0, 0, 0, 0, 0, 200u8, 0, 0, 0, 0, 0, 1, 47, 203]);
payload.as_slice(),
&[0, 0, 0, 0, 0, 0, 0, 200u8, 0, 0, 0, 0, 0, 1, 47, 203]
);
} }
#[test] #[test]

View file

@ -10,10 +10,6 @@ impl<'a, T> StackVec<'a, T> {
StackVec { data, length: 0 } StackVec { data, length: 0 }
} }
pub fn len(&self) -> usize {
self.length
}
pub fn as_slice(&self) -> &[T] { pub fn as_slice(&self) -> &[T] {
&self.data[0..self.length] &self.data[0..self.length]
} }

View file

@ -1,7 +1,11 @@
use crate::server::Events; use crate::server::Events;
use serde::{Serialize, Deserialize};
use log::{error, trace}; use log::{error, trace};
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::BTreeMap;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use tokio::stream::StreamExt;
use tokio::sync::RwLock;
#[derive(Deserialize, Clone, PartialEq)] #[derive(Deserialize, Clone, PartialEq)]
pub enum TrackerMode { pub enum TrackerMode {
@ -18,6 +22,7 @@ pub enum TrackerMode {
PrivateMode, PrivateMode,
} }
#[derive(Clone)]
struct TorrentPeer { struct TorrentPeer {
ip: std::net::SocketAddr, ip: std::net::SocketAddr,
uploaded: u64, uploaded: u64,
@ -32,6 +37,27 @@ pub struct InfoHash {
info_hash: [u8; 20], info_hash: [u8; 20],
} }
impl std::fmt::Display for InfoHash {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let mut chars = [0u8; 40];
binascii::bin2hex(&self.info_hash, &mut chars).expect("failed to hexlify");
write!(f, "{}", std::str::from_utf8(&chars).unwrap())
}
}
impl std::str::FromStr for InfoHash {
type Err = binascii::ConvertError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut i = Self { info_hash: [0u8; 20] };
if s.len() != 40 {
return Err(binascii::ConvertError::InvalidInputLength);
}
binascii::hex2bin(s.as_bytes(), &mut i.info_hash)?;
Ok(i)
}
}
impl std::cmp::PartialOrd<InfoHash> for InfoHash { impl std::cmp::PartialOrd<InfoHash> for InfoHash {
fn partial_cmp(&self, other: &InfoHash) -> Option<std::cmp::Ordering> { fn partial_cmp(&self, other: &InfoHash) -> Option<std::cmp::Ordering> {
self.info_hash.partial_cmp(&other.info_hash) self.info_hash.partial_cmp(&other.info_hash)
@ -41,9 +67,7 @@ impl std::cmp::PartialOrd<InfoHash> for InfoHash {
impl std::convert::From<&[u8]> for InfoHash { impl std::convert::From<&[u8]> for InfoHash {
fn from(data: &[u8]) -> InfoHash { fn from(data: &[u8]) -> InfoHash {
assert_eq!(data.len(), 20); assert_eq!(data.len(), 20);
let mut ret = InfoHash{ let mut ret = InfoHash { info_hash: [0u8; 20] };
info_hash: [0u8; 20],
};
ret.info_hash.clone_from_slice(data); ret.info_hash.clone_from_slice(data);
return ret; return ret;
} }
@ -58,9 +82,7 @@ impl std::convert::Into<InfoHash> for [u8; 20] {
impl serde::ser::Serialize for InfoHash { impl serde::ser::Serialize for InfoHash {
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut buffer = [0u8; 40]; let mut buffer = [0u8; 40];
let bytes_out = binascii::bin2hex(&self.info_hash, &mut buffer) let bytes_out = binascii::bin2hex(&self.info_hash, &mut buffer).ok().unwrap();
.ok()
.unwrap();
let str_out = std::str::from_utf8(bytes_out).unwrap(); let str_out = std::str::from_utf8(bytes_out).unwrap();
serializer.serialize_str(str_out) serializer.serialize_str(str_out)
@ -84,9 +106,7 @@ impl<'v> serde::de::Visitor<'v> for InfoHashVisitor {
)); ));
} }
let mut res = InfoHash { let mut res = InfoHash { info_hash: [0u8; 20] };
info_hash: [0u8; 20],
};
if let Err(_) = binascii::hex2bin(v.as_bytes(), &mut res.info_hash) { if let Err(_) = binascii::hex2bin(v.as_bytes(), &mut res.info_hash) {
return Err(serde::de::Error::invalid_value( return Err(serde::de::Error::invalid_value(
@ -107,7 +127,7 @@ impl<'de> serde::de::Deserialize<'de> for InfoHash {
pub type PeerId = [u8; 20]; pub type PeerId = [u8; 20];
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
pub struct TorrentEntry { pub struct TorrentEntry {
is_flagged: bool, is_flagged: bool,
@ -135,28 +155,20 @@ impl TorrentEntry {
} }
pub fn update_peer( pub fn update_peer(
&mut self, &mut self, peer_id: &PeerId, remote_address: &std::net::SocketAddr, uploaded: u64, downloaded: u64, left: u64,
peer_id: &PeerId,
remote_address: &std::net::SocketAddr,
uploaded: u64,
downloaded: u64,
left: u64,
event: Events, event: Events,
) { ) {
let is_seeder = left == 0 && uploaded > 0; let is_seeder = left == 0 && uploaded > 0;
let mut was_seeder = false; let mut was_seeder = false;
let mut is_completed = left == 0 && (event as u32) == (Events::Complete as u32); let mut is_completed = left == 0 && (event as u32) == (Events::Complete as u32);
if let Some(prev) = self.peers.insert( if let Some(prev) = self.peers.insert(*peer_id, TorrentPeer {
*peer_id, updated: std::time::SystemTime::now(),
TorrentPeer { left,
updated: std::time::SystemTime::now(), downloaded,
left, uploaded,
downloaded, ip: *remote_address,
uploaded, event,
ip: *remote_address, }) {
event,
},
) {
was_seeder = prev.left == 0 && prev.uploaded > 0; was_seeder = prev.left == 0 && prev.uploaded > 0;
if is_completed && (prev.event as u32) == (Events::Complete as u32) { if is_completed && (prev.event as u32) == (Events::Complete as u32) {
@ -200,13 +212,13 @@ impl TorrentEntry {
} }
struct TorrentDatabase { struct TorrentDatabase {
torrent_peers: std::sync::RwLock<std::collections::BTreeMap<InfoHash, TorrentEntry>>, torrent_peers: tokio::sync::RwLock<std::collections::BTreeMap<InfoHash, TorrentEntry>>,
} }
impl Default for TorrentDatabase { impl Default for TorrentDatabase {
fn default() -> Self { fn default() -> Self {
TorrentDatabase { TorrentDatabase {
torrent_peers: std::sync::RwLock::new(std::collections::BTreeMap::new()), torrent_peers: tokio::sync::RwLock::new(std::collections::BTreeMap::new()),
} }
} }
} }
@ -216,14 +228,16 @@ pub struct TorrentTracker {
database: TorrentDatabase, database: TorrentDatabase,
} }
#[derive(Serialize, Deserialize)]
struct DatabaseRow<'a> {
info_hash: InfoHash,
entry: Cow<'a, TorrentEntry>,
}
pub enum TorrentStats { pub enum TorrentStats {
TorrentFlagged, TorrentFlagged,
TorrentNotRegistered, TorrentNotRegistered,
Stats { Stats { seeders: u32, leechers: u32, complete: u32 },
seeders: u32,
leechers: u32,
complete: u32,
},
} }
impl TorrentTracker { impl TorrentTracker {
@ -231,32 +245,52 @@ impl TorrentTracker {
TorrentTracker { TorrentTracker {
mode, mode,
database: TorrentDatabase { database: TorrentDatabase {
torrent_peers: std::sync::RwLock::new(std::collections::BTreeMap::new()), torrent_peers: RwLock::new(std::collections::BTreeMap::new()),
}, },
} }
} }
pub fn load_database<R: std::io::Read>( pub async fn load_database<R: tokio::io::AsyncRead + Unpin>(
mode: TrackerMode, mode: TrackerMode, reader: &mut R,
reader: &mut R, ) -> Result<TorrentTracker, std::io::Error> {
) -> serde_json::Result<TorrentTracker> { let reader = tokio::io::BufReader::new(reader);
let decomp_reader = bzip2::read::BzDecoder::new(reader); let mut tmp: Vec<u8> = Vec::with_capacity(4096);
let result: serde_json::Result<std::collections::BTreeMap<InfoHash, TorrentEntry>> = tmp.resize(tmp.capacity(), 0);
serde_json::from_reader(decomp_reader);
match result { let res = TorrentTracker::new(mode);
Ok(v) => Ok(TorrentTracker { let mut db = res.database.torrent_peers.write().await;
mode,
database: TorrentDatabase { let mut records = reader
torrent_peers: std::sync::RwLock::new(v), .lines()
}, .filter_map(|res| {
}), if let Err(ref err) = res {
Err(e) => Err(e), error!("failed to read lines! {}", err);
}
res.ok()
})
.map(|line| serde_json::from_str::<DatabaseRow>(&line))
.filter_map(|jsr| {
if let Err(ref err) = jsr {
error!("failed to parse json: {}", err);
}
jsr.ok()
});
while let Some(entry) = records.next().await {
let x = || (entry.entry, entry.info_hash);
let (a, b) = x();
let a = a.into_owned();
db.insert(b, a);
} }
drop(db);
Ok(res)
} }
/// Adding torrents is not relevant to dynamic trackers. /// Adding torrents is not relevant to dynamic trackers.
pub fn add_torrent(&self, info_hash: &InfoHash) -> Result<(), ()> { pub async fn add_torrent(&self, info_hash: &InfoHash) -> Result<(), ()> {
let mut write_lock = self.database.torrent_peers.write().unwrap(); let mut write_lock = self.database.torrent_peers.write().await;
match write_lock.entry(info_hash.clone()) { match write_lock.entry(info_hash.clone()) {
std::collections::btree_map::Entry::Vacant(ve) => { std::collections::btree_map::Entry::Vacant(ve) => {
ve.insert(TorrentEntry::new()); ve.insert(TorrentEntry::new());
@ -269,9 +303,9 @@ impl TorrentTracker {
} }
/// If the torrent is flagged, it will not be removed unless force is set to true. /// If the torrent is flagged, it will not be removed unless force is set to true.
pub fn remove_torrent(&self, info_hash: &InfoHash, force: bool) -> Result<(), ()> { pub async fn remove_torrent(&self, info_hash: &InfoHash, force: bool) -> Result<(), ()> {
use std::collections::btree_map::Entry; use std::collections::btree_map::Entry;
let mut entry_lock = self.database.torrent_peers.write().unwrap(); let mut entry_lock = self.database.torrent_peers.write().await;
let torrent_entry = entry_lock.entry(info_hash.clone()); let torrent_entry = entry_lock.entry(info_hash.clone());
match torrent_entry { match torrent_entry {
Entry::Vacant(_) => { Entry::Vacant(_) => {
@ -289,28 +323,23 @@ impl TorrentTracker {
} }
/// flagged torrents will result in a tracking error. This is to allow enforcement against piracy. /// flagged torrents will result in a tracking error. This is to allow enforcement against piracy.
pub fn set_torrent_flag(&self, info_hash: &InfoHash, is_flagged: bool) { pub async fn set_torrent_flag(&self, info_hash: &InfoHash, is_flagged: bool) -> bool {
if let Some(entry) = self if let Some(entry) = self.database.torrent_peers.write().await.get_mut(info_hash) {
.database
.torrent_peers
.write()
.unwrap()
.get_mut(info_hash)
{
if is_flagged && !entry.is_flagged { if is_flagged && !entry.is_flagged {
// empty peer list. // empty peer list.
entry.peers.clear(); entry.peers.clear();
} }
entry.is_flagged = is_flagged; entry.is_flagged = is_flagged;
true
} else {
false
} }
} }
pub fn get_torrent_peers( pub async fn get_torrent_peers(
&self, &self, info_hash: &InfoHash, remote_addr: &std::net::SocketAddr,
info_hash: &InfoHash,
remote_addr: &std::net::SocketAddr,
) -> Option<Vec<std::net::SocketAddr>> { ) -> Option<Vec<std::net::SocketAddr>> {
let read_lock = self.database.torrent_peers.read().unwrap(); let read_lock = self.database.torrent_peers.read().await;
match read_lock.get(info_hash) { match read_lock.get(info_hash) {
None => { None => {
return None; return None;
@ -321,25 +350,21 @@ impl TorrentTracker {
}; };
} }
pub fn update_torrent_and_get_stats( pub async fn update_torrent_and_get_stats(
&self, &self, info_hash: &InfoHash, peer_id: &PeerId, remote_address: &std::net::SocketAddr, uploaded: u64,
info_hash: &InfoHash, downloaded: u64, left: u64, event: Events,
peer_id: &PeerId,
remote_address: &std::net::SocketAddr,
uploaded: u64,
downloaded: u64,
left: u64,
event: Events,
) -> TorrentStats { ) -> TorrentStats {
use std::collections::btree_map::Entry; use std::collections::btree_map::Entry;
let mut torrent_peers = self.database.torrent_peers.write().unwrap(); let mut torrent_peers = self.database.torrent_peers.write().await;
let torrent_entry = match torrent_peers.entry(info_hash.clone()) { let torrent_entry = match torrent_peers.entry(info_hash.clone()) {
Entry::Vacant(vacant) => match self.mode { Entry::Vacant(vacant) => {
TrackerMode::DynamicMode => vacant.insert(TorrentEntry::new()), match self.mode {
_ => { TrackerMode::DynamicMode => vacant.insert(TorrentEntry::new()),
return TorrentStats::TorrentNotRegistered; _ => {
return TorrentStats::TorrentNotRegistered;
}
} }
}, }
Entry::Occupied(entry) => { Entry::Occupied(entry) => {
if entry.get().is_flagged() { if entry.get().is_flagged() {
return TorrentStats::TorrentFlagged; return TorrentStats::TorrentFlagged;
@ -359,70 +384,75 @@ impl TorrentTracker {
}; };
} }
pub(crate) fn get_database( pub(crate) async fn get_database<'a>(&'a self) -> tokio::sync::RwLockReadGuard<'a, BTreeMap<InfoHash, TorrentEntry>> {
&self, self.database.torrent_peers.read().await
) -> std::sync::RwLockReadGuard<std::collections::BTreeMap<InfoHash, TorrentEntry>> {
self.database.torrent_peers.read().unwrap()
} }
pub fn save_database<W: std::io::Write>(&self, writer: &mut W) -> serde_json::Result<()> { pub async fn save_database<W: tokio::io::AsyncWrite + Unpin>(&self, mut writer: W) -> Result<(), std::io::Error> {
let compressor = bzip2::write::BzEncoder::new(writer, bzip2::Compression::Best); // TODO: find async friendly compressor
let db_lock = self.database.torrent_peers.read().unwrap(); let db_lock = self.database.torrent_peers.read().await;
let db = &*db_lock; let db: &BTreeMap<InfoHash, TorrentEntry> = &*db_lock;
let mut tmp = Vec::with_capacity(4096);
serde_json::to_writer(compressor, &db) for row in db {
let entry = DatabaseRow {
info_hash: row.0.clone(),
entry: Cow::Borrowed(row.1),
};
tmp.clear();
if let Err(err) = serde_json::to_writer(&mut tmp, &entry) {
error!("failed to serialize: {}", err);
continue;
};
tmp.push(b'\n');
writer.write_all(&tmp).await?;
}
Ok(())
} }
fn cleanup(&self) { async fn cleanup(&self) {
use std::ops::Add; use std::ops::Add;
let now = std::time::SystemTime::now(); let now = std::time::SystemTime::now();
match self.database.torrent_peers.write() { let mut lock = self.database.torrent_peers.write().await;
Err(err) => { let db: &mut BTreeMap<InfoHash, TorrentEntry> = &mut *lock;
error!("failed to obtain write lock on database. err: {}", err); let mut torrents_to_remove = Vec::new();
return;
}
Ok(mut db) => {
let mut torrents_to_remove = Vec::new();
for (k, v) in db.iter_mut() { for (k, v) in db.iter_mut() {
// timed-out peers.. // timed-out peers..
{ {
let mut peers_to_remove = Vec::new(); let mut peers_to_remove = Vec::new();
let torrent_peers = &mut v.peers; let torrent_peers = &mut v.peers;
for (peer_id, state) in torrent_peers.iter() { for (peer_id, state) in torrent_peers.iter() {
if state.updated.add(std::time::Duration::new(3600 * 2, 0)) < now { if state.updated.add(std::time::Duration::new(3600 * 2, 0)) < now {
// over 2 hours past since last update... // over 2 hours past since last update...
peers_to_remove.push(*peer_id); peers_to_remove.push(*peer_id);
}
}
for peer_id in peers_to_remove.iter() {
torrent_peers.remove(peer_id);
}
}
if self.mode == TrackerMode::DynamicMode {
// peer-less torrents..
if v.peers.len() == 0 {
torrents_to_remove.push(k.clone());
}
} }
} }
for info_hash in torrents_to_remove { for peer_id in peers_to_remove.iter() {
db.remove(&info_hash); torrent_peers.remove(peer_id);
}
}
if self.mode == TrackerMode::DynamicMode {
// peer-less torrents..
if v.peers.len() == 0 {
torrents_to_remove.push(k.clone());
} }
} }
} }
for info_hash in torrents_to_remove {
db.remove(&info_hash);
}
} }
pub fn periodic_task(&self, db_path: &str) { pub async fn periodic_task(&self, db_path: &str) {
// cleanup db // cleanup db
self.cleanup(); self.cleanup().await;
// save journal db. // save journal db.
let mut journal_path = std::path::PathBuf::from(db_path); let mut journal_path = std::path::PathBuf::from(db_path);
@ -435,7 +465,7 @@ impl TorrentTracker {
// scope to make sure backup file is dropped/closed. // scope to make sure backup file is dropped/closed.
{ {
let mut file = match std::fs::File::create(jp_str) { let mut file = match tokio::fs::File::create(jp_str).await {
Err(err) => { Err(err) => {
error!("failed to open file '{}': {}", db_path, err); error!("failed to open file '{}': {}", db_path, err);
return; return;
@ -443,7 +473,7 @@ impl TorrentTracker {
Ok(v) => v, Ok(v) => v,
}; };
trace!("writing database to {}", jp_str); trace!("writing database to {}", jp_str);
if let Err(err) = self.save_database(&mut file) { if let Err(err) = self.save_database(&mut file).await {
error!("failed saving database. {}", err); error!("failed saving database. {}", err);
return; return;
} }
@ -451,7 +481,7 @@ impl TorrentTracker {
// overwrite previous db // overwrite previous db
trace!("renaming '{}' to '{}'", jp_str, db_path); trace!("renaming '{}' to '{}'", jp_str, db_path);
if let Err(err) = std::fs::rename(jp_str, db_path) { if let Err(err) = tokio::fs::rename(jp_str, db_path).await {
error!("failed to move db backup. {}", err); error!("failed to move db backup. {}", err);
} }
} }
@ -474,14 +504,14 @@ mod tests {
is_sync::<TorrentTracker>(); is_sync::<TorrentTracker>();
} }
#[test] #[tokio::test]
fn test_save_db() { async fn test_save_db() {
let tracker = TorrentTracker::new(TrackerMode::DynamicMode); let tracker = TorrentTracker::new(TrackerMode::DynamicMode);
tracker.add_torrent(&[0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0].into()); tracker.add_torrent(&[0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0].into());
let mut out = Vec::new(); let mut out = Vec::new();
tracker.save_database(&mut out).unwrap(); tracker.save_database(&mut out).await.expect("db save failed");
assert!(out.len() > 0); assert!(out.len() > 0);
} }

View file

@ -1,419 +1,193 @@
use std::collections::HashMap; use crate::tracker::{InfoHash, TorrentTracker};
use std::sync::Arc; use serde::{Deserialize, Serialize};
use std::cmp::min;
use actix_web; use std::collections::{HashMap, HashSet};
use actix_net; use std::sync::Arc;
use binascii; use warp::{filters, reply, reply::Reply, serve, Filter, Server};
use crate::config; fn view_root() -> impl Reply {
use crate::tracker::TorrentTracker; reply::html(concat!(
use log::{error, warn, info, debug}; r#"<html>
<head>
const SERVER: &str = concat!("udpt/", env!("CARGO_PKG_VERSION")); <title>udpt/"#,
env!("CARGO_PKG_VERSION"),
pub struct WebServer { r#"</title>
thread: std::thread::JoinHandle<()>, </head>
addr: Option<actix_web::actix::Addr<actix_net::server::Server>>, <body>
} This is your <a href="https://github.com/naim94a/udpt">udpt</a> torrent tracker.
</body>
mod http_responses { </html>"#
use serde::Serialize; ))
use crate::tracker::InfoHash; }
#[derive(Serialize)] #[derive(Deserialize, Debug)]
pub struct TorrentInfo { struct TorrentInfoQuery {
pub is_flagged: bool, offset: Option<u32>,
pub leecher_count: u32, limit: Option<u32>,
pub seeder_count: u32, }
pub completed: u32,
} #[derive(Serialize)]
struct TorrentEntry<'a> {
#[derive(Serialize)] info_hash: &'a InfoHash,
pub struct TorrentList { #[serde(flatten)]
pub offset: u32, data: &'a crate::tracker::TorrentEntry,
pub length: u32, }
pub total: u32,
pub torrents: Vec<InfoHash>, #[derive(Serialize, Deserialize)]
} struct TorrentFlag {
is_flagged: bool,
#[derive(Serialize)] }
#[serde(rename_all = "snake_case")]
pub enum APIResponse { #[derive(Serialize, Debug)]
Error(String), #[serde(tag = "status", rename_all = "snake_case")]
TorrentList(TorrentList), enum ActionStatus<'a> {
TorrentInfo(TorrentInfo), Ok,
} Err { reason: std::borrow::Cow<'a, str> },
} }
struct UdptState { impl warp::reject::Reject for ActionStatus<'static> {}
// k=token, v=username.
access_tokens: HashMap<String, String>, fn authenticate(tokens: HashMap<String, String>) -> impl Filter<Extract = (), Error = warp::reject::Rejection> + Clone {
tracker: Arc<TorrentTracker>, #[derive(Deserialize)]
} struct AuthToken {
token: Option<String>,
impl UdptState { }
fn new(tracker: Arc<TorrentTracker>, tokens: HashMap<String, String>) -> UdptState {
UdptState { let tokens: HashSet<String> = tokens.into_iter().map(|(_, v)| v).collect();
tracker,
access_tokens: tokens, let tokens = Arc::new(tokens);
} warp::filters::any::any()
} .map(move || tokens.clone())
} .and(filters::query::query::<AuthToken>())
.and_then(|tokens: Arc<HashSet<String>>, token: AuthToken| {
#[derive(Debug)] async move {
struct UdptRequestState { if let Some(token) = token.token {
current_user: Option<String>, if tokens.contains(&token) {
} return Ok(());
}
impl Default for UdptRequestState { }
fn default() -> Self { Err(warp::reject::custom(ActionStatus::Err {
UdptRequestState { reason: "Access Denied".into(),
current_user: Option::None, }))
} }
} })
} .untuple_one()
}
impl UdptRequestState {
fn get_user<S>(req: &actix_web::HttpRequest<S>) -> Option<String> { pub fn build_server(
let exts = req.extensions(); tracker: Arc<TorrentTracker>, tokens: HashMap<String, String>,
let req_state: Option<&UdptRequestState> = exts.get(); ) -> Server<impl Filter<Extract = impl Reply> + Clone + Send + Sync + 'static> {
match req_state { let root = filters::path::end().map(|| view_root());
None => None,
Option::Some(state) => match state.current_user { let t1 = tracker.clone();
Option::Some(ref v) => Option::Some(v.clone()), // view_torrent_list -> GET /t/?offset=:u32&limit=:u32 HTTP/1.1
None => { let view_torrent_list = filters::path::end()
error!( .and(filters::method::get())
"Invalid API token from {} @ {}", .and(filters::query::query())
req.peer_addr().unwrap(), .map(move |limits| {
req.path() let tracker = t1.clone();
); (limits, tracker)
return None; })
} .and_then(|(limits, tracker): (TorrentInfoQuery, Arc<TorrentTracker>)| {
}, async move {
} let offset = limits.offset.unwrap_or(0);
} let limit = min(limits.limit.unwrap_or(1000), 4000);
}
let db = tracker.get_database().await;
struct UdptMiddleware; let results: Vec<_> = db
.iter()
impl actix_web::middleware::Middleware<UdptState> for UdptMiddleware { .map(|(k, v)| TorrentEntry { info_hash: k, data: v })
fn start( .skip(offset as usize)
&self, .take(limit as usize)
req: &actix_web::HttpRequest<UdptState>, .collect();
) -> actix_web::Result<actix_web::middleware::Started> {
let mut req_state = UdptRequestState::default(); Result::<_, warp::reject::Rejection>::Ok(reply::json(&results))
if let Option::Some(token) = req.query().get("token") { }
let app_state: &UdptState = req.state(); });
if let Option::Some(v) = app_state.access_tokens.get(token) {
req_state.current_user = Option::Some(v.clone()); let t2 = tracker.clone();
} // view_torrent_info -> GET /t/:infohash HTTP/*
} let view_torrent_info = filters::method::get()
req.extensions_mut().insert(req_state); .and(filters::path::param())
Ok(actix_web::middleware::Started::Done) .map(move |info_hash: InfoHash| {
} let tracker = t2.clone();
(info_hash, tracker)
fn response( })
&self, .and_then(|(info_hash, tracker): (InfoHash, Arc<TorrentTracker>)| {
_req: &actix_web::HttpRequest<UdptState>, async move {
mut resp: actix_web::HttpResponse, let db = tracker.get_database().await;
) -> actix_web::Result<actix_web::middleware::Response> { let info = match db.get(&info_hash) {
resp.headers_mut().insert( Some(v) => v,
actix_web::http::header::SERVER, None => return Err(warp::reject::reject()),
actix_web::http::header::HeaderValue::from_static(SERVER), };
);
Ok(reply::json(&TorrentEntry {
Ok(actix_web::middleware::Response::Done(resp)) info_hash: &info_hash,
} data: info,
} }))
}
impl WebServer { });
fn get_access_tokens(cfg: &config::HTTPConfig, tokens: &mut HashMap<String, String>) {
for (user, token) in cfg.get_access_tokens().iter() { // DELETE /t/:info_hash
tokens.insert(token.clone(), user.clone()); let t3 = tracker.clone();
} let delete_torrent = filters::method::post()
if tokens.len() == 0 { .and(filters::path::param())
warn!("No access tokens provided. HTTP API will not be useful."); .map(move |info_hash: InfoHash| {
} let tracker = t3.clone();
} (info_hash, tracker)
})
pub fn shutdown(self) { .and_then(|(info_hash, tracker): (InfoHash, Arc<TorrentTracker>)| {
match self.addr { async move {
Some(v) => { let resp = match tracker.remove_torrent(&info_hash, true).await.is_ok() {
use futures::future::Future; true => ActionStatus::Ok,
false => {
v.send(actix_web::actix::signal::Signal(actix_web::actix::signal::SignalType::Term)).wait().unwrap(); ActionStatus::Err {
}, reason: "failed to delete torrent".into(),
None => {}, }
}; }
};
self.thread.thread().unpark();
let _ = self.thread.join(); Result::<_, warp::Rejection>::Ok(reply::json(&resp))
} }
});
pub fn new(
tracker: Arc<TorrentTracker>, let t4 = tracker.clone();
cfg: Arc<config::Configuration>, // add_torrent/alter: POST /t/:info_hash
) -> WebServer { // (optional) BODY: json: {"is_flagged": boolean}
let cfg_cp = cfg.clone(); let change_torrent = filters::method::post()
.and(filters::path::param())
let (tx_addr, rx_addr) = std::sync::mpsc::channel(); .and(filters::body::content_length_limit(4096))
.and(filters::body::json())
let thread = std::thread::spawn(move || { .map(move |info_hash: InfoHash, body: Option<TorrentFlag>| {
let server = actix_web::server::HttpServer::new(move || { let tracker = t4.clone();
let mut access_tokens = HashMap::new(); (info_hash, tracker, body)
})
if let Some(http_cfg) = cfg_cp.get_http_config() { .and_then(
Self::get_access_tokens(http_cfg, &mut access_tokens); |(info_hash, tracker, body): (InfoHash, Arc<TorrentTracker>, Option<TorrentFlag>)| {
} async move {
let is_flagged = body.map(|e| e.is_flagged).unwrap_or(false);
let state = UdptState::new(tracker.clone(), access_tokens); if !tracker.set_torrent_flag(&info_hash, is_flagged).await {
// torrent doesn't exist, add it...
actix_web::App::<UdptState>::with_state(state)
.middleware(UdptMiddleware) if is_flagged {
.resource("/t", |r| r.f(Self::view_torrent_list)) if tracker.add_torrent(&info_hash).await.is_ok() {
.scope(r"/t/{info_hash:[\dA-Fa-f]{40,40}}", |scope| { tracker.set_torrent_flag(&info_hash, is_flagged).await;
scope.resource("", |r| { } else {
r.method(actix_web::http::Method::GET) return Err(warp::reject::custom(ActionStatus::Err {
.f(Self::view_torrent_stats); reason: "failed to flag torrent".into(),
r.method(actix_web::http::Method::POST) }));
.f(Self::torrent_action); }
}) }
}) }
.resource("/", |r| {
r.method(actix_web::http::Method::GET).f(Self::view_root) Result::<_, warp::Rejection>::Ok(reply::json(&ActionStatus::Ok))
}) }
}); },
);
if let Some(http_cfg) = cfg.get_http_config() { let torrent_mgmt =
let bind_addr = http_cfg.get_address(); filters::path::path("t").and(view_torrent_list.or(delete_torrent).or(view_torrent_info).or(change_torrent));
match server.bind(bind_addr) {
Ok(v) => { let server = root.or(authenticate(tokens).and(torrent_mgmt));
let sys = actix_web::actix::System::new("http-server");
let addr = v.start(); serve(server)
let _ = tx_addr.send(addr); }
sys.run();
}
Err(err) => {
error!("Failed to bind http server. {}", err);
}
}
} else {
unreachable!();
}
});
let addr = match rx_addr.recv() {
Ok(v) => Some(v),
Err(_) => None
};
WebServer {
thread,
addr,
}
}
fn view_root(_req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse {
actix_web::HttpResponse::build(actix_web::http::StatusCode::OK)
.content_type("text/html")
.body(r#"Powered by <a href="https://github.com/naim94a/udpt">https://github.com/naim94a/udpt</a>"#)
}
fn view_torrent_list(req: &actix_web::HttpRequest<UdptState>) -> impl actix_web::Responder {
use std::str::FromStr;
if UdptRequestState::get_user(req).is_none() {
return actix_web::Json(http_responses::APIResponse::Error(String::from(
"access_denied",
)));
}
let req_offset = match req.query().get("offset") {
None => 0,
Some(v) => match u32::from_str(v.as_str()) {
Ok(v) => v,
Err(_) => 0,
},
};
let mut req_limit = match req.query().get("limit") {
None => 0,
Some(v) => match u32::from_str(v.as_str()) {
Ok(v) => v,
Err(_) => 0,
},
};
if req_limit > 4096 {
req_limit = 4096;
} else if req_limit == 0 {
req_limit = 1000;
}
let app_state: &UdptState = req.state();
let app_db = app_state.tracker.get_database();
let total = app_db.len() as u32;
let mut torrents = Vec::with_capacity(req_limit as usize);
for (info_hash, _) in app_db
.iter()
.skip(req_offset as usize)
.take(req_limit as usize)
{
torrents.push(info_hash.clone());
}
actix_web::Json(http_responses::APIResponse::TorrentList(
http_responses::TorrentList {
total,
length: torrents.len() as u32,
offset: req_offset,
torrents,
},
))
}
fn view_torrent_stats(req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse {
use actix_web::FromRequest;
if UdptRequestState::get_user(req).is_none() {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::UNAUTHORIZED).json(
http_responses::APIResponse::Error(String::from("access_denied")),
);
}
let path: actix_web::Path<String> = match actix_web::Path::extract(req) {
Ok(v) => v,
Err(_) => {
return actix_web::HttpResponse::build(
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
)
.json(http_responses::APIResponse::Error(String::from(
"internal_error",
)));
}
};
let mut info_hash = [0u8; 20];
if let Err(_) = binascii::hex2bin((*path).as_bytes(), &mut info_hash) {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST).json(
http_responses::APIResponse::Error(String::from("invalid_info_hash")),
);
}
let app_state: &UdptState = req.state();
let db = app_state.tracker.get_database();
let entry = match db.get(&info_hash.into()) {
Some(v) => v,
None => {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::NOT_FOUND).json(
http_responses::APIResponse::Error(String::from("not_found")),
);
}
};
let is_flagged = entry.is_flagged();
let (seeders, completed, leechers) = entry.get_stats();
return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK).json(
http_responses::APIResponse::TorrentInfo(http_responses::TorrentInfo {
is_flagged,
seeder_count: seeders,
leecher_count: leechers,
completed,
}),
);
}
fn torrent_action(req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse {
use actix_web::FromRequest;
if UdptRequestState::get_user(req).is_none() {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::UNAUTHORIZED).json(
http_responses::APIResponse::Error(String::from("access_denied")),
);
}
let query = req.query();
let action_opt = query.get("action");
let action = match action_opt {
Some(v) => v,
None => {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST)
.json(http_responses::APIResponse::Error(String::from(
"action_required",
)));
}
};
let app_state: &UdptState = req.state();
let path: actix_web::Path<String> = match actix_web::Path::extract(req) {
Ok(v) => v,
Err(_err) => {
return actix_web::HttpResponse::build(
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
)
.json(http_responses::APIResponse::Error(String::from(
"internal_error",
)));
}
};
let info_hash_str = &(*path);
let mut info_hash = [0u8; 20];
if let Err(_) = binascii::hex2bin(info_hash_str.as_bytes(), &mut info_hash) {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST).json(
http_responses::APIResponse::Error(String::from("invalid_info_hash")),
);
}
match action.as_str() {
"flag" => {
app_state.tracker.set_torrent_flag(&info_hash.into(), true);
info!("Flagged {}", info_hash_str.as_str());
return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK).body("");
}
"unflag" => {
app_state.tracker.set_torrent_flag(&info_hash.into(), false);
info!("Unflagged {}", info_hash_str.as_str());
return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK).body("");
}
"add" => {
let success = app_state.tracker.add_torrent(&info_hash.into()).is_ok();
info!("Added {}, success={}", info_hash_str.as_str(), success);
let code = if success {
actix_web::http::StatusCode::OK
} else {
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR
};
return actix_web::HttpResponse::build(code).body("");
}
"remove" => {
let success = app_state
.tracker
.remove_torrent(&info_hash.into(), true)
.is_ok();
info!("Removed {}, success={}", info_hash_str.as_str(), success);
let code = if success {
actix_web::http::StatusCode::OK
} else {
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR
};
return actix_web::HttpResponse::build(code).body("");
}
_ => {
debug!("Invalid action {}", action.as_str());
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST)
.json(http_responses::APIResponse::Error(String::from(
"invalid_action",
)));
}
}
}
}