8
votes

I have a Rust program which contains a number of different structs which all implement a trait called ApplyAction. Another struct, ActionList, contains a vector of boxed objects which implement ApplyAction. I would like to create some unit tests which compare ActionLists with one another.

There are a few different SO questions which deal with PartialEq on boxed traits, and I've used these to get some way towards an implementation. However, in the (simplified) code below (and on the Playground), the assertions in main() fail because the type ids of the objects passed to eq() differ. Why?

Also, this seems extremely complicated for such a simple use case -- is there an easier way to do this?

use std::any::TypeId;
use std::boxed::Box;
use std::fmt;
use std::mem::transmute;

#[derive(Debug, Eq, PartialEq)]
pub struct MyAction<T: fmt::Debug> {
    label: T,
}

impl<T: fmt::Debug> MyAction<T> {
    pub fn new(label: T) -> MyAction<T> {
        MyAction { label: label }
    }
}

pub trait ApplyAction<T: fmt::Debug + PartialEq>: fmt::Debug {
    fn get_type(&self) -> TypeId;
    fn is_eq(&self, other: &ApplyAction<T>) -> bool;
}

impl<T: fmt::Debug + Eq + 'static> ApplyAction<T> for MyAction<T> {
    fn get_type(&self) -> TypeId {
        TypeId::of::<MyAction<T>>()
    }

    fn is_eq(&self, other: &ApplyAction<T>) -> bool {
        if other.get_type() == TypeId::of::<Self>() {
            // Rust thinks that self and other are different types in the calls below.
            let other_ = unsafe { *transmute::<&&ApplyAction<T>, &&Self>(&other) };
            self.label == other_.label
        } else {
            false
        }
    }
}

impl<T: fmt::Debug + Eq + PartialEq + 'static> PartialEq for ApplyAction<T> {
    fn eq(&self, other: &ApplyAction<T>) -> bool {
        if other.get_type() == TypeId::of::<Self>() {
            self.is_eq(other)
        } else {
            false
        }
    }
}

#[derive(Debug)]
pub struct ActionList<T: fmt::Debug> {
    actions: Vec<Box<ApplyAction<T>>>,
}

impl<T: fmt::Debug + PartialEq> ActionList<T> {
    pub fn new() -> ActionList<T> {
        ActionList { actions: vec![] }
    }
    pub fn push<A: ApplyAction<T> + 'static>(&mut self, action: A) {
        self.actions.push(Box::new(action));
    }
}

impl<T: fmt::Debug + Eq + PartialEq + 'static> PartialEq for ActionList<T> {
    fn eq(&self, other: &ActionList<T>) -> bool {
        for (i, action) in self.actions.iter().enumerate() {
            if **action != *other.actions[i] {
                return false;
            }
        }
        true
    }
}

fn main() {
    let mut script1: ActionList<String> = ActionList::new();
    script1.push(MyAction::new("foo".to_string()));
    let mut script2: ActionList<String> = ActionList::new();
    script2.push(MyAction::new("foo".to_string()));
    let mut script3: ActionList<String> = ActionList::new();
    script3.push(MyAction::new("bar".to_string()));
    assert_eq!(script1, script2);
    assert_ne!(script1, script3);
}
1
You can do something like this: playground. I cannot write full answer now, so feel free to write your own.red75prime

1 Answers

1
votes

In the impl<...> PartialEq for ApplyAction<T> you used TypeId::of::<Self>(); i.e. the type of the unsized trait object. That isn't what you wanted; but remove the if and directly call self.is_eq(other), and your code should be working.

Sadly your example requires a lot of code to implement ApplyAction<T> for MyAction<T> - and again for each other action type you might want to use.

I tried to remove that overhead, and with nightly features it is completely gone (and otherwise only a small stub remains):

Playground

// see `default impl` below
#![feature(specialization)]

// Any::<T>::downcast_ref only works for special trait objects (`Any` and
// `Any + Send`); having a trait `T` derive from `Any` doesn't allow you to
// coerce ("cast") `&T` into `&Any` (that might change in the future).
//
// Implementing a custom `downcast_ref` which takes any
// `T: Any + ?Sized + 'static` as input leads to another problem: if `T` is a
// trait that didn't inherit `Any` you still can call `downcast_ref`, but it
// won't work (it will use the `TypeId` of the trait object instead of the
// underlying (sized) type).
//
// Use `SizedAny` instead: it's only implemented for sized types by default;
// that prevents the problem above, and we can implement `downcast_ref` without
// worrying.

mod sized_any {
    use std::any::TypeId;

    // don't allow other implementations of `SizedAny`; `SizedAny` must only be
    // implemented for sized types.
    mod seal {
        // it must be a `pub trait`, but not be reachable - hide it in
        // private mod.
        pub trait Seal {}
    }

    pub trait SizedAny: seal::Seal + 'static {
        fn get_type_id(&self) -> TypeId {
            TypeId::of::<Self>()
        }
    }

    impl<T: 'static> seal::Seal for T {}
    impl<T: 'static> SizedAny for T {}

    // `SizedAny + ?Sized` means it can be a trait object, but `SizedAny` was
    // implemented for the underlying sized type.
    pub fn downcast_ref<From, To>(v: &From) -> Option<&To>
    where
        From: SizedAny + ?Sized + 'static,
        To: 'static,
    {
        if TypeId::of::<To>() == <From as SizedAny>::get_type_id(v) {
            Some(unsafe { &*(v as *const From as *const To) })
        } else {
            None
        }
    }
}
use sized_any::*;

use std::boxed::Box;
use std::fmt;

// `ApplyAction`

fn foreign_eq<T, U>(a: &T, b: &U) -> bool
where
    T: PartialEq + 'static,
    U: SizedAny + ?Sized + 'static,
{
    if let Some(b) = downcast_ref::<U, T>(b) {
        a == b
    } else {
        false
    }
}

pub trait ApplyAction<T: 'static>: fmt::Debug + SizedAny + 'static {
    fn foreign_eq(&self, other: &ApplyAction<T>) -> bool;
}

// requires `#![feature(specialization)]` and a nightly compiler.
// could also copy the default implementation manually to each `impl` instead.
//
// this implementation only works with sized `A` types; we cannot make 
// `ApplyAction<T>` inherit `Sized`, as that would destroy object safety.
default impl<T: 'static, A: PartialEq + 'static> ApplyAction<T> for A {
    fn foreign_eq(&self, other: &ApplyAction<T>) -> bool {
        foreign_eq(self, other)
    }
}

impl<T: 'static> PartialEq for ApplyAction<T> {
    fn eq(&self, other: &ApplyAction<T>) -> bool {
        self.foreign_eq(other)
    }
}

// `MyAction`

#[derive(Debug, Eq, PartialEq)]
pub struct MyAction<T: fmt::Debug> {
    label: T,
}

impl<T: fmt::Debug> MyAction<T> {
    pub fn new(label: T) -> MyAction<T> {
        MyAction { label: label }
    }
}

impl<T: fmt::Debug + PartialEq + 'static> ApplyAction<T> for MyAction<T> {}

// `ActionList`

#[derive(Debug)]
pub struct ActionList<T> {
    actions: Vec<Box<ApplyAction<T>>>,
}

impl<T: 'static> ActionList<T> {
    pub fn new() -> ActionList<T> {
        ActionList { actions: vec![] }
    }

    pub fn push<A: ApplyAction<T> + 'static>(&mut self, action: A) {
        self.actions.push(Box::<A>::new(action));
    }
}

impl<T: 'static> PartialEq for ActionList<T> {
    fn eq(&self, other: &ActionList<T>) -> bool {
        if self.actions.len() != other.actions.len() {
            return false;
        }
        for (i, action) in self.actions.iter().enumerate() {
            if **action != *other.actions[i] {
                return false;
            }
        }
        true
    }
}

// `main`

fn main() {
    let mut script1: ActionList<String> = ActionList::new();
    script1.push(MyAction::new("foo".to_string()));
    let mut script2: ActionList<String> = ActionList::new();
    script2.push(MyAction::new("foo".to_string()));
    let mut script3: ActionList<String> = ActionList::new();
    script3.push(MyAction::new("bar".to_string()));
    assert_eq!(script1, script2);
    assert_ne!(script1, script3);
}

See also: