Skip to content

Conversation

goldsborough
Copy link
Contributor

This PR will be the development ground of PyTorch's C++ API for a while. To get CI feedback, without pushing to master, I will be posting continuous updates into this PR, one isolated, squashed commit at a time, and request reviews of individual commits.

CC @ebetica @jgehring

@goldsborough
Copy link
Contributor Author

goldsborough commented Apr 6, 2018

To begin, this PR creates a prototypical C++ API for torch::nn::Module. Lots of things are not yet implemented, however I created a stub (which compiles) for an LSTM module that showcases the general look and feel of the API, and have documented the methods I think torch::nn::Module should have. This PR also sets up the general directory structure, which I have chosen as torch/csrc/api as a start. I have the build system updates in another PR that I will push here once ready.

@dzhulgakov
Copy link
Collaborator

This is cool! Just making sure you're aware of https://sourcegraph.com/github.com/pytorch/pytorch/-/blob/torch/csrc/jit/script/module.h - we need to make sure that there's a single nn.Module abstraction in C++ which works with python, script, tracing, etc

@dzhulgakov dzhulgakov requested a review from jamesr66a April 6, 2018 17:59
/// Returns the name of the `Module`.
const std::string& name() const noexcept;

/// Performs a recursive clone of the entire module hierarchy. This is to

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

virtual std::unique_ptr<Module> clone();

/// Takes a list of input variables and computes a list of output variables.
virtual std::vector<Tensor> forward(const std::vector<Tensor>& inputs) = 0;

This comment was marked as off-topic.

}

std::vector<torch::Tensor> forward(
const std::vector<torch::Tensor>& inputs) override {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


namespace torch {
/// There will only be gradient recording tensors in the frontend API.
using Tensor = autograd::Variable;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}

void Module::eval() {
is_training_ = false;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

class LSTM : public torch::nn::CloneableModule<LSTM> {
public:
LSTM(long input_features, long state_size)
: CloneableModule<LSTM>("LSTM"),

This comment was marked as off-topic.

This comment was marked as off-topic.

class Module {
public:
/// Tells the base `Module` about the name of the submodule.
explicit Module(std::string name);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Apr 8, 2018

Nice POC! I found a few worrying things and left comments

Copy link
Contributor

@jgehring jgehring left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is perhaps a naive question, but why provide a clone() method rather than a copy constructor? clone() requires CRTP if you want to return the correct type, and in the current implementation it dictates the storage for the resulting module (unique_ptr).

@goldsborough
Copy link
Contributor Author

goldsborough commented Apr 9, 2018 via email

@ebetica
Copy link
Contributor

ebetica commented Apr 9, 2018

@jgehring see the virtual copy constructor paradigm:

https://www.geeksforgeeks.org/advanced-c-virtual-copy-constructor/

@jgehring
Copy link
Contributor

jgehring commented Apr 9, 2018

Ah thanks, of course, I completely missed that. Indeed, we discussed CRTP quite a lot for @ebetica's autogradpp...

@goldsborough
Copy link
Contributor Author

goldsborough commented Apr 10, 2018

I'v squashed my commits a bit to now have:

  1. Initial commit + operator() (already reviewed)
  2. Stuff to make the code compile with cpp_build, so that CI is now building my changes
  3. Code move to put the OrderedDict data structure from torch/csrc/jit/module.cpp into the API folder
  4. Implemented the Sequential container (reviews welcome)

@@ -543,6 +543,7 @@ def run(self):
include_dirs += [
cwd,
os.path.join(cwd, "torch", "csrc"),
os.path.join(cwd, "torch", "csrc", "api", "include"),

This comment was marked as off-topic.

This comment was marked as off-topic.

class Module {
public:
/// Tells the base `Module` about the name of the submodule.
explicit Module(std::string name);

This comment was marked as off-topic.

if (auto method = methods.find(name)) {
return method->get();
}
return nullptr;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/autograd/variable.h"
#include <ATen/optional.h>

#include "torch/csrc/api/include/torch/detail/ordered_dict.h"

This comment was marked as off-topic.

This comment was marked as off-topic.

append(std::forward<Tail>(tail)...);
}

std::vector<std::unique_ptr<Module>> modules_;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


/// Adds a new `Module` to the `Sequential` container.
template <typename M>
void append(M&& module) {

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Apr 12, 2018

CC @smessmer who is our resident C++ expert

@@ -10,6 +10,11 @@ if (VERBOSE)
message(STATUS "ATEN_BUILD_PATH is ${ATEN_BUILD_PATH}")
endif()

set(CMAKE_CXX_FLAGS "--std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-variable ${CMAKE_CXX_FLAGS}")

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@goldsborough
Copy link
Contributor Author

Changes:

  1. OrderedDict now stores Items, to support iteration (otherwise there's no easy way to go from a value back to its key)
  2. Submodules are stored as shared_ptrs instead of Module*.
  3. Implementation of cursors, which allow hierarchical traversal of submodules, parameters and buffers. Essentially a wrapper over std::vector<std::pair<std::string, Module>> but with a nice API to be able to write e.g. module.parameters().find("key") or module.modules.apply([](Module& module) { ... }). -- @smessmer may want to have a look
  4. Quick hacky implementation of zero_grad() based on the cursors API. Need to think more about the Variable API, and implement toType_()

@goldsborough goldsborough mentioned this pull request Apr 12, 2018
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM. Let's get this in soon, because it's getting dangerously large.


## API tests

SET(TORCH_API_TEST_SRCS ${TORCH_SRC_DIR}/csrc/api/test/test.cpp)

This comment was marked as off-topic.

This comment was marked as off-topic.

"${TORCH_SRC_DIR}/../aten/src/ATen/utils/catch/single_include"
"${COMMON_INCLUDES}")

INSTALL(TARGETS test_api

This comment was marked as off-topic.

size_t size() const noexcept;

protected:
std::vector<Item> items_;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


void Module::type(at::ScalarType new_type) {
// parameters().apply([=](Tensor& tensor) { tensor.toType_(new_type); });
// buffers().apply([=](Tensor& tensor) { tensor.toType_(new_type); });

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/autograd/variable.h"
#include <ATen/optional.h>

#include "torch/csrc/api/include/torch/detail/ordered_dict.h"

This comment was marked as off-topic.

if (auto method = methods.find(name)) {
return method->get();
}
return nullptr;

This comment was marked as off-topic.


py::tuple result(self.get_parameters().size());
size_t index = 0;
for (auto& parameter : self.get_parameters()) {

This comment was marked as off-topic.

This comment was marked as off-topic.

for (auto& method : self.get_methods()) {
method_names.push_back((*method)->name());
}
return method_names;

This comment was marked as off-topic.

This comment was marked as off-topic.

@goldsborough
Copy link
Contributor Author

@apaszke I was actually planning on doing all development in this PR, and squashing commits to make it manageable. This means we don't have to make it "public" yet (so people don't use it). And I think if I push it onto a branch, say pytorch:cpp , the CI won't test it when I make PRs (is this the case?)? Is this a good idea?

@apaszke
Copy link
Contributor

apaszke commented Apr 15, 2018

I think it's easier to merge this thing in parts, even if it's going to undergo breaking changes. It's already 1,1k lines long, and reviewing PRs that have 2k+ is really hard.

It's probably ok to merge this even before it's public. I doubt it's going to be used before we actually start advertising it, and we should make it clear that it's really experimental and will break your code.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure this isn't merged before you remove the accidental ONNX submodule change.

This was referenced Apr 28, 2018
@goldsborough
Copy link
Contributor Author

Will be rebased on top of autogradpp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants