Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert parameters from server properly in iOS example #2288

Merged
merged 9 commits into from
Sep 6, 2023
Merged
38 changes: 18 additions & 20 deletions examples/ios/FLiOS.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
19764EED297D3ABE009F3E5D /* BenchmarkModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19764EEC297D3ABE009F3E5D /* BenchmarkModel.swift */; };
19764EEF297D3AC9009F3E5D /* BenchmarkSuite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19764EEE297D3AC9009F3E5D /* BenchmarkSuite.swift */; };
21075DC929788ED400311461 /* FLiOSModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21075DC829788ED400311461 /* FLiOSModel.swift */; };
216A90382A0CBE5400E0B532 /* flwr in Frameworks */ = {isa = PBXBuildFile; productRef = 216A90372A0CBE5400E0B532 /* flwr */; };
210BBB812AA5D10C00B283AC /* MLFlwrClientModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 210BBB802AA5D10C00B283AC /* MLFlwrClientModel.swift */; };
210BBB822AA5D10C00B283AC /* MLFlwrClient.swift in Sources */ = {isa = PBXBuildFile; fileRef = 210BBB7F2AA5D10C00B283AC /* MLFlwrClient.swift */; };
21A77AC32865B22B0062EBD8 /* Array+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21A77A6F2865B22A0062EBD8 /* Array+Extensions.swift */; };
21A77AC42865B22B0062EBD8 /* Math.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21A77A702865B22A0062EBD8 /* Math.swift */; };
21A77AC52865B22B0062EBD8 /* UIImage+CVPixelBuffer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21A77A712865B22A0062EBD8 /* UIImage+CVPixelBuffer.swift */; };
Expand Down Expand Up @@ -127,6 +128,9 @@
19764EEC297D3ABE009F3E5D /* BenchmarkModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BenchmarkModel.swift; sourceTree = "<group>"; };
19764EEE297D3AC9009F3E5D /* BenchmarkSuite.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BenchmarkSuite.swift; sourceTree = "<group>"; };
21075DC829788ED400311461 /* FLiOSModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FLiOSModel.swift; sourceTree = "<group>"; };
210BBB7E2AA5D04500B283AC /* flwr */ = {isa = PBXFileReference; lastKnownFileType = wrapper; name = flwr; path = ../../src/swift/flwr; sourceTree = "<group>"; };
210BBB7F2AA5D10C00B283AC /* MLFlwrClient.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLFlwrClient.swift; sourceTree = "<group>"; };
210BBB802AA5D10C00B283AC /* MLFlwrClientModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLFlwrClientModel.swift; sourceTree = "<group>"; };
21A77A6F2865B22A0062EBD8 /* Array+Extensions.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Array+Extensions.swift"; sourceTree = "<group>"; };
21A77A702865B22A0062EBD8 /* Math.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Math.swift; sourceTree = "<group>"; };
21A77A712865B22A0062EBD8 /* UIImage+CVPixelBuffer.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "UIImage+CVPixelBuffer.swift"; sourceTree = "<group>"; };
Expand Down Expand Up @@ -229,7 +233,6 @@
buildActionMask = 2147483647;
files = (
50B5A2C6294C848000B0ABC6 /* flwr in Frameworks */,
216A90382A0CBE5400E0B532 /* flwr in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand All @@ -246,6 +249,14 @@
path = Benchmark;
sourceTree = "<group>";
};
210BBB7D2AA5D04500B283AC /* Packages */ = {
isa = PBXGroup;
children = (
210BBB7E2AA5D04500B283AC /* flwr */,
);
name = Packages;
sourceTree = "<group>";
};
21A77A6E2865B22A0062EBD8 /* CoreMLHelpers */ = {
isa = PBXGroup;
children = (
Expand Down Expand Up @@ -346,6 +357,8 @@
21A77B162865B23D0062EBD8 /* CoreMLClient */ = {
isa = PBXGroup;
children = (
210BBB7F2AA5D10C00B283AC /* MLFlwrClient.swift */,
210BBB802AA5D10C00B283AC /* MLFlwrClientModel.swift */,
50CDF0E0299D807300AEBD55 /* DataLoader.swift */,
21A77B202865B3B10062EBD8 /* MLModelInspect.swift */,
50EB97C929AC9C94000AC771 /* Constants.swift */,
Expand All @@ -363,6 +376,7 @@
21B540642865B14000185DEE = {
isa = PBXGroup;
children = (
210BBB7D2AA5D04500B283AC /* Packages */,
21B5406F2865B14000185DEE /* FLiOS */,
21B5406E2865B14000185DEE /* Products */,
21A77B2E2865B5420062EBD8 /* Frameworks */,
Expand Down Expand Up @@ -440,7 +454,6 @@
name = FLiOS;
packageProductDependencies = (
50B5A2C5294C848000B0ABC6 /* flwr */,
216A90372A0CBE5400E0B532 /* flwr */,
);
productName = FlowerCoreML;
productReference = 21B5406D2865B14000185DEE /* FLiOS.app */;
Expand Down Expand Up @@ -471,7 +484,6 @@
);
mainGroup = 21B540642865B14000185DEE;
packageReferences = (
216A90362A0CBE5400E0B532 /* XCRemoteSwiftPackageReference "flower-swift" */,
);
productRefGroup = 21B5406E2865B14000185DEE /* Products */;
projectDirPath = "";
Expand Down Expand Up @@ -522,6 +534,7 @@
21A77ADC2865B22B0062EBD8 /* BayesianProbitRegressor.proto in Sources */,
21A77ADE2865B22B0062EBD8 /* Gazetteer.proto in Sources */,
21A77ADB2865B22B0062EBD8 /* SoundAnalysisPreprocessing.pb.swift in Sources */,
210BBB822AA5D10C00B283AC /* MLFlwrClient.swift in Sources */,
21A77AC32865B22B0062EBD8 /* Array+Extensions.swift in Sources */,
21A77B152865B22B0062EBD8 /* Scaler.pb.swift in Sources */,
21A77B092865B22B0062EBD8 /* Parameters.pb.swift in Sources */,
Expand Down Expand Up @@ -585,6 +598,7 @@
21A77AF72865B22B0062EBD8 /* ItemSimilarityRecommender.pb.swift in Sources */,
21A77AD92865B22B0062EBD8 /* FeatureVectorizer.proto in Sources */,
21A77B012865B22B0062EBD8 /* TextClassifier.proto in Sources */,
210BBB812AA5D10C00B283AC /* MLFlwrClientModel.swift in Sources */,
21A77B0F2865B22B0062EBD8 /* LinkedModel.proto in Sources */,
21A77ADF2865B22B0062EBD8 /* MIL.pb.swift in Sources */,
21A77AE52865B22B0062EBD8 /* GLMRegressor.pb.swift in Sources */,
Expand Down Expand Up @@ -804,23 +818,7 @@
};
/* End XCConfigurationList section */

/* Begin XCRemoteSwiftPackageReference section */
216A90362A0CBE5400E0B532 /* XCRemoteSwiftPackageReference "flower-swift" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https:/adap/flower-swift.git";
requirement = {
branch = main;
kind = branch;
};
};
/* End XCRemoteSwiftPackageReference section */

/* Begin XCSwiftPackageProductDependency section */
216A90372A0CBE5400E0B532 /* flwr */ = {
isa = XCSwiftPackageProductDependency;
package = 216A90362A0CBE5400E0B532 /* XCRemoteSwiftPackageReference "flower-swift" */;
productName = flwr;
};
21A77B312865B55D0062EBD8 /* Flower */ = {
isa = XCSwiftPackageProductDependency;
productName = Flower;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import NIOCore
import NIOPosix
import CoreML
import os
import flwr

/// Set of CoreML machine learning task options.
public enum MLTask {
Expand Down Expand Up @@ -55,6 +56,7 @@ public class MLFlwrClient: Client {

private var compiledModelUrl: URL
private var tempModelUrl: URL
private var modelUrl: URL

private let log = Logger(subsystem: Bundle.main.bundleIdentifier ?? "flwr.Flower",
category: String(describing: MLFlwrClient.self))
Expand All @@ -65,8 +67,9 @@ public class MLFlwrClient: Client {
/// - layerWrappers: A MLLayerWrapper struct that contains layer information.
/// - dataLoader: A MLDataLoader struct that contains train- and testdata batches.
/// - compiledModelUrl: An URL specifying the location or path of the compiled model.
public init(layerWrappers: [MLLayerWrapper], dataLoader: MLDataLoader, compiledModelUrl: URL) {
self.parameters = MLParameter(layerWrappers: layerWrappers)
public init(layerWrappers: [MLLayerWrapper], dataLoader: MLDataLoader, compiledModelUrl: URL, modelUrl: URL) {
self.modelUrl = modelUrl
self.parameters = MLParameter(layerWrappers: layerWrappers, modelUrl: modelUrl, compiledModelUrl: compiledModelUrl)
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
self.dataLoader = dataLoader
self.compiledModelUrl = compiledModelUrl
Expand All @@ -85,6 +88,7 @@ public class MLFlwrClient: Client {
///
/// - Returns: Parameters from the local model
public func getParameters() -> GetParametersRes {
parameters.initializeParameters()
let parameters = parameters.weightsToParameters()
let status = Status(code: .ok, message: String())

Expand Down Expand Up @@ -128,7 +132,7 @@ public class MLFlwrClient: Client {
}

let progressHandler: (MLUpdateContext) -> Void = { contextProgress in
let loss = String(format: "%.4f", contextProgress.metrics[.lossValue] as! Double)
let loss = String(format: "%.20f", contextProgress.metrics[.lossValue] as! Float)
switch task {
case .train:
self.log.info("Epoch \(contextProgress.metrics[.epochIndex] as! Int + 1) finished with loss \(loss)")
Expand All @@ -143,8 +147,8 @@ public class MLFlwrClient: Client {
self.saveModel(finalContext)
}

let loss = finalContext.metrics[.lossValue] as! Double
let result = MLResult(loss: loss, numSamples: dataset.count, accuracy: (1.0 - loss) * 100)
let loss = String(format: "%.20f", finalContext.metrics[.lossValue] as! Double)
let result = MLResult(loss: Double(loss)!, numSamples: dataset.count, accuracy: (1.0 - Double(loss)!) * 100)
promise?.succeed(result)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import Foundation
import CoreML
import os
import flwr

typealias Model = CoreML_Specification_Model
typealias NeuralNetwork = CoreML_Specification_NeuralNetwork
typealias NeuralNetworkLayer = CoreML_Specification_NeuralNetworkLayer

/// Container for train and test dataset.
public struct MLDataLoader {
Expand Down Expand Up @@ -47,15 +52,22 @@ public class MLParameter {
private var parameterConverter = ParameterConverter.shared

var layerWrappers: [MLLayerWrapper]
var model: Model?
let modelUrl: URL
let compiledModelUrl: URL

private let log = Logger(subsystem: Bundle.main.bundleIdentifier ?? "flwr.Flower",
category: String(describing: MLParameter.self))

/// Inits MLParameter class that contains information about the model parameters and implements routines for their update and transformation.
///
/// - Parameters:
/// - layerWrappers: Information about the layer provided with primitive data types.
public init(layerWrappers: [MLLayerWrapper]) {
public init(layerWrappers: [MLLayerWrapper], modelUrl: URL, compiledModelUrl: URL) {
self.layerWrappers = layerWrappers
self.modelUrl = modelUrl
self.compiledModelUrl = compiledModelUrl
self.model = try? Model(serializedData: try Data(contentsOf: modelUrl))
}

/// Converts the Parameters structure to MLModelConfiguration to interface with CoreML.
Expand All @@ -81,19 +93,28 @@ public class MLParameter {

layerWrappers[index].weights = weightsArray
if layerWrappers[index].isUpdatable {
if let weightsMultiArray = parameterConverter.dataToMultiArray(data: data) {
let weightsShape = weightsMultiArray.shape.map { Int16(truncating: $0) }
guard weightsShape == layerWrappers[index].shape else {
log.info("shape not the same")
guard model?.neuralNetwork.layers != nil else {
continue
}
for (indexB, neuralNetworkLayer) in model!.neuralNetwork.layers.enumerated() {
guard layerWrappers[index].name == neuralNetworkLayer.name else {
continue
}
switch neuralNetworkLayer.layer! {
case .convolution:
model!.neuralNetwork.layers[indexB].convolution.weights.floatValue = layerWrappers[index].weights
case .innerProduct:
model!.neuralNetwork.layers[indexB].innerProduct.weights.floatValue = layerWrappers[index].weights
default:
log.info("unexpected layer \(neuralNetworkLayer.name)")
continue
}
let paramKey = MLParameterKey.weights.scoped(to: layerWrappers[index].name)
config.parameters?[paramKey] = weightsMultiArray
}
}
}
}

exportModel()
return config
}

Expand All @@ -108,6 +129,47 @@ public class MLParameter {
return Parameters(tensors: dataArray, tensorType: "ndarray")
}

private func exportModel() {
let modelFileName = modelUrl.deletingPathExtension().lastPathComponent
let fileManager = FileManager.default
let tempModelUrl = appDirectory.appendingPathComponent("temp\(modelFileName).mlmodel")
try? model?.serializedData().write(to: tempModelUrl)
if let compiledTempModelUrl = try? MLModel.compileModel(at: tempModelUrl) {
_ = try? fileManager.replaceItemAt(compiledModelUrl, withItemAt: compiledTempModelUrl)
}
}

func initializeParameters() {
guard ((model?.neuralNetwork.layers) != nil) else {
return
}
for (indexA, neuralNetworkLayer) in model!.neuralNetwork.layers.enumerated() {
for (indexB, layer) in layerWrappers.enumerated() {
if layer.name != neuralNetworkLayer.name { continue }
switch neuralNetworkLayer.layer! {
case .convolution:
let convolution = neuralNetworkLayer.convolution
//shape definition = [outputChannels, kernelChannels, kernelHeight, kernelWidth]
let upperLower = Float(6.0 / Float(Int16(convolution.outputChannels) + Int16(convolution.kernelChannels) + Int16(convolution.kernelSize[0]) + Int16(convolution.kernelSize[1]))).squareRoot()
let initialise = (0..<(neuralNetworkLayer.convolution.weights.floatValue.count)).map { _ in Float.random(in: -upperLower...upperLower) }
model?.neuralNetwork.layers[indexA].convolution.weights.floatValue = initialise
layerWrappers[indexB].weights = initialise
case .innerProduct:
let innerProduct = neuralNetworkLayer.innerProduct
//shape definition = [C_out, C_in].
let upperLower = Float(6.0 / Float(Int16(innerProduct.outputChannels) + Int16(innerProduct.inputChannels))).squareRoot()
let initialise = (0..<(neuralNetworkLayer.innerProduct.weights.floatValue.count)).map { _ in Float.random(in: -upperLower...upperLower) }
model?.neuralNetwork.layers[indexA].innerProduct.weights.floatValue = initialise
layerWrappers[indexB].weights = initialise
default:
log.info("unexpected layer \(neuralNetworkLayer.name)")
continue
}
}
}
exportModel()
}

/// Updates the layers given the CoreML update context
///
/// - Parameters:
Expand Down
5 changes: 3 additions & 2 deletions examples/ios/FLiOS/FLiOSModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ public class FLiOSModel: ObservableObject {
let layerWrappers = modelInspect.getLayerWrappers()
self.mlFlwrClient = MLFlwrClient(layerWrappers: layerWrappers,
dataLoader: dataLoader,
compiledModelUrl: compiledModelUrl)
compiledModelUrl: compiledModelUrl,
modelUrl: url)
case .local:
self.localClient = LocalClient(dataLoader: dataLoader, compiledModelUrl: compiledModelUrl)
}
Expand Down Expand Up @@ -203,7 +204,7 @@ class LocalClient {

func runMLTask(statusHandler: @escaping (Constants.TaskStatus) -> Void,
numEpochs: Int,
task: flwr.MLTask
task: MLTask
) {
let dataset: MLBatchProvider
let configuration = MLModelConfiguration()
Expand Down
Loading