# coding=utf-8
__author__ = "Dimitris Karkalousos"
# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/classes/exportable.py
from abc import ABC
from typing import List, Union
import torch
from pytorch_lightning.core.module import _jit_is_scripting
from torch.onnx import TrainingMode
from atommic.core.classes.common import typecheck
from atommic.core.utils.neural_type_utils import get_dynamic_axes, get_io_names
from atommic.utils import logging
from atommic.utils.export_utils import (
ExportFormat,
augment_filename,
get_export_format,
parse_input_example,
replace_for_export,
verify_runtime,
verify_torchscript,
wrap_forward_method,
)
__all__ = ["ExportFormat", "Exportable"]
[docs]class Exportable(ABC):
"""This Interface should be implemented by particular classes derived from atommic.core.ModelPT. It gives these
entities ability to be exported for deployment to formats such as ONNX.
"""
@property
def input_module(self):
"""Implement this method to return the input module"""
return self
@property
def output_module(self):
"""Implement this method to return the output module."""
return self
[docs] def export(
self,
output: str,
input_example=None,
verbose=False,
do_constant_folding=True,
onnx_opset_version=None,
check_trace: Union[bool, List[torch.Tensor]] = False,
dynamic_axes=None,
check_tolerance=0.01,
export_modules_as_functions: bool = False,
keep_initializers_as_inputs=None,
):
"""Export the module to a file.
Parameters
----------
output : str
The output file path.
input_example : dict
A dictionary of input names and values.
verbose : bool
If True, print out the export process.
do_constant_folding : bool
If True, do constant folding.
onnx_opset_version : int
The ONNX opset version to use.
check_trace : bool or list of torch.Tensor
If True, check the trace of the exported model.
dynamic_axes : dict
A dictionary of input names and dynamic axes.
check_tolerance : float
The tolerance for the check_trace.
export_modules_as_functions : bool
If True, export modules as functions.
keep_initializers_as_inputs : bool
If True, keep initializers as inputs.
"""
all_out = []
all_descr = []
for subnet_name in self.list_export_subnets():
model = self.get_export_subnet(subnet_name)
out_name = augment_filename(output, subnet_name)
out, descr, out_example = model._export( # pylint: disable=protected-access
out_name,
input_example=input_example,
verbose=verbose,
do_constant_folding=do_constant_folding,
onnx_opset_version=onnx_opset_version,
check_trace=check_trace,
dynamic_axes=dynamic_axes,
check_tolerance=check_tolerance,
export_modules_as_functions=export_modules_as_functions,
keep_initializers_as_inputs=keep_initializers_as_inputs,
)
# Propagate input example (default scenario, may need to be overriden)
if input_example is not None:
input_example = out_example
all_out.append(out)
all_descr.append(descr)
logging.info(f"Successfully exported {model.__class__.__name__} to {out_name}")
return all_out, all_descr
def _export(
self,
output: str,
input_example=None,
verbose=False,
do_constant_folding=True,
onnx_opset_version=None,
training=TrainingMode.EVAL, # pylint: disable=unused-argument
check_trace: Union[bool, List[torch.Tensor]] = False,
dynamic_axes=None,
check_tolerance=0.01,
export_modules_as_functions: bool = False,
keep_initializers_as_inputs=None,
):
"""Helper to export the module to a file.
Parameters
----------
output : str
The output file path.
input_example : dict
A dictionary of input names and values.
verbose : bool
If True, print out the export process.
do_constant_folding : bool
If True, do constant folding.
onnx_opset_version : int
The ONNX opset version to use.
training : TrainingMode
Training mode for the export.
check_trace : bool or list of torch.Tensor
If True, check the trace of the exported model.
dynamic_axes : dict
A dictionary of input names and dynamic axes.
check_tolerance : float
The tolerance for the check_trace.
export_modules_as_functions : bool
If True, export modules as functions.
keep_initializers_as_inputs : bool
If True, keep initializers as inputs.
"""
my_args = locals().copy()
my_args.pop("self")
self.eval() # type: ignore
for param in self.parameters(): # type: ignore
param.requires_grad = False
exportables = [m for m in self.modules() if isinstance(m, Exportable)] # type: ignore
qual_name = f"{self.__module__}.{self.__class__.__qualname__}"
exp_format = get_export_format(output)
output_descr = f"{qual_name} exported to {exp_format}"
# Pytorch's default opset version is too low, using reasonable latest one
if onnx_opset_version is None:
onnx_opset_version = 16
try:
# Disable typechecks
typecheck.set_typecheck_enabled(enabled=False)
# Allow user to completely override forward method to export
forward_method, old_forward_method = wrap_forward_method(self)
with torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True), _jit_is_scripting():
if input_example is None:
input_example = self.input_module.input_example()
# Remove i/o examples from args we propagate to enclosed Exportables
my_args.pop("output")
my_args.pop("input_example")
# Run (possibly overridden) prepare methods before calling forward()
for ex in exportables:
ex._prepare_for_export(**my_args, noreplace=True) # pylint: disable=protected-access
self._prepare_for_export(output=output, input_example=input_example, **my_args)
input_list, input_dict = parse_input_example(input_example)
input_names = self.input_names
output_names = self.output_names
output_example = tuple(self.forward(*input_list, **input_dict)) # type: ignore
if check_trace:
if isinstance(check_trace, bool):
check_trace_input = [input_example]
else:
check_trace_input = check_trace
jitted_model = self
if exp_format == ExportFormat.TORCHSCRIPT:
jitted_model = torch.jit.trace_module(
self,
{"forward": tuple(input_list) + tuple(input_dict.values())},
strict=True,
check_trace=check_trace,
check_tolerance=check_tolerance,
)
jitted_model = torch.jit.freeze(jitted_model)
if verbose:
logging.info(f"JIT code:\n{jitted_model.code}") # type: ignore
jitted_model.save(output) # type: ignore
jitted_model = torch.jit.load(output)
if check_trace:
verify_torchscript(jitted_model, output, check_trace_input, check_tolerance)
elif exp_format == ExportFormat.ONNX:
# dynamic axis is a mapping from input/output_name => list of "dynamic" indices
if dynamic_axes is None:
dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names)
dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names))
torch.onnx.export(
jitted_model,
input_example,
output,
input_names=input_names,
output_names=output_names,
verbose=verbose,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
opset_version=onnx_opset_version,
keep_initializers_as_inputs=keep_initializers_as_inputs,
export_modules_as_functions=export_modules_as_functions,
)
if check_trace:
verify_runtime(self, output, check_trace_input, input_names, check_tolerance=check_tolerance)
else:
raise ValueError(f"Encountered unknown export format {exp_format}.")
finally:
typecheck.set_typecheck_enabled(enabled=True)
if forward_method: # pylint: disable=used-before-assignment
type(self).forward = old_forward_method # type: ignore # pylint: disable=used-before-assignment
self._export_teardown()
return (output, output_descr, output_example)
@property
def disabled_deployment_input_names(self):
"""Implement this method to return a set of input names disabled for export"""
return set()
@property
def disabled_deployment_output_names(self):
"""Implement this method to return a set of output names disabled for export"""
return set()
@property
def supported_export_formats(self):
"""Implement this method to return a set of export formats supported. Default is all types."""
return {ExportFormat.ONNX, ExportFormat.TORCHSCRIPT}
def _prepare_for_export(self, **kwargs):
"""Override this method to prepare module for export. This is in-place operation. Base version does common
necessary module replacements (Apex etc)
"""
if "noreplace" not in kwargs:
replace_for_export(self)
def _export_teardown(self):
"""Override this method for any teardown code after export."""
@property
def input_names(self):
"""Implement this method to return a list of input names"""
return get_io_names(self.input_module.input_types_for_export, self.disabled_deployment_input_names)
@property
def output_names(self):
"""Override this method to return a set of output names disabled for export"""
return get_io_names(self.output_module.output_types_for_export, self.disabled_deployment_output_names)
@property
def input_types_for_export(self):
"""Implement this method to return a list of input types"""
return self.input_types
@property
def output_types_for_export(self):
"""Implement this method to return a list of output types"""
return self.output_types
[docs] def get_export_subnet(self, subnet=None):
"""Returns Exportable subnet model/module to export"""
return self if subnet is None or subnet == "self" else getattr(self, subnet)
[docs] @staticmethod
def list_export_subnets():
"""Returns default set of subnet names exported for this model. First goes the one receiving input
(input_example).
"""
return ["self"]
[docs] def get_export_config(self):
"""Returns export_config dictionary."""
return getattr(self, 'export_config', {})
[docs] def set_export_config(self, args):
"""Sets/updates export_config dictionary."""
ex_config = self.get_export_config()
ex_config.update(args)
self.export_config = ex_config