Scala 在LMS(轻量级模块化暂存)中使用数字类型类

Scala 在LMS(轻量级模块化暂存)中使用数字类型类,scala,dsl,Scala,Dsl,以下是教程的网址: 我已经成功地编译并理解了大部分代码。虽然这些概念和示例非常吸引人,但我立即想将代码从硬编码“Double”作为标量集更改为从标准库实现类型类Numeric[T]的任何代码。然而,我没有成功 我尝试过向linearagebraexp特性添加以下代码: override type Scalar = Double override type Vector = Seq[Scalar] implicit val num:Numeric[Scalar] = implicit

以下是教程的网址:

我已经成功地编译并理解了大部分代码。虽然这些概念和示例非常吸引人,但我立即想将代码从硬编码“Double”作为标量集更改为从标准库实现类型类Numeric[T]的任何代码。然而,我没有成功

我尝试过向linearagebraexp特性添加以下代码:

  override type Scalar = Double
  override type Vector = Seq[Scalar]
  implicit val num:Numeric[Scalar] = implicitly[Numeric[Scalar]]
这是行不通的。我的下一个想法(可能更好)是将隐式数值参数添加到任何实现函数(即vector_scale的所有实际实现)。由于奇异的编译时错误,我仍然无法完全理解它


LMS目前是否支持使用数字类型?查看LMS的源代码,现在看起来可能真的很混乱。

我通过向程序中的contextbounds添加清单,获得了一些可用代码:

import scala.virtualization.lms.common._

trait LinearAlgebra extends Base {

  type Vector[T]
  type Matrix[T]

  def vector_scale[T:Manifest:Numeric](v: Rep[Vector[T]], k: Rep[T]): Rep[Vector[T]]

  // Tensor product between 2 matrices
  def tensor_prod[T:Manifest:Numeric](A:Rep[Matrix[T]],B:Rep[Matrix[T]]):Rep[Matrix[T]]

  // Concrete syntax
  implicit class VectorOps[T:Numeric:Manifest](v: Rep[Vector[T]]) {
    def *(k: Rep[T]):Rep[Vector[T]] = vector_scale[T](v, k)
  }
  implicit class MatrixOps[T:Numeric:Manifest](A:Rep[Matrix[T]]) {
    def |*(B:Rep[Matrix[T]]):Rep[Matrix[T]] = tensor_prod(A,B)
  }

  implicit def any2rep[T:Manifest](t:T) = unit(t)

}

trait Interpreter extends Base {
  override type Rep[+A] = A
  override protected def unit[A: Manifest](a: A) = a
}

trait LinearAlgebraInterpreter extends LinearAlgebra with Interpreter {

  override type Vector[T] = Array[T]
  override type Matrix[T] = Array[Array[T]]
  override def vector_scale[T:Manifest](v: Vector[T], k: T)(implicit num:Numeric[T]):Rep[Vector[T]] =  v map {x => num.times(x,k)}
  def tensor_prod[T:Manifest](A:Matrix[T],B:Matrix[T])(implicit num:Numeric[T]):Matrix[T] = {
    def smm(s:T,m:Matrix[T]) = m.map(_.map(x => num.times(x,s)))
    def concat(A:Matrix[T],B:Matrix[T]) = (A,B).zipped.map(_++_)
    A flatMap (row => row map ( s => smm(s,B)) reduce concat)
  }
}


trait LinearAlgebraExp extends LinearAlgebra with BaseExp {
  // Here we say how a Rep[Vector] will be bound to a Array[Scalar] in regular Scala code
  override type Vector[T] = Array[T]
  type Matrix[T] = Array[Array[T]]

  // Reification of the concept of scaling a vector `v` by a factor `k`
  case class VectorScale[T:Manifest:Numeric](v: Exp[Vector[T]], k: Exp[T]) extends Def[Vector[T]]

  override def vector_scale[T:Manifest:Numeric](v: Exp[Vector[T]], k: Exp[T]) = toAtom(VectorScale(v, k))
  def tensor_prod[T:Manifest:Numeric](A:Rep[Matrix[T]],B:Rep[Matrix[T]]):Rep[Matrix[T]] = ???
}

trait ScalaGenLinearAlgebra extends ScalaGenBase {
  // This code generator works with IR nodes defined by the LinearAlgebraExp trait
  val IR: LinearAlgebraExp
  import IR._

  override def emitNode(sym: Sym[Any], node: Def[Any]): Unit = node match {
    case VectorScale(v, k) => {
      emitValDef(sym, quote(v) + ".map(x => x * " + quote(k) + ")")
    }
    case _ => super.emitNode(sym, node)
  }
}

trait LinearAlgebraExpOpt extends LinearAlgebraExp {
  override def vector_scale[T:Manifest:Numeric](v: Exp[Vector[T]], k: Exp[T]) = k match {
    case Const(1.0) => v
    case _ => super.vector_scale(v, k)
  }
}

trait Prog extends LinearAlgebra {
  def f[T:Manifest](v: Rep[Vector[T]])(implicit num:Numeric[T]): Rep[Vector[T]] = v * unit(num.fromInt(3))
  def g[T:Manifest](v: Rep[Vector[T]])(implicit num:Numeric[T]): Rep[Vector[T]] = v * unit(num.fromInt(1))
  //def h(A:Rep[Matrix],B:Rep[Matrix]):Rep[Matrix] = A |* B

}

object TestLinAlg extends App {

  val interpretedProg = new Prog with LinearAlgebraInterpreter {
    println(g(Array(1.0, 2.0)).mkString(","))
  }

  val optProg = new Prog with LinearAlgebraExpOpt with EffectExp with CompileScala { self =>
    override val codegen = new ScalaGenEffect with ScalaGenLinearAlgebra { val IR: self.type = self }
    codegen.emitSource(g[Double], "optimizedG", new java.io.PrintWriter(System.out))
  }

  val nonOptProg = new Prog with LinearAlgebraExp with EffectExp with CompileScala { self =>
    override val codegen = new ScalaGenEffect with ScalaGenLinearAlgebra { val IR: self.type = self }
    codegen.emitSource(g[Double], "nonOptimizedG", new java.io.PrintWriter(System.out))
  }

  def compareInterpCompiled = {
    val optcomp = optProg.compile(optProg.g[Double])
    val nonOptComp = nonOptProg.compile(nonOptProg.g[Double])
    val a = Array(1.0,2.0)
    optcomp(a).toList == nonOptComp(a).toList
  }

  println(compareInterpCompiled)

}
我的目标是在上使用该示例,然后将其修改为使用数字类型。我想发现传递隐式类型类的开销被完全剥离了。上述程序的(成功)输出为

我们看到对num.times的调用确实被剥离了