Skip to content

Commit

Permalink
Insert parameters from server properly in iOS example (#2288)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: danielnugraha <[email protected]>
Co-authored-by: Taner Topal <[email protected]>
  • Loading branch information
3 people authored Sep 6, 2023
1 parent d266c73 commit e91635b
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 34 deletions.
38 changes: 18 additions & 20 deletions examples/ios/FLiOS.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
19764EED297D3ABE009F3E5D /* BenchmarkModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19764EEC297D3ABE009F3E5D /* BenchmarkModel.swift */; };
19764EEF297D3AC9009F3E5D /* BenchmarkSuite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19764EEE297D3AC9009F3E5D /* BenchmarkSuite.swift */; };
21075DC929788ED400311461 /* FLiOSModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21075DC829788ED400311461 /* FLiOSModel.swift */; };
216A90382A0CBE5400E0B532 /* flwr in Frameworks */ = {isa = PBXBuildFile; productRef = 216A90372A0CBE5400E0B532 /* flwr */; };
210BBB812AA5D10C00B283AC /* MLFlwrClientModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 210BBB802AA5D10C00B283AC /* MLFlwrClientModel.swift */; };
210BBB822AA5D10C00B283AC /* MLFlwrClient.swift in Sources */ = {isa = PBXBuildFile; fileRef = 210BBB7F2AA5D10C00B283AC /* MLFlwrClient.swift */; };
21A77AC32865B22B0062EBD8 /* Array+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21A77A6F2865B22A0062EBD8 /* Array+Extensions.swift */; };
21A77AC42865B22B0062EBD8 /* Math.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21A77A702865B22A0062EBD8 /* Math.swift */; };
21A77AC52865B22B0062EBD8 /* UIImage+CVPixelBuffer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21A77A712865B22A0062EBD8 /* UIImage+CVPixelBuffer.swift */; };
Expand Down Expand Up @@ -127,6 +128,9 @@
19764EEC297D3ABE009F3E5D /* BenchmarkModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BenchmarkModel.swift; sourceTree = "<group>"; };
19764EEE297D3AC9009F3E5D /* BenchmarkSuite.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BenchmarkSuite.swift; sourceTree = "<group>"; };
21075DC829788ED400311461 /* FLiOSModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FLiOSModel.swift; sourceTree = "<group>"; };
210BBB7E2AA5D04500B283AC /* flwr */ = {isa = PBXFileReference; lastKnownFileType = wrapper; name = flwr; path = ../../src/swift/flwr; sourceTree = "<group>"; };
210BBB7F2AA5D10C00B283AC /* MLFlwrClient.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLFlwrClient.swift; sourceTree = "<group>"; };
210BBB802AA5D10C00B283AC /* MLFlwrClientModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLFlwrClientModel.swift; sourceTree = "<group>"; };
21A77A6F2865B22A0062EBD8 /* Array+Extensions.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Array+Extensions.swift"; sourceTree = "<group>"; };
21A77A702865B22A0062EBD8 /* Math.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Math.swift; sourceTree = "<group>"; };
21A77A712865B22A0062EBD8 /* UIImage+CVPixelBuffer.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "UIImage+CVPixelBuffer.swift"; sourceTree = "<group>"; };
Expand Down Expand Up @@ -229,7 +233,6 @@
buildActionMask = 2147483647;
files = (
50B5A2C6294C848000B0ABC6 /* flwr in Frameworks */,
216A90382A0CBE5400E0B532 /* flwr in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand All @@ -246,6 +249,14 @@
path = Benchmark;
sourceTree = "<group>";
};
210BBB7D2AA5D04500B283AC /* Packages */ = {
isa = PBXGroup;
children = (
210BBB7E2AA5D04500B283AC /* flwr */,
);
name = Packages;
sourceTree = "<group>";
};
21A77A6E2865B22A0062EBD8 /* CoreMLHelpers */ = {
isa = PBXGroup;
children = (
Expand Down Expand Up @@ -346,6 +357,8 @@
21A77B162865B23D0062EBD8 /* CoreMLClient */ = {
isa = PBXGroup;
children = (
210BBB7F2AA5D10C00B283AC /* MLFlwrClient.swift */,
210BBB802AA5D10C00B283AC /* MLFlwrClientModel.swift */,
50CDF0E0299D807300AEBD55 /* DataLoader.swift */,
21A77B202865B3B10062EBD8 /* MLModelInspect.swift */,
50EB97C929AC9C94000AC771 /* Constants.swift */,
Expand All @@ -363,6 +376,7 @@
21B540642865B14000185DEE = {
isa = PBXGroup;
children = (
210BBB7D2AA5D04500B283AC /* Packages */,
21B5406F2865B14000185DEE /* FLiOS */,
21B5406E2865B14000185DEE /* Products */,
21A77B2E2865B5420062EBD8 /* Frameworks */,
Expand Down Expand Up @@ -440,7 +454,6 @@
name = FLiOS;
packageProductDependencies = (
50B5A2C5294C848000B0ABC6 /* flwr */,
216A90372A0CBE5400E0B532 /* flwr */,
);
productName = FlowerCoreML;
productReference = 21B5406D2865B14000185DEE /* FLiOS.app */;
Expand Down Expand Up @@ -471,7 +484,6 @@
);
mainGroup = 21B540642865B14000185DEE;
packageReferences = (
216A90362A0CBE5400E0B532 /* XCRemoteSwiftPackageReference "flower-swift" */,
);
productRefGroup = 21B5406E2865B14000185DEE /* Products */;
projectDirPath = "";
Expand Down Expand Up @@ -522,6 +534,7 @@
21A77ADC2865B22B0062EBD8 /* BayesianProbitRegressor.proto in Sources */,
21A77ADE2865B22B0062EBD8 /* Gazetteer.proto in Sources */,
21A77ADB2865B22B0062EBD8 /* SoundAnalysisPreprocessing.pb.swift in Sources */,
210BBB822AA5D10C00B283AC /* MLFlwrClient.swift in Sources */,
21A77AC32865B22B0062EBD8 /* Array+Extensions.swift in Sources */,
21A77B152865B22B0062EBD8 /* Scaler.pb.swift in Sources */,
21A77B092865B22B0062EBD8 /* Parameters.pb.swift in Sources */,
Expand Down Expand Up @@ -585,6 +598,7 @@
21A77AF72865B22B0062EBD8 /* ItemSimilarityRecommender.pb.swift in Sources */,
21A77AD92865B22B0062EBD8 /* FeatureVectorizer.proto in Sources */,
21A77B012865B22B0062EBD8 /* TextClassifier.proto in Sources */,
210BBB812AA5D10C00B283AC /* MLFlwrClientModel.swift in Sources */,
21A77B0F2865B22B0062EBD8 /* LinkedModel.proto in Sources */,
21A77ADF2865B22B0062EBD8 /* MIL.pb.swift in Sources */,
21A77AE52865B22B0062EBD8 /* GLMRegressor.pb.swift in Sources */,
Expand Down Expand Up @@ -804,23 +818,7 @@
};
/* End XCConfigurationList section */

/* Begin XCRemoteSwiftPackageReference section */
216A90362A0CBE5400E0B532 /* XCRemoteSwiftPackageReference "flower-swift" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https:/adap/flower-swift.git";
requirement = {
branch = main;
kind = branch;
};
};
/* End XCRemoteSwiftPackageReference section */

/* Begin XCSwiftPackageProductDependency section */
216A90372A0CBE5400E0B532 /* flwr */ = {
isa = XCSwiftPackageProductDependency;
package = 216A90362A0CBE5400E0B532 /* XCRemoteSwiftPackageReference "flower-swift" */;
productName = flwr;
};
21A77B312865B55D0062EBD8 /* Flower */ = {
isa = XCSwiftPackageProductDependency;
productName = Flower;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import NIOCore
import NIOPosix
import CoreML
import os
import flwr

/// Set of CoreML machine learning task options.
public enum MLTask {
Expand Down Expand Up @@ -55,6 +56,7 @@ public class MLFlwrClient: Client {

private var compiledModelUrl: URL
private var tempModelUrl: URL
private var modelUrl: URL

private let log = Logger(subsystem: Bundle.main.bundleIdentifier ?? "flwr.Flower",
category: String(describing: MLFlwrClient.self))
Expand All @@ -65,8 +67,9 @@ public class MLFlwrClient: Client {
/// - layerWrappers: A MLLayerWrapper struct that contains layer information.
/// - dataLoader: A MLDataLoader struct that contains train- and testdata batches.
/// - compiledModelUrl: An URL specifying the location or path of the compiled model.
public init(layerWrappers: [MLLayerWrapper], dataLoader: MLDataLoader, compiledModelUrl: URL) {
self.parameters = MLParameter(layerWrappers: layerWrappers)
public init(layerWrappers: [MLLayerWrapper], dataLoader: MLDataLoader, compiledModelUrl: URL, modelUrl: URL) {
self.modelUrl = modelUrl
self.parameters = MLParameter(layerWrappers: layerWrappers, modelUrl: modelUrl, compiledModelUrl: compiledModelUrl)
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
self.dataLoader = dataLoader
self.compiledModelUrl = compiledModelUrl
Expand All @@ -85,6 +88,7 @@ public class MLFlwrClient: Client {
///
/// - Returns: Parameters from the local model
public func getParameters() -> GetParametersRes {
parameters.initializeParameters()
let parameters = parameters.weightsToParameters()
let status = Status(code: .ok, message: String())

Expand Down Expand Up @@ -128,7 +132,7 @@ public class MLFlwrClient: Client {
}

let progressHandler: (MLUpdateContext) -> Void = { contextProgress in
let loss = String(format: "%.4f", contextProgress.metrics[.lossValue] as! Double)
let loss = String(format: "%.20f", contextProgress.metrics[.lossValue] as! Float)
switch task {
case .train:
self.log.info("Epoch \(contextProgress.metrics[.epochIndex] as! Int + 1) finished with loss \(loss)")
Expand All @@ -143,8 +147,8 @@ public class MLFlwrClient: Client {
self.saveModel(finalContext)
}

let loss = finalContext.metrics[.lossValue] as! Double
let result = MLResult(loss: loss, numSamples: dataset.count, accuracy: (1.0 - loss) * 100)
let loss = String(format: "%.20f", finalContext.metrics[.lossValue] as! Double)
let result = MLResult(loss: Double(loss)!, numSamples: dataset.count, accuracy: (1.0 - Double(loss)!) * 100)
promise?.succeed(result)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import Foundation
import CoreML
import os
import flwr

typealias Model = CoreML_Specification_Model
typealias NeuralNetwork = CoreML_Specification_NeuralNetwork
typealias NeuralNetworkLayer = CoreML_Specification_NeuralNetworkLayer

/// Container for train and test dataset.
public struct MLDataLoader {
Expand Down Expand Up @@ -47,15 +52,22 @@ public class MLParameter {
private var parameterConverter = ParameterConverter.shared

var layerWrappers: [MLLayerWrapper]
var model: Model?
let modelUrl: URL
let compiledModelUrl: URL

private let log = Logger(subsystem: Bundle.main.bundleIdentifier ?? "flwr.Flower",
category: String(describing: MLParameter.self))

/// Inits MLParameter class that contains information about the model parameters and implements routines for their update and transformation.
///
/// - Parameters:
/// - layerWrappers: Information about the layer provided with primitive data types.
public init(layerWrappers: [MLLayerWrapper]) {
public init(layerWrappers: [MLLayerWrapper], modelUrl: URL, compiledModelUrl: URL) {
self.layerWrappers = layerWrappers
self.modelUrl = modelUrl
self.compiledModelUrl = compiledModelUrl
self.model = try? Model(serializedData: try Data(contentsOf: modelUrl))
}

/// Converts the Parameters structure to MLModelConfiguration to interface with CoreML.
Expand All @@ -81,19 +93,28 @@ public class MLParameter {

layerWrappers[index].weights = weightsArray
if layerWrappers[index].isUpdatable {
if let weightsMultiArray = parameterConverter.dataToMultiArray(data: data) {
let weightsShape = weightsMultiArray.shape.map { Int16(truncating: $0) }
guard weightsShape == layerWrappers[index].shape else {
log.info("shape not the same")
guard model?.neuralNetwork.layers != nil else {
continue
}
for (indexB, neuralNetworkLayer) in model!.neuralNetwork.layers.enumerated() {
guard layerWrappers[index].name == neuralNetworkLayer.name else {
continue
}
switch neuralNetworkLayer.layer! {
case .convolution:
model!.neuralNetwork.layers[indexB].convolution.weights.floatValue = layerWrappers[index].weights
case .innerProduct:
model!.neuralNetwork.layers[indexB].innerProduct.weights.floatValue = layerWrappers[index].weights
default:
log.info("unexpected layer \(neuralNetworkLayer.name)")
continue
}
let paramKey = MLParameterKey.weights.scoped(to: layerWrappers[index].name)
config.parameters?[paramKey] = weightsMultiArray
}
}
}
}

exportModel()
return config
}

Expand All @@ -108,6 +129,47 @@ public class MLParameter {
return Parameters(tensors: dataArray, tensorType: "ndarray")
}

private func exportModel() {
let modelFileName = modelUrl.deletingPathExtension().lastPathComponent
let fileManager = FileManager.default
let tempModelUrl = appDirectory.appendingPathComponent("temp\(modelFileName).mlmodel")
try? model?.serializedData().write(to: tempModelUrl)
if let compiledTempModelUrl = try? MLModel.compileModel(at: tempModelUrl) {
_ = try? fileManager.replaceItemAt(compiledModelUrl, withItemAt: compiledTempModelUrl)
}
}

func initializeParameters() {
guard ((model?.neuralNetwork.layers) != nil) else {
return
}
for (indexA, neuralNetworkLayer) in model!.neuralNetwork.layers.enumerated() {
for (indexB, layer) in layerWrappers.enumerated() {
if layer.name != neuralNetworkLayer.name { continue }
switch neuralNetworkLayer.layer! {
case .convolution:
let convolution = neuralNetworkLayer.convolution
//shape definition = [outputChannels, kernelChannels, kernelHeight, kernelWidth]
let upperLower = Float(6.0 / Float(Int16(convolution.outputChannels) + Int16(convolution.kernelChannels) + Int16(convolution.kernelSize[0]) + Int16(convolution.kernelSize[1]))).squareRoot()
let initialise = (0..<(neuralNetworkLayer.convolution.weights.floatValue.count)).map { _ in Float.random(in: -upperLower...upperLower) }
model?.neuralNetwork.layers[indexA].convolution.weights.floatValue = initialise
layerWrappers[indexB].weights = initialise
case .innerProduct:
let innerProduct = neuralNetworkLayer.innerProduct
//shape definition = [C_out, C_in].
let upperLower = Float(6.0 / Float(Int16(innerProduct.outputChannels) + Int16(innerProduct.inputChannels))).squareRoot()
let initialise = (0..<(neuralNetworkLayer.innerProduct.weights.floatValue.count)).map { _ in Float.random(in: -upperLower...upperLower) }
model?.neuralNetwork.layers[indexA].innerProduct.weights.floatValue = initialise
layerWrappers[indexB].weights = initialise
default:
log.info("unexpected layer \(neuralNetworkLayer.name)")
continue
}
}
}
exportModel()
}

/// Updates the layers given the CoreML update context
///
/// - Parameters:
Expand Down
5 changes: 3 additions & 2 deletions examples/ios/FLiOS/FLiOSModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ public class FLiOSModel: ObservableObject {
let layerWrappers = modelInspect.getLayerWrappers()
self.mlFlwrClient = MLFlwrClient(layerWrappers: layerWrappers,
dataLoader: dataLoader,
compiledModelUrl: compiledModelUrl)
compiledModelUrl: compiledModelUrl,
modelUrl: url)
case .local:
self.localClient = LocalClient(dataLoader: dataLoader, compiledModelUrl: compiledModelUrl)
}
Expand Down Expand Up @@ -203,7 +204,7 @@ class LocalClient {

func runMLTask(statusHandler: @escaping (Constants.TaskStatus) -> Void,
numEpochs: Int,
task: flwr.MLTask
task: MLTask
) {
let dataset: MLBatchProvider
let configuration = MLModelConfiguration()
Expand Down

0 comments on commit e91635b

Please sign in to comment.