2
votes

I want to write a scala macros that can override field values of case class based on map entries with simple type check. In case original field type and override value type are compatible set new value otherwise keep original value.

So far I have following code:

    import language.experimental.macros
    import scala.reflect.macros.Context

    object ProductUtils {

        def withOverrides[T](entity: T, overrides: Map[String, Any]): T =
            macro withOverridesImpl[T]

        def withOverridesImpl[T: c.WeakTypeTag](c: Context)
                                               (entity: c.Expr[T], overrides: c.Expr[Map[String, Any]]): c.Expr[T] = {
            import c.universe._

            val originalEntityTree = reify(entity.splice).tree
            val originalEntityCopy = entity.actualType.member(newTermName("copy"))

            val originalEntity =
                weakTypeOf[T].declarations.collect {
                    case m: MethodSymbol if m.isCaseAccessor =>
                        (m.name, c.Expr[T](Select(originalEntityTree, m.name)), m.returnType)
                }

            val values =
                originalEntity.map {
                    case (name, value, ctype) =>
                        AssignOrNamedArg(
                            Ident(name),
                            {
                                def reifyWithType[K: WeakTypeTag] = reify {
                                    overrides
                                        .splice
                                        .asInstanceOf[Map[String, Any]]
                                        .get(c.literal(name.decoded).splice) match {
                                            case Some(newValue : K) => newValue
                                            case _                  => value.splice
                                        }
                                }

                                reifyWithType(c.WeakTypeTag(ctype)).tree
                            }
                        )
                }.toList

            originalEntityCopy match {
                case s: MethodSymbol =>
                    c.Expr[T](
                        Apply(Select(originalEntityTree, originalEntityCopy), values))
                case _ => c.abort(c.enclosingPosition, "No eligible copy method!")
            }

        }

    }

Executed like this:

    import macros.ProductUtils

    case class Example(field1: String, field2: Int, filed3: String)

    object MacrosTest {
        def main(args: Array[String]) {
            val overrides = Map("field1" -> "new value", "field2" -> "wrong type")
            println(ProductUtils.withOverrides(Example("", 0, ""), overrides)) // Example("new value", 0, "")
        }
    }

As you can see, I've managed to get type of original field and now want to pattern match on it in reifyWithType.

