Multiple changes

1. [WIP] database serialization/deserialization
2. ran `rustfmt` over sources.
3. updated Cargo.lock
This commit is contained in:
Naim A 2018-12-07 01:58:31 +02:00
parent 80a23981f4
commit 4608cdfa00
No known key found for this signature in database
GPG key ID: FD7948915D9EF8B9
7 changed files with 899 additions and 661 deletions

868
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -41,6 +41,7 @@ pub struct Configuration {
udp: UDPConfig, udp: UDPConfig,
http: Option<HTTPConfig>, http: Option<HTTPConfig>,
log_level: Option<String>, log_level: Option<String>,
db_path: Option<String>,
} }
#[derive(Debug)] #[derive(Debug)]
@ -67,12 +68,10 @@ impl Configuration {
pub fn load_file(path: &str) -> Result<Configuration, ConfigError> { pub fn load_file(path: &str) -> Result<Configuration, ConfigError> {
match std::fs::read(path) { match std::fs::read(path) {
Err(e) => Err(ConfigError::IOError(e)), Err(e) => Err(ConfigError::IOError(e)),
Ok(data) => { Ok(data) => match Self::load(data.as_slice()) {
match Self::load(data.as_slice()) {
Ok(cfg) => Ok(cfg), Ok(cfg) => Ok(cfg),
Err(e) => Err(ConfigError::ParseError(e)), Err(e) => Err(ConfigError::ParseError(e)),
} },
}
} }
} }
@ -91,18 +90,23 @@ impl Configuration {
pub fn get_http_config(&self) -> &Option<HTTPConfig> { pub fn get_http_config(&self) -> &Option<HTTPConfig> {
&self.http &self.http
} }
pub fn get_db_path(&self) -> &Option<String> {
&self.db_path
}
} }
impl Default for Configuration { impl Default for Configuration {
fn default() -> Configuration { fn default() -> Configuration {
Configuration{ Configuration {
log_level: None, log_level: None,
mode: TrackerMode::DynamicMode, mode: TrackerMode::DynamicMode,
udp: UDPConfig{ udp: UDPConfig {
announce_interval: 120, announce_interval: 120,
bind_address: String::from("0.0.0.0:6969"), bind_address: String::from("0.0.0.0:6969"),
}, },
http: None, http: None,
db_path: None,
} }
} }
} }

View file

@ -1,32 +1,33 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
extern crate clap;
extern crate bincode; extern crate bincode;
extern crate clap;
extern crate serde; extern crate serde;
#[macro_use] extern crate serde_derive; #[macro_use]
extern crate serde_derive;
extern crate actix_web; extern crate actix_web;
extern crate binascii; extern crate binascii;
extern crate toml; extern crate toml;
#[macro_use] extern crate log; #[macro_use]
extern crate log;
extern crate bzip2;
extern crate fern; extern crate fern;
extern crate num_cpus; extern crate num_cpus;
extern crate serde_json; extern crate serde_json;
extern crate bzip2;
mod server;
mod tracker;
mod stackvec;
mod webserver;
mod config; mod config;
mod server;
mod stackvec;
mod tracker;
mod webserver;
use std::process::exit;
use config::Configuration; use config::Configuration;
use std::process::exit;
fn setup_logging(cfg: &Configuration) { fn setup_logging(cfg: &Configuration) {
let log_level = match cfg.get_log_level() { let log_level = match cfg.get_log_level() {
None => log::LevelFilter::Info, None => log::LevelFilter::Info,
Some(level) => { Some(level) => match level.as_str() {
match level.as_str() {
"off" => log::LevelFilter::Off, "off" => log::LevelFilter::Off,
"trace" => log::LevelFilter::Trace, "trace" => log::LevelFilter::Trace,
"debug" => log::LevelFilter::Debug, "debug" => log::LevelFilter::Debug,
@ -37,25 +38,26 @@ fn setup_logging(cfg: &Configuration) {
eprintln!("udpt: unknown log level encountered '{}'", level.as_str()); eprintln!("udpt: unknown log level encountered '{}'", level.as_str());
exit(-1); exit(-1);
} }
} },
}
}; };
if let Err(err) = fern::Dispatch::new() if let Err(err) = fern::Dispatch::new()
.format(|out, message, record| { .format(|out, message, record| {
out.finish( out.finish(format_args!(
format_args!(
"{}[{}][{}]\t{}", "{}[{}][{}]\t{}",
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(), std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
record.target(), record.target(),
record.level(), record.level(),
message message
) ))
)
}) })
.level(log_level) .level(log_level)
.chain(std::io::stdout()) .chain(std::io::stdout())
.apply() { .apply()
{
eprintln!("udpt: failed to initialize logging. {}", err); eprintln!("udpt: failed to initialize logging. {}", err);
std::process::exit(-1); std::process::exit(-1);
} }
@ -66,7 +68,13 @@ fn main() {
let parser = clap::App::new("udpt") let parser = clap::App::new("udpt")
.about("High performance, lightweight, udp based torrent tracker.") .about("High performance, lightweight, udp based torrent tracker.")
.author("Naim A. <naim94a@gmail.com>") .author("Naim A. <naim94a@gmail.com>")
.arg(clap::Arg::with_name("config").takes_value(true).short("-c").help("Configuration file to load.").required(true)); .arg(
clap::Arg::with_name("config")
.takes_value(true)
.short("-c")
.help("Configuration file to load.")
.required(true),
);
let matches = parser.get_matches(); let matches = parser.get_matches();
let cfg_path = matches.value_of("config").unwrap(); let cfg_path = matches.value_of("config").unwrap();
@ -81,7 +89,31 @@ fn main() {
setup_logging(&cfg); setup_logging(&cfg);
let tracker = std::sync::Arc::new(tracker::TorrentTracker::new(cfg.get_mode().clone())); let tracker_obj = match cfg.get_db_path() {
Some(path) => {
let file_path = std::path::Path::new(path);
if !file_path.exists() {
warn!("database file \"{}\" doesn't exist.", path);
}
let mut input_file = match std::fs::File::open(file_path) {
Ok(v) => v,
Err(err) => {
error!("failed to open \"{}\". error: {}", path.as_str(), err);
panic!("error opening file. check logs.");
}
};
match tracker::TorrentTracker::load_database(cfg.get_mode().clone(), &mut input_file) {
Ok(v) => v,
Err(err) => {
error!("failed to load database. error: {}", err);
panic!("failed to load database. check logs.");
}
}
}
None => tracker::TorrentTracker::new(cfg.get_mode().clone()),
};
let tracker = std::sync::Arc::new(tracker_obj);
// start http server: // start http server:
if cfg.get_http_config().is_some() { if cfg.get_http_config().is_some() {
@ -98,16 +130,14 @@ fn main() {
let logical_cpus = num_cpus::get(); let logical_cpus = num_cpus::get();
let mut threads = Vec::with_capacity(logical_cpus); let mut threads = Vec::with_capacity(logical_cpus);
for i in 0..logical_cpus { for i in 0..logical_cpus {
debug!("starting thread {}/{}", i+1, logical_cpus); debug!("starting thread {}/{}", i + 1, logical_cpus);
let server_handle = udp_server.clone(); let server_handle = udp_server.clone();
threads.push(std::thread::spawn(move || { threads.push(std::thread::spawn(move || loop {
loop {
match server_handle.accept_packet() { match server_handle.accept_packet() {
Err(e) => { Err(e) => {
error!("Failed to process packet. {}", e); error!("Failed to process packet. {}", e);
},
Ok(_) => {},
} }
Ok(_) => {}
} }
})); }));
} }

View file

@ -1,14 +1,14 @@
use std; use std;
use std::sync::Arc;
use std::net::{SocketAddr, UdpSocket};
use std::io::Write; use std::io::Write;
use std::net::{SocketAddr, UdpSocket};
use std::sync::Arc;
use bincode; use bincode;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use tracker;
use stackvec::StackVec;
use config::Configuration; use config::Configuration;
use stackvec::StackVec;
use tracker;
// maximum MTU is usually 1500, but our stack allows us to allocate the maximum - so why not? // maximum MTU is usually 1500, but our stack allows us to allocate the maximum - so why not?
const MAX_PACKET_SIZE: usize = 0xffff; const MAX_PACKET_SIZE: usize = 0xffff;
@ -105,7 +105,10 @@ pub struct UDPTracker {
} }
impl UDPTracker { impl UDPTracker {
pub fn new(config: Arc<Configuration>, tracker: std::sync::Arc<tracker::TorrentTracker>) -> Result<UDPTracker, std::io::Error> { pub fn new(
config: Arc<Configuration>,
tracker: std::sync::Arc<tracker::TorrentTracker>,
) -> Result<UDPTracker, std::io::Error> {
let cfg = config.clone(); let cfg = config.clone();
let server = match UdpSocket::bind(cfg.get_udp_config().get_address()) { let server = match UdpSocket::bind(cfg.get_udp_config().get_address()) {
@ -115,7 +118,7 @@ impl UDPTracker {
} }
}; };
Ok(UDPTracker{ Ok(UDPTracker {
server, server,
tracker, tracker,
config: cfg, config: cfg,
@ -123,7 +126,7 @@ impl UDPTracker {
} }
fn handle_packet(&self, remote_address: &SocketAddr, payload: &[u8]) { fn handle_packet(&self, remote_address: &SocketAddr, payload: &[u8]) {
let header : UDPRequestHeader = match unpack(payload) { let header: UDPRequestHeader = match unpack(payload) {
Some(val) => val, Some(val) => val,
None => { None => {
trace!("failed to parse packet from {}", remote_address); trace!("failed to parse packet from {}", remote_address);
@ -152,8 +155,8 @@ impl UDPTracker {
// send response... // send response...
let conn_id = self.get_connection_id(remote_addr); let conn_id = self.get_connection_id(remote_addr);
let response = UDPConnectionResponse{ let response = UDPConnectionResponse {
header: UDPResponseHeader{ header: UDPResponseHeader {
transaction_id: header.transaction_id, transaction_id: header.transaction_id,
action: Actions::Connect, action: Actions::Connect,
}, },
@ -187,22 +190,41 @@ impl UDPTracker {
let bep41_payload = &payload[plen..]; let bep41_payload = &payload[plen..];
// TODO: process BEP0041 payload. // TODO: process BEP0041 payload.
trace!("BEP0041 payload of {} bytes from {}", bep41_payload.len(), remote_addr); trace!(
"BEP0041 payload of {} bytes from {}",
bep41_payload.len(),
remote_addr
);
} }
} }
if packet.ip_address != 0 { if packet.ip_address != 0 {
// TODO: allow configurability of ip address // TODO: allow configurability of ip address
// for now, ignore request. // for now, ignore request.
trace!("announce request for other IP ignored. (from {})", remote_addr); trace!(
"announce request for other IP ignored. (from {})",
remote_addr
);
return; return;
} }
let client_addr = SocketAddr::new(remote_addr.ip(), packet.port); let client_addr = SocketAddr::new(remote_addr.ip(), packet.port);
let info_hash = packet.info_hash.into(); let info_hash = packet.info_hash.into();
match self.tracker.update_torrent_and_get_stats(&info_hash, &packet.peer_id, &client_addr, packet.uploaded, packet.downloaded, packet.left, packet.event) { match self.tracker.update_torrent_and_get_stats(
tracker::TorrentStats::Stats {leechers, complete: _, seeders} => { &info_hash,
&packet.peer_id,
&client_addr,
packet.uploaded,
packet.downloaded,
packet.left,
packet.event,
) {
tracker::TorrentStats::Stats {
leechers,
complete: _,
seeders,
} => {
let peers = match self.tracker.get_torrent_peers(&info_hash, &client_addr) { let peers = match self.tracker.get_torrent_peers(&info_hash, &client_addr) {
Some(v) => v, Some(v) => v,
None => { None => {
@ -213,7 +235,9 @@ impl UDPTracker {
let mut payload_buffer = [0u8; MAX_PACKET_SIZE]; let mut payload_buffer = [0u8; MAX_PACKET_SIZE];
let mut payload = StackVec::from(&mut payload_buffer); let mut payload = StackVec::from(&mut payload_buffer);
match pack_into(&mut payload,&UDPAnnounceResponse { match pack_into(
&mut payload,
&UDPAnnounceResponse {
header: UDPResponseHeader { header: UDPResponseHeader {
action: Actions::Announce, action: Actions::Announce,
transaction_id: packet.header.transaction_id, transaction_id: packet.header.transaction_id,
@ -221,8 +245,9 @@ impl UDPTracker {
seeders, seeders,
interval: self.config.get_udp_config().get_announce_interval(), interval: self.config.get_udp_config().get_announce_interval(),
leechers, leechers,
}) { },
Ok(_) => {}, ) {
Ok(_) => {}
Err(_) => { Err(_) => {
return; return;
} }
@ -232,22 +257,23 @@ impl UDPTracker {
match peer { match peer {
SocketAddr::V4(ipv4) => { SocketAddr::V4(ipv4) => {
let _ = payload.write(&ipv4.ip().octets()); let _ = payload.write(&ipv4.ip().octets());
}, }
SocketAddr::V6(ipv6) => { SocketAddr::V6(ipv6) => {
let _ = payload.write(&ipv6.ip().octets()); let _ = payload.write(&ipv6.ip().octets());
} }
}; };
let port_hton = client_addr.port().to_be(); let port_hton = client_addr.port().to_be();
let _ = payload.write(&[(port_hton & 0xff) as u8, ((port_hton >> 8) & 0xff) as u8]); let _ =
payload.write(&[(port_hton & 0xff) as u8, ((port_hton >> 8) & 0xff) as u8]);
} }
let _ = self.send_packet(&client_addr, payload.as_slice()); let _ = self.send_packet(&client_addr, payload.as_slice());
}, }
tracker::TorrentStats::TorrentFlagged => { tracker::TorrentStats::TorrentFlagged => {
self.send_error(&client_addr, &packet.header, "torrent flagged."); self.send_error(&client_addr, &packet.header, "torrent flagged.");
return; return;
}, }
tracker::TorrentStats::TorrentNotRegistered => { tracker::TorrentStats::TorrentNotRegistered => {
self.send_error(&client_addr, &packet.header, "torrent not registered."); self.send_error(&client_addr, &packet.header, "torrent not registered.");
return; return;
@ -265,16 +291,16 @@ impl UDPTracker {
fn get_connection_id(&self, remote_address: &SocketAddr) -> u64 { fn get_connection_id(&self, remote_address: &SocketAddr) -> u64 {
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(duration) => { Ok(duration) => (duration.as_secs() / 3600) | ((remote_address.port() as u64) << 36),
(duration.as_secs() / 3600) | ((remote_address.port() as u64) << 36) Err(_) => 0x8000000000000000,
},
Err(_) => {
0x8000000000000000
}
} }
} }
fn send_packet(&self, remote_addr: &SocketAddr, payload: &[u8]) -> Result<usize, std::io::Error> { fn send_packet(
&self,
remote_addr: &SocketAddr,
payload: &[u8],
) -> Result<usize, std::io::Error> {
self.server.send_to(payload, remote_addr) self.server.send_to(payload, remote_addr)
} }
@ -282,10 +308,13 @@ impl UDPTracker {
let mut payload_buffer = [0u8; MAX_PACKET_SIZE]; let mut payload_buffer = [0u8; MAX_PACKET_SIZE];
let mut payload = StackVec::from(&mut payload_buffer); let mut payload = StackVec::from(&mut payload_buffer);
if let Ok(_) = pack_into(&mut payload, &UDPResponseHeader{ if let Ok(_) = pack_into(
&mut payload,
&UDPResponseHeader {
transaction_id: header.transaction_id, transaction_id: header.transaction_id,
action: Actions::Error, action: Actions::Error,
}) { },
) {
let msg_bytes = Vec::from(error_msg.as_bytes()); let msg_bytes = Vec::from(error_msg.as_bytes());
payload.extend(msg_bytes); payload.extend(msg_bytes);
@ -301,7 +330,7 @@ impl UDPTracker {
self.handle_packet(&remote_address, &packet[..size]); self.handle_packet(&remote_address, &packet[..size]);
Ok(()) Ok(())
}, }
Err(e) => Err(e), Err(e) => Err(e),
} }
} }
@ -322,7 +351,10 @@ mod tests {
assert!(pack_into(&mut payload, &mystruct).is_ok()); assert!(pack_into(&mut payload, &mystruct).is_ok());
assert_eq!(payload.len(), 16); assert_eq!(payload.len(), 16);
assert_eq!(payload.as_slice(), &[0, 0, 0, 0, 0, 0, 0, 200u8, 0, 0, 0, 0, 0, 1, 47, 203]); assert_eq!(
payload.as_slice(),
&[0, 0, 0, 0, 0, 0, 0, 200u8, 0, 0, 0, 0, 0, 1, 47, 203]
);
} }
#[test] #[test]
@ -330,9 +362,9 @@ mod tests {
let buf = [0u8, 0, 0, 0, 0, 0, 0, 200, 0, 0, 0, 1, 0, 1, 47, 203]; let buf = [0u8, 0, 0, 0, 0, 0, 0, 200, 0, 0, 0, 1, 0, 1, 47, 203];
match super::unpack(&buf) { match super::unpack(&buf) {
Some(obj) => { Some(obj) => {
let x : super::UDPResponseHeader = obj; let x: super::UDPResponseHeader = obj;
println!("conn_id={}", x.action as u32); println!("conn_id={}", x.action as u32);
}, }
None => { None => {
assert!(false); assert!(false);
} }

View file

@ -7,10 +7,7 @@ pub struct StackVec<'a, T: 'a> {
impl<'a, T> StackVec<'a, T> { impl<'a, T> StackVec<'a, T> {
pub fn from(data: &mut [T]) -> StackVec<T> { pub fn from(data: &mut [T]) -> StackVec<T> {
StackVec{ StackVec { data, length: 0 }
data,
length: 0,
}
} }
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
@ -23,7 +20,7 @@ impl<'a, T> StackVec<'a, T> {
} }
impl<'a, T> Extend<T> for StackVec<'a, T> { impl<'a, T> Extend<T> for StackVec<'a, T> {
fn extend<I: IntoIterator<Item=T>>(&mut self, iter: I) { fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for item in iter { for item in iter {
self.data[self.length] = item; self.data[self.length] = item;
self.length += 1; self.length += 1;

View file

@ -1,23 +1,22 @@
use std;
use serde;
use binascii; use binascii;
use serde;
use serde_json; use serde_json;
use std;
use server::Events; use server::Events;
#[derive(Deserialize, Clone)] #[derive(Deserialize, Clone, PartialEq)]
pub enum TrackerMode { pub enum TrackerMode {
/// In static mode torrents are tracked only if they were added ahead of time. /// In static mode torrents are tracked only if they were added ahead of time.
#[serde(rename="static")] #[serde(rename = "static")]
StaticMode, StaticMode,
/// In dynamic mode, torrents are tracked being added ahead of time. /// In dynamic mode, torrents are tracked being added ahead of time.
#[serde(rename="dynamic")] #[serde(rename = "dynamic")]
DynamicMode, DynamicMode,
/// Tracker will only serve authenticated peers. /// Tracker will only serve authenticated peers.
#[serde(rename="private")] #[serde(rename = "private")]
PrivateMode, PrivateMode,
} }
@ -43,16 +42,16 @@ impl std::cmp::PartialOrd<InfoHash> for InfoHash {
impl std::convert::Into<InfoHash> for [u8; 20] { impl std::convert::Into<InfoHash> for [u8; 20] {
fn into(self) -> InfoHash { fn into(self) -> InfoHash {
InfoHash{ InfoHash { info_hash: self }
info_hash: self,
}
} }
} }
impl serde::ser::Serialize for InfoHash { impl serde::ser::Serialize for InfoHash {
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut buffer = [0u8; 40]; let mut buffer = [0u8; 40];
let bytes_out = binascii::bin2hex(&self.info_hash, &mut buffer).ok().unwrap(); let bytes_out = binascii::bin2hex(&self.info_hash, &mut buffer)
.ok()
.unwrap();
let str_out = std::str::from_utf8(bytes_out).unwrap(); let str_out = std::str::from_utf8(bytes_out).unwrap();
serializer.serialize_str(str_out) serializer.serialize_str(str_out)
@ -70,15 +69,21 @@ impl<'v> serde::de::Visitor<'v> for InfoHashVisitor {
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> { fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
if v.len() != 40 { if v.len() != 40 {
return Err(serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &"expected a 40 character long string")); return Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Str(v),
&"expected a 40 character long string",
));
} }
let mut res = InfoHash{ let mut res = InfoHash {
info_hash: [0u8; 20], info_hash: [0u8; 20],
}; };
if let Err(_) = binascii::hex2bin(v.as_bytes(), &mut res.info_hash) { if let Err(_) = binascii::hex2bin(v.as_bytes(), &mut res.info_hash) {
return Err(serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &"expected a hexadecimal string")); return Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Str(v),
&"expected a hexadecimal string",
));
} else { } else {
return Ok(res); return Ok(res);
} }
@ -107,8 +112,8 @@ pub struct TorrentEntry {
} }
impl TorrentEntry { impl TorrentEntry {
pub fn new() -> TorrentEntry{ pub fn new() -> TorrentEntry {
TorrentEntry{ TorrentEntry {
is_flagged: false, is_flagged: false,
peers: std::collections::BTreeMap::new(), peers: std::collections::BTreeMap::new(),
completed: 0, completed: 0,
@ -120,18 +125,29 @@ impl TorrentEntry {
self.is_flagged self.is_flagged
} }
pub fn update_peer(&mut self, peer_id: &PeerId, remote_address: &std::net::SocketAddr, uploaded: u64, downloaded: u64, left: u64, event: Events) { pub fn update_peer(
&mut self,
peer_id: &PeerId,
remote_address: &std::net::SocketAddr,
uploaded: u64,
downloaded: u64,
left: u64,
event: Events,
) {
let is_seeder = left == 0 && uploaded > 0; let is_seeder = left == 0 && uploaded > 0;
let mut was_seeder = false; let mut was_seeder = false;
let mut is_completed = left == 0 && (event as u32) == (Events::Complete as u32); let mut is_completed = left == 0 && (event as u32) == (Events::Complete as u32);
if let Some(prev) = self.peers.insert(*peer_id, TorrentPeer{ if let Some(prev) = self.peers.insert(
*peer_id,
TorrentPeer {
updated: std::time::SystemTime::now(), updated: std::time::SystemTime::now(),
left, left,
downloaded, downloaded,
uploaded, uploaded,
ip: *remote_address, ip: *remote_address,
event, event,
}) { },
) {
was_seeder = prev.left == 0 && prev.uploaded > 0; was_seeder = prev.left == 0 && prev.uploaded > 0;
if is_completed && (prev.event as u32) == (Events::Complete as u32) { if is_completed && (prev.event as u32) == (Events::Complete as u32) {
@ -153,7 +169,12 @@ impl TorrentEntry {
pub fn get_peers(&self, remote_addr: &std::net::SocketAddr) -> Vec<std::net::SocketAddr> { pub fn get_peers(&self, remote_addr: &std::net::SocketAddr) -> Vec<std::net::SocketAddr> {
let mut list = Vec::new(); let mut list = Vec::new();
for (_, peer) in self.peers.iter().filter(|e| e.1.ip.is_ipv4() == remote_addr.is_ipv4()).take(74) { for (_, peer) in self
.peers
.iter()
.filter(|e| e.1.ip.is_ipv4() == remote_addr.is_ipv4())
.take(74)
{
if peer.ip == *remote_addr { if peer.ip == *remote_addr {
continue; continue;
} }
@ -175,7 +196,7 @@ struct TorrentDatabase {
impl Default for TorrentDatabase { impl Default for TorrentDatabase {
fn default() -> Self { fn default() -> Self {
TorrentDatabase{ TorrentDatabase {
torrent_peers: std::sync::RwLock::new(std::collections::BTreeMap::new()), torrent_peers: std::sync::RwLock::new(std::collections::BTreeMap::new()),
} }
} }
@ -189,21 +210,40 @@ pub struct TorrentTracker {
pub enum TorrentStats { pub enum TorrentStats {
TorrentFlagged, TorrentFlagged,
TorrentNotRegistered, TorrentNotRegistered,
Stats{ Stats {
seeders: u32, seeders: u32,
leechers: u32, leechers: u32,
complete: u32, complete: u32,
} },
} }
impl TorrentTracker { impl TorrentTracker {
pub fn new(mode: TrackerMode) -> TorrentTracker { pub fn new(mode: TrackerMode) -> TorrentTracker {
TorrentTracker{ TorrentTracker {
mode, mode,
database: TorrentDatabase{ database: TorrentDatabase {
torrent_peers: std::sync::RwLock::new(std::collections::BTreeMap::new()), torrent_peers: std::sync::RwLock::new(std::collections::BTreeMap::new()),
},
} }
} }
pub fn load_database<R: std::io::Read>(
mode: TrackerMode,
reader: &mut R,
) -> serde_json::Result<TorrentTracker> {
use bzip2;
let decomp_reader = bzip2::read::BzDecoder::new(reader);
let result: serde_json::Result<std::collections::BTreeMap<InfoHash, TorrentEntry>> =
serde_json::from_reader(decomp_reader);
match result {
Ok(v) => Ok(TorrentTracker {
mode,
database: TorrentDatabase {
torrent_peers: std::sync::RwLock::new(v),
},
}),
Err(e) => Err(e),
}
} }
/// Adding torrents is not relevant to dynamic trackers. /// Adding torrents is not relevant to dynamic trackers.
@ -213,7 +253,7 @@ impl TorrentTracker {
std::collections::btree_map::Entry::Vacant(ve) => { std::collections::btree_map::Entry::Vacant(ve) => {
ve.insert(TorrentEntry::new()); ve.insert(TorrentEntry::new());
return Ok(()); return Ok(());
}, }
std::collections::btree_map::Entry::Occupied(_entry) => { std::collections::btree_map::Entry::Occupied(_entry) => {
return Err(()); return Err(());
} }
@ -229,20 +269,26 @@ impl TorrentTracker {
Entry::Vacant(_) => { Entry::Vacant(_) => {
// no entry, nothing to do... // no entry, nothing to do...
return Err(()); return Err(());
}, }
Entry::Occupied(entry) => { Entry::Occupied(entry) => {
if force || !entry.get().is_flagged() { if force || !entry.get().is_flagged() {
entry.remove(); entry.remove();
return Ok(()); return Ok(());
} }
return Err(()); return Err(());
}, }
} }
} }
/// flagged torrents will result in a tracking error. This is to allow enforcement against piracy. /// flagged torrents will result in a tracking error. This is to allow enforcement against piracy.
pub fn set_torrent_flag(&self, info_hash: &InfoHash, is_flagged: bool) { pub fn set_torrent_flag(&self, info_hash: &InfoHash, is_flagged: bool) {
if let Some(entry) = self.database.torrent_peers.write().unwrap().get_mut(info_hash) { if let Some(mut entry) = self
.database
.torrent_peers
.write()
.unwrap()
.get_mut(info_hash)
{
if is_flagged && !entry.is_flagged { if is_flagged && !entry.is_flagged {
// empty peer list. // empty peer list.
entry.peers.clear(); entry.peers.clear();
@ -251,7 +297,11 @@ impl TorrentTracker {
} }
} }
pub fn get_torrent_peers(&self, info_hash: &InfoHash, remote_addr: &std::net::SocketAddr) -> Option<Vec<std::net::SocketAddr>> { pub fn get_torrent_peers(
&self,
info_hash: &InfoHash,
remote_addr: &std::net::SocketAddr,
) -> Option<Vec<std::net::SocketAddr>> {
let read_lock = self.database.torrent_peers.read().unwrap(); let read_lock = self.database.torrent_peers.read().unwrap();
match read_lock.get(info_hash) { match read_lock.get(info_hash) {
None => { None => {
@ -263,26 +313,31 @@ impl TorrentTracker {
}; };
} }
pub fn update_torrent_and_get_stats(&self, info_hash: &InfoHash, peer_id: &PeerId, remote_address: &std::net::SocketAddr, uploaded: u64, downloaded: u64, left: u64, event: Events) -> TorrentStats { pub fn update_torrent_and_get_stats(
&self,
info_hash: &InfoHash,
peer_id: &PeerId,
remote_address: &std::net::SocketAddr,
uploaded: u64,
downloaded: u64,
left: u64,
event: Events,
) -> TorrentStats {
use std::collections::btree_map::Entry; use std::collections::btree_map::Entry;
let mut torrent_peers = self.database.torrent_peers.write().unwrap(); let mut torrent_peers = self.database.torrent_peers.write().unwrap();
let torrent_entry = match torrent_peers.entry(info_hash.clone()) { let torrent_entry = match torrent_peers.entry(info_hash.clone()) {
Entry::Vacant(vacant) => { Entry::Vacant(vacant) => match self.mode {
match self.mode { TrackerMode::DynamicMode => vacant.insert(TorrentEntry::new()),
TrackerMode::DynamicMode => {
vacant.insert(TorrentEntry::new())
},
_ => { _ => {
return TorrentStats::TorrentNotRegistered; return TorrentStats::TorrentNotRegistered;
} }
}
}, },
Entry::Occupied(entry) => { Entry::Occupied(entry) => {
if entry.get().is_flagged() { if entry.get().is_flagged() {
return TorrentStats::TorrentFlagged; return TorrentStats::TorrentFlagged;
} }
entry.into_mut() entry.into_mut()
}, }
}; };
torrent_entry.update_peer(peer_id, remote_address, uploaded, downloaded, left, event); torrent_entry.update_peer(peer_id, remote_address, uploaded, downloaded, left, event);
@ -296,7 +351,9 @@ impl TorrentTracker {
}; };
} }
pub (crate) fn get_database(&self) -> std::sync::RwLockReadGuard<std::collections::BTreeMap<InfoHash, TorrentEntry>>{ pub(crate) fn get_database(
&self,
) -> std::sync::RwLockReadGuard<std::collections::BTreeMap<InfoHash, TorrentEntry>> {
self.database.torrent_peers.read().unwrap() self.database.torrent_peers.read().unwrap()
} }
@ -311,6 +368,67 @@ impl TorrentTracker {
serde_json::to_writer(compressor, &db) serde_json::to_writer(compressor, &db)
} }
fn cleanup(&mut self) {
use std::ops::Add;
let now = std::time::SystemTime::now();
match self.database.torrent_peers.write() {
Err(err) => {
error!("failed to obtain write lock on database. err: {}", err);
return;
}
Ok(mut db) => {
let mut torrents_to_remove = Vec::new();
for (k, v) in db.iter_mut() {
// timed-out peers..
{
let mut peers_to_remove = Vec::new();
let torrent_peers = &mut v.peers;
for (peer_id, state) in torrent_peers.iter() {
if state.updated.add(std::time::Duration::new(3600 * 2, 0)) < now {
// over 2 hours past since last update...
peers_to_remove.push(*peer_id);
}
}
for peer_id in peers_to_remove.iter() {
torrent_peers.remove(peer_id);
}
}
if self.mode == TrackerMode::DynamicMode {
// peer-less torrents..
if v.peers.len() == 0 {
torrents_to_remove.push(k.clone());
}
}
}
for info_hash in torrents_to_remove {
db.remove(&info_hash);
}
}
}
}
pub fn periodic_task(&mut self, db_path: &str) {
// cleanup db
self.cleanup();
// save db.
match std::fs::File::open(db_path) {
Err(err) => {
error!("failed to open file '{}': {}", db_path, err);
return;
}
Ok(mut file) => {
self.save_database(&mut file);
}
}
}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -47,7 +47,7 @@ struct UdptState {
impl UdptState { impl UdptState {
fn new(tracker: Arc<tracker::TorrentTracker>, tokens: HashMap<String, String>) -> UdptState { fn new(tracker: Arc<tracker::TorrentTracker>, tokens: HashMap<String, String>) -> UdptState {
UdptState{ UdptState {
tracker, tracker,
access_tokens: tokens, access_tokens: tokens,
} }
@ -61,7 +61,7 @@ struct UdptRequestState {
impl Default for UdptRequestState { impl Default for UdptRequestState {
fn default() -> Self { fn default() -> Self {
UdptRequestState{ UdptRequestState {
current_user: Option::None, current_user: Option::None,
} }
} }
@ -73,26 +73,31 @@ impl UdptRequestState {
let req_state: Option<&UdptRequestState> = exts.get(); let req_state: Option<&UdptRequestState> = exts.get();
match req_state { match req_state {
None => None, None => None,
Option::Some(state) => { Option::Some(state) => match state.current_user {
match state.current_user {
Option::Some(ref v) => Option::Some(v.clone()), Option::Some(ref v) => Option::Some(v.clone()),
None => { None => {
error!("Invalid API token from {} @ {}", req.peer_addr().unwrap(), req.path()); error!(
"Invalid API token from {} @ {}",
req.peer_addr().unwrap(),
req.path()
);
return None; return None;
}
}, },
} }
} }
}
}
} }
struct UdptMiddleware; struct UdptMiddleware;
impl actix_web::middleware::Middleware<UdptState> for UdptMiddleware { impl actix_web::middleware::Middleware<UdptState> for UdptMiddleware {
fn start(&self, req: &actix_web::HttpRequest<UdptState>) -> actix_web::Result<actix_web::middleware::Started> { fn start(
&self,
req: &actix_web::HttpRequest<UdptState>,
) -> actix_web::Result<actix_web::middleware::Started> {
let mut req_state = UdptRequestState::default(); let mut req_state = UdptRequestState::default();
if let Option::Some(token) = req.query().get("token") { if let Option::Some(token) = req.query().get("token") {
let app_state : &UdptState = req.state(); let app_state: &UdptState = req.state();
if let Option::Some(v) = app_state.access_tokens.get(token) { if let Option::Some(v) = app_state.access_tokens.get(token) {
req_state.current_user = Option::Some(v.clone()); req_state.current_user = Option::Some(v.clone());
} }
@ -101,9 +106,15 @@ impl actix_web::middleware::Middleware<UdptState> for UdptMiddleware {
Ok(actix_web::middleware::Started::Done) Ok(actix_web::middleware::Started::Done)
} }
fn response(&self, _req: &actix_web::HttpRequest<UdptState>, mut resp: actix_web::HttpResponse) -> actix_web::Result<actix_web::middleware::Response> { fn response(
resp.headers_mut() &self,
.insert(actix_web::http::header::SERVER, actix_web::http::header::HeaderValue::from_static(SERVER)); _req: &actix_web::HttpRequest<UdptState>,
mut resp: actix_web::HttpResponse,
) -> actix_web::Result<actix_web::middleware::Response> {
resp.headers_mut().insert(
actix_web::http::header::SERVER,
actix_web::http::header::HeaderValue::from_static(SERVER),
);
Ok(actix_web::middleware::Response::Done(resp)) Ok(actix_web::middleware::Response::Done(resp))
} }
@ -119,7 +130,10 @@ impl WebServer {
} }
} }
pub fn new(tracker: Arc<tracker::TorrentTracker>, cfg: Arc<config::Configuration>) -> WebServer { pub fn new(
tracker: Arc<tracker::TorrentTracker>,
cfg: Arc<config::Configuration>,
) -> WebServer {
let cfg_cp = cfg.clone(); let cfg_cp = cfg.clone();
let server = actix_web::server::HttpServer::new(move || { let server = actix_web::server::HttpServer::new(move || {
@ -135,13 +149,16 @@ impl WebServer {
.middleware(UdptMiddleware) .middleware(UdptMiddleware)
.resource("/t", |r| r.f(Self::view_torrent_list)) .resource("/t", |r| r.f(Self::view_torrent_list))
.scope(r"/t/{info_hash:[\dA-Fa-f]{40,40}}", |scope| { .scope(r"/t/{info_hash:[\dA-Fa-f]{40,40}}", |scope| {
scope scope.resource("", |r| {
.resource("", |r| { r.method(actix_web::http::Method::GET)
r.method(actix_web::http::Method::GET).f(Self::view_torrent_stats); .f(Self::view_torrent_stats);
r.method(actix_web::http::Method::POST).f(Self::torrent_action); r.method(actix_web::http::Method::POST)
.f(Self::torrent_action);
}) })
}) })
.resource("/", |r| r.method(actix_web::http::Method::GET).f(Self::view_root)) .resource("/", |r| {
r.method(actix_web::http::Method::GET).f(Self::view_root)
})
}); });
if let Some(http_cfg) = cfg.get_http_config() { if let Some(http_cfg) = cfg.get_http_config() {
@ -149,17 +166,16 @@ impl WebServer {
match server.bind(bind_addr) { match server.bind(bind_addr) {
Ok(v) => { Ok(v) => {
v.run(); v.run();
}, }
Err(err) => { Err(err) => {
error!("Failed to bind http server. {}", err); error!("Failed to bind http server. {}", err);
} }
} }
} } else {
else {
unreachable!(); unreachable!();
} }
WebServer{} WebServer {}
} }
fn view_root(_req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse { fn view_root(_req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse {
@ -172,27 +188,25 @@ impl WebServer {
use std::str::FromStr; use std::str::FromStr;
if UdptRequestState::get_user(req).is_none() { if UdptRequestState::get_user(req).is_none() {
return actix_web::Json(http_responses::APIResponse::Error(String::from("access_denied"))); return actix_web::Json(http_responses::APIResponse::Error(String::from(
"access_denied",
)));
} }
let req_offset = match req.query().get("offset") { let req_offset = match req.query().get("offset") {
None => 0, None => 0,
Some(v) => { Some(v) => match u32::from_str(v.as_str()) {
match u32::from_str(v.as_str()) {
Ok(v) => v, Ok(v) => v,
Err(_) => 0, Err(_) => 0,
} },
}
}; };
let mut req_limit = match req.query().get("limit") { let mut req_limit = match req.query().get("limit") {
None => 0, None => 0,
Some(v) => { Some(v) => match u32::from_str(v.as_str()) {
match u32::from_str(v.as_str()) {
Ok(v) => v, Ok(v) => v,
Err(_) => 0, Err(_) => 0,
} },
}
}; };
if req_limit > 4096 { if req_limit > 4096 {
@ -208,38 +222,50 @@ impl WebServer {
let mut torrents = Vec::with_capacity(req_limit as usize); let mut torrents = Vec::with_capacity(req_limit as usize);
for (info_hash, _) in app_db.iter().skip(req_offset as usize).take(req_limit as usize) { for (info_hash, _) in app_db
.iter()
.skip(req_offset as usize)
.take(req_limit as usize)
{
torrents.push(info_hash.clone()); torrents.push(info_hash.clone());
} }
actix_web::Json(http_responses::APIResponse::TorrentList(http_responses::TorrentList{ actix_web::Json(http_responses::APIResponse::TorrentList(
http_responses::TorrentList {
total, total,
length: torrents.len() as u32, length: torrents.len() as u32,
offset: req_offset, offset: req_offset,
torrents, torrents,
})) },
))
} }
fn view_torrent_stats(req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse { fn view_torrent_stats(req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse {
use actix_web::FromRequest; use actix_web::FromRequest;
if UdptRequestState::get_user(req).is_none() { if UdptRequestState::get_user(req).is_none() {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::UNAUTHORIZED) return actix_web::HttpResponse::build(actix_web::http::StatusCode::UNAUTHORIZED).json(
.json(http_responses::APIResponse::Error(String::from("access_denied"))); http_responses::APIResponse::Error(String::from("access_denied")),
);
} }
let path: actix_web::Path<String> = match actix_web::Path::extract(req) { let path: actix_web::Path<String> = match actix_web::Path::extract(req) {
Ok(v) => v, Ok(v) => v,
Err(_) => { Err(_) => {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR) return actix_web::HttpResponse::build(
.json(http_responses::APIResponse::Error(String::from("internal_error"))); actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
)
.json(http_responses::APIResponse::Error(String::from(
"internal_error",
)));
} }
}; };
let mut info_hash = [0u8; 20]; let mut info_hash = [0u8; 20];
if let Err(_) = binascii::hex2bin((*path).as_bytes(), &mut info_hash) { if let Err(_) = binascii::hex2bin((*path).as_bytes(), &mut info_hash) {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST) return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST).json(
.json(http_responses::APIResponse::Error(String::from("invalid_info_hash"))); http_responses::APIResponse::Error(String::from("invalid_info_hash")),
);
} }
let app_state: &UdptState = req.state(); let app_state: &UdptState = req.state();
@ -248,31 +274,32 @@ impl WebServer {
let entry = match db.get(&info_hash.into()) { let entry = match db.get(&info_hash.into()) {
Some(v) => v, Some(v) => v,
None => { None => {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::NOT_FOUND) return actix_web::HttpResponse::build(actix_web::http::StatusCode::NOT_FOUND).json(
.json(http_responses::APIResponse::Error(String::from("not_found"))); http_responses::APIResponse::Error(String::from("not_found")),
);
} }
}; };
let is_flagged = entry.is_flagged(); let is_flagged = entry.is_flagged();
let (seeders, completed, leechers) = entry.get_stats(); let (seeders, completed, leechers) = entry.get_stats();
return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK) return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK).json(
.json(http_responses::APIResponse::TorrentInfo( http_responses::APIResponse::TorrentInfo(http_responses::TorrentInfo {
http_responses::TorrentInfo{
is_flagged, is_flagged,
seeder_count: seeders, seeder_count: seeders,
leecher_count: leechers, leecher_count: leechers,
completed, completed,
} }),
)); );
} }
fn torrent_action(req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse { fn torrent_action(req: &actix_web::HttpRequest<UdptState>) -> actix_web::HttpResponse {
use actix_web::FromRequest; use actix_web::FromRequest;
if UdptRequestState::get_user(req).is_none() { if UdptRequestState::get_user(req).is_none() {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::UNAUTHORIZED) return actix_web::HttpResponse::build(actix_web::http::StatusCode::UNAUTHORIZED).json(
.json(http_responses::APIResponse::Error(String::from("access_denied"))); http_responses::APIResponse::Error(String::from("access_denied")),
);
} }
let query = req.query(); let query = req.query();
@ -281,7 +308,9 @@ impl WebServer {
Some(v) => v, Some(v) => v,
None => { None => {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST) return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST)
.json(http_responses::APIResponse::Error(String::from("action_required"))); .json(http_responses::APIResponse::Error(String::from(
"action_required",
)));
} }
}; };
@ -290,51 +319,65 @@ impl WebServer {
let path: actix_web::Path<String> = match actix_web::Path::extract(req) { let path: actix_web::Path<String> = match actix_web::Path::extract(req) {
Ok(v) => v, Ok(v) => v,
Err(_err) => { Err(_err) => {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR) return actix_web::HttpResponse::build(
.json(http_responses::APIResponse::Error(String::from("internal_error"))); actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
)
.json(http_responses::APIResponse::Error(String::from(
"internal_error",
)));
} }
}; };
let info_hash_str = &(*path); let info_hash_str = &(*path);
let mut info_hash = [0u8; 20]; let mut info_hash = [0u8; 20];
if let Err(_) = binascii::hex2bin(info_hash_str.as_bytes(), &mut info_hash) { if let Err(_) = binascii::hex2bin(info_hash_str.as_bytes(), &mut info_hash) {
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST) return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST).json(
.json(http_responses::APIResponse::Error(String::from("invalid_info_hash"))); http_responses::APIResponse::Error(String::from("invalid_info_hash")),
);
} }
match action.as_str() { match action.as_str() {
"flag" => { "flag" => {
app_state.tracker.set_torrent_flag(&info_hash.into(), true); app_state.tracker.set_torrent_flag(&info_hash.into(), true);
info!("Flagged {}", info_hash_str.as_str()); info!("Flagged {}", info_hash_str.as_str());
return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK) return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK).body("");
.body("") }
},
"unflag" => { "unflag" => {
app_state.tracker.set_torrent_flag(&info_hash.into(), false); app_state.tracker.set_torrent_flag(&info_hash.into(), false);
info!("Unflagged {}", info_hash_str.as_str()); info!("Unflagged {}", info_hash_str.as_str());
return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK) return actix_web::HttpResponse::build(actix_web::http::StatusCode::OK).body("");
.body("") }
},
"add" => { "add" => {
let success = app_state.tracker.add_torrent(&info_hash.into()).is_ok(); let success = app_state.tracker.add_torrent(&info_hash.into()).is_ok();
info!("Added {}, success={}", info_hash_str.as_str(), success); info!("Added {}, success={}", info_hash_str.as_str(), success);
let code = if success { actix_web::http::StatusCode::OK } else { actix_web::http::StatusCode::INTERNAL_SERVER_ERROR }; let code = if success {
actix_web::http::StatusCode::OK
} else {
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR
};
return actix_web::HttpResponse::build(code) return actix_web::HttpResponse::build(code).body("");
.body("") }
},
"remove" => { "remove" => {
let success = app_state.tracker.remove_torrent(&info_hash.into(), true).is_ok(); let success = app_state
.tracker
.remove_torrent(&info_hash.into(), true)
.is_ok();
info!("Removed {}, success={}", info_hash_str.as_str(), success); info!("Removed {}, success={}", info_hash_str.as_str(), success);
let code = if success { actix_web::http::StatusCode::OK } else { actix_web::http::StatusCode::INTERNAL_SERVER_ERROR }; let code = if success {
actix_web::http::StatusCode::OK
} else {
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR
};
return actix_web::HttpResponse::build(code) return actix_web::HttpResponse::build(code).body("");
.body("") }
},
_ => { _ => {
debug!("Invalid action {}", action.as_str()); debug!("Invalid action {}", action.as_str());
return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST) return actix_web::HttpResponse::build(actix_web::http::StatusCode::BAD_REQUEST)
.json(http_responses::APIResponse::Error(String::from("invalid_action"))); .json(http_responses::APIResponse::Error(String::from(
"invalid_action",
)));
} }
} }
} }