I have implemented my Pix2Pix GAN model in tensorrt using onnx format. But I do not know how to perform inference on tensorRT model, because input to the model in (3, 512, 512 ) image and output is also (3, 512, 512) image. The samples do not clearly show how to input and output image from tensorRT engine.
This is my code :
def get_image():
img = _load_and_resize("/home/abdullah/Desktop/Profile.png")
img = _shuffle_and_normalize(img)
return img
def _load_and_resize(input_image_path):
image_raw = Image.open(input_image_path)
# convention (width, height) in PIL:
image_resized = image_raw.resize(
(512, 512), resample=Image.BICUBIC)
image_resized = np.array(image_resized, dtype=np.float32, order='C')
return image_resized
def _shuffle_and_normalize(image):
image /= 255.0
# HWC to CHW format:
image = np.transpose(image, [2, 0, 1])
# CHW to NCHW format
image = np.expand_dims(image, axis=0)
# Convert the image to row-major order, also known as "C order":
image = np.array(image, dtype=np.float32, order='C')
return image
def _reshape_output(output):
return np.reshape(output, (512, 512, 3))
def main():
"""Create inference for generator for eye."""
onnx_file_path = 'generator.onnx'
engine_file_path = "generator.trt"
output_shapes = [(1, 3,512, 512)]
image = get_image()
print('IMage shape is ', image.shape)
# Do inference with TensorRT
with get_engine(onnx_file_path, engine_file_path) as engine, engine.create_execution_context() as context:
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
print('Running inference on image {}...'.format(1))
inputs[0].host = image
print(inputs)
trt_outputs = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
# Before doing post-processing, we need to reshape the outputs as the common.do_inference will give us flat arrays.
trt_outputs = [output.reshape(shape) for output, shape in zip(trt_outputs, output_shapes)]
image = _reshape_output(trt_outputs)
image = np.array(image, dtype = 'float32')
plt.imshow(image)
plt.show()
print('Output shape is ', trt_outputs.shape)
But the output is not as it is expected. Please tell me, what is the best way to input image into tensorRT engine and how to get proper output from it. Thanks