Skip to content

[pytorch] [java/scala] new version remove Module class convert to layerImpl ,cause the layerImpl cannot covert to Module ! #1393

@mullerhai

Description

@mullerhai

HI ,
when I use new version pytorch 2.0.1 in scala 2.11.10, meet the class convert error, the example java code I rewrite in scala, but it cannot work ,because in javacpp-pytorch new version ,Module.class has remove all the layerImpl convert to Module ,use register_module method, why remove them? now the error is

Exception in thread "main" java.lang.ClassCastException: class org.bytedeco.pytorch.Module cannot be cast to class org.bytedeco.pytorch.LinearImpl (org.bytedeco.pytorch.Module and org.bytedeco.pytorch.LinearImpl are in unnamed module of loader 'app')
	at SimpleMNIST$Net.<init>(hell.scala:23)
	at SimpleMNIST$.main(hell.scala:52)
	at SimpleMNIST.main(hell.scala)

how to solve that error ,do I need import some method sugar dependency in scala code?
if I scala code remove asInstanceOf[LinearImpl] these code ,the scala code cannot compile, Please help me ,thanks
dependency:

ThisBuild / version := "0.1.0-SNAPSHOT"

ThisBuild / scalaVersion := "2.12.10"

lazy val root = (project in file("."))
  .settings(
    name := "torchSa"
  )


scalaVersion := "2.12.10"

//idePackagePrefix := Some("org.example")
resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots"

val sparkVersion = "3.1.1"

//libraryDependencies ++= Seq(
//  "org.apache.spark" %% "spark-core" % sparkVersion,
//  "org.apache.spark" %% "spark-sql" % sparkVersion,
//  "org.apache.spark" %% "spark-mllib" % sparkVersion,
//  "org.apache.spark" %% "spark-streaming" % sparkVersion
//)
// https://mvnrepository.com/artifact/org.apache.parquet/parquet-common
libraryDependencies += "org.apache.parquet" % "parquet-common" % "1.12.3"

libraryDependencies += "org.bytedeco" % "pytorch" %  "2.0.1-1.5.10-SNAPSHOT" // "1.12.1-1.5.8" // "1.10.2-1.5.7"
// https://mvnrepository.com/artifact/org.bytedeco/pytorch-platform
libraryDependencies += "org.bytedeco" % "pytorch-platform" % "2.0.1-1.5.10-SNAPSHOT"  //  "1.12.1-1.5.8" //"1.10.2-1.5.7"
//libraryDependencies += "org.bytedeco" % "pytorch-platform-gpu" %  "2.0.1-1.5.10-SNAPSHOT" // "1.12.1-1.5.8" // "1.10.2-1.5.7"
//// https://mvnrepository.com/artifact/org.bytedeco/pytorch-platform

libraryDependencies += "org.bytedeco" % "mkl-platform-redist" % "2023.1-1.5.10-SNAPSHOT"  //  "1.12.1-1.5.8" //"1.10.2-1.5.7"
//

code : convert the example java code to scala

import org.bytedeco.javacpp._
import org.bytedeco.pytorch._
import org.bytedeco.pytorch.Module
import org.bytedeco.pytorch.global.torch._
import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
object SimpleMNIST { // Define a new Module. :LinearImpl :LinearImpl=
  class Net() extends Module { // Construct and register two Linear submodules.
    //fc1 = register_module("fc1", new LinearImpl(784, 64));
    var fc1 = register_module("fc1", new LinearImpl(784, 64)).asInstanceOf[LinearImpl]
    var  fc2 = register_module("fc2", new LinearImpl(64, 32)).asInstanceOf[LinearImpl]
    var  fc3 = register_module("fc3", new LinearImpl(32, 10)).asInstanceOf[LinearImpl]

    // Implement the Net's algorithm.
    def forward(xl: Tensor): Tensor = { // Use one of many tensor manipulation functions.
      var x = xl
      x = relu(fc1.forward(x.reshape(x.size(0), 784)))
      x = dropout(x,  0.5,  is_training)
      x = relu(fc2.asInstanceOf[LinearImpl].forward(x))
      x = log_softmax(fc3.asInstanceOf[LinearImpl].forward(x),  1)
      x
    }

//     Use one of many "standard library" modules.
//    var fc1: LinearImpl = null
//    var fc2: LinearImpl = null
//    var fc3: LinearImpl = null
  }

  @throws[Exception]
  def main(args: Array[String]): Unit = {
    /* try to use MKL when available */
    System.setProperty("org.bytedeco.openblas.load", "mkl")

    // Create a multi-threaded data loader for the MNIST dataset.
    val data_set = new MNIST("./data").map(new ExampleStack)
    val data_loader = new MNISTRandomDataLoader(data_set, new RandomSampler(data_set.size.get), new DataLoaderOptions(/*batch_size=*/ 64))
    // Create a new Net.
    val net = new SimpleMNIST.Net
    // Instantiate an SGD optimization algorithm to update our Net's parameters.
    val optimizer = new SGD(net.parameters, new SGDOptions(/*lr=*/ 0.01))
    for (epoch <- 1 to 10) {
      var batch_index = 0
      // Iterate the data loader to yield batches from the dataset.
      var it = data_loader.begin
      while ( {
        !(it == data_loader.end)
      }) {
        val batch = it.access
        // Reset gradients.
        optimizer.zero_grad()
        // Execute the model on the input data.
        val prediction = net.forward(batch.data)
        // Compute a loss value to judge the prediction of our model.
        val loss = nll_loss(prediction, batch.target)
        // Compute gradients of the loss w.r.t. the parameters of our model.
        loss.backward()
        // Update the parameters based on the calculated gradients.
        optimizer.step
        // Output the loss and checkpoint every 100 batches.
        if ( {
          batch_index += 1; batch_index
        } % 100 == 0) {
          System.out.println("Epoch: " + epoch + " | Batch: " + batch_index + " | Loss: " + loss.item_float)
          // Serialize your model periodically as a checkpoint.
          val archive = new OutputArchive
          net.save(archive)
          archive.save_to("net.pt")
        }

        it = it.increment
      }
    }
  }
}

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions