-
Notifications
You must be signed in to change notification settings - Fork 25.1k
[WIP] C++ API #6345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] C++ API #6345
Conversation
To begin, this PR creates a prototypical C++ API for |
This is cool! Just making sure you're aware of https://sourcegraph.com/github.com/pytorch/pytorch/-/blob/torch/csrc/jit/script/module.h - we need to make sure that there's a single nn.Module abstraction in C++ which works with python, script, tracing, etc |
/// Returns the name of the `Module`. | ||
const std::string& name() const noexcept; | ||
|
||
/// Performs a recursive clone of the entire module hierarchy. This is to |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
virtual std::unique_ptr<Module> clone(); | ||
|
||
/// Takes a list of input variables and computes a list of output variables. | ||
virtual std::vector<Tensor> forward(const std::vector<Tensor>& inputs) = 0; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
} | ||
|
||
std::vector<torch::Tensor> forward( | ||
const std::vector<torch::Tensor>& inputs) override { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
namespace torch { | ||
/// There will only be gradient recording tensors in the frontend API. | ||
using Tensor = autograd::Variable; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/api/src/module.cpp
Outdated
} | ||
|
||
void Module::eval() { | ||
is_training_ = false; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
class LSTM : public torch::nn::CloneableModule<LSTM> { | ||
public: | ||
LSTM(long input_features, long state_size) | ||
: CloneableModule<LSTM>("LSTM"), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
class Module { | ||
public: | ||
/// Tells the base `Module` about the name of the submodule. | ||
explicit Module(std::string name); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Nice POC! I found a few worrying things and left comments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is perhaps a naive question, but why provide a clone()
method rather than a copy constructor? clone()
requires CRTP if you want to return the correct type, and in the current implementation it dictates the storage for the resulting module (unique_ptr).
It's for when all you have is an abstract Module*, i.e. a "polymorphic copy constructor". In that case you need to know the static type to invoke the copy constructor. @ebetica was showing me some neuroevolution stuff where you guys needed this
|
@jgehring see the virtual copy constructor paradigm: https://www.geeksforgeeks.org/advanced-c-virtual-copy-constructor/ |
Ah thanks, of course, I completely missed that. Indeed, we discussed CRTP quite a lot for @ebetica's autogradpp... |
I'v squashed my commits a bit to now have:
|
@@ -543,6 +543,7 @@ def run(self): | |||
include_dirs += [ | |||
cwd, | |||
os.path.join(cwd, "torch", "csrc"), | |||
os.path.join(cwd, "torch", "csrc", "api", "include"), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
class Module { | ||
public: | ||
/// Tells the base `Module` about the name of the submodule. | ||
explicit Module(std::string name); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if (auto method = methods.find(name)) { | ||
return method->get(); | ||
} | ||
return nullptr; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
#include "torch/csrc/jit/ir.h" | ||
#include "torch/csrc/jit/graph_executor.h" | ||
#include "torch/csrc/autograd/variable.h" | ||
#include <ATen/optional.h> | ||
|
||
#include "torch/csrc/api/include/torch/detail/ordered_dict.h" |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
append(std::forward<Tail>(tail)...); | ||
} | ||
|
||
std::vector<std::unique_ptr<Module>> modules_; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
/// Adds a new `Module` to the `Sequential` container. | ||
template <typename M> | ||
void append(M&& module) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
CC @smessmer who is our resident C++ expert |
@@ -10,6 +10,11 @@ if (VERBOSE) | |||
message(STATUS "ATEN_BUILD_PATH is ${ATEN_BUILD_PATH}") | |||
endif() | |||
|
|||
set(CMAKE_CXX_FLAGS "--std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-variable ${CMAKE_CXX_FLAGS}") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Changes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM. Let's get this in soon, because it's getting dangerously large.
|
||
## API tests | ||
|
||
SET(TORCH_API_TEST_SRCS ${TORCH_SRC_DIR}/csrc/api/test/test.cpp) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
"${TORCH_SRC_DIR}/../aten/src/ATen/utils/catch/single_include" | ||
"${COMMON_INCLUDES}") | ||
|
||
INSTALL(TARGETS test_api |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
size_t size() const noexcept; | ||
|
||
protected: | ||
std::vector<Item> items_; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
void Module::type(at::ScalarType new_type) { | ||
// parameters().apply([=](Tensor& tensor) { tensor.toType_(new_type); }); | ||
// buffers().apply([=](Tensor& tensor) { tensor.toType_(new_type); }); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
#include "torch/csrc/jit/ir.h" | ||
#include "torch/csrc/jit/graph_executor.h" | ||
#include "torch/csrc/autograd/variable.h" | ||
#include <ATen/optional.h> | ||
|
||
#include "torch/csrc/api/include/torch/detail/ordered_dict.h" |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if (auto method = methods.find(name)) { | ||
return method->get(); | ||
} | ||
return nullptr; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
py::tuple result(self.get_parameters().size()); | ||
size_t index = 0; | ||
for (auto& parameter : self.get_parameters()) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
for (auto& method : self.get_methods()) { | ||
method_names.push_back((*method)->name()); | ||
} | ||
return method_names; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@apaszke I was actually planning on doing all development in this PR, and squashing commits to make it manageable. This means we don't have to make it "public" yet (so people don't use it). And I think if I push it onto a branch, say |
I think it's easier to merge this thing in parts, even if it's going to undergo breaking changes. It's already 1,1k lines long, and reviewing PRs that have 2k+ is really hard. It's probably ok to merge this even before it's public. I doubt it's going to be used before we actually start advertising it, and we should make it clear that it's really experimental and will break your code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure this isn't merged before you remove the accidental ONNX submodule change.
Will be rebased on top of autogradpp |
This PR will be the development ground of PyTorch's C++ API for a while. To get CI feedback, without pushing to master, I will be posting continuous updates into this PR, one isolated, squashed commit at a time, and request reviews of individual commits.
CC @ebetica @jgehring