Skip to content

Commit

Permalink
Fix a small issue that on MPS we don't need to reshape the output.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jun 19, 2024
1 parent a33f509 commit a08e4b4
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions lib/nnc/mps/ccv_nnc_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ MPSGraphTensorNamedDataLayout ccv_nnc_mps_tensor_data_layout(const int format)
for (i = 0; i < nd; i++)
[shape addObject:@(dim[i])];
}
MPSGraphTensorData* data = [[MPSGraphTensorData alloc] initWithMTLBuffer:(id<MTLBuffer>)buffer shape:shape dataType:ccv_nnc_mps_datatype(tensor_view->info.datatype)];
MPSGraphTensorData* data = [[MPSGraphTensorData alloc] initWithMTLBuffer:(id<MTLBuffer>)buffer shape:shape dataType:ccv_nnc_mps_datatype(datatype)];
[shape release];
return [data autorelease];
}
Expand All @@ -1164,6 +1164,22 @@ MPSGraphTensorNamedDataLayout ccv_nnc_mps_tensor_data_layout(const int format)
return ccv_nnc_mps_graph_tensor_data_with_buffer(tensor_view, dim, stride, mpgetbuffer((ccv_nnc_tensor_t*)tensor_view), mpgetoffset((ccv_nnc_tensor_t*)tensor_view));
}

static MPSGraphTensorData* ccv_nnc_mps_graph_output_tensor_data(const ccv_nnc_tensor_view_t* tensor_view, const int dim[CCV_NNC_MAX_DIM_ALLOC], const int stride[CCV_NNC_MAX_DIM_ALLOC])
{
const int nd = ccv_nnc_tensor_nd(dim);
int i;
NSMutableArray<NSNumber*>* shape = [NSMutableArray new];
const int datatype = CCV_GET_DATA_TYPE(tensor_view->info.datatype) == CCV_QX ? ((tensor_view->info.datatype & 0xff) << 12) : tensor_view->info.datatype;
assert(CCV_IS_TENSOR_CONTIGUOUS(tensor_view));
assert(mpgetoffset((ccv_nnc_tensor_t*)tensor_view) == 0);
for (i = 0; i < nd; i++)
[shape addObject:@(dim[i])];
void* buffer = mpgetbuffer((ccv_nnc_tensor_t*)tensor_view);
MPSGraphTensorData* data = [[MPSGraphTensorData alloc] initWithMTLBuffer:(id<MTLBuffer>)buffer shape:shape dataType:ccv_nnc_mps_datatype(datatype)];
[shape release];
return [data autorelease];
}

MPSGraphTensorData* ccv_nnc_mps_graph_constant_data(const float val, const int datatype)
{
id<MTLBuffer> buffer;
Expand Down Expand Up @@ -1248,7 +1264,7 @@ void ccv_nnc_mps_graph_executable_result(MPSGraphExecutable* executable, MPSComm
{
NSMutableArray<MPSGraphTensorData*>* results = [NSMutableArray new];
for (i = 0; i < size; i++)
[results addObject:ccv_nnc_mps_graph_tensor_data(data[i], dim[i], stride[i])];
[results addObject:ccv_nnc_mps_graph_output_tensor_data(data[i], dim[i], stride[i])];
[executable encodeToCommandBuffer:command_buffer inputsArray:inputsArray resultsArray:results executionDescriptor:nil];
[results release];
return;
Expand Down

0 comments on commit a08e4b4

Please sign in to comment.