Unfortunately in current implementation I`m getting a warning during compilation:

warning: abstract type pattern K is unchecked since it is eliminated by erasure case Some(newValue : K) => newValue

and a compiler crash in IntelliJ:

Exception in thread "main" java.lang.NullPointerException
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseAsInstanceOf$1(Erasure.scala:1032)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseNormalApply(Erasure.scala:1083)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseApply(Erasure.scala:1187)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preErase(Erasure.scala:1193)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1268)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1018)
    at scala.reflect.internal.Trees$class.itransform(Trees.scala:1217)
    at scala.reflect.internal.SymbolTable.itransform(SymbolTable.scala:13)
    at scala.reflect.internal.SymbolTable.itransform(SymbolTable.scala:13)
    at scala.reflect.api.Trees$Transformer.transform(Trees.scala:2897)
    at scala.tools.nsc.transform.TypingTransformers$TypingTransformer.transform(TypingTransformers.scala:48)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1280)
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1018)

So the questions are:
* Is it possible to make type comparison of type received in macro to value runtime type?
* Or is there any better approach to solve this task?

1
I'm not sure macros help you much here, since you're going to have to have the macro generate runtime reflection code (which is definitely possible but kind of unpleasant).Travis Brown

1 Answers

0
votes

After all I ended up with following solution:

import language.experimental.macros
import scala.reflect.macros.Context

object ProductUtils {

    def withOverrides[T](entity: T, overrides: Map[String, Any]): T =
        macro withOverridesImpl[T]

    def withOverridesImpl[T: c.WeakTypeTag](c: Context)(entity: c.Expr[T], overrides: c.Expr[Map[String, Any]]): c.Expr[T] = {
        import c.universe._

        val originalEntityTree = reify(entity.splice).tree
        val originalEntityCopy = entity.actualType.member(newTermName("copy"))

        val originalEntity =
            weakTypeOf[T].declarations.collect {
                case m: MethodSymbol if m.isCaseAccessor =>
                    (m.name, c.Expr[T](Select(c.resetAllAttrs(originalEntityTree), m.name)), m.returnType)
            }

        val values =
            originalEntity.map {
                case (name, value, ctype) =>
                    AssignOrNamedArg(
                        Ident(name),
                        {

                            val ruClass = c.reifyRuntimeClass(ctype)
                            val mtag    = c.reifyType(treeBuild.mkRuntimeUniverseRef, Select(treeBuild.mkRuntimeUniverseRef, newTermName("rootMirror")), ctype)
                            val mtree   = Select(mtag, newTermName("tpe"))

                            def reifyWithType[K: c.WeakTypeTag] = reify {

                                def tryNewValue[A: scala.reflect.runtime.universe.TypeTag](candidate: Option[A]): Option[K] =
                                    if (candidate.isEmpty) {
                                        None
                                    } else {
                                        val cc =  c.Expr[Class[_]](ruClass).splice
                                        val candidateValue = candidate.get
                                        val candidateType  = scala.reflect.runtime.universe.typeOf[A]
                                        val expectedType   = c.Expr[scala.reflect.runtime.universe.Type](mtree).splice

                                        val ok = (cc.isPrimitive, candidateValue) match {
                                            case (true, _: java.lang.Integer)   => cc == java.lang.Integer.TYPE
                                            case (true, _: java.lang.Long)      => cc == java.lang.Long.TYPE
                                            case (true, _: java.lang.Double)    => cc == java.lang.Double.TYPE
                                            case (true, _: java.lang.Character) => cc == java.lang.Character.TYPE
                                            case (true, _: java.lang.Float)     => cc == java.lang.Float.TYPE
                                            case (true, _: java.lang.Byte)      => cc == java.lang.Byte.TYPE
                                            case (true, _: java.lang.Short)     => cc == java.lang.Short.TYPE
                                            case (true, _: java.lang.Boolean)   => cc == java.lang.Boolean.TYPE
                                            case (true, _: Unit)                => cc == java.lang.Void.TYPE
                                            case  _                             =>
                                                val args = candidateType.asInstanceOf[scala.reflect.runtime.universe.TypeRefApi].args
                                                if (!args.contains(scala.reflect.runtime.universe.typeOf[Any])
                                                       && !(candidateType =:= scala.reflect.runtime.universe.typeOf[Any]))
                                                    candidateType =:= expectedType
                                                else cc.isInstance(candidateValue)
                                        }

                                        if (ok)
                                            Some(candidateValue.asInstanceOf[K])
                                        else None
                                }

                                tryNewValue(overrides.splice.get(c.literal(name.decoded).splice)).getOrElse(value.splice)
                            }

                            reifyWithType(c.WeakTypeTag(ctype)).tree
                        }
                    )
            }.toList

        originalEntityCopy match {
            case s: MethodSymbol =>
                c.Expr[T](
                    Apply(Select(originalEntityTree, originalEntityCopy), values))
            case _ => c.abort(c.enclosingPosition, "No eligible copy method!")
        }

    }

}

It kind of satisfies original requirements:

class ProductUtilsTest extends FunSuite {

    case class A(a: String, b: String)
    case class B(a: String, b: Int)
    case class C(a: List[Int], b: List[String])
    case class D(a: Map[Int, String], b: Double)
    case class E(a: A, b: B)

    test("simple overrides works"){
        val overrides = Map("a" -> "A", "b" -> "B")
        assert(ProductUtils.withOverrides(A("", ""), overrides) === A("A", "B"))
    }

    test("simple overrides works 1"){
        val overrides = Map("a" -> "A", "b" -> 1)
        assert(ProductUtils.withOverrides(B("", 0), overrides) === B("A", 1))
    }

    test("do not override if types do not match"){
        val overrides = Map("a" -> 0, "b" -> List("B"))
        assert(ProductUtils.withOverrides(B("", 0), overrides) === B("", 0))
    }

    test("complex types also works"){
        val overrides = Map("a" -> List(1), "b" -> List("A"))
        assert(ProductUtils.withOverrides(C(List(0), List("")), overrides) === C(List(1), List("A")))
    }

    test("complex types also works 1"){
        val overrides = Map("a" -> List(new Date()), "b" -> 2.0d)
        assert(ProductUtils.withOverrides(D(Map(), 1.0), overrides) === D(Map(), 2.0))
    }

    test("complex types also works 2"){
        val overrides = Map("a" -> A("AA", "BB"), "b" -> 2.0d)
        assert(ProductUtils.withOverrides(E(A("", ""), B("", 0)), overrides) === E(A("AA", "BB"), B("", 0)))
    }

}

Unfortunatelly because of type erasure in Java/Scala it is hard to force type equality before changing value to new value, so you can do something like this:

scala> case class C(a: List[Int], b: List[String])
defined class C

scala> val overrides = Map("a" -> List(new Date()), "b" -> List(1.0))
overrides: scala.collection.immutable.Map[String,List[Any]] = Map(a -> List(Mon Aug 26 15:52:27 CEST 2013), b -> List(1.0))

scala> ProductUtils.withOverrides(C(List(0), List("")), overrides)
res0: C = C(List(Mon Aug 26 15:52:27 CEST 2013),List(1.0))

scala> res0.a.head + 1
java.lang.ClassCastException: java.util.Date cannot be cast to java.lang.Integer
    at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:106)
    at .<init>(<console>:14)
    at .<clinit>(<console>)
    at .<init>(<console>:7)
    at .<clinit>(<console>)
    at $print(<console>)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:606)
    at scala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:734)
    at scala.tools.nsc.interpreter.IMain$Request.loadAndRun(IMain.scala:983)
    at scala.tools.nsc.interpreter.IMain.loadAndRunReq$1(IMain.scala:573)
    at scala.tools.nsc.interpreter.IMain.interpret(IMain.scala:604)