I have a Rust program which contains a number of different structs which all implement a trait called ApplyAction
. Another struct, ActionList
, contains a vector of boxed objects which implement ApplyAction
. I would like to create some unit tests which compare ActionList
s with one another.
There are a few different SO questions which deal with PartialEq
on boxed traits, and I've used these to get some way towards an implementation. However, in the (simplified) code below (and on the Playground), the assertions in main()
fail because the type ids of the objects passed to eq()
differ. Why?
Also, this seems extremely complicated for such a simple use case -- is there an easier way to do this?
use std::any::TypeId;
use std::boxed::Box;
use std::fmt;
use std::mem::transmute;
#[derive(Debug, Eq, PartialEq)]
pub struct MyAction<T: fmt::Debug> {
label: T,
}
impl<T: fmt::Debug> MyAction<T> {
pub fn new(label: T) -> MyAction<T> {
MyAction { label: label }
}
}
pub trait ApplyAction<T: fmt::Debug + PartialEq>: fmt::Debug {
fn get_type(&self) -> TypeId;
fn is_eq(&self, other: &ApplyAction<T>) -> bool;
}
impl<T: fmt::Debug + Eq + 'static> ApplyAction<T> for MyAction<T> {
fn get_type(&self) -> TypeId {
TypeId::of::<MyAction<T>>()
}
fn is_eq(&self, other: &ApplyAction<T>) -> bool {
if other.get_type() == TypeId::of::<Self>() {
// Rust thinks that self and other are different types in the calls below.
let other_ = unsafe { *transmute::<&&ApplyAction<T>, &&Self>(&other) };
self.label == other_.label
} else {
false
}
}
}
impl<T: fmt::Debug + Eq + PartialEq + 'static> PartialEq for ApplyAction<T> {
fn eq(&self, other: &ApplyAction<T>) -> bool {
if other.get_type() == TypeId::of::<Self>() {
self.is_eq(other)
} else {
false
}
}
}
#[derive(Debug)]
pub struct ActionList<T: fmt::Debug> {
actions: Vec<Box<ApplyAction<T>>>,
}
impl<T: fmt::Debug + PartialEq> ActionList<T> {
pub fn new() -> ActionList<T> {
ActionList { actions: vec![] }
}
pub fn push<A: ApplyAction<T> + 'static>(&mut self, action: A) {
self.actions.push(Box::new(action));
}
}
impl<T: fmt::Debug + Eq + PartialEq + 'static> PartialEq for ActionList<T> {
fn eq(&self, other: &ActionList<T>) -> bool {
for (i, action) in self.actions.iter().enumerate() {
if **action != *other.actions[i] {
return false;
}
}
true
}
}
fn main() {
let mut script1: ActionList<String> = ActionList::new();
script1.push(MyAction::new("foo".to_string()));
let mut script2: ActionList<String> = ActionList::new();
script2.push(MyAction::new("foo".to_string()));
let mut script3: ActionList<String> = ActionList::new();
script3.push(MyAction::new("bar".to_string()));
assert_eq!(script1, script2);
assert_ne!(script1, script3);
}