Commit a8a23a1e authored by Johannes Hayeß's avatar Johannes Hayeß
Browse files

Make used olm-rs types thread-safe

parent 5fd30797
Pipeline #62962 failed with stages
in 2 minutes and 27 seconds
......@@ -24,7 +24,7 @@ use futures::sync::mpsc;
use olm_rs::account::OlmAccount;
use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::sync::{Arc, RwLock};
use std::sync::{Arc, Mutex, RwLock};
/// Used to build a Metaolm object.
pub struct MetaolmBuilder {
......@@ -98,7 +98,7 @@ impl MetaolmBuilder {
let mut metaolm = Metaolm {
device_id: self.device_id.unwrap(),
user_id: self.user_id.unwrap(),
olm_account: Arc::new(self.olm_account),
olm_account: Arc::new(Mutex::new(self.olm_account)),
blocker_manager: self.blocker_manager,
blocked_items: self.blocked_items,
received_items: u_receiver,
......
......@@ -27,7 +27,7 @@ use olm_rs::account::OlmAccount;
use olm_rs::session::{OlmMessageType, OlmSession};
use serde_json;
use std::collections::BTreeMap;
use std::sync::{Arc, RwLock};
use std::sync::{Arc, Mutex, RwLock};
pub enum OlmDecryptResult {
Plaintext(OlmPlaintext),
......@@ -44,7 +44,7 @@ pub enum OlmHandleType {
/// algorithm.
pub fn olm_decrypt(
event: &str,
account: &OlmAccount,
account: Arc<Mutex<OlmAccount>>,
session_cache: Arc<RwLock<SessionCache>>,
user_id: &str,
) -> Result<OlmDecryptResult, String> {
......@@ -53,10 +53,11 @@ pub fn olm_decrypt(
Err(_) => return Err(String::from("Error when serialising the received event")),
};
let sender_key = encrypted_event.content.sender_key.clone();
let public_key_pair: OlmIdentityKeys = match serde_json::from_str(&account.identity_keys()) {
Ok(x) => x,
Err(_) => return Err(String::from("Error when serialising the public key pair")),
};
let public_key_pair: OlmIdentityKeys =
match serde_json::from_str(&account.lock().unwrap().identity_keys()) {
Ok(x) => x,
Err(_) => return Err(String::from("Error when serialising the public key pair")),
};
let curve25519_key = &public_key_pair.curve25519;
// Extract specific ciphertext object for our device.
let ciphertext_info_result = encrypted_event.get_ciphertext_info(&public_key_pair.curve25519);
......@@ -88,8 +89,11 @@ pub fn olm_decrypt(
Some(x) => x,
None => unreachable!(),
};
let cs_read = cached_sessions.read().unwrap();
let mut matching_session = None;
for session in cached_sessions {
for session in cs_read.iter() {
if session
.matches_inbound_session(ciphertext_info.body.clone())
.unwrap()
......@@ -113,7 +117,7 @@ pub fn olm_decrypt(
// can decrypt the ciphertext.
} else if msg_type == OlmMessageType::Message && cached_sessions_result.is_some() {
let cached_sessions = cached_sessions_result.unwrap();
for session in cached_sessions {
for session in cached_sessions.read().unwrap().iter() {
let decrypted_result = session.decrypt(msg_type, ciphertext_info.body.clone());
if decrypted_result.is_ok() {
plaintext = Some(decrypted_result.unwrap());
......@@ -155,7 +159,7 @@ pub fn olm_decrypt(
} else if msg_type == OlmMessageType::PreKey {
// Everything else has failed, so we'll try to establish a new session.
let session = match OlmSession::create_inbound_session_from(
account,
&account.lock().unwrap(),
&encrypted_event.content.sender_key,
ciphertext_info.body.clone(),
) {
......@@ -163,7 +167,7 @@ pub fn olm_decrypt(
Err(x) => return Err(format!("Error when trying to create a new inbound session for olm decryption routine: {:#?}", x))
};
match account.remove_one_time_keys(&session) {
match account.lock().unwrap().remove_one_time_keys(&session) {
Ok(_) => {}
Err(_) => return Err("Error when trying to remove OTK from olm account".to_string()),
}
......@@ -184,7 +188,7 @@ pub fn handle_olm_event(
packet: &str,
packet_blocker_id: Option<u32>,
run_type: OlmHandleType,
account: &OlmAccount,
account: Arc<Mutex<OlmAccount>>,
session_cache: Arc<RwLock<SessionCache>>,
blocker_manager: &mut BlockerManager,
blocked_items: &mut BTreeMap<u32, M2CPacket>,
......@@ -223,7 +227,12 @@ pub fn handle_olm_event(
} else {
// If we do, we can immediately proceed
// towards decrypting the received event.
match olm_decrypt(&packet, account, Arc::clone(&session_cache), user_id) {
match olm_decrypt(
&packet,
Arc::clone(&account),
Arc::clone(&session_cache),
user_id,
) {
Ok(OlmDecryptResult::Plaintext(plaintext)) => {
// check for m.room_key event
if check_for_room_key_event_and_handle(
......@@ -285,11 +294,16 @@ pub fn handle_olm_event(
header: M2CPacketType::Store(StorageDesc::Account {
device_id: device_id.to_string(),
}),
body: account.pickle(&[]),
body: account.lock().unwrap().pickle(&[]),
blocker_id: None,
});
match olm_decrypt(&packet, account, Arc::clone(&session_cache), user_id) {
match olm_decrypt(
&packet,
Arc::clone(&account),
Arc::clone(&session_cache),
user_id,
) {
Ok(OlmDecryptResult::Plaintext(plaintext)) => {
// check for m.room_key event
if check_for_room_key_event_and_handle(
......@@ -391,6 +405,8 @@ pub fn handle_megolm_event(
// to prevent replay attacks
let (plaintext, _message_index) = session_with_info
.session
.lock()
.unwrap()
.decrypt(event.content.ciphertext.clone())
.unwrap();
......
......@@ -36,13 +36,13 @@ use olm_rs::errors::OlmAccountError;
use olm_rs::inbound_group_session::OlmInboundGroupSession;
use serde_json;
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::sync::{Arc, RwLock};
use std::sync::{Arc, Mutex, RwLock};
/// Monolithically encapsulates all functionallity of the Metaolm module.
pub struct Metaolm {
pub(crate) device_id: String,
pub(crate) user_id: String,
pub(crate) olm_account: Arc<OlmAccount>,
pub(crate) olm_account: Arc<Mutex<OlmAccount>>,
pub(crate) blocker_manager: BlockerManager,
pub(crate) blocked_items: BTreeMap<u32, M2CPacket>,
pub(crate) received_items: UnboundedReceiver<C2MPacket>,
......@@ -111,7 +111,7 @@ impl Metaolm {
match stored_account {
Ok(account) => {
self.olm_account = Arc::new(account);
self.olm_account = Arc::new(Mutex::new(account));
self.phase = ModulePhase::Running;
}
Err(OlmAccountError::InvalidBase64) => {
......@@ -135,7 +135,7 @@ impl Metaolm {
header: M2CPacketType::Store(StorageDesc::Account {
device_id: self.device_id.clone(),
}),
body: self.olm_account.pickle(&[]),
body: self.olm_account.lock().unwrap().pickle(&[]),
blocker_id: Some(store_request_blocker),
};
......@@ -156,12 +156,13 @@ impl Metaolm {
let register_request_blocker = self.blocker_manager.next_blocker_id();
let identity_keys: OlmIdentityKeys =
serde_json::from_str(&self.olm_account.identity_keys()).unwrap();
serde_json::from_str(&self.olm_account.lock().unwrap().identity_keys())
.unwrap();
let mut device_keys =
SignedDeviceKeys::new(&self.device_id, &self.user_id, identity_keys);
// dump constructed object as Canonical JSON
let device_keys_signature = self.olm_account.sign(
let device_keys_signature = self.olm_account.lock().unwrap().sign(
&serde_json::to_string(&device_keys.get_device_keys_only()).unwrap(),
);
......@@ -240,7 +241,7 @@ impl Metaolm {
&packet.body,
packet.blocker_id,
OlmHandleType::FirstRun,
&self.olm_account,
Arc::clone(&self.olm_account),
Arc::clone(&self.session_cache),
&mut self.blocker_manager,
&mut self.blocked_items,
......@@ -332,7 +333,7 @@ impl Metaolm {
&blocked_item.body,
blocked_item.blocker_id,
OlmHandleType::SecondRun,
&self.olm_account,
Arc::clone(&self.olm_account),
Arc::clone(&self.session_cache),
&mut self.blocker_manager,
&mut self.blocked_items,
......@@ -404,12 +405,12 @@ impl Metaolm {
}
}
M2CPacketType::SelfInfo(SelfInfoType::OTKListPublished) => {
self.olm_account.mark_keys_as_published();
self.olm_account.lock().unwrap().mark_keys_as_published();
self.schedule_for_sending(M2CPacket {
header: M2CPacketType::Store(StorageDesc::Account {
device_id: self.device_id.clone(),
}),
body: self.olm_account.pickle(&[]),
body: self.olm_account.lock().unwrap().pickle(&[]),
blocker_id: None,
});
}
......@@ -486,7 +487,7 @@ impl Metaolm {
match self.otk_replenisher.replenish_for(
otk_count as usize,
&self.olm_account,
Arc::clone(&self.olm_account),
&self.device_id,
&self.user_id,
blocker_id,
......
......@@ -20,6 +20,7 @@ use futures::try_ready;
use futures::{Async, Poll, Stream};
use olm_rs::account::OlmAccount;
use serde_json;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::timer::Interval;
......@@ -56,7 +57,7 @@ impl OTKReplenisher {
pub fn replenish_for(
&self,
otk_count: usize,
olm_account: &OlmAccount,
olm_account: Arc<Mutex<OlmAccount>>,
device_id: &str,
user_id: &str,
blocker_id: u32,
......@@ -65,9 +66,12 @@ impl OTKReplenisher {
None
} else {
// create amount of OTKs to satisfy the replenish goal
olm_account.generate_one_time_keys(self.replenish_goal - otk_count);
olm_account
.lock()
.unwrap()
.generate_one_time_keys(self.replenish_goal - otk_count);
let otks_json: serde_json::Value =
serde_json::from_str(&olm_account.one_time_keys()).unwrap();
serde_json::from_str(&olm_account.lock().unwrap().one_time_keys()).unwrap();
// get the direct mapping of OTK ID to OTK value
let otk_map = otks_json["curve25519"].as_object().unwrap();
......@@ -75,8 +79,10 @@ impl OTKReplenisher {
let mut signed_otks = OneTimeKeys::new();
for (index, one_time_key) in otk_map.iter() {
let mut key = SignedOneTimeKey::new(one_time_key.as_str().unwrap());
let key_signature =
olm_account.sign(&serde_json::to_string(&key.get_otk_only()).unwrap());
let key_signature = olm_account
.lock()
.unwrap()
.sign(&serde_json::to_string(&key.get_otk_only()).unwrap());
key.set_signature(user_id, device_id, &key_signature);
signed_otks.add_key(index, key);
}
......
......@@ -19,6 +19,7 @@ use olm_rs::inbound_group_session::OlmInboundGroupSession;
use olm_rs::session::OlmSession;
use serde_derive::{Deserialize, Serialize};
use serde_json;
use std::sync::{Arc, Mutex, RwLock};
static OLM_CACHE_SIZE: usize = 100;
static INBOUND_MEGOLM_CACHE_SIZE: usize = 100;
......@@ -27,9 +28,9 @@ static INBOUND_MEGOLM_CACHE_SIZE: usize = 100;
/// that will be needed by Metaolm over the course of its operation.
#[derive(Default)]
pub struct SessionCache {
cached_olm_sessions: IndexMap<String, Vec<OlmSession>>,
cached_olm_sessions: IndexMap<String, Arc<RwLock<Vec<OlmSession>>>>,
cached_inbound_megolm_sessions:
IndexMap<InboundMegolmSessionIdentifier, InboundGroupSessionWithKey>,
Arc<RwLock<IndexMap<InboundMegolmSessionIdentifier, InboundGroupSessionWithKey>>>,
}
#[derive(Hash, Eq, PartialEq, Serialize, Deserialize, Clone)]
......@@ -41,8 +42,9 @@ pub struct InboundMegolmSessionIdentifier {
/// An inbound megolm session with the associated ed25519
/// fingerprint key for verifying received messages.
#[derive(Clone)]
pub struct InboundGroupSessionWithKey {
pub session: OlmInboundGroupSession,
pub session: Arc<Mutex<OlmInboundGroupSession>>,
pub ed25519: String,
}
......@@ -51,7 +53,7 @@ impl SessionCache {
pub fn new() -> Self {
SessionCache {
cached_olm_sessions: IndexMap::new(),
cached_inbound_megolm_sessions: IndexMap::new(),
cached_inbound_megolm_sessions: Arc::new(RwLock::new(IndexMap::new())),
}
}
......@@ -71,11 +73,11 @@ impl SessionCache {
if self.cached_olm_sessions.contains_key(&curve25519_key) {
let cached_sessions =
self.cached_olm_sessions.get_mut(&curve25519_key).unwrap();
cached_sessions.push(x);
cached_sessions.write().unwrap().push(x);
} else {
let new_session_list = vec![x];
self.cached_olm_sessions
.insert(curve25519_key, new_session_list);
.insert(curve25519_key, Arc::new(RwLock::new(new_session_list)));
}
Ok(())
}
......@@ -94,7 +96,7 @@ impl SessionCache {
match self.get_cached_olm_sessions(curve25519_key) {
Some(sessions) => {
let mut sessions_serialised: Vec<String> = Vec::new();
for session in sessions {
for session in sessions.read().unwrap().iter() {
sessions_serialised.push(session.pickle(&[]));
}
......@@ -121,7 +123,7 @@ impl SessionCache {
}
self.cached_olm_sessions
.insert(curve25519_key, sessions_deserialised);
.insert(curve25519_key, Arc::new(RwLock::new(sessions_deserialised)));
}
/// Checks if there are any sessions cached for the provided
......@@ -132,7 +134,10 @@ impl SessionCache {
/// Get the sessions associated with the fingerprint key that was
/// provided.
pub fn get_cached_olm_sessions(&self, curve25519_key: &str) -> Option<&Vec<OlmSession>> {
pub fn get_cached_olm_sessions(
&self,
curve25519_key: &str,
) -> Option<&Arc<RwLock<Vec<OlmSession>>>> {
self.cached_olm_sessions.get(curve25519_key)
}
......@@ -149,12 +154,15 @@ impl SessionCache {
) {
// make sure we maintain the correct maximum amount of inbound sessions
while self.cached_olm_sessions.len() >= INBOUND_MEGOLM_CACHE_SIZE {
self.cached_inbound_megolm_sessions.pop();
self.cached_inbound_megolm_sessions.write().unwrap().pop();
}
self.cached_inbound_megolm_sessions.insert(
self.cached_inbound_megolm_sessions.write().unwrap().insert(
session_identifier,
InboundGroupSessionWithKey { session, ed25519 },
InboundGroupSessionWithKey {
session: Arc::new(Mutex::new(session)),
ed25519,
},
);
}
......@@ -162,8 +170,16 @@ impl SessionCache {
pub fn get_cached_inbound_megolm_session(
&self,
session_identifier: &InboundMegolmSessionIdentifier,
) -> Option<&InboundGroupSessionWithKey> {
self.cached_inbound_megolm_sessions.get(session_identifier)
) -> Option<InboundGroupSessionWithKey> {
match self
.cached_inbound_megolm_sessions
.read()
.unwrap()
.get(session_identifier)
{
Some(key) => Some(key.clone()),
None => None,
}
}
pub fn megolm_session_cached(
......@@ -171,6 +187,8 @@ impl SessionCache {
session_identifier: &InboundMegolmSessionIdentifier,
) -> bool {
self.cached_inbound_megolm_sessions
.read()
.unwrap()
.contains_key(session_identifier)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment