ONNX
ONNX is an open representation format for machine learning models, which enables AI developers to use models across different libraries and tools. SINGA supports loading ONNX format models for training and inference, and saving models defined using SINGA APIs (e.g., Module) into ONNX format.
SINGA has been tested with the following version of ONNX.
ONNX version | File format version | Opset version ai.onnx | Opset version ai.onnx.ml | Opset version ai.onnx.training |
---|---|---|---|---|
1.6.0 | 6 | 11 | 2 | - |
General usage
Loading an ONNX Model into SINGA
After loading an ONNX model from disk by onnx.load
, You only need to update
the batch-size of input using tensor.PlaceHolder
after SINGA v3.0, the shape
of internal tensors will be inferred automatically.
Then, you should define a class inheriting from sonnx.SONNXModel
and implement
two methods forward
for forward work and train_one_batch
for training work.
After you call model.compile
, the SONNX iterates and translates all the nodes
within the ONNX model's graph into SINGA operators, loads all stored weights and
infers each intermediate tensor's shape.
import onnx
from singa import device
from singa import sonnx
class MyModel(sonnx.SONNXModel):
def __init__(self, onnx_model):
super(MyModel, self).__init__(onnx_model)
def forward(self, *x):
y = super(MyModel, self).forward(*x)
# Since SINGA model returns the output as a list,
# if there is only one output,
# you just need to take the first element.
return y[0]
def train_one_batch(self, x, y):
pass
model_path = "PATH/To/ONNX/MODEL"
onnx_model = onnx.load(model_path)
# convert onnx model into SINGA model
dev = device.create_cuda_gpu()
x = tensor.PlaceHolder(INPUT.shape, device=dev)
model = MyModel(onnx_model)
model.compile([x], is_train=False, use_graph=True, sequential=True)
Inference SINGA model
Once the model is created, you can do inference by calling model.forward
. The
input and output must be SINGA Tensor
instances.
x = tensor.Tensor(device=dev, data=INPUT)
y = model.forward(x)
Saving SINGA model into ONNX Format
Given the input tensors and the output tensors generated by the operators the model, you can trace back all internal operations. Therefore, a SINGA model is defined by the input and outputs tensors. To export a SINGA model into ONNX format, you just need to provide the input and output tensor list.
# x is the input tensor, y is the output tensor
sonnx.to_onnx([x], [y])
Re-training an ONNX model
To train (or refine) an ONNX model using SINGA, you need to implement the
train_one_batch
from sonnx.SONNXModel
and mark the is_train=True
when
calling model.compile
.
from singa import opt
from singa import autograd
class MyModel(sonnx.SONNXModel):
def __init__(self, onnx_model):
super(MyModel, self).__init__(onnx_model)
def forward(self, *x):
y = super(MyModel, self).forward(*x)
return y[0]
def train_one_batch(self, x, y, dist_option, spars):
out = self.forward(x)
loss = autograd.softmax_cross_entropy(out, y)
if dist_option == 'fp32':
self.optimizer.backward_and_update(loss)
elif dist_option == 'fp16':
self.optimizer.backward_and_update_half(loss)
elif dist_option == 'partialUpdate':
self.optimizer.backward_and_partial_update(loss)
elif dist_option == 'sparseTopK':
self.optimizer.backward_and_sparse_update(loss,
topK=True,
spars=spars)
elif dist_option == 'sparseThreshold':
self.optimizer.backward_and_sparse_update(loss,
topK=False,
spars=spars)
return out, loss
def set_optimizer(self, optimizer):
self.optimizer = optimizer
sgd = opt.SGD(lr=0.005, momentum=0.9, weight_decay=1e-5)
model.set_optimizer(sgd)
model.compile([tx], is_train=True, use_graph=graph, sequential=True)
Transfer-learning an ONNX model
You also can append some layers to the end of the ONNX model to do
transfer-learning. The last_layers
accept a negative integer indicating the
layer to cut off from. For example, -1
means cut off after the final output(do
not cut off any layer), -2
means you cut off after the last second layer.
from singa import opt
from singa import autograd
class MyModel(sonnx.SONNXModel):
def __init__(self, onnx_model):
super(MyModel, self).__init__(onnx_model)
self.linear = layer.Linear(1000, 3)
def forward(self, *x):
# cut off after the last third layer
# and append a linear layer
y = super(MyModel, self).forward(*x, last_layers=-3)[0]
y = self.linear(y)
return y
def train_one_batch(self, x, y, dist_option, spars):
out = self.forward(x)
loss = autograd.softmax_cross_entropy(out, y)
if dist_option == 'fp32':
self.optimizer.backward_and_update(loss)
elif dist_option == 'fp16':
self.optimizer.backward_and_update_half(loss)
elif dist_option == 'partialUpdate':
self.optimizer.backward_and_partial_update(loss)
elif dist_option == 'sparseTopK':
self.optimizer.backward_and_sparse_update(loss,
topK=True,
spars=spars)
elif dist_option == 'sparseThreshold':
self.optimizer.backward_and_sparse_update(loss,
topK=False,
spars=spars)
return out, loss
def set_optimizer(self, optimizer):
self.optimizer = optimizer
sgd = opt.SGD(lr=0.005, momentum=0.9, weight_decay=1e-5)
model.set_optimizer(sgd)
model.compile([tx], is_train=True, use_graph=graph, sequential=True)
ONNX model zoo
The ONNX Model Zoo is a collection of pre-trained, state-of-the-art models in the ONNX format contributed by community members. SINGA has supported several CV and NLP models now. More models are going to be supported soon.
Image Classification
This collection of models take images as input, then classifies the major objects in the images into 1000 object categories such as keyboard, mouse, pencil, and many animals.
Model Class | Reference | Description | Link |
---|---|---|---|
MobileNet | Sandler et al. | Light-weight deep neural network best suited for mobile and embedded vision applications. Top-5 error from paper - ~10% | |
ResNet18 | He et al. | A CNN model (up to 152 layers). Uses shortcut connections to achieve higher accuracy when classifying images. Top-5 error from paper - ~3.6% | |
VGG16 | Simonyan et al. | Deep CNN model(up to 19 layers). Similar to AlexNet but uses multiple smaller kernel-sized filters that provides more accuracy when classifying images. Top-5 error from paper - ~8% | |
ShuffleNet_V2 | Simonyan et al. | Extremely computation efficient CNN model that is designed specifically for mobile devices. This network architecture design considers direct metric such as speed, instead of indirect metric like FLOP. Top-1 error from paper - ~30.6% |
We also give some re-training examples by using VGG and ResNet, please check
examples/onnx/training
.
Object Detection
Object detection models detect the presence of multiple objects in an image and segment out areas of the image where the objects are detected.
Model Class | Reference | Description | Link |
---|---|---|---|
Tiny YOLOv2 | Redmon et al. | A real-time CNN for object detection that detects 20 different classes. A smaller version of the more complex full YOLOv2 network. |
Face Analysis
Face detection models identify and/or recognize human faces and emotions in given images.
Model Class | Reference | Description | Link |
---|---|---|---|
ArcFace | Deng et al. | A CNN based model for face recognition which learns discriminative features of faces and produces embeddings for input face images. | |
Emotion FerPlus | Barsoum et al. | Deep CNN for emotion recognition trained on images of faces. |
Machine Comprehension
This subset of natural language processing models that answer questions about a given context paragraph.
Model Class | Reference | Description | Link |
---|---|---|---|
BERT-Squad | Devlin et al. | This model answers questions based on the context of the given input paragraph. | |
RoBERTa | Devlin et al. | A large transformer-based model that predicts sentiment based on given input text. | |
GPT-2 | Devlin et al. | A large transformer-based language model that given a sequence of words within some text, predicts the next word. |
Supported operators
The following operators are supported:
- Acos
- Acosh
- Add
- And
- Asin
- Asinh
- Atan
- Atanh
- AveragePool
- BatchNormalization
- Cast
- Ceil
- Clip
- Concat
- ConstantOfShape
- Conv
- Cos
- Cosh
- Div
- Dropout
- Elu
- Equal
- Erf
- Expand
- Flatten
- Gather
- Gemm
- GlobalAveragePool
- Greater
- HardSigmoid
- Identity
- LeakyRelu
- Less
- Log
- MatMul
- Max
- MaxPool
- Mean
- Min
- Mul
- Neg
- NonZero
- Not
- OneHot
- Or
- Pad
- Pow
- PRelu
- Reciprocal
- ReduceMean
- ReduceSum
- Relu
- Reshape
- ScatterElements
- Selu
- Shape
- Sigmoid
- Sign
- Sin
- Sinh
- Slice
- Softmax
- Softplus
- Softsign
- Split
- Sqrt
- Squeeze
- Sub
- Sum
- Tan
- Tanh
- Tile
- Transpose
- Unsqueeze
- Upsample
- Where
- Xor
Special comments for ONNX backend
Conv, MaxPool and AveragePool
Input must be 1d
(N*C*H)
and 2d(N*C*H*W
) shape anddilation
must be 1.BatchNormalization
epsilon
is 1e-05 and cannot be changed.Cast
Only support float32 and int32, other types are casted to these two types.
Squeeze and Unsqueeze
If you encounter errors when you
Squeeze
orUnsqueeze
betweenTensor
and Scalar, please report to us.Empty tensor Empty tensor is illegal in SINGA.
Implementation
The code of SINGA ONNX locates at python/singa/soonx.py
. There are four main
class, SingaFrontend
, SingaBackend
, SingaRep
and SONNXModel
.
SingaFrontend
translates a SINGA model to an ONNX model; SingaBackend
translates an ONNX model to SingaRep
object which stores all SINGA operators
and tensors(the tensor in this doc means SINGA Tensor
); SingaRep
can be run
like a SINGA model. SONNXModel
inherits from model.Model
which defines a
unified API for SINGA.
SingaFrontend
The entry function of SingaFrontend
is singa_to_onnx_model
which also is
called to_onnx
. singa_to_onnx_model
creates the ONNX model, and it also
create a ONNX graph by using singa_to_onnx_graph
.
singa_to_onnx_graph
accepts the output of the model, and recursively iterate
the SINGA model's graph from the output to get all operators to form a queue.
The input and intermediate tensors, i.e, trainable weights, of the SINGA model
is picked up at the same time. The input is stored in onnx_model.graph.input
;
the output is stored in onnx_model.graph.output
; and the trainable weights are
stored in onnx_model.graph.initializer
.
Then the SINGA operator in the queue is translated to ONNX operators one by one.
_rename_operators
defines the operators name mapping between SINGA and ONNX.
_special_operators
defines which function to be used to translate the
operator.
In addition, some operators in SINGA has different definition with ONNX, that
is, ONNX regards some attributes of SINGA operators as input, so
_unhandled_operators
defines which function to handle the special operator.
Since the bool type is regarded as int32 in SINGA, _bool_operators
defines the
operators to be changed as bool type.
SingaBackend
The entry function of SingaBackend
is prepare
which checks the version of
ONNX model and call _onnx_model_to_singa_ops
then.
The purpose of _onnx_model_to_singa_ops
is to get SINGA tensors and operators.
The tensors are stored in a dictionary by their name in ONNX, and operators are
stored in queue by the form of namedtuple('SingaOps', ['node', 'operator'])
.
For each operator, node
is an instance from OnnxNode which is defined to store
some basic information for an ONNX node; operator
is the SINGA operator's
forward function;
The first step of _onnx_model_to_singa_ops
has four steps, the first one is to
call _parse_graph_params
to get all tensors stored as params
. Then call
_parse_graph_inputs_outputs
to get all input and output information stores as
inputs
and outputs
. Finally, it iterators all nodes within the ONNX graph
and parses it by _onnx_node_to_singa_op
as SIGNA operators or layers and store
them as outputs
. Some weights are stored within an ONNX node called
Constant
, SONNX can handle them by _onnx_constant_to_np
to store it into
params
.
This class finally return a SingaRep
object and stores above params
,
inputs
, outputs
, layers
.
SingaRep
SingaBackend
stores all SINGA tensors and operators. run
accepts the input
of the model and runs the SINGA operators one by one following the operators'
queue. The user can use last_layers
to cut off the model after the last few
layers.
SONNXModel
SONNXModel
inherits from sonnx.SONNXModel
and implements the method
forward
to provide a unified API with other SINGA models.