-
-
Notifications
You must be signed in to change notification settings - Fork 751
Closed
Labels
Description
HI ,
when I use new version pytorch 2.0.1 in scala 2.11.10, meet the class convert error, the example java code I rewrite in scala, but it cannot work ,because in javacpp-pytorch new version ,Module.class has remove all the layerImpl convert to Module ,use register_module method, why remove them? now the error is
Exception in thread "main" java.lang.ClassCastException: class org.bytedeco.pytorch.Module cannot be cast to class org.bytedeco.pytorch.LinearImpl (org.bytedeco.pytorch.Module and org.bytedeco.pytorch.LinearImpl are in unnamed module of loader 'app')
at SimpleMNIST$Net.<init>(hell.scala:23)
at SimpleMNIST$.main(hell.scala:52)
at SimpleMNIST.main(hell.scala)
how to solve that error ,do I need import some method sugar dependency in scala code?
if I scala code remove asInstanceOf[LinearImpl] these code ,the scala code cannot compile, Please help me ,thanks
dependency:
ThisBuild / version := "0.1.0-SNAPSHOT"
ThisBuild / scalaVersion := "2.12.10"
lazy val root = (project in file("."))
.settings(
name := "torchSa"
)
scalaVersion := "2.12.10"
//idePackagePrefix := Some("org.example")
resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots"
val sparkVersion = "3.1.1"
//libraryDependencies ++= Seq(
// "org.apache.spark" %% "spark-core" % sparkVersion,
// "org.apache.spark" %% "spark-sql" % sparkVersion,
// "org.apache.spark" %% "spark-mllib" % sparkVersion,
// "org.apache.spark" %% "spark-streaming" % sparkVersion
//)
// https://mvnrepository.com/artifact/org.apache.parquet/parquet-common
libraryDependencies += "org.apache.parquet" % "parquet-common" % "1.12.3"
libraryDependencies += "org.bytedeco" % "pytorch" % "2.0.1-1.5.10-SNAPSHOT" // "1.12.1-1.5.8" // "1.10.2-1.5.7"
// https://mvnrepository.com/artifact/org.bytedeco/pytorch-platform
libraryDependencies += "org.bytedeco" % "pytorch-platform" % "2.0.1-1.5.10-SNAPSHOT" // "1.12.1-1.5.8" //"1.10.2-1.5.7"
//libraryDependencies += "org.bytedeco" % "pytorch-platform-gpu" % "2.0.1-1.5.10-SNAPSHOT" // "1.12.1-1.5.8" // "1.10.2-1.5.7"
//// https://mvnrepository.com/artifact/org.bytedeco/pytorch-platform
libraryDependencies += "org.bytedeco" % "mkl-platform-redist" % "2023.1-1.5.10-SNAPSHOT" // "1.12.1-1.5.8" //"1.10.2-1.5.7"
//
code : convert the example java code to scala
import org.bytedeco.javacpp._
import org.bytedeco.pytorch._
import org.bytedeco.pytorch.Module
import org.bytedeco.pytorch.global.torch._
import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
object SimpleMNIST { // Define a new Module. :LinearImpl :LinearImpl=
class Net() extends Module { // Construct and register two Linear submodules.
//fc1 = register_module("fc1", new LinearImpl(784, 64));
var fc1 = register_module("fc1", new LinearImpl(784, 64)).asInstanceOf[LinearImpl]
var fc2 = register_module("fc2", new LinearImpl(64, 32)).asInstanceOf[LinearImpl]
var fc3 = register_module("fc3", new LinearImpl(32, 10)).asInstanceOf[LinearImpl]
// Implement the Net's algorithm.
def forward(xl: Tensor): Tensor = { // Use one of many tensor manipulation functions.
var x = xl
x = relu(fc1.forward(x.reshape(x.size(0), 784)))
x = dropout(x, 0.5, is_training)
x = relu(fc2.asInstanceOf[LinearImpl].forward(x))
x = log_softmax(fc3.asInstanceOf[LinearImpl].forward(x), 1)
x
}
// Use one of many "standard library" modules.
// var fc1: LinearImpl = null
// var fc2: LinearImpl = null
// var fc3: LinearImpl = null
}
@throws[Exception]
def main(args: Array[String]): Unit = {
/* try to use MKL when available */
System.setProperty("org.bytedeco.openblas.load", "mkl")
// Create a multi-threaded data loader for the MNIST dataset.
val data_set = new MNIST("./data").map(new ExampleStack)
val data_loader = new MNISTRandomDataLoader(data_set, new RandomSampler(data_set.size.get), new DataLoaderOptions(/*batch_size=*/ 64))
// Create a new Net.
val net = new SimpleMNIST.Net
// Instantiate an SGD optimization algorithm to update our Net's parameters.
val optimizer = new SGD(net.parameters, new SGDOptions(/*lr=*/ 0.01))
for (epoch <- 1 to 10) {
var batch_index = 0
// Iterate the data loader to yield batches from the dataset.
var it = data_loader.begin
while ( {
!(it == data_loader.end)
}) {
val batch = it.access
// Reset gradients.
optimizer.zero_grad()
// Execute the model on the input data.
val prediction = net.forward(batch.data)
// Compute a loss value to judge the prediction of our model.
val loss = nll_loss(prediction, batch.target)
// Compute gradients of the loss w.r.t. the parameters of our model.
loss.backward()
// Update the parameters based on the calculated gradients.
optimizer.step
// Output the loss and checkpoint every 100 batches.
if ( {
batch_index += 1; batch_index
} % 100 == 0) {
System.out.println("Epoch: " + epoch + " | Batch: " + batch_index + " | Loss: " + loss.item_float)
// Serialize your model periodically as a checkpoint.
val archive = new OutputArchive
net.save(archive)
archive.save_to("net.pt")
}
it = it.increment
}
}
}
